In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import model
import data_loader
import train

In [2]:
def temperature_softmax(input_tensor, temperature=1.0):
    return torch.softmax(input_tensor / temperature, dim=1)


def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0.0)

In [3]:
# load data
batch_size = 256
train_iter, test_iter = data_loader.load_data_MNIST(batch_size=batch_size)
# train_iter, test_iter = data_loader.load_data_fashion_mnist(batch_size=batch_size)

In [4]:
# Define teacher & stduent model, Move models and data to GPU, Initialize weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = model.TeacherModel().to(device)
student = model.Student().to(device)
student.apply(weights_init)

Student(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=784, out_features=1000, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (linear2): Linear(in_features=1000, out_features=10, bias=True)
)

In [5]:
# Define loss and optimizer
criterion_teacher = nn.CrossEntropyLoss()
criterion_student = nn.CrossEntropyLoss()
optimizer_teacher = optim.RMSprop(teacher.parameters(), lr=1e-4)
optimizer_student = optim.SGD(student.parameters(), lr=0.01)

In [6]:
train.train_teacher(teacher, optimizer_teacher, criterion_teacher, train_iter, test_iter, device,num_epochs=50)

Epoch 1, loss = 2.03665840067762, test_loss = 1.7900316953659057, accuracy = 0.7306
Epoch 2, loss = 1.7965107993876681, test_loss = 1.6787387818098067, accuracy = 0.7965
Epoch 3, loss = 1.7082879695486515, test_loss = 1.5990016728639602, accuracy = 0.8751
Epoch 4, loss = 1.64574545697963, test_loss = 1.5691695004701613, accuracy = 0.8995
Epoch 5, loss = 1.6146515475942733, test_loss = 1.5547476530075073, accuracy = 0.9126
Epoch 6, loss = 1.5920708808493107, test_loss = 1.53950754404068, accuracy = 0.9261
Epoch 7, loss = 1.5786194182456807, test_loss = 1.5288008540868758, accuracy = 0.9359
Epoch 8, loss = 1.568010933348473, test_loss = 1.5239405423402785, accuracy = 0.9387
Epoch 9, loss = 1.5592817971046935, test_loss = 1.5202448457479476, accuracy = 0.9428
Epoch 10, loss = 1.552722997361041, test_loss = 1.514493039250374, accuracy = 0.9485
Epoch 11, loss = 1.5463431064118731, test_loss = 1.5104084849357604, accuracy = 0.9516
Epoch 12, loss = 1.5402489657097675, test_loss = 1.5059545248

In [None]:
train.train_student(student, optimizer_student, criterion_student, train_iter, test_iter, device,num_epochs=50)

In [8]:
train.train_distill(teacher, student, optimizer_student, criterion_student, train_iter, test_iter, device, num_epochs=50)

Epoch 1, loss = 2.388189326955917, test_loss = 1.560771644115448, accuracy = 0.9066
Epoch 2, loss = 2.3876115596040766, test_loss = 1.5600969284772872, accuracy = 0.9071
Epoch 3, loss = 2.3856506317219837, test_loss = 1.559677955508232, accuracy = 0.9069
Epoch 4, loss = 2.3843143676189666, test_loss = 1.5591090261936187, accuracy = 0.9078
Epoch 5, loss = 2.382295178352518, test_loss = 1.5586395114660263, accuracy = 0.9081
Epoch 6, loss = 2.381769061595836, test_loss = 1.5581533402204513, accuracy = 0.909
Epoch 7, loss = 2.380413240067502, test_loss = 1.557728934288025, accuracy = 0.9095
Epoch 8, loss = 2.3792416521843442, test_loss = 1.5573432177305222, accuracy = 0.9095
Epoch 9, loss = 2.377765774219594, test_loss = 1.556966060400009, accuracy = 0.9097
Epoch 10, loss = 2.377044689908941, test_loss = 1.5565327137708664, accuracy = 0.9098
Epoch 11, loss = 2.3762858106734903, test_loss = 1.555995348095894, accuracy = 0.9105
Epoch 12, loss = 2.37478045098325, test_loss = 1.555648392438888