In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim

import model
import data_loader
import train
import evaluate

In [2]:
def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0.0)

In [3]:
# load data
batch_size = 256
train_iter, test_iter = data_loader.load_data_MNIST(
    batch_size=batch_size, resize=224
)

In [4]:
# Define teacher & stduent model, Move models and data to GPU, Initialize weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

teacher = torchvision.models.resnet18(pretrained=False)
teacher.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
teacher.fc = nn.Linear(512, 10)
teacher.load_state_dict(torch.load("../models/resnet18_mnist.pth"))
teacher = teacher.to(device)

student = model.Student().to(device)
student_distill = model.Student().to(device)
student.apply(weights_init)
student_distill.apply(weights_init)



Student(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=50176, out_features=1000, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (linear2): Linear(in_features=1000, out_features=10, bias=True)
)

In [10]:
# Define loss and optimizer
criterion_student = nn.CrossEntropyLoss()
criterion_student_distill = nn.CrossEntropyLoss()
optimizer_student = optim.SGD(student.parameters(), lr=0.01)
optimizer_student_distill = optim.SGD(student_distill.parameters(), lr=0.1)

In [6]:
teacher_loss, teacher_acc = evaluate.evaluate(teacher, test_iter, device)
print(f"Test Loss: {teacher_loss}, Acc: {100. * teacher_acc}%")

Test Loss: 0.012864453472833404, Acc: 99.6484375%


In [8]:
train.train_student(
    student, train_iter, test_iter, criterion_student, optimizer_student, device, num_epochs=40
)

epoch 1, loss 0.4096, train acc 0.886, test loss 0.2366, test acc 0.937
epoch 2, loss 0.2225, train acc 0.939, test loss 0.1860, test acc 0.948
epoch 3, loss 0.1827, train acc 0.951, test loss 0.1644, test acc 0.953
epoch 4, loss 0.1586, train acc 0.957, test loss 0.1431, test acc 0.959
epoch 5, loss 0.1420, train acc 0.962, test loss 0.1331, test acc 0.963
epoch 6, loss 0.1291, train acc 0.965, test loss 0.1237, test acc 0.965
epoch 7, loss 0.1195, train acc 0.968, test loss 0.1166, test acc 0.968
epoch 8, loss 0.1106, train acc 0.971, test loss 0.1088, test acc 0.969
epoch 9, loss 0.1035, train acc 0.973, test loss 0.1040, test acc 0.971
epoch 10, loss 0.0971, train acc 0.975, test loss 0.1004, test acc 0.972
epoch 11, loss 0.0919, train acc 0.976, test loss 0.0991, test acc 0.972
epoch 12, loss 0.0871, train acc 0.977, test loss 0.0935, test acc 0.972
epoch 13, loss 0.0829, train acc 0.978, test loss 0.0910, test acc 0.974
epoch 14, loss 0.0792, train acc 0.979, test loss 0.0891, te

In [11]:
train.train_distill(
    teacher,
    student_distill,
    optimizer_student_distill,
    criterion_student_distill,
    train_iter,
    test_iter,
    device,
    num_epochs=60,
)

epoch 1, loss 1.9034, train acc 0.950, test loss 0.3070, test acc 0.956
epoch 2, loss 1.8971, train acc 0.954, test loss 0.2898, test acc 0.957
epoch 3, loss 1.8929, train acc 0.957, test loss 0.2691, test acc 0.960
epoch 4, loss 1.8890, train acc 0.960, test loss 0.2592, test acc 0.962
epoch 5, loss 1.8855, train acc 0.962, test loss 0.2475, test acc 0.964
epoch 6, loss 1.8822, train acc 0.964, test loss 0.2331, test acc 0.965
epoch 7, loss 1.8799, train acc 0.966, test loss 0.2228, test acc 0.967
epoch 8, loss 1.8769, train acc 0.968, test loss 0.2179, test acc 0.967
epoch 9, loss 1.8752, train acc 0.969, test loss 0.2062, test acc 0.969
epoch 10, loss 1.8729, train acc 0.971, test loss 0.1998, test acc 0.971
epoch 11, loss 1.8715, train acc 0.972, test loss 0.2031, test acc 0.969
epoch 12, loss 1.8700, train acc 0.973, test loss 0.1831, test acc 0.973
epoch 13, loss 1.8683, train acc 0.974, test loss 0.1824, test acc 0.972
epoch 14, loss 1.8666, train acc 0.976, test loss 0.1759, te