# Notebook 26: Fine-Tuning vs Distillation

## Inference Engineering Course

---

## Overview

When a pre-trained model doesn't perform well enough for your specific use case, you have two main options for improvement:

### Fine-Tuning
Adapt the model's weights directly on your task-specific data. With **LoRA** (Low-Rank Adaptation), you can do this efficiently by only training a small number of new parameters.

### Knowledge Distillation
Train a smaller **student** model to mimic a larger **teacher** model. This gives you the quality of a large model with the speed of a small one.

```
Fine-Tuning:                    Distillation:

Pre-trained Model               Teacher (Large)
     |                               |
     | + Task Data                   | Soft Labels
     v                               v
Fine-tuned Model               Student (Small)
(same size, better)            (faster, nearly as good)
```

### What You'll Learn

| Topic | Description |
|-------|-------------|
| LoRA Setup | Configure Parameter-Efficient Fine-Tuning (PEFT) |
| Fine-Tuning | Train a model on custom data |
| Distillation | Transfer knowledge from teacher to student |
| Comparison | Quality, cost, and speed trade-offs |
| Loss Curves | Visualize and interpret training dynamics |

### Prerequisites
- Understanding of neural networks and gradient descent
- Google Colab with GPU runtime (T4 is sufficient)

> **Important**: Enable GPU: `Runtime > Change runtime type > T4 GPU`

In [None]:
# ============================================================
# Install dependencies
# ============================================================
!pip install transformers datasets peft accelerate bitsandbytes -q
!pip install matplotlib numpy pandas scikit-learn -q

print("Installation complete!")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from torch.utils.data import DataLoader, Dataset
import time
import warnings
warnings.filterwarnings('ignore')

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
else:
    print("WARNING: No GPU. This notebook will use CPU (slower but functional).")

---

## Section 1: Understanding LoRA (Low-Rank Adaptation)

LoRA is the most popular parameter-efficient fine-tuning method. Instead of updating all model weights, it adds small **low-rank matrices** to specific layers.

### How LoRA Works

For a weight matrix $W \in \mathbb{R}^{d \times k}$, LoRA adds:

$$W_{new} = W_{frozen} + \frac{\alpha}{r} \cdot B \cdot A$$

Where:
- $A \in \mathbb{R}^{r \times k}$ (down projection)
- $B \in \mathbb{R}^{d \times r}$ (up projection)
- $r$ = rank (typically 4-64, much smaller than $d$)
- $\alpha$ = scaling factor

### Parameter Savings

```
Full fine-tuning:  d * k parameters     (e.g., 4096 * 4096 = 16.7M)
LoRA:              r * (d + k) parameters (e.g., 8 * (4096 + 4096) = 65K)
Savings:           ~256x fewer parameters!
```

In [None]:
# ============================================================
# Visual explanation of LoRA
# ============================================================

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Plot 1: Full Fine-Tuning vs LoRA parameter count
d = 4096  # Hidden dimension
ranks = [1, 2, 4, 8, 16, 32, 64, 128]

full_params = d * d  # Full weight matrix
lora_params = [r * (d + d) for r in ranks]
savings = [full_params / lp for lp in lora_params]

ax = axes[0]
ax.bar(range(len(ranks)), [lp / 1e6 for lp in lora_params], 
       color='steelblue', alpha=0.8)
ax.axhline(y=full_params / 1e6, color='red', linestyle='--', 
           linewidth=2, label=f'Full Fine-Tuning ({full_params/1e6:.1f}M)')
ax.set_xlabel('LoRA Rank', fontsize=12)
ax.set_ylabel('Trainable Parameters (M)', fontsize=12)
ax.set_title('LoRA vs Full Fine-Tuning\n(Trainable Parameters)', fontsize=13, fontweight='bold')
ax.set_xticks(range(len(ranks)))
ax.set_xticklabels(ranks)
ax.legend(fontsize=10)
ax.set_yscale('log')
ax.grid(True, alpha=0.3, axis='y')

# Plot 2: LoRA Architecture Diagram
ax = axes[1]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title('LoRA Architecture', fontsize=13, fontweight='bold')

