# Knowledge Distillation in PyTorch



## Basic Setup


In [1]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import models
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

print(f"PyTorch Version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device used: {device.type}")

PyTorch Version: 2.8.0.dev20250319+cu128
Device used: cuda


## Load Dataset

In [2]:
# # Use 224×224 resize for ResNet compatibility
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor()
# ])
# transform = models.ResNet50_Weights.IMAGENET1K_V2.transforms()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],  # CIFAR-10 means
                         std=[0.2023, 0.1994, 0.2010])
])

# Load full CIFAR-10 train set
full_trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)

# Calculate split sizes for train and validation sets
train_size = int(0.9 * len(full_trainset))
val_size = len(full_trainset) - train_size

# Perform split
train_subset, val_subset = random_split(full_trainset, [train_size, val_size])
print(f"Train samples: {train_size}")
print(f"Validation samples: {val_size}")

# Create DataLoaders
train_loader = DataLoader(train_subset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=128, shuffle=False)

# CIFAR-10 test set and loader for accuracy evaluation
test_set = CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False)
print(f"Test samples: {len(test_set)}")


Train samples: 45000
Validation samples: 5000
Test samples: 10000


## Define Models

In [17]:
# Teacher: ResNet50 pretrained on ImageNet, re-headed for CIFAR-10
teacher = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
teacher.fc = nn.Linear(2048, 10)

# Student: ResNet18 from scratch
student = models.resnet18(weights=None)
student.fc = nn.Linear(512, 10)

teacher = teacher.to(device)
student = student.to(device)


## Define Train Function

In [18]:

def train(model, loader, epochs, tag, lr=1e-3, save_path="model.pth", silent=False):
    """
    Trains a model with Adam and cross-entropy loss.
    Loads from save_path if it exists.
    """
        
    if os.path.exists(save_path):
        if not silent:
            print(f"Model already trained. Loading from {save_path}")
        model.load_state_dict(torch.load(save_path))
        return

    # no saved model found. training from given model state

    optimizer = torch.optim.Adam(teacher.parameters(), lr=1e-3)
    model.train()

    for epoch in range(epochs):
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = teacher(inputs)
            loss = F.cross_entropy(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        if not silent:
            print(f"({tag})\tEpoch {epoch+1}: loss={loss.item():.4f}")

            evaluate_accuracy(model, val_loader, f"Epoch {epoch+1}")
            model.train()
            
    if save_path:
        torch.save(model.state_dict(), save_path)
        if not silent:
            print(f"Training complete. Model saved to {save_path}")

# Function to evaluate accuracy given loader
def evaluate_accuracy(model, dataloader, tag):
    """
    Evaluate and print accuracy given loader
    """
    model.eval()
    model.to(device)
    correct = total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total
    print(f"Accuracy ({tag}): {accuracy*100:.2f}%")
    

## Fine-Tune the Teacher

In [19]:
# Train the teacher on CIFAR-10 for a few epochs
train(teacher, train_loader, epochs=10, tag="Fine-tuning teacher", save_path="resnet50_ImageNet1K_pretrained_CIFAR10_tuned.pth")

Model already trained. Loading from resnet50_ImageNet1K_pretrained_CIFAR10_tuned.pth


## Intermediate Feature Extraction for Matching


In [20]:
# Extract final logits and selected intermediate outputs
def extract_features(model, x, layers=[1, 2, 3, 4]):
    features = []
    x = model.conv1(x)
    x = model.bn1(x)
    x = model.relu(x)
    x = model.maxpool(x)
    for i, block in enumerate([model.layer1, model.layer2, model.layer3, model.layer4]):
        x = block(x)
        if (i + 1) in layers:
            features.append(x)
    return x, features


## Intermediate Distillation with Projections

In [21]:
class FeatureProjector(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x, target_shape):
        # Resize spatial dims if needed (e.g., from 8x8 to 4x4)
        if x.shape[2:] != target_shape[2:]:
            x = F.adaptive_avg_pool2d(x, output_size=target_shape[2:])
        return self.proj(x)

student_channels = [64, 128, 256, 512]
teacher_channels = [256, 512, 1024, 2048]

proj_layers = nn.ModuleList([
    FeatureProjector(in_c, out_c).to(device)
    for in_c, out_c in zip(student_channels, teacher_channels)
])


## Distillation Loss

In [22]:
# Combine soft and hard targets using KL divergence and cross-entropy
# T = temperature, alpha = weighting between soft and hard losses
def distillation_loss(student_logits, teacher_logits, targets, T=5.0, alpha=0.7):
    # Soft target loss (teacher softmax vs student softmax)
    soft_targets = F.kl_div(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1),
        reduction='batchmean'
    ) * (T * T)
    # Hard label loss
    hard_loss = F.cross_entropy(student_logits, targets)
    return alpha * soft_targets + (1 - alpha) * hard_loss


