## Training a CNN for Image Classification on your Custom Dataset

Before you proceed with this notebook, you need to have a custom dataset to train your model on. You may use one of the methods suggested in [04_collect_data.ipynb](04_collect_data.ipynb) and make sure your custom dataset in under your `./datasets` folder.

### Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pathlib
from PIL import Image

import torch
from torch import nn
import torchvision as tv
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import torchvision.datasets as datasets 
from torchvision.transforms import v2

In [None]:
# Get cpu, gpu or mps device for training

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
# Create your directory structure for your datasets and models

data_dir = pathlib.Path("datasets/gallery_dl_dataset")
data_dir.mkdir(exist_ok=True)

models_dir = pathlib.Path("models")
models_dir.mkdir(exist_ok=True)

model_name = "image_classifier" #Â change when working with other datasets

model_dir = models_dir / model_name
model_dir.mkdir(exist_ok=True)

### Data Processing ~ Image Transformations

Could you augment your training data by adding more transformations to them?

You could randomly change their brightness, contrast, saturation, and hue.

You could flip them horizontally or vertically with a 0.5 probability.

You could randomly rotate them.

Look in [here](https://pytorch.org/vision/stable/transforms.html) and [here](https://pytorch.org/vision/stable/auto_examples/transforms/plot_transforms_illustrations.html#sphx-glr-auto-examples-transforms-plot-transforms-illustrations-py) for references and examples. 

Do you need to also add the above transformations to your validation set? Or are the existing ones enough? You need to consider what the purpose of each dataset is.

In [None]:
num_classes = 3 # your number of classes

train_transform = v2.Compose([
        # v2.Resize(size=(64, 64), antialias=True),
        v2.RandomResizedCrop(size=(64, 64), antialias=True),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True), 
    ])

val_transform = v2.Compose([
        v2.Resize(size=(64, 64), antialias=True),
        v2.CenterCrop(size=(64, 64)),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True), 
    ])

# create train and validation datasets with seperate transforms
train_dataset = datasets.ImageFolder(data_dir, transform=train_transform)
val_dataset = datasets.ImageFolder(data_dir, transform=val_transform)
test_dataset = datasets.ImageFolder(data_dir, transform=val_transform)

print("\n".join(train_dataset.classes)) # should show the folder names

Here we create our train, validation and test datasets by splitting the full input dataset into three subsets. A 70-20-10 split is quite common.

By setting a `random_state`, we are performing the split randomly but in a deterministic way, i.e. we will always get the same random train_test_split as long as we use the same random_state.

In [None]:
# get length of the full dataset before split, and save it in idx
num_train = len(train_dataset)

# define the percentage that will be used for validation
val_size = 0.2
test_size = 0.1  

# create an array of idx numbers for each element of the full dataset
idx = list(range(num_train))
print(num_train, idx)

In [None]:
# perform train / val split for data points
train_indices, val_indices = train_test_split(idx, test_size=val_size, random_state=42)
train_indices, test_indices = train_test_split(train_indices, test_size=test_size/(1 - val_size), random_state=42)  

# override datasets to only be samples for each split
train_dataset = Subset(train_dataset, train_indices)
val_dataset = Subset(val_dataset, val_indices)
test_dataset = Subset(test_dataset, test_indices)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

### Observing our Data

In [None]:
# Check dataset sizes and sample shape
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

img_num = 92 # change this number to view a different sample

# Get a sample to check shape
sample_img, sample_label = train_dataset[img_num]
print(f"\nSample image shape: {sample_img.shape}")
print(f"Sample label: {sample_label}")
print(f"Classes: {train_dataset.dataset.classes}")

In [None]:
train_dataset[img_num]

In [None]:
print("\n".join(train_dataset.dataset.classes)) # join an array into a string

### Visualising Data

In [None]:
# Get a sample image and its label
sample_img, sample_label = train_dataset[img_num]
label_name = train_dataset.dataset.classes[sample_label]

# Plot it
plt.figure(figsize=(6, 6))
plt.title(f"Label: {label_name} (index: {10})")
plt.imshow(sample_img.permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C) for display
plt.axis("off")
plt.show()

In [None]:
# plotting for multiple images, randomly selected
figure = plt.figure(figsize=(12, 10))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):
    # generate a random index
    sample_idx = torch.randint(len(train_dataset), size=(1,)).item()
    # retrieve the image and the respective label for that index
    img, label = train_dataset[sample_idx]
    label_name = train_dataset.dataset.classes[label]
    
    # create the grid of subplots
    figure.add_subplot(rows, cols, i)
    plt.title(label_name, fontsize=8)
    plt.axis("off")
    # Convert from (C, H, W) to (H, W, C) for display
    plt.imshow(img.permute(1, 2, 0))
plt.tight_layout()
plt.show()

### Dataloaders

In [None]:
batch_size = 6

# create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

for X, y in val_loader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

### Defining our Convolutional Neural Network

In [None]:
class ConvNetwork(nn.Module):
    def __init__(self):
        super(ConvNetwork, self).__init__()
        # Input shape: [batch, 3, 64, ]
        # Breaking down the first conv layer: 
        #   > 1 input channel for grayscale images
        #   > 32 different filters to output
        #   > 3x3 kernel size
        #   > 1 padding
        # output shape: [batch, 64, 64, 64]
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        # 2x2 maxpooling, output shape: [batch, 64, 32, 32]
        self.pool = nn.MaxPool2d(2, 2)
        # and so on and so forth ...
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 3) # change the output size to match your number of classes
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.relu(self.conv3(x))
        x = self.pool(x)
        x = torch.flatten(x, 1) 
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
model = ConvNetwork().to(device)
print(model)

