# Transfer Learning



TL is the process of taking features learned for one task and reusing them to solve a new but similar problem, instead of starting the learning process from scratch. This technique is very popular since it allows to build accurate models without having to training our network for days. Usually, transfer learning is the way to go in tasks where the training dataset has a small number of samples.

A transfer learning workflow for image classification includes the following steps:

1. **Take a pre-trained model**: choose a model that was trained on a large dataset to solve a similar problem. A common practice is to grab models from the literature such as: VGG, ResNet, MobileNet etc.
2. **Chop the classifier**: remove the old classifier.
3. **Add a new classifier**: adapt the architecture to solve the new task.
4. **Use the convolutional block as Feature Extractor**: train **only** the new classifier on the new dataset and exclude the feature extractor from the back-propagation process (freezing).
5. **Fine-tuning**: a last optional step involves the **fine-tuning** of the new network. It consists in unfreezing parts of the pre-trained model and continue to training it on the new dataset in order to adapt the pretrained features to the new data. To avoid overfitting, we usually run this step only if the new dataset is **large**.


**A part comparing results obtained using a fine-tuned configuration of VIT and a fine-tuned CNN have been added.**


Credits to [Giuseppe Lisanti](https://www.unibo.it/sitoweb/giuseppe.lisanti/en), Samuele Salti and Riccardo Spezialetti.

Thanks to [Lorenzo Stacchio](https://www.unibo.it/sitoweb/lorenzo.stacchio2/en) for the ViT notebook part.



## Import Dependencies



In [None]:
! pip install torchinfo

import os
import random
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.utils as utils
import matplotlib.pyplot as plt
import numpy as np

from timeit import default_timer as timer
from torchvision import  models, transforms
from torchinfo import summary # Formerly known  as torch summary
from typing import Callable, Dict, List, Tuple, Union
from torch.optim import lr_scheduler
from torch.utils.tensorboard import SummaryWriter

## Reproducibility
Remember that deterministic operations tend to have slower performance than non-deterministic operations.

In [None]:
def fix_random(seed: int) -> None:
    """Fix all the possible sources of randomness.

    Args:
        seed: the seed to use.
    """
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

fix_random(seed=7)

## Runtime Settings

Let's check that our  environment has the proper configuration.

In [None]:
device = "cpu"
if torch.cuda.is_available:
  print('Gpu available')
  device = torch.device("cuda:0")
else:
  print('Please set GPU via Edit -> Notebook Settings.')

!nvidia-smi

## Dataset






In [None]:
!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip
!unzip -qq hymenoptera_data.zip

In [None]:
path_ds = 'hymenoptera_data'
path_ds_train = os.path.join(path_ds, 'train')
path_ds_val = os.path.join(path_ds, 'val')

# Means and standard deviations of the RGB channels of the ImageNet dataset
mean_image_net = [0.485, 0.456, 0.406]
std_image_net = [0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean_image_net, std_image_net)

size_image = 224  # try 64 or 128 (only with CNN)
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(size_image), # Data augmentation
                                                transforms.RandomHorizontalFlip(),        # Data augmentation
                                                transforms.ToTensor(),
                                                normalize]),

                   'val': transforms.Compose([transforms.Resize(int(size_image*1.2)),
                                              transforms.CenterCrop(size_image),
                                              transforms.ToTensor(),
                                              normalize])}


# ImageFolder is a generic dataloader where the images are arranged in this way:
#     root/class_1/xxx.png
#     root/class_1/xxy.png
#     ...
#     root/class_2/123.png
#     root/class_2/nsdf3.png
data_train = torchvision.datasets.ImageFolder(path_ds_train, data_transforms['train'])
data_val = torchvision.datasets.ImageFolder(path_ds_val, data_transforms['val'])

classes = data_train.classes
num_classes = len(classes)

print(f'Samples -> Train = {len(data_train)} | Val = {len(data_val)} '
      f'| Classes = {classes}')

Visualize some examples from the validation dataset.

