# Knowledge Distillation on MNIST
Knowledge distillation is the process of transferring the higher performance of a more expensive model to a smaller one.  In this notebook, we will explore performing this process on MNIST.  To begin with, I have provided access to pre-trained model that is large, but performant.  The exact architecture is not relevant (although you can inspect this easily if you wish).  It is straightforward to load in pytorch with

In [1]:
import torch
device = 'cpu'

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(28**2,800)
        self.l2 = torch.nn.Linear(800,800)
        self.l3 = torch.nn.Linear(800,10)
        self.dropout2 = torch.nn.Dropout(0.5)
        self.dropout3 = torch.nn.Dropout(0.5)

    def forward(self, x):
        x = self.l1(x)
        x = torch.relu(x)
        x = self.dropout2(x)
        x = self.l2(x)
        x = torch.relu(x)
        x = self.dropout3(x)
        x = self.l3(x)
        return x
    
big_model = torch.load('pretrained_model.pt').to(device)

First, let's establish the baseline performance of the big model on the MNIST test set.  Of course we'll need acces to the MNIST test set to do this.  At the same time, let's also get our transfer set, which in this case will be a $n=10$k subset of the full MNIST training set (using a subset is helpful for speeding up training of distilled models, and also helps showcase some of the improved performance due to model distillation).   

In [2]:
from torchvision import transforms, datasets
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    ])

dataset_train = datasets.MNIST('./data', train=True, download=True, transform=transform)

dataset_test = datasets.MNIST('../data', train=False, download=True, transform=transform)

# This is a useful function that I didn't know about before
first_10k = list(range(0, 10000))
dataset_transfer = torch.utils.data.Subset(dataset_train, first_10k)

