In [16]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim

import model
import data_loader
import train
import evaluate

In [2]:
def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0.0)

In [3]:
# load data
batch_size = 256
train_iter, test_iter = data_loader.load_data_fashion_mnist(
    batch_size=batch_size, resize=224
)

In [4]:
# Define teacher & stduent model, Move models and data to GPU, Initialize weights
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

teacher = torchvision.models.resnet18(pretrained=False)
teacher.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
teacher.fc = nn.Linear(512, 10)
teacher.load_state_dict(torch.load("../models/resnet18_fashion-mnist.pth"))
teacher = teacher.to(device)

student = model.Student().to(device)
student_distill = model.Student().to(device)
student.apply(weights_init)
student_distill.apply(weights_init)



Student(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=50176, out_features=1000, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (linear2): Linear(in_features=1000, out_features=10, bias=True)
)

In [20]:
# Define loss and optimizer
criterion_student = nn.CrossEntropyLoss()
criterion_student_distill = nn.CrossEntropyLoss()
optimizer_student = optim.SGD(student.parameters(), lr=0.01)
optimizer_student_distill = optim.SGD(student_distill.parameters(), lr=0.05)

In [6]:
teacher_loss, teacher_acc = evaluate.evaluate(teacher, test_iter, device)
print(f"Test Loss: {teacher_loss}, Acc: {100. * teacher_acc}%")

Test Loss: 0.18911720365285872, Acc: 93.14453125%


In [18]:
train.train_student(
    student, train_iter, test_iter, criterion_student, optimizer_student, device, num_epochs=50
)

epoch 1, loss 0.7374, train acc 0.754, test loss 0.5825, test acc 0.792
epoch 2, loss 0.4923, train acc 0.829, test loss 0.4965, test acc 0.822
epoch 3, loss 0.4392, train acc 0.849, test loss 0.4476, test acc 0.841
epoch 4, loss 0.4133, train acc 0.855, test loss 0.4308, test acc 0.850
epoch 5, loss 0.3964, train acc 0.861, test loss 0.4151, test acc 0.852
epoch 6, loss 0.3786, train acc 0.868, test loss 0.4093, test acc 0.855
epoch 7, loss 0.3686, train acc 0.870, test loss 0.3910, test acc 0.862
epoch 8, loss 0.3577, train acc 0.874, test loss 0.3859, test acc 0.864
epoch 9, loss 0.3475, train acc 0.878, test loss 0.4050, test acc 0.853
epoch 10, loss 0.3409, train acc 0.880, test loss 0.3995, test acc 0.857
epoch 11, loss 0.3346, train acc 0.882, test loss 0.3810, test acc 0.867
epoch 12, loss 0.3290, train acc 0.885, test loss 0.3899, test acc 0.861
epoch 13, loss 0.3211, train acc 0.886, test loss 0.4142, test acc 0.848
epoch 14, loss 0.3173, train acc 0.888, test loss 0.3685, te

In [21]:
train.train_distill(
    teacher,
    student_distill,
    optimizer_student_distill,
    criterion_student_distill,
    train_iter,
    test_iter,
    device,
    num_epochs=5,
)

epoch 1, loss 2.5018, train acc 0.837, test loss 4.5353, test acc 0.825
epoch 2, loss 2.5015, train acc 0.838, test loss 4.7478, test acc 0.825
epoch 3, loss 2.5013, train acc 0.839, test loss 4.5536, test acc 0.824
epoch 4, loss 2.5018, train acc 0.838, test loss 4.5484, test acc 0.824
epoch 5, loss 2.5009, train acc 0.839, test loss 4.5500, test acc 0.825
training finished
