# Knowledge Distillation using KL Divergence Loss in PyTorch

This notebook demonstrates knowledge distillation (KD) using KL divergence loss on the CIFAR‑10 dataset in PyTorch. We first build and train a high‑capacity teacher CNN that achieves high accuracy on CIFAR‑10. Then, we define a simpler student model and train it in two ways:

1. Using standard cross‑entropy loss (normal training).
2. Using knowledge distillation (KD) that combines cross‑entropy loss with KL divergence loss between the teacher’s and student’s softened outputs.

At the end, we compare the test performance of the teacher, the normally trained student, and the KD-trained student, and provide a conclusion.

## 1. Setup and Imports

We import PyTorch, torchvision, and other necessary libraries. We set seeds for reproducibility and define the device.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import time

# Set device (use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# For reproducibility
# torch.manual_seed(42)
# np.random.seed(42)

Using device: cuda


## 2. Load and Preprocess the CIFAR‑10 Dataset

We load the CIFAR‑10 dataset using torchvision. The images are converted to tensors and normalized. Data loaders are created for training and testing.

In [16]:
import torchvision.transforms as transforms

transform =  transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

batch_size = 256
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0,pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print("Training set size:", len(train_dataset))
print("Test set size:", len(test_dataset))


Files already downloaded and verified
Files already downloaded and verified
Training set size: 50000
Test set size: 10000


## 3. Build and Train the Teacher Model

The teacher model is a deep CNN with three convolutional blocks and a fully connected classifier. It is trained using standard cross‑entropy loss.

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import time

class TeacherNet(nn.Module):
    def __init__(self, num_classes=10):
        super(TeacherNet, self).__init__()
        # Block 1
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        # Block 2
        self.block2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        # Block 3
        self.block3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        # Classification block
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.classifier(x)
        return x

# Initialize model and move to device
teacher = TeacherNet().to(device)
print(teacher)

