In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim

import model
import data_loader
import train
import evaluate

In [2]:
def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0.0)

In [3]:
# load data
batch_size = 256
train_iter, test_iter = data_loader.load_data_MNIST(
    batch_size=batch_size, resize=224
)

In [4]:
# Define teacher & stduent model, Move models and data to GPU, Initialize weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

teacher = torchvision.models.resnet18(pretrained=False)
teacher.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
teacher.fc = nn.Linear(512, 10)
teacher.load_state_dict(torch.load("../models/resnet18_mnist.pth"))
teacher = teacher.to(device)

student = model.Student().to(device)
student_distill = model.Student().to(device)
student.apply(weights_init)
student_distill.apply(weights_init)



Student(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=50176, out_features=1000, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (linear2): Linear(in_features=1000, out_features=10, bias=True)
)

In [5]:
# Define loss and optimizer
criterion_student = nn.CrossEntropyLoss()
criterion_student_distill = nn.CrossEntropyLoss()
optimizer_student = optim.SGD(student.parameters(), lr=0.01)
optimizer_student_distill = optim.SGD(student_distill.parameters(), lr=0.07)

In [6]:
teacher_loss, teacher_acc = evaluate.evaluate(teacher, test_iter, device)
print(f"Test Loss: {teacher_loss}, Acc: {100. * teacher_acc}%")

Test Loss: 0.012864453472833404, Acc: 99.6484375%


In [7]:
train.train_student(
    student, train_iter, test_iter, criterion_student, optimizer_student, device
)

epoch 1, loss 0.4129, train acc 0.885, test loss 0.2419, test acc 0.934
epoch 2, loss 0.2244, train acc 0.939, test loss 0.1914, test acc 0.948
epoch 3, loss 0.1844, train acc 0.950, test loss 0.1665, test acc 0.955
epoch 4, loss 0.1598, train acc 0.957, test loss 0.1471, test acc 0.958
epoch 5, loss 0.1432, train acc 0.961, test loss 0.1389, test acc 0.962
epoch 6, loss 0.1300, train acc 0.965, test loss 0.1260, test acc 0.966
epoch 7, loss 0.1196, train acc 0.968, test loss 0.1183, test acc 0.966
epoch 8, loss 0.1112, train acc 0.971, test loss 0.1103, test acc 0.969
epoch 9, loss 0.1045, train acc 0.972, test loss 0.1065, test acc 0.971
epoch 10, loss 0.0984, train acc 0.974, test loss 0.1042, test acc 0.971
training finished


In [8]:
train.train_distill(
    teacher,
    student_distill,
    optimizer_student_distill,
    criterion_student_distill,
    train_iter,
    test_iter,
    device,
)

epoch 1, loss 2.0695, train acc 0.855, test loss 0.4842, test acc 0.914
epoch 2, loss 1.9574, train acc 0.915, test loss 0.4444, test acc 0.931
epoch 3, loss 1.9383, train acc 0.926, test loss 0.4166, test acc 0.934
epoch 4, loss 1.9260, train acc 0.934, test loss 0.3841, test acc 0.944
epoch 5, loss 1.9172, train acc 0.941, test loss 0.3507, test acc 0.949
epoch 6, loss 1.9106, train acc 0.945, test loss 0.3277, test acc 0.950
epoch 7, loss 1.9047, train acc 0.949, test loss 0.3132, test acc 0.953
epoch 8, loss 1.9002, train acc 0.951, test loss 0.2987, test acc 0.956
epoch 9, loss 1.8961, train acc 0.955, test loss 0.2846, test acc 0.958
epoch 10, loss 1.8926, train acc 0.957, test loss 0.2714, test acc 0.959
training finished
