In [None]:
#| include: false
from fasterai.distill.all import *
from fastai.vision.all import *

We'll illustrate how to use Knowledge Distillation to distill the knowledge of a Resnet34 (the teacher), to a Resnet18 (the student)

Let's us grab some data

In [None]:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

The first step is then to train the teacher model. We'll start from a pretrained model, ensuring to get good results on our dataset.

In [None]:
teacher = vision_learner(dls, resnet34, metrics=accuracy)
teacher.unfreeze()
teacher.fit_one_cycle(10, 1e-3)

### Without KD

We'll now train a Resnet18 from scratch, and without any help from the teacher model, to get that as a baseline 

In [None]:
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
#student = vision_learner(dls, resnet18, metrics=accuracy)
student.fit_one_cycle(10, 1e-3)

### With KD

And now we train the same model, but with the help of the teacher. The chosen loss is a combination of the regular classification loss (Cross-Entropy) and a loss pushing the student to learn from the teacher's predictions.

In [None]:
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
#student = vision_learner(dls, resnet18, metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, SoftTarget)
student.fit_one_cycle(10, 1e-3, cbs=kd)

When helped, the student model performs better ! 

There exist more complicated KD losses, such as the one coming from ``Paying Attention to Attention``, where the student tries to replicate the same attention maps of the teacher at intermediate layers.

Using such a loss requires to be able to specify from which layer we want to replicate those attention maps. To do so, we have to specify them from their `string` name, which can be obtained with the `get_model_layers` function.

For example, we set the loss to be applied after each Residual block of our models: 

In [None]:
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, Attention, ['layer1', 'layer2', 'layer3', 'layer4'], ['0.4', '0.5', '0.6', '0.7'], weight=0.9)
student.fit_one_cycle(10, 1e-3, cbs=kd)