In [1]:
import torch
import helper
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

In [2]:
# !pip install helper

# Implementing th neural network of the paper pieces by pieces

## Convolution block

In [5]:
class ConvolutionBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, maxpool_kernel):
        super().__init__()
        
        ### Convolution layer
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.maxpool_kernel = maxpool_kernel
        
        self.block = nn.Sequential(
            nn.Conv2d(in_channels=self.in_channels, 
                      out_channels=self.out_channels, 
                      kernel_size=(3,2), padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
            nn.MaxPool2d(self.maxpool_kernel)
#             nn.Conv2d(in_channels=self.out_channels, 
#                       out_channels=self.out_channels, 
#                       kernel_size=3, stride=2, padding=1)
        )
        
    def forward(self, x):
        return self.block(x)

In [6]:
ConvolutionBlock(8, 32, (3, 2))

ConvolutionBlock(
  (block): Sequential(
    (0): Conv2d(8, 32, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=(3, 2), stride=(3, 2), padding=0, dilation=1, ceil_mode=False)
  )
)

In [32]:
class UpConvolutionBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, maxpool_kernel):
        super().__init__()
        
        ### Convolution layer
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.maxpool_kernel = maxpool_kernel
        
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channels=self.in_channels, 
                      out_channels=self.out_channels, 
                      kernel_size=(3,2), padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
            nn.MaxPool2d(self.maxpool_kernel),
            nn.Upsample()
        )
        
    def forward(self, x):
        return self.block(x)

In [33]:
UpConvolutionBlock(8, 32, (2, 2))

UpConvolutionBlock(
  (block): Sequential(
    (0): ConvTranspose2d(8, 32, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (4): Upsample(size=None, mode=nearest)
  )
)

## Encoder part

In [17]:
class Encoder(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        self.cbe1 = ConvolutionBlock(8, 32, (5, 2))
        self.cbe2 = ConvolutionBlock(32, 128, (4, 2))
        self.cbe3 = ConvolutionBlock(128, 256, (2, 2))
        
    def forward(self, x):
        
        x = self.cbe1(x)
        x = self.cbe2(x)
        x = self.cbe3(x)
        return x

In [23]:
Encoder()

Encoder(
  (cbe1): ConvolutionBlock(
    (block): Sequential(
      (0): Conv2d(8, 32, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=(5, 2), stride=(5, 2), padding=0, dilation=1, ceil_mode=False)
    )
  )
  (cbe2): ConvolutionBlock(
    (block): Sequential(
      (0): Conv2d(32, 128, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=(4, 2), stride=(4, 2), padding=0, dilation=1, ceil_mode=False)
    )
  )
  (cbe3): ConvolutionBlock(
    (block): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), pa

## Decoder part

In [26]:
# Add upsample to the decoder
class Decoder(nn.Module):
    
    def __init__(self, in_channels):
        
        super().__init__()
        self.in_channels = in_channels
        self.cbd1 = UpConvolutionBlock(self.in_channels, 256, (5, 4))
        self.cbd2 = UpConvolutionBlock(256, 128, (4,2))
        self.cbd3 = UpConvolutionBlock(128, 32, (2, 2))
        
    def forward(self, x):
        
        x = self.cbd1(x)
        x = self.cbd2(x)
        x = self.cbd3(x)
        return x

In [27]:
Decoder()

Decoder(
  (cbd1): ConvolutionBlock(
    (block): Sequential(
      (0): Conv2d(256, 256, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=(5, 4), stride=(5, 4), padding=0, dilation=1, ceil_mode=False)
    )
  )
  (cbd2): ConvolutionBlock(
    (block): Sequential(
      (0): Conv2d(256, 128, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=(4, 2), stride=(4, 2), padding=0, dilation=1, ceil_mode=False)
    )
  )
  (cbd3): ConvolutionBlock(
    (block): Sequential(
      (0): Conv2d(128, 32, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2),