<a href="https://colab.research.google.com/github/Janindu-Muthunayaka/model-distillation/blob/main/Classificationv2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install necessary packages
!pip install torch torchvision numpy pandas

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time
import numpy as np
import random

# Seed for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# -----------------------------
# Hyperparameters
# -----------------------------
epochs = 12
batch_size = 64
lr = 0.001
temperature = 2.0
alpha = 0.5

# -----------------------------
# Data Preparation
# -----------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# -----------------------------
# Teacher CNN
# -----------------------------
class TeacherCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.pool = nn.MaxPool2d(2,2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# -----------------------------
# Student CNN (smaller)
# -----------------------------
class StudentCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.fc1 = nn.Linear(16*7*7, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# -----------------------------
# Distillation Loss
# -----------------------------
def distillation_loss(student_logits, teacher_logits, labels, T=2.0, alpha=0.5):
    hard_loss = F.cross_entropy(student_logits, labels)
    soft_student = F.log_softmax(student_logits / T, dim=1)
    soft_teacher = F.softmax(teacher_logits / T, dim=1)
    soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T * T)
    return alpha * hard_loss + (1 - alpha) * soft_loss

# -----------------------------
# Training Functions
# -----------------------------
def train_model(model, train_loader, optimizer, criterion, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for x, y in train_loader:
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if (epoch+1) % 3 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}")

def train_student_with_distillation(student, teacher, train_loader, optimizer, epochs, T=2.0, alpha=0.5):
    teacher.eval()
    for epoch in range(epochs):
        total_loss = 0
        for x, y in train_loader:
            optimizer.zero_grad()
            with torch.no_grad():
                teacher_logits = teacher(x)
            student_logits = student(x)
            loss = distillation_loss(student_logits, teacher_logits, y, T, alpha)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if (epoch+1) % 3 == 0:
            print(f"Student KD Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}")

# -----------------------------
# Evaluation Functions
# -----------------------------
def evaluate_accuracy(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in test_loader:
            outputs = model(x)
            predicted = torch.argmax(outputs, dim=1)
            correct += (predicted == y).sum().item()
            total += y.size(0)
    return correct / total

def evaluate_inference_time(model, test_loader, repeats=5):
    model.eval()
    times = []
    with torch.no_grad():
        for _ in range(repeats):
            start = time.time()
            for x, _ in test_loader:
                _ = model(x)
            end = time.time()
            times.append(end - start)
    avg_total_time = np.mean(times)
    avg_per_sample = avg_total_time / len(test_loader.dataset)
    return avg_total_time, avg_per_sample

def evaluate_model_size(model):
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    size_MB = num_params * 4 / (1024**2)
    return num_params, size_MB

# -----------------------------
# Create Models
# -----------------------------
teacher = TeacherCNN()
student = StudentCNN()
person = StudentCNN()  # Same as student

# -----------------------------
# Optimizers
# -----------------------------
optimizer_teacher = optim.Adam(teacher.parameters(), lr=lr)
optimizer_student = optim.Adam(student.parameters(), lr=lr)
optimizer_person = optim.Adam(person.parameters(), lr=lr)

# -----------------------------
# Train Teacher
# -----------------------------
print("Training Teacher...")
train_model(teacher, train_loader, optimizer_teacher, nn.CrossEntropyLoss(), epochs)

# -----------------------------
# Train Student with KD
# -----------------------------
print("\nTraining Student with KD...")
train_student_with_distillation(student, teacher, train_loader, optimizer_student, epochs, T=temperature, alpha=alpha)

# -----------------------------
# Train Person (Student architecture, normal training)
# -----------------------------
print("\nTraining Person (same as Student, normal)...")
train_model(person, train_loader, optimizer_person, nn.CrossEntropyLoss(), epochs)

# -----------------------------
# Evaluate
# -----------------------------
models = {'Teacher': teacher, 'Student': student, 'Person': person}
results = {}

for name, model in models.items():
    acc = evaluate_accuracy(model, test_loader)
    total_time, avg_time = evaluate_inference_time(model, test_loader)
    params, size = evaluate_model_size(model)
    results[name] = {
        'Accuracy': acc,
        'TotalTime': total_time,
        'AvgTimePerSample': avg_time,
        'Params': params,
        'Size_MB': size
    }

# -----------------------------
# Display Results
# -----------------------------
for name, res in results.items():
    print(f"\n{name} Metrics:")
    print(f"Accuracy: {res['Accuracy']:.4f}")
    print(f"Inference Time: {res['TotalTime']:.4f}s ({res['AvgTimePerSample']:.6f}s per sample)")
    print(f"Parameters: {res['Params']}, Size: {res['Size_MB']:.4f} MB")

# -----------------------------
# Compare % difference
# -----------------------------
def percent_change(student_val, teacher_val):
    return ((student_val - teacher_val)/teacher_val)*100 if teacher_val !=0 else float('inf')

print("\n--- Percentage Change vs Teacher ---")
for name in ['Student', 'Person']:
    print(f"\n{name} vs Teacher:")
    print(f"Accuracy: {percent_change(results[name]['Accuracy'], results['Teacher']['Accuracy']):.2f}%")
    print(f"Total Inference Time: {percent_change(results[name]['TotalTime'], results['Teacher']['TotalTime']):.2f}%")
    print(f"Avg Time per Sample: {percent_change(results[name]['AvgTimePerSample'], results['Teacher']['AvgTimePerSample']):.2f}%")
    print(f"Params: {percent_change(results[name]['Params'], results['Teacher']['Params']):.2f}%")
    print(f"Model Size: {percent_change(results[name]['Size_MB'], results['Teacher']['Size_MB']):.2f}%")
