# Knowledge Distillation on MNIST
Knowledge distillation is the process of transferring the higher performance of a more expensive model to a smaller one.  In this notebook, we will explore performing this process on MNIST.  To begin with, I have provided access to pre-trained model that is large, but performant.  The exact architecture is not relevant (although you can inspect this easily if you wish).  It is straightforward to load in pytorch with

In [2]:
import torch
device = 'cpu'

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(28**2,800)
        self.l2 = torch.nn.Linear(800,800)
        self.l3 = torch.nn.Linear(800,10)
        self.dropout2 = torch.nn.Dropout(0.5)
        self.dropout3 = torch.nn.Dropout(0.5)

    def forward(self, x):
        x = self.l1(x)
        x = torch.relu(x)
        x = self.dropout2(x)
        x = self.l2(x)
        x = torch.relu(x)
        x = self.dropout3(x)
        x = self.l3(x)
        return x
    
big_model = torch.load('pretrained_model.pt').to(device)


Net(
  (l1): Linear(in_features=784, out_features=800, bias=True)
  (l2): Linear(in_features=800, out_features=800, bias=True)
  (l3): Linear(in_features=800, out_features=10, bias=True)
  (dropout2): Dropout(p=0.5, inplace=False)
  (dropout3): Dropout(p=0.5, inplace=False)
)


First, let's establish the baseline performance of the big model on the MNIST test set.  Of course we'll need acces to the MNIST test set to do this.  At the same time, let's also get our transfer set, which in this case will be a $n=10k$ subset of the full MNIST training set (using a subset is helpful for speeding up training of distilled models, and also helps showcase some of the improved performance due to model distillation).   

In [3]:
from torchvision import transforms, datasets
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    ])

dataset_train = datasets.MNIST('./data', train=True, download=True, transform=transform)

dataset_test = datasets.MNIST('../data', train=False, download=True, transform=transform)

# This is a useful function that I didn't know about before
first_10k = list(range(0, 10000))
dataset_transfer = torch.utils.data.Subset(dataset_train, first_10k)