# Frozen weights
frozen = plt.Rectangle((0.5, 3), 3, 4, fill=True, color='#E3F2FD', 
                        edgecolor='#1565C0', linewidth=2)
ax.add_patch(frozen)
ax.text(2, 5, 'W\n(Frozen)', ha='center', va='center', fontsize=14, fontweight='bold')

# LoRA A matrix
lora_a = plt.Rectangle((5.5, 4.5), 1, 2.5, fill=True, color='#FFF3E0',
                        edgecolor='#E65100', linewidth=2)
ax.add_patch(lora_a)
ax.text(6, 5.75, 'A', ha='center', va='center', fontsize=14, fontweight='bold', color='#E65100')

# LoRA B matrix
lora_b = plt.Rectangle((7.5, 3), 2.5, 1, fill=True, color='#FFF3E0',
                        edgecolor='#E65100', linewidth=2)
ax.add_patch(lora_b)
ax.text(8.75, 3.5, 'B', ha='center', va='center', fontsize=14, fontweight='bold', color='#E65100')

# Arrows and labels
ax.annotate('', xy=(5.5, 5.75), xytext=(3.5, 5.75),
            arrowprops=dict(arrowstyle='->', color='gray', lw=1.5))
ax.annotate('', xy=(8.75, 4.5), xytext=(8.75, 7),
            arrowprops=dict(arrowstyle='->', color='gray', lw=1.5))
ax.annotate('', xy=(8.75, 4), xytext=(6, 4.5),
            arrowprops=dict(arrowstyle='->', color='#E65100', lw=1.5))

ax.text(5, 2, 'x → W·x + (α/r)·B·A·x', ha='center', fontsize=11, style='italic',
        bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))
ax.text(5, 1, 'Only A and B are trained!', ha='center', fontsize=11, 
        fontweight='bold', color='#E65100')

# Plot 3: Memory comparison
ax = axes[2]
model_sizes = ['GPT-2\n(124M)', 'Phi-2\n(2.7B)', 'Llama-7B\n(7B)', 'Llama-13B\n(13B)']
full_memory = [0.5, 10.8, 28, 52]  # GB for full fine-tuning
lora_memory = [0.5, 5.8, 14.5, 27]  # GB for LoRA (rank=8)

x = np.arange(len(model_sizes))
width = 0.35
bars1 = ax.bar(x - width/2, full_memory, width, label='Full Fine-Tuning', 
               color='#E57373', alpha=0.8)
bars2 = ax.bar(x + width/2, lora_memory, width, label='LoRA (r=8)',
               color='#4CAF50', alpha=0.8)

ax.axhline(y=16, color='blue', linestyle=':', alpha=0.5, label='T4 GPU (16GB)')
ax.set_xlabel('Model', fontsize=12)
ax.set_ylabel('GPU Memory Required (GB)', fontsize=12)
ax.set_title('GPU Memory: Full vs LoRA', fontsize=13, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(model_sizes)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('lora_overview.png', dpi=150, bbox_inches='tight')
plt.show()

---

## Section 2: LoRA Fine-Tuning Setup with PEFT

Let's fine-tune a small model using the PEFT library from HuggingFace. We'll use `GPT-2` (124M params) to keep things fast on free Colab.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import LoraConfig, get_peft_model, TaskType

# ============================================================
# Load base model and tokenizer
# ============================================================

MODEL_NAME = "gpt2"  # 124M parameters, fits easily on T4

print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"\nBase model loaded: {MODEL_NAME}")
print(f"Total parameters: {total_params:,} ({total_params/1e6:.1f}M)")

In [None]:
# ============================================================
# Configure LoRA
# ============================================================

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,                    # Rank - lower = fewer params, higher = more capacity
    lora_alpha=32,          # Scaling factor
    lora_dropout=0.05,      # Dropout for regularization
    target_modules=["c_attn", "c_proj"],  # Which layers to adapt (GPT-2 specific)
    bias="none",
)

# Apply LoRA to the model
peft_model = get_peft_model(model, lora_config)

# Compare parameter counts
trainable_params = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in peft_model.parameters())

