## Training a CNN for Image Classification (vs the fully connected NN)
### Working with the MNIST dataset ~ The *Hello World* of Machine Learning

This notebook is identical to [01_train_mnist_classifier_simple_NN.ipynb](01_train_mnist_classifier_simple_NN.ipynb), except for the model's architecture.

Due to the similarities of the two notebooks, here I only provide the absolute necessary steps through which one needs to go through. 

### Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pathlib
from PIL import Image

import torch
from torch import nn
import torchvision as tv
from torch.utils.data import DataLoader
import torchvision.datasets as datasets 
from torchvision.transforms import v2

In [None]:
# Get cpu, gpu or mps device for training

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
# Create your directory structure for your datasets and models

data_dir = pathlib.Path("datasets")
data_dir.mkdir(exist_ok=True)

models_dir = pathlib.Path("models")
models_dir.mkdir(exist_ok=True)

model_name = "mnist_cnn" # change when working with other datasets

mnist_dir = models_dir / model_name
mnist_dir.mkdir(exist_ok=True)

### Data Processing ~ Image Transformations

In [None]:
transforms = v2.Compose([
    v2.ToImage(),                         # an updated version of the older conversion .ToTensor()
    v2.ToDtype(torch.float32, scale=True) # scales the pixel values from [0, 255] to [0.0, 1.0]
])

train_data = datasets.MNIST(
    root=data_dir,
    train=True,
    download=True,
    transform=transforms,
)

test_data = datasets.MNIST(
    root=data_dir,
    train=False,
    download=True,
    transform=transforms,
)

### Observing our Data

In [None]:
print(train_data.data.shape, train_data.targets.shape)
print("-----")
print(test_data.data.shape, test_data.targets.shape)

In [None]:
print("\n".join(train_data.classes)) # join an array into a string

### Visualising Data

In [None]:
img_num = 7 # change this number to have a glimpse into another item of the trainset

# plotting for one image
plt.figure()
plt.title(f"Label: {train_data.targets[img_num]}")
plt.imshow(train_data.data[img_num], cmap='gray')
plt.show()

In [None]:
# plotting for multiple images, randomly selected
figure = plt.figure(figsize=(10, 8))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):
    # generate a random index between 0 and len(mnist_trainset)-1, inclusive
    sample_idx = torch.randint(len(train_data), size=(1,)).item()
    # retrieve the image and the respective label for that index
    img, label = train_data[sample_idx]
    # create the grid of subplots within the bigger plot
    figure.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    # squeeze() removes all dimensions with size 1
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

### Dataloaders

In [None]:
batch_size = 64

train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

## The Training Workflow

Heavily based on Jérémie Wenger.

- **Defining our NN, our Model**
- **Defining our Optimizer and Loss Function** 
- **Implementing the Training Loop**   
- **Testing our Model**

### Defining our Convolutional Neural Network

For estimating the input and output dimensions of our data in the NN below, one needs to remember that MNIST contains images of single digits, so 10 classes, from 0 to 9. All images are grayscale (1 colour channel) with dimensions 28 * 28 pixels. In the case of a CNN network, the input is the number of colour channels, i.e. 1 for the grayscale MNIST.