TeacherNet(
  (block1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (block2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (block3): Sequential(
    (0): Conv2d(

In [5]:
def train_teacher(model, train_loader, test_loader, num_epochs=50, lr=0.1):
    # Use SGD with momentum and weight decay for better generalization
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()
    
    # Cosine Annealing LR Scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    best_acc = 0.0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        start_time = time.time()
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        # Step the scheduler at the end of the epoch
        scheduler.step()
        
        # Evaluate on test set
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        acc = 100. * correct / total
        print(f"Teacher Epoch {epoch+1}/{num_epochs} - Loss: {running_loss/len(train_loader):.4f}, Test Acc: {acc:.2f}%, Time: {time.time()-start_time:.1f}s")
        
        if acc > best_acc:
            best_acc = acc
    print(f"Best teacher accuracy: {best_acc:.2f}%")

# Uncomment to train teacher model
train_teacher(teacher, train_loader, test_loader, num_epochs=50, lr=0.1)


Teacher Epoch 1/50 - Loss: 0.6567, Test Acc: 77.80%, Time: 13.0s
Teacher Epoch 2/50 - Loss: 0.5928, Test Acc: 78.66%, Time: 13.1s
Teacher Epoch 3/50 - Loss: 0.5557, Test Acc: 79.82%, Time: 13.9s
Teacher Epoch 4/50 - Loss: 0.5289, Test Acc: 80.47%, Time: 13.2s
Teacher Epoch 5/50 - Loss: 0.5013, Test Acc: 81.20%, Time: 13.7s
Teacher Epoch 6/50 - Loss: 0.4774, Test Acc: 81.63%, Time: 13.8s
Teacher Epoch 7/50 - Loss: 0.4557, Test Acc: 82.13%, Time: 13.7s
Teacher Epoch 8/50 - Loss: 0.4340, Test Acc: 82.82%, Time: 14.1s
Teacher Epoch 9/50 - Loss: 0.4161, Test Acc: 82.93%, Time: 13.6s
Teacher Epoch 10/50 - Loss: 0.3926, Test Acc: 82.97%, Time: 13.9s
Teacher Epoch 11/50 - Loss: 0.3764, Test Acc: 83.18%, Time: 13.5s
Teacher Epoch 12/50 - Loss: 0.3623, Test Acc: 83.61%, Time: 13.8s
Teacher Epoch 13/50 - Loss: 0.3452, Test Acc: 83.26%, Time: 13.5s
Teacher Epoch 14/50 - Loss: 0.3291, Test Acc: 84.36%, Time: 13.7s
Teacher Epoch 15/50 - Loss: 0.3139, Test Acc: 84.44%, Time: 13.8s
Teacher Epoch 16/50

## 4. Build and Train the Student Model (Without KD)

The student model is a simpler CNN. It is first trained using standard cross‑entropy loss.

In [26]:
class StudentNet(nn.Module):
    def __init__(self, num_classes=10):
        super(StudentNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, num_classes)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

student_normal = StudentNet().to(device)
print(student_normal)

StudentNet(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=4096, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


In [27]:
import warnings
warnings.filterwarnings("ignore")

def train_student_normal(model, train_loader, test_loader, num_epochs=25, lr=0.005):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            targets_onehot = F.one_hot(targets, num_classes=10).float().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets_onehot)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        acc = 100. * correct / total
        print(f"Student Normal Epoch {epoch+1}/{num_epochs} - Loss: {running_loss/len(train_loader):.4f}, Test Acc: {acc:.2f}%")
  
# Uncomment to train student model using MSELoss
train_student_normal(student_normal, train_loader, test_loader, num_epochs=25, lr=0.0001)


Student Normal Epoch 1/25 - Loss: 0.0804, Test Acc: 41.81%
Student Normal Epoch 2/25 - Loss: 0.0736, Test Acc: 46.45%
Student Normal Epoch 3/25 - Loss: 0.0708, Test Acc: 49.49%
Student Normal Epoch 4/25 - Loss: 0.0686, Test Acc: 51.46%
Student Normal Epoch 5/25 - Loss: 0.0667, Test Acc: 53.92%
Student Normal Epoch 6/25 - Loss: 0.0651, Test Acc: 54.65%
Student Normal Epoch 7/25 - Loss: 0.0636, Test Acc: 56.18%
Student Normal Epoch 8/25 - Loss: 0.0624, Test Acc: 56.55%
Student Normal Epoch 9/25 - Loss: 0.0612, Test Acc: 57.45%
Student Normal Epoch 10/25 - Loss: 0.0602, Test Acc: 58.88%
Student Normal Epoch 11/25 - Loss: 0.0593, Test Acc: 58.45%
Student Normal Epoch 12/25 - Loss: 0.0585, Test Acc: 59.66%
Student Normal Epoch 13/25 - Loss: 0.0576, Test Acc: 61.10%
Student Normal Epoch 14/25 - Loss: 0.0569, Test Acc: 60.64%
Student Normal Epoch 15/25 - Loss: 0.0562, Test Acc: 60.80%
Student Normal Epoch 16/25 - Loss: 0.0556, Test Acc: 62.56%
Student Normal Epoch 17/25 - Loss: 0.0548, Test A

## 5. Train the Student Model with Knowledge Distillation (KD)

The student model is re‑trained using knowledge distillation. In each training step:
- The teacher (in eval mode) produces softened predictions using a temperature parameter.
- The student produces predictions.
- We compute the hard loss using cross‑entropy with true labels.
- We compute the KD loss using KL divergence between the student’s and teacher’s softened predictions.
- The total loss is a weighted sum:


total_loss = \$ \alpha \$ * CE_loss + (1 - \$ \alpha \$) * (temperature**2) * KD_loss


KLDivLoss in PyTorch expects the input as log‑probabilities and the target as probabilities.

In [8]:
import warnings
warnings.filterwarnings("ignore")
student_kd = StudentNet().to(device)

def train_student_kd(student, teacher, train_loader, test_loader, num_epochs=25, lr=0.001, temperature=5.0, alpha=0.5):
    optimizer = optim.Adam(student.parameters(), lr=lr)
    ce_loss_fn = nn.CrossEntropyLoss()
    kd_loss_fn = nn.KLDivLoss(reduction='batchmean')
    
    teacher.eval()  # Freeze teacher
    
    for epoch in range(num_epochs):
        student.train()
        running_loss = 0.0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            
            with torch.no_grad():
                teacher_outputs = teacher(inputs)
                teacher_soft = F.softmax(teacher_outputs / temperature, dim=1)
            
            student_outputs = student(inputs)
            ce_loss = ce_loss_fn(student_outputs, targets)
            student_log_soft = F.log_softmax(student_outputs / temperature, dim=1)
            kd_loss = kd_loss_fn(student_log_soft, teacher_soft)
            loss = alpha * ce_loss + (1 - alpha) * (temperature**2) * kd_loss
            
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
        student.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = student(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        acc = 100. * correct / total
        print(f"Student KD Epoch {epoch+1}/{num_epochs} - Loss: {running_loss/len(train_loader):.4f}, Test Acc: {acc:.2f}%")
        
# Uncomment to train student model with KD
train_student_kd(student_kd, teacher, train_loader, test_loader, num_epochs=30, lr=0.001, temperature=4.5, alpha=0.8)


Student KD Epoch 1/30 - Loss: 3.4799, Test Acc: 51.18%
Student KD Epoch 2/30 - Loss: 2.7296, Test Acc: 55.03%
Student KD Epoch 3/30 - Loss: 2.4162, Test Acc: 60.62%
Student KD Epoch 4/30 - Loss: 2.2296, Test Acc: 63.10%
Student KD Epoch 5/30 - Loss: 2.0980, Test Acc: 62.96%
Student KD Epoch 6/30 - Loss: 1.9840, Test Acc: 65.68%
Student KD Epoch 7/30 - Loss: 1.8619, Test Acc: 66.69%
Student KD Epoch 8/30 - Loss: 1.7834, Test Acc: 68.14%
Student KD Epoch 9/30 - Loss: 1.7143, Test Acc: 69.56%
Student KD Epoch 10/30 - Loss: 1.6416, Test Acc: 70.69%
Student KD Epoch 11/30 - Loss: 1.5712, Test Acc: 70.61%
Student KD Epoch 12/30 - Loss: 1.5153, Test Acc: 71.22%
Student KD Epoch 13/30 - Loss: 1.4909, Test Acc: 71.46%
Student KD Epoch 14/30 - Loss: 1.4410, Test Acc: 72.72%
Student KD Epoch 15/30 - Loss: 1.4134, Test Acc: 72.68%
Student KD Epoch 16/30 - Loss: 1.3639, Test Acc: 73.95%
Student KD Epoch 17/30 - Loss: 1.3435, Test Acc: 73.83%
Student KD Epoch 18/30 - Loss: 1.3235, Test Acc: 74.48%
S

## 6. Comparison and Conclusion

The following cell compares the test accuracies of:
- The Teacher Model
- The Student Model trained without KD (normal training)
- The Student Model trained with KD

Based on the results, we conclude how knowledge distillation (using KL divergence loss) can help the student model achieve performance closer to the teacher model.

In [28]:
def evaluate(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return 100. * correct / total

teacher_acc = evaluate(teacher, test_loader)
student_normal_acc = evaluate(student_normal, test_loader)
student_kd_acc = evaluate(student_kd, test_loader)

print(f"Teacher Model Test Accuracy: {teacher_acc:.2f}%")
print(f"Student Model (Normal Training) Test Accuracy: {student_normal_acc:.2f}%")
print(f"Student Model (KD Training) Test Accuracy: {student_kd_acc:.2f}%")

Teacher Model Test Accuracy: 88.77%
Student Model (Normal Training) Test Accuracy: 65.59%
Student Model (KD Training) Test Accuracy: 76.84%


## Conclusion

In this notebook:
- A high‑capacity teacher model was trained on CIFAR‑10 using standard cross‑entropy loss.
- A simpler student model was trained in two ways:
  - Using normal training with MSE loss.
  - Using knowledge distillation, where the student was trained with a combination of cross‑entropy loss and KL divergence loss (comparing softened predictions of the teacher and student).
- The comparison shows that knowledge distillation can help the student model achieve improved performance, narrowing the gap between the student and the teacher.

Adjusting hyperparameters (such as temperature and the loss weight alpha) and training duration can further enhance performance.