# The big flaw of MLPs

Today we will experimentally show what is the **fundamental flaw** of MLPs when it comes to correctly detect visual patterns in the input data, and present it as a motivation for **convolution-based solutions**. In particular, we will compare the performance of our previously implemented MLP on the [MNIST](https://pytorch.org/vision/stable/datasets.html#mnist) dataset with the performance on a **translated version** of the same data.

As usual, we start by importing the modules that we need.

In [None]:
import torch
import torchvision
from torchvision import transforms as T
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

## Translation function
In this block we are going to define and implement the method which will take care of **transforming** the items in our dataset in such a way to obtain **visually translated** versions of the input digits.

In [None]:
def create_translation_transform():
  '''
  Creates a transformation that pads the original input so that the
  digit is not centered in the image
  '''
  def padding(x):
      pad_size = 28
      left_padding = torch.randint(low=0, high=pad_size, size=(1,))
      top_padding = torch.randint(low=0, high=pad_size, size=(1,))
      return F.pad(x, (left_padding, 
                       pad_size - left_padding, 
                       top_padding, 
                       pad_size - top_padding), "constant", 0)
   
  translation_transform = list()
  translation_transform.append(T.ToTensor())
  translation_transform.append(T.Lambda(lambda x: padding(x)))
  translation_transform.append(T.Lambda(lambda x: x.repeat(3, 1, 1)))
  translation_transform = T.Compose(translation_transform)
  
  return translation_transform

## Data loading
In this block we want to create a method that returns the required **dataloading utility** over our dataset, in such a way to **choose** whether we want to apply the translation or not by means of a parameter.

In [None]:
def get_data_for_visualization(batch_size, translate=False): 
  
  if not translate:
    # image transformation that appends 14 pixels on each side of a digit
    transform = list()
    transform.append(T.ToTensor())
    transform.append(T.Lambda(lambda x: F.pad(x, (14, 14, 14, 14), "constant", 0)))
    transform = T.Compose(transform)
  else:
    # applies random translations to images
    transform = create_translation_transform()

  # load data
  full_training_data = torchvision.datasets.MNIST('./data', train=True, transform=transform, download=True)

  # initialize dataloaders
  train_loader = torch.utils.data.DataLoader(full_training_data, batch_size, shuffle=True)
  
  return train_loader

## Data visualization
We want to obtain a visual representation that allows us to **compare** the original digits of the dataset with their translated version.

In [None]:
# get a bunch of training images for visualization over the original and 
# translated dataset, respectively
train_loader = get_data_for_visualization(256, translate=False)
train_loader_translated = get_data_for_visualization(256, translate=True)

# define iterators over both datasets
train_iter, train_iter_translated = iter(train_loader), iter(train_loader_translated)

# get labels of original digits
data, labels = next(train_iter)

# get labels of translated digits
data_translated, labels_translated = next(train_iter_translated)

# the label of the digit you want to visualize
digit_label = 8

# get first 9 indices of the chosen digit for non-translated digits
get_idx = (labels == digit_label).nonzero().squeeze(-1)[0:9]

# get first 9 indices of the chosen digit for translated digits
get_idx_translated = (labels_translated == digit_label).nonzero().squeeze(-1)[0:9]

# get the data and labels for the chosen digit
get_data, get_labels = data[get_idx, :, :, :], labels[get_idx]
get_data_translated, get_labels_translated = data_translated[get_idx_translated, :, :, :], \
                                             labels_translated[get_idx_translated]


### visualize the plots inline, both for original and translated digits ###

# original
display_grid = torchvision.utils.make_grid(get_data, nrow=3, padding=2, pad_value=1)
plt.subplot(1, 2, 1)
plt.imshow(display_grid.numpy().transpose(1,2,0))
plt.axis('off')
plt.title('Centered Digits')

# translated
display_grid_translated = torchvision.utils.make_grid(get_data_translated, nrow=3, padding=2, pad_value=1)
plt.subplot(1, 2, 2)
plt.imshow(display_grid_translated.numpy().transpose(1,2,0))
plt.axis('off')
plt.title('Translated Digits')

plt.tight_layout()
plt.show()

## MLP architecture
We will use the same architecture that we have discussed and implemented in the previous lab session

In [None]:
class MyFirstNetwork(torch.nn.Module):

  def __init__(self, input_dim, hidden_dim, output_dim):

    # initialize the function
    super(MyFirstNetwork, self).__init__()
    
    # first linear layer (input)
    self.input_to_hidden = torch.nn.Linear(input_dim, hidden_dim)

    # activation function
    self.activation = torch.nn.Sigmoid()

    # second linear layer (output)
    self.hidden_to_output = torch.nn.Linear(hidden_dim, output_dim)
    
    # initialize bias
    self.input_to_hidden.bias.data.fill_(0.)
    self.hidden_to_output.bias.data.fill_(0.)

  def forward(self, x):

    # puts the output in (batch_size, input_dim) format
    x = x.view(x.shape[0],-1)

    # forward the input through the layers
    x = self.input_to_hidden(x)
    x = self.activation(x)
    x = self.hidden_to_output(x)
    
    return x

## Cost function and optimizer
Also for these two components, we stick to those employed in the previous lab.

In [None]:
def get_cost_function():
  cost_function = torch.nn.CrossEntropyLoss()
  return cost_function

def get_optimizer(net, lr, wd, momentum):
  optimizer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd, momentum=momentum)
  return optimizer

