<a href="https://colab.research.google.com/github/arkajyotimitra/mini_projects/blob/main/knowledgedistill.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# imports
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the teacher model
class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.maxpool(x)
        x = self.relu(self.conv2(x))
        x = self.maxpool(x)
        x = x.view(-1, 64 * 7 * 7)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the student model
class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(16 * 14 * 14, 10)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.maxpool(x)
        x = x.view(-1, 16 * 14 * 14)
        x = self.fc1(x)
        return x

# Hyperparameters
batch_size = 64
learning_rate = 0.01
num_epochs = 10
temperature = 5
alpha = 0.5

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

# Initialize models and optimizers
teacher_model = TeacherNet().to(device)
student_model = StudentNet().to(device)
teacher_optimizer = optim.SGD(teacher_model.parameters(), lr=learning_rate)
student_optimizer = optim.SGD(student_model.parameters(), lr=learning_rate)
criterion_hard = nn.CrossEntropyLoss()
criterion_soft = nn.KLDivLoss(reduction='batchmean')




In [2]:
# Train the teacher model
for epoch in range(num_epochs):
    teacher_model.train()
    for images, labels in trainloader:
        teacher_optimizer.zero_grad()
        images, labels = images.to(device), labels.to(device)
        outputs = teacher_model(images)
        loss = criterion_hard(outputs, labels)
        loss.backward()
        teacher_optimizer.step()
    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, loss.item()))

# Train the student model with knowledge distillation
for epoch in range(num_epochs):
    student_model.train()
    for images, labels in trainloader:
        student_optimizer.zero_grad()
        images, labels = images.to(device), labels.to(device)
        student_outputs = student_model(images)
        with torch.no_grad():
            teacher_outputs = teacher_model(images)
        loss_hard = criterion_hard(student_outputs, labels)
        loss_soft = criterion_soft(
            nn.functional.log_softmax(student_outputs / temperature, dim=1),
            nn.functional.softmax(teacher_outputs / temperature, dim=1)
        ) * (temperature ** 2)
        loss = alpha * loss_hard + (1 - alpha) * loss_soft
        loss.backward()
        student_optimizer.step()
    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, loss.item()))



Epoch [1/10], Loss: 0.3879
Epoch [2/10], Loss: 0.0459
Epoch [3/10], Loss: 0.0800
Epoch [4/10], Loss: 0.0338
Epoch [5/10], Loss: 0.1130
Epoch [6/10], Loss: 0.1544
Epoch [7/10], Loss: 0.0021
Epoch [8/10], Loss: 0.0234
Epoch [9/10], Loss: 0.0287
Epoch [10/10], Loss: 0.0202
Epoch [1/10], Loss: 0.9584
Epoch [2/10], Loss: 0.4507
Epoch [3/10], Loss: 0.3283
Epoch [4/10], Loss: 0.2648
Epoch [5/10], Loss: 0.2650
Epoch [6/10], Loss: 0.2633
Epoch [7/10], Loss: 0.2860
Epoch [8/10], Loss: 0.2910
Epoch [9/10], Loss: 0.2011
Epoch [10/10], Loss: 0.2465


In [3]:
# Evaluate the student model
student_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = student_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print('Accuracy of the student model on the 10000 test images: {} %'.format(accuracy))



Accuracy of the student model on the 10000 test images: 97.79 %


In [4]:
teacher_model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = teacher_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print('Accuracy of the student model on the 10000 test images: {} %'.format(accuracy))

Accuracy of the student model on the 10000 test images: 98.69 %
