## Classifying handwritten digits with a CNN


In the previous exercies, we worked with fully-connected neural networks, which are good at handling tabular data, where the inputs and targets are easily presented as vectors.

However, in the case of images, or image-like objects, such models are less efficient for reasons we have discussed in the slides. When inputs are images, or image-like data, a more natural choice of model is a convolution neural network—in particularly, a model which uses 2D convolutional layers.

Before we start worrying about choosing models, let's first acquaint ourselves with the MNIST data.

### Task 1: Access the MNIST dataset.

PyTorch has a (sort of) sister Python library for dealing with images: [``Torchvision``](https://pytorch.org/vision/stable/index.html) (take a look at the website for a few minutes).

In the previous exercies, we used a custom ``Dataset`` object created specifically for this event, but ``Torchvision`` comes a bunch of easy-to-use datasets, one of which is MNIST.

#### Task 1—(a)

- Look at the arguments of the MNIST datset: what options do we have?
- Instantiate the (training) dataset.
- Iterate over it: how are the inputs and targets presented to us?
  - As before, we must supply transforms to maps between the raw data and ``torch.Tensor``s.
  - Look at how we want to prepare the image tensors first.
  - Look at how we want to prepare the target tensors second.
- Plot some image tensors, and set their targets as the title, to make sure the data we feed to the model makes sense.
- Create a validation dataset, too.

#### Task 2—(a)

Now that we have a functioning dataset, as before, we could proceed and train our model. However, with images, we can often improve our model's abilitiy to generalise—and therefore produce more meaningful results—by employing data augmentation.

- Are CNNs rotationally invariant?
  - If we want out model still to work on images which are not of a regular orientation, we must use random rotations as a form of augmentation.
- If we train a model on purely black-and-white images, how will it fare on more colourful data?
- Go to ``Torchvision``'s [transforms](https://pytorch.org/vision/stable/transforms.html) and look at the the available forms of images augmentation.
  - Take a few minutes to pick ones you think might be relevant.
  - Let's discuss and choose some.
- Now plot your augmentated data, as before, and make sure everything is in order.
  -  It's _always_ _always_ _always_ a good idea to do this.
- Note: applying data augmentations to the validation set is unconventional, so we can create our transforms in a function which optionally applies augmentations.

In [None]:
from torchvision.datasets import MNIST

from torchvision.transforms import ToTensor, Compose

# Part (b) imports
# from torchvision.transforms import ...

from torch import eye

import matplotlib.pyplot as plt

### Task 3: ``Dataset`` $\to$ ``DataLoader``

As before, wrap the ``Dataset``'s in ``DataLoader``s.

In [None]:
from torch.utils.data import DataLoader

### Task 4: Choose a model architecture

Torchvision provides a collection of models, [here](https://pytorch.org/vision/stable/models.html).

Since we are all on laptops, some of which don't have CUDA-enabled GPUs, I suggest we choose a modest neural network than wont melt any of our hardware. One such network, designed with mobile phones in mind, is [``MOBILENET``](https://pytorch.org/vision/stable/models/mobilenetv3.html).

#### Task 4 (a): instantiate the small one, and print it out.

Note:
- Torchvision's models can optionally be endowed with pretrained weights (from corresponding models pretrain on the ImageNet dataset).
- We can (optionally) supply these weights.
  - Using weights from a model, trained on one problem, as an intitial condition in another problem, is called transfer learning.
  - Why do you think this might be advantageous, even in disparate problems?

In [8]:
from torchvision.models import mobilenet_v3_small
from torchvision.models import MobileNet_V3_Small_Weights

model = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT)
print(model)

MobileNetV3(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), 

#### Task 4 (b): Uh oh, we've a problem.

- Look at the final linear layer.
  - How many output classes are there?
  - How many do we need?
- We need to "overload" the final layer to produce the correct number of output features for our problem. Fortunately this is easy.
- Uncomment the code below, choose the correct number of output features, and print the model again.

In [15]:
from torch.nn import Linear

# model.classifier[3] = Linear(model.classifier[3].in_features, ???)

### Task 5: Set up the remaining PyTorch bits and bobs

- We need to choose a loss function appropriate for classification.
  - Can you remember what we choose previously?
- We need an optimiser, too.
  - Remember our friend, Adam?
- Instantiate the model and loss function.

In [None]:
# from torch.nn import what_the_func?
# from torch.optim import Adam

### Test 6: Set the device

We could have done this when we created the model, but I forgot.

In [18]:
from torch.cuda import is_available

DEVICE = "cuda" if is_available() else "cpu"

model.to(DEVICE);