In [None]:
from collections import OrderedDict
import torch
import torch.nn as nn

In [None]:
class InputBlock(nn.Module):
    def __init__(self, in_channels = None, out_channels = None):
        super(InputBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.kernel_size = 9
        self.stride = 1
        self.padding = 4
        
        try:
            self.model = self.input_block()
        except Exception as e:
            print("Input block not implemented")
        
    def input_block(self):
        layers = OrderedDict()
        
        layers["conv"] = nn.Conv2d(
            in_channels = self.in_channels,
            out_channels = self.out_channels,
            kernel_size = self.kernel_size,
            stride = self.stride,
            padding = self.padding
        )
        layers["PReLU"] = nn.PReLU()
        
        return nn.Sequential(layers)
    
    def forward(self, x=None):
        if x is not None:
            return self.model(x)
        else:
            raise Exception("Input block not implemented")

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels = None, out_channels = None, index = None):
        super(ResidualBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.index = index
        
        self.kernel_size = 3
        self.stride = 1
        self.padding = 1
        
        try:
            self.model = self.residual_block()
        except Exception as e:
            print("Residual block not implemented")
        
    def residual_block(self):
        layers = OrderedDict()
        
        layers["conv{}".format(self.index)] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding
        )
        layers["batchnorm{}".format(self.index)] = nn.BatchNorm2d(
            num_features=self.out_channels)
        
        layers["PReLU{}".format(self.index)] = nn.PReLU()
        layers["conv{}".format(self.index+1)] = nn.Conv2d(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding
        )
        layers["batchnorm{}".format(self.index+1)] = nn.BatchNorm2d(
            num_features=self.out_channels)
        
        return nn.Sequential(layers)
    
    def forward(self, x=None):
        if x is not None:
            return x + self.model(x)
        else:
            raise Exception("Residual block not implemented")

