In [1]:
from __future__ import print_function
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

import os
import time

from tqdm import tqdm
from tqdm.notebook import tqdm as tqdm_n

In [2]:
batch_size = 128
test_batch_size = 1000

In [3]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./mnist_data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, 
    shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./mnist_data', train=False, 
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=test_batch_size,
    shuffle=True
)

### Model fit

In [4]:
class Teacher(nn.Module):
    def __init__(self):
        super(Teacher, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, 600)
        self.fc4 = nn.Linear(600, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.8, training=self.training)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p=0.8, training=self.training)
        x = F.relu(self.fc3(x))
        x = F.dropout(x, p=0.8, training=self.training)
        x = self.fc4(x)
        return x

In [5]:
teacher = Teacher()

In [6]:
def train(model, optimizer, epochs, eval_per_epochs, train_loader, test_loader):
    for epoch in range(1, epochs + 1):
        model.train()

        epoch_loss = 0
        for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
            optimizer.zero_grad()

            data, target = Variable(data), Variable(target)
            output = model(data)

            loss = F.cross_entropy(output, target)
            loss.backward()
            epoch_loss += loss.item()

            optimizer.step()

        epoch_loss /= len(train_loader)
        print(f'Epoch: {epoch}, epoch loss: {epoch_loss:.6f}')

        if epoch % eval_per_epochs == 0:
            model.eval()

            test_loss = 0
            correct = 0
            for data, target in test_loader:
                data, target = Variable(data, volatile=True), Variable(target)
                output = model(data)
                test_loss += F.cross_entropy(output, target).item()

                pred = output.data.max(1, keepdim=True)[1]
                correct += pred.eq(target.data.view_as(pred)).sum()

            print(f'Test loss: {test_loss / len(test_loader):.4f}, '
                  f'accuracy: {correct.item() / len(test_loader.dataset):.3f}')

In [7]:
epochs = 10
lr = 0.01
momentum = 0.5

optimizer = optim.SGD(teacher.parameters(), lr=lr, momentum=momentum)

In [8]:
epochs = 50
eval_per_epochs = 10

train(teacher, optimizer, epochs, eval_per_epochs, train_loader, test_loader)

100%|██████████| 469/469 [00:19<00:00, 24.01it/s]
  0%|          | 2/469 [00:00<00:24, 18.68it/s]

Epoch: 1, epoch loss: 1.903513


100%|██████████| 469/469 [00:26<00:00, 17.94it/s]
  0%|          | 2/469 [00:00<00:28, 16.66it/s]

Epoch: 2, epoch loss: 0.933519


100%|██████████| 469/469 [00:21<00:00, 22.01it/s]
  0%|          | 2/469 [00:00<00:33, 14.14it/s]

Epoch: 3, epoch loss: 0.642852


100%|██████████| 469/469 [00:22<00:00, 21.06it/s]
  1%|          | 3/469 [00:00<00:17, 27.34it/s]

Epoch: 4, epoch loss: 0.524164


100%|██████████| 469/469 [00:22<00:00, 21.30it/s]
  1%|          | 3/469 [00:00<00:17, 27.22it/s]

Epoch: 5, epoch loss: 0.445027


100%|██████████| 469/469 [00:21<00:00, 22.26it/s]
  1%|          | 3/469 [00:00<00:22, 20.37it/s]

Epoch: 6, epoch loss: 0.391232


100%|██████████| 469/469 [00:18<00:00, 24.90it/s]
  1%|          | 3/469 [00:00<00:18, 25.24it/s]

Epoch: 7, epoch loss: 0.353656


100%|██████████| 469/469 [00:21<00:00, 21.33it/s]
  0%|          | 2/469 [00:00<00:28, 16.35it/s]

Epoch: 8, epoch loss: 0.325232


100%|██████████| 469/469 [00:23<00:00, 20.03it/s]
  1%|          | 3/469 [00:00<00:19, 24.05it/s]

Epoch: 9, epoch loss: 0.299241


100%|██████████| 469/469 [00:18<00:00, 25.79it/s]


Epoch: 10, epoch loss: 0.278129


  1%|          | 3/469 [00:00<00:19, 23.90it/s]

Test loss: 0.1516, accuracy: 0.955


100%|██████████| 469/469 [00:20<00:00, 23.31it/s]
  1%|          | 4/469 [00:00<00:12, 36.44it/s]

