# Distillation
This notebook shows how the tool can be used to perform knowledge distillation.

## Set Up

In [1]:
import sys
import torch
import importlib
import inspect
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F

# Add thesis package to path
sys.path.append("../")

import src.general as general
import src.compression.distillation as distil

Import datasets

In [2]:
# Load MNIST dataset
batch_size = 8
test_batch_size = 1000
use_cuda = False

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
mnist_transform = transforms.ToTensor()
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True, transform=mnist_transform,),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True, transform=mnist_transform,),
    batch_size=test_batch_size, shuffle=True, **kwargs)

Import a trained model given by user.

In [3]:
model_state = "../models/mnist.pt"
model_class = "models.mnist"

# Import the module classes
module = importlib.import_module(model_class)
classes = general.get_module_classes(module)
for cls in classes:
    globals()[cls.__name__] = cls

# Get device
device = general.get_device()

# Load the model
model = torch.load(model_state, map_location=torch.device(device))

Using cuda: False


## Creating Student Model
The original model acts as the teacher model. 

For the student model the user can either give a model architecture of their own, presented in a `.py` file, or use the the tool to intelligently design a student model. 

In [4]:
from models.mnist_student import MnistStudent

student_model = MnistStudent()

In [5]:
epochs = 3
lr = 0.01
momentum = 0.5
log_interval = 100
import src.compression.distillation as distil

optimizer = optimizer = optim.Adam(student_model.parameters(), lr=lr) # Important: use the student model parameters
distil_criterion = F.mse_loss
test_criterion = F.nll_loss
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)

distil.distillation_train_loop(model, student_model, train_loader, test_loader, distil_criterion, test_criterion, optimizer, scheduler, epochs)

Epoch: 0, Test loss: 11.095250765482584, Test accuracy: 0.09035
Epoch: 1, Test loss: 11.095250622431438, Test accuracy: 0.09035
Epoch: 2, Test loss: 11.09525071779887, Test accuracy: 0.09035


In [6]:
# Define the teacher and student models
teacher = torch.nn.Linear(10, 5)
student = torch.nn.Linear(10, 5)

# Define the loss function
loss_fn = torch.nn.KLDivLoss(reduction='batchmean')

# Define the optimizer and learning rate schedule
optimizer = torch.optim.SGD(student.parameters(), lr=0.001, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
