In [1]:
# Knowledge Distillation in PyTorch - Complete Tutorial
# This notebook demonstrates how to train a small "student" model to mimic a larger "teacher" model

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

# ============================================================================
# STEP 1: Define Teacher Model (Large, Complex Network)
# ============================================================================

class TeacherModel(nn.Module):
    """Large model with more parameters - the 'expert' teacher"""
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 3 * 3, 512)
        self.fc2 = nn.Linear(512, 10)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 256 * 3 * 3)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# ============================================================================
# STEP 2: Define Student Model (Small, Simple Network)
# ============================================================================

class StudentModel(nn.Module):
    """Smaller model with fewer parameters - learns from teacher"""
    def __init__(self):
        super(StudentModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# ============================================================================
# STEP 3: Knowledge Distillation Loss Function
# ============================================================================

class DistillationLoss(nn.Module):
    """
    Combines two losses:
    1. Distillation Loss: Student learns from teacher's soft predictions
    2. Student Loss: Student learns from true labels
    
    Temperature (T): Softens probability distributions
    - Higher T: More uniform distribution (more information transfer)
    - Lower T: Sharper distribution (closer to hard labels)
    
    Alpha: Balance between distillation and student loss
    """
    def __init__(self, temperature=3.0, alpha=0.7):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.kl_div = nn.KLDivLoss(reduction='batchmean')
        self.ce_loss = nn.CrossEntropyLoss()
        
    def forward(self, student_logits, teacher_logits, labels):
        # Soften the predictions using temperature
        soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
        
        # Distillation loss (KL divergence between soft predictions)
        distillation_loss = self.kl_div(soft_student, soft_teacher) * (self.temperature ** 2)
        
        # Student loss (cross-entropy with true labels)
        student_loss = self.ce_loss(student_logits, labels)
        
        # Combined loss
        total_loss = self.alpha * distillation_loss + (1 - self.alpha) * student_loss
        
        return total_loss, distillation_loss, student_loss

# ============================================================================
# STEP 4: Data Loading (MNIST Dataset)
# ============================================================================

def get_data_loaders(batch_size=64):
    """Prepare MNIST dataset"""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

# ============================================================================
# STEP 5: Training Functions
# ============================================================================

def train_teacher(model, train_loader, epochs=5, device='cuda'):
    """Train the teacher model from scratch"""
    print("Training Teacher Model...")
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        accuracy = 100. * correct / total
        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch+1} - Avg Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%\n')
    
    return model

def train_student_with_distillation(student, teacher, train_loader, epochs=10, 
                                   temperature=3.0, alpha=0.7, device='cuda'):
    """Train student model using knowledge distillation"""
    print("Training Student Model with Knowledge Distillation...")
    student = student.to(device)
    teacher = teacher.to(device)
    teacher.eval()  # Teacher in evaluation mode
    
    optimizer = optim.Adam(student.parameters(), lr=0.001)
    distillation_criterion = DistillationLoss(temperature=temperature, alpha=alpha)
    
    for epoch in range(epochs):
        student.train()
        total_loss = 0
        total_distill_loss = 0
        total_student_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            
            # Get predictions from both models
            student_logits = student(data)
            with torch.no_grad():
                teacher_logits = teacher(data)
            
            # Calculate distillation loss
            loss, distill_loss, student_loss = distillation_criterion(
                student_logits, teacher_logits, target
            )
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_distill_loss += distill_loss.item()
            total_student_loss += student_loss.item()
            
            pred = student_logits.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        accuracy = 100. * correct / total
        avg_loss = total_loss / len(train_loader)
        avg_distill = total_distill_loss / len(train_loader)
        avg_student = total_student_loss / len(train_loader)
        
        print(f'Epoch {epoch+1} - Total Loss: {avg_loss:.4f}, '
              f'Distill Loss: {avg_distill:.4f}, Student Loss: {avg_student:.4f}, '
              f'Accuracy: {accuracy:.2f}%\n')
    
    return student

# ============================================================================
# STEP 6: Evaluation Function
# ============================================================================

