## Training a NN from scratch for Image Classification
### Working with the MNIST dataset ~ The *Hello World* of Machine Learning

In this notebook we are going to get familiar with using [PyTorch](https://pytorch.org), a deep learning library, to train a simple neural network. The network will be trained on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) which contains small images of handwritten numerical digits. By the end of this training, the model should be able to accurately classify images with numerical digits.

Training a network on the MNIST dataset has become the 'hello world' of machine learning. 

More info on PyTorch and all the steps we go through in this notebook can be found in the [PyTorch Quickstart](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html).

### Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pathlib
from PIL import Image

import torch
from torch import nn
import torchvision as tv
from torch.utils.data import DataLoader
import torchvision.datasets as datasets 
from torchvision.transforms import v2

In [None]:
# Get cpu, gpu or mps device for training

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
# Create your directory structure for your datasets and models

data_dir = pathlib.Path("datasets")
data_dir.mkdir(exist_ok=True)

models_dir = pathlib.Path("models")
models_dir.mkdir(exist_ok=True)

model_name = "mnist" # change when working with other datasets

mnist_dir = models_dir / model_name
mnist_dir.mkdir(exist_ok=True)

### Data Processing ~ Image Transformations

Data does not always come in its final processed form that is required for training machine learning algorithms. We use transforms to perform some manipulation of the data and make it suitable for training. 

Images are typically stored as PIL images or NumPy n-dim arrays with pixel values in the range [0, 255] (integers). Neural networks in PyTorch expect tensors (multi-dimensional arrays) as inputs, with floating-point values, often normalized to [0., 1.] for better training stability and convergence. The transforms we apply below ensure the data is preprocessed into this "model-ready" format.

See [here](https://pytorch.org/vision/stable/transforms.html#performance-considerations) for more.

All TorchVision datasets have two parameters:
- *transform* to modify the features as normalized tensors
- *target_transform* to modify the labels and turn them into one-hot encoded tensors (not applicable in this case, because our loss function works with integer labels)

In [None]:
transforms = v2.Compose([
    v2.ToImage(),                         # an updated version of the older conversion .ToTensor()
    v2.ToDtype(torch.float32, scale=True) # scales the pixel values from [0, 255] to [0.0, 1.0]
])

train_data = datasets.MNIST(
    root=data_dir,
    train=True,
    download=True,
    transform=transforms,
)

test_data = datasets.MNIST(
    root=data_dir,
    train=False,
    download=True,
    transform=transforms,
)

As mentioned above, sometimes we need to modify the targets (labels) and turn them into one-hot encoded tensors. The implementation of the loss we are using in this notebook (cross-entropy) accepts integer labels, whereas in various cases that same loss accepts one-hot vectors. 

If we wanted to perform target transformations, this is what we would do:

In [None]:
# num_classes = 10

# target_transform = tv.transforms.Lambda(
#     lambda y: torch.zeros(num_classes, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1)
# )
# x = 3
# print(x)
# print(target_transform(x)) # one-hot representation

### Observing our Data

In [None]:
print(train_data)
print("-----")
print(test_data)

In [None]:
print(train_data[0])    # returns a tuple (image tensor, label)
print("-----")
print(train_data[0][1]) # label only

In [None]:
print(train_data.data.shape, train_data.targets.shape)
print("-----")
print(test_data.data.shape, test_data.targets.shape)

In [None]:
print("\n".join(train_data.classes)) # join an array into a string

Note the difference between the original data type and range, and what happens when you 'call' the dataset to extract a batch, as the model will do:

In [None]:
print(f"Original data type:    {train_data.data.dtype}")
print(f"range:                 [{train_data.data.min().item()}: {train_data.data.max().item()}]")
print(f"Transformed data type: {train_data[0][0].dtype}")
print(f"range:                 [{train_data[0][0].min().item()}: {train_data[0][0].max().item()}]")
print()

print(f"Original label shape:    {train_data.targets.shape} (60k integers)")
print(f"dtype:                   {train_data.targets.dtype}")

### Visualising Data

In [None]:
img_num = 7 # change this number to have a glimpse into another item of the trainset
torch.set_printoptions(linewidth=150) # add a wide linewidth to prevent wrapping
print(f"Label: {train_data.targets[img_num]}")
print()
print(train_data.data[img_num])

In [None]:
# plotting for one image
plt.figure()
plt.title(f"Label: {train_data.targets[img_num]}")
plt.imshow(train_data.data[img_num], cmap='gray')
plt.show()

In [None]:
# plotting for multiple images, randomly selected
figure = plt.figure(figsize=(10, 8))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):
    # generate a random index between 0 and len(mnist_trainset)-1, inclusive
    sample_idx = torch.randint(len(train_data), size=(1,)).item()
    # retrieve the image and the respective label for that index
    img, label = train_data[sample_idx]
    # create the grid of subplots within the bigger plot
    figure.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    # squeeze() removes all dimensions with size 1
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

### Dataloaders

The batch size is only limited by GPU/CPU constains. Generally, we tend to start with a power of 2 (eg 32, 64, 128, 256, 512) for optimal use of GPU capacities. As a general principle, larger batches increase the training speed / epoch.

In [None]:
batch_size = 64

train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

## The Training Workflow

Heavily based on Jérémie Wenger.

- **Defining our NN, our Model**
- **Defining our Optimizer and Loss Function** 
- **Implementing the Training Loop**   
- **Testing our Model**

### Defining our Neural Network

Neural networks comprise of layers/modules that perform operations on data. The torch.nn namespace provides all the building blocks you need to build your own neural network. Every module in PyTorch subclasses the nn.Module. A neural network is a module itself that consists of other modules (layers). This nested structure allows for building and managing complex architectures easily.

[See here for more](https://docs.pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html)

For estimating the input and output dimensions of our data in the NN below, one needs to remember that MNIST contains images of single digits, so 10 classes, from 0 to 9. All images are grayscale (1 colour channel) with dimensions 28 * 28 pixels.

In [None]:
num_classes = 10

# Define our fully connected NN 
class NeuralNetwork(nn.Module):
    def __init__(self): # the constructor
        super().__init__()
        self.flatten = nn.Flatten() # [1, 28, 28] -> [1, 28*28]
        # a sequence of the layers below, where each layer's output is the next layer's input
        # output = input @ weights + bias
        # the ReLU activation introduces non-linearity, helping the model learn more complex patterns
        self.linear_relu_stack = nn.Sequential( 
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [None]:
model = NeuralNetwork().to(device)
print(model)

In [None]:
print("Layers and their initial weights/bias shapes:")
for name, param in model.named_parameters():
    print(f" - {name} | Shape: {param.shape} | Sample values: {param.data.flatten()[:5]}...")

print()
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

# Breakdown:
# Layer 0: 784*128 weights + 128 biases = 100480 + 128 = 100608
# Layer 2: 128*64 weights + 64 biases = 8192 + 64 = 8256
# Layer 4: 64*10 weights + 10 biases = 640 + 10 = 650
# Total: 100608 + 8256 + 650 = 109514

### Optimizer and Loss Function

The **loss** is how we measure how good our performance is. The cross-entropy loss is a measure of how two probability distributions differ. It calculates the 'distance' between our predictions (a probability distribution) and our labels (*also* a probability distribution, with a 1 where the ground truth is, and zero everywhere else).

The **optimizer** will take this loss, and change the parameters of the network in order to improve its performance. You can try different [optimizers](https://pytorch.org/docs/stable/optim.html) from the PyTorch API.

The torch.optim.SGD below initialises the Stochastic Gradient Descent (SGD), a typical alorithm for training NNs.

The learning rate is a hyperparameter that we can adjust. It control the step size for each update of the weights during backpropagation. Generally, a smaller rate leads to more stable convergence but takes more time, while a larger rate can speed things up but risks missing the optimal loss values (underfitting).

In [None]:
learning_rate = 1e-3

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

### Implementing our Training Loop

In [None]:
# defining the training loop

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset) # total number of samples in the dataset to track progress
    losses, accs = [], [] # to store loss and accuracy history

    model.train() # set the model to training mode

    for batch_idx, (data, target) in enumerate(dataloader):
        
        # 0: data & target to specified device
        X, y = data.to(device), target.to(device)
        
        # 1: prediction - getting logits for each class
        pred = model(X)

        # 2: loss
        loss = loss_fn(pred, y)
    
        # 3: backpropagation - computes gradients
        loss.backward()

        # 4: update parameters based on computed gradients
        optimizer.step()

        # 5: 'zero grad' (otherwise the gradients remain there)
        optimizer.zero_grad()

        # Logging & saving history

        # save losses
        losses.append(loss.item())
        # save our accuracy
        accs.append((pred.argmax(1) == y).type(torch.float).mean().item())
        
        if batch_idx % 100 == 0:
            loss, current = loss.item(), (batch_idx + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    print()
    return losses, accs

In [None]:
# defining the test loop

def test(dataloader, model, loss_fn):
    
    losses, accs = [], []
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    
    model.eval() # set the model to test mode

    # no gradients since we are not training!
    with torch.no_grad():
        for (data, target) in dataloader:

            # 0: data & target to specified device
            X, y = data.to(device), target.to(device)

            # 1: prediction - getting logits for each class
            pred = model(X)
            
            # 2: loss and accuracy
            loss = loss_fn(pred, y)
            t_l = loss.item()
            test_loss += t_l

            # accumulate our accuracy
            a = (pred.argmax(1) == y).type(torch.float)
            correct += a.sum().item()

            # save loss and acc
            losses.append(t_l)
            accs.append(a.mean().item())
    
    # average loss & results
    test_loss /= num_batches
    correct /= size
    
    print("Test Error:")
    print(f"Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")
    print()
    
    return losses, accs, correct

**And now onto the actual training!**

There are two parameters we need to define, the `batch_size` (already defines in an earlier step) and the number of `epochs`.

The number of `epochs` defines how many iterations we perform over the dataset over training. The more epochs in training we perform, the longer the training is going to take, but it often (but not always) leads to better performance.

The `batch_size` defines how many data samples we process in parallel during training, this helps speed up training if we use a bigger batch size (but is dependent on the size of the memory of our computer). Using a higher batch size generally leads to better results training, as the weights are updated based on the loss of the whole batch, which leads to more stable training than if we were to update the weights after each single example. Training in batches is a form of *regularisation*.

In [None]:
epochs = 5
train_losses, train_accs, test_losses, test_accs = [], [], [], []

for t in range(epochs):
    print(f"Epoch {t+1}")
    print("-------------------------------")
    train_l, train_a = train(train_dataloader, model, loss_fn, optimizer) # per batches losses and accuracies
    test_l, test_a, _ = test(test_dataloader, model, loss_fn) # per batches losses and accuracies
    # save history
    train_losses.extend(train_l)
    train_accs.extend(train_a)
    test_losses.extend(test_l)
    test_accs.extend(test_a)
print("Done!")

### Testing ~ Evaluating the Performance of our Model

In [None]:
plt.figure(figsize=(10,5))
plt.title("Train loss")
plt.plot(train_losses,label="train")
plt.xlabel("epochs")
plt.ylabel("cumulative loss")
plt.legend()
plt.show()

In [None]:
_ = test(test_dataloader, model, loss_fn)

### Using our Model on an Input Image


See [Transforming and Augmenting Images](https://pytorch.org/vision/stable/transforms.html).

In [None]:
img = Image.open('images/3.png') # try also images/4.png

transforms = v2.Compose([  
    v2.Grayscale(num_output_channels=1),
    v2.Resize(size=(28,28), antialias=True),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True), # from [0,255] to [0,1]
])

input = transforms(img)
input = input.to(device)

print(f"Input shape: {input.shape}")

In [None]:
model.eval()
with torch.no_grad():
    predictions = nn.Softmax(dim=-1)(model(input)).cpu().numpy()
print(f"Our predictions (shape: {predictions.shape})")
print(predictions)

In [None]:
# note that predictions is still *batched* (shape: (1,10)), we need to fetch the first array
predicted = np.argmax(predictions[0]) # argmax: the *index* of the highest prediction

plt.figure()
plt.title(f'Predicted number: {train_data.classes[predicted]}') # use the predicted category in the title
plt.imshow(img, cmap="gray")
plt.axis("off")
plt.show()

We can plot our predictions for all classes using a [bar chart](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.bar.html).

In [None]:
plt.figure(figsize=(14,5))
plt.title("Predictions")
xs = train_data.classes     # 0 to 9 for Xs, our ys are our predictions
plt.bar(xs, predictions[0]) # a bar chart
plt.xticks(xs)
plt.show()

### Saving & Loading Models

```python
# save (reload using torch.jit.load)
torch.jit.save(torch.jit.script(model), mnist_dir / f"{model_name}_scripted.pt")

# save (reload using model.load_state_dict, requires the model class!)
torch.save(model.state_dict(), mnist_dir / f"{model_name}.pt")
print(f"Saved PyTorch Model State to {mnist_dir / model_name}.pt")

# instantiate then load (you need to have defined NeuralNetwork)!
model_reloaded = NeuralNetwork().to(device)
model_reloaded.load_state_dict(torch.load(mnist_dir / f"{model_name}.pt", weights_only=True))
```

---

## To Do:

**Task 1:** Run all the cells in this notebook to train a simple NN on the MNIST dataset. While doing that, try to understand what each cell does, even if you do not understand each one of the commands seperately. The important thing is to start comprehending the whole process of the training. Training a NN in PyTorch is a great opportunity to see all the theory of training in practice. Consult [the pytorch documentation page](https://pytorch.org/docs/stable/index.html) for anything you are unsure about. 

**Task 2:** Add other images of handwritten digits in the data folder and test your model on those within the last sections where you visualise predictions on new input images. You can add images you download from the internet, or you could get more experimental e.g. create handwritten digits on paper and load them here, or create b-w digits in p5/py5 or other environments and test how the model performs on all of these cases. Test the model's limits!

**Task 3:** Test this network on the Fashion MNIST dataset. We are currently calling `tv.datasets.MNIST`. You will instead need to call `tv.datasets.FashionMNIST`.

**Task 4:** If you feel confident, you may start exploring how to train a NN on your custom dataset. For that you may follow the instructions below. Note, that a fully connected NN like the one we use here, is not ideal for training an image classifier. Next week, we will use CNN which will be much more appropriate for this task. 

### Notes on Training on a custom dataset:

Provided that you have images in a folder like this:
```bash
main_directory/
...class_a/
......image_1.jpg
......image_2.jpg
...class_b/
......image_1.jpg
......image_2.jpg
```

You can then replace the data loading by

```python
# Model / data parameters
num_classes = # your number of classes

transforms = transforms.Compose([  
    tv.transforms.Grayscale(num_output_channels=1),
    tv.transforms.Resize(size=(28,28), antialias=True)
])

custom_data = tv.datasets.ImageFolder(
    data_dir / "custom_dataset",
    transform=transforms,
)

print(custom_data)
print("\n".join(custom_data.classes)) # should show the folder names

train_data, test_data = torch.utils.data.random_split(custom_data, [.9,.1])
print(len(train_data), len(test_data))

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
```

See [the documentation](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html#torchvision.datasets.ImageFolder).

Checking the contents, as well as training and testing your net, should be identical as before.