HHU Deep Learning, SS2022/23, 19.05.2023, Prof. Dr. Markus Kollmann

Lecturers and Tutoring is done by Tim Kaiser, Nikolas Adaloglou and Felix Michels.

# Assignment 07 - Knowledge distillation on CIFAR100 with Vision Transformers


Copyright © 2023 Nikolas Adaloglou, Tim Kaiser and Felix Michels

Link to Imagenet-pretrained + fine-tuned teacher model on CIFAR100: https://uni-duesseldorf.sciebo.de/s/Y6yGbC9AHpxMKBJ


## Contents

1. Imports and preparation
2. Load the timm models
3. Applying image augmentations with torchvision
4. Fine-tune the teacher model on CIFAR100
5. Train ViT-tiny (student) from scratch (random init) with cross-entropy (without knowledge distilation)
6. Implement knowledge distillation loss
7. Implement MultiScaleData loading
8. Adjust training code to support knowledge distillation from the teacher model
9. Train student with knowledge distillation

# Introduction

The following method is partially based on the paper: ["Be Your Own Teacher: Improve the Performance of Convolutional Neural
Networks via Self Distillation"](https://arxiv.org/pdf/1905.08094.pdf)

Briefly, knowledge distillation simply trains a new randomly initialized model to match the prediction of another trained model.

The output of the so-called teacher model is some mixed version of a set of real labels, i.e. 88% cat, 7% tiger, 5% dog.


#### Overview
The teacher model is usually a larger model pretrained on the same dataset. However, for the purpose of this exercise we will use a larger imagenet-pretrained model (ViT-base), fine-tune it on cifar100. This will be our teacher. 

We will then try to improve a much smaller model trained from scratch with the losses proposed in https://arxiv.org/pdf/1905.08094.pdf , equation 2 and equation 3.


`Note`: I used the pretrained models from timm, you can install it with `!pip install timm`

# Part I. Imports and preparation


Below we provide the imports and some necessary data functionalities. We will experiment with CIFAR100 this time!

You may need to change the path (root='../data') where the data will be downloaded.

In [None]:
import os 
import torch
import torchvision
import numpy as np
import random
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torchvision
import timm
import torch.optim as optim

# local imports
from utils import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'
reproducibility(99)

def get_transform_plain():
    mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)
    std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)
    normalize = transforms.Normalize(mean.tolist(), std.tolist())
    return transforms.Compose(
                    [transforms.ToTensor(),
                    normalize])

def load_cifar100_data(transform=None, train=True, batch_size=64, num_workers=2, shuffle=False):
    pin = True if train else False
    if transform is None:
        transform = get_transform_plain()
        
    dataset = torchvision.datasets.CIFAR100(root='../data',  transform=transform, train=train, download=True)
    loader = torch.utils.data.DataLoader(dataset, 
                                         batch_size=batch_size,
                                         shuffle=shuffle,
                                         num_workers=num_workers,
                                         pin_memory=pin)
    return loader

# Part II.  Load the timm models

- Teacher ViT-B: use the supervised imagenet-1k pretrained weights of ViT-Base from `timm` we will use that model for supervised fine-tuning.
- Student: create a vision transformer for CIFAR100 images, patch_size of 4, embeding dimension of 192, 12 layers with 3 heads per layer, and prenorm: layer normalization should be applied before the attention and the mlp. This is the model that will be trained from scratch. Use random weights!


In [None]:
### START CODE HERE ### (≈ 3 lines of code) 
teacher_vitb = timm.create_model(...)
model_args = ...
student = ...
### END CODE HERE ###

# Part III. Applying image augmentations with torchvision

Augmentations are very important in visual representation learning. 

In natural language processing (NLP), you don’t care about augmentations. The pretexts tasks are quite straightforward. The most common task for NLP is to predict missing words from a sentence, like BERT.

Why are augmentation important in representation learning? 

Augmentations is an indirect way to pass human prior knowledge into the model. Typically strong augmentations are applied in pretraining backbone models with self-supervised methods. In natural images it is well-established that colour distortion and cropping are the key transformations to create augmented views. Ultimately, whatever gets transformed, don’t pay attention to it! Thus, in representation learning augmentation must maintain the image semantics (i.e. label-related information).

Your task here is to understand how to make the API calls to create a transformation pipeline with multiple image augmentations.

What augmentation you will apply here (**random order**):

- Apply random crop of the image to a minimum of 40% the size of the image. Resize image to the initial dimension.
- Apply 10% hue with 20% probability 
- Apply imagenet mean/std normalization
- Apply brightness jittering, contrast jittering, saturation jittering of 30% with 80% probability 
- Apply blur with a kernel that is 20% the size of the image and sigma in [0.1 , 2] with a probability of 10%
- Apply horizontal flip with 50% probability
- Convert the resulting image to greyscale with 10% probability

The final pipeline should consists of all the aforementioned augmentations together.