In [None]:
class MiddleBlock(nn.Module):
    def __init__(self, in_channels = None, out_channels = None):
        super(MiddleBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.kernel_size = 3
        self.stride = 1
        self.padding = 1
        
        try:
            self.model = self.middle_block()
        except Exception as _:
            print("Middle block not implemented")
            
    def middle_block(self):
        layers = OrderedDict()
        
        layers["conv"] = nn.Conv2d(
            in_channels = self.in_channels,
            out_channels = self.out_channels,
            kernel_size = self.kernel_size,
            stride = self.stride,
            padding = self.padding
        )
        
        layers["batchnorm"] = nn.BatchNorm2d(num_features=self.out_channels)
        
        return nn.Sequential(layers)
    
    def forward(self, x = None, skip_info = None):
        if (x is not None) and (skip_info is not None):
            return self.model(x) + skip_info 
        else:
            raise Exception("Middle block not implemented".capitalize())     

In [None]:
class UpSampleBlock(nn.Module):
    def __init__(self, in_channels = None, out_channels = None, is_first_block = False, index = None):
        super(UpSampleBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.is_first_block = is_first_block
        self.index = index
        
        self.kernel_size = 3
        self.stride = 1
        self.padding = 1
        self.factor = 2
        
        try:
            self.model = self.up_sample_block()
        except Exception as _:
            print("Up sample block not implemented".capitalize())
            
    def up_sample_block(self):
        layers = OrderedDict()
        layers["conv{}".format(self.index)] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding
        )
        layers["pixel_shuffle{}".format(self.index)] = nn.PixelShuffle(
            upscale_factor=self.factor)
        
        if self.is_first_block:
            layers["PReLU"] = nn.PReLU()
            
        return nn.Sequential(layers)
        
    def forward(self, x):
        if x is not None:
            return self.model(x)
        else:
            raise Exception("Up sample block not implemented".capitalize())

In [None]:
class OutputBlock(nn.Module):
    def __init__(self, in_channels = None, out_channels = None):
        super(OutputBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.kernel_size = 9
        self.stride = 1
        self.padding = 4
        
        try:
            self.model = self.output_block()
        except Exception as _:
            print("Output block not implemented".capitalize())
            
    def output_block(self):
        layers = OrderedDict()
        
        layers["conv"] = nn.Conv2d(
            in_channels = self.in_channels,
            out_channels = self.out_channels,
            kernel_size = self.kernel_size,
            stride=self.stride,
            padding = self.padding
        )
        layers["tanh"] = nn.Tanh()
        
        return nn.Sequential(layers)
    
    def forward(self, x):
        if x is not None:
            return self.model(x)
        else:
            raise Exception("Output block not implemented".capitalize())

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.num_repetitive = 16
        
        self.input_block = InputBlock(in_channels = 3, out_channels = 64)
        
        self.residual_block = nn.Sequential(
            *[ResidualBlock(in_channels=64, out_channels=64, index=index) for index in range(self.num_repetitive)])
        
        self.middle_block = MiddleBlock(in_channels=64, out_channels=64)
        
        self.up_sample = nn.Sequential(
            *[UpSampleBlock(in_channels=64, out_channels=256, is_first_block=is_first_block, index=index)
            for index, is_first_block in enumerate([True, False])])
        
        self.out_block = OutputBlock(in_channels=64, out_channels=3)
        
    def forward(self, x):
        if x is not None:
            input = self.input_block(x)
            residual = self.residual_block(input)
            middle = self.middle_block(residual, input)
            upsample = self.up_sample(middle)
            output = self.out_block(upsample)
            
            return output
        
        else:
            raise Exception("Generator not implemented".capitalize())

#### Discriminator

In [None]:
class InputBlock(nn.Module):
    def __init__(self, in_channels = None, out_channels = None):
        super(InputBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.kernel_size = 3
        self.stride = 1
        self.padding =1
        
        self.model = self.input_block()
        
    def input_block(self):
        return nn.Sequential(
            nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, self.kernel_size, self.stride, self.padding),
            nn.BatchNorm2d(num_features=self.out_channels),
            nn.LeakyReLU(negative_slope=0.2, inplace=True)
        )
        
    def forward(self, x):
        if x is not None:
            return self.model(x)

In [None]:
class FeatureBlock(nn.Module):
    def __init__(self, in_channels = None, out_channels = None, kernel_size=3, stride=2, padding=1):
        super(FeatureBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        self.model = self.feature_block()
        
    def feature_block(self):
        return nn.Sequential(
            nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding),
            nn.BatchNorm2d(num_features=self.out_channels),
            nn.LeakyReLU(negative_slope=0.2, inplace=True)
        )       
        
    def forward(self, x):
        if x is not None:
            return self.model(x)

In [None]:
class OutBlock(nn.Module):
    def __init__(self, in_channels = None, out_channels=None):
        super(OutBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = out_channels
        
        self.model = self.output_block()
        
    def output_block(self):
        return nn.Sequential(
            nn.Conv2d(self.in_channels*8, self.in_channels*16, self.kernel_size),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv2d(self.in_channels*16, 1, self.kernel_size),
            nn.Tanh() 
        )
        
    def forward(self, x):
        if x is not None:
            return self.model(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = None, out_channels = None):
        super(Discriminator, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.filters = out_channels
        
        self.kernel_size = 3
        self.stride = 1
        self.padding = 1
        self.num_repetitive = 7
        self.layers = []
        
        self.input = InputBlock(in_channels=self.in_channels, out_channels=self.out_channels)
        
        for index in range(self.num_repetitive):
            if index%2:
                self.layers.append(FeatureBlock(in_channels = self.out_channels, out_channels=self.out_channels*2))
                
                self.out_channels = self.out_channels*2
            
            else:
                self.layers.append(FeatureBlock(in_channels = self.out_channels, out_channels=self.out_channels))
                
                self.out_channels = self.out_channels
                
        self.features = nn.Sequential(*self.layers)
        
        self.avg_pool = nn.AdaptiveMaxPool2d(output_size=1)
        
        self.output = OutBlock(in_channels=self.filters, out_channels=1)
        
    def forward(self, x):
        input = self.input(x)
        features = self.features(input)
        output = self.output(self.avg_pool(features))
        
        return output

In [None]:
netD = Discriminator(in_channels=3, out_channels=64)

netD(torch.randn(1, 3, 256, 256)).size()