In [None]:
#| default_exp distill.distillation_callback

In [None]:
#| include: false
import warnings
warnings.filterwarnings('ignore')
from nbdev.showdoc import *
import torch
import torch.nn.functional as F

%config InlineBackend.figure_format = 'retina'

Knowledge Distillation, sometimes called teacher-student training, is a compression method in which a small (the student) model is trained to mimic the behaviour of a larger (the teacher) model.

The main goal is to reveal what is called the **Dark Knowledge** hidden in the teacher model.

If we take the same [example](https://www.ttic.edu/dl/dark14.pdf) provided by Geoffrey Hinton et al., we have

The main problem of classification is that the output activation function (softmax) will, by design, make a single value really high and squash others.

$$
p_{i}=\frac{\exp \left(z_{i}\right)}{\sum_{j} \exp \left(z_{j}\right)}
$$

With $p_i$ the probability of class $i$, computed from the logits $z$

Here is an example to illustrate this phenomenon:

Let's say that we have trained a model to discriminate between the following 5 classes: [cow, dog, plane, cat, car]

And here is the output of the final layer (the logits) when the model is fed a new input image: 

In [None]:
logits = torch.tensor([1.3, 3.1, 0.2, 1.9, -0.3])

By judging on the predictions, the model seems confident that the input data is a dog and quite confident that it is definitely not a plane nor a car, with predictions for cow and cat being moderately high.

So the model not only has learned to recognize a dog in the image, but also that a dog is very different from a car and a plane and share similarities with cats and cows. This information is what is called **dark knowledge** !

When passing those predictions through a softmax, we have:

In [None]:
predictions = F.softmax(logits, dim=-1); predictions

tensor([0.1063, 0.6431, 0.0354, 0.1937, 0.0215])

This is accuenting the differences that we had earlier, discarding some of the dark knowledge acquired earlier. The way to keep this knowledge is to "soften" our softmax outputs, by adding a **temperature** parameter. The higher the temperature, the softer the predictions.

In [None]:
soft_predictions = F.softmax(logits/3, dim=-1); soft_predictions

tensor([0.1879, 0.3423, 0.1302, 0.2294, 0.1102])

:::{.callout-note}

if the Temperature is equal to 1, then we have regular softmax

:::

When applying Knowledge Distillation, we want to keep the **Dark Knowledge** that the teacher model has acquired during its training but not rely entirely on it. So we combine two losses: 

- The Teacher loss between the softened predictions of the teacher and the softened predictions of the student
- The Classification loss, which is the regular loss between hard labels and hard predictions

The combination between those losses are weighted by an additional parameter Î±, as:

$$
L_{K D}=\alpha  * \text { CrossEntropy }\left(p_{S}^{\tau}, p_{T}^{\tau}\right)+(1-\alpha) * \text { CrossEntropy }\left(p_{S}, y_{\text {true }}\right)
$$

With $p^{\tau}$ being the softened predictions of the student and teacher

:::{.callout-note}

In practice, the distillation loss will be a [bit different](http://cs230.stanford.edu/files_winter_2018/projects/6940224.pdf) in the implementation

:::

![](../imgs/distill.png "Knowledge Distillation")

In [None]:
#| export
from __future__ import annotations
from fastai.vision.all import *

import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import reduce
from typing import Callable, Any
from fasterai.core.schedule import Schedule

This can be done with fastai, using the Callback system !

In [None]:
#| export
class KnowledgeDistillationCallback(Callback):
    def __init__(self, 
                 teacher: nn.Module,                                           # Teacher model
                 loss: Callable,                                               # Distillation loss function
                 activations_student: str | list[str] | None = None,           # Student activation layers to match
                 activations_teacher: str | list[str] | None = None,           # Teacher activation layers to match
                 weight: float = 0.5,                                          # Weight for distillation loss
                 schedule: Schedule | None = None                              # Optional schedule for weight progression
    ):
        "Implement knowledge distillation from a teacher model to the student being trained"
        self.stored_activation_student, self.stored_activation_teacher  = {}, {}
        store_attr()
        if self.activations_student is not None:
            self.activations_student, self.activations_teacher = listify(activations_student), listify(activations_teacher)
        self.current_weight = weight
        
    def before_fit(self) -> None:
        "Setup hooks and prepare teacher before training"
        if self.activations_student and self.activations_teacher: self.register_hooks()
        self.teacher.eval()

    def before_batch(self) -> None:
        "Update distillation weight if scheduled"
        if self.schedule is not None:
            progress = self.schedule.progress(self.pct_train)
            self.current_weight = self.weight * progress

    def after_batch(self) -> None:
        "Clear activations after each batch to prevent memory buildup"
        self.stored_activation_student.clear()
        self.stored_activation_teacher.clear()

    def after_loss(self) -> None:
        "Apply distillation loss using teacher predictions"
        teacher_pred = self.teacher(self.x)
        new_loss = self.loss(pred=self.pred, teacher_pred=teacher_pred, fm_s=self.stored_activation_student, fm_t=self.stored_activation_teacher)
        self.learn.loss_grad = torch.lerp(self.learn.loss_grad, new_loss, self.current_weight)
        self.learn.loss = self.learn.loss_grad.clone()
    
    def register_hooks(self) -> None:
        "Set up forward hooks to capture activations"
        self.handles_st, self.handles_t = {}, {}
        for name_st, name_t in zip(self.activations_student, self.activations_teacher):
            self.handles_st[name_st] = get_module_by_name(self.learn, name_st).register_forward_hook(self.get_activation(self.stored_activation_student, name_st))
            self.handles_t[name_t] = get_module_by_name(self.teacher, name_t).register_forward_hook(self.get_activation(self.stored_activation_teacher, name_t))
        
    def get_activation(self, 
                       activation: dict[str, torch.Tensor],  # Dictionary to store activations
                       name: str                             # Name of the layer
    ) -> Callable:
        "Create a hook function to store activations"
        def hook(model, input, output):
            activation[name] = output
        return hook
    
    def find_hook(self, 
                  m: nn.Module
    ) -> list[tuple[str, int, str]]:
        "Find all hooks registered in a module"
        save = []
        module_name = type(m).__name__
        for k, v in m._forward_hooks.items():
            save.append((module_name, k, v.__name__))
        return save
    
    def remove_hooks(self, 
                     handles: dict[str, Any]
    ) -> None:
        "Remove all registered hooks"
        for handle in handles.values():
            handle.remove()
    
    def after_fit(self) -> None:
        "Clean up hooks after training"
        if self.activations_student and self.activations_teacher:
            self.remove_hooks(self.handles_t)
            self.remove_hooks(self.handles_st)

In [None]:
show_doc(KnowledgeDistillationCallback)

Found permutation search CUDA kernels
[ASP][Info] permutation_search_kernels can be imported.


---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/distill/distillation_callback.py#L19){target="_blank" style="float:right; font-size:smaller"}

### KnowledgeDistillationCallback

```python

def KnowledgeDistillationCallback(
    teacher:nn.Module, # Teacher model
    loss:Callable, # Distillation loss function
    activations_student:str | list[str] | None=None, # Student activation layers to match
    activations_teacher:str | list[str] | None=None, # Teacher activation layers to match
    weight:float=0.5, # Weight for distillation loss
    schedule:Schedule | None=None, # Optional schedule for weight progression
):


```

*Basic class handling tweaks of the training loop by changing a `Learner` in various events*

In [None]:
#| export
def get_model_layers(
    model: nn.Module,             # Model to inspect
    getLayerRepr: bool = False    # Whether to return layer representations
) -> list[str] | dict[str, str]:
    "Get all layer names in a model, optionally with their representations"
    layers = OrderedDict() if getLayerRepr else []
    
    def get_layers(net, prefix=[]):
        if hasattr(net, "_modules"):
            for name, layer in net._modules.items():
                if layer is None:
                    continue
                if getLayerRepr:
                    layers[".".join(prefix+[name])] = layer.__repr__()
                else:
                    layers.append(".".join(prefix + [name]))
                get_layers(layer, prefix=prefix+[name])

    get_layers(model)
    return layers



def get_module_by_name(
    module: torch.Tensor | nn.Module,  # Module to search in
    access_string: str                 # Dot-separated path to the submodule
) -> nn.Module | None:
    "Access a nested submodule by its name path"
    try:
        names = access_string.split(sep='.')
        return reduce(getattr, names, module)
    except AttributeError:
        return None

In [None]:
show_doc(get_model_layers)
show_doc(get_module_by_name)

---

[source](https://github.com/FasterAI-Labs/fasterai/tree/master/blob/master/fasterai/distill/distillation_callback.py#L121){target="_blank" style="float:right; font-size:smaller"}

### get_module_by_name

```python

def get_module_by_name(
    module:torch.Tensor | nn.Module, # Module to search in
    access_string:str, # Dot-separated path to the submodule
)->nn.Module | None:


```

*Access a nested submodule by its name path*

The loss function that is used may depend on the use case. For classification, we usually use the one presented above, named `SoftTarget` in fasterai. But for regression cases, we may want to perform regression on the logits directly.

In [None]:
#| hide
from fastcore.test import *

def _test_model():
    return nn.Sequential(
        nn.Conv2d(3, 16, 3, padding=1),
        nn.BatchNorm2d(16),
        nn.ReLU(),
        nn.Conv2d(16, 32, 3, padding=1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.AdaptiveAvgPool2d(1),
        nn.Flatten(),
        nn.Linear(32, 10)
    )

model = _test_model()

# get_model_layers returns list of strings
layers = get_model_layers(model)
assert isinstance(layers, list)
assert all(isinstance(n, str) for n in layers)
assert len(layers) > 0

# get_model_layers with repr returns dict (OrderedDict)
layers_d = get_model_layers(model, getLayerRepr=True)
assert isinstance(layers_d, OrderedDict)
assert len(layers_d) > 0

# get_module_by_name returns correct module
m0 = get_module_by_name(model, '0')
test_is(m0, model[0])

# get_module_by_name returns None for nonexistent
test_eq(get_module_by_name(model, 'nonexistent'), None)

# KnowledgeDistillationCallback construction
teacher = _test_model()
from fasterai.distill.losses import SoftTarget
cb = KnowledgeDistillationCallback(
    teacher=teacher,
    loss=SoftTarget,
    weight=0.5
)
test_eq(cb.weight, 0.5)
test_eq(cb.current_weight, 0.5)
assert cb.activations_student is None

In [None]:
#| hide
#| slow
# Teacher-student training with KnowledgeDistillationCallback
from torch.utils.data import TensorDataset
from fastai.data.core import DataLoaders
from fasterai.distill.losses import SoftTarget

_teacher = nn.Sequential(
    nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(),
    nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(16, 10)
)
_student = nn.Sequential(
    nn.Conv2d(3, 8, 3, padding=1), nn.ReLU(),
    nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(8, 10)
)

_X = torch.randn(64, 3, 8, 8)
_y = torch.randint(0, 10, (64,))
_dls = DataLoaders.from_dsets(
    TensorDataset(_X[:48], _y[:48]),
    TensorDataset(_X[48:], _y[48:]),
    bs=16, device='cpu'
)

_cb = KnowledgeDistillationCallback(teacher=_teacher, loss=SoftTarget)
_learn = Learner(_dls, _student, loss_func=nn.CrossEntropyLoss(), cbs=[_cb])
_learn.fit(2)  # verify it runs end-to-end without error

epoch,train_loss,valid_loss,time
0,1.191925,1.23174,00:00
1,1.190237,1.229488,00:00


---

## Usage with Schedule

You can now gradually increase the distillation weight during training:

```python
from fasterai.core.schedule import cos

# Gradually increase teacher influence from 0 to 0.8 using cosine schedule
cb = KnowledgeDistillationCallback(
    teacher=teacher_model,
    loss=SoftTarget,
    weight=0.8,
    schedule=cos
)

learn.fit(10, cbs=[cb])
```

---

## See Also

- [Distillation Losses](losses.html) - Available loss functions (Attention, FitNet, PKT, etc.)
- [Distillation Tutorial](../tutorials/distill/distill_callback.html) - Step-by-step guide to knowledge distillation
- [Schedules](../core/schedules.html) - Control distillation weight progression
- [Pruner](../prune/pruner.html) - Combine with pruning for maximum compression