## Demostration of Spatial Transformer Network

In this notebook, I am going to showcase the highly popular **Spatial Transformer Network** using Pytorch.

A **Spatial Transformer Network** is nothing but an addon to the normal neural network architecture. What it does is mainly reorient the transformed data to a setting which helps the network to predict better scores.

In this notebook, we are going to use it on EMNIST dataset (with letters split) and is going to set it up just after the input images.

For a better intuition about Spatial Transformers you can have a look at the paper: https://arxiv.org/pdf/1506.02025.pdf

In [None]:
## Importing the necessary packages ##

import torch
import torch.nn as nn
from tqdm import tqdm
import torchvision
import torch.nn.functional as F
from torch.utils.data import Dataset , DataLoader
from torchvision.transforms import transforms
from torchvision.datasets import EMNIST
from torchvision.utils import make_grid


import numpy as np
import matplotlib.pyplot as plt

As always the first thing to do is to set up the dataset.

Since, we are going to use the EMNIST dataset, it is already preloaded in torchvision, our task is very easy-- just by using the <code>EMNIST</code> method from the <code>torchvision.datasets</code> package. 

But, since we want the dataset to have certain aspect ratio/size we would like to define some transforms to give as an argument to the EMNIST method.  

In [None]:
## Defining transformations ##

aug = transforms.Compose([
    transforms.RandomAffine(degrees = 30,
                            scale = (0.5 , 1.5)
                           ),
    transforms.ToTensor()
])

## Loading the train dataset to the disk ##

emnist_letter_train_dataset = EMNIST(root = 'train_data' ,
                               split = 'letters' ,
                               train = True , 
                               download = True , 
                               transform = aug)


With the data loaded lets check our data.

In [None]:
## Checking the length of the dataset ##

print('Training dataset length :' , len(emnist_letter_train_dataset))

Lets check the shape of a datapoint. I assume it would be a tuple, where the first part is the image and the second part would be the label.

In [None]:
## Getting a random integer ##

rand_idx = int(np.random.randint(low = 0 , high = len(emnist_letter_train_dataset) , size = 1))

## Getting a datapoint from the training set ##

train_datapoint = emnist_letter_train_dataset[rand_idx]

print('Shape of the train datapoint is :' , len(train_datapoint))
print('Datatype of the train datapoint is :' , type(train_datapoint))

As guessed it is a tuple of size 2. 

Now let's see what's inside the tuple. 

In [None]:
## Checking the inside of the datapoint ##

## Extracting the image data ##

img = train_datapoint[0]

print('The shape of the image is :' , img.shape)

## Extracting the label ##

label = train_datapoint[1]

print('The label is :' , label)

Since, this is similar to the much popular MNIST dataset, the images are grayscale with dimension 28 * 28. 

Lets visualize the image.

In [None]:
## Visualizing the image ##

plt.title(label)

plt.imshow(img.permute(1 , 2 , 0) , cmap = 'gray')

plt.show()

Now to feed into a network we need to create a dataloader which sends in batches of data.

In [None]:
## Creating the train dataloader ##

train_dataloader = DataLoader(dataset = emnist_letter_train_dataset,
                              batch_size = 16 , 
                              shuffle = True)

Lets check the length of the dataloaders!!

In [None]:
## Checking the length of the dataloaders ##

print('Length of the train dataloader :' , len(train_dataloader))


Okay cool.

But its going to be really awesome if we could visualize a set of the data from the dataloader. So, lets do that.

In [None]:
## Creating an utility function to visualize a set of data ##

def visualize(img_batch):
    '''
    Function to visualize a batch (taken as 16) of image data.
    '''
    
    fig , ax = plt.subplots(figsize = (4 , 4))
    plt.imshow(make_grid(img_batch.detach().to('cpu') , 4).permute(1 , 2 , 0))
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    plt.show()
    
    
## Visualizing a batch of images ##

for img , _ in train_dataloader:
    
    visualize(img)
    
    break

With that out of the way, let's head down to the main part of the project-- the creation of the network.

Our main network is going to be very simple but the highlight of the model is the Spatial Transformer module, which is comprised of three parts namely:

- Localization Network
- Grid Generator
- Sampler

![](stn.png)

The 2nd and the 3rd part, namely, Grid Generator and the Sampler are easily taken care of by the ```affine_grid``` and the ```grid_sample``` methods of the ```torch.nn.functional``` package respectively.

In [None]:
## Creating the network module ##

