In [None]:
import torch
import torch.nn as nn

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 64, kernel_size = 4, stride = 2, padding = 1, use_norm = True):
        super(EncoderBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.is_normalization = use_norm
        
        self.encoder_block = self.block()
        
    def block(self):
        self.layers = []
        
        self.layers.append(
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel_size,
                stride=self.stride,
                padding=self.padding
            )
        )
        
        if self.is_normalization:
            self.layers.append(nn.BatchNorm2d(num_features=self.out_channels))
            
        self.layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
        
        return nn.Sequential(*self.layers)
    
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.encoder_block(x)
        else:
            raise ValueError("Input should be in the format of the tensor".capitalize())


if __name__ == "__main__":
    in_channels = 3
    out_channels = 64
    kernel_size = 4
    stride = 2
    padding = 1

    layers = []

    for _ in range(2):
        layers.append(EncoderBlock(
            in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, use_norm=False))

        in_channels = out_channels

    for _ in range(3):
        layers.append(EncoderBlock(
            in_channels=in_channels, out_channels=out_channels * 2, kernel_size=kernel_size, stride=stride, padding=padding, use_norm=True))

        in_channels = out_channels * 2
        out_channels = in_channels

    layers.append(
        nn.Conv2d(
            in_channels=in_channels,
            out_channels=4000,
            kernel_size=1,
            stride=1,
            padding=0
        )
    ) 

    model = nn.Sequential(*layers)

    assert model(torch.randn(1, 3, 128, 128)).size() == (1, 4000, 4, 4)

In [None]:
class DecoderBlock(nn.Module):
    def __init__(
        self, in_channels=4000, out_channels=512, kernel_size=4, stride=2, padding=1
    ):
        super(DecoderBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        self.decoder_block = self.block()

    def block(self):
        self.layers = []

        self.layers.append(
            nn.ConvTranspose2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel_size,
                stride=self.stride,
                padding=self.padding,
            )
        )

        self.layers.append(nn.BatchNorm2d(num_features=self.out_channels))
        self.layers.append(nn.ReLU(inplace=True))

        return nn.Sequential(*self.layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.decoder_block(x)

        else:
            raise ValueError("Input should be in the format of the tensor".capitalize())


if __name__ == "__main__":
    layers = []

    in_channels = 4000
    out_channels = 512
    kernel_size = 4
    stride = 2
    padding = 1

    for _ in range(4):
        layers.append(
            DecoderBlock(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
            )
        )
        in_channels = out_channels
        out_channels = in_channels // 2

    layers.append(
        nn.ConvTranspose2d(
            in_channels=in_channels,
            out_channels=3,
            kernel_size=kernel_size - 1,
            stride=stride // stride,
            padding=padding,
        )
    )

    model = nn.Sequential(*layers)

    print(model(torch.randn(1, 4000, 4, 4)).size())

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=64):
        super(Generator, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel_size = 4
        self.stride = 2
        self.padding = 1

        self.layers = []

        for _ in range(2):
            self.layers.append(
                EncoderBlock(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=self.kernel_size,
                    stride=self.stride,
                    padding=self.padding,
                    use_norm=False,
                )
            )

            in_channels = out_channels

        for _ in range(3):
            self.layers.append(
                EncoderBlock(
                    in_channels=in_channels,
                    out_channels=out_channels * 2,
                    kernel_size=self.kernel_size,
                    stride=self.stride,
                    padding=self.padding,
                    use_norm=True,
                )
            )

            in_channels = out_channels * 2
            out_channels = in_channels

        self.layers.append(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=4000,
                kernel_size=self.kernel_size // self.kernel_size,
                stride=self.stride // self.stride,
                padding=0,
            )
        )

        in_channels = 4000

        for _ in range(4):
            self.layers.append(
                DecoderBlock(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=self.kernel_size,
                    stride=self.stride,
                    padding=self.padding,
                )
            )
            in_channels = out_channels
            out_channels = in_channels // 2

        self.layers.append(
            nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=3,
                kernel_size=self.kernel_size - 1,
                stride=self.stride // self.stride,
                padding=self.padding,
            )
        )

        self.model = nn.Sequential(*self.layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.model(x)
        else:
            raise ValueError("Input should be in the format of the tensor".capitalize())


if __name__ == "__main__":
    netG = Generator()

    print(netG(torch.randn(1, 3, 128, 128)).size())