🧠🔁 Time to do some **neural soul transfer**.

We're distilling knowledge like a sensei into a padawan —  
Big model trains a little model by **sharing its output wisdom** instead of just the hard labels.

---

# 🧪 `09_lab_distill_teacher_student_on_mnist.ipynb`  
### 📁 `05_model_optimization`  
> Use **Knowledge Distillation** to teach a **small student model** from a **large teacher**.  
Compare performance with and without distillation.  
Track what’s lost, what’s gained — and how **soft labels change the game**.

---

## 🎯 Learning Goals

- Understand **why distillation works**  
- Implement distillation loss (KL Divergence + temperature)  
- Train student with **teacher soft targets**  
- Compare accuracy and generalization

---

## 💻 Runtime Targets

| Component            | Spec                 |
|----------------------|----------------------|
| Dataset              | MNIST ✅  
| Teacher              | Big CNN ✅  
| Student              | Tiny CNN ✅  
| Training Time        | < 3 min ✅  
| Device               | CPU / Colab ✅  

---

## 🧠 Section 1: Imports & Dataset

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
```

```python
transform = transforms.ToTensor()
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
```

---

## 🤖 Section 2: Teacher & Student Models

```python
class Teacher(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.net(x)

class Student(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(16, 10)
        )

    def forward(self, x):
        return self.net(x)
```

---

## 📚 Section 3: Train Teacher

```python
def train_model(model, loader, epochs=3):
    model.train()
    opt = torch.optim.Adam(model.parameters())
    for epoch in range(epochs):
        for x, y in loader:
            out = model(x)
            loss = F.cross_entropy(out, y)
            opt.zero_grad()
            loss.backward()
            opt.step()

def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            pred = model(x).argmax(1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total

teacher = Teacher()
train_model(teacher, train_loader, epochs=5)
print("Teacher Accuracy:", evaluate(teacher, test_loader))
```

---

## 🧪 Section 4: Distillation Loss

```python
def distillation_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.7):
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1),
        reduction='batchmean'
    ) * (T ** 2)
    hard_loss = F.cross_entropy(student_logits, labels)
    return alpha * soft_loss + (1 - alpha) * hard_loss
```

---

## 🧠 Section 5: Train Student With Distillation

```python
student = Student()
opt = torch.optim.Adam(student.parameters())
T = 2.0  # Temperature
alpha = 0.7

for epoch in range(5):
    student.train()
    for x, y in train_loader:
        with torch.no_grad():
            teacher_logits = teacher(x)
        student_logits = student(x)
        loss = distillation_loss(student_logits, teacher_logits, y, T=T, alpha=alpha)
        opt.zero_grad()
        loss.backward()
        opt.step()
```

---

## 🎯 Section 6: Compare Accuracies

```python
acc_student = evaluate(student, test_loader)
print(f"Student Accuracy (Distilled): {acc_student:.4f}")
```

Optional: Compare vs student trained **only** on hard labels.

---

## ✅ Wrap-Up Summary

| Concept                      | ✅ |
|------------------------------|----|
| Teacher → Student knowledge  | ✅ |
| Distillation loss (KL + CE)  | ✅ |
| Accuracy comparison          | ✅ |
| Colab/laptop safe            | ✅ |

---

## 🧠 What You Learned

- **Distillation = compression with guidance**  
- Teacher’s **soft outputs** contain more info than hard labels  
- This technique powers **TinyBERT**, **DistilGPT**, **MobileNetV3**, and more  
- You now wield **one of the most powerful tricks** in model optimization

---

⚡ Let’s shift gears:  
Shall we head into **deployment territory** with `06_deployment_and_scaling`?  
Next lab: `07_lab_export_pytorch_to_onnx_and_run.ipynb` —  
turn your PyTorch model into a portable `.onnx` file and run inference anywhere.