In [1]:

import matplotlib.pyplot as plt
import numpy as np

import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

from torchvision import datasets, transforms
from torch.autograd import Variable

plt.ion()


In [2]:
# training data
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307, ), (0.3081))
                   ])), batch_size=64, shuffle=True, num_workers=4)

                   
# test data
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root='.', train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307, ), (0.3081))
                   ])), batch_size=64, shuffle=True, num_workers=4)


In [3]:

class STNet(nn.Module):
    
    def __init__(self):
        
        super(STNet, self).__init__()
        
        self.drop_prob = 0.6
        
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv_drop = nn.Dropout2d(p=self.drop_prob)
        
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        
        # spatial transformer localisation net
        self.localisation = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(inplace=True)
        )
        
        # 2x3 affine trafo param estimation
        self.trafo_est = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 2 * 3)
        )
        
        # init trafo to identity
        self.trafo_est[2].weight.data.fill_(0) # ignore inputs
        self.trafo_est[2].bias.data = torch.FloatTensor([1, 0, 0,  #  make bias output
                                                         0, 1, 0]) #  the identity affine matrix
    
    
    def warp(self, x):
        """ spatial transformer learned warping function """

        # localise, flatten
        xs = self.localisation(x)
        xs = xs.view(-1, 10 * 3 * 3)

        # estimate params, flatten
        theta = self.trafo_est(xs)
        theta = theta.view(-1, 2, 3)

        # resample image
        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)

        return x
    
    def classify(self, x):
        """ spatial transformer classifer function """
        
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.drop_prob, training=self.training)
        x = self.fc2(x)
        
        return F.log_softmax(x, dim=1)
    
    def forward(self, x):
        """ forward activation function for STNet """
        return self.classify(self.warp(x))

model = STNet()
model.cuda()


STNet (
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv_drop): Dropout2d (p=0.6)
  (fc1): Linear (320 -> 50)
  (fc2): Linear (50 -> 10)
  (localisation): Sequential (
    (0): Conv2d(1, 8, kernel_size=(7, 7), stride=(1, 1))
    (1): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (2): ReLU (inplace)
    (3): Conv2d(8, 10, kernel_size=(5, 5), stride=(1, 1))
    (4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (5): ReLU (inplace)
  )
  (trafo_est): Sequential (
    (0): Linear (90 -> 32)
    (1): ReLU (inplace)
    (2): Linear (32 -> 6)
  )
)

In [None]:
optimiser = optim.SGD(model.parameters(), lr=0.01)

def train(epoch):
    model.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data.cuda()), Variable(target.cuda())
        
        optimiser.zero_grad()
        output = model(data)
        
        loss = F.nll_loss(output, target)
        loss.backward()
        optimiser.step()
        
        if batch_idx % 500 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))
        
        