In [1]:
import os
import numpy as np

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn as nn

In [3]:
# Defining the Distillation model along with its loss, this will train the student model from the teacher model
class Distiller(nn.Module):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        self.optimizer = optimizer
        self.metrics = metrics
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(self, x, y, y_pred, sample_weight=None, allow_empty=False):
        teacher_pred = self.teacher(x)
        student_loss = self.student_loss_fn(y_pred, y)

        teacher_softened = F.softmax(teacher_pred / self.temperature, dim=1)
        student_softened = F.softmax(y_pred / self.temperature, dim=1)

        distillation_loss = self.distillation_loss_fn(
            student_softened,
            teacher_softened
        ) * (self.temperature ** 2)

        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss

    def forward(self, x):
        return self.student(x)

In [4]:
import matplotlib.pyplot as plt

In [5]:
# Defining the student class
class Student(nn.Module):
    ''' Models a simple Convolutional Neural Network'''

# Instantiate the network layers

    def __init__(self):
        super(Student, self).__init__()
	# 3 input image channel, 8 output channels,
  # The image is 32x32 so after applying 5x5 kernel we get 28x28 image
	# 5x5 square convolution kernel, represented by the last argument 5

        self.conv1 = nn.Conv2d(3, 8, 5)

	# Max pooling over a (2, 2) window

        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(8, 18, 5)

        # after this we apply 2X2 pooling which basically halves the dimension of image along height and width, here we have just defined it, we will actually apply in the forward method
        # That is why image dimemsion becomes 5*5 from 10*10
        self.fc1 = nn.Linear(18 * 5 * 5, 120) # 5x5 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 18 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [6]:
# Instantiate the model
student = Student()

# Clone student for later comparison
import copy
student_scratch = copy.deepcopy(student)

In [7]:
# python image library of range [0, 1]
# transform them to tensors of normalized range[-1, 1]

transform = transforms.Compose( # composing several transforms together
    [transforms.ToTensor(), # to tensor object
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # mean = 0.5, std = 0.5

# set batch_size
batch_size = 4

# set number of workers
num_workers = 2

# load train data
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=num_workers)

# load test data
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=num_workers)

# put 10 classes into a set
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:13<00:00, 13022375.95it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [8]:
import torch.optim as optim


In [9]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(student.parameters(), lr=0.001, momentum=0.9)

In [10]:
import torchvision.models as models

resnet18 = models.resnet18(pretrained = True)
squeezenet = models.squeezenet1_0(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 191MB/s]
Downloading: "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth" to /root/.cache/torch/hub/checkpoints/squeezenet1_0-b66bff10.pth
100%|██████████| 4.78M/4.78M [00:00<00:00, 168MB/s]


In [11]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.4.3-py3-none-any.whl.metadata (19 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.7-py3-none-any.whl.metadata (5.2 kB)
Downloading torchmetrics-1.4.3-py3-none-any.whl (869 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/869.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m869.5/869.5 kB[0m [31m45.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.7-py3-none-any.whl (26 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.11.7 torchmetrics-1.4.3


In [12]:
teacher = models.resnet18(pretrained=True)
teacher.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(512, 10)
)  # Modify the output layer to match the number of classes in your dataset

# Freeze the parameters of the teacher model
for param in teacher.parameters():
    param.requires_grad = False

In [13]:
import torchmetrics

# Initialize and compile distiller

distiller = Distiller(student=student, teacher = teacher)

optimizer = optim.Adam(distiller.parameters())
metrics = {'accuracy': torchmetrics.Accuracy(task='MULTICLASS', num_classes = 10)}

def student_loss_fn(outputs, labels):
    return F.cross_entropy(outputs, labels)

def distillation_loss_fn(student_logits, teacher_probs):
    return torch.nn.KLDivLoss()(F.log_softmax(student_logits / distiller.temperature, dim=1), teacher_probs)

In [14]:
distiller.compile(
    optimizer=optimizer,
    metrics=metrics,
    student_loss_fn=student_loss_fn,
    distillation_loss_fn=distillation_loss_fn,
    alpha=0.1,
    temperature=10
)

In [15]:
# Train distiller
def train_distiller(model, train_loader, criterion, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            student_logits = model(inputs)
            teacher_probs = teacher(inputs)
            loss = criterion(inputs, labels, student_logits, teacher_probs)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}")

In [None]:
epochs = 4
train_distiller(distiller, trainloader, distiller.compute_loss, optimizer, epochs)


  self.pid = os.fork()


Epoch [1/4], Loss: 0.1785
Epoch [2/4], Loss: 0.1517


In [16]:
# Evaluate distiller on test dataset
def evaluate(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    print(f"Test Accuracy: {accuracy:.4f}")

In [17]:
evaluate(distiller, testloader)

Test Accuracy: 0.1000


In [18]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = teacher(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 9 %


In [19]:
# Training student model without the teacher model
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = student_scratch(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

# whatever you are timing goes here
end.record()

# Waits for everything to finish running
torch.cuda.synchronize()

print('Finished Training')
print(start.elapsed_time(end))  # milliseconds

[1,  2000] loss: 2.304
[1,  4000] loss: 2.303
[1,  6000] loss: 2.304
[1,  8000] loss: 2.303
[1, 10000] loss: 2.304
[1, 12000] loss: 2.305
[2,  2000] loss: 2.303
[2,  4000] loss: 2.304
[2,  6000] loss: 2.304
[2,  8000] loss: 2.304
[2, 10000] loss: 2.304
[2, 12000] loss: 2.304
Finished Training
130384.203125


In [20]:
# Performance of student model
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = student(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 10 %
