In [30]:
import torch
import torch.nn as nn

In [31]:
class DiscBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel=4, stride=2):
        super(DiscBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=out_channels,
                      kernel_size=kernel,
                      stride=stride,
                      bias=False,
                      padding_mode='reflect'),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(negative_slope=0.2)
        )
    def forward(self, x):
        return self.conv(x)

In [32]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=None, kernel=4, stride=2, padding=1):
        super().__init__()
        if not features:
            features = [64, 128, 256, 512]
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels*2, features[0], kernel_size=kernel, stride=stride, padding=padding, padding_mode='reflect'),
            nn.LeakyReLU(0.2)
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(DiscBlock(in_channels=in_channels,
                                   out_channels=feature,
                                   stride=1 if feature==features[-1] else 2,
                                   kernel=kernel
                                   ))
            in_channels = feature

        layers.append(nn.Conv2d(
            in_channels=in_channels,
            out_channels=1,
            kernel_size=kernel,
            stride = (1, 1),
            padding=1,
            padding_mode='reflect'
        ))
        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat((x,y), dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x

In [33]:
def test_discriminator():
    a = Discriminator()
    img = torch.ones(10, 3, 256, 256)
    print(a.forward(img, img).shape)

In [34]:
test_discriminator()

torch.Size([10, 1, 26, 26])


In [None]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act='relu', use_dropout=False):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                      out_channels=out_channels,
                      kernel_size=(4,4),
                      stride=(2,2),
                      padding=(1,1),
                      padding_mode='reflect')
            if down
            else nn.ConvTranspose2d(in_channels=in_channels,
                                    out_channels=out_channels,
                                    kernel_size=(4,4),
                                    stride=(2,2),
                                    padding=(1,1),
                                    bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act=='relu' else nn.LeakyReLU(0.2)
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.conv(x),
        x = self.dropout(x) if self.use_dropout else x
        return x