In [None]:
# For visualization purposes, we need to denormalize the images
# to show them in the correct range of values.
class NormalizeInverse(torchvision.transforms.Normalize):
    def __init__(self, mean: List[float], std: List[float]) -> None:
        """Reconstructs the images in the input domain by inverting
        the normalization transformation.

        Args:
            mean: the mean used to normalize the images.
            std: the standard deviation used to normalize the images.
        """
        mean = torch.as_tensor(mean)
        std = torch.as_tensor(std)
        std_inv = 1 / (std + 1e-7) # 1e-7 is a small value to avoid division by zero.
        mean_inv = -mean * std_inv
        super().__init__(mean=mean_inv, std=std_inv)

    def __call__(self, tensor):
        return super().__call__(tensor.clone())

def show_grid(dataset: torchvision.datasets.ImageFolder,
              process: Callable = None) -> None:
    """Shows a grid with random images taken from the dataset.

    Args:
        dataset: the dataset containing the images.
        process: a function to apply on the images before showing them.
    """
    fig = plt.figure(figsize=(15, 5))
    indices_random = np.random.randint(10, size=10, high=len(dataset))

    for count, idx in enumerate(indices_random):
        fig.add_subplot(2, 5, count + 1)
        title = dataset.classes[dataset[idx][1]]
        plt.title(title)
        image_processed = process(dataset[idx][0]) if process is not None else dataset[idx][0]
        plt.imshow(transforms.ToPILImage()(image_processed))
        plt.axis("off")

    plt.tight_layout()
    plt.show()

# Show some examples
denormalize = NormalizeInverse(mean_image_net, std_image_net)
show_grid(data_val, process=denormalize)

Plot the distribution of data.

In [None]:
def plot_histograms(dataset_train: torchvision.datasets.ImageFolder,
                    dataset_test: torchvision.datasets.ImageFolder,
                    title: str,
                    classes_as_ticks: bool = True) -> None:
    """Plot histograms with train and test or validation data distributions.

    Args:
        dataset_train: the train dataset.
        dataset_test: the test or validation dataset.
        title: the title of the plot.
        classes_as_ticks: if true the name of the classes are show in the x axis.
    """
    classes = len(dataset_train.classes)
    train_data = [label for _, label in dataset_train]
    test_data = [label for _, label in dataset_test]

    plt.figure(figsize=(5, 4))
    plt.hist([train_data, test_data], bins=np.arange(classes + 1) - 0.5, rwidth=0.8, align='mid')
    plt.title(title)
    plt.xlabel('Classes')
    plt.ylabel('Number of images')
    plt.legend(['Train', 'Test'])

    if classes_as_ticks:
        plt.xticks(np.arange(classes), dataset_train.classes)

    plt.show()

plot_histograms(data_train, data_val, "Data Distribution")

## Train Functionalities


In [None]:
num_workers = 2
size_batch = 64

loader_train = torch.utils.data.DataLoader(data_train, batch_size=size_batch,
                                           shuffle=True,
                                           pin_memory=True, # speed-up CPU-GPU transfer
                                           num_workers=num_workers)

loader_val = torch.utils.data.DataLoader(data_val, batch_size=size_batch,
                                         shuffle=False,
                                         num_workers=num_workers)

In [None]:
# Accuracy
def get_correct_samples(scores: torch.Tensor, labels: torch.Tensor) -> int:
    """Gets the number of correctly classified examples.

    Args:
        scores: the scores predicted with the network.
        labels: the class labels.

    Returns:
        the number of correct samples.
    """
    # Argmax is used when the label is one-hot encoded (e.g. softmax)
    classes_predicted = torch.argmax(scores, 1) 
    return (classes_predicted == labels).sum().item()

# Train one epoch
def train(writer: utils.tensorboard.writer.SummaryWriter,
          model: nn.Module,
          train_loader: utils.data.DataLoader,
          device: torch.device,
          optimizer: torch.optim,
          criterion: Callable[[torch.Tensor, torch.Tensor], float],
          log_interval: int,
          epoch: int) -> Tuple[float, float]:
    """Trains a neural network for one epoch.

    Args:
        model: the model to train.
        train_loader: the data loader containing the training data.
        device: the device to use to train the model.
        optimizer: the optimizer to use to train the model.
        criterion: the loss to optimize.
        log_interval: the log interval.
        epoch: the number of the current epoch.

    Returns:
        the cross entropy Loss value on the training data.
        the accuracy on the training data.
    """
    correct = 0
    samples_train = 0
    loss_train = 0
    size_ds_train = len(train_loader.dataset)
    num_batches = len(train_loader)

    model.train()
    for idx_batch, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        # Reset the gradients to zero
        optimizer.zero_grad()

        scores = model(images)

        loss = criterion(scores, labels)
        loss_train += loss.item() * len(images)
        samples_train += len(images)

        loss.backward()
        optimizer.step()
        correct += get_correct_samples(scores, labels)

        if log_interval > 0:
            if idx_batch % log_interval == 0:
                running_loss = loss_train / samples_train
                global_step = idx_batch + (epoch * num_batches)
                writer.add_scalar('Metrics/Loss_Train_IT', running_loss, global_step)
                # Visualize images on tensorboard
                indices_random = torch.randperm(images.size(0))[:4]
                writer.add_images('Samples/Train', denormalize(images[indices_random]), global_step)

    loss_train /= samples_train
    accuracy_training = 100. * correct / samples_train
    return loss_train, accuracy_training

