In this notebook, we explore the "Rotated MNIST + background images" from https://sites.google.com/a/lisa.iro.umontreal.ca/public_static_twiki/variations-on-the-mnist-digits and build a Data Loader to process this into Dataset objects. Note that in an accompanying notebook, we use a similar neural architecture on the classical MNIST dataset and find very good performance. In fact, the training achieves $>95%$ accuracy on a validation batch within the first few hundred iterations (after around just 1 epoch). It appears challenging to achieve similar results on the version of the MNIST dataset used here.

A wonderful reference for convolutional arithmetic is __Dumoulin, Vincent, and Francesco Visin. "A guide to convolution arithmetic for deep learning." arXiv preprint arXiv:1603.07285 (2016)__ . One of the basic takeaways is that most of the dimensioning for multivariate tensors can be analyzed by considering single dimensions. This makes accounting for kernel sizes, strides, padding, pooling, etc. more intuitive.  


In [None]:
import requests
import zipfile
import matplotlib.pyplot as plt
import os

import torch
from torch import nn
import torch.nn.functional as F
import argparse
from typing import List, Dict, Any

In [None]:
class BaseDataset(torch.utils.data.Dataset):

  def __init__(self, x,y):
    super().__init__()
    self.x = x
    self.y = y

  def __len__(self):
    return self.x.shape[0]

  def __getitem__(self, index):
    'produces a single data point'
    xs = self.x[index]
    ys = self.y[index]

    return xs, ys

class MNISTDataModule:
  '''
  Helper module that downloads, loads, and partitions the MNIST dataset.
  '''
  url = "http://www.iro.umontreal.ca/~lisa/icml2007data/"
  filename = "mnist_rotation_back_image_new.zip"

  def __init__(self, dir, batch_size=32):
      self.dir = dir 
      self.batch_size = batch_size
      self.path = self.dir + '/' +self.filename

  def download_data(self):
    # create directories and download dataset
    if not os.path.exists(self.dir):
      os.mkdir(self.dir)
    if not os.path.exists(self.path):
      content = requests.get(self.url + self.filename).content
      with open(self.path, "wb") as f:
        f.write(content)
    with zipfile.ZipFile(self.path) as f:
      f.extractall(path=self.dir)
      
  def setup(self):
    # load data
    with open(self.dir+'/mnist_all_background_images_rotation_normalized_test.amat', 'r') as f1:
      ds_te = [[float(a) for a in line.split()] for line in f1]
    with open(self.dir+'/mnist_all_background_images_rotation_normalized_train_valid.amat', 'r') as f2:
      ds_tr_val = [[float(a) for a in line.split()]  for line in f2]
      
    ds_te, ds_tr_val = map(torch.tensor, (ds_te, ds_tr_val))

    # hardwired 80%-20% split into training and validation
    n1 = int(0.8*ds_tr_val.shape[0])

    Xtr, Ytr = ds_tr_val[:n1,:-1], ds_tr_val[:n1,-1]
    Xval, Yval = ds_tr_val[n1:,:-1], ds_tr_val[n1:,-1]
    Xte, Yte = ds_te[:,:-1], ds_te[:,-1]
    
    self.train_ds = BaseDataset(Xtr, Ytr)
    self.valid_ds = BaseDataset(Xval, Yval)
    self.test_ds = BaseDataset(Xte, Yte)

  def train_dataloader(self):
    return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)

  def val_dataloader(self):
    return torch.utils.data.DataLoader(self.valid_ds, batch_size=3*self.batch_size, shuffle=False)

In [None]:
torch.manual_seed(42)

dir = "/content/data"
bs = 32
dm = MNISTDataModule(dir, batch_size = bs)
dm.download_data()

In [None]:
dm.setup()

In [None]:
train_dl = dm.train_dataloader()

In [None]:
#load a batch of data
dataiter = iter(train_dl)
X,y = dataiter.next()

In [None]:
#verify that the data is normalized
print(X.min().item(), X.max().item())

In [None]:
## let's look at some data samples

