In [10]:
from collections import OrderedDict
import torch
import torch.nn as nn

In [13]:
class InputBlock(nn.Module):
    def __init__(self, in_channels = None, out_channels = None):
        super(InputBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.kernel_size = 9
        self.stride = 1
        self.padding = 4
        
        try:
            self.model = self.input_block()
        except Exception as e:
            print("Input block not implemented")
        
    def input_block(self):
        layers = OrderedDict()
        
        layers["conv"] = nn.Conv2d(
            in_channels = self.in_channels,
            out_channels = self.out_channels,
            kernel_size = self.kernel_size,
            stride = self.stride,
            padding = self.padding
        )
        layers["PReLU"] = nn.PReLU()
        
        return nn.Sequential(layers)
    
    def forward(self, x=None):
        if x is not None:
            return self.model(x)
        else:
            raise Exception("Input block not implemented")

In [26]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels = None, out_channels = None, index = None):
        super(ResidualBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.index = index
        
        self.kernel_size = 3
        self.stride = 1
        self.padding = 1
        
        try:
            self.model = self.residual_block()
        except Exception as e:
            print("Residual block not implemented")
        
    def residual_block(self):
        layers = OrderedDict()
        
        layers["conv{}".format(self.index)] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding
        )
        layers["batchnorm{}".format(self.index)] = nn.BatchNorm2d(
            num_features=self.out_channels)
        
        layers["PReLU{}".format(self.index)] = nn.PReLU()
        layers["conv{}".format(self.index+1)] = nn.Conv2d(
            in_channels=self.out_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding
        )
        layers["batchnorm{}".format(self.index+1)] = nn.BatchNorm2d(
            num_features=self.out_channels)
        
        return nn.Sequential(layers)
    
    def forward(self, x=None):
        if x is not None:
            return x + self.model(x)
        else:
            raise Exception("Residual block not implemented")

In [40]:
class MiddleBlock(nn.Module):
    def __init__(self, in_channels = None, out_channels = None):
        super(MiddleBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.kernel_size = 3
        self.stride = 1
        self.padding = 1
        
        try:
            self.model = self.middle_block()
        except Exception as _:
            print("Middle block not implemented")
            
    def middle_block(self):
        layers = OrderedDict()
        
        layers["conv"] = nn.Conv2d(
            in_channels = self.in_channels,
            out_channels = self.out_channels,
            kernel_size = self.kernel_size,
            stride = self.stride,
            padding = self.padding
        )
        
        layers["batchnorm"] = nn.BatchNorm2d(num_features=self.out_channels)
        
        return nn.Sequential(layers)
    
    def forward(self, x = None, skip_info = None):
        if (x is not None) and (skip_info is not None):
            return self.model(x) + skip_info 
        else:
            raise Exception("Middle block not implemented".capitalize())     

In [44]:
class UpSampleBlock(nn.Module):
    def __init__(self, in_channels = None, out_channels = None, is_first_block = False, index = None):
        super(UpSampleBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.is_first_block = is_first_block
        self.index = index
        
        self.kernel_size = 3
        self.stride = 1
        self.padding = 1
        self.factor = 2
        
        try:
            self.model = self.up_sample_block()
        except Exception as _:
            print("Up sample block not implemented".capitalize())
            
    def up_sample_block(self):
        layers = OrderedDict()
        layers["conv{}".format(self.index)] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding
        )
        layers["pixel_shuffle{}".format(self.index)] = nn.PixelShuffle(
            upscale_factor=self.factor)
        
        if self.is_first_block:
            layers["PReLU"] = nn.PReLU()
            
        return nn.Sequential(layers)
        
    def forward(self, x):
        if x is not None:
            return self.model(x)
        else:
            raise Exception("Up sample block not implemented".capitalize())

In [46]:
class OutputBlock(nn.Module):
    def __init__(self, in_channels = None, out_channels = None):
        super(OutputBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.kernel_size = 9
        self.stride = 1
        self.padding = 4
        
        try:
            self.model = self.output_block()
        except Exception as _:
            print("Output block not implemented".capitalize())
            
    def output_block(self):
        layers = OrderedDict()
        
        layers["conv"] = nn.Conv2d(
            in_channels = self.in_channels,
            out_channels = self.out_channels,
            kernel_size = self.kernel_size,
            stride=self.stride,
            padding = self.padding
        )
        layers["tanh"] = nn.Tanh()
        
        return nn.Sequential(layers)
    
    def forward(self, x):
        if x is not None:
            return self.model(x)
        else:
            raise Exception("Output block not implemented".capitalize())

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
    def forward(self, x):
        pass