In [1]:
import torch
from torch import nn

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, 
                      out_channels=out_channels, 
                      kernel_size=4, 
                      stride=stride,
                      padding=1,
                      bias="False",
                      padding_mode="reflect"),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )


    def forward(self, x):
        return self.conv(x)

ConvBlock(5,10,1)

ConvBlock(
  (conv): Sequential(
    (0): Conv2d(5, 10, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
    (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
)

In [2]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features = [64, 128, 256, 512]):
        super().__init__()

        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels*2, 
                out_channels=features[0], 
                kernel_size=4, 
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_c = features[0]
        for feature in features[1:]:
            layers.append(ConvBlock(in_c, feature, stride= 1 if feature == features[-1] else 2))
            in_c = feature

        layers.append(
            nn.Conv2d(
                in_c, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            )
        )

        # layers.append(nn.AdaptiveAvgPool2d((30, 30)))
        self.all = nn.Sequential(*layers)
    
    def forward(self, x, y):
        X = torch.cat([x, y], dim=1)
        print(X.shape)
        X = self.initial(X)
        print(X.shape)
        return self.all(X)
        

In [4]:
x = torch.rand(1, 3, 256, 256)
y = torch.rand(1, 3, 256, 256)

d = Discriminator(in_channels=3)
d(x, y).shape

torch.Size([1, 6, 256, 256])
torch.Size([1, 64, 128, 128])


torch.Size([1, 1, 30, 30])

In [5]:
d

Discriminator(
  (initial): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (all): Sequential(
    (0): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (2): ConvBlock(
      (conv): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.