fig, ax = plt.subplots(figsize=(15, 5), nrows=2, ncols=5)
i = 0
for row in ax:
    for col in row:
        
        col.imshow(X[i].reshape(28,28), cmap='Greys')
        col.set_title('label = '+str(y[i].item()))
        col.set_axis_off()
        i+=1

plt.show()

The digits appear rotated and flipped/mirrored (both horizontally and vertically). Moreover, the digits have been overlayed on various textured backgrounds. Naturally, these backgrounds are irrelevant to the label. The combination of these corruptions renders some of the digits beyond human recognition. But are they beyond machine recognition?

How should we think of this data? Let us suppose the data are random variables $X$ (image) and $Y$ (label) with unknown joint distribution $P_{XY}(x,y)$. We further assume that our training, validation, and test datasets are constructed from i.i.d. samples from $P$. The marginals $P_{X}(x)$, $P_{Y}(y)$ and conditional distributions $P_{X|Y}(x|Y=y)$, $P_{Y|X}(y|X=x)$ describe important information about the dataset. For example,
$P(X|Y=y)$ describes the within-class variability in $X$. Presumably, this will be quite high given the differing backgrounds, rotations, and other transformations applied to the associated digit. Alternatively, $P(Y|X=x)$ is basically what we are trying to model in constructing a classifier. 

In [None]:
#how balanced is this dataset? Approximate P_Y(y).
Ytr = dm.train_ds.y
label_freqs = torch.zeros(10, dtype=torch.float32)
for y in Ytr:
    label_freqs[int(y.item())] += 1.0
label_freqs *= 1/Ytr.shape[0]
plt.bar(range(10), label_freqs)
plt.xticks(range(10))
plt.title('histogram of labels')
plt.show()

The implication of a balanced dataset is that a probability model in which $X$ and $Y$ are statistically independent and based purely on estimated marginals, i.e., $\hat{P}_{XY}(x,y) = \hat{P}_{X}(x) \hat{P}_{Y}(y)$, will perform poorly on the training data. This is because any classifier based on the conditional $\hat{P}_{Y|X}(y|X=x) = \hat{P}_Y(y)$ is no better than guessing a label uniformly at random. Since such a trivial model performs poorly on the training dataset, we can expect our fitting process to avoid this configuration and produce a model that actually learns some of the dependence structure in the training data. As an aside, if we learn too much from the training dataset, then we are likely to overfit. 

Conversely, if the dataset was severly unbalanced, then the above trivial model could perform well on the training dataset by just guessing the labels that appear more frequently in the training dataset. Such a model is very easy to learn, but will not generalize beyond the training regime.

Note that the model $\hat{P}_{Y|X}(y|X=x) = \hat{P}_Y(y)$ is perhaps the simplest model that one can build from the training data. In the context of MNIST, it contains only $10$ parameters -- the empirical probabilities of each of the $10$ class labels. 

## Building the Neural Network

In [None]:
CONV_DIM = 64
FC_DIM = 256


class ResBlock(nn.Module):
    """
    Residual block with two 3x3 convolutional layers
    Padding size 1 to preserve input dimensionality
    ReLU nonlinearity
    """

    def __init__(self, indim: int, outdim: int, ksize=3, s=1) -> None:
        super().__init__()
        #{in,out}dim = number of filters (each of size ksize_in) in the layer


        self.conv1 = nn.Conv2d(indim, outdim, kernel_size=ksize, stride=s, padding=1)
        self.bn = nn.BatchNorm2d(outdim)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=ksize, stride=s, padding=1)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """.

        Inputs
        ----------
        x
            (B, C_in, H_in, W_in) tensor

        Outputs
        -------
        torch.Tensor
            (B, C_out, H_out, W_out) tensor
        """
        z = self.conv1(x)  # (B, C_out, H_in, W_in)
        z = self.bn(z)     # (B, C_out, H_in, W_in)  
        z = self.relu(z)   # (B, C_out, H_in, W_in)
        z = self.conv2(z)  # (B, C_out , H_in, W_in) 
        outs = x+z
        return outs