# Validate one epoch
def validate(model: nn.Module,
             data_loader: utils.data.DataLoader,
             device: torch.device,
             criterion: Callable[[torch.Tensor, torch.Tensor], float]) -> Tuple[float, float]:
    """Evaluates the model.

    Args:
        model: the model to evalaute.
        data_loader: the data loader containing the validation or test data.
        device: the device to use to evaluate the model.
        criterion: the loss function.

    Returns:
        the loss value on the validation data.
        the accuracy on the validation data.
    """
    correct = 0
    samples_val = 0
    loss_val = 0.
    
    # Always set the model to eval mode when evaluating
    model.eval() 
    with torch.no_grad():
        for idx_batch, (images, labels) in enumerate(data_loader):
            images, labels = images.to(device), labels.to(device)
            scores = model(images)

            loss = criterion(scores, labels)
            loss_val += loss.item() * len(images)
            samples_val += len(images)
            correct += get_correct_samples(scores, labels)

    loss_val /= samples_val
    accuracy = 100. * correct / samples_val
    return loss_val, accuracy

In [None]:
def training_loop(writer: utils.tensorboard.writer.SummaryWriter,
                  num_epochs: int,
                  optimizer: torch.optim,
                  lr_scheduler: torch.optim.lr_scheduler,
                  log_interval: int,
                  model: nn.Module,
                  loader_train: utils.data.DataLoader,
                  loader_val: utils.data.DataLoader,
                  verbose: bool=True) -> Dict:
    """Executes the training loop.

        Args:
            writer: the summary writer for tensorboard.
            num_epochs: the number of epochs.
            optimizer: the optimizer to use.
            lr_scheduler: the scheduler for the learning rate.
            log_interval: intervall to print on tensorboard.
            model: the mode to train.
            loader_train: the data loader containing the training data.
            loader_val: the data loader containing the validation data.
            verbose: if true print the value of loss.

        Returns:
            A dictionary with the statistics computed during the train:
            the values for the train loss for each epoch.
            the values for the train accuracy for each epoch.
            the values for the validation accuracy for each epoch.
            the time of execution in seconds for the entire loop.
    """
    criterion = nn.CrossEntropyLoss()
    loop_start = timer()

    losses_values = []
    train_acc_values = []
    val_acc_values = []
    for epoch in range(1, num_epochs + 1):
        time_start = timer()
        loss_train, accuracy_train = train(writer, model, loader_train, device,
                                           optimizer, criterion, log_interval,
                                           epoch)
        loss_val, accuracy_val = validate(model, loader_val, device, criterion)
        time_end = timer()

        losses_values.append(loss_train)
        train_acc_values.append(accuracy_train)
        val_acc_values.append(accuracy_val)

        lr = optimizer.param_groups[0]['lr']

        if verbose:
            print(f'Epoch: {epoch} '
                  f' Lr: {lr:.10f} '
                  f' Loss: Train = [{loss_train:.4f}] - Val = [{loss_val:.4f}] '
                  f' Accuracy: Train = [{accuracy_train:.2f}%] - Val = [{accuracy_val:.2f}%] '
                  f' Time one epoch (s): {(time_end - time_start):.4f} ')

        # Plot to tensorboard
        writer.add_scalar('Hyperparameters/Learning Rate', lr, epoch)
        writer.add_scalars('Metrics/Losses', {"Train": loss_train, "Val": loss_val}, epoch)
        writer.add_scalars('Metrics/Accuracy', {"Train": accuracy_train, "Val": accuracy_val}, epoch)
        writer.flush()

        # Increases the internal counter
        if lr_scheduler:
            lr_scheduler.step()

    loop_end = timer()
    time_loop = loop_end - loop_start
    if verbose:
        print(f'Time for {num_epochs} epochs (s): {(time_loop):.3f}')

    return {'loss_values': losses_values,
            'train_acc_values': train_acc_values,
            'val_acc_values': val_acc_values,
            'time': time_loop}

