# Transfer learning -- VGG-16 on Cifar-10

We will use tranfer learning to retrain Vgg16 on the cifar10 dataset (https://www.cs.toronto.edu/~kriz/cifar.html).

## Imports

In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt
from torchmetrics.classification import MulticlassAccuracy, MulticlassConfusionMatrix

## Settings and Dataset

In [None]:
##########################
### SETTINGS
##########################

BATCH_SIZE = 256
NUM_EPOCHS = 50
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
##########################
### CIFAR-10 DATASET
##########################


train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((70, 70)),
    torchvision.transforms.RandomCrop((64, 64)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((70, 70)),        
    torchvision.transforms.CenterCrop((64, 64)),            
    torchvision.transforms.ToTensor(),                
    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.2255))])

train_dataset = datasets.CIFAR10(root='data',
                                 train=True,
                                 transform=train_transforms,
                                 download=True)

test_dataset = datasets.CIFAR10(root='data',
                                train=False,
                                transform=test_transforms)

train_loader = DataLoader(dataset=train_dataset,
                                  batch_size=BATCH_SIZE,
                                  num_workers=4,
                                  drop_last=True,
                                  shuffle=True)

test_loader = DataLoader(dataset=test_dataset,
                             batch_size=BATCH_SIZE,
                             num_workers=4,
                             shuffle=False)

# Checking the dataset
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    print('Class labels of 10 examples:', labels[:10])
    break

## Load Pre-Trained Model

You can find a list of available models on the [website of torchvision](https://pytorch.org/vision/stable/models.html). Here we load the VGG16 model with pre-trained weights (trained on ImageNet).

In [None]:
model = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
model

## Freezing the Model

We freeze all parameters of the model by setting the requires_grad attribute to False. This will prevent the gradients to be computetd on these parameters and thus the parameters to be updated.

In [None]:
for param in model.parameters():
    param.requires_grad = False

Assume we want to fine-tune (train) the last 3 layers, we un-freeze them (the indices corresponds to the ones in the model printed above):

In [None]:
model.classifier[0].requires_grad = True
model.classifier[3].requires_grad = True

For the last layer, because the number of class labels differs compared to ImageNet, we replace the output layer with your own output layer:

In [None]:
model.classifier[6] = torch.nn.Linear(4096, 10)

## Then, Training as Usual

In [None]:
model = model.to(DEVICE)

#optimizer = torch.optim.Adam(model.parameters())
optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=0.001)

metric = MulticlassAccuracy(num_classes=10).to(DEVICE)

minibatch_loss_list, train_acc_list, valid_acc_list = [], [], []
for epoch in range(NUM_EPOCHS):
    model.train()
    for batch_idx, (features, targets) in enumerate(train_loader):

        features = features.to(DEVICE)
        targets = targets.to(DEVICE)

        # FORWARD AND BACK PROP
        logits = model(features)
        loss = torch.nn.functional.cross_entropy(logits, targets)
        optimizer.zero_grad()

        loss.backward()

        # UPDATE MODEL PARAMETERS
        optimizer.step()

        # ## LOGGING
        minibatch_loss_list.append(loss.detach().item())
        if not batch_idx % 100:
            print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} '
                  f'| Batch {batch_idx:04d}/{len(train_loader):04d} '
                  f'| Loss: {loss:.4f}')

    # Test
    model.eval()
    with torch.no_grad():  # save memory during inference
        metric.reset()
        for i, (features, targets) in enumerate(train_loader):
            features = features.to(DEVICE)
            targets = targets.float().to(DEVICE)
            logits = model(features)
            metric.update(logits, targets)
            
        train_acc = metric.compute()
        
        metric.reset()
        for i, (features, targets) in enumerate(test_loader):
            features = features.to(DEVICE)
            targets = targets.float().to(DEVICE)
            logits = model(features)
            metric.update(logits, targets)
            
        test_acc = metric.compute()

        print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} '
              f'| Train: {train_acc*100 :.2f}% '
              f'| Validation: {test_acc*100 :.2f}%')
        train_acc_list.append(train_acc.item())
        valid_acc_list.append(test_acc.item())


In [None]:
conf_matrix = MulticlassConfusionMatrix(num_classes=10).to(DEVICE)
metric = MulticlassAccuracy(num_classes=10).to(DEVICE)

class_dict = {0: 'airplane',
              1: 'automobile',
              2: 'bird',
              3: 'cat',
              4: 'deer',
              5: 'dog',
              6: 'frog',
              7: 'horse',
              8: 'ship',
              9: 'truck'}
# Test
model.eval()

all_targets = torch.tensor([]).to(DEVICE)
all_predictions = torch.tensor([]).to(DEVICE)
with torch.no_grad():  # save memory during inference
    for i, (features, targets) in enumerate(test_loader):
        features = features.to(DEVICE)
        targets = targets.float().to(DEVICE)
        logits = model(features)
        predicted_labels = torch.argmax(logits,1)
        
        all_targets = torch.cat((all_targets, targets))
        all_predictions = torch.cat((all_predictions,predicted_labels))
        
print(f"Accuracy: {metric(all_predictions, all_targets)*100:.2f}")
conf_matrix.update(all_predictions, all_targets)
fig_, ax_ = conf_matrix.plot(labels=class_dict.values())
plt.show()