def evaluate_model(model, test_loader, device='cuda'):
    """Evaluate model accuracy"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
    
    accuracy = 100. * correct / total
    return accuracy

# ============================================================================
# STEP 7: Main Execution
# ============================================================================

if __name__ == "__main__":
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}\n")
    
    # Load data
    train_loader, test_loader = get_data_loaders(batch_size=64)
    
    # Initialize models
    teacher = TeacherModel()
    student = StudentModel()
    
    # Count parameters
    teacher_params = sum(p.numel() for p in teacher.parameters())
    student_params = sum(p.numel() for p in student.parameters())
    print(f"Teacher Parameters: {teacher_params:,}")
    print(f"Student Parameters: {student_params:,}")
    print(f"Compression Ratio: {teacher_params/student_params:.2f}x\n")
    
    # Train teacher model
    teacher = train_teacher(teacher, train_loader, epochs=5, device=device)
    teacher_accuracy = evaluate_model(teacher, test_loader, device=device)
    print(f"Teacher Test Accuracy: {teacher_accuracy:.2f}%\n")
    
    # Train student with distillation
    student = train_student_with_distillation(
        student, teacher, train_loader, 
        epochs=10, temperature=3.0, alpha=0.7, device=device
    )
    student_accuracy = evaluate_model(student, test_loader, device=device)
    print(f"Student Test Accuracy: {student_accuracy:.2f}%\n")
    
    # Compare with student trained without distillation (baseline)
    print("Training baseline student without distillation for comparison...")
    baseline_student = StudentModel()
    baseline_student = train_teacher(baseline_student, train_loader, epochs=10, device=device)
    baseline_accuracy = evaluate_model(baseline_student, test_loader, device=device)
    print(f"Baseline Student Test Accuracy: {baseline_accuracy:.2f}%\n")
    
    # Summary
    print("=" * 60)
    print("FINAL RESULTS:")
    print(f"Teacher Accuracy: {teacher_accuracy:.2f}%")
    print(f"Student (with distillation) Accuracy: {student_accuracy:.2f}%")
    print(f"Student (baseline) Accuracy: {baseline_accuracy:.2f}%")
    print(f"Improvement from distillation: {student_accuracy - baseline_accuracy:.2f}%")
    print("=" * 60)

  import pynvml  # type: ignore[import]


Using device: cuda



100%|██████████| 9.91M/9.91M [00:17<00:00, 575kB/s] 
100%|██████████| 28.9k/28.9k [00:00<00:00, 120kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 856kB/s] 
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.45MB/s]


Teacher Parameters: 1,554,954
Student Parameters: 206,922
Compression Ratio: 7.51x

Training Teacher Model...
Epoch 1/5, Batch 0, Loss: 2.3112
Epoch 1/5, Batch 100, Loss: 0.0697
Epoch 1/5, Batch 200, Loss: 0.0817
Epoch 1/5, Batch 300, Loss: 0.1090
Epoch 1/5, Batch 400, Loss: 0.3740
Epoch 1/5, Batch 500, Loss: 0.0923
Epoch 1/5, Batch 600, Loss: 0.0862
Epoch 1/5, Batch 700, Loss: 0.0426
Epoch 1/5, Batch 800, Loss: 0.0191
Epoch 1/5, Batch 900, Loss: 0.0517
Epoch 1 - Avg Loss: 0.1398, Accuracy: 95.51%

Epoch 2/5, Batch 0, Loss: 0.0055
Epoch 2/5, Batch 100, Loss: 0.0049
Epoch 2/5, Batch 200, Loss: 0.0608
Epoch 2/5, Batch 300, Loss: 0.0351
Epoch 2/5, Batch 400, Loss: 0.1175
Epoch 2/5, Batch 500, Loss: 0.0055
Epoch 2/5, Batch 600, Loss: 0.0160
Epoch 2/5, Batch 700, Loss: 0.0268
Epoch 2/5, Batch 800, Loss: 0.0196
Epoch 2/5, Batch 900, Loss: 0.1173
Epoch 2 - Avg Loss: 0.0447, Accuracy: 98.69%

Epoch 3/5, Batch 0, Loss: 0.0078
Epoch 3/5, Batch 100, Loss: 0.0058
Epoch 3/5, Batch 200, Loss: 0.0557