In [None]:
#| include: false
from fasterai.core.all import *
from fasterai.distill.all import *
from fastai.vision.all import *

## Overview

**Knowledge Distillation** transfers knowledge from a large, accurate "teacher" model to a smaller, faster "student" model. The student learns not just from ground truth labels, but also from the teacher's soft predictionsâ€”capturing the teacher's learned relationships between classes.

### Why Use Knowledge Distillation?

| Approach | Model Size | Training Data | Accuracy |
|----------|------------|---------------|----------|
| Train small model alone | Small | Labels only | Lower |
| **Distillation** | Small | Labels + Teacher knowledge | **Higher** |

### Key Benefits

- **Smaller deployment models** - Student can be much smaller than teacher
- **Better than training from scratch** - Teacher provides richer supervision
- **No additional labeled data needed** - Uses existing training set
- **Flexible loss functions** - Soft targets, attention transfer, feature matching

In this tutorial, we'll distill a ResNet-34 (teacher) into a ResNet-18 (student).

## 1. Setup and Data

In [None]:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

## 2. Train the Teacher Model

First, we train the larger teacher model (ResNet-34) to achieve good accuracy on our dataset:

In [None]:
teacher = vision_learner(dls, resnet34, metrics=accuracy)
teacher.unfreeze()
teacher.fit_one_cycle(10, 1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,0.663302,0.38265,0.881597,00:02
1,0.444977,1.731543,0.723951,00:02
2,0.456336,0.390448,0.847091,00:02
3,0.463871,0.31498,0.864005,00:02
4,0.399526,0.548,0.845061,00:03
5,0.267582,0.222926,0.903248,00:02
6,0.177511,0.180466,0.933694,00:02
7,0.121694,0.195583,0.927605,00:02
8,0.077676,0.192459,0.936401,00:02
9,0.047532,0.180056,0.936401,00:02


## 3. Baseline: Student Without Distillation

Let's train a ResNet-18 student model **without** distillation to establish a baseline:

Training from scratch with only ground truth labels:

In [None]:
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
student.fit_one_cycle(10, 1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,0.611359,0.660552,0.67659,00:02
1,0.565523,0.669257,0.70433,00:02
2,0.537007,0.567621,0.728011,00:02
3,0.498747,0.541553,0.741543,00:02
4,0.449077,0.455508,0.783491,00:02
5,0.399169,0.393245,0.828823,00:02
6,0.342478,0.369859,0.834912,00:02
7,0.272756,0.334547,0.853857,00:02
8,0.187447,0.346933,0.859269,00:02
9,0.147805,0.358428,0.859946,00:02


## 4. Student With Knowledge Distillation

Now let's train the same architecture with help from the teacher using `SoftTarget` loss:

The `SoftTarget` loss combines:
- **Classification loss** (Cross-Entropy with ground truth)
- **Distillation loss** (KL divergence between student and teacher soft predictions)

In [None]:
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, SoftTarget, schedule=cos)
student.fit_one_cycle(10, 1e-3, cbs=kd)

epoch,train_loss,valid_loss,accuracy,time
0,0.622423,0.658045,0.692828,00:03
1,0.65433,1.211342,0.677267,00:02
2,0.736943,0.75777,0.736807,00:03
3,0.830559,0.949577,0.698241,00:02
4,0.882739,0.915873,0.79364,00:03
5,0.890884,0.799081,0.824763,00:02
6,0.817516,1.475584,0.737483,00:02
7,0.687356,0.73007,0.866035,00:02
8,0.523237,0.718984,0.866035,00:03
9,0.452811,0.703519,0.870771,00:03


With teacher guidance, the student achieves better accuracy!

## 5. Advanced: Attention Transfer

Beyond soft targets, fasterai supports more sophisticated distillation losses like **Attention Transfer** from "Paying Attention to Attention". Here, the student learns to replicate the teacher's attention maps at intermediate layers.

To use intermediate layer losses, specify which layers to match using their string names. Use `get_model_layers` to discover available layers.

Here we match attention maps after each residual block:

In [None]:
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd = KnowledgeDistillationCallback(teacher.model, Attention, ['layer1', 'layer2', 'layer3', 'layer4'], ['0.4', '0.5', '0.6', '0.7'], weight=0.9)
student.fit_one_cycle(10, 1e-3, cbs=kd)

epoch,train_loss,valid_loss,accuracy,time
0,0.092506,0.091555,0.67862,00:03
1,0.083053,0.084819,0.648173,00:03
2,0.071733,0.073612,0.705007,00:02
3,0.062212,0.059138,0.815291,00:03
4,0.055396,0.053225,0.82747,00:03
5,0.047694,0.052672,0.82138,00:03
6,0.041354,0.048255,0.860622,00:03
7,0.031322,0.042128,0.874831,00:03
8,0.024217,0.042546,0.879567,00:03
9,0.019581,0.042967,0.886333,00:03


## 6. Parameter Guide

### KnowledgeDistillationCallback Parameters

| Parameter | Description |
|-----------|-------------|
| `teacher` | The trained teacher model |
| `loss` | Distillation loss function (`SoftTarget`, `Attention`, `FitNet`, etc.) |
| `student_layers` | (For intermediate losses) Layers in student to extract features from |
| `teacher_layers` | (For intermediate losses) Corresponding layers in teacher |
| `weight` | Weight of distillation loss vs classification loss |

### Available Loss Functions

| Loss | Type | Description |
|------|------|-------------|
| `SoftTarget` | Output | Match teacher's softened predictions |
| `Attention` | Intermediate | Match attention maps (spatial activation patterns) |
| `FitNet` | Intermediate | Directly match feature maps (requires same dimensions) |
| `RKD` | Relational | Match distance/angle relationships between samples |
| `PKT` | Probabilistic | Match probability distributions in feature space |

## Summary

| Concept | Description |
|---------|-------------|
| **Knowledge Distillation** | Training a small student to mimic a large teacher |
| **KnowledgeDistillationCallback** | fastai callback for distillation during training |
| **SoftTarget** | Basic distillation using teacher's soft predictions |
| **Attention Transfer** | Advanced distillation using intermediate attention maps |
| **Typical Benefit** | 1-3% accuracy improvement over training student alone |

---

## See Also

- [Distillation Losses](../../distill/losses.html) - All available distillation loss functions
- [Pruner](../../prune/pruner.html) - Combine distillation with pruning for even smaller models
- [Sparsifier](../../sparse/sparsifier.html) - Add sparsity to distilled models