# ðŸŽ“ Knowledge Distillation From Scratch: Teacher to Student

[!["Open In Colab"](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/model-size-reduction/blob/main/distillation_demo.ipynb)

## ðŸ“– The Theory: Transferring dark knowledge

Knowledge Distillation (KD) is a compression method where a compact model (**Student**) learns to mimic a large model (**Teacher**). Hinton et al. (2015) introduced the concept of "Dark Knowledge"â€”the relative probabilities of incorrect classes assigned by the teacher, which reveal the teacher's internal structure.

### The Loss Function
The student is trained to minimize a weighted sum of two losses:
1.  **Distillation Loss**: KL-Divergence between the teacher's and student's soft-targets.
2.  **Student Loss**: Standard Cross-Entropy between the student's output and the ground truth (hard labels).

### Temperature (T)
We use a temperature $T$ to soften the probability distributions. A higher $T$ makes the "dark knowledge" (the small probabilities of incorrect classes) more visible.

$$q_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}$$

The distillation loss is then scaled by $T^2$ to keep the gradient magnitudes consistent when changing $T$.

---

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer

def manual_distillation_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.5):
    """
    Implementation of the Distillation Loss from scratch.
    """
    # 1. Soft Targets (Teacher and Student)
    # We use log_softmax for the student and softmax for the teacher for KLDivLoss
    soft_targets = F.softmax(teacher_logits / T, dim=-1)
    soft_prob = F.log_softmax(student_logits / T, dim=-1)
    
    distillation_loss = nn.KLDivLoss(reduction="batchmean")(soft_prob, soft_targets) * (T * T)
    
    # 2. Standard Cross Entropy
    # student_logits: [batch, seq, vocab] -> [batch*seq, vocab]
    # labels: [batch, seq]
    student_loss = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), labels.view(-1))
    
    # 3. Combined Loss
    return alpha * distillation_loss + (1 - alpha) * student_loss

# 1. Setup Models
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
teacher = GPT2LMHeadModel.from_pretrained("gpt2").eval()

# Student: 2 layers, small hidden size
student_cfg = GPT2Config(n_layer=2, n_head=4, n_embd=256, vocab_size=tokenizer.vocab_size)
student = GPT2LMHeadModel(student_cfg)

print(f"Teacher Size: {sum(p.numel() for p in teacher.parameters())/1e6:.1f}M")
print(f"Student Size: {sum(p.numel() for p in student.parameters())/1e6:.1f}M")

## ðŸ”¢ Worked Example with numbers

A step-by-step trace of `manual_distillation_loss` with a tiny input so you can verify every value by hand.

- Vocabulary of **3 tokens**, batch size **1**, sequence length **1**
- Temperature **T = 2.0**, mixing weight **Î± = 0.5**

In [None]:
import torch, torch.nn as nn, torch.nn.functional as F

T, alpha = 2.0, 0.5

# Tiny inputs: batch=1, seq=1, vocab=3
teacher_logits = torch.tensor([[[2.0, 1.0, 0.0]]])   # shape [1, 1, 3]
student_logits = torch.tensor([[[1.5, 0.5, 0.0]]])   # shape [1, 1, 3]
labels         = torch.tensor([[0]])                  # correct token = index 0

# â”€â”€ Step 1 : Soft targets (Teacher) â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
# teacher_logits / T  =>  [1.0000,  0.5000,  0.0000]
# exp(...)            =>  [2.7183,  1.6487,  1.0000]   sum = 5.3670
soft_targets = F.softmax(teacher_logits / T, dim=-1)
# soft_targets        =>  [0.5065,  0.3072,  0.1863]

# â”€â”€ Step 2 : Soft log-probabilities (Student) â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
# student_logits / T  =>  [0.7500,  0.2500,  0.0000]
# exp(...)            =>  [2.1170,  1.2840,  1.0000]   sum = 4.4010
soft_prob = F.log_softmax(student_logits / T, dim=-1)
# soft_prob           =>  [-0.7315, -1.2315, -1.4815]

# â”€â”€ Step 3 : KL-Divergence loss (scaled by TÂ²) â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
# KLDiv = Î£ soft_targets Ã— (log(soft_targets) âˆ’ soft_prob)
#       = 0.5065Ã—0.0504 + 0.3072Ã—0.0504 + 0.1863Ã—(âˆ’0.1996)  â‰ˆ  0.0038
# Ã— TÂ²  = 0.0038 Ã— 4                                         â‰ˆ  0.0153
distillation_loss = nn.KLDivLoss(reduction="batchmean")(soft_prob, soft_targets) * (T * T)
# distillation_loss   =>  â‰ˆ 0.0153

# â”€â”€ Step 4 : Cross-Entropy loss (Student vs hard label) â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
# student_logits flat  =>  [[1.5, 0.5, 0.0]],  label = [0]
# exp(...)             =>  [4.4817, 1.6487, 1.0000]   sum = 7.1304
# P(class=0)          =   4.4817 / 7.1304  â‰ˆ  0.6285
# CE = âˆ’log(0.6285)                         â‰ˆ  0.4644
student_loss = F.cross_entropy(
    student_logits.view(-1, student_logits.size(-1)),  # => shape [1, 3]
    labels.view(-1)                                    # => shape [1]
)
# student_loss        =>  â‰ˆ 0.4644

# â”€â”€ Step 5 : Combined loss â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
# = 0.5 Ã— 0.0153 + 0.5 Ã— 0.4644
# = 0.0077         + 0.2322         â‰ˆ  0.2399
combined_loss = alpha * distillation_loss + (1 - alpha) * student_loss
# combined_loss       =>  â‰ˆ 0.2399

print(f"soft_targets      : {[round(v,4) for v in soft_targets.squeeze().tolist()]}")
# => [0.5065, 0.3072, 0.1863]
print(f"soft_prob         : {[round(v,4) for v in soft_prob.squeeze().tolist()]}")
# => [-0.7315, -1.2315, -1.4815]
print(f"distillation_loss : {distillation_loss.item():.4f}")   # => 0.0153
print(f"student_loss      : {student_loss.item():.4f}")        # => 0.4644
print(f"combined_loss     : {combined_loss.item():.4f}")       # => 0.2399

## ðŸ”„ The Distillation Loop
A manual step-by-step distillation training step.

In [None]:
optimizer = torch.optim.Adam(student.parameters(), lr=1e-4)

inputs = tokenizer("The quick brown fox jumps over the lazy dog", return_tensors="pt")
labels = inputs["input_ids"]

# Step 1: Get Teacher Knowledge (no gradients here)
with torch.no_grad():
    teacher_out = teacher(**inputs)
    teacher_logits = teacher_out.logits

# Step 2: Student Forward Pass
student_out = student(**inputs)
student_logits = student_out.logits

# Step 3: Compute Loss & Update
loss = manual_distillation_loss(student_logits, teacher_logits, labels)
loss.backward()
optimizer.step()

print(f"Training step lost: {loss.item():.4f}")