In [None]:
class ConvNetwork(nn.Module):
    def __init__(self):
        super(ConvNetwork, self).__init__()
        # Input shape: [batch, 1, 28, 28]
        # Breaking down the first conv layer: 
        #   > 1 input channel for grayscale images
        #   > 32 different filters to output
        #   > 3x3 kernel size
        #   > 1 padding
        # output shape: [batch, 32, 28, 28]
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        # 2x2 maxpooling, output shape: [batch, 32, 14, 14]
        self.pool = nn.MaxPool2d(2, 2)
        # and so on and so forth ...
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 64)
        self.fc2 = nn.Linear(64, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = torch.flatten(x, 1) 
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
# class ConvNetwork(nn.Module):
#     def __init__(self):
#         super(ConvNetwork, self).__init__()
#         # Convolutional feature extraction layers
#         # Input: [batch, 1, 28, 28]
#         # After conv1: [batch, 32, 28, 28]
#         # After pool1: [batch, 32, 14, 14]
#         # After conv2: [batch, 64, 14, 14]
#         # After pool2: [batch, 64, 7, 7]
#         self.features = nn.Sequential(
#             nn.Conv2d(1, 32, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2, 2),
#             nn.Conv2d(32, 64, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.MaxPool2d(2, 2)
#         )
#         # Fully connected classification layers
#         self.classifier = nn.Sequential(
#             nn.Flatten(),
#             nn.Linear(64 * 7 * 7, 64),
#             nn.ReLU(),
#             nn.Linear(64, 10)
#         )

#     def forward(self, x):
#         x = self.features(x)
#         x = self.classifier(x)
#         return x

In [None]:
model = ConvNetwork().to(device)
print(model)

In [None]:
print("Layers and their initial weights/bias shapes:")
for name, param in model.named_parameters():
    print(f" - {name} | Shape: {param.shape} | Sample values: {param.data.flatten()[:5]}...")

print()
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

### Optimizer and Loss Function

In [None]:
learning_rate = 0.001

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

### Implementing our Training Loop

In [None]:
# defining the training loop

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset) # total number of samples in the dataset to track progress
    losses, accs = [], [] # to store loss and accuracy history

    model.train() # set the model to training mode

    for batch_idx, (data, target) in enumerate(dataloader):
        
        # 0: data & target to specified device
        X, y = data.to(device), target.to(device)
        
        # 1: prediction - getting logits for each class
        pred = model(X)

        # 2: loss
        loss = loss_fn(pred, y)
    
        # 3: backpropagation - computes gradients
        loss.backward()

        # 4: update parameters based on computed gradients
        optimizer.step()

        # 5: 'zero grad' (otherwise the gradients remain there)
        optimizer.zero_grad()

        # Logging & saving history

        # save losses
        losses.append(loss.item())
        # save our accuracy
        accs.append((pred.argmax(1) == y).type(torch.float).mean().item())
        
        if batch_idx % 100 == 0:
            loss, current = loss.item(), (batch_idx + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    print()
    return losses, accs

In [None]:
# defining the test loop

def test(dataloader, model, loss_fn):
    
    losses, accs = [], []
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    
    model.eval() # set the model to test mode

    # no gradients since we are not training!
    with torch.no_grad():
        for (data, target) in dataloader:

            # 0: data & target to specified device
            X, y = data.to(device), target.to(device)

            # 1: prediction - getting logits for each class
            pred = model(X)
            
            # 2: loss and accuracy
            loss = loss_fn(pred, y)
            t_l = loss.item()
            test_loss += t_l

            # accumulate our accuracy
            a = (pred.argmax(1) == y).type(torch.float)
            correct += a.sum().item()

            # save loss and acc
            losses.append(t_l)
            accs.append(a.mean().item())
    
    # average loss & results
    test_loss /= num_batches
    correct /= size
    
    print("Test Error:")
    print(f"Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")
    print()
    
    return losses, accs, correct

In [None]:
epochs = 5
train_losses, train_accs, test_losses, test_accs = [], [], [], []

for t in range(epochs):
    print(f"Epoch {t+1}")
    print("-------------------------------")
    train_l, train_a = train(train_dataloader, model, loss_fn, optimizer) # per batches losses and accuracies
    test_l, test_a, _ = test(test_dataloader, model, loss_fn) # per batches losses and accuracies
    # save history
    train_losses.extend(train_l)
    train_accs.extend(train_a)
    test_losses.extend(test_l)
    test_accs.extend(test_a)
print("Done!")

### Testing ~ Evaluating the Performance of our Model

In [None]:
plt.figure(figsize=(10,5))
plt.title("Train loss")
plt.plot(train_losses,label="train")
plt.xlabel("epochs")
plt.ylabel("cumulative loss")
plt.legend()
plt.show()

In [None]:
_ = test(test_dataloader, model, loss_fn)

### Using our Model on an Input Image


See [Transforming and Augmenting Images](https://pytorch.org/vision/stable/transforms.html).

In [None]:
img = Image.open('images/3.png') # try also images/4.png

transforms = v2.Compose([  
    v2.Grayscale(num_output_channels=1),
    v2.Resize(size=(28,28), antialias=True),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True), # from [0,255] to [0,1]
])

input = transforms(img).unsqueeze(0)  # ADD BATCH DIMENSION [1, 1, 28, 28]
input = input.to(device)

print(f"Input shape: {input.shape}")

In [None]:
model.eval()
with torch.no_grad():
    predictions = nn.Softmax(dim=-1)(model(input)).cpu().numpy()
print(f"Our predictions (shape: {predictions.shape})")
print(predictions)

In [None]:
# note that predictions is still *batched* (shape: (1,10)), we need to fetch the first array
predicted = np.argmax(predictions[0]) # argmax: the *index* of the highest prediction

plt.figure()
plt.title(f'Predicted number: {train_data.classes[predicted]}') # use the predicted category in the title
plt.imshow(img, cmap="gray")
plt.axis("off")
plt.show()

We can plot our predictions for all classes using a [bar chart](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.bar.html).

In [None]:
plt.figure(figsize=(14,5))
plt.title("Predictions")
xs = train_data.classes     # 0 to 9 for Xs, our ys are our predictions
plt.bar(xs, predictions[0]) # a bar chart
plt.xticks(xs)
plt.show()

In [None]:
torch.jit.save(torch.jit.script(model), mnist_dir / f"my_{model_name}.pt")

### Saving & Loading Models

```python
# save (reload using torch.jit.load)
torch.jit.save(torch.jit.script(model), mnist_dir / f"my_{model_name}.pt")

# save (reload using model.load_state_dict, requires the model class, ie you need to redefine your model architecture)
torch.save(model.state_dict(), mnist_dir / f"{model_name}.pt")
print(f"Saved PyTorch Model State to {mnist_dir / model_name}.pt")

# instantiate then load (you need to have defined NeuralNetwork)!
model_reloaded = NeuralNetwork().to(device)
model_reloaded.load_state_dict(torch.load(mnist_dir / f"{model_name}.pt", weights_only=True))
```
The `jit` only method is ideal for using model (inference), **however**, if you want to finetune your model after reloading it, prefer the full method above (class definition + loading weights).