## Training and test steps
We already know the drill!

In [None]:
def training_step(net, data_loader, optimizer, cost_function, device='cuda'):

  samples = 0.
  cumulative_loss = 0.
  cumulative_accuracy = 0.

  # set the network to training mode
  net.train() 

  # iterate over the training set
  for batch_idx, (inputs, targets) in enumerate(data_loader):

    # load data into GPU
    inputs = inputs.to(device)
    targets = targets.to(device)
      
    # forward pass
    outputs = net(inputs)

    # loss computation
    loss = cost_function(outputs,targets)

    # backward pass
    loss.backward()
    
    # parameters update
    optimizer.step()
    
    # gradients reset
    optimizer.zero_grad()

    # fetch prediction and loss value
    samples += inputs.shape[0]
    cumulative_loss += loss.item()
    _, predicted = outputs.max(dim=1) # max() returns (maximum_value, index_of_maximum_value)

    # compute training accuracy
    cumulative_accuracy += predicted.eq(targets).sum().item()

  return cumulative_loss/samples, cumulative_accuracy/samples*100

def test_step(net, data_loader, cost_function, device='cuda'):

  samples = 0.
  cumulative_loss = 0.
  cumulative_accuracy = 0.

  # set the network to evaluation mode
  net.eval() 

  # disable gradient computation (we are only testing, we do not want our model to be modified in this step!)
  with torch.no_grad():

    # iterate over the test set
    for batch_idx, (inputs, targets) in enumerate(data_loader):
      
      # load data into GPU
      inputs = inputs.to(device)
      targets = targets.to(device)
        
      # forward pass
      outputs = net(inputs)

      # loss computation
      loss = cost_function(outputs, targets)

      # fetch prediction and loss value
      samples+=inputs.shape[0]
      cumulative_loss += loss.item() # Note: the .item() is needed to extract scalars from tensors
      _, predicted = outputs.max(1)

      # compute accuracy
      cumulative_accuracy += predicted.eq(targets).sum().item()

  return cumulative_loss/samples, cumulative_accuracy/samples*100

## Data loading
Let us now define a compact method to return the dataloaders that we need to perform our experiment.

In [None]:
def get_data(batch_size, test_batch_size=128, translate=False): 
  
  # define the transformations (same way as before)
  if translate:
    def padding(x):
        pad_size = 28
        left_padding = torch.randint(low=0, high=pad_size, size=(1,))
        top_padding = torch.randint(low=0, high=pad_size, size=(1,))
        return F.pad(x, (left_padding, 
                         pad_size - left_padding, 
                         top_padding, 
                         pad_size - top_padding), "constant", 0)

    transform = list()
    transform.append(T.ToTensor())
    transform.append(T.Lambda(lambda x: padding(x)))
    transform = T.Compose(transform)
  else:
    transform = list()
    transform.append(T.ToTensor())
    transform.append(T.Lambda(lambda x: F.pad(x, (14, 14, 14, 14), "constant", 0)))
    transform = T.Compose(transform)
    
  # load data
  full_training_data = torchvision.datasets.MNIST('./data', train=True, transform=transform, download=True) 
  test_data = torchvision.datasets.MNIST('./data', train=False, transform=transform, download=True) 
  
  # split into training and validation sets
  num_samples = len(full_training_data)
  training_samples = int(num_samples*0.5+1)
  validation_samples = num_samples - training_samples

  training_data, validation_data = torch.utils.data.random_split(full_training_data, [training_samples, validation_samples])

  # initialize dataloaders
  train_loader = torch.utils.data.DataLoader(training_data, batch_size, shuffle=True)
  val_loader = torch.utils.data.DataLoader(validation_data, test_batch_size, shuffle=False)
  test_loader = torch.utils.data.DataLoader(test_data, test_batch_size, shuffle=False)
  
  return train_loader, val_loader, test_loader

