# Distillation
This notebook shows how the tool can be used to perform knowledge distillation.

## Set Up
* Import dependencies
* Import data loaders
* Import models

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import torch
import importlib
import inspect
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F

# Add thesis package to path
sys.path.append("../")

import src.general as general
import src.compression.distillation as distill
import src.metrics as metrics
import src.evaluation as eval
import src.plot as plot
from models.mnist import *

In [3]:
# Load MNIST dataset
batch_size = 64
test_batch_size = 1000
use_cuda = False

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
mnist_transform = transforms.ToTensor()
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True, transform=mnist_transform,),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True, transform=mnist_transform,),
    batch_size=test_batch_size, shuffle=True, **kwargs)

In [4]:
model_state = "../models/mnist.pt"

device = general.get_device()
teacher_model = torch.load(model_state, map_location=torch.device(device))

Using cuda: False


In [5]:
teacher_model

MnistModel(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

## Distillation
The original model acts as the teacher model. 

For the student model the user can either give a model architecture of their own, presented in a `.py` file, or use the the tool to intelligently design a student model. 

In [6]:
# Load the student model
student_model = MnistSmallLinear()

In [7]:
input_batch = next(iter(test_loader))
example_input = input_batch[0][0]

In [8]:
# Test performance of student model before training
loss, score, duration, batch_duration, data_duration = general.test(student_model, device,  test_loader, criterion=F.nll_loss, metric = lambda x,y: metrics.accuracy_topk(x,y,topk=(1,))[0].item())

Test: 100%|██████████| 60/60 [00:01<00:00, 30.52it/s]

Average loss = 0.0010
Metric = 7.5550
Elapsed time = 1975.40 milliseconds (32.92 per batch, 0.03 per data point)





In [9]:
before_training_student_evaluation_metrics = {
    "loss": loss,
    "score": score,
    "duration": duration,
    "batch_duration": batch_duration,
    "data_duration": data_duration,
    "model": student_model,
    "batch_size": batch_size,
    "example_input": example_input,
}

plot.print_metrics(**before_training_student_evaluation_metrics)

Loss: 0.001031
Score: 7.555000
Time per batch: 32.9233 ms (64 per batch)
Time per data point: 0.0329 ms
Model Size: 0.15 MB
Number of parameters: 39760
Number of FLOPS: 39.75K


In [10]:

epochs = 5
lr = 0.01

optimizer = optim.Adam(student_model.parameters(), lr=lr) # Important: use the student model parameters
distil_criterion = F.mse_loss
eval_criterion = F.cross_entropy


distill.distillation_train_loop(teacher_model, student_model, train_loader, test_loader, distil_criterion, eval_criterion, optimizer, epochs)

Distillation Training: 100%|██████████| 938/938 [00:08<00:00, 115.56it/s]
Distillation Validation: 100%|██████████| 60/60 [00:01<00:00, 31.99it/s]


Epoch: 0
Distillation loss: 10.551363945007324
Test loss: 0.34507241000731786, Test accuracy: 0.9132666666666667


Distillation Training: 100%|██████████| 938/938 [00:08<00:00, 113.07it/s]
Distillation Validation: 100%|██████████| 60/60 [00:01<00:00, 31.89it/s]


Epoch: 1
Distillation loss: 8.376518249511719
Test loss: 0.24120648329456648, Test accuracy: 0.93515


Distillation Training: 100%|██████████| 938/938 [00:08<00:00, 116.86it/s]
Distillation Validation: 100%|██████████| 60/60 [00:01<00:00, 31.92it/s]


Epoch: 2
Distillation loss: 8.546671867370605
Test loss: 0.22435658077398937, Test accuracy: 0.9376833333333333


Distillation Training: 100%|██████████| 938/938 [00:08<00:00, 112.65it/s]
Distillation Validation: 100%|██████████| 60/60 [00:02<00:00, 27.06it/s]


Epoch: 3
Distillation loss: 11.111848831176758
Test loss: 0.2286367081105709, Test accuracy: 0.9371333333333334


Distillation Training: 100%|██████████| 938/938 [00:08<00:00, 107.03it/s]
Distillation Validation: 100%|██████████| 60/60 [00:01<00:00, 31.61it/s]

Epoch: 4
Distillation loss: 9.468725204467773
Test loss: 0.20371440450350445, Test accuracy: 0.9424666666666667





In [11]:
# Evaluate after training
print("Student model performance:")
loss, score, duration, batch_duration, data_duration = general.test(student_model, device,  test_loader, criterion=F.nll_loss, metric = metrics.accuracy)
after_training_student_evaluation_metrics = {
    "loss": loss,
    "score": score,
    "duration": duration,
    "batch_duration": batch_duration,
    "data_duration": data_duration,
    "model": student_model,
    "batch_size": batch_size,
    "example_input": example_input,
}
plot.print_before_after_metrics(before_training_student_evaluation_metrics, after_training_student_evaluation_metrics)

Student model performance:


Test: 100%|██████████| 60/60 [00:01<00:00, 31.19it/s]

Average loss = 1.3820
Accuracy = 0.9425
Elapsed time = 1924.83 milliseconds (32.08 per batch, 0.03 per data point)
Loss: 0.001031 -> 1.381984
Score: 7.555000 -> 0.942467 
Time per batch: 32.9233 ms -> 32.0805 ms (64 per batch)
Time per data point: 0.0329 ms -> 0.0321 ms
Model Size: 0.15 MB -> 0.15 MB
Number of parameters: 39760 -> 39760
Number of FLOPS: 39.75K -> 39.75K





## Evaluation
Analyze the metrics of the new student model

In [13]:
# Test model performance after distillation
print("Teacher model performance:")
loss, score, duration, batch_duration, data_duration = general.test(teacher_model, device, test_loader, criterion=F.nll_loss, metric = metrics.accuracy)
teacher_evaluation_metrics = {
    "model": teacher_model,
    "loss": loss,
    "score": score,
    "duration": duration,
    "batch_duration": batch_duration,
    "data_duration": data_duration,
    "batch_size": batch_size,
    # "example_input": example_input,
}
plot.print_metrics(**teacher_evaluation_metrics)
print("Student model performance:")
loss, score, duration, batch_duration, data_duration = general.test(student_model, device, test_loader, criterion=F.nll_loss, metric = metrics.accuracy)
student_evaluation_metrics = {
    "model": student_model,
    "loss": loss,
    "score": score,
    "duration": duration,
    "batch_duration": batch_duration,
    "data_duration": data_duration,
    "batch_size": batch_size,
    "example_input": example_input,
}
print('\n\n')


# # Compare the number of parameters of the teacher and student model
# teacher_params = eval.get_model_parameters(teacher_model)
# student_params = eval.get_model_parameters(student_model)
# print('Number of parameters: {} (Teacher) -> {} (Student)'.format(teacher_params, student_params))

# # Compare the model size of the teacher and student model
# teacher_size = eval.get_model_size(teacher_model)
# student_size = eval.get_model_size(student_model)
# print('Model Size: {} MB (Teacher) -> {} MB (Student)'.format(teacher_size, student_size))


plot.print_before_after_metrics(teacher_evaluation_metrics, student_evaluation_metrics)



Teacher model performance:


Test: 100%|██████████| 60/60 [00:04<00:00, 12.48it/s]


Average loss = 0.0174
Accuracy = 0.9946
Elapsed time = 4808.84 milliseconds (80.15 per batch, 0.08 per data point)
Loss: 0.017429
Score: 0.994583
Time per batch: 80.1473 ms (64 per batch)
Time per data point: 0.0801 ms
Model Size: 1.65 MB
Number of parameters: 431080
Student model performance:


Test: 100%|██████████| 60/60 [00:01<00:00, 31.41it/s]

Average loss = 1.3820
Accuracy = 0.9425
Elapsed time = 1911.41 milliseconds (31.86 per batch, 0.03 per data point)



Loss: 0.017429 -> 1.381984
Score: 0.994583 -> 0.942467 
Time per batch: 80.1473 ms -> 31.8569 ms (64 per batch)
Time per data point: 0.0801 ms -> 0.0319 ms
Model Size: 1.65 MB -> 0.15 MB
Number of parameters: 431080 -> 39760