In [None]:
class ResNet(nn.Module):
    """
    Simple ResNet for the Rotated MNIST dataset
    Recall, the MNIST dataset takes as input (channels, width, height) = (1,28,28) images, i.e., 784 dimensional feature vectors
    and has 10 classes (for the digits 0,1,...,9)
    """

    def __init__(self, args: argparse.Namespace = None) -> None:
        super().__init__()
        self.args = vars(args) if args is not None else {}

        #hardwired for MNIST
        self.input_height, self.input_width = 28,28 
        num_classes = 10

        conv_dim = self.args.get("conv_dim", CONV_DIM) #number of filters / channels
        fc_dim = self.args.get("fc_dim", FC_DIM)


        ## input = (B,C=1,H,W), C=1 as MNIST has just a single channel
        self.res1 = ResBlock(1, conv_dim)   
        self.res2 = ResBlock(conv_dim, conv_dim)
        self.max_pool = nn.MaxPool2d(2)
        self.res3 = ResBlock(conv_dim, conv_dim)
        self.res4 = ResBlock(conv_dim, conv_dim)
        self.res5 = ResBlock(conv_dim, conv_dim)
        self.res6 = ResBlock(conv_dim, conv_dim)
        self.bn2d = nn.BatchNorm2d(conv_dim)
        self.drop = nn.Dropout(0.5)
        conv_output_height, conv_output_width = self.input_height // 4, self.input_width // 4
        self.flatten = nn.Flatten()
        fc_input_dim = int(conv_output_height * conv_output_width * conv_dim)
        self.fc1 = nn.Linear(fc_input_dim, fc_dim)
        self.bn1d = nn.BatchNorm1d(fc_dim)
        self.fc2 = nn.Linear(fc_dim, num_classes)
        self.relu = nn.ReLU()

        #adjust inits
        with torch.no_grad():
          self.fc2.weight *= 0.001
          self.fc2.bias *= 0.001
        

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Parameters
        ----------
        x
            (B, Ch, H, W) tensor, where H and W must equal input height and width from data_config.

        Returns
        -------
        torch.Tensor
            (B, Classes) tensor
        """
        B, C, H, W = x.shape
        x = self.res1(x)      # (B, CONV_DIM, H, W)
        x = self.res2(x)      # (B, CONV_DIM, H, W)
        x = self.max_pool(x)  # (B, CONV_DIM, H // 2, W // 2)
        x = self.drop(x)
        x = self.res3(x)      # (B, CONV_DIM, H // 2, W // 2)
        x = self.res4(x)      # (B, CONV_DIM, H // 2, W // 2)
        x = self.max_pool(x)  # (B, CONV_DIM, H // 4, W // 4)
        x = self.drop(x)
        x = self.res5(x)      # (B, CONV_DIM, H // 4, W // 4)
        x = self.res6(x)      # (B, CONV_DIM, H // 4, W // 4)
        x = self.bn2d(x)
        x = self.relu(x)
        x = self.flatten(x)   # (B, CONV_DIM * H // 4 * W // 4)
        x = self.fc1(x)       # (B, FC_DIM)
        x = self.bn1d(x)      # (B, FC_DIM)
        x = self.relu(x)      # (B, FC_DIM)
        x = self.fc2(x)       # (B, Classes)
        return x

In [None]:
def accuracy(out: torch.Tensor, yb: torch.Tensor) -> torch.Tensor:
    preds = torch.argmax(out, dim=1) 
    return (preds == yb).float().mean() #converts from 0-1 bool to float and then averages = the fraction of correct classifications. 

In [None]:
def fit(net: nn.Module, config_opt: Dict[str, Any]) -> List:
  train_dataloader = config_opt['train_dataloader']
  val_dataloader = config_opt['val_dataloader']
  epochs = config_opt['epochs']
  lr = config_opt['lr']
  wd = config_opt['wd']
  loss_func = F.cross_entropy
  optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=wd)

  valiter = iter(val_dataloader)

  net.train()
  iters = 0
  outfreq = 100
  lossi=[]

  for p in net.parameters():
    p.requires_grad = True

  for epoch in range(epochs):
    for xb, yb in train_dataloader:
      # xb = (B, 28*28), yb = (B,)
      xb.unsqueeze(1) #unsqueezing in a channel dimension
      xb = xb.view(-1,1,28,28)
      yb = yb.long()
      
      logits = net(xb)
      loss = loss_func(logits, yb)

      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      # tracking
      lossi.append(loss.log10().item())

      # outputs
      if iters % outfreq == 0:
        with torch.no_grad():

          try:
            xv,yv = next(valiter)
          except StopIteration:
            valiter = iter(val_dataloader)
            xv,yv = next(valiter)
          
          xv.unsqueeze(1) #unsqueezing in a channel dimension
          xv = xv.view(-1,1,28,28)
          yv = yv.long()

          net.eval()
          logitsv = net(xv)
          net.train()
          lossv = loss_func(logitsv,yv)
          accv = accuracy(logitsv, yv)
          acctr = accuracy(logits, yb)
          print(f'{iters:7d} | epoch {epoch:7d} | loss  {loss.item():.4f} (val: {lossv.item():.4f}) | acc {acctr:.4f} (val: {accv:.4})') 
      
      iters +=1

  return lossi

In [None]:
model = ResNet()
print(sum(p.nelement() for p in model.parameters()))

In [None]:
train_dl = dm.train_dataloader()
val_dl = dm.val_dataloader()
#opts = {'epochs': 10, 'train_dataloader': train_dl, 'val_dataloader': val_dl, 'learning_rate': 3e-4}
opts = {'epochs': 5, 'train_dataloader': train_dl, 'val_dataloader': val_dl, 'lr': 1e-4, 'wd': 1e-4}
l = fit(model, opts)

In [None]:
#so far, the best model seems to have a validation accuracy of around 0.9 (on a chosen batch)

In [None]:
@torch.no_grad()
def performance(net: nn.Module, datamodule: MNISTDataModule):
  metrics = {'train': {}, 'val': {}, 'test': {}}

  #colab crashes due to limited RAM on the full test set, so we will reduce the test set to the training size
  # ds = {'train': (datamodule.train_ds.x, datamodule.train_ds.y),
  #       'val': (datamodule.valid_ds.x, datamodule.valid_ds.y),
  #       'test': (datamodule.test_ds.x, datamodule.test_ds.y)}
  ds = {'train': (datamodule.train_ds.x, datamodule.train_ds.y),
      'val': (datamodule.valid_ds.x, datamodule.valid_ds.y)}
  train_length = ds['train'][0].shape[0] 
  ds['test'] = (datamodule.test_ds.x[:train_length], datamodule.test_ds.y[:train_length])

  net.eval()
  for key,data in ds.items():
    x = data[0].unsqueeze(1)
    y = data[1].long()
    x = x.view(-1,1,28,28)
    logits = net(x)
    loss = F.cross_entropy(logits, y)
    acc = accuracy(logits, y) 
    metrics[key] = {'loss': loss.item(), 'acc': acc}
  return metrics


perf = performance(model, dm)

In [None]:
## let's look at some example predictions from the test dataset
Xt, yt = dm.test_ds.x, dm.test_ds.y

model.eval()

Xt = Xt[:32].unsqueeze(1)
yt = yt[:32].long()
Xt = Xt.view(-1,1,28,28)
logits = model(Xt) # (B, num_classes)
ypred = torch.argmax(logits, dim=1) #(B,) 


Xt = Xt.squeeze(1)
fig, ax = plt.subplots(figsize=(15, 5), nrows=2, ncols=5)
i = 0
for row in ax:
    for col in row:
        col.imshow(Xt[i].reshape(28,28), cmap='Greys')
        col.set_title('pred = '+str(ypred[i].item()))
        col.set_axis_off()
        i+=1

plt.show()

While the results are good, a natural question to ask is what exactly did the model learn? Specifically, what do the filters in each convolutional layer look like? For a given layer, are they diverse or redundant?