# Distillation
This notebook shows how the tool can be used to perform knowledge distillation.

## Set Up
* Import dependencies
* Import data loaders
* Import models

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import torch
import importlib
import inspect
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F

# Add thesis package to path
sys.path.append("../")
sys.path.append("../src/")

import src.general as general
import src.compression.distillation as distill
import src.metrics as metrics
import src.evaluation as eval
import src.plot as plot
import src.dataset_models as data
from models.mnist import *

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Get device
device = general.get_device()

# Load the dataset
dataset = data.supported_datasets["MNIST"]

In [4]:
model_state = "../models/mnist.pt"
teacher_model = torch.load(model_state, map_location=torch.device(device))
print(teacher_model)

MnistModel(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)


## Distillation
The original model acts as the teacher model. 

For the student model the user can either give a model architecture of their own, presented in a `.py` file, or use the the tool to intelligently design a student model. 

In [5]:
student_model = distill.create_student_model(teacher_model, dataset, fineTune=True)
print(teacher_model)
print(student_model)

Train:  10%|█         | 97/938 [00:02<00:22, 37.26it/s]


KeyboardInterrupt: 

In [None]:
# Test performance of student model before training
teacher_results = eval.get_results(teacher_model, dataset)
plot.print_results(**teacher_results)

Using cuda: False


Test: 100%|██████████| 157/157 [00:01<00:00, 91.62it/s]

Average loss = 0.0778
Accuracy = 97.6115
Elapsed time = 1716.67 milliseconds (10.93 per batch, 0.68 per data point)
Could not calculate FLOPS
Loss: 0.077825
Score: 97.611465
Time per data point: 0.6834 ms
Model Size: 1.65 MB
Number of parameters: 431080
Number of FLOPs: -1
Number of MACs: 2307728





In [None]:
distil_criterion = F.mse_loss
optimizer = optim.Adam(student_model.parameters(), lr=0.01)

distill.train(teacher_model, student_model, dataset.train_loader, distil_criterion, optimizer)

Distillation Training: 100%|██████████| 938/938 [00:15<00:00, 58.87it/s]


tensor(0.6245, grad_fn=<MseLossBackward0>)

In [None]:
settings = {
    "performance_target": 99,
    "fineTune": False,
    "epochs": 5,
}

distilled_model = distill.perform_distillation(teacher_model, dataset, settings)

Settings: {'performance_target': 99, 'fineTune': False, 'epochs': 5}
Fine-tuning: False




Distillation Validation: 100%|██████████| 157/157 [00:01<00:00, 115.08it/s]


Test loss: 0.11945149836361788, Test score: 96.89490445859873


Distillation Training: 100%|██████████| 938/938 [00:16<00:00, 57.83it/s]


Distillation loss: 5.034704208374023


Distillation Validation: 100%|██████████| 157/157 [00:01<00:00, 119.18it/s]


Test loss: 0.39766418582694546, Test score: 87.84832802547771
Stopped training because score started decreasing: from 96.89490445859873 to 87.84832802547771


## Evaluation
Analyze the metrics of the new student model

In [None]:
student_results = eval.get_results(student_model, dataset)
plot.print_before_after_results(teacher_results, student_results)

Test: 100%|██████████| 157/157 [00:01<00:00, 115.79it/s]

Test loss: 0.0925
Test score: 97.0939
Could not calculate FLOPS
Loss: 0.077825 -> 0.092510 (18.87%)
Score: 97.611465 -> 97.093949 (-0.53%)
Time per data point: 0.6834 ms -> 0.5410 ms (-20.84%)
Model Size: 1.65 MB -> 0.42 MB (-74.55%)
Number of parameters: 431080 -> 109295 (-74.65%)
Number of FLOPs: -1 -> -1 (-0.00%)
Number of MACs: 2307728 -> 653864 (-71.67%)