`Hint`: use [torchvision](https://pytorch.org/vision/stable/)!

In [None]:
import torchvision.transforms as T
import torch

def Augment(img_size=32):
### START CODE HERE ### (≈ 11 lines of code)
    color_jitter = T.RandomApply([T.ColorJitter(0.3, 0.3, 0.3, 0)], p=0.8)
    hue = T.RandomApply([T.ColorJitter(0, 0, 0, 0.1)], p=0.2)
    kernel = int(img_size*0.2)
    kernel = kernel-1 if kernel%2==0 else kernel
    sigma = (0.1, 2.0)
    blur =  T.RandomApply([T.GaussianBlur((kernel, kernel), sigma)], p=0.1)
    crop = T.RandomResizedCrop(size=img_size, scale=(0.4, 1.0))
    hflip = T.RandomHorizontalFlip(p=0.5)
    grey = T.RandomGrayscale(p=0.1)
    norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    return T.Compose([
                color_jitter,
                hue,
                blur,
                crop,
                hflip,
                grey,
                transforms.ToTensor(),
                norm] )
### END CODE HERE ###


def test_transform(transform):
    dl_transf = load_cifar100_data(transform=transform, train=False, batch_size=16)
    dl = load_cifar100_data(transform=None, train=False, batch_size=16)
    plt.figure(figsize=(8,8))
    img_id = 15
    plt.subplot(1,2,1)
    imshow(next(iter(dl))[0][img_id,...])
    plt.title('Original')
    plt.subplot(1,2,2)
    imshow(next(iter(dl_transf))[0][img_id,...])
    plt.title('Transformed')

transform = Augment(img_size=32)
test_transform(transform)

# Part IV. Fine-tune the teacher model on CIFAR100

- Fine-tune the teacher model for 10-20 epochs with the proposed augmentations.
- Keep the model with the best performance on CIFAR100 validation split.

`Hint`: You should be able to train the model with a batch size between 32-48 on the gcolab resources(~11.2 GB of VRAM with batch size of 48, train time per epoch with Adam ~7 minutes).

In [None]:
### START CODE HERE ### (≈ 6 lines of code)

# train 
#...
# valiadte
val_acc , val_loss_epoch = validate(teacher, val_loader, device)
print(f"Validation accuracy: {val_acc:.2f}%, Validation loss: {val_loss_epoch:.4f}")
### END CODE HERE ###

### Expected result
```
Model teacher_cifar100_vitb.pth is loaded from epoch 14 , loss 0.8871352076530457
Validation accuracy: 77.89%, Validation loss: 0.8871
```



# Part V. Train ViT-tiny (student) from scratch (random init) with cross-entropy (without knowledge distilation)

- Use the implemented augmentation pipeline for training the model from scratch 
- 40 epochs should be sufficient to get the expected performance

`Hint`:

```python
dict_log = finetune(student, optimizer, 40, train_dl, val_dl, device)
```

In [None]:
### START CODE HERE ### (≈ 14 lines of code)

# train code
dict_log = finetune(student, optimizer, 40, train_dl, val_dl, device)
#val code....
val_acc, val_loss = validate(student, val_dl, device)

### END CODE HERE ###
print("ViT Tiny trained from random init on CIFAR100 with augmentations: Val acc", val_acc, "Val loss" , val_loss )

### Expected result

```
Model vit-tiny.pth is loaded from epoch 37 , loss 1.7089426517486572
ViT Tiny trained from random init on CIFAR100 with augmentations: Val acc 55.76 Val loss 1.7089427
```

# Part VI. Implement knowledge distillation loss

$$ Loss = (1- \alpha) CE(p_s,y) + \alpha KL(p_s, p_t) $$

Where y is the labeled target, $p_s,p_t$ are the student/teacher prediction probabilities.

KL is the KL divergence and CE is the cross entropy.

In [None]:
import torch.nn.functional as F
class CriterioDistill():
    ### START CODE HERE ### (≈ 11 lines of code)
    def __init__(self, alpha=0.5):
        ...

    def kl_div(self, p, q):
        ...
    
    def __call__(self, logits_teacher, logits_student, labels):
        ...
    ### END CODE HERE ###

def test_criterio_distill():
    torch.manual_seed(42)
    criterion = CriterioDistill()
    logits_teacher = torch.randn(10, 100)
    logits_student = torch.randn(10, 100)
    labels = torch.randint(0, 100, (10,))
    loss = criterion(logits_teacher, logits_student, labels)
    print(loss)
    assert loss.shape == torch.Size([]), "Wrong shape for loss"

test_criterio_distill()

### Expected results


```tensor(3.0764)```

# Part VII. Implement MultiScaleData loading

The new dataset class should return the image in the original dimension (32x32), the image rescaled to 224 , and the label.

Why? Because the vit_tiny will accept images of size 32x32, while the teacher vit_base will accept images of 224x224.

In [None]:
class MultiScaleData(torch.utils.data.Dataset):
    def __init__(self, dataset, scale=224, transform=None):
        self.dataset = dataset
        ### START CODE HERE ### (≈ 3 lines of code) 
        
        ### END CODE HERE ###
    def __getitem__(self, index):
        ### START CODE HERE ### (≈ 4 lines of code)
        
        ### END CODE HERE ###
        return img, img_rescaled, label
    def __len__(self):
        return len(self.dataset)

def load_multiscale_data(transform=None, batch_size=64, num_workers=2, shuffle=False):
    pin = True
    if transform is None:
        transform = get_transform_plain()
    dataset = torchvision.datasets.CIFAR100(root='../data',  transform=None, train=True, download=True)
    dataset_rescaled = MultiScaleData(dataset, scale=224, transform=transform)
    loader = torch.utils.data.DataLoader(dataset_rescaled, 
                                         batch_size=batch_size,
                                         shuffle=shuffle,
                                         num_workers=num_workers,
                                         pin_memory=pin)
    return loader

def test_MultiScaleData():
    transform = get_transform_plain()
    loader = load_multiscale_data(transform=transform, batch_size=4, num_workers=2, shuffle=False)
    imgs, imgs_rescaled, labels = next(iter(loader))
    assert imgs.shape == (4, 3, 32, 32), "Wrong shape for img"
    assert imgs_rescaled.shape == (4, 3, 224, 224), "Wrong shape for img_rescaled"
    print("Success!")

test_MultiScaleData()

# Part VIII. Adjust training code to support knowledge distillation from the teacher model


In [None]:
def train_one_epoch_distill(model, optimizer, train_loader, device, teacher, alpha=0.5):
    model.train()
    teacher.eval()
    criterion_distill = CriterioDistill(alpha)
    loss_step = []
    correct, total = 0, 0
    for data in train_loader:
        ### START CODE HERE ### (≈ 14 line of code)
        
        ### END CODE HERE ###
    loss_curr_epoch = np.mean(loss_step)
    train_acc = (100 * correct / total).cpu()
    return loss_curr_epoch, train_acc

# GIVEN.
def train_distill(model, teacher, optimizer, num_epochs, train_loader, val_loader, device, prefix='model', alpha=0.5):
    best_val_loss = 1e8
    best_val_acc = 0
    model, teacher = model.to(device), teacher.to(device)
    dict_log = {"train_acc_epoch":[], "val_acc_epoch":[], "loss_epoch":[], "val_loss":[]}
    pbar = tqdm(range(num_epochs))
    for epoch in pbar:
        loss_curr_epoch, train_acc = train_one_epoch_distill(model, optimizer, train_loader, device, teacher, alpha)
        val_acc, val_loss = validate(model, val_loader, device)

        # Print epoch results to screen 
        msg = (f'Ep {epoch}/{num_epochs}: Accuracy : Train:{train_acc:.2f} \t Val:{val_acc:.2f} || Loss: Train {loss_curr_epoch:.3f} \t Val {val_loss:.3f}')
        pbar.set_description(msg)
        # Track stats
        dict_log["train_acc_epoch"].append(train_acc)
        dict_log["val_acc_epoch"].append(val_acc)
        dict_log["loss_epoch"].append(loss_curr_epoch)
        dict_log["val_loss"].append(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                  'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'loss': val_loss,
                  }, f'{prefix}_best_model_min_val_loss.pth')
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                  'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'loss': val_loss,
                  }, f'{prefix}_best_model_max_val_acc.pth')
    return dict_log

