In [1]:
import torch
import helper
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

In [2]:
# !pip install helper

# Implementing th neural network of the paper pieces by pieces

## Convolution block

In [3]:
class ConvolutionBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, maxpool_kernel):
        super().__init__()
        
        ### Convolution layer
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.maxpool_kernel = maxpool_kernel
        
        self.block = nn.Sequential(
            nn.Conv2d(in_channels=self.in_channels, 
                      out_channels=self.out_channels, 
                      kernel_size=(3,2), padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
            nn.MaxPool2d(self.maxpool_kernel)
#             nn.Conv2d(in_channels=self.out_channels, 
#                       out_channels=self.out_channels, 
#                       kernel_size=3, stride=2, padding=1)
        )
        
    def forward(self, x):
        return self.block(x)

In [4]:
ConvolutionBlock(8, 32, (3, 2))

ConvolutionBlock(
  (block): Sequential(
    (0): Conv2d(8, 32, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=(3, 2), stride=(3, 2), padding=0, dilation=1, ceil_mode=False)
  )
)

In [5]:
class UpConvolutionBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, maxpool_kernel):
        super().__init__()
        
        ### Convolution layer
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.maxpool_kernel = maxpool_kernel
        
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.in_channels, 
                      out_channels=self.out_channels, 
                      kernel_size=(3,2), padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
            nn.Upsample()
        )
        
    def forward(self, x):
        return self.block(x)

In [6]:
UpConvolutionBlock(8, 32, (2, 2))

UpConvolutionBlock(
  (block): Sequential(
    (0): ConvTranspose2d(8, 32, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Upsample(size=None, mode=nearest)
  )
)

## Encoder part

In [7]:
class Encoder(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        self.cbe1 = ConvolutionBlock(8, 32, (5, 2))
        self.cbe2 = ConvolutionBlock(32, 128, (4, 2))
        self.cbe3 = ConvolutionBlock(128, 256, (2, 2))
        
    def forward(self, x):
        
        x = self.cbe1(x)
        x = self.cbe2(x)
        x = self.cbe3(x)
        return x

In [8]:
Encoder()

Encoder(
  (cbe1): ConvolutionBlock(
    (block): Sequential(
      (0): Conv2d(8, 32, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=(5, 2), stride=(5, 2), padding=0, dilation=1, ceil_mode=False)
    )
  )
  (cbe2): ConvolutionBlock(
    (block): Sequential(
      (0): Conv2d(32, 128, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=(4, 2), stride=(4, 2), padding=0, dilation=1, ceil_mode=False)
    )
  )
  (cbe3): ConvolutionBlock(
    (block): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), pa

## Decoder part

In [16]:
# Add upsample to the decoder, not correct yet
class Decoder(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        #self.in_channels = in_channels
        self.cbd1 = UpConvolutionBlock(256, 256, (5, 4))
        self.cbd2 = UpConvolutionBlock(256, 128, (4,2))
        self.cbd3 = UpConvolutionBlock(128, 32, (2, 2))
        
    def forward(self, x):
        
        x = self.cbd1(x)
        x = self.cbd2(x)
        x = self.cbd3(x)
        return x

In [17]:
Decoder()

Decoder(
  (cbd1): UpConvolutionBlock(
    (block): Sequential(
      (0): ConvTranspose2d(256, 256, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Upsample(size=None, mode=nearest)
    )
  )
  (cbd2): UpConvolutionBlock(
    (block): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Upsample(size=None, mode=nearest)
    )
  )
  (cbd3): UpConvolutionBlock(
    (block): Sequential(
      (0): ConvTranspose2d(128, 32, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Upsample(size=None, mode=nearest)
    )
  )
)

# Residual block

source: https://blog.paperspace.com/writing-resnet-from-scratch-in-pytorch/

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU())
        
        self.conv2 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU())
        
        self.conv3 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU())
        self.downsample = downsample
        self.relu = nn.ReLU()
        self.out_channels = out_channels
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

In [None]:
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes = 10):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Sequential(
                        nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3),
                        nn.BatchNorm2d(64),
                        nn.ReLU())
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
        self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
        self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
        self.layer3 = self._make_layer(block, 512, layers[3], stride = 2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512, num_classes)
        
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:
            
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x