In [None]:
def execute(name_train: str, network: nn.Module, starting_lr: float,
            num_epochs: int,
            data_loader_train: torch.utils.data.DataLoader,
            data_loader_val: torch.utils.data.DataLoader) -> None:
    """Executes the training loop.

    Args:
        name_train: the name for the log subfolder.
        network: the network to train.
        starting_lr: the staring learning rate.
        num_epochs: the number of epochs.
        data_loader_train: the data loader with training data.
        data_loader_val: the data loader with validation data.
    """
    # Visualization
    log_interval = 20
    log_dir = os.path.join("logs", name_train)
    writer = torch.utils.tensorboard.writer.SummaryWriter(log_dir)

    # Optimization
    optimizer = optim.SGD(network.parameters(), lr=starting_lr, momentum=0.9,
                          weight_decay=0.0001)
    
    # Learning Rate schedule: decays the learning rate by a factor of `gamma`
    # every `step_size` epochs
    scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    # With newer optimizer such as Adam or AdamW the learning rate scheduler
    # is no longer needed, as each parameter has its own learning rate update
    # that can vary from 0 (no update) to lambda (maximum update).
    # optimizer = optim.AdamW(network.parameters(), lr=starting_lr)

    statistics = training_loop(writer, num_epochs, optimizer, scheduler,
                               log_interval, network, data_loader_train,
                               data_loader_val)
    writer.close()

    best_epoch = np.argmax(statistics['val_acc_values']) + 1
    best_accuracy = statistics['val_acc_values'][best_epoch - 1]

    print(f'Best val accuracy: {best_accuracy:.2f} epoch: {best_epoch}.')

## Train the network from scratch