# Part IX. Train student with knowledge distillation


- Plug everything together and show the training curves, even if you train for a small number of epochs.
- Print the validation accuracy of the best model.

In [None]:
# Comments will be kept for help!
### START CODE HERE ### (≈ 16 lines of code)  
# Load teacher

# Load student

# Define optimizer

# Define transforms and loaders



# Launch training!
dict_log = train_distill(student, teacher, optimizer, num_epochs, train_dl, val_dl, device, prefix='distill_vit_t_kl_ts', alpha=alpha)
plt.figure(figsize=(10,10))
plot_stats(dict_log, modelname="distilled_ViT-tiny_", title="distilled_ViT-tiny_")
plt.savefig("ViT-tiny_CIFAR100_DISTILL___.png")

val_acc, val_loss = validate(student, val_dl, device)
print(f"Validation accuracy: {val_acc:.2f} \t Validation loss: {val_loss:.3f}")


### END CODE HERE ###

### Expected results
You can train the model even for 40 epochs (~120 min) and get a validarion accuracy of ~60%, already significantly better than training the vit-tiny from scratch on cifar100 without distillation.

Here is our results for reference and in case you were curious to see what's the outcome of training for more epochs. Of course, we don't expect you to reproduce our results with the gcolab resources but here is what we got (for alpha=0.7 - maybe not be the optimal choice):

Our best experiment for reference:
```
Validation accuracy: 64.88 	 Validation loss: 1.410
```
Remember that ViT Tiny trained from random init on CIFAR100 with augmentations: `Val acc 55.76 || Val loss 1.71`

# Conclusion and Bonus reads

That's the end of this exercise. If you reached this point, **congratulations**!


Can you find a better alpha? How can alpha be interpreted? 

What additional losses did the the paper "Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self Distillation" intoduce?

