In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import time
import os
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Load and preprocess MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

100%|██████████| 9.91M/9.91M [00:00<00:00, 17.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 481kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.42MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.91MB/s]


In [18]:
# Define the teacher model (larger CNN)
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [19]:
# Define the student model (smaller CNN)
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 14 * 14, 32),
            nn.ReLU(),
            nn.Linear(32, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [20]:
# Knowledge distillation loss function
def distillation_loss(y_pred, y_true, teacher_logits, temperature=5.0, alpha=0.5):
    soft_teacher = torch.softmax(teacher_logits / temperature, dim=1)
    soft_student = torch.softmax(y_pred / temperature, dim=1)
    
    # Distillation loss (KL divergence)
    distillation_ce = nn.KLDivLoss(reduction='batchmean')(torch.log(soft_student), soft_teacher) * (temperature ** 2)
    
    # Standard cross-entropy loss
    standard_ce = nn.CrossEntropyLoss()(y_pred, y_true)
    
    return alpha * distillation_ce + (1 - alpha) * standard_ce

In [21]:
# Function to evaluate model
def evaluate_model(model, data_loader, desc="Evaluating"):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc=desc, leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [22]:
# Train the teacher model 
teacher_model = TeacherModel().to(device)
optimizer = optim.Adam(teacher_model.parameters())
criterion = nn.CrossEntropyLoss()

In [23]:
print("Training teacher model...")
for epoch in range(5):
    teacher_model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/5')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = teacher_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    avg_loss = running_loss / len(train_loader)
    epoch_acc = correct / total 
    
    print(f"Epoch {epoch+1}/5 - Loss: {avg_loss:.4f}, Accuracy: {epoch_acc:.4f}")

Training teacher model...


Epoch 1/5: 100%|██████████| 469/469 [00:11<00:00, 41.22it/s]


Epoch 1/5 - Loss: 0.1918, Accuracy: 0.9414


Epoch 2/5: 100%|██████████| 469/469 [00:11<00:00, 41.34it/s]


Epoch 2/5 - Loss: 0.0493, Accuracy: 0.9849


Epoch 3/5: 100%|██████████| 469/469 [00:11<00:00, 41.69it/s]


Epoch 3/5 - Loss: 0.0328, Accuracy: 0.9897


Epoch 4/5: 100%|██████████| 469/469 [00:11<00:00, 40.88it/s]


Epoch 4/5 - Loss: 0.0249, Accuracy: 0.9921


Epoch 5/5: 100%|██████████| 469/469 [00:11<00:00, 41.73it/s]

Epoch 5/5 - Loss: 0.0205, Accuracy: 0.9934





In [24]:
teacher_acc = evaluate_model(teacher_model, test_loader, "Testing teacher")
print(f"Teacher model test accuracy: {teacher_acc:.4f}")
# Get teacher model size
torch.save(teacher_model.state_dict(), 'teacher_model.pth')
teacher_size = os.path.getsize('teacher_model.pth') / (1024 * 1024)  # Size in MB
print(f"Teacher model size: {teacher_size:.2f} MB")

# Measure teacher inference time 
start_time = time.time()
teacher_model.eval()
with torch.no_grad():
    test_subset = DataLoader(test_dataset, batch_size=1000, shuffle=False)
    for images, _ in tqdm(test_subset, desc="Teacher inference", leave=False):
        images = images.to(device)
        teacher_model(images)
        break  # Only process 1000 samples
teacher_inference_time = time.time() - start_time
print(f"Teacher inference time (1000 samples): {teacher_inference_time:.4f} seconds")

                                                                

Teacher model test accuracy: 0.9915
Teacher model size: 1.75 MB


                                                         

Teacher inference time (1000 samples): 0.1585 seconds




In [25]:
# Precompute teacher logits for training with tqdm
print("Computing teacher logits...")
teacher_logits = []
teacher_model.eval()
with torch.no_grad():
    for images, _ in tqdm(train_loader, desc="Computing logits", leave=False):
        images = images.to(device)
        logits = teacher_model(images)
        teacher_logits.append(logits.cpu())
teacher_logits = torch.cat(teacher_logits, dim=0)

Computing teacher logits...


                                                                   

In [26]:
# Train the student model with knowledge distillation and tqdm
student_model = StudentModel().to(device)
optimizer = optim.Adam(student_model.parameters())

In [27]:
# Train the student model with knowledge distillation
print("\nTraining student model with knowledge distillation...")
for epoch in range(5):
    student_model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Student Epoch {epoch+1}/5')
    for i, (images, labels) in pbar:
        images, labels = images.to(device), labels.to(device)
        batch_logits = teacher_logits[i * 128:(i + 1) * 128].to(device)
        
        optimizer.zero_grad()
        outputs = student_model(images)
        loss = distillation_loss(outputs, labels, batch_logits)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    avg_loss = running_loss / len(train_loader)
    epoch_acc = correct / total
    print(f"Student Epoch {epoch+1}/5 - Loss: {avg_loss:.4f}, Accuracy: {epoch_acc:.4f}")


Training student model with knowledge distillation...


Student Epoch 1/5: 100%|██████████| 469/469 [00:10<00:00, 43.13it/s]


Student Epoch 1/5 - Loss: 16.8721, Accuracy: 0.6451


Student Epoch 2/5: 100%|██████████| 469/469 [00:10<00:00, 43.40it/s]


Student Epoch 2/5 - Loss: 16.7272, Accuracy: 0.8600


Student Epoch 3/5: 100%|██████████| 469/469 [00:11<00:00, 41.84it/s]


Student Epoch 3/5 - Loss: 16.7064, Accuracy: 0.9103


Student Epoch 4/5: 100%|██████████| 469/469 [00:11<00:00, 40.62it/s]


Student Epoch 4/5 - Loss: 16.6718, Accuracy: 0.9276


Student Epoch 5/5: 100%|██████████| 469/469 [00:11<00:00, 41.72it/s]

Student Epoch 5/5 - Loss: 16.6715, Accuracy: 0.9335





In [28]:
student_acc = evaluate_model(student_model, test_loader, "Testing student")
print(f"Student model test accuracy: {student_acc:.4f}")

                                                                

Student model test accuracy: 0.9484




In [29]:
# Get student model size
torch.save(student_model.state_dict(), 'student_model.pth')
student_size = os.path.getsize('student_model.pth') / (1024 * 1024)  # Size in MB
print(f"Student model size: {student_size:.2f} MB")

Student model size: 0.39 MB


In [30]:
# Measure student inference time with tqdm
start_time = time.time()
student_model.eval()
with torch.no_grad():
    test_subset = DataLoader(test_dataset, batch_size=1000, shuffle=False)
    for images, _ in tqdm(test_subset, desc="Student inference", leave=False):
        images = images.to(device)
        student_model(images)
        break  # Only process 1000 samples
student_inference_time = time.time() - start_time
print(f"Student inference time (1000 samples): {student_inference_time:.4f} seconds")


                                                         

Student inference time (1000 samples): 0.1694 seconds




In [31]:

# Summary of benefits
print("\nKnowledge Distillation Benefits:")
print(f"- Teacher Accuracy: {teacher_acc:.4f}, Student Accuracy: {student_acc:.4f}")
print(f"- Model Size Reduction: {(1 - student_size / teacher_size) * 100:.2f}%")
print(f"- Inference Time Reduction: {(1 - student_inference_time / teacher_inference_time) * 100:.2f}%")


Knowledge Distillation Benefits:
- Teacher Accuracy: 0.9915, Student Accuracy: 0.9484
- Model Size Reduction: 77.89%
- Inference Time Reduction: -6.90%
