In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import model
import data_loader
import train

In [2]:
def temperature_softmax(input_tensor, temperature=1.0):
    return torch.softmax(input_tensor / temperature, dim=1)


def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0.0)

In [3]:
# load data
batch_size = 256
train_iter, test_iter = data_loader.load_data_MNIST(batch_size=batch_size)

In [4]:
# Define teacher & stduent model, Move models and data to GPU, Initialize weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = model.TeacherModel().to(device)
student = model.Student().to(device)
student_distill = model.Student().to(device)
student.apply(weights_init)
student_distill.apply(weights_init)

Student(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=784, out_features=1000, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (linear2): Linear(in_features=1000, out_features=10, bias=True)
)

In [5]:
# Simplified student model
student.linear1 = nn.Linear(784, 128).to(device)
student.linear2 = nn.Linear(128, 10).to(device)
student_distill.linear1 = nn.Linear(784, 128).to(device)
student_distill.linear2 = nn.Linear(128, 10).to(device)

In [12]:
# Define loss and optimizer
criterion_teacher = nn.CrossEntropyLoss()
criterion_student = nn.CrossEntropyLoss()
criterion_student_distill = nn.CrossEntropyLoss()
optimizer_teacher = optim.RMSprop(teacher.parameters(), lr=1e-4)
optimizer_student = optim.SGD(student.parameters(), lr=0.01)
optimizer_student_distill = optim.SGD(student_distill.parameters(), lr=0.07)

In [7]:
train.train_teacher(teacher, optimizer_teacher, criterion_teacher, train_iter, test_iter, device,num_epochs=50)

Epoch 1, loss = 1.8411239487059572, test_loss = 1.654077634215355, accuracy = 0.8334
Epoch 2, loss = 1.636415818397035, test_loss = 1.5368891268968583, accuracy = 0.942
Epoch 3, loss = 1.5632253175086164, test_loss = 1.5110902428627013, accuracy = 0.9595
Epoch 4, loss = 1.5404838633030018, test_loss = 1.500614231824875, accuracy = 0.9664
Epoch 5, loss = 1.528732084213419, test_loss = 1.495979180932045, accuracy = 0.9699
Epoch 6, loss = 1.5213004690535525, test_loss = 1.490612342953682, accuracy = 0.9747
Epoch 7, loss = 1.516134158601152, test_loss = 1.487645548582077, accuracy = 0.9768
Epoch 8, loss = 1.5117817452613345, test_loss = 1.4852753251791, accuracy = 0.978
Epoch 9, loss = 1.5083054197595474, test_loss = 1.4835948586463927, accuracy = 0.9811
Epoch 10, loss = 1.5052019697554568, test_loss = 1.4821049749851227, accuracy = 0.9807
Epoch 11, loss = 1.504144040067145, test_loss = 1.4820381820201873, accuracy = 0.9808
Epoch 12, loss = 1.5021797043211917, test_loss = 1.480410784482956

In [10]:
train.train_student(student, optimizer_student, criterion_student, train_iter, test_iter, device,num_epochs=150)

Epoch 1, loss = 1.6401697924796572, test_loss = 1.6227310657501222, accuracy = 0.8472
Epoch 2, loss = 1.639638295579464, test_loss = 1.6222314596176148, accuracy = 0.8478
Epoch 3, loss = 1.6384812025313682, test_loss = 1.6215544492006302, accuracy = 0.8479
Epoch 4, loss = 1.637824492758893, test_loss = 1.6206644356250763, accuracy = 0.849
Epoch 5, loss = 1.637358168845481, test_loss = 1.6194400846958161, accuracy = 0.8503
Epoch 6, loss = 1.6347491355652506, test_loss = 1.616957676410675, accuracy = 0.8518
Epoch 7, loss = 1.6314541020291917, test_loss = 1.6113352656364441, accuracy = 0.86
Epoch 8, loss = 1.6254181775640935, test_loss = 1.6006428956985475, accuracy = 0.8768
Epoch 9, loss = 1.618741026837775, test_loss = 1.5932772159576416, accuracy = 0.8908
Epoch 10, loss = 1.6145416660511747, test_loss = 1.589525282382965, accuracy = 0.8953
Epoch 11, loss = 1.6111023984056838, test_loss = 1.5867505997419358, accuracy = 0.8988
Epoch 12, loss = 1.608606711854326, test_loss = 1.58438452184

In [13]:
train.train_distill(teacher, student_distill, optimizer_student_distill, criterion_student_distill, train_iter, test_iter, device, num_epochs=150)

Epoch 1, loss = 2.0462179691233535, test_loss = 1.6124901294708252, accuracy = 0.8493
Epoch 2, loss = 2.043569054502122, test_loss = 1.6120561093091965, accuracy = 0.849
Epoch 3, loss = 2.042707153584095, test_loss = 1.611604428291321, accuracy = 0.8495
Epoch 4, loss = 2.0411331506485633, test_loss = 1.6111511081457137, accuracy = 0.8503
Epoch 5, loss = 2.041393507795131, test_loss = 1.6107615232467651, accuracy = 0.8505
Epoch 6, loss = 2.0390332044439115, test_loss = 1.6099882900714875, accuracy = 0.8515
Epoch 7, loss = 2.0389205821017002, test_loss = 1.6095828264951706, accuracy = 0.8517
Epoch 8, loss = 2.0373021470739485, test_loss = 1.6088697552680968, accuracy = 0.8527
Epoch 9, loss = 2.036311720787211, test_loss = 1.6082497090101242, accuracy = 0.8535
Epoch 10, loss = 2.0357838087893545, test_loss = 1.6079275995492934, accuracy = 0.8536
Epoch 11, loss = 2.0349967428978455, test_loss = 1.6072347998619079, accuracy = 0.8544
Epoch 12, loss = 2.033318361830204, test_loss = 1.60734866