<a href="https://colab.research.google.com/github/Janindu-Muthunayaka/model-distillation/blob/main/Classificationv2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [None]:
# Commented out IPython magic to ensure Python compatibility.
# %pip install numpy pandas scikit-learn torch torchvision

import torch
import numpy as np
import random

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# If using GPU
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False



epochs=20

import torchvision.datasets as datasets
import torchvision.transforms as transforms

#Dataset & Preparation

In [None]:
# Download and load MNIST dataset
transform = transforms.ToTensor()

train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Flatten and normalize
xTrain = train_data.data.view(-1, 784).float() / 255.0
yTrain = train_data.targets
xTest = test_data.data.view(-1, 784).float() / 255.0
yTest = test_data.targets


# Data Preparation for PyTorch

In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader

# No need for train_test_split since MNIST already split as well as no scaling needed as normalized to [0,1]

# Already converted to PyTorch tensors and just need to create TensorDatasets

train_dataset = TensorDataset(xTrain, yTrain)
test_dataset = TensorDataset(xTest, yTest)

# Create DataLoaders
batch_size = 128  # Increased batch size for efficiency
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    generator=torch.Generator().manual_seed(seed) # Use the defined seed
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False # No need to shuffle test data
)

#Teacher Creation


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class mnistTeacher(nn.Module):
  def __init__(self):
    super(mnistTeacher,self).__init__()
    self.fc1=nn.Linear(784,1024)
    self.fc2=nn.Linear(1024,512)
    self.fc3=nn.Linear(512,256)
    self.fc4=nn.Linear(256,128)
    self.fc5=nn.Linear(128,64)
    self.fc6=nn.Linear(64,10)

  def forward(self,x):
    x=F.relu(self.fc1(x))
    x=F.relu(self.fc2(x))
    x=F.relu(self.fc3(x))
    x=F.relu(self.fc4(x))
    x=F.relu(self.fc5(x))
    x=self.fc6(x)
    return x



import torch.optim as optim

teacher = mnistTeacher()  # teacher model
criterion = nn.CrossEntropyLoss()  # classification loss
optimizer = optim.Adam(teacher.parameters(), lr=0.001)
#epochs = already declare

for epoch in range(epochs):
    teacher.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = teacher(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    if (epoch+1) % 5 == 0:  # Print every 5 epochs for cleanliness
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}')


# Creating & Training 'Student'  on Knowledge Distillation

In [None]:
class mnistStudent(nn.Module):
  def __init__(self):
    super(mnistStudent,self).__init__()
    self.fc1=nn.Linear(784,128)
    self.fc2=nn.Linear(128,64)
    self.fc3=nn.Linear(64,10)

  def forward(self,x):
    x=F.relu(self.fc1(x))
    x=F.relu(self.fc2(x))
    x=self.fc3(x)
    return x


def distillation_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.5):
    # Hard loss: student vs true labels
    hard_loss = F.cross_entropy(student_logits, labels)

    # Soft loss: student vs teacher (using softmax with temperature)
    soft_student = F.log_softmax(student_logits / T, dim=1)
    soft_teacher = F.softmax(teacher_logits / T, dim=1)
    soft_loss = F.kl_div(soft_student, soft_teacher, reduction="batchmean") * (T * T)

    # Weighted sum
    return alpha * hard_loss + (1 - alpha) * soft_loss

#Training Student"


student = mnistStudent()
optimizer_s = optim.Adam(student.parameters(), lr=0.001)

#epochs = already declare
for epoch in range(epochs):
    student.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer_s.zero_grad()

        # Teacher predictions (fixed, no gradients)
        with torch.no_grad():
            teacher_outputs = teacher(inputs)

        # Student predictions
        student_outputs = student(inputs)

        # Distillation loss
        loss = distillation_loss(student_outputs, teacher_outputs, labels, T=2.0, alpha=0.5)

        loss.backward()
        optimizer_s.step()
        running_loss += loss.item()

    if (epoch+1) % 5 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], KD Loss: {running_loss / len(train_loader):.4f}")


# Creating & Training 'Person' Model with regular training

In [None]:
class mnistPerson(nn.Module):
  def __init__(self):
    super(mnistPerson,self).__init__()
    self.fc1=nn.Linear(784,128)
    self.fc2=nn.Linear(128,64)
    self.fc3=nn.Linear(64,10)

  def forward(self,x):
    x=F.relu(self.fc1(x))
    x=F.relu(self.fc2(x))
    x=self.fc3(x)
    return x

person = mnistPerson()

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(person.parameters(), lr=0.001)



# Training loop
for epoch in range(epochs):
    person.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()             # reset gradients
        outputs = person(inputs)          # forward pass
        loss = criterion(outputs, labels) # compute loss
        loss.backward()                   # backpropagation
        optimizer.step()                  # update weights
        running_loss += loss.item()

    # print every 5 epochs
    if (epoch+1) % 5 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}")


# Creating Test functions

In [None]:
import time

# 1. Accuracy function
def evaluate_accuracy(model, xTest, yTest):
    model.eval()
    with torch.no_grad():
        y_pred = model(xTest)
        predicted_classes = torch.argmax(y_pred, dim=1)
        accuracy = (predicted_classes == yTest).float().mean().item()
    return accuracy

import time
#2.Time Function
def evaluate_inference_time_stable(model, xTest, repeats=10):
    model.eval()
    # warm-up
    with torch.no_grad():
        _ = model(xTest)

    times = []
    with torch.no_grad():
        for _ in range(repeats):
            start = time.time()
            _ = model(xTest)
            end = time.time()
            times.append(end - start)
    total_time = sum(times) / repeats
    avg_time = total_time / xTest.size(0)
    return total_time, avg_time


