# Section 4 - Computer vision based machine learning #
## Building CNNs with pyTorch ##

In this second part we will:

Star with: 
* look at the MNIST dataset
* construct datasets, loaders and visualise the data
* Build a MLP to classify the MNIST dataset

and then:
* Build a CNN based on the MNIST dataset features

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from time import time
import copy
import pandas as pd 

We are going to use here the datasets readily available in pytorch using ```torchvision.datasets```. We will need to define if the training set is used to train or test. We will add some transformations that are already available (ToTensor() and Normalise).

In [None]:
train = torchvision.datasets.MNIST(root='./data', download=True, train=True, transform=transforms.Compose([transforms.ToTensor()]))
#train = torchvision.datasets.MNIST(root='./data', download=True, train=True, transform=transforms.Compose([transforms.ToTensor(),
 

Now we have a training sample based on MNIST we can instantiate a dataloader that will be used to provide training data to our model:

In [None]:
train_loader = torch.utils.data.DataLoader(dataset=train, shuffle=True, batch_size=10) #we can set the batch size for each interation

It is now possible to visualise easily our dataset using the dataloader and explore:

In [None]:
fig, ax = plt.subplots(2, 5, figsize=(14, 8))
ax = ax.flatten()
dataiter = iter(train_loader)
images, labels = dataiter.next()
print("image shape:", images.shape, "\n label shape:", labels.shape)
for i in range(0, 10):
    plottable_image = images[i].squeeze()
    ax[i].imshow(plottable_image, cmap='gray')
    ax[i].set_title("Label: {}".format(labels[i]))
fig.tight_layout()

Exercise: 
* write a Load_MNIST function that load the MNIST dataset both for training and testing and return four sets ```train,test,train_loader, test_loader```.
* write a short function ```visualize_MNIST``` to visualise any number of digits using the dataloader

In [None]:
# Load and visualize MNISt
train, test, train_loader, test_loader = load_MNIST()
visualize_MNIST(train_loader)

### Let's now build a MLP based on a few layers before

In [None]:
class Net(nn.Module):
    
  def __init__(self):
    super(Net,self).__init__()
    self.input_layer = nn.Linear(784, 1000, bias=False)
    self.hidden1_layer = nn.Linear(1000, 1000, bias=False)
    self.hidden2_layer = nn.Linear(1000, 500, bias=False)
    self.hidden3_layer = nn.Linear(500, 200, bias=False)
    self.hidden4_layer = nn.Linear(200, 10, bias=False)

  def forward(self, x):
    x = self.input_layer(x)
    x = F.relu(x)
    x = self.hidden1_layer(x)
    x = F.relu(x)
    x = self.hidden2_layer(x)
    x = F.relu(x)
    x = self.hidden3_layer(x)
    x = F.relu(x)
    x = self.hidden4_layer(x)
    output = F.log_softmax(x, dim=1)

    return output

In [None]:
def train(model, train_loader, epochs=3, learning_rate=0.001):
  """Function to train a neural net"""

  lossFunction = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  time0 = time()
  total_samples = 0 

  for e in range(epochs):
    print("Starting epoch", e)
    total_loss = 0

    for idx, (images,labels) in enumerate(train_loader):
      images = images.view(images.shape[0],-1) # flatten
      optimizer.zero_grad() # forward pass
      output = model(images)
      loss = lossFunction(output,labels) # calculate loss
      loss.backward() # backpropagate
      optimizer.step() # update weights

      total_samples += labels.size(0)
      total_loss += loss.item()

      if idx % 100 == 0:
        print("Running loss:", total_loss)

  final_time = (time()-time0)/60 
  print("Model trained in ", final_time, "minutes on ", total_samples, "samples")


In [None]:
model = Net()
train(model, train_loader, 3) # 10 epochs

In [None]:
def test(model, test_loader):
  """Test neural net"""

  correct = 0
  total = 0 

  with torch.no_grad():
    for idx, (images, labels) in enumerate(test_loader):
      images = images.view(images.shape[0],-1) # flatten
      output = model(images)
      values, indices = torch.max(output.data, 1)
      total += labels.size(0)
      correct += (labels == indices).sum().item()

    acc = correct / total * 100
    # print("Accuracy: ", acc, "% for ", total, "training samples")

  return acc


In [None]:
acc = test(model, test_loader)
print("The accuracy of our vanilla NN is", acc, "%")