batch_size = 32
num_workers = 4
transfer_loader = torch.utils.data.DataLoader(dataset_transfer,batch_size=batch_size,num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(dataset_test,batch_size=batch_size,num_workers=num_workers)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data\MNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../data\MNIST\raw\train-images-idx3-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data\MNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../data\MNIST\raw\train-labels-idx1-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data\MNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../data\MNIST\raw\t10k-images-idx3-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../data\MNIST\raw



Here's a function that runs the big model in test mode and provides the number of correct examples

In [10]:
def test(model,test_loader):
    correct = 0
    counter = 0
    model.eval()
    with torch.no_grad():
        for data,target in test_loader:
            data, target = data.to(device), target.to(device)
            data = data.reshape(data.shape[0],-1)
            logits = model(data)
            pred = logits.argmax(dim=1,keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            counter += batch_size
    return correct

num_correct = test(big_model,test_loader)
print( "{} -> {:.2f}%".format( num_correct, num_correct/1e4 * 100.))

9833 -> 98.33%


We find that the big model gets 167 examples wrong (not quite as good as the Hinton paper, but who cares). 

Now we would like to perform knowledge distillation by training a smaller model to approximate the larger model's performance on the transfer set.  First, let's build a smaller model.  You may use whatever architecture you choose, but I found that using two hidden layers, each with 200 units along with ReLU activations (and no regularization at all) worked fine.

In [15]:
class SmallNet(torch.nn.Module):
    def __init__(self):
        super(SmallNet, self).__init__()
        # Build a SmallNet
        self.l1 = torch.nn.Linear(28**2,200)
        self.l2 = torch.nn.Linear(200,10)

    def forward(self, x):
        # Don't forget to put the right operations here too!
        x = self.l1(x)
        x = torch.relu(x)
        x = self.l2(x)
        return x
    
small_model = SmallNet()
small_model.to(device)

SmallNet(
  (linear_1): Linear(in_features=784, out_features=200, bias=True)
  (linear_2): Linear(in_features=200, out_features=10, bias=True)
)

**To establish a baseline performance level, train the small model on the transfer set**  

In [24]:
# I'm giving you this training function: you'll need to modify it below to do knowledge distillation
def train(model,train_loader,n_epochs):
    optimizer = torch.optim.Adam(model.parameters(),1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()
    model.train()
    for epoch in range(n_epochs):
        avg_l = 0.0
        counter = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            data = data.reshape(data.shape[0],-1)
            optimizer.zero_grad()
            logits = model(data)
            L = loss_fn(logits,target)
            L.backward()
            optimizer.step()
            with torch.no_grad():
                avg_l += L
                counter += 1
        print(epoch,avg_l/counter)

train(small_model,transfer_loader,50)

0 tensor(1.2626)
1 tensor(1.1937)
2 tensor(1.1214)
3 tensor(1.0694)
4 tensor(1.0307)
5 tensor(0.9929)
6 tensor(0.9606)
7 tensor(0.9316)
8 tensor(0.9073)
9 tensor(0.8826)
10 tensor(0.8567)
11 tensor(0.8310)
12 tensor(0.8051)
13 tensor(0.7834)
14 tensor(0.7644)
15 tensor(0.7460)
16 tensor(0.7274)
17 tensor(0.7076)
18 tensor(0.6774)
19 tensor(0.6501)
20 tensor(0.6294)
21 tensor(0.6140)
22 tensor(0.6016)
23 tensor(0.5909)
24 tensor(0.5797)
25 tensor(0.5709)
26 tensor(0.5648)
27 tensor(0.5573)
28 tensor(0.5499)
29 tensor(0.5422)
30 tensor(0.5354)
31 tensor(0.5269)
32 tensor(0.5166)
33 tensor(0.5066)
34 tensor(0.4993)
35 tensor(0.4931)
36 tensor(0.4877)
37 tensor(0.4861)
38 tensor(0.4794)
39 tensor(0.4749)
40 tensor(0.4697)
41 tensor(0.4654)
42 tensor(0.4614)
43 tensor(0.4529)
44 tensor(0.4463)
45 tensor(0.4375)
46 tensor(0.4302)
47 tensor(0.4243)
48 tensor(0.4187)
49 tensor(0.4137)


**Evaluate the small model on the test set, and comment on its accuracy relative to the big model.**  As you might expect, the performance is relatively worse.  

In [26]:
num_correct = test(small_model,test_loader)
print( "{} -> {:.2f}%".format( num_correct, num_correct/1e4 * 100.))

8401 -> 84.01%


**The primary task of this notebook is now as follows: create a new training function similar to "train" above, but instead called "distill".**  "distill" should perform knowledge distillation as outlined in this week's paper.  It should accept a few additional arguments compared to train, namely the big model, the temperature hyperparameter, and a hyperparameter $\alpha$ that weights the relative magnitude of the soft target loss and the hard target loss.

In [33]:
distilled_model = SmallNet()
distilled_model.to(device)

# The body of this method is currently copied verbatim from the train method above: 
# you will need to modify it to utilize the big_model, temperature, and alpha values 
# to perform knowledge distillation
def distill(small_model,big_model,T,alpha,transfer_loader,n_epochs):
    F = torch.nn.functional
    optimizer = torch.optim.Adam(small_model.parameters(),1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()
    small_model.train()
    for epoch in range(n_epochs):
        avg_l = 0.0
        counter = 0
        for batch_idx, (data, target) in enumerate(transfer_loader):
            
            # load and prep data
            data, target = data.to(device), target.to(device)
            data = data.reshape(data.shape[0],-1)
            optimizer.zero_grad()
            
            # eval data on both models
            logits_small = small_model(data)
            logits_big   = big_model(data)

            # do softmax with Temp
            y      = logits_small
            y_t    = logits_small/T
            y_soft = F.softmax(logits_big/T, dim=1)
            
            # loss of logits of small model to soft targets
            L1 = loss_fn( y_t,y_soft )
            # loss on logits of  small model to hard targets
            L2 = loss_fn( y,target )

            # weighted avg of the two losses
            L_wavg = (1-alpha)*L1 + alpha*L2
            
            # gradient on avg loss
            L_wavg.backward()
            # backprop
            optimizer.step()

            with torch.no_grad():
                avg_l += L_wavg
                counter += 1
        print(epoch,(avg_l/counter).item())
    
T = 20
alpha = 1e-1
distill(distilled_model,big_model,T,alpha,transfer_loader,50)

0 1.7858387231826782
1 1.6275298595428467
2 1.579005479812622
3 1.5460394620895386
4 1.5232347249984741
5 1.5070286989212036
6 1.4955511093139648
7 1.487117052078247
8 1.4802067279815674
9 1.4749900102615356
10 1.4710147380828857
11 1.4677759408950806
12 1.4649893045425415
13 1.4625779390335083
14 1.4608081579208374
15 1.4589192867279053
16 1.4575626850128174
17 1.4562504291534424
18 1.455114483833313
19 1.4541563987731934
20 1.4532109498977661
21 1.4524013996124268
22 1.4516109228134155
23 1.450930118560791
24 1.4503288269042969
25 1.4497402906417847
26 1.4492292404174805
27 1.4487626552581787
28 1.448309063911438
29 1.4479272365570068
30 1.4475839138031006
31 1.4472436904907227
32 1.4469393491744995
33 1.4466742277145386
34 1.4464104175567627
35 1.4461771249771118
36 1.4459580183029175
37 1.4457341432571411
38 1.4455450773239136
39 1.445366621017456
40 1.4452158212661743
41 1.4450560808181763
42 1.4449042081832886
43 1.4447740316390991
44 1.444637417793274
45 1.4445123672485352
46 1.

**Finally, test your distilled model (on the test set) and describe how it performs relative to both big and small models.**

In [35]:
num_correct = test(distilled_model,test_loader)
print( "{} -> {:.2f}%".format( num_correct, num_correct/1e4 * 100.))

9650 -> 96.50%