print("\nLoRA Configuration:")
print(f"  Rank: {lora_config.r}")
print(f"  Alpha: {lora_config.lora_alpha}")
print(f"  Target modules: {lora_config.target_modules}")
print(f"\nParameter Summary:")
print(f"  Total parameters:     {all_params:>12,}")
print(f"  Trainable (LoRA):     {trainable_params:>12,}")
print(f"  Frozen:               {all_params - trainable_params:>12,}")
print(f"  Trainable %:          {100 * trainable_params / all_params:>11.2f}%")
print(f"  Parameter reduction:  {all_params / trainable_params:>11.0f}x")

peft_model.print_trainable_parameters()

In [None]:
# ============================================================
# Create a custom fine-tuning dataset
# ============================================================
# Task: Customer support response generation

training_data = [
    {"input": "Customer: My order hasn't arrived yet.\nAgent:",
     "output": " I apologize for the delay. Let me look up your order number and check the shipping status for you right away."},
    {"input": "Customer: I want to return this product.\nAgent:",
     "output": " I'd be happy to help with your return. Could you provide your order number? Our return policy allows returns within 30 days."},
    {"input": "Customer: This product is broken!\nAgent:",
     "output": " I'm sorry to hear that. We'll get this resolved. Can you describe the issue? We can arrange a replacement or refund."},
    {"input": "Customer: How do I change my shipping address?\nAgent:",
     "output": " You can update your shipping address in your account settings. If you have a pending order, I can update it for you."},
    {"input": "Customer: I was charged twice for my order!\nAgent:",
     "output": " I sincerely apologize for the billing error. Let me investigate this immediately and process a refund for the duplicate charge."},
    {"input": "Customer: Can I get a discount on this item?\nAgent:",
     "output": " While I can't offer individual discounts, let me check for any active promotions or coupon codes that might apply to your purchase."},
    {"input": "Customer: The product doesn't match the description.\nAgent:",
     "output": " I'm sorry for the confusion. Can you tell me what's different from what you expected? We want to make sure our listings are accurate."},
    {"input": "Customer: I need help setting up my new device.\nAgent:",
     "output": " I'd be happy to walk you through the setup process. What device did you purchase and what step are you currently on?"},
    {"input": "Customer: When will the item be back in stock?\nAgent:",
     "output": " Let me check our inventory system for an estimated restock date. I can also set up a notification for you when it becomes available."},
    {"input": "Customer: I'm very disappointed with your service.\nAgent:",
     "output": " I'm truly sorry to hear about your experience. Your feedback is important to us. Please tell me more so I can make this right."},
] * 5  # Repeat to create a slightly larger dataset (50 examples)

class CustomerSupportDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.examples = []
        for item in data:
            text = item['input'] + item['output'] + tokenizer.eos_token
            encoding = tokenizer(
                text,
                max_length=max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            self.examples.append({
                'input_ids': encoding['input_ids'].squeeze(),
                'attention_mask': encoding['attention_mask'].squeeze(),
                'labels': encoding['input_ids'].squeeze().clone(),
            })
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        return self.examples[idx]

dataset = CustomerSupportDataset(training_data, tokenizer)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

print(f"Dataset size: {len(dataset)} examples")
print(f"Batch size: 4")
print(f"Number of batches: {len(dataloader)}")
print(f"Sample input: {training_data[0]['input']}")

---

## Section 3: Fine-Tuning Training Loop

In [None]:
# ============================================================
# Training loop for LoRA fine-tuning
# ============================================================

optimizer = torch.optim.AdamW(peft_model.parameters(), lr=3e-4, weight_decay=0.01)
num_epochs = 5

# Track metrics
train_losses = []
epoch_losses = []
step_times = []

peft_model.train()
print("Starting LoRA fine-tuning...")
print("=" * 60)

total_start = time.time()

for epoch in range(num_epochs):
    epoch_loss = 0
    num_batches = 0
    
    for batch in dataloader:
        step_start = time.time()
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = peft_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs.loss
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(peft_model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()
        
        step_time = time.time() - step_start
        step_times.append(step_time)
        train_losses.append(loss.item())
        epoch_loss += loss.item()
        num_batches += 1
    
    avg_epoch_loss = epoch_loss / num_batches
    epoch_losses.append(avg_epoch_loss)
    
    print(f"Epoch {epoch+1}/{num_epochs}: Loss = {avg_epoch_loss:.4f} | "
          f"Avg step time: {np.mean(step_times[-num_batches:]):.3f}s")

total_time = time.time() - total_start
print(f"\nTraining complete in {total_time:.1f}s")
print(f"Final loss: {epoch_losses[-1]:.4f}")
print(f"Loss reduction: {epoch_losses[0] - epoch_losses[-1]:.4f} ({(1 - epoch_losses[-1]/epoch_losses[0])*100:.1f}%)")

In [None]:
# ============================================================
# Test the fine-tuned model
# ============================================================

peft_model.eval()

test_prompts = [
    "Customer: My package was damaged during shipping.\nAgent:",
    "Customer: Can I cancel my subscription?\nAgent:",
    "Customer: The app keeps crashing on my phone.\nAgent:",
]

print("Fine-tuned Model Outputs:")
print("=" * 60)

for prompt in test_prompts:
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        output = peft_model.generate(
            **inputs,
            max_new_tokens=60,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    generated = tokenizer.decode(output[0], skip_special_tokens=True)
    print(f"\nPrompt: {prompt}")
    print(f"Response: {generated[len(prompt):][:200]}")
    print("-" * 60)

---

## Section 4: Knowledge Distillation - Teacher to Student

Knowledge distillation transfers the "knowledge" of a larger model (teacher) to a smaller model (student).

### Key Concept: Soft Labels

Instead of training on hard labels (one-hot vectors), the student learns from the teacher's **probability distribution** over all tokens. This distribution contains rich information about which tokens are similar.

$$\mathcal{L}_{distill} = \alpha \cdot \text{KL}(\text{softmax}(z_t/T) \| \text{softmax}(z_s/T)) + (1-\alpha) \cdot \mathcal{L}_{CE}$$

Where:
- $z_t, z_s$ = teacher and student logits
- $T$ = temperature (higher = softer distributions)
- $\alpha$ = balance between distillation and task loss

In [None]:
# ============================================================
# Visualize: Hard vs Soft Labels
# ============================================================

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Simulated logits for the word after "The cat sat on the"
tokens = ['mat', 'floor', 'chair', 'bed', 'table', 'roof', 'car', 'moon', 'code', 'xyz']
logits = np.array([5.2, 3.8, 3.1, 2.5, 2.0, 1.2, 0.5, -0.3, -1.5, -3.0])

# Hard label (one-hot)
hard = np.zeros(len(tokens))
hard[0] = 1.0
axes[0].bar(tokens, hard, color='#E57373', alpha=0.8)
axes[0].set_title('Hard Labels (One-Hot)', fontsize=13, fontweight='bold')
axes[0].set_ylabel('Probability', fontsize=11)
axes[0].set_ylim(0, 1.1)
axes[0].tick_params(axis='x', rotation=45)

# Soft labels (T=1, standard softmax)
soft_t1 = np.exp(logits) / np.sum(np.exp(logits))
axes[1].bar(tokens, soft_t1, color='#4CAF50', alpha=0.8)
axes[1].set_title('Soft Labels (T=1)', fontsize=13, fontweight='bold')
axes[1].set_ylabel('Probability', fontsize=11)
axes[1].set_ylim(0, 1.1)
axes[1].tick_params(axis='x', rotation=45)

# Soft labels with temperature (T=3)
T = 3
soft_t3 = np.exp(logits/T) / np.sum(np.exp(logits/T))
axes[2].bar(tokens, soft_t3, color='#2196F3', alpha=0.8)
axes[2].set_title('Soft Labels (T=3, "Softer")', fontsize=13, fontweight='bold')
axes[2].set_ylabel('Probability', fontsize=11)
axes[2].set_ylim(0, 1.1)
axes[2].tick_params(axis='x', rotation=45)

# Add annotations
for ax in axes:
    ax.grid(True, alpha=0.3, axis='y')

fig.suptitle('Why Soft Labels Carry More Information',
             fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig('soft_labels.png', dpi=150, bbox_inches='tight')
plt.show()

print("With hard labels, the student only learns 'mat' is correct.")
print("With soft labels, it also learns 'floor' and 'chair' are reasonable alternatives.")
print("Higher temperature reveals more of this inter-token similarity.")

In [None]:
# ============================================================
# Implement distillation with simple neural networks
# ============================================================
# For demonstration, we'll use simple feedforward networks
# The concepts transfer directly to LLMs

# Create a classification task (simulating language modeling)
np.random.seed(42)
torch.manual_seed(42)

# Generate synthetic data
n_samples = 2000
n_features = 50
n_classes = 10

X = torch.randn(n_samples, n_features)
# Create labels with some structure
W_true = torch.randn(n_features, n_classes)
logits_true = X @ W_true
y = logits_true.argmax(dim=1)

# Split data
train_X, test_X = X[:1600], X[1600:]
train_y, test_y = y[:1600], y[1600:]

# ---- Teacher Model (Large) ----
class TeacherModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, num_classes),
        )
    
    def forward(self, x):
        return self.layers(x)

# ---- Student Model (Small) ----
class StudentModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes),
        )
    
    def forward(self, x):
        return self.layers(x)

