# Lesson 5: Knowledge Distillation

**Module 4: Model Development & Optimization**  
**Estimated Time**: 1-2 hours  
**Difficulty**: Advanced

---

## ðŸŽ¯ Learning Objectives

By the end of this lesson, you will:

âœ… Understand the Teacher-Student configuration  
âœ… Learn why "Dark Knowledge" (Soft Tags) helps small models learn  
âœ… Implement a custom Distillation Loss function in PyTorch  
âœ… Answer interview questions on model compression  

---

## ðŸ“š Table of Contents

1. [The Concept: Teacher & Student](#1-concept)
2. [The Magic: Soft Targets & Temperature](#2-temperature)
3. [Hands-On: Distillation Loss Implementation](#3-hands-on)
4. [Interview Preparation](#4-interview-questions)

---

## 1. The Concept: Teacher & Student

**Goal**: We want the performance of a huge model (e.g., BERT-Large) but the speed of a tiny model (e.g., DistilBERT).

**Method**: Train the tiny model (Student) to mimic the huge model (Teacher). 
The Student learns not just from the ground truth labels (Hard Targets), but from the Teacher's probability distribution (Soft Targets).

## 2. The Magic: Soft Targets & Temperature

**Hard Target**: `[0, 1, 0]` (It's a Dog).
**Teacher Output**: `[0.05, 0.90, 0.05]` (It's mostly a Dog, but looks 5% like a Cat).

That 5% is **Dark Knowledge**. It tells the Student "Dogs look a bit like Cats sometimes". This richness helps the Student generalize better than just learning "Dog".

**Temperature (T)**: 
To make the distribution softer (reveal more dark knowledge), we divide logits by T > 1 before Softmax.

## 3. Hands-On: Distillation Loss Implementation

The loss is a combination of:
1. **Student vs Ground Truth** (CrossEntropy)
2. **Student vs Teacher** (KL Divergence)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.5):
    """
    student_logits: Output of student model
    teacher_logits: Output of teacher model (frozen)
    labels: Ground truth
    T: Temperature
    alpha: Weight for Soft Target loss
    """
    
    # 1. Hard Target Loss (Student vs Label)
    hard_loss = F.cross_entropy(student_logits, labels)
    
    # 2. Soft Target Loss (Student vs Teacher)
    # Apply Temperature
    soft_student = F.log_softmax(student_logits / T, dim=1)
    soft_teacher = F.softmax(teacher_logits / T, dim=1)
    
    # KL Divergence
    distillation_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T * T)
    
    # Combine
    total_loss = (1. - alpha) * hard_loss + alpha * distillation_loss
    return total_loss

# Simulation
s_logits = torch.tensor([[2.0, 5.0, 1.0]], requires_grad=True) # Learns
t_logits = torch.tensor([[2.1, 5.8, 1.2]]) # Fixed
labels = torch.tensor([1])

loss = distillation_loss(s_logits, t_logits, labels)
print(f"Distillation Loss: {loss.item():.4f}")

## 4. Interview Preparation

### Common Questions

#### Q1: "Why use Temperature > 1?"
**Answer**: "Standard Softmax pushes probabilities towards 0 and 1 (sharp peaks). High temperature softens the distribution, spreading probability mass to the incorrect classes. This reveals the structural similarity between classes (e.g., Truck is more similar to Car than to Frog), which provides more information to the student."

#### Q2: "Can a Student outperform a Teacher?"
**Answer**: "Rarely in raw accuracy if capacities differ significantly. However, a Student can outperform a Teacher trained **only** on labels, because the Student gets the extra supervision ('Dark Knowledge') from the Teacher's soft targets."