In [None]:
print("Layers and their initial weights/bias shapes:")
for name, param in model.named_parameters():
    print(f" - {name} | Shape: {param.shape} | Sample values: {param.data.flatten()[:5]}...")

print()
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

### Optimizer and Loss Function

In [None]:
learning_rate = 0.001

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

### Implementing our Training Loop

In [None]:
epochs = 100
train_losses = []
val_losses = []

for epoch in range(epochs): 
    train_loss = 0.0
    
    # training loop
    for batch_idx, (data, target) in enumerate(train_loader):
        # get data
        inputs = data.to(device)
        labels = target.to(device)
        
        # zero the gradients
        optimizer.zero_grad()
        # forward pass
        predictions = model(inputs)
        # compute the loss
        loss = loss_fn(predictions, labels)
        # backpropagate
        loss.backward()
        # update the parameters, i.e. weights
        optimizer.step()

        # save statistics to plot later
        train_loss += loss.item()
    
    # validation loop
    with torch.no_grad():
        val_loss = 0.0
        for batch_idx, (data, target) in enumerate(val_loader):
            # get data
            inputs = data.to(device)
            labels = target.to(device)
            # forward pass, no backpropagation and optimisation
            predictions = model(inputs)
            # compute the loss
            loss = loss_fn(predictions, labels)
            # save statistics to plot later
            val_loss += loss.item()
    
    # normalise cumulative losses to dataset size
    train_loss = train_loss / len(train_loader)
    val_loss = val_loss / len(val_loader)
    
    # added cumulative losses to lists to plot later
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    print(f'Epoch {epoch + 1}, train loss: {train_loss:.3f}, val loss: {val_loss:.3f}')

### Testing ~ Evaluating the Performance of our Model

In [None]:
plt.figure(figsize=(10,5))
plt.title("Train vs validation loss")
plt.plot(train_losses,label="train")
plt.plot(val_losses,label="val")
plt.xlabel("epochs")
plt.ylabel("cumulative loss")
plt.legend()
plt.show()

In [None]:
def test(dataloader, model, loss_fn, device=device):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss = 0.0
    correct = 0
    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    # average loss across batches and accuracy across samples
    test_loss = test_loss / num_batches
    accuracy = correct / size
    print(f"Test Error: \n Accuracy: {(100*accuracy):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss, accuracy

# Run test on the test loader
test_loss, test_acc = test(test_loader, model, loss_fn)

### Using our Model on an Input Image


See [Transforming and Augmenting Images](https://pytorch.org/vision/stable/transforms.html).

In [None]:
img = Image.open('images/colorful-carpet-sample.png') # try also images/4.png

transforms = v2.Compose([  
    # v2.Grayscale(num_output_channels=1),
    v2.Resize(size=(64,64), antialias=True),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
])

input = transforms(img).unsqueeze(0)  # ADD BATCH DIMENSION [1, 1, 28, 28]
input = input.to(device)

print(f"Input shape: {input.shape}")

In [None]:
model.eval()
with torch.no_grad():
    predictions = nn.Softmax(dim=-1)(model(input)).cpu().numpy()
print(f"Our predictions (shape: {predictions.shape})")
print(predictions)

In [None]:
predicted = np.argmax(predictions[0]) # argmax: the *index* of the highest prediction

plt.figure()
plt.title(f'Predicted number: {train_dataset.dataset.classes[predicted]}') # use the predicted category in the title
plt.imshow(img, cmap="gray")
plt.axis("off")
plt.show()

We can plot our predictions for all classes using a [bar chart](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.bar.html).

In [None]:
plt.figure(figsize=(14,5))
plt.title("Predictions")
xs = train_dataset.dataset.classes     # 0 to 9 for Xs, our ys are our predictions
plt.bar(xs, predictions[0]) # a bar chart
plt.xticks(xs)
plt.show()

In [None]:
torch.jit.save(torch.jit.script(model), model_dir / f"my_{model_name}_01.pt")

## To Do:

**Task 1:** Create your own dataset with 2, 3, or more classes, based on one of the suggested approaches. It would be effective to have at least 1000 images/class. Create a folder for each one of your classes and save the respective images there. Then move all of the class folders in `./datasets/name_of_custom_dataset`. Make sure you manually clean-up your data before training, to remove any irrelevant or destroyed images.

**Task 2:** Run all the cells in this code to train a classifier on your custom dataset.

**Task 3:** Add image transformations on the training dataset. Look in [here](https://pytorch.org/vision/stable/transforms.html) and [here](https://pytorch.org/vision/stable/auto_examples/transforms/plot_transforms_illustrations.html#sphx-glr-auto-examples-transforms-plot-transforms-illustrations-py) for references and examples. 

**Task 4:** Create a new notebook where you call the model that you just saved from this training and test it on some new unseen data.

**Bonus Challenges:**

**Bonus 1:** Look into the concept of Early Stopping. What is it? Could it be useful for our training? How? Attempt to implement it by adding the following lines of code after the training loop is completed:

   `if val_loss < best_loss:`
        
        `best_loss = val_loss`
        
        `torch.save(model.state_dict(), 'best_img_classifier.pt')`

For this to work, you will have to initialise best_loss with a high value before you enter the training loop.

**Bonus 2:** In this example you are building your classifier from scratch, i.e. you decide yourself what the architecture of the network is and you train it from the very beginning. Could you explore a way for training your classifier based on a pre-trained model? There are many available pre-trained models in [the torchvision models library](https://pytorch.org/vision/stable/models.html), like [ResNet](https://arxiv.org/abs/1512.03385) which is trained on [imagenet dataset](https://www.image-net.org/). This approach will require a few changes and additions in your notebook. Attempt it if you are feeling adventurous!