To build our classifier we can use one of the off-the-shelf models provided in the [model zoo](https://pytorch.org/serve/model_zoo.html) of PyTorch. In this case, we will rely on a [ResNet18](https://pytorch.org/docs/stable/_modules/torchvision/models/resnet.html#resnet18). In this first example we will train the network from scratch. We have to set the parameter ```num_classes``` in the constructor of `models.resnet18` equal to the number of classes in our dataset (ie 2).

In [None]:
# Deprecated since version 0.13 of torchvision, may be removed in the future.
# net_from_scratch = models.resnet18(pretrained=False, num_classes=num_classes)

# Correct way
net_from_scratch = models.resnet18(weights=None, num_classes=num_classes)
net_from_scratch.to(device)

# You can use the batch size to preview the total memory impact of the model
# during forward and backward pass.
summary(net_from_scratch,input_size=(size_batch, 3, size_image, size_image))

We can use similar hyperparameters to those used in the [paper](https://arxiv.org/pdf/1512.03385.pdf).


In [None]:
name_train = "resnet_from_scratch"
lr = 0.001
num_epochs = 20  # try a higher number
execute(name_train, net_from_scratch, lr, num_epochs, loader_train, loader_val)

## Transfer Learning




Due to the too small size of our train dataset the solution learned from the network is not very accurate. This is the case in which transfer learning can help us. We can use the feature extractor trained on the ImageNet dataset as plug-and-play module and add a new classifier.

PyTorch [implementation](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L173) of [ResNet18](https://pytorch.org/hub/pytorch_vision_resnet/).






In [None]:
from torchvision.models import ResNet18_Weights

def get_model(pretrained: bool, num_classes: int) -> nn.Module:
    """Gets a image classifier based on ResNet18.

    Args:
        pretrained: if true initializes the network with ImageNet weights.
        num_classes: the number of classes.

    Returns:
        The required network.
    """
    
    model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
    # Here we override the old classifier
    model.fc = nn.Linear(model.fc.in_features, num_classes)

    return model

To avoid to use the gradients with respect to some parameters of our model, we can set the attribute `requires_grad` to `False` as explained in the [autograd](https://pytorch.org/docs/stable/notes/autograd.html) page of PyTorch.

In [None]:
def set_requires_grad_for_layer(layer: torch.nn.Module, train: bool) -> None:
    """Sets the attribute requires_grad to True or False for each parameter.

        Args:
            layer: the layer to freeze.
            train: if true train the layer.
    """
    for p in layer.parameters():
        p.requires_grad = train

In [None]:
net_feat_ext = get_model(True, num_classes)
net_feat_ext.to(device)

set_requires_grad_for_layer(net_feat_ext.conv1, False)
set_requires_grad_for_layer(net_feat_ext.bn1, False)
set_requires_grad_for_layer(net_feat_ext.layer1, False)
set_requires_grad_for_layer(net_feat_ext.layer2, False)
set_requires_grad_for_layer(net_feat_ext.layer3, False)
set_requires_grad_for_layer(net_feat_ext.layer4, False)

summary(net_feat_ext, input_size=(size_batch, 3, size_image, size_image))

In [None]:
name_train = "resnet_feat_ext"
execute(name_train, net_feat_ext, lr, num_epochs, loader_train, loader_val)

In [None]:
index_sample = random.randint(0, len(data_val))
image, label =  data_val[index_sample]
batch_image = image.unsqueeze(0)

net_feat_ext.eval()
with torch.no_grad():
    output = net_feat_ext(batch_image.to(device))
    _, preds = torch.max(output, 1)

fig = plt.figure()
cax = plt.imshow(transforms.ToPILImage()(denormalize(image)))

title = f'Prediction: {classes[preds[0].item()]} - Label: {classes[label]}'
title_obj = plt.title(title)
plt.setp(title_obj, color=("green" if preds[0]==label else "red"))

plt.show()

## Fine Tuning the Network


Once the network has been trained on the new dataset, you can try to continue training for few epochs the whole model end-to-end on the new dataset using a lower learning rate. A common practice is to make the initial learning rate 10 times smaller than the one used to train the network from scratch.

> **Good Practice**: *use a smaller learning rate than the one use for the scratch training*.

> **Good Practice**: *fine-tune for few epochs*.

Check the trainable parameters for the froozen model.

In [None]:
print("Trainable parameters: ", summary(net_feat_ext, input_size=(size_batch, 3, size_image, size_image)).trainable_params)

In [None]:
set_requires_grad_for_layer(net_feat_ext.conv1, True)
set_requires_grad_for_layer(net_feat_ext.bn1, True)
set_requires_grad_for_layer(net_feat_ext.layer1, True)
set_requires_grad_for_layer(net_feat_ext.layer2, True)
set_requires_grad_for_layer(net_feat_ext.layer3, True)
set_requires_grad_for_layer(net_feat_ext.layer4, True)

Check the trainable parameters.

In [None]:
print("Trainable parameters: ", summary(net_feat_ext, input_size=(size_batch, 3, size_image, size_image)).trainable_params)

In [None]:
name_train = "resnet_fine_tuning"
lr_ft = lr * 0.1
num_epochs_ft = 10
execute(name_train, net_feat_ext, lr_ft, num_epochs_ft, loader_train, loader_val)

## Visual Transformer (ViT)

In [None]:
# https://github.com/rwightman/pytorch-image-models
! pip install timm

In [None]:
def train_only_classification_layer(model, classf_layer_name):
  for name, param in model.named_parameters():
    if classf_layer_name in name:
      param.requires_grad = True
    else:
      param.requires_grad = False
  return model

In [None]:
import timm
vit_model_name = 'vit_tiny_patch16_224' # try changing "tiny" with "small", check results and number of params
vit_model = timm.create_model(vit_model_name, pretrained=True)
vit_model.head = torch.nn.Linear(in_features=vit_model.head.in_features, out_features=num_classes)
vit_model = train_only_classification_layer(vit_model, classf_layer_name = "head").to(device)


Check the trainable parameters.

In [None]:
summary(vit_model, input_size=(size_batch, 3, 224, 224))

In [None]:
name_train = vit_model_name+"_feat_ext"
lr_ft = lr
num_epochs_ft = 10
execute(name_train, vit_model, lr_ft, num_epochs_ft, loader_train, loader_val)

## Tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir="logs"