class network(nn.Module):
    '''
    The network incorporating the Spatial Transformer Network module along with the primary backbone.
    '''
    
    def __init__(self):
        '''
        The constructor method. In this the general backbone of the network is created with the class variable net,
        and the localization network of spatial transformer module is created with the class variable localization_network.
        '''
        super().__init__()
        
        ## Creating the spine network ##
        
        self.net = nn.Sequential(nn.Conv2d(in_channels = 1,
                                           out_channels = 8,
                                           kernel_size = 3,
                                           stride = 1,
                                           padding = 1),
                                 nn.BatchNorm2d(num_features = 8),
                                 nn.ReLU(),
                                 nn.MaxPool2d(kernel_size = 2,
                                              stride = 2),   ## (14 , 14 , 8)
                                 
                                 ############################################
                                 
                                 nn.Conv2d(in_channels = 8,
                                           out_channels = 16,
                                           kernel_size = 3,
                                           stride = 1,
                                           padding = 1),
                                 nn.BatchNorm2d(num_features = 16),
                                 nn.ReLU(),
                                 nn.MaxPool2d(kernel_size = 2,
                                              stride = 2), ## (7 , 7 , 16)
                                 
                                 ############################################
                                 
                                 nn.Conv2d(in_channels = 16,
                                           out_channels = 32,
                                           kernel_size = 3,
                                           stride = 1,
                                           padding = 1),
                                 nn.BatchNorm2d(num_features = 32),
                                 nn.ReLU(),
                                 nn.MaxPool2d(kernel_size = 2,
                                              stride = 2), ## (3 , 3 , 32)
                                 
                                 
                                 ############################################
                                 
                                 nn.Conv2d(in_channels = 32,
                                           out_channels = 64,
                                           kernel_size = 3,
                                           stride = 1,
                                           padding = 1),
                                 nn.BatchNorm2d(num_features = 64),
                                 nn.ReLU(),
                                 nn.MaxPool2d(kernel_size = 2,
                                              stride = 2), ## (1 , 1 , 64)
                                 
                                 ############################################
                                 
                                 nn.Flatten(),
                                 nn.Linear(64 , 27)
                                )
        
        ## Creating the localization network of spatial transformer module ##
        
        self.localization_network = nn.Sequential(nn.Conv2d(in_channels = 1,
                                                            out_channels = 8,
                                                            kernel_size = 9,
                                                            stride = 1),
                                    nn.BatchNorm2d(num_features = 8),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size = 2,
                                                 stride = 2),   ## (10 , 10 , 8)
                                                
                                    ############################################
                                                  
                                    nn.Conv2d(in_channels = 8,
                                                            out_channels = 16,
                                                            kernel_size = 7,
                                                            stride = 1),
                                    nn.BatchNorm2d(num_features = 16),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size = 2,
                                                 stride = 2),   ## (2 , 2 , 16)
                                                
                                    ############################################
                                    
                                    nn.Flatten(),
                                    nn.Linear(2 * 2 * 16 , 16),
                                    nn.ReLU(),
                                    nn.Linear(16 , 6)              
                                    )
        
        ## Initializing the weights and bias of the output of the localization network with identity transformation ##
        
        self.localization_network[-1].weight.data.zero_()
        self.localization_network[-1].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
        
    
    def forward_stn(self , inp):
        '''
        Defines the forward pass of the Spatial Transformer Network.
        This is necessary to finally visualize the output of the STN.
        '''
        
        out = self.localization_network(inp)
        
        ## Reshape the output to have the shape (batch_num , 2 rows , 3 columns) ##

        out = out.reshape(-1 , 2 , 3)
        
        ## Grid Generator ##
        
        generated_grid = F.affine_grid(out , inp.size())
        
        ## Sampler ##
        
        out = F.grid_sample(inp , generated_grid)
        
        return out
    
    def forward(self , inp):
        '''
        Defines one forward pass through the network.
        '''
        
        ## We decided to put the STN after the input image ##
        
        inp = self.forward_stn(inp)
        
        out = self.net(inp)
        
        return out

Done. Our model is created.

Now lets set our model and put it to GPU.

In [None]:
## Get the device ##

def get_device():
    '''
    Sets the torch.device to cuda or cpu.
    '''
    if torch.cuda.is_available():
        return torch.device('cuda')
    
    return torch.device('cpu')

## Setting the device ##

device = get_device()

## Putting the model to the device ##

model = network().to(device)

Done. Our model object is created and is put to the GPU (I am certain its a GPU, because I have one. :P)

Now lets set our loss function and our optimizer.

In [None]:
## Setting the loss function ##

loss_func = torch.nn.CrossEntropyLoss()

## Setting the optimizer ##

optim = torch.optim.Adam(model.parameters() , lr = 3e-4)

And we are all set to train our model.

In [None]:
## Setting the training phase ##

epochs = 50

for epoch in range(epochs):
    
    loop = tqdm(train_dataloader)
    
    for img , label in loop:
        
        img = img.to(device)
        
        label = label.to(device)
        
        pred = model(img)
        
        #print('Maximum label value :' , torch.max(label))
        
        loss = loss_func(pred , label)
        
        loss.backward()
        
        optim.zero_grad()
        
        optim.step()
        
        loop.set_description('Epoch : {} / {}'.format(epoch + 1 , epochs))
        
        loop.set_postfix(loss = loss.item())