In [38]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

In [39]:
import model
import data_loader

In [40]:
def temperature_softmax(input_tensor, temperature=1.0):
    return torch.softmax(input_tensor / temperature, dim=1)

In [41]:
teacher = model.TeacherModel()
student = model.Student()

In [42]:
def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0.0)

student.apply(weights_init)

Student(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=784, out_features=1000, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
  (linear2): Linear(in_features=1000, out_features=10, bias=True)
)

In [43]:
batch_size = 256
train_iter, test_iter = data_loader.load_data_MNIST(batch_size=batch_size)

In [44]:
# Move models and data to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = teacher.to(device)
student = student.to(device)

In [45]:
# Define loss and optimizer
criterion_teacher = nn.CrossEntropyLoss()
criterion_student = nn.CrossEntropyLoss()
optimizer_teacher = optim.RMSprop(teacher.parameters(), lr=1e-4)
optimizer_student = optim.SGD(student.parameters(), lr=0.01)

In [46]:
# Define teacher evaluation function
def evaluate_teacher(net, criterion, testloader, device):
    net.eval()	# Set model to evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in testloader:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            y_hat = temperature_softmax(net(X),) 
            loss = criterion(y_hat, y)
            
            running_loss += loss.item()
            _, predicted = y_hat.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    print(f"Test loss = {running_loss / len(testloader)}, accuracy = {correct / total}")

In [47]:
# Define teacher training function
def train_teacher(net, optimizer, criterion, trainloader, testloader, device, num_epochs=10):
    for epoch in range(num_epochs):
        net.train()	# Set model to training mode
        running_loss = 0.0
        for X, y in trainloader:
            X = X.to(device)
            y = y.to(device)
            # y_hat = net(X)
            y_hat = temperature_softmax(net(X))
            loss = criterion(y_hat, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}, loss = {running_loss / len(trainloader)}")
        if epoch % 5 == 5 - 1:
            evaluate_teacher(net, criterion, testloader, device)

In [49]:
train_teacher(teacher, optimizer_teacher, criterion_teacher, train_iter, test_iter, device,num_epochs=20)

Epoch 1, loss = 1.4996710412045742
Epoch 2, loss = 1.4985767963084768
Epoch 3, loss = 1.497438501804433
Epoch 4, loss = 1.497542148955325
Epoch 5, loss = 1.4983872002743661
Test loss = 1.4822357922792435, accuracy = 0.9794
Epoch 6, loss = 1.4960177030969173
Epoch 7, loss = 1.4951664777512246
Epoch 8, loss = 1.4951716321579953
Epoch 9, loss = 1.4950845895929539
Epoch 10, loss = 1.4941560106074556
Test loss = 1.4808611243963241, accuracy = 0.9804
Epoch 11, loss = 1.4930315550337447
Epoch 12, loss = 1.4934779796194524
Epoch 13, loss = 1.4928328534390063
Epoch 14, loss = 1.4922175057390903
Epoch 15, loss = 1.492471792342815
Test loss = 1.479246386885643, accuracy = 0.9821
Epoch 16, loss = 1.4908364174213815
Epoch 17, loss = 1.4913386162291182
Epoch 18, loss = 1.491286019061474
Epoch 19, loss = 1.4890501057848018
Epoch 20, loss = 1.4898812745479828
Test loss = 1.478385517001152, accuracy = 0.9825


In [None]:
# Define student evaluation function
def evaluate_student(net, criterion, testloader, device):
    net.eval()	# Set model to evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in testloader:
            X = X.to(device)
            y = y.to(device)
            # y_hat = net(X)
            y_hat = temperature_softmax(net(X))
            loss = criterion(y_hat, y)
            
            running_loss += loss.item()
            _, predicted = y_hat.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
    print(f"Test loss = {running_loss / len(testloader)}, accuracy = {correct / total}")

In [None]:
# Define student training function
def train_student(net, optimizer, criterion, trainloader,testloader, device, num_epochs=10):
    for epoch in range(num_epochs):
        net.train()	# Set model to training mode
        running_loss = 0.0
        for X, y in trainloader:
            X = X.to(device)
            y = y.to(device)
            # y_hat = net(X)
            y_hat = temperature_softmax(net(X))
            loss = criterion(y_hat, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}, loss = {running_loss / len(trainloader)}")
        if epoch % 5 == 5 - 1:
            evaluate_student(net, criterion, testloader, device)

In [None]:
train_student(student, optimizer_student, criterion_student, train_iter, test_iter, device)

Epoch 1, loss = 1.2536832418847592
Epoch 2, loss = 0.8337080724695896
Epoch 3, loss = 0.7318104987448835
Epoch 4, loss = 0.6757268994412524
Epoch 5, loss = 0.6392246447979135
Test loss = 0.5993871130049229, accuracy = 0.798
Epoch 6, loss = 0.6138397891470726
Epoch 7, loss = 0.5910799625072073
Epoch 8, loss = 0.5727639648508518
Epoch 9, loss = 0.5575131676298507
Epoch 10, loss = 0.5464943954285155
Test loss = 0.5272816389799118, accuracy = 0.8224


In [None]:
def train_distill(teacher, student, optimizer_student, criterion_student, trainloader, testloader, device, num_epochs=10):
    for epoch in range(num_epochs):
        teacher.eval()
        student.train()
        running_loss = 0.0
        for X, y in trainloader:
            X = X.to(device)
            y = y.to(device)
            y_teacher = teacher(X)
            y_hat = student(X)
            loss = 5 * criterion_student(y_hat, y) + F.kl_div(y_hat.log(), y_teacher, reduction="batchmean")
            
            optimizer_student.zero_grad()
            loss.backward()
            optimizer_student.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}, loss = {running_loss / len(trainloader)}")
        if epoch % 5 == 5 - 1:
            evaluate_student(student, criterion_student, testloader, device)

In [None]:
train_distill(teacher, student, optimizer_student, criterion_student, train_iter, test_iter, device, num_epochs=50)

NameError: name 'trainloader' is not defined