## Evaluation Functions

In [23]:
# Function to count trainable parameters
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Function to measure average inference latency over multiple runs
def measure_latency(model, input_size=(1, 3, 32, 32), device='cuda', repetitions=50):
    model.eval()
    inputs = torch.randn(input_size).to(device)
    with torch.no_grad():
        # Warm-up
        for _ in range(10):
            _ = model(inputs)
        # Measure
        times = []
        for _ in range(repetitions):
            start = time.time()
            _ = model(inputs)
            end = time.time()
            times.append(end - start)
    return (sum(times) / repetitions) * 1000  # ms


## Train the Student via Distillation

In [24]:
# Train the student using the teacher's output as soft targets
teacher.eval()  # Teacher in eval mode
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)

# Reduce LR if validation loss doesn't improve for 3 epochs
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

for epoch in range(50):  # Adjust as needed
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Teacher outputs (frozen)
        with torch.no_grad():
            teacher_logits, teacher_feats = extract_features(teacher, inputs)

        # Student outputs
        student_logits, student_feats = extract_features(student, inputs)





        
        # Project and match all intermediate layers
        feat_loss = 0
        for s_feat, t_feat, proj in zip(student_feats, teacher_feats, proj_layers):
            s_proj = proj(s_feat, t_feat.shape)
            feat_loss += F.mse_loss(s_proj, t_feat.detach())
        
        # Soft + hard loss
        loss = distillation_loss(student_logits, teacher_logits, labels) + 0.1 * feat_loss




        
        # # Feature matching loss (intermediate distillation)
        # feat_loss = sum(F.mse_loss(s, t.detach()) for s, t in zip(student_feats, teacher_feats))

        # loss = distillation_loss(student_logits, teacher_logits, labels) + 0.1 * feat_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    val_accuracy = evaluate_accuracy(student, val_loader, device)
    print(f"(Training student)\tEpoch {epoch+1}: Loss = {loss.item():.4f}, Val Accuracy={val_accuracy*100:.2f}%")
    scheduler.step(loss)

RuntimeError: The size of tensor a (2048) must match the size of tensor b (512) at non-singleton dimension 1

## Model Comparison Code

In [None]:

# Compare size, latency, and accuracy
teacher_params = count_params(teacher)
student_params = count_params(student)

teacher_latency = measure_latency(teacher, device=device)
student_latency = measure_latency(student, device=device)

teacher_acc = evaluate_accuracy(teacher, test_loader, device)
student_acc = evaluate_accuracy(student, test_loader, device)

print(f"Teacher Params: {teacher_params / 1e6:.2f}M")
print(f"Student Params: {student_params / 1e6:.2f}M")
print(f"Teacher Latency: {teacher_latency:.2f} ms")
print(f"Student Latency: {student_latency:.2f} ms")
print(f"Teacher Test Accuracy: {teacher_acc * 100:.2f}%")
print(f"Student Test Accuracy: {student_acc * 100:.2f}%")