## Main function
Let's now define our wrapper to actually train and evaluate our model

In [None]:
'''
Input arguments:
  batch_size: the size of a mini-batch that is used for training
  input_dim: flattened size of the input image vector
  hidden_dim: number of hidden neurons in the network
  output_dim: the number of output neurons
  device: GPU where you want to train your network
  learning_rate: learning rate for the optimizer
  weight_decay: weight decay coefficient for regularization of weights
  momentum: momentum for SGD optimizer
  epochs: number of epochs for training the network
  translate: whether to translate the images that are fed to the network
  visualization_name: name of the graph for visualizing in tensorboards
                      (always remember to use an unique visualization_name
                       for each training, otherwise it will mess up the visualization!)
'''

def main_MLP(batch_size=64, input_dim=56*56, hidden_dim=100, output_dim=10, device='cuda:0', 
             learning_rate=0.01, weight_decay=0.000001, momentum=0.9, epochs=50, 
             translate=False, visualization_name='centered'):
  
  # creates a logger for the experiment
  writer = SummaryWriter(log_dir=f"runs/{visualization_name}")

  # get dataloaders
  train_loader, val_loader, test_loader = get_data(batch_size=batch_size, translate=translate)
  
  # correctly set the device
  device = torch.device(device)
  
  # instantiate the model and send it to the device
  net = MyFirstNetwork(input_dim, hidden_dim, output_dim).to(device)
  
  # instantiate optimizer & cost function
  optimizer = get_optimizer(net, learning_rate, weight_decay, momentum)
  cost_function = get_cost_function()

  # perform a single test step beforehand and print metrics
  print('Before training:')
  train_loss, train_accuracy = test_step(net, train_loader, cost_function)
  val_loss, val_accuracy = test_step(net, val_loader, cost_function)
  test_loss, test_accuracy = test_step(net, test_loader, cost_function)

  print('\t Training loss {:.5f}, Training accuracy {:.2f}'.format(train_loss, train_accuracy))
  print('\t Validation loss {:.5f}, Validation accuracy {:.2f}'.format(val_loss, val_accuracy))
  print('\t Test loss {:.5f}, Test accuracy {:.2f}'.format(test_loss, test_accuracy))
  print('-----------------------------------------------------')
  
  # add values to logger
  writer.add_scalar('Loss/train_loss', train_loss, 0)
  writer.add_scalar('Loss/val_loss', val_loss, 0)
  writer.add_scalar('Accuracy/train_accuracy', train_accuracy, 0)
  writer.add_scalar('Accuracy/val_accuracy', val_accuracy, 0)
  
  # iterate over the epochs number
  for e in range(epochs):
    train_loss, train_accuracy = training_step(net, train_loader, optimizer, cost_function)
    val_loss, val_accuracy = test_step(net, val_loader, cost_function)
    print('Epoch: {:d}'.format(e+1))
    print('\t Training loss {:.5f}, Training accuracy {:.2f}'.format(train_loss, train_accuracy))
    print('\t Validation loss {:.5f}, Validation accuracy {:.2f}'.format(val_loss, val_accuracy))
    print('-----------------------------------------------------')
    
    # add values to logger
    writer.add_scalar('Loss/train_loss', train_loss, e + 1)
    writer.add_scalar('Loss/val_loss', val_loss, e + 1)
    writer.add_scalar('Accuracy/train_accuracy', train_accuracy, e + 1)
    writer.add_scalar('Accuracy/val_accuracy', val_accuracy, e + 1)

  # compute and print final metrics
  print('After training:')
  train_loss, train_accuracy = test_step(net, train_loader, cost_function)
  val_loss, val_accuracy = test_step(net, val_loader, cost_function)
  test_loss, test_accuracy = test_step(net, test_loader, cost_function)
  
  print('\t Training loss {:.5f}, Training accuracy {:.2f}'.format(train_loss, train_accuracy))
  print('\t Validation loss {:.5f}, Validation accuracy {:.2f}'.format(val_loss, val_accuracy))
  print('\t Test loss {:.5f}, Test accuracy {:.2f}'.format(test_loss, test_accuracy))
  print('-----------------------------------------------------')

  # close the logger
  writer.close()

## Run!
Let's first run the MLP on the original dataset

In [None]:
! rm -r runs
main_MLP(translate=False, visualization_name='centered')

Let us now run it on the translated version

In [None]:
main_MLP(translate=True, visualization_name='translated')

And see the results!

In [None]:
%load_ext tensorboard
%tensorboard --logdir=runs