# Module 6 - Fine-tuning ResNet toward plankton data

We have seen that a neural network that was trained on a completely plankton-unrelated dataset (like ImageNet) still produces features that allow the classification of plankton data.
Now, we can go a step further and *fine-tune* such a network to do plankton classification.
This is akin to teaching a person without prior oceanographic experience how to recognize different types of fish, assuming that they are able to recognize other kinds of objects.

In practice, CNNs are almost always fine-tuned (and not trained from scratch) for convergence reasons.

In [None]:
import copy
import os
import time
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from torch.optim import lr_scheduler
from torch.utils.data import RandomSampler
from torchvision import datasets, models, transforms
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, ToTensor
from tqdm.notebook import tqdm, trange
import random

sys.path.insert(1, os.path.join(os.getcwd(),'computer-vision-workshop/utilities'))
from display_utils import imshow_tensor, make_confmat
from split import stratified_random_split

TRAINING_PATH = "/groups/cv-workshop/ZooScan/train"
VALIDATION_PATH = "/groups/cv-workshop/ZooScan/train"

## Data loading and transformation

Image datasets can conveniently loaded with [`torchvision.datasets.ImageFolder`](https://docs.pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html).
It assumes one folder for each class where the images are located.

CNNs have a fixed input size. ResNets happen to be trained with 224x244 images. 
Therefore, we need to make sure that each image has the correct dimensions.
`ImageFolder` has a `transform` parameter for that.
After resizing, the images need to be converted to a PyTorch [`Tensor`](https://pytorch.org/docs/stable/tensors.html#torch.Tensor).

We will use the training set for the training the network and the validation set for the evaluation of the model.

In [None]:
transform = Compose([
    # Resize every image to a 224x244 square
    Resize((224,224)),
    # Convert to a tensor that PyTorch can work with
    ToTensor()
])

# Images are located at at {dataset_path}/{class_name}/{objid}.jpg
dataset_train = ImageFolder(TRAINING_PATH, transform)
dataset_val = ImageFolder(TRAINING_PATH, transform)

# Make sure that the class names are identical
assert dataset_train.classes == dataset_val.classes

Now let's look at the first example.

In [None]:
# Extract the tensor and the label of the first example
tensor, label = dataset_train[0]

print("Class: {:d} ({})".format(label, dataset_train.classes[label]))
imshow_tensor(tensor)

## Preparing the model

We start with a pre-trained ResNet18 model.
It was initially trained on ImageNet which happens to contain 1000 classes. However, our plankton dataset contains XXX classes. Therefore, we have to reset the classifier layer to the correct number of classes.

In [None]:
model = models.resnet18(pretrained=True)

# get the number of features that are input to the fully connected layer
num_ftrs = model.fc.in_features

# reset the fully connect layer
model.fc = nn.Linear(num_ftrs, len(dataset_train.classes))

# if we are using a multiple GPU system with multiple users on it 
# then we want to balance out which GPU we use. By default we'll
# just use GPU0. Randomly pick one to use.

gpu_id = random.randint(0,torch.cuda.device_count()-1)
print("using GPU", gpu_id)

# Transfer model to GPU
model = model.cuda(device=gpu_id)

## Preparing the optimizer

We will train the network using [Stochastic Gradient Descend (SDG)](https://en.wikipedia.org/wiki/Stochastic_gradient_descent).
In each iteration, the network parameters are updated in order to minimize a training criterion, in our case the [Cross Entropy](https://en.wikipedia.org/wiki/Cross_entropy) Loss.
The better the predictions, the smaller the loss.

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

## Train

In [None]:
loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=110,
                                           shuffle=True, num_workers=1)

In [None]:
# Activate training mode
model.train()

# Train for 5 epochs
for epoch in trange(5, desc="Epoch"):
    # tqdm_notebook displays a nice progress bar
    with tqdm(loader_train, desc="Training Epoch #{:d}".format(epoch + 1)) as t:
        for inputs, labels in t:
            # Copy data to GPU
            inputs = inputs.cuda(device=gpu_id)
            labels = labels.cuda(device=gpu_id)
    # for ii, data in enumerate(loader_train, 0):
    #     inputs, labels = data
    #     inputs = inputs.cuda(device=gpu_id)
    #     labels = labels.cuda(device=gpu_id)
            # zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
            # print statistics
            t.set_postfix(loss=loss.item())

print('Finished Training')
# empty the GPU cache memory to free memory for other users
torch.cuda.empty_cache()

## Evaluate

Let's see how well our model performs.

First, display some examplary images together with their ground-truth and predicted labels.

In [None]:
%matplotlib inline

# Activate evaluation mode
model.eval()

# A data loader for the validation set with a batch size of 4 for demonstration purposes
loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=4, shuffle=True)

# Extract one batch
images, labels = next(iter(loader_val))

# Show images of the batch
imshow_tensor(torchvision.utils.make_grid(images))
print('Ground truth:', ', '.join('%5s' % dataset_val.classes[labels[j]] for j in range(4)))

# Run the batch through the model
outputs = model(images.cuda(device=gpu_id))

# Collect the predicted classes
_, predicted = torch.max(outputs, 1)

print('Predicted:', ', '.join('%5s' % dataset_val.classes[predicted[j]]
                              for j in range(4)))

Now we do a thorough evaluation of the whole dataset. In order to do that, we need to run the whole validation set through the network and record the predictions.

In [None]:
labels_true = []
labels_predicted = []

# Validation data loader with a reasonable batch size
loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=110, num_workers=6, shuffle=True)

# Activate evaluation mode
model.eval()

# We don't need to calculate gradients
with torch.no_grad():
    with tqdm(loader_val, desc="Evaluating") as t:
        for inputs_batch, labels_batch in t:
            # Copy data to GPU
            inputs_batch = inputs_batch.cuda(device=gpu_id)

            outputs = model(inputs_batch)
            _, predicted = torch.max(outputs.data, 1)

            labels_true.extend(labels_batch.tolist())
            labels_predicted.extend(predicted.tolist())

In [None]:
acc = accuracy_score(labels_true, labels_predicted)
print("Accuracy:", acc)
print(classification_report(labels_true,
                            labels_predicted,
                            labels=np.arange(len(dataset_val.classes)),
                            target_names=dataset_val.classes))
# empty the GPU cache memory to free memory for other users
torch.cuda.empty_cache()

Do you know what these scores are?
Make yourself familiar with [precision and recall](https://en.wikipedia.org/wiki/Precision_and_recall) and the [F-Score](https://en.wikipedia.org/wiki/F1_score). These are important metrics for the evaluation of a classifier.

You may notice that classes with a larger support (number of examples) tend to get higher scores. Can you guess why?

In [None]:
%matplotlib widget
make_confmat(labels_true, labels_predicted, acc, labels=dataset_val.classes)

Hopefully you see a diagonal of true predictions. You may also notice vertical stripes that occur if a wide range of different objects is classified as the same class. This often happens in datasets with a skewed class distribtion where a few classes contain most of the objects. In this case, the classifier learns that it is a relative save bet to predict these majority classes most of the time. Module 7 will take care of this.

## Exercises

- Apply this notebook to the SPC dataset.
- Compare the results to the previous classifiers.
- What happens if you use a randomly initialized network (`model = models.resnet18(pretrained=False)`)?
- Try different [transformations](https://docs.pytorch.org/vision/0.22/transforms.html).

## Conclusion

In this module, you learned how to use a folder of images to fine-tune a model in PyTorch.

## Bonus: Visualization of the feature space

How are the classes distributed in the feature space?

In [None]:
# Copy the model but remove the last layer
feat_extractor = nn.Sequential(*list(model.children())[:-1])

features = []
labels = []
# We don't need to calculate gradients
with torch.no_grad():
    with tqdm(loader_val, desc="Evaluating") as t:
        for input_batch, label_batch in t:
            # Copy input batch to GPU
            input_batch = input_batch.cuda(device=gpu_id)

            features_batch = feat_extractor(input_batch)
            
            features.extend(features_batch.cpu().numpy())
            labels.extend(label_batch.cpu().numpy())
            
features = np.array(features)
labels = np.array(labels)

We project the features from 512 dimensions to 2 dimensions using [t-SNE](https://lvdmaaten.github.io/tsne/). This will take a while.

In [None]:
%%time

from sklearn.manifold import TSNE

tsne = TSNE()
features_2d = tsne.fit_transform(np.squeeze(features)[:1000])

In [None]:
%matplotlib widget
fig, ax = plt.subplots()

scat = ax.scatter(features_2d[:,0], features_2d[:,1], c=labels[:1000])
cbar = fig.colorbar(scat)
cbar.set_ticks(np.arange(len(dataset_val.classes)))
cbar.set_ticklabels(dataset_val.classes)

Ideally, the different classes build clusters in the feature space.

# Free Up GPU memory when done

GPU memory will remain in use until the kernel stops. Other users might need this memory to run their jobs. You can free all your memory by clicking "Kernel -> Shutdown Kernel". You will still be able to see any results you had visualised in your notebook but the state of all variables will be lost. 

We can also free some memory by running `torch.cuda.empty_cache()` while still keeping our kernel running and not losing any variable state.


In [None]:
# empty the GPU cache memory to free memory for other users
torch.cuda.empty_cache()