# 3. Model size function
def evaluate_model_size(model):
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    param_size_MB = num_params * 4 / (1024**2)  # 4 bytes per float32
    return num_params, param_size_MB

#Results Individual

In [None]:
# Teacher
teacher_acc = evaluate_accuracy(teacher, xTest, yTest)
teacher_time, teacher_avg_time = evaluate_inference_time_stable(teacher, xTest)
teacher_params, teacher_size = evaluate_model_size(teacher)

print(f"Teacher Accuracy: {teacher_acc:.4f}")
print(f"Teacher Inference Time: {teacher_time:.6f}s ({teacher_avg_time:.6f}s per sample)")
print(f"Teacher Params: {teacher_params}, Size: {teacher_size:.6f} MB")

# Student
student_acc = evaluate_accuracy(student, xTest, yTest)
student_time, student_avg_time = evaluate_inference_time_stable(student, xTest)
student_params, student_size = evaluate_model_size(student)

print(f"Student Accuracy: {student_acc:.4f}")
print(f"Student Inference Time: {student_time:.6f}s ({student_avg_time:.6f}s per sample)")
print(f"Student Params: {student_params}, Size: {student_size:.6f} MB")

# Person metrics
person_acc = evaluate_accuracy(person, xTest, yTest)
person_time, person_avg_time = evaluate_inference_time_stable(person, xTest)
person_params, person_size = evaluate_model_size(person)

print(f"Person Accuracy: {person_acc:.4f}")
print(f"Person Inference Time: {person_time:.6f}s ({person_avg_time:.6f}s per sample)")
print(f"Person Params: {person_params}, Size: {person_size:.6f} MB")



#Result Comparism

In [None]:
def percent_change(val1, val2):
    return ((val1 - val2) / val2) * 100 if val2 != 0 else float('inf')

print("--- Results Comparison ---")

# Student vs Teacher
acc_change_s_t = percent_change(student_acc, teacher_acc)
time_change_s_t = percent_change(student_time, teacher_time)
avg_time_change_s_t = percent_change(student_avg_time, teacher_avg_time)
params_change_s_t = percent_change(student_params, teacher_params)
size_change_s_t = percent_change(student_size, teacher_size)

print("\n--- Student vs Teacher ---")
print(f"Accuracy Change: {acc_change_s_t:.2f}%")
print(f"Total Inference Time Change: {time_change_s_t:.2f}%")
print(f"Avg Inference Time per Sample Change: {avg_time_change_s_t:.2f}%")
print(f"Params Change: {params_change_s_t:.2f}%")
print(f"Model Size Change: {size_change_s_t:.2f}%")

# Person vs Teacher
acc_change_p_t = percent_change(person_acc, teacher_acc)
time_change_p_t = percent_change(person_time, teacher_time)
avg_time_change_p_t = percent_change(person_avg_time, teacher_avg_time)
params_change_p_t = percent_change(person_params, teacher_params)
size_change_p_t = percent_change(person_size, teacher_size)

print("\n--- Person vs Teacher ---")
print(f"Accuracy Change: {acc_change_p_t:.2f}%")
print(f"Total Inference Time Change: {time_change_p_t:.2f}%")
print(f"Avg Inference Time per Sample Change: {avg_time_change_p_t:.2f}%")
print(f"Params Change: {params_change_p_t:.2f}%")
print(f"Model Size Change: {size_change_p_t:.2f}%")

# Student vs Person
acc_change_s_p = percent_change(student_acc, person_acc)
time_change_s_p = percent_change(student_time, person_time)
avg_time_change_s_p = percent_change(student_avg_time, person_avg_time)
params_change_s_p = percent_change(student_params, person_params)
size_change_s_p = percent_change(student_size, person_size)

print("\n--- Student vs Person ---")
print(f"Accuracy Change: {acc_change_s_p:.2f}%")
print(f"Total Inference Time Change: {time_change_s_p:.2f}%")
print(f"Avg Inference Time per Sample Change: {avg_time_change_s_p:.2f}%")
print(f"Params Change: {params_change_s_p:.2f}%")
print(f"Model Size Change: {size_change_s_p:.2f}%")



100%|██████████| 9.91M/9.91M [00:00<00:00, 42.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.09MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.93MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.72MB/s]


Epoch [5/20], Loss: 0.0434
Epoch [10/20], Loss: 0.0207
Epoch [15/20], Loss: 0.0103
Epoch [20/20], Loss: 0.0080
Epoch [5/20], KD Loss: 0.2123
Epoch [10/20], KD Loss: 0.1007
Epoch [15/20], KD Loss: 0.0597
Epoch [20/20], KD Loss: 0.0391
Teacher Accuracy: 0.9807
Teacher Inference Time: 0.497231s (0.000050s per sample)
Teacher Params: 1501770, Size: 5.728798 MB
Student Accuracy: 0.9786
Student Inference Time: 0.041835s (0.000004s per sample)
Student Params: 109386, Size: 0.417274 MB

--- Percentage Change (Student vs Teacher) ---
Accuracy Change: -0.21%
Total Inference Time Change: -91.59%
Avg Inference Time per Sample Change: -91.59%
Params Change: -92.72%
Model Size Change: -92.72%
Epoch [5/20], Loss: 0.0773
Epoch [10/20], Loss: 0.0288
Epoch [15/20], Loss: 0.0133
Epoch [20/20], Loss: 0.0110
Person Accuracy: 0.9783
Person Inference Time: 0.044387s (0.000004s per sample)
Person Params: 109386, Size: 0.417274 MB

--- Percentage Change(Person vs Teacher) ---
Accuracy Change: -0.24%
Total Infe