In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import model
import data_loader
import train

In [2]:
def temperature_softmax(input_tensor, temperature=1.0):
    return torch.softmax(input_tensor / temperature, dim=1)


def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0.0)

In [3]:
# load data
batch_size = 256
train_iter, test_iter = data_loader.load_data_fashion_mnist(batch_size=batch_size)

In [4]:
# Define teacher & stduent model, Move models and data to GPU, Initialize weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = model.TeacherModel().to(device)
student = model.Student().to(device)
student_distill = model.Student().to(device)
student.apply(weights_init)
student_distill.apply(weights_init)

Student(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=784, out_features=1000, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (linear2): Linear(in_features=1000, out_features=10, bias=True)
)

In [9]:
# Define loss and optimizer
criterion_teacher = nn.CrossEntropyLoss()
criterion_student = nn.CrossEntropyLoss()
criterion_student_distill = nn.CrossEntropyLoss()
optimizer_teacher = optim.RMSprop(teacher.parameters(), lr=1e-4)
optimizer_student = optim.SGD(student.parameters(), lr=0.01)
optimizer_student_distill = optim.SGD(student_distill.parameters(), lr=0.07)

In [6]:
train.train_teacher(teacher, optimizer_teacher, criterion_teacher, train_iter, test_iter, device,num_epochs=60)

Epoch 1, loss = 1.8642696715415792, test_loss = 1.7459946632385255, accuracy = 0.7363
Epoch 2, loss = 1.7335356159413116, test_loss = 1.7093547224998473, accuracy = 0.762
Epoch 3, loss = 1.7068078888223526, test_loss = 1.6949671268463136, accuracy = 0.7694
Epoch 4, loss = 1.6921150567683767, test_loss = 1.677220532298088, accuracy = 0.7856
Epoch 5, loss = 1.6802283165302683, test_loss = 1.6674225866794585, accuracy = 0.7975
Epoch 6, loss = 1.6693116629377325, test_loss = 1.6567104518413545, accuracy = 0.8062
Epoch 7, loss = 1.661570366900018, test_loss = 1.648829346895218, accuracy = 0.8142
Epoch 8, loss = 1.6550766062229239, test_loss = 1.64264377951622, accuracy = 0.8205
Epoch 9, loss = 1.6506536204764184, test_loss = 1.6369083285331727, accuracy = 0.8271
Epoch 10, loss = 1.6453787382612837, test_loss = 1.633430165052414, accuracy = 0.8296
Epoch 11, loss = 1.6403011347385164, test_loss = 1.6241026699543, accuracy = 0.839
Epoch 12, loss = 1.6378359332997747, test_loss = 1.625820523500

In [11]:
train.train_student(student, optimizer_student, criterion_student, train_iter, test_iter, device,num_epochs=150)

Epoch 1, loss = 1.657535027950368, test_loss = 1.6618335247039795, accuracy = 0.8077
Epoch 2, loss = 1.6573216123783843, test_loss = 1.6618191599845886, accuracy = 0.8076
Epoch 3, loss = 1.6570985154902682, test_loss = 1.6617167830467223, accuracy = 0.8075
Epoch 4, loss = 1.6568438382858925, test_loss = 1.661665540933609, accuracy = 0.8073
Epoch 5, loss = 1.6569274395070177, test_loss = 1.6613876849412919, accuracy = 0.8074
Epoch 6, loss = 1.6567570427630811, test_loss = 1.661282330751419, accuracy = 0.808
Epoch 7, loss = 1.6567995771448663, test_loss = 1.6611857563257217, accuracy = 0.8077
Epoch 8, loss = 1.656761709172675, test_loss = 1.6610952943563462, accuracy = 0.8076
Epoch 9, loss = 1.6562935088543183, test_loss = 1.6609951615333558, accuracy = 0.8081
Epoch 10, loss = 1.6563612628490367, test_loss = 1.6608612060546875, accuracy = 0.8076
Epoch 11, loss = 1.6562673259288707, test_loss = 1.6607503592967987, accuracy = 0.808
Epoch 12, loss = 1.656211584679624, test_loss = 1.66069661

KeyboardInterrupt: 

In [10]:
train.train_distill(teacher, student_distill, optimizer_student_distill, criterion_student_distill, train_iter, test_iter, device, num_epochs=150)

Epoch 1, loss = 2.0600950033106704, test_loss = 1.6672202110290528, accuracy = 0.7899
Epoch 2, loss = 2.0494372281622377, test_loss = 1.6280801445245743, accuracy = 0.8321
Epoch 3, loss = 2.0101762583915224, test_loss = 1.62460158765316, accuracy = 0.8346
Epoch 4, loss = 2.004246626508997, test_loss = 1.6216260880231856, accuracy = 0.8384
Epoch 5, loss = 2.0022594822214006, test_loss = 1.6220321565866471, accuracy = 0.8382
Epoch 6, loss = 2.0001206986447597, test_loss = 1.6208615452051163, accuracy = 0.8393
Epoch 7, loss = 1.998136794820745, test_loss = 1.6195638865232467, accuracy = 0.8391
Epoch 8, loss = 1.9966309283642059, test_loss = 1.6192172318696976, accuracy = 0.8412
Epoch 9, loss = 1.9952711120564888, test_loss = 1.6177712678909302, accuracy = 0.8432
Epoch 10, loss = 1.994749869184291, test_loss = 1.6176768869161606, accuracy = 0.8424
Epoch 11, loss = 1.992976856738963, test_loss = 1.6170044481754302, accuracy = 0.8429
Epoch 12, loss = 1.9924973569017776, test_loss = 1.6163988