Epoch: 11, epoch loss: 0.264148


100%|██████████| 469/469 [00:12<00:00, 36.51it/s]
  1%|          | 4/469 [00:00<00:12, 36.78it/s]

Epoch: 12, epoch loss: 0.248060


100%|██████████| 469/469 [00:17<00:00, 27.35it/s]
  1%|          | 4/469 [00:00<00:12, 36.54it/s]

Epoch: 13, epoch loss: 0.239690


100%|██████████| 469/469 [00:14<00:00, 32.44it/s]
  1%|          | 3/469 [00:00<00:16, 28.14it/s]

Epoch: 14, epoch loss: 0.227878


100%|██████████| 469/469 [00:18<00:00, 25.91it/s]
  0%|          | 2/469 [00:00<00:24, 19.24it/s]

Epoch: 15, epoch loss: 0.217054


100%|██████████| 469/469 [00:21<00:00, 21.79it/s]
  0%|          | 2/469 [00:00<00:27, 17.14it/s]

Epoch: 16, epoch loss: 0.209481


100%|██████████| 469/469 [00:20<00:00, 22.78it/s]
  1%|          | 3/469 [00:00<00:16, 27.97it/s]

Epoch: 17, epoch loss: 0.203830


100%|██████████| 469/469 [00:18<00:00, 24.80it/s]
  1%|          | 3/469 [00:00<00:17, 26.88it/s]

Epoch: 18, epoch loss: 0.193524


100%|██████████| 469/469 [00:17<00:00, 26.20it/s]
  1%|          | 3/469 [00:00<00:17, 26.93it/s]

Epoch: 19, epoch loss: 0.190762


100%|██████████| 469/469 [00:19<00:00, 24.09it/s]


Epoch: 20, epoch loss: 0.185100


  1%|          | 3/469 [00:00<00:20, 23.16it/s]

Test loss: 0.1026, accuracy: 0.972


100%|██████████| 469/469 [00:17<00:00, 26.13it/s]
  0%|          | 2/469 [00:00<00:25, 18.33it/s]

Epoch: 21, epoch loss: 0.175030


100%|██████████| 469/469 [00:19<00:00, 24.04it/s]
  1%|          | 3/469 [00:00<00:16, 28.63it/s]

Epoch: 22, epoch loss: 0.175381


100%|██████████| 469/469 [00:16<00:00, 27.62it/s]
  1%|          | 3/469 [00:00<00:16, 27.74it/s]

Epoch: 23, epoch loss: 0.170436


100%|██████████| 469/469 [00:16<00:00, 27.94it/s]
  1%|          | 3/469 [00:00<00:16, 28.78it/s]

Epoch: 24, epoch loss: 0.161836


100%|██████████| 469/469 [00:18<00:00, 25.81it/s]
  1%|          | 3/469 [00:00<00:16, 28.58it/s]

Epoch: 25, epoch loss: 0.158005


100%|██████████| 469/469 [00:18<00:00, 24.87it/s]
  0%|          | 2/469 [00:00<00:28, 16.32it/s]

Epoch: 26, epoch loss: 0.157334


100%|██████████| 469/469 [00:20<00:00, 22.52it/s]
  1%|          | 3/469 [00:00<00:16, 27.65it/s]

Epoch: 27, epoch loss: 0.146983


100%|██████████| 469/469 [00:19<00:00, 23.53it/s]
  1%|          | 3/469 [00:00<00:16, 28.56it/s]

Epoch: 28, epoch loss: 0.147488


100%|██████████| 469/469 [00:17<00:00, 26.45it/s]
  1%|          | 3/469 [00:00<00:21, 21.58it/s]

Epoch: 29, epoch loss: 0.144601


100%|██████████| 469/469 [00:18<00:00, 25.86it/s]


Epoch: 30, epoch loss: 0.141954


  1%|          | 3/469 [00:00<00:19, 23.58it/s]

Test loss: 0.0878, accuracy: 0.975


100%|██████████| 469/469 [00:23<00:00, 20.06it/s]
  0%|          | 2/469 [00:00<00:23, 19.73it/s]

Epoch: 31, epoch loss: 0.140263


100%|██████████| 469/469 [00:20<00:00, 23.25it/s]
  1%|          | 3/469 [00:00<00:16, 27.45it/s]

Epoch: 32, epoch loss: 0.136498