teacher = TeacherModel(n_features, 256, n_classes).to(device)
teacher_params = sum(p.numel() for p in teacher.parameters())

student_standalone = StudentModel(n_features, 64, n_classes).to(device)
student_distilled = StudentModel(n_features, 64, n_classes).to(device)
student_params = sum(p.numel() for p in student_standalone.parameters())

print(f"Teacher parameters: {teacher_params:,}")
print(f"Student parameters: {student_params:,}")
print(f"Size ratio: {teacher_params / student_params:.1f}x")

In [None]:
# ============================================================
# Train the Teacher Model
# ============================================================

def train_model(model, train_X, train_y, epochs=50, lr=1e-3, verbose=True):
    """Standard training loop."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    losses = []
    
    X = train_X.to(device)
    y = train_y.to(device)
    
    model.train()
    for epoch in range(epochs):
        logits = model(X)
        loss = criterion(logits, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        if verbose and (epoch + 1) % 10 == 0:
            acc = (logits.argmax(dim=1) == y).float().mean()
            print(f"  Epoch {epoch+1}/{epochs}: Loss={loss.item():.4f}, Acc={acc:.3f}")
    
    return losses

print("Training Teacher Model...")
teacher_losses = train_model(teacher, train_X, train_y, epochs=50)

# Evaluate teacher
teacher.eval()
with torch.no_grad():
    teacher_acc = (teacher(test_X.to(device)).argmax(dim=1) == test_y.to(device)).float().mean()
print(f"\nTeacher test accuracy: {teacher_acc:.3f}")

In [None]:
# ============================================================
# Train Student (standalone - no distillation)
# ============================================================

print("Training Student (standalone, no distillation)...")
student_standalone_losses = train_model(student_standalone, train_X, train_y, epochs=50)

student_standalone.eval()
with torch.no_grad():
    standalone_acc = (student_standalone(test_X.to(device)).argmax(dim=1) == test_y.to(device)).float().mean()
print(f"\nStudent (standalone) test accuracy: {standalone_acc:.3f}")

In [None]:
# ============================================================
# Train Student with Distillation
# ============================================================

def distillation_loss(student_logits, teacher_logits, labels, 
                      temperature=3.0, alpha=0.7):
    """
    Combined distillation + task loss.
    
    Args:
        student_logits: Student model output logits
        teacher_logits: Teacher model output logits
        labels: True labels
        temperature: Softening temperature
        alpha: Weight for distillation loss (1-alpha for task loss)
    """
    # Soft targets from teacher
    soft_teacher = F.softmax(teacher_logits / temperature, dim=1)
    soft_student = F.log_softmax(student_logits / temperature, dim=1)
    
    # KL divergence loss (distillation)
    distill_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (temperature ** 2)
    
    # Standard cross-entropy loss (task)
    task_loss = F.cross_entropy(student_logits, labels)
    
    # Combined loss
    total_loss = alpha * distill_loss + (1 - alpha) * task_loss
    
    return total_loss, distill_loss.item(), task_loss.item()


def train_with_distillation(student, teacher, train_X, train_y, 
                            epochs=50, lr=1e-3, temperature=3.0, alpha=0.7):
    """Train student with knowledge distillation."""
    optimizer = torch.optim.Adam(student.parameters(), lr=lr)
    
    X = train_X.to(device)
    y = train_y.to(device)
    
    losses = []
    distill_losses = []
    task_losses = []
    
    teacher.eval()
    student.train()
    
    for epoch in range(epochs):
        # Get teacher predictions (no gradient needed)
        with torch.no_grad():
            teacher_logits = teacher(X)
        
        # Student forward pass
        student_logits = student(X)
        
        # Compute combined loss
        loss, d_loss, t_loss = distillation_loss(
            student_logits, teacher_logits, y,
            temperature=temperature, alpha=alpha
        )
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        distill_losses.append(d_loss)
        task_losses.append(t_loss)
        
        if (epoch + 1) % 10 == 0:
            acc = (student_logits.argmax(dim=1) == y).float().mean()
            print(f"  Epoch {epoch+1}/{epochs}: Loss={loss.item():.4f} "
                  f"(Distill={d_loss:.4f}, Task={t_loss:.4f}), Acc={acc:.3f}")
    
    return losses, distill_losses, task_losses


print("Training Student (with distillation from Teacher)...")
distill_losses, distill_kl_losses, distill_task_losses = train_with_distillation(
    student_distilled, teacher, train_X, train_y,
    epochs=50, temperature=3.0, alpha=0.7
)

student_distilled.eval()
with torch.no_grad():
    distilled_acc = (student_distilled(test_X.to(device)).argmax(dim=1) == test_y.to(device)).float().mean()
print(f"\nStudent (distilled) test accuracy: {distilled_acc:.3f}")

---

## Section 5: Comparing Results

In [None]:
# ============================================================
# Visualization: Loss Curves Comparison
# ============================================================

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: All loss curves
axes[0].plot(teacher_losses, label='Teacher', linewidth=2, color='#2196F3')
axes[0].plot(student_standalone_losses, label='Student (standalone)', 
             linewidth=2, color='#FF9800', linestyle='--')
axes[0].plot(distill_losses, label='Student (distilled)', 
             linewidth=2, color='#4CAF50')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training Loss Curves', fontsize=13, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Plot 2: Distillation loss components
axes[1].plot(distill_kl_losses, label='KL Divergence (Distill)', 
             linewidth=2, color='#9C27B0')
axes[1].plot(distill_task_losses, label='Cross-Entropy (Task)', 
             linewidth=2, color='#FF5722')
axes[1].plot(distill_losses, label='Combined', linewidth=2, 
             color='#4CAF50', linestyle='--')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Loss', fontsize=12)
axes[1].set_title('Distillation Loss Components', fontsize=13, fontweight='bold')
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)

# Plot 3: Accuracy comparison bar chart
models = ['Teacher\n(Large)', 'Student\n(Standalone)', 'Student\n(Distilled)']
accuracies = [teacher_acc.item(), standalone_acc.item(), distilled_acc.item()]
colors = ['#2196F3', '#FF9800', '#4CAF50']

bars = axes[2].bar(models, accuracies, color=colors, alpha=0.8, width=0.5,
                   edgecolor='white', linewidth=2)
for bar, acc in zip(bars, accuracies):
    axes[2].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.005,
                f'{acc:.1%}', ha='center', va='bottom', fontweight='bold', fontsize=13)

axes[2].set_ylabel('Test Accuracy', fontsize=12)
axes[2].set_title('Model Accuracy Comparison', fontsize=13, fontweight='bold')
axes[2].set_ylim(0, 1.1)
axes[2].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('training_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nKey Takeaway:")
print(f"  Distillation boosts the student's accuracy from {standalone_acc:.1%} to {distilled_acc:.1%}")
print(f"  That's a {(distilled_acc - standalone_acc)*100:.1f} percentage point improvement!")
print(f"  The student has only {student_params/teacher_params*100:.1f}% of the teacher's parameters.")

In [None]:
# ============================================================
# Measure inference speed comparison
# ============================================================

def benchmark_inference(model, X, n_runs=100):
    """Measure inference time."""
    model.eval()
    X_device = X.to(device)
    
    # Warmup
    for _ in range(10):
        with torch.no_grad():
            _ = model(X_device)
    
    if device.type == 'cuda':
        torch.cuda.synchronize()
    
    times = []
    for _ in range(n_runs):
        start = time.time()
        with torch.no_grad():
            _ = model(X_device)
        if device.type == 'cuda':
            torch.cuda.synchronize()
        times.append(time.time() - start)
    
    return np.mean(times) * 1000, np.std(times) * 1000  # ms

# Benchmark all models
test_batch = test_X[:32]  # Use batch of 32

teacher_time, teacher_std = benchmark_inference(teacher, test_batch)
student_time, student_std = benchmark_inference(student_standalone, test_batch)
distill_time, distill_std = benchmark_inference(student_distilled, test_batch)

print("\nInference Speed Benchmark (32 samples):")
print("=" * 50)
print(f"Teacher:           {teacher_time:.2f} +/- {teacher_std:.2f} ms")
print(f"Student:           {student_time:.2f} +/- {student_std:.2f} ms")
print(f"Student (distill): {distill_time:.2f} +/- {distill_std:.2f} ms")
print(f"\nSpeedup (teacher vs distilled): {teacher_time/distill_time:.2f}x")

In [None]:
# ============================================================
# Comprehensive comparison: Quality, Speed, Cost
# ============================================================

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

model_names = ['Teacher', 'Student\n(Standalone)', 'Student\n(Distilled)', 'LoRA\nFine-tuned']
colors = ['#2196F3', '#FF9800', '#4CAF50', '#9C27B0']

# Simulate LoRA fine-tuning metrics for comparison
lora_acc = min(teacher_acc.item() * 0.98, 1.0)  # Very close to teacher on specific task
lora_time = teacher_time * 1.05  # Slightly slower due to adapter overhead

# Plot 1: Accuracy
accs = [teacher_acc.item(), standalone_acc.item(), distilled_acc.item(), lora_acc]
bars = axes[0].bar(model_names, accs, color=colors, alpha=0.8)
for bar, acc in zip(bars, accs):
    axes[0].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.005,
                f'{acc:.1%}', ha='center', fontweight='bold', fontsize=11)
axes[0].set_ylabel('Accuracy', fontsize=12)
axes[0].set_title('Model Quality', fontsize=13, fontweight='bold')
axes[0].set_ylim(0, 1.15)
axes[0].grid(True, alpha=0.3, axis='y')

# Plot 2: Inference Speed
times = [teacher_time, student_time, distill_time, lora_time]
bars = axes[1].bar(model_names, times, color=colors, alpha=0.8)
for bar, t in zip(bars, times):
    axes[1].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.02,
                f'{t:.2f}ms', ha='center', fontweight='bold', fontsize=11)
axes[1].set_ylabel('Inference Time (ms)', fontsize=12)
axes[1].set_title('Inference Speed', fontsize=13, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='y')

# Plot 3: Parameter Count
params = [teacher_params, student_params, student_params, teacher_params + trainable_params]
bars = axes[2].bar(model_names, [p/1000 for p in params], color=colors, alpha=0.8)
for bar, p in zip(bars, params):
    axes[2].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.5,
                f'{p/1000:.1f}K', ha='center', fontweight='bold', fontsize=11)
axes[2].set_ylabel('Parameters (K)', fontsize=12)
axes[2].set_title('Model Size', fontsize=13, fontweight='bold')
axes[2].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('full_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================================
# Effect of temperature on distillation quality
# ============================================================

temperatures = [1.0, 2.0, 3.0, 5.0, 8.0, 10.0, 15.0, 20.0]
temp_accuracies = []

print("Sweeping distillation temperature...")
for temp in temperatures:
    student_temp = StudentModel(n_features, 64, n_classes).to(device)
    train_with_distillation(
        student_temp, teacher, train_X, train_y,
        epochs=50, temperature=temp, alpha=0.7
    )
    
    student_temp.eval()
    with torch.no_grad():
        acc = (student_temp(test_X.to(device)).argmax(dim=1) == test_y.to(device)).float().mean()
    temp_accuracies.append(acc.item())
    print(f"  T={temp:>5.1f}: Accuracy = {acc:.3f}")

# Visualize
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(temperatures, temp_accuracies, 'b-o', linewidth=2, markersize=8)
ax.axhline(y=teacher_acc.item(), color='red', linestyle='--', 
           label=f'Teacher ({teacher_acc:.3f})', alpha=0.7)
ax.axhline(y=standalone_acc.item(), color='orange', linestyle='--',
           label=f'Student standalone ({standalone_acc:.3f})', alpha=0.7)

best_temp = temperatures[np.argmax(temp_accuracies)]
best_acc = max(temp_accuracies)
ax.annotate(f'Best: T={best_temp}, Acc={best_acc:.3f}',
            xy=(best_temp, best_acc),
            xytext=(best_temp + 3, best_acc - 0.02),
            arrowprops=dict(arrowstyle='->', color='green', lw=2),
            fontsize=12, fontweight='bold', color='green')

ax.set_xlabel('Temperature', fontsize=12)
ax.set_ylabel('Test Accuracy', fontsize=12)
ax.set_title('Distillation Temperature Sweep', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('temperature_sweep.png', dpi=150, bbox_inches='tight')
plt.show()

---

## Section 6: When to Use What

| Method | Best When | Limitations |
|--------|-----------|-------------|
| **LoRA Fine-Tuning** | You need task-specific quality improvement | Same model size, same inference speed |
| **Full Fine-Tuning** | Maximum quality, sufficient resources | Expensive, risk of catastrophic forgetting |
| **Distillation** | You need a smaller/faster model | Requires a good teacher, complex setup |
| **Distillation + LoRA** | Best of both worlds | Most complex to set up |

### Decision Framework

```
Need better quality on specific task?
├── YES: Is inference speed/cost critical?
│   ├── YES: Use distillation (smaller student)
│   └── NO: Use LoRA fine-tuning (fastest to set up)
└── NO: Use the base model with better prompting
```

---

## Summary & Key Takeaways

| Concept | Key Insight |
|---------|-------------|
| **LoRA** | Train <1% of parameters for task-specific adaptation |
| **Rank (r)** | Controls capacity vs efficiency trade-off |
| **Distillation** | Soft labels carry more information than hard labels |
| **Temperature** | Higher T = softer distributions = more knowledge transfer |
| **Combined Loss** | Balance distillation + task loss with alpha parameter |
| **Speed vs Quality** | Distillation gives smaller, faster models; LoRA gives better quality |

---

## Exercises

### Exercise 1: Rank Ablation
Try different LoRA ranks (1, 4, 8, 16, 32, 64) and plot accuracy vs trainable parameters. What's the sweet spot?

### Exercise 2: Alpha Sweep
Vary the alpha parameter in distillation (0.1 to 0.9) and find the optimal balance between distillation and task loss.

### Exercise 3: Student Architecture Search
Try different student sizes (hidden dims: 16, 32, 64, 128). At what size does distillation stop helping?

### Exercise 4: Real LLM Distillation
Using the HuggingFace transformers library, implement distillation from GPT-2 Medium (355M) to GPT-2 Small (124M) on a text classification task.

In [None]:
# ============================================================
# Exercise 1 Starter: Rank Ablation Study
# ============================================================

# lora_ranks = [1, 4, 8, 16, 32, 64]
# rank_results = []
#
# for rank in lora_ranks:
#     config = LoraConfig(
#         task_type=TaskType.CAUSAL_LM,
#         r=rank,
#         lora_alpha=32,
#         lora_dropout=0.05,
#         target_modules=["c_attn", "c_proj"],
#     )
#     
#     # Reload base model
#     base = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
#     peft = get_peft_model(base, config)
#     
#     n_trainable = sum(p.numel() for p in peft.parameters() if p.requires_grad)
#     
#     # Train and evaluate...
#     rank_results.append({'rank': rank, 'params': n_trainable})
#     print(f"Rank {rank}: {n_trainable:,} trainable parameters")

print("Uncomment the code above to run the rank ablation study!")