## Imports libraries

In [1]:
import torch
import torch.nn as nn
from torchviz import make_dot

## Discriminator

In [2]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels,stride):
        super(CNNBlock,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, 4,stride,bias=False,padding_mode="reflect"
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    def forward(self, x):
        return self.conv(x)

In [3]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, out_channels=[64,128,256,512] ):
        super(Discriminator,self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels *2,
                out_channels[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode='reflect'
            ),
            nn.LeakyReLU(0.2),
        )
        layers = []
        in_channels = out_channels[0]
        for feature in out_channels[1:]:
            layers.append(
                CNNBlock(in_channels, feature,stride=1 if feature == out_channels[-1] else 2)
            ),
            in_channels = feature
        layers.append(
            nn.Conv2d(
                in_channels,1,kernel_size=4,stride=1,padding=1,padding_mode="reflect"
            ),
        )
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, x,y):
        x = torch.cat([x,y],dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x

In [41]:
def visualize_model(x,y):
    model = Discriminator(in_channels=3)
    dummy_input = (x, y)
    dot = make_dot(model(dummy_input[0], dummy_input[1]), params=dict(model.named_parameters()))
    dot.view()

In [42]:
def test(x,y):
    model = Discriminator(in_channels=3)
    preds = model(x, y)
    print(preds.shape)
    print(model)
    

In [43]:
x = torch.randn((1, 3, 256, 256))
y = torch.randn((1, 3, 256, 256))
visualize_model(x,y)
test(x,y)

torch.Size([1, 1, 26, 26])
Discriminator(
  (initial): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (model): Sequential(
    (0): CNNBlock(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), bias=False, padding_mode=reflect)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): CNNBlock(
      (conv): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), bias=False, padding_mode=reflect)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (2): CNNBlock(
      (conv): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), bias=False, padding_mode=reflect)
        (1): BatchNorm2d(512, eps=1e-0

<img src="Discriminator.png" alt="Discriminator" >

## Generator

In [7]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels,down=True, act="relu",use_dropout=False):
        super(Block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels,out_channels,4,2,1,bias=False,padding_mode="reflect"
            )
            if down
            else nn.ConvTranspose2d(in_channels,out_channels,4,2,1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down
    def forward(self,x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x

In [36]:
class Generator(nn.Module):
    def __init__(self, in_channel=3, out_channel=64):
        super(Generator,self).__init__()
        
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channel,out_channel,4,2,1,padding_mode="reflect"),
            nn.LeakyReLU(0.2)
        )
        self.down1 = Block(
            out_channel,out_channel*2,down=True,act="leaky",use_dropout=False
        )
        self.down2 = Block(
            out_channel*2,out_channel*4,down=True,act="leaky",use_dropout=False
        )
        self.down3 = Block(
            out_channel*4,out_channel*8,down=True,act="leaky",use_dropout=False
        )
        self.down4 = Block(
            out_channel*8,out_channel*8,down=True,act="leaky",use_dropout=False
        )
        self.down5 = Block(
            out_channel*8,out_channel*8,down=True,act="leaky",use_dropout=False
        )
        self.down6 = Block(
            out_channel*8,out_channel*8,down=True,act="leaky",use_dropout=False
        )
        
        self.bottleneck = nn.Sequential(
            nn.Conv2d(out_channel*8, out_channel*8,4,2,1),
            nn.ReLU()
        )
        
        self.up1 = Block(
            out_channel*8,out_channel*8,down=False,act='relu',use_dropout=True
        )
        self.up2 = Block(
            out_channel*8*2,out_channel*8,down=False,act='relu',use_dropout=True
        )
        self.up3 = Block(
            out_channel*8*2,out_channel*8,down=False,act='relu',use_dropout=True
        )
        self.up4 = Block(
            out_channel*8*2,out_channel*8,down=False,act='relu',use_dropout=True
        )
        self.up5 = Block(
            out_channel*8*2,out_channel*4,down=False,act='relu',use_dropout=True
        )
        self.up6 = Block(
            out_channel*4*2,out_channel*2,down=False,act='relu',use_dropout=True
        )
        self.up7 = Block(
            out_channel*2*2,out_channel,down=False,act='relu',use_dropout=True
        )
        
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(out_channel*2, in_channel, kernel_size=4,stride=2,padding=1),
            nn.Tanh()
        )
    def forward(self,x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1,d7],1))
        up3 = self.up3(torch.cat([up2,d6],1))
        up4 = self.up4(torch.cat([up3,d5],1))
        up5 = self.up5(torch.cat([up4,d4],1))
        up6 = self.up6(torch.cat([up5,d3],1))
        up7 = self.up7(torch.cat([up6,d2],1))
        return self.final_up(torch.cat([up7,d1],1))

In [44]:
def test(x):
    model = Generator(3, 64)
    preds = model(x)
    print(preds.shape)
    print(model)

In [45]:
def visualize_model(x):
    model = Generator(3,64)
    dot = make_dot(model(x), params=dict(model.named_parameters()))
    dot.view()

In [46]:
x = torch.randn((1, 3, 256, 256))
visualize_model(x)
test(x)

torch.Size([1, 3, 256, 256])
Generator(
  (initial_down): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (down1): Block(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (down2): Block(
    (conv): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (down3): Block(
    (conv): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), p

<img src="Generator.png" alt="Discriminator" >

## Config

## Utils

## Dataset