100%|██████████| 469/469 [00:18<00:00, 24.69it/s]
  1%|          | 3/469 [00:00<00:21, 21.60it/s]

Epoch: 33, epoch loss: 0.134875


100%|██████████| 469/469 [00:20<00:00, 23.08it/s]
  1%|          | 3/469 [00:00<00:22, 20.58it/s]

Epoch: 34, epoch loss: 0.131712


100%|██████████| 469/469 [00:19<00:00, 23.97it/s]
  1%|          | 3/469 [00:00<00:16, 28.06it/s]

Epoch: 35, epoch loss: 0.128886


100%|██████████| 469/469 [00:21<00:00, 22.20it/s]
  1%|          | 3/469 [00:00<00:16, 27.65it/s]

Epoch: 36, epoch loss: 0.128013


100%|██████████| 469/469 [00:19<00:00, 23.94it/s]
  1%|          | 3/469 [00:00<00:16, 28.02it/s]

Epoch: 37, epoch loss: 0.121535


100%|██████████| 469/469 [00:17<00:00, 26.62it/s]
  1%|          | 3/469 [00:00<00:16, 28.30it/s]

Epoch: 38, epoch loss: 0.121286


100%|██████████| 469/469 [00:17<00:00, 26.16it/s]
  1%|          | 3/469 [00:00<00:16, 27.90it/s]

Epoch: 39, epoch loss: 0.123179


100%|██████████| 469/469 [00:21<00:00, 21.81it/s]


Epoch: 40, epoch loss: 0.117798


  0%|          | 2/469 [00:00<00:33, 13.77it/s]

Test loss: 0.0795, accuracy: 0.978


100%|██████████| 469/469 [00:23<00:00, 19.91it/s]
  1%|          | 3/469 [00:00<00:17, 27.33it/s]

Epoch: 41, epoch loss: 0.115815


100%|██████████| 469/469 [00:19<00:00, 23.69it/s]
  1%|          | 3/469 [00:00<00:21, 21.37it/s]

Epoch: 42, epoch loss: 0.110434


100%|██████████| 469/469 [00:18<00:00, 24.85it/s]
  1%|          | 3/469 [00:00<00:16, 27.91it/s]

Epoch: 43, epoch loss: 0.112087


100%|██████████| 469/469 [00:18<00:00, 26.01it/s]
  1%|          | 3/469 [00:00<00:21, 22.14it/s]

Epoch: 44, epoch loss: 0.111542


100%|██████████| 469/469 [00:19<00:00, 23.91it/s]
  1%|          | 3/469 [00:00<00:16, 28.42it/s]

Epoch: 45, epoch loss: 0.112315


100%|██████████| 469/469 [00:16<00:00, 27.74it/s]
  1%|          | 3/469 [00:00<00:16, 28.37it/s]

Epoch: 46, epoch loss: 0.109563


100%|██████████| 469/469 [00:19<00:00, 23.75it/s]
  1%|          | 3/469 [00:00<00:17, 26.12it/s]

Epoch: 47, epoch loss: 0.110007


100%|██████████| 469/469 [00:19<00:00, 23.92it/s]
  1%|          | 3/469 [00:00<00:17, 27.13it/s]

Epoch: 48, epoch loss: 0.104591


100%|██████████| 469/469 [00:19<00:00, 23.73it/s]
  1%|          | 3/469 [00:00<00:16, 27.42it/s]

Epoch: 49, epoch loss: 0.105738


100%|██████████| 469/469 [00:18<00:00, 24.74it/s]


Epoch: 50, epoch loss: 0.100993
Test loss: 0.0757, accuracy: 0.980


In [9]:
torch.save(teacher.state_dict(), 'models/teacher.pth')

### Distillation

