# Classical LLM Distillation (BERT → DistilBERT)

## Overview
This is a step-by-step outline of how Large Language Model (LLM) distillation is performed, using **BERT** as the teacher model and **DistilBERT** as the student model.

---

## Process Breakdown

| Block                   | Purpose                                                                                                                              |
| ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------ |
| **0. Imports & Config** | Import all required Python libraries (Transformers, Torch, etc.), detect GPU/CPU, and set hyperparameters (batch size, learning rate, epochs, temperature, alpha). |
| **1. Dataset**          | Load a demo text dataset, tokenize it using the tokenizer; the data collator automatically pads sequences so each batch is uniform. |
| **2. Models**           | Load pretrained **BERT** as the teacher (freeze parameters), and **DistilBERT** as the student (trainable).                         |
| **3. Losses**           | Use two loss functions: <br>• **CrossEntropy (CE)** = for hard labels (ground truth) <br>• **KL Divergence** = for soft labels (teacher’s logits → softmax with temperature). |
| **4. Optimizer**        | Use **AdamW** optimizer with a **linear learning rate scheduler** for smoother and more stable training.                            |
| **5. distill_epoch()**  | For each batch: <br>• Get teacher logits and create soft targets with temperature <br>• Get student logits <br>• Compute soft loss and hard loss, combine them using α <br>• Backpropagate **only** through the student model. |
| **evaluate()**          | Evaluate the student model’s accuracy on the validation set to monitor performance improvements.                                   |
| **Loop**                | For each epoch: run `distill_epoch()` followed by `evaluate()`.                                                                     |
| **Save**                | Save the fine-tuned student model to disk for future inference.                                                                     |

---

# Distillation: step-by-step (concise reference)

> **Tip:** For very large datasets or multi-GPU/TPU training use Hugging Face’s `DistillationTrainer` or `accelerate`. The core algorithm below stays the same.

---

## Overview (what’s happening)
1. Train a large **teacher** model (high capacity) normally and freeze it.  
2. Train a smaller **student** model to mimic the teacher **and** the ground-truth labels.  
3. Student loss = weighted combination of a **soft** (teacher) loss and a **hard** (label) loss.

---

## Step-by-step

### 1. Teacher output (`t_soft`)
- Compute teacher logits: `z_T = teacher(x)` (teacher is frozen; run under `torch.no_grad()`).
- Apply **temperature** `T` and softmax to get *soft targets*:
  \[
  p_T = \text{softmax}\!\left(\frac{z_T}{T}\right)
  \]
- `T > 1` “softens” the distribution (reveals class similarities).

### 2. Student output (`s_soft`)
- Compute student logits: `z_S = student(x)`.
- Convert to log-probabilities at the same temperature:
  \[
  \log q_S = \log\text{softmax}\!\left(\frac{z_S}{T}\right)
  \]

### 3. Distillation (soft) loss
- Use KL divergence (teacher distribution → student distribution).
- In PyTorch style: `nn.KLDivLoss(reduction='batchmean')(log_q_S, p_T)`
- Multiply the KL loss by `T^2` to correct gradient scale (Hinton et al.):
  \[
  L_{\text{soft}} = T^2 \cdot \text{KL}(p_T \,\|\, q_S)
  \]

### 4. Hard (label) loss
- Standard cross-entropy between student logits and true labels:
  \[
  L_{\text{hard}} = \text{CE}(z_S, y)
  \]

### 5. Combine
- Weighted sum:
  \[
  L = \alpha \cdot L_{\text{soft}} + (1-\alpha)\cdot L_{\text{hard}}
  \]
- Typical choices: `T ∈ [2,5]`, `α ≈ 0.5` (tune for your task).
- If `α = 1` → pure distillation (no hard labels). If `α = 0` → normal fine-tuning (no distillation).

---

## Implementation notes / best practices
- Freeze teacher: `teacher.eval()` and use `with torch.no_grad()` when generating `z_T`. This saves memory and avoids updating teacher weights.
- Use `F.softmax(z_T / T, dim=-1)` for teacher targets and `F.log_softmax(z_S / T, dim=-1)` for student input to `KLDivLoss`.
- In PyTorch, prefer `nn.KLDivLoss(reduction='batchmean')` for stable gradients.
- Multiply KL term by `T**2` (important — otherwise gradients from soft targets are scaled down).
- Monitor both components (`L_soft`, `L_hard`) and validation accuracy.

---

## Small numeric example (temperature effect)
- Teacher logits: `[10, 2]`
  - `T = 1` → `softmax([10,2]) ≈ [0.9997, 0.0003]` (very peaked)
  - `T = 2` → `softmax([5,1]) ≈ [0.982, 0.018]` (softer, reveals second choice)
  - `T = 5` → `softmax([2,0.4]) ≈ [0.83, 0.17]` (much softer)
- Softer distributions reveal the teacher’s relative beliefs and help the student learn nuanced class relations.

---

## Why distillation helps
- **Soft targets** encode “dark knowledge”: relative similarities between classes that hard labels hide.  
- Student learns both the dataset labels **and** the teacher’s nuanced behavior → better generalization for a much smaller model.

---

## Short pseudocode (conceptual)
1. `z_T = teacher(x)`  (no grad)
2. `p_T = softmax(z_T / T)`
3. `z_S = student(x)`
4. `log_q_S = log_softmax(z_S / T)`
5. `loss_soft = T^2 * KLDiv(log_q_S, p_T)`
6. `loss_hard = CrossEntropy(z_S, y)`
7. `loss = alpha * loss_soft + (1 - alpha) * loss_hard`
8. `loss.backward()` and `optimizer.step()` (update only student)

---

## Quick hyperparameter suggestions
- `T = 2` (good starting point), try `2–5`.  
- `alpha = 0.3–0.7` depending on trust in teacher vs labels.  
- `batch size`: 64–256 (task-dependent).  
- Ensure teacher has good accuracy before distillation.

---

## Final note
Distillation is **not** just fine-tuning: it explicitly transfers the teacher’s learned distributional knowledge (soft targets) into a compact student while still respecting hard labels. For large-scale runs, use `DistillationTrainer` / `accelerate` to scale cleanly across devices.


In [None]:
# !pip install --upgrade datasets fsspec transformers