In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

In [None]:
"""
Section 3 on MNIST 

Large NN with 2 hidden layers and 1200 hidden units EACH (60k training cases from MNIST set?)
Strongly regularized with dropout and weight constraints (listed somewhere)
MNIST dataset was modified using jiterring in 2 pixels in a random direction

SMALLER student network had 2 hidden layers of 800 hidden units and NO regularization
Regularized by adding task of matching soft targets from large network at temperature 20 (fusing?)
T >= 8 at 300 or more units

So main idea is that you create a distillation loss which compares softmax soft output to teacher soft output
- e.g. Logit matching... (nope this is a special case?)
- How to get teacher soft output instead of the class? guess you just don't modify it?


Steps: (softmax vs logit = prob vs logodds)
1. Train teacher on MNIST using low temperature
2. Train student on MNIST using distillation loss with high temperature on teacher softmax output...
"""

In [None]:
"""
Main idea of paper is to train smaller model using SOFT output distribution from larger model
Softmax of the student will be higher than the parent to smooth it?

T in teacher usually 1
T in student is higher (softer probability over classes = more volatile though?) T=20 (hyperparameter)

Implement 2 stage training (just nest normal training loop into the student train function)

Set parent to eval then student to train mode
Transfer set? How to build this?
"""
def distillation_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.5):
    """
    Compute the knowledge distillation loss.
    
    Args:
        student_logits: Logits from the student model
        teacher_logits: Logits from the teacher model
        labels: True labels
        T: Temperature for softening probability distributions
        alpha: Weight for the distillation loss vs. standard cross-entropy loss
    
    Returns:
        Combined loss
    """
    # Softmax with temperature for soft targets
    soft_targets = F.softmax(teacher_logits / T, dim=1)
    soft_prob = F.log_softmax(student_logits / T, dim=1)
    
    # Calculate the distillation loss (soft targets)
    # The T^2 term is to scale the gradients appropriately
    distillation = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (T * T)
    
    # Calculate the standard cross-entropy loss (hard targets)
    standard_loss = F.cross_entropy(student_logits, labels)
    
    # Return the weighted sum
    return alpha * distillation + (1 - alpha) * standard_loss

# Training function for the student model
def train_student(teacher_model, student_model, train_loader, optimizer, device, T=2.0, alpha=0.5):
    teacher_model.eval()  # Teacher model in evaluation mode
    student_model.train()  # Student model in training mode
    
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass with the teacher model (no gradient calculation needed)
        with torch.no_grad():
            teacher_logits = teacher_model(data)
        
        # Forward pass with the student model
        student_logits = student_model(data)
        
        # Calculate the combined loss
        loss = distillation_loss(student_logits, teacher_logits, target, T, alpha)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    return running_loss / len(train_loader)

In [None]:
def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load MNIST dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.MNIST('./data', train=False, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1000)
    
    # Initialize teacher and student models
    teacher_model = SimpleCNN(num_hidden=1200).to(device)  # Larger teacher model
    student_model = SimpleCNN(num_hidden=800).to(device)   # Smaller student model
    
    # We assume the teacher model has already been trained
    # teacher_model.load_state_dict(torch.load('teacher_model.pth'))
    
    # Here we would train the teacher model first
    # For brevity, assume teacher model is already trained
    
    # Create an optimizer for the student model
    optimizer = optim.SGD(student_model.parameters(), lr=0.01, momentum=0.9)
    
    # Train the student model with knowledge distillation
    num_epochs = 10
    for epoch in range(num_epochs):
        loss = train_student(teacher_model, student_model, train_loader, optimizer, device, T=20.0, alpha=0.5)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}")
    
    # Save the distilled student model
    torch.save(student_model.state_dict(), 'distilled_student_model.pth')

if __name__ == '__main__':
    main()