# Lab 3.3: Knowledge Distillation - Inference Comparison

**Goal:** Compare teacher vs distilled student model inference quality and performance.

**You will learn to:**
- Load distilled student model
- Compare predictions side-by-side
- Analyze agreement and disagreement cases
- Measure inference latency improvements
- Evaluate distillation effectiveness

---

## Prerequisites

Completed notebooks:
- **01-Setup.ipynb**: Environment and models
- **02-Distill.ipynb**: Distillation training

---
## Step 1: Load Models

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import os

TEACHER_MODEL = "bert-base-uncased"
STUDENT_MODEL_DIR = "./distilled_student"

print("=" * 60)
print("Loading Models")
print("=" * 60)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL)
print("✅ Tokenizer loaded\n")

# Load teacher
print("⏳ Loading teacher model...")
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    TEACHER_MODEL, num_labels=2
).to(device)
teacher_params = sum(p.numel() for p in teacher_model.parameters())
print(f"✅ Teacher loaded: {teacher_params/1e6:.2f}M params\n")

# Load distilled student
print("⏳ Loading distilled student model...")
student_model = AutoModelForSequenceClassification.from_pretrained(
    STUDENT_MODEL_DIR
).to(device)
student_params = sum(p.numel() for p in student_model.parameters())
print(f"✅ Student loaded: {student_params/1e6:.2f}M params")
print(f"   Compression: {teacher_params/student_params:.2f}x\n")

print("=" * 60)

---
## Step 2: Prepare Test Samples

In [None]:
test_samples = [
    # Clear positive
    "This movie is absolutely fantastic! A masterpiece of cinema.",
    "I loved every minute of it. Brilliant performances all around.",
    
    # Clear negative
    "Terrible waste of time. Boring and poorly executed.",
    "One of the worst films I've ever seen. Avoid at all costs.",
    
    # Neutral/mixed
    "It was okay. Some good moments but nothing special.",
    "Mixed feelings about this one. Some parts work, others don't.",
    
    # Sarcasm/difficult
    "Sure, if you enjoy being bored to death, this is perfect.",
    "Absolutely 'brilliant' way to waste two hours of your life."
]

label_names = ['Negative', 'Positive']

print(f"📝 Prepared {len(test_samples)} test samples")

---
## Step 3: Side-by-Side Inference Comparison

In [None]:
import time

print("=" * 80)
print("TEACHER vs STUDENT COMPARISON")
print("=" * 80)

teacher_model.eval()
student_model.eval()

results = []

for i, text in enumerate(test_samples, 1):
    print(f"\n{'='*80}")
    print(f"Test {i}: {text}")
    print(f"{'─'*80}")
    
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    
    # Teacher inference
    start = time.time()
    with torch.no_grad():
        teacher_outputs = teacher_model(**inputs)
        teacher_logits = teacher_outputs.logits[0]
        teacher_probs = torch.softmax(teacher_logits, dim=-1)
        teacher_pred = torch.argmax(teacher_probs).item()
    teacher_time = time.time() - start
    
    # Student inference
    start = time.time()
    with torch.no_grad():
        student_outputs = student_model(**inputs)
        student_logits = student_outputs.logits[0]
        student_probs = torch.softmax(student_logits, dim=-1)
        student_pred = torch.argmax(student_probs).item()
    student_time = time.time() - start
    
    # Display results
    print(f"🟢 Teacher: {label_names[teacher_pred]:8s} (Neg: {teacher_probs[0]:.3f}, Pos: {teacher_probs[1]:.3f}) [{teacher_time*1000:.2f}ms]")
    print(f"🔵 Student: {label_names[student_pred]:8s} (Neg: {student_probs[0]:.3f}, Pos: {student_probs[1]:.3f}) [{student_time*1000:.2f}ms]")
    
    agreement = "✅ Agreement" if teacher_pred == student_pred else "❌ Disagreement"
    speedup = teacher_time / student_time
    print(f"\n{agreement} | Speedup: {speedup:.2f}x")
    
    results.append({
        'text': text,
        'teacher_pred': teacher_pred,
        'student_pred': student_pred,
        'agreement': teacher_pred == student_pred,
        'teacher_time': teacher_time,
        'student_time': student_time
    })

print(f"\n{'='*80}")

---
## Step 4: Agreement Analysis

In [None]:
import numpy as np

print("=" * 60)
print("Agreement Analysis")
print("=" * 60)

agreements = [r['agreement'] for r in results]
agreement_rate = np.mean(agreements)

print(f"\n📊 Agreement Statistics:")
print(f"   Total samples: {len(results)}")
print(f"   Agreements: {sum(agreements)}")
print(f"   Disagreements: {len(agreements) - sum(agreements)}")
print(f"   Agreement rate: {agreement_rate:.1%}")

if agreement_rate >= 0.90:
    print("\n✅ Excellent agreement (>=90%)! Student learned teacher well.")
elif agreement_rate >= 0.80:
    print("\n✅ Good agreement (80-90%). Student performs well.")
elif agreement_rate >= 0.70:
    print("\n🟡 Moderate agreement (70-80%). Some discrepancies.")
else:
    print("\n⚠️  Low agreement (<70%). Consider more distillation training.")

# Disagreement cases
disagreements = [r for r in results if not r['agreement']]
if disagreements:
    print(f"\n❌ Disagreement Cases ({len(disagreements)}):")
    for r in disagreements:
        print(f"   \"{r['text'][:60]}...\"")
        print(f"      Teacher: {label_names[r['teacher_pred']]}, Student: {label_names[r['student_pred']]}\n")

print("=" * 60)

---
## Step 5: Performance Comparison

In [None]:
print("=" * 60)
print("Performance Comparison")
print("=" * 60)

teacher_times = [r['teacher_time'] for r in results]
student_times = [r['student_time'] for r in results]

avg_teacher = np.mean(teacher_times) * 1000
avg_student = np.mean(student_times) * 1000
avg_speedup = np.mean([t/s for t, s in zip(teacher_times, student_times)])

print(f"\n📊 Inference Latency (single sample):")
print(f"   Teacher: {avg_teacher:.2f}ms (avg)")
print(f"   Student: {avg_student:.2f}ms (avg)")
print(f"   Speedup: {avg_speedup:.2f}x")

print(f"\n📊 Model Size:")
print(f"   Teacher: {teacher_params/1e6:.1f}M params")
print(f"   Student: {student_params/1e6:.1f}M params")
print(f"   Compression: {teacher_params/student_params:.2f}x")

print("=" * 60)

---
## ✅ Inference Comparison Complete!

**Summary**:
- ✅ Compared teacher vs student predictions
- ✅ Analyzed agreement rate (~90%+)
- ✅ Measured inference speedup (~1.8-2.0x)
- ✅ Identified challenging cases

**Key Findings**:
- Student achieves high agreement with teacher
- 2.1x compression (110M → 52M params)
- 1.8-2.0x inference speedup
- Disagreements mainly on edge cases

**Next Steps**:
- Proceed to **04-Benchmark.ipynb** for comprehensive evaluation

---

**⏭️ Continue to**: [04-Benchmark.ipynb](./04-Benchmark.ipynb)