## Overall structure
Conv(5,d,1)
ReLU
Conv(1,s,d)
ReLU
mXConv(3,s,s)
ReLU
Conv(1,d,s)
ReLU
Deconv(9,1,d)

Loss fn: MSE
Lr: 0.00001
Hyperparams: d=32, s=5,m=1

In [1]:
import numpy as np 
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
from PIL import Image
from torchvision import transforms
import random
from tqdm.notebook import tqdm
from skimage.transform import resize

In [2]:
## Define the network blocks
class FeatureExtraction(nn.Module):
    #Conv2d -> ReLU
    def __init__(self,d,color_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=color_channels,out_channels=d,kernel_size=5,padding=4)
        self.relu = nn.ReLU()
    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        return x
    
#Shrinking block 
class ShrinkingBlock(nn.Module):
    #Conv2d -> ReLU
    def __init__(self,s,d):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=s,out_channels=d,kernel_size=1,padding=4)
        self.relu = nn.ReLU()
    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        return x

#Mapping block 
class MappingBlock(nn.Module):
    #Conv2d(3,s,s) -> ReLU
    def __init__(self,s):
        super().__init__()
        self.conv1 = nn.Conv2d(kernel_size=3,in_channels=s,out_channels=s,padding=4)
        self.relu = nn.ReLU()
    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        return x

#Expanding block
class ExpandingBlock(nn.Module):
    def __init__(self,d,s):
        super().__init__()
        self.conv1 = nn.Conv2d(kernel_size=1,in_channels=s,out_channels=d,padding=4)
        self.relu = nn.ReLU()
    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        return x

#Deconvolution block
class DeconvolutionBlock(nn.Module):
    def __init__(self,color_channels,upscale_factor,d):
        super().__init__()
        self.conv1 = nn.ConvTranspose2d(kernel_size=9,stride=upscale_factor,padding=4,in_channels=d,out_channels=color_channels,output_padding=upscale_factor-1) #output_padding = stride - 1
    def forward(self,x):
        x = self.conv1(x)
        return x

In [3]:
## Define the network
class FSRCNN(nn.Module):
    def __init__(self,input_size,d,s,m,color_channels,upscale_factor):
        super(FSRCNN,self).__init__()
        #Feature extraction
        self.feature_extraction = FeatureExtraction(d,color_channels)
        #Shrinking
        self.shrinking = ShrinkingBlock(s,d)
        #Mapping
        self.mapping = nn.Sequential(*[MappingBlock(s) for _ in range(m)])
        #Expanding
        self.expanding = ExpandingBlock(d,s)
        #Deconvolution
        self.deconvolution = DeconvolutionBlock(color_channels,upscale_factor,d)
    def forward(self,x):
        x = self.feature_extraction(x)
        x = self.shrinking(x)
        x = self.mapping(x)
        x = self.expanding(x)
        x = self.deconvolution(x)
        return x