## Imports libraries

In [1]:
import torch
import torch.nn as nn
from torchviz import make_dot
from graphviz import Source
from IPython.display import display

## Discriminator

In [2]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels,stride):
        super(CNNBlock,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, 4,stride,bias=False,padding_mode="reflect"
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    def forward(self, x):
        return self.conv(x)

In [3]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, out_channels=[64,128,256,512] ):
        super(Discriminator,self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels *2,
                out_channels[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode='reflect'
            ),
            nn.LeakyReLU(0.2),
        )
        layers = []
        in_channels = out_channels[0]
        for feature in out_channels[1:]:
            layers.append(
                CNNBlock(in_channels, feature,stride=1 if feature == out_channels[-1] else 2)
            ),
            in_channels = feature
        layers.append(
            nn.Conv2d(
                in_channels,1,kernel_size=4,stride=1,padding=1,padding_mode="reflect"
            ),
        )
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, x,y):
        x = torch.cat([x,y],dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x

In [4]:
def visualize_model(x,y):
    model = Discriminator(in_channels=3)
    dummy_input = (x, y)
    dot = make_dot(model(dummy_input[0], dummy_input[1]), params=dict(model.named_parameters()))
    dot.view()

In [5]:
def test(x,y):
    model = Discriminator(in_channels=3)
    preds = model(x, y)
    print(model)

In [6]:

x = torch.randn((1, 3, 256, 256))
y = torch.randn((1, 3, 256, 256))
visualize_model(x,y)
test(x,y)

Discriminator(
  (initial): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (model): Sequential(
    (0): CNNBlock(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), bias=False, padding_mode=reflect)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): CNNBlock(
      (conv): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), bias=False, padding_mode=reflect)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (2): CNNBlock(
      (conv): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), bias=False, padding_mode=reflect)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=Tru

<img src="Discriminator.png" alt="Discriminator" width="600" heigth='500'>

## Generator

## Config

## Utils

## Dataset