In [2]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)

<torch._C.Generator at 0x7fccc852fab0>

In [3]:
import sys
sys.path.append('../unet')

In [4]:
import pytorch_lightning as pl

In [5]:
from unet_lightning import crop, show_tensor_images

In [29]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_dropout=False, use_bn=True):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.activation = nn.LeakyReLU(0.2)
        
        if use_bn:
            self.batchnorm = nn.BatchNorm2d(out_channels)
        self.use_bn = use_bn
        
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout
        
        
    def forward(self, x):
        x = self.conv1(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        return x

In [24]:
class DownConv(nn.Module):
    def __init__(self, in_channels, use_dropout=False, use_bn=True):
        super(DownConv, self).__init__()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        if use_bn:
            self.batchnorm = nn.BatchNorm2d(in_channels * 2)
        self.use_bn = use_bn
        
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout
            
        self.conv_block1 = ConvBlock(in_channels, in_channels*2, use_dropout, use_bn)
        self.conv_block2 = ConvBlock(in_channels*2, in_channels*2, use_dropout, use_bn)
            
    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.maxpool(x)
        return x

In [25]:
class UpConv(nn.Module):
    
    def __init__(self, input_channels, use_dropout=False, use_bn=True):
        super(UpConv, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv1 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=2)
        self.conv2 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(input_channels // 2, input_channels // 2, kernel_size=2, padding=1)
        if use_bn:
            self.batchnorm = nn.BatchNorm2d(input_channels // 2)
        self.use_bn = use_bn
        self.activation = nn.ReLU()
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout

    def forward(self, x, skip_con_x):

        x = self.upsample(x)
        x = self.conv1(x)
        skip_con_x = crop(skip_con_x, x.shape)
        x = torch.cat([x, skip_con_x], axis=1)
        x = self.conv2(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.conv3(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        return x

In [40]:
class Generator(pl.LightningModule):
    def __init__(self, in_channels, out_channels, hidden_channels = 32, depth=6):
        super(Generator, self).__init__()
                
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
        
        self.conv_final = nn.Conv2d(hidden_channels,
                                    out_channels,
                                    kernel_size=1)
        self.depth = depth

        self.contracting_layers = []
        self.expanding_layers = []
        self.sigmoid = nn.Sigmoid()
        
        # encoding/contracting path of the Generator
        for i in range(depth):
            self.contracting_layers += [
                DownConv(hidden_channels * 2**i, use_dropout=True if i<3 else False)
            ]
        
        # Upsampling/Expanding path of the Generator
        for i in range(1, depth + 1):
            self.expanding_layers += [UpConv(hidden_channels * 2**i)]

        self.contracting_layers = nn.ModuleList(self.contracting_layers)
        self.expanding_layers = nn.ModuleList(self.expanding_layers)
        
    
    def forward(self, x):
        depth = self.depth
        contractive_x = []

        x = self.conv1(x)
        contractive_x.append(x)

        for i in range(depth):
            x = self.contracting_layers[i](x)
            print(x.shape)
            contractive_x.append(x)
        
        for i in range(depth - 1, -1, -1):
            x = self.expanding_layers[i](x, contractive_x[i])
            print(x.shape)
        x = self.conv_final(x)
        
        return self.sigmoid(x)

In [11]:
class Discriminator(pl.LightningModule):

    def __init__(self, input_channels, hidden_channels=8):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, hidden_channels, kernel_size=1)
        self.contract1 = DownConv(hidden_channels, use_bn=False)
        self.contract2 = DownConv(hidden_channels * 2)
        self.contract3 = DownConv(hidden_channels * 4)
        self.contract4 = DownConv(hidden_channels * 8)
        self.final = nn.Conv2d(hidden_channels * 16, 1, kernel_size=1)
        
    def forward(self, x, y):
        x = torch.cat([x, y], axis=1)
        x0 = self.conv1(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.contract3(x2)
        x4 = self.contract4(x3)
        xn = self.final(x4)
        return xn

In [11]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

In [27]:
class Pix2Pix(pl.LightningModule):
    def __init__(self,
                 in_channels,
                 out_channels,
                 hidden_channels=32,
                 depth=6,
                 learning_rate=0.0002,
                 lambda_recon=None):
        
        self.gen = Generator(in_channels, out_channels, hidden_channels, depth)
        self.disc = Discriminator(input_channels, hidden_channels=8)
        self.learning_rate = learning_rate

        # intializing weights
        self.gen = self.gen.apply(weights_init)
        self.disc = self.disc.apply(weights_init)

        self.adv_criterion = nn.BCEWithLogitsLoss()
        self.recon_criterion = nn.L1Loss()
        if lambda_recon is None:
            lambda_recon = 200
        self.lambda_recon = lambda_recon

In [50]:
unet = UNet(3, 3)

In [54]:
unet(torch.randn(1, 3, 256, 256))
pass

torch.Size([1, 32, 256, 256])
torch.Size([1, 64, 128, 128])
torch.Size([1, 128, 64, 64])
torch.Size([1, 256, 32, 32])
torch.Size([1, 512, 16, 16])
torch.Size([1, 1024, 8, 8])
torch.Size([1, 2048, 4, 4])
torch.Size([1, 1024, 8, 8])
torch.Size([1, 512, 16, 16])
torch.Size([1, 256, 32, 32])
torch.Size([1, 128, 64, 64])
torch.Size([1, 64, 128, 128])
torch.Size([1, 32, 256, 256])
torch.Size([1, 3, 256, 256])


In [52]:
gen = Generator(3, 3, depth=6)

In [53]:
gen(torch.randn(1, 3, 256, 256)).shape

torch.Size([1, 64, 128, 128])
torch.Size([1, 128, 64, 64])
torch.Size([1, 256, 32, 32])
torch.Size([1, 512, 16, 16])
torch.Size([1, 1024, 8, 8])
torch.Size([1, 2048, 4, 4])
torch.Size([1, 1024, 8, 8])
torch.Size([1, 512, 16, 16])
torch.Size([1, 256, 32, 32])
torch.Size([1, 128, 64, 64])
torch.Size([1, 64, 128, 128])
torch.Size([1, 32, 256, 256])


torch.Size([1, 3, 256, 256])