# FINETUNING

We exploit a complex model that was already trained on a big dataset (such as ImageNet) and we finetune it to classify our dataset.
In order to finetune the following steps are required:
- changing the last layer of the pretrained model in order to be on the same output size as our number of classes
- freezing (i.e. not training) the initial layers of the model
- train the last layers of the models on the new data

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# Let's start by defining the model to finetune
# For this example we will use VGG11
import torchvision.models as models

# vgg11 = models.vgg11(pretrained=True) # load vgg model pretrained on ImageNet
# print(vgg11)

# resnet18 = models.resnet18(pretrained = True)
# print(resnet18)

alexnet = models.alexnet(pretrained=True)
print(alexnet)

In [None]:
# define a function to freeze the model layers
def set_parameter_requires_grad(model, req_grad = False):
    for param in model.parameters():
        param.requires_grad = req_grad

In [None]:
# classes in CIFAR10?
NUM_CLASSES = 10
# freeze model layers
set_parameter_requires_grad(alexnet.features, req_grad = False)
# change last layer of the model
num_ftrs = alexnet.classifier[6].in_features # get the input dimension of last layer
alexnet.classifier[6] = nn.Linear(num_ftrs,NUM_CLASSES)
input_size = 224 #model requires this input size

alexnet = alexnet.to(device)
print(alexnet)

In [None]:
# define the dataset again with correct input size
# define transforms
transform = transforms.Compose(
    [transforms.Resize(input_size),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) # normalization parameteres tuned on ImageNet mean and var
# define batch size
batch_size = 4

# load train ds
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
# load test ds
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

In [None]:
# TRAIN AGAIN!

# define Loss and Optimizer
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(alexnet.parameters(), lr=0.001, momentum=0.9)

#for epoch in range(2):  # loop over the dataset multiple times

running_loss = 0.0
for i, data in enumerate(trainloader, 0):
    # stop after 2k iterations
    if i > 2000:
        break
    # get the inputs; data is a list of [inputs, labels]
    inputs, labels = data

    # put data on correct device
    inputs, labels = inputs.to(device), labels.to(device)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = alexnet(inputs) #finetuned model
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    # print statistics
    running_loss += loss.item()
    if i % 200 == 199:    # print every 200 mini-batches
        print(f'[it: {i + 1}] loss: {running_loss / 200:.3f}')
        running_loss = 0.0

print('Finished Training')

In [None]:
# if you want to save the model
PATH = './res/finetuned_cifar_net.pth'
#torch.save(alexnet.state_dict(), PATH)

# if you want to load the model
alexnet.load_state_dict(torch.load(PATH))

In [None]:
# now lets evaluate the model on the test set
correct = 0
total = 0

# put the model into evaluation mode
alexnet.eval()
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for i, data in enumerate(testloader):
        inputs, labels = data
        # put data on correct device
        inputs, labels = inputs.to(device), labels.to(device)
        # calculate outputs by running images through the network
        outputs = alexnet(inputs)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        if i % 200 == 199:    # print every 200 mini-batches
            print(f'It: {i}')

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

## Features visualization
Let's try to visualize the features in the dataset before and after the training 

In [None]:
# we will use TSNE as a tool to visualize high-dimensional data
from sklearn.manifold import TSNE #pip install --pre --extra-index https://pypi.anaconda.org/scipy-wheels-nightly/simple scikit-learn

In [None]:
import numpy as np
# get the features extractors from vgg
features_extractor = alexnet.features
avg_pool = alexnet.avgpool
out_features = alexnet.classifier[:3]
# initialize features and labels list
features_list = []
labels_list = []
with torch.no_grad():
    for i, data in enumerate(testloader):
        inputs, labels = data
        # put data on correct device
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = out_features(torch.flatten(avg_pool(features_extractor(inputs)),1))
        #flatten outputs
        outputs = outputs.view(outputs.size(0), -1)

        current_outputs = outputs.cpu().numpy()
        current_labels = labels.cpu().numpy()
        # create features list
        features_list.append(current_outputs)
        labels_list.append(current_labels)
        if i % 200 == 199:    # get only 200 batches
            break

In [None]:
features_list_cat = np.concatenate(features_list, axis=0)
labels_list_cat = np.concatenate(labels_list, axis=0)
print(features_list_cat.shape)
print(labels_list_cat.shape)

In [None]:
tsne = TSNE(n_components=2).fit_transform(features_list_cat)

In [None]:
print(tsne.shape)

In [None]:
# scale and move the coordinates so they fit [-1; 1] range
def scale_to_11_range(x):
    # compute the distribution range
    value_range = (np.max(x) - np.min(x))

    # move the distribution so that it starts from zero
    # by extracting the minimal value from all its values
    starts_from_zero = x - np.min(x)

    # make the distribution fit [-1; 1] by dividing by its range
    return 2*(starts_from_zero / value_range) - 1

# extract x and y coordinates representing the positions of the images on T-SNE plot
tx = tsne[:, 0]
ty = tsne[:, 1]

tx = scale_to_11_range(tx)
ty = scale_to_11_range(ty)

tsne[:, 0] = tx
tsne[:, 1] = ty

In [None]:
from res.plot_lib import plot_data, plot_data_np, plot_model, set_default
# Initiale default plotting parameters
set_default()

In [None]:
# plot classes
plot_data_np(tsne, labels_list_cat)

In [None]:
# let's try with the untrained network
alexnet_un = models.alexnet(pretrained=False).to(device) # load vgg model pretrained on ImageNet
alexnet_un.eval()

# get the features extractors from vgg
features_extractor = alexnet_un.features
avg_pool = alexnet_un.avgpool
out_features = alexnet_un.classifier[:3]
# initialize features and labels list
features_list = []
labels_list = []
with torch.no_grad():
    for i, data in enumerate(testloader):
        inputs, labels = data
        # put data on correct device
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = out_features(torch.flatten(avg_pool(features_extractor(inputs)),1))
        #flatten outputs
        outputs = outputs.view(outputs.size(0), -1)

        current_outputs = outputs.cpu().numpy()
        current_labels = labels.cpu().numpy()
        # create features list
        features_list.append(current_outputs)
        labels_list.append(current_labels)
        if i % 200 == 199:    # get only 200 el
            break

In [None]:
features_list_cat = np.concatenate(features_list, axis=0)
labels_list_cat = np.concatenate(labels_list, axis=0)
print(features_list_cat.shape)
print(labels_list_cat.shape)

In [None]:
tsne = TSNE(n_components=2).fit_transform(features_list_cat)

In [None]:
# extract x and y coordinates representing the positions of the images on T-SNE plot
tx = tsne[:, 0]
ty = tsne[:, 1]

tx = scale_to_11_range(tx)
ty = scale_to_11_range(ty)

tsne[:, 0] = tx
tsne[:, 1] = ty

In [None]:
# plot classes
plot_data_np(tsne, labels_list_cat)

In [None]:
# Ex1: do the finetuning using a resnet18 and vgg11

In [None]:
# Ex2: try to train the full model (both pretrained and not pretrained) and not just the last layers.
# How the results are different? 