In [10]:
class Student(nn.Module):
    def __init__(self):
        super(Student, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 300)
        self.fc2 = nn.Linear(300, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [11]:
student = Student()

In [12]:
def distillation_loss(y, labels, teacher_scores, T=20, alpha=0.7):
    kldivloss = nn.KLDivLoss()(F.log_softmax(y/T), F.softmax(teacher_scores/T))
    temp_impact = (T*T * 2.0 * alpha)
    cross_entropy = F.cross_entropy(y, labels) * (1. - alpha)
    loss = kldivloss * temp_impact + cross_entropy
    return loss

In [13]:
def train(model, teacher_model, loss_fn, optimizer, epochs, eval_per_epochs):
    teacher_model.eval()

    epoch_loss = 0
    for epoch in range(1, epochs + 1):
        model.train()

        for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
            optimizer.zero_grad()

            data, target = Variable(data), Variable(target)
            output = model(data)
            teacher_output = Variable(teacher_model(data), requires_grad=False)

            loss = loss_fn(output, target, teacher_output)
            loss.backward()
            epoch_loss += loss.item()

            optimizer.step()

        epoch_loss /= len(train_loader)
        print(f'Epoch: {epoch}, epoch distillation loss: {epoch_loss:.6f}')

        if epoch % eval_per_epochs == 0:
            model.eval()

            test_loss = 0
            correct = 0
            for data, target in test_loader:
                data, target = Variable(data, volatile=True), Variable(target)
                output = model(data)
                test_loss += F.cross_entropy(output, target).item()

                pred = output.data.max(1, keepdim=True)[1]
                correct += pred.eq(target.data.view_as(pred)).sum()

            print(f'Student test loss: {test_loss / len(test_loader):.4f}, '
                  f'student accuracy: {correct.item() / len(test_loader.dataset):.3f}')

In [14]:
epochs = 10
lr = 0.01
momentum = 0.9

optimizer = optim.SGD(student.parameters(), lr=lr, momentum=momentum)

In [None]:
epochs = 50
eval_per_epochs = 10

train(student, teacher, distillation_loss, optimizer, epochs, eval_per_epochs)

  from ipykernel import kernelapp as app
  from ipykernel import kernelapp as app
100%|██████████| 469/469 [00:13<00:00, 34.31it/s]
  1%|          | 3/469 [00:00<00:15, 29.34it/s]

Epoch: 1, epoch distillation loss: 1.001934


100%|██████████| 469/469 [00:13<00:00, 35.13it/s]
  1%|          | 5/469 [00:00<00:10, 45.25it/s]

Epoch: 2, epoch distillation loss: 0.239985


100%|██████████| 469/469 [00:14<00:00, 33.30it/s]
  1%|          | 5/469 [00:00<00:10, 45.14it/s]

Epoch: 3, epoch distillation loss: 0.150917


100%|██████████| 469/469 [00:13<00:00, 34.63it/s]
  1%|          | 3/469 [00:00<00:19, 23.82it/s]

Epoch: 4, epoch distillation loss: 0.113777


100%|██████████| 469/469 [00:13<00:00, 35.78it/s]
  1%|          | 5/469 [00:00<00:09, 46.46it/s]

Epoch: 5, epoch distillation loss: 0.093237


100%|██████████| 469/469 [00:12<00:00, 38.80it/s]
  1%|          | 5/469 [00:00<00:10, 43.65it/s]

Epoch: 6, epoch distillation loss: 0.079874


100%|██████████| 469/469 [00:16<00:00, 28.92it/s]
  1%|          | 5/469 [00:00<00:10, 45.12it/s]

Epoch: 7, epoch distillation loss: 0.070695


100%|██████████| 469/469 [00:11<00:00, 39.66it/s]
  1%|          | 5/469 [00:00<00:10, 43.28it/s]

Epoch: 8, epoch distillation loss: 0.063801


100%|██████████| 469/469 [00:12<00:00, 36.88it/s]
  1%|          | 3/469 [00:00<00:17, 27.07it/s]

Epoch: 9, epoch distillation loss: 0.058533


100%|██████████| 469/469 [00:15<00:00, 30.84it/s]


Epoch: 10, epoch distillation loss: 0.054227


  1%|          | 3/469 [00:00<00:18, 25.60it/s]

Student test loss: 0.0885, student accuracy: 0.975


100%|██████████| 469/469 [00:12<00:00, 36.84it/s]
  1%|          | 3/469 [00:00<00:16, 27.87it/s]

Epoch: 11, epoch distillation loss: 0.050821


100%|██████████| 469/469 [00:16<00:00, 28.93it/s]
  1%|          | 3/469 [00:00<00:17, 27.38it/s]

Epoch: 12, epoch distillation loss: 0.047963


100%|██████████| 469/469 [00:13<00:00, 33.72it/s]
  1%|          | 5/469 [00:00<00:10, 45.10it/s]

Epoch: 13, epoch distillation loss: 0.045596


 78%|███████▊  | 367/469 [00:09<00:02, 45.95it/s]

In [None]:
torch.save(student.state_dict(), 'models/student.pth.tar')