Model Distillation: Compressing Large Models into Smaller, Efficient Ones
Model distillation (also called knowledge distillation) is a technique used in deep learning to compress a large, complex model (teacher) into a smaller, faster model (student) while retaining most of its performance.

 Why Use Model Distillation?
🔹 Reduces computation & memory usage for deployment on edge devices (e.g., mobile, IoT, self-driving cars).
🔹 Speeds up inference while maintaining high accuracy.
🔹 Allows transfer of knowledge from large models (GPT, BERT, ResNet) to smaller ones.

. How Model Distillation Works
Train a large, high-performance model (Teacher).
Train a smaller model (Student) using:
Soft Labels: The teacher provides probability distributions over classes.
Hard Labels: Standard ground truth labels.
Optimize the student model by minimizing the difference between:
The student's predictions.
The teacher's predictions (knowledge transfer).
Distillation Loss Function

L=(1−α)⋅L 
hard
​
 +α⋅T 
2
 ⋅L 
soft
​
 


where:

𝐿
hard
L 
hard
​
  = Cross-Entropy loss on ground truth labels.
𝐿
soft
L 
soft
​
  = KL-Divergence loss on soft labels from the teacher.
𝛼
α = Trade-off parameter.
𝑇
T = Temperature (smooths probabilities).

: Define Teacher and Student Models

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Load Pre-trained Teacher Model (ResNet-50)
teacher_model = models.resnet50(pretrained=True)
teacher_model.fc = nn.Linear(2048, 10)  # Modify for 10-class output
teacher_model.eval()

# Define Student Model (ResNet-18)
student_model = models.resnet18(pretrained=False)
student_model.fc = nn.Linear(512, 10)  # Same output as teacher


3: Distillation Loss Function

In [None]:
class DistillationLoss(nn.Module):
    def __init__(self, temperature=3, alpha=0.5):
        super(DistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')  # Soft labels
        self.ce_loss = nn.CrossEntropyLoss()  # Hard labels

    def forward(self, student_logits, teacher_logits, labels):
        soft_loss = self.kl_loss(
            nn.functional.log_softmax(student_logits / self.temperature, dim=1),
            nn.functional.softmax(teacher_logits / self.temperature, dim=1)
        )
        hard_loss = self.ce_loss(student_logits, labels)
        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss


 Train the Student Model

In [None]:
# Load CIFAR-10 Dataset
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor()
])

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root="./data", train=True, transform=transform, download=True),
    batch_size=32, shuffle=True
)

# Optimizer and Loss Function
optimizer = optim.Adam(student_model.parameters(), lr=0.001)
loss_fn = DistillationLoss(temperature=3, alpha=0.5)

# Training Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model.to(device)
student_model.to(device)

for epoch in range(10):  # Train for 10 epochs
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward Pass (Teacher & Student)
        with torch.no_grad():
            teacher_logits = teacher_model(images)
        student_logits = student_model(images)

        # Compute Distillation Loss
        loss = loss_fn(student_logits, teacher_logits, labels)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")


Variants of Model Distillation

Type	Description

Logit Matching	Student mimics teacher’s softmax output.

Feature Distillation	Student learns intermediate layer representations.

Self-Distillation	A single model teaches itself (used in BERT, ViTs).

Data-Free Distillation	Generates synthetic data to train the student.