batch_size = 32
num_workers = 4
transfer_loader = torch.utils.data.DataLoader(dataset_transfer,batch_size=batch_size,num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(dataset_test,batch_size=batch_size,num_workers=num_workers)

Here's a function that runs the big model in test mode and provides the number of correct examples

In [3]:
def test(model,test_loader):
    correct = 0
    counter = 0
    model.eval()
    with torch.no_grad():
        for data,target in test_loader:
            data, target = data.to(device), target.to(device)
            data = data.reshape(data.shape[0],-1)
            logits = model(data)
            pred = logits.argmax(dim=1,keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            counter += batch_size
    return correct

test(big_model,test_loader)

9833

We find that the big model gets 167 examples wrong (not quite as good as the Hinton paper, but who cares). 

Now we would like to perform knowledge distillation by training a smaller model to approximate the larger model's performance on the transfer set.  First, let's build a smaller model.  You may use whatever architecture you choose, but I found that using two hidden layers, each with 200 units along with ReLU activations (and no regularization at all) worked fine.

In [4]:
class SmallNet(torch.nn.Module):
    def __init__(self):
        super(SmallNet, self).__init__()
        self.L1 = torch.nn.Linear(784, 200)
        self.L2 = torch.nn.Linear(200, 10)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.L1(x)
        x = self.relu(x)
        x = self.L2(x)
        return x
    
small_model = SmallNet()
small_model.to(device)

SmallNet(
  (L1): Linear(in_features=784, out_features=200, bias=True)
  (L2): Linear(in_features=200, out_features=10, bias=True)
  (relu): ReLU()
)

**To establish a baseline performance level, train the small model on the transfer set**  

In [5]:
# I'm giving you this training function: you'll need to modify it below to do knowledge distillation
def train(model,train_loader,n_epochs):
    optimizer = torch.optim.Adam(model.parameters(),1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()
    model.train()
    for epoch in range(n_epochs):
        avg_l = 0.0
        counter = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            data = data.reshape(data.shape[0],-1)
            optimizer.zero_grad()
            logits = model(data)
            L = loss_fn(logits,target)
            L.backward()
            optimizer.step()
            with torch.no_grad():
                avg_l += L
                counter += 1
        print(epoch,avg_l/counter)

train(small_model,transfer_loader,50)

0 tensor(0.4295)
1 tensor(0.2024)
2 tensor(0.1324)
3 tensor(0.0890)
4 tensor(0.0607)
5 tensor(0.0390)
6 tensor(0.0257)
7 tensor(0.0185)
8 tensor(0.0180)
9 tensor(0.0173)
10 tensor(0.0188)
11 tensor(0.0137)
12 tensor(0.0141)
13 tensor(0.0153)
14 tensor(0.0133)
15 tensor(0.0127)
16 tensor(0.0066)
17 tensor(0.0080)
18 tensor(0.0125)
19 tensor(0.0095)
20 tensor(0.0090)
21 tensor(0.0031)
22 tensor(0.0037)
23 tensor(0.0127)
24 tensor(0.0184)
25 tensor(0.0083)
26 tensor(0.0082)
27 tensor(0.0016)
28 tensor(0.0004)
29 tensor(0.0001)
30 tensor(7.0514e-05)
31 tensor(5.7303e-05)
32 tensor(4.9143e-05)
33 tensor(4.3081e-05)
34 tensor(3.8200e-05)
35 tensor(3.4161e-05)
36 tensor(3.0645e-05)
37 tensor(2.7562e-05)
38 tensor(2.4797e-05)
39 tensor(2.2340e-05)
40 tensor(2.0122e-05)
41 tensor(1.8134e-05)
42 tensor(1.6321e-05)
43 tensor(1.4683e-05)
44 tensor(1.3192e-05)
45 tensor(1.1857e-05)
46 tensor(1.0617e-05)
47 tensor(9.4939e-06)
48 tensor(8.4623e-06)
49 tensor(7.5294e-06)


**Evaluate the small model on the test set, and comment on its accuracy relative to the big model.**  As you might expect, the performance is relatively worse.  

In [6]:
test(small_model,test_loader)
# The small net gets about 200 more wrong which isn't too much worse but in the case of an easy MNIST dataset
# we should be able to do much better

9645

**The primary task of this notebook is now as follows: create a new training function similar to "train" above, but instead called "distill".**  "distill" should perform knowledge distillation as outlined in this week's paper.  It should accept a few additional arguments compared to train, namely the big model, the temperature hyperparameter, and a hyperparameter $\alpha$ that weights the relative magnitude of the soft target loss and the hard target loss.

In [7]:
distilled_model = SmallNet()
distilled_model.to(device)

# The body of this method is currently copied verbatim from the train method above: 
# you will need to modify it to utilize the big_model, temperature, and alpha values 
# to perform knowledge distillation
def distill(small_model,big_model,T,alpha,transfer_loader,n_epochs):
    optimizer = torch.optim.Adam(small_model.parameters(),1e-3)
    loss_fn = torch.nn.CrossEntropyLoss() # expects logits instead of probs, don't softmax before passing in
    sf = torch.nn.Softmax(dim = -1)
    # hard target loss - pass in logits
    # soft target pass in logits / T
    # targets - can accept an int class, 
    #           vector of probs of same size as logits
    small_model.train()
    for epoch in range(n_epochs):
        avg_l = 0.0
        counter = 0
        for batch_idx, (data, target) in enumerate(transfer_loader):
            data, target = data.to(device), target.to(device)
            data = data.reshape(data.shape[0],-1)
            optimizer.zero_grad()
            # small model
            small_logits = small_model(data)
            L_student = loss_fn(small_logits,target)
            # big model
            big_logits = big_model(data) # softmax these ones?
            L_teacher = loss_fn(small_logits / T, sf(big_logits / T))
            # weight the loss functions with alpha
            L = alpha*L_student + (1-alpha)*L_teacher
            L.backward()
            optimizer.step()
            with torch.no_grad():
                avg_l += L
                counter += 1
        print(epoch,avg_l/counter)


In [8]:
T = 3
alpha = 0.2
distill(distilled_model,big_model,T,alpha,transfer_loader,50)

0 tensor(0.5976)
1 tensor(0.3325)
2 tensor(0.2589)
3 tensor(0.2130)
4 tensor(0.1843)
5 tensor(0.1647)
6 tensor(0.1510)
7 tensor(0.1409)
8 tensor(0.1329)
9 tensor(0.1271)
10 tensor(0.1228)
11 tensor(0.1192)
12 tensor(0.1167)
13 tensor(0.1154)
14 tensor(0.1153)
15 tensor(0.1203)
16 tensor(0.1261)
17 tensor(0.1203)
18 tensor(0.1157)
19 tensor(0.1146)
20 tensor(0.1134)
21 tensor(0.1126)
22 tensor(0.1145)
23 tensor(0.1127)
24 tensor(0.1131)
25 tensor(0.1121)
26 tensor(0.1120)
27 tensor(0.1099)
28 tensor(0.1097)
29 tensor(0.1107)
30 tensor(0.1105)
31 tensor(0.1093)
32 tensor(0.1092)
33 tensor(0.1098)
34 tensor(0.1090)
35 tensor(0.1083)
36 tensor(0.1098)
37 tensor(0.1090)
38 tensor(0.1072)
39 tensor(0.1079)
40 tensor(0.1074)
41 tensor(0.1074)
42 tensor(0.1078)
43 tensor(0.1069)
44 tensor(0.1065)
45 tensor(0.1066)
46 tensor(0.1075)
47 tensor(0.1073)
48 tensor(0.1071)
49 tensor(0.1068)


**Finally, test your distilled model (on the test set) and describe how it performs relative to both big and small models.**
With a lower temperature my distilled model really didn't do any better than the small model (and both are worse than the big model). Increasing the temperature actually gave even worse performance. Changing alpha to be really small such that most of the information is coming from the large "teacher" model gave the biggest improvement.

In [9]:
test(distilled_model,test_loader)

9668

In [10]:
T = 20
alpha = 0.2
distill(distilled_model,big_model,T,alpha,transfer_loader,50)
test(distilled_model,test_loader)

0 tensor(1.3143)
1 tensor(1.3028)
2 tensor(1.2989)
3 tensor(1.2962)
4 tensor(1.2946)
5 tensor(1.2932)
6 tensor(1.2921)
7 tensor(1.2911)
8 tensor(1.2903)
9 tensor(1.2896)
10 tensor(1.2889)
11 tensor(1.2884)
12 tensor(1.2879)
13 tensor(1.2874)
14 tensor(1.2870)
15 tensor(1.2866)
16 tensor(1.2863)
17 tensor(1.2860)
18 tensor(1.2857)
19 tensor(1.2855)
20 tensor(1.2852)
21 tensor(1.2850)
22 tensor(1.2848)
23 tensor(1.2846)
24 tensor(1.2845)
25 tensor(1.2843)
26 tensor(1.2842)
27 tensor(1.2840)
28 tensor(1.2839)
29 tensor(1.2838)
30 tensor(1.2837)
31 tensor(1.2836)
32 tensor(1.2835)
33 tensor(1.2834)
34 tensor(1.2833)
35 tensor(1.2832)
36 tensor(1.2832)
37 tensor(1.2831)
38 tensor(1.2830)
39 tensor(1.2829)
40 tensor(1.2829)
41 tensor(1.2828)
42 tensor(1.2828)
43 tensor(1.2827)
44 tensor(1.2827)
45 tensor(1.2826)
46 tensor(1.2826)
47 tensor(1.2826)
48 tensor(1.2825)
49 tensor(1.2825)


9629

In [11]:
T = 3
alpha = 0.01
distill(distilled_model,big_model,T,alpha,transfer_loader,50)
test(distilled_model,test_loader)

0 tensor(0.1347)
1 tensor(0.1319)
2 tensor(0.1299)
3 tensor(0.1287)
4 tensor(0.1285)
5 tensor(0.1263)
6 tensor(0.1269)
7 tensor(0.1276)
8 tensor(0.1274)
9 tensor(0.1278)
10 tensor(0.1261)
11 tensor(0.1256)
12 tensor(0.1256)
13 tensor(0.1271)
14 tensor(0.1271)
15 tensor(0.1253)
16 tensor(0.1250)
17 tensor(0.1241)
18 tensor(0.1239)
19 tensor(0.1251)
20 tensor(0.1251)
21 tensor(0.1264)
22 tensor(0.1252)
23 tensor(0.1253)
24 tensor(0.1246)
25 tensor(0.1234)
26 tensor(0.1233)
27 tensor(0.1241)
28 tensor(0.1243)
29 tensor(0.1237)
30 tensor(0.1236)
31 tensor(0.1243)
32 tensor(0.1253)
33 tensor(0.1243)
34 tensor(0.1238)
35 tensor(0.1231)
36 tensor(0.1229)
37 tensor(0.1244)
38 tensor(0.1238)
39 tensor(0.1230)
40 tensor(0.1231)
41 tensor(0.1235)
42 tensor(0.1232)
43 tensor(0.1226)
44 tensor(0.1229)
45 tensor(0.1228)
46 tensor(0.1231)
47 tensor(0.1230)
48 tensor(0.1233)
49 tensor(0.1239)


9688