# Knowledge Distillation in PyTorch



## Basic Setup


In [54]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import models
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Load Dataset

In [55]:
# Use 224×224 resize for ResNet compatibility
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Load full CIFAR-10 train set
full_trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)

# Calculate split sizes for train and validation sets
train_size = int(0.9 * len(full_trainset))
val_size = len(full_trainset) - train_size

# Perform split
train_subset, val_subset = random_split(full_trainset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_subset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=128, shuffle=False)

# CIFAR-10 test set and loader for accuracy evaluation
test_set = CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
test_loader = DataLoader(testset, batch_size=128, shuffle=False)


## Define Models

In [56]:
# Teacher: ResNet50 pretrained on ImageNet, re-headed for CIFAR-10
teacher = models.resnet50(pretrained=True)
teacher.fc = nn.Linear(2048, 10)

# Student: ResNet18 from scratch
student = models.resnet18(pretrained=False)
student.fc = nn.Linear(512, 10)

teacher.to(device)
student.to(device)





MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

## Fine-Tune the Teacher

In [57]:
# Train the teacher on CIFAR-10 for a few epochs
optimizer = torch.optim.Adam(teacher.parameters(), lr=1e-3)
teacher.train()
for epoch in range(10):
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = teacher(inputs)
        loss = F.cross_entropy(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"(Fine-tuning teacher)\tEpoch {epoch+1}: loss={loss.item():.4f}")


(Fine-tuning teacher)	Epoch 1: loss=0.6591
(Fine-tuning teacher)	Epoch 2: loss=0.5280
(Fine-tuning teacher)	Epoch 3: loss=0.5504
(Fine-tuning teacher)	Epoch 4: loss=0.2639
(Fine-tuning teacher)	Epoch 5: loss=0.2542
(Fine-tuning teacher)	Epoch 6: loss=0.3202
(Fine-tuning teacher)	Epoch 7: loss=0.0955
(Fine-tuning teacher)	Epoch 8: loss=0.2393
(Fine-tuning teacher)	Epoch 9: loss=0.1304
(Fine-tuning teacher)	Epoch 10: loss=0.0534


## Intermediate Feature Extraction for Matching


In [None]:
# Extract final logits and selected intermediate outputs
def extract_features(model, x, layers=[1, 2, 3, 4]):
    features = []
    x = model.conv1(x)
    x = model.bn1(x)
    x = model.relu(x)
    x = model.maxpool(x)
    for i, block in enumerate([model.layer1, model.layer2, model.layer3, model.layer4]):
        x = block(x)
        if (i + 1) in layers:
            features.append(x)
    return x, features


## Distillation Loss

In [58]:
# Combine soft and hard targets using KL divergence and cross-entropy
# T = temperature, alpha = weighting between soft and hard losses
def distillation_loss(student_logits, teacher_logits, targets, T=5.0, alpha=0.7):
    # Soft target loss (teacher softmax vs student softmax)
    soft_targets = F.kl_div(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1),
        reduction='batchmean'
    ) * (T * T)
    # Hard label loss
    hard_loss = F.cross_entropy(student_logits, targets)
    return alpha * soft_targets + (1 - alpha) * hard_loss


## Evaluation Functions

In [59]:
# Function to count trainable parameters
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Function to measure average inference latency over multiple runs
def measure_latency(model, input_size=(1, 3, 32, 32), device='cuda', repetitions=50):
    model.eval()
    inputs = torch.randn(input_size).to(device)
    with torch.no_grad():
        # Warm-up
        for _ in range(10):
            _ = model(inputs)
        # Measure
        times = []
        for _ in range(repetitions):
            start = time.time()
            _ = model(inputs)
            end = time.time()
            times.append(end - start)
    return (sum(times) / repetitions) * 1000  # ms

# Function to evaluate accuracy on test set
def evaluate_accuracy(model, dataloader, device):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total

## Train the Student via Distillation

In [None]:
# Train the student using the teacher's output as soft targets
teacher.eval()  # Teacher in eval mode
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)

# Reduce LR if validation loss doesn't improve for 3 epochs
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

for epoch in range(50):  # Adjust as needed
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Teacher outputs (frozen)
        with torch.no_grad():
            teacher_logits, teacher_feats = extract_features(teacher, inputs)

        # Student outputs
        student_logits, student_feats = extract_features(student, inputs)

        # Feature matching loss (intermediate distillation)
        feat_loss = sum(F.mse_loss(s, t.detach()) for s, t in zip(student_feats, teacher_feats))

        # Soft + hard loss
        loss = distillation_loss(student_logits, teacher_logits, labels) + 0.1 * feat_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    val_accuracy = evaluate_accuracy(student, val_loader, device)
    print(f"(Training student)\tEpoch {epoch+1}: Loss = {loss.item():.4f}, Val Accuracy={val_accuracy*100:.2f}%")
    scheduler.step(loss)

## Model Comparison Code

In [61]:

# Compare size, latency, and accuracy
teacher_params = count_params(teacher)
student_params = count_params(student)

teacher_latency = measure_latency(teacher, device=device)
student_latency = measure_latency(student, device=device)

teacher_acc = evaluate_accuracy(teacher, test_loader, device)
student_acc = evaluate_accuracy(student, test_loader, device)

print(f"Teacher Params: {teacher_params / 1e6:.2f}M")
print(f"Student Params: {student_params / 1e6:.2f}M")
print(f"Teacher Latency: {teacher_latency:.2f} ms")
print(f"Student Latency: {student_latency:.2f} ms")
print(f"Teacher Test Accuracy: {teacher_acc * 100:.2f}%")
print(f"Student Test Accuracy: {student_acc * 100:.2f}%")


Teacher Params: 11.18M
Student Params: 2.24M
Teacher Latency: 2.21 ms
Student Latency: 5.59 ms
Teacher Accuracy: 80.84%
Student Accuracy: 68.57%
