In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision
import time

In [2]:
device = torch.device('cpu')

In [3]:
BATCH_SIZE_TRAIN = 100
BATCH_SIZE_VAL = 100
BATCH_SIZE_TEST = 100

In [4]:
trainset = torchvision.datasets.MNIST('./', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
testset = torchvision.datasets.MNIST('./', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

In [5]:
len(trainset)

60000

In [6]:
len(testset)

10000

In [7]:
train_subset, val_subset = torch.utils.data.random_split(
        trainset, [50000, 10000], generator=torch.Generator().manual_seed(1))

In [8]:
train_loader = DataLoader(dataset=train_subset, shuffle=True, batch_size=BATCH_SIZE_TRAIN)
val_loader = DataLoader(dataset=val_subset, shuffle=False, batch_size=BATCH_SIZE_VAL)
test_loader = DataLoader(dataset=testset, shuffle=False, batch_size=BATCH_SIZE_TEST)

In [68]:
def evaluate(model, dataset, batch_size=100, max_ex=0):
    """
    evaluate for dirichlet does not calculate the mean of the distribution since the mean is proportional to alpha_i.
    i.e., the probability with the highest mean = probability with highest alpha_i
    """
    model.eval()
    acc = 0
    N = len(dataset) * batch_size
    for i, (features, labels) in enumerate(dataset):
        features = features.double()
        features = features.view(batch_size, -1)
        features = features.to(device)
        labels = labels.to(device)
        scores = model(features)
        # print(labels)
        # print(scores)
        pred = torch.argmax(scores, dim=1)
        acc += torch.sum(torch.eq(pred, labels)).item()
        if max_ex != 0 and i >= max_ex:
            break
    # print(i)
    return (acc * 100 / ((i+1) * batch_size) )

In [10]:
def mean_squared_error(teacher_output_prob, student_output_alpha):
    output = student_output_alpha / torch.sum(student_output_alpha, dim=1, keepdim=True)
    return F.mse_loss(teacher_output_prob, output).item()

In [61]:
class TeacherModel(nn.Module):
    def __init__(self, dropout=0.5):
        super(TeacherModel, self).__init__()
        self.linear1 = nn.Linear(784, 1200, dtype=torch.float64)
        self.dropout = nn.Dropout(p=dropout)
        self.linear2 = nn.Linear(1200, 1200, dtype=torch.float64)
        self.dropout = nn.Dropout(p=dropout)
        self.linear3 = nn.Linear(1200, 10, dtype=torch.float64)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = F.relu(self.linear2(x))
        x = self.dropout(x)
        x = self.linear3(x)
        return x

In [62]:
NUM_EPOCHS = 10
lr = 0.001

In [63]:
teacher_model = TeacherModel()

In [64]:
criterion = nn.CrossEntropyLoss()
optimizer = Adam(teacher_model.parameters(), lr=lr)

In [66]:
start = time.time()
teacher_model.to(device)
teacher_model.train()
for epoch in range(1, NUM_EPOCHS):
    for i, (data, target) in enumerate(train_loader):
        data = data.double()
        data = data.view(BATCH_SIZE_TRAIN, -1)
        data = data.to(device)
        target = target.to(device)
        
        output = teacher_model(data)
        loss = criterion(output, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            print("Loss:", loss.item())
            
end = time.time()

print(f"Time taken:{end-start}")

Loss: 2.303287425700386
Loss: 0.30852745333251136
Loss: 0.32154546260819733
Loss: 0.21399600120998166
Loss: 0.29092894384608675
Loss: 0.079235226919389
Loss: 0.13370553588615477
Loss: 0.1463429162356006
Loss: 0.1101149493041377
Loss: 0.19215927272523398
Loss: 0.13721820871697024
Loss: 0.1548088120899069
Loss: 0.12144178103976695
Loss: 0.036790668936964414
Loss: 0.1356627635968576
Loss: 0.08432089876704037
Loss: 0.09190448342900026
Loss: 0.19988806582759366
Loss: 0.1341253921082884
Loss: 0.09667001996075535
Loss: 0.14091394733867976
Loss: 0.1922969484298019
Loss: 0.06631771430028652
Loss: 0.07075102954024037
Loss: 0.10321670422037885
Loss: 0.0788475945244516
Loss: 0.04844425849865379
Loss: 0.13565104304835304
Loss: 0.061478311009363844
Loss: 0.09436288340114134
Loss: 0.11450734491835629
Loss: 0.06349670687926358
Loss: 0.050485011909864735
Loss: 0.11003119429740518
Loss: 0.11453528878319831
Loss: 0.04007923529414313
Loss: 0.19777858059428893
Loss: 0.2912708154669698
Loss: 0.1449028956977

In [69]:
evaluate(teacher_model, train_loader)

98.86

In [70]:
evaluate(teacher_model, test_loader)

97.81

In [71]:
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.linear1 = nn.Linear(784, 50, dtype=torch.float64)
        self.linear2 = nn.Linear(50, 10, dtype=torch.float64)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [72]:
student_model = StudentModel()
student_model.to(device)

StudentModel(
  (linear1): Linear(in_features=784, out_features=50, bias=True)
  (linear2): Linear(in_features=50, out_features=10, bias=True)
)

In [73]:
criterion = nn.CrossEntropyLoss()
optimizer = Adam(student_model.parameters(), lr=lr)

In [74]:
student_model.train()
for epoch in range(1, NUM_EPOCHS):
    for i, (data, target) in enumerate(train_loader):
        data = data.double()
        data = data.view(BATCH_SIZE_TRAIN, -1)
        data = data.to(device)
        target = target.to(device)
        
        output = student_model(data)
        loss = criterion(output, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            print("Loss:", loss.item())

Loss: 2.347399054312487
Loss: 0.48183600122794135
Loss: 0.20560929226882405
Loss: 0.196136314844859
Loss: 0.25455717524880234
Loss: 0.11457593625663103
Loss: 0.16190092655927735
Loss: 0.10706857834372295
Loss: 0.20032116716740372
Loss: 0.11172894216186353
Loss: 0.19089811885242675
Loss: 0.18412750173453107
Loss: 0.2434642570425792
Loss: 0.15407006818880867
Loss: 0.19223732260201168
Loss: 0.061989243185216016
Loss: 0.051927957107854236
Loss: 0.06384545090455923
Loss: 0.12476227588829371
Loss: 0.08955146982993421
Loss: 0.09095117290575683
Loss: 0.19228158493777292
Loss: 0.13068685103707914
Loss: 0.12514300474909032
Loss: 0.15192215447259427
Loss: 0.05523088706264561
Loss: 0.08446146593085185
Loss: 0.05926083800673895
Loss: 0.0742524875561164
Loss: 0.11511030938619658
Loss: 0.017762347602045314
Loss: 0.10881318183425184
Loss: 0.022470427425304477
Loss: 0.07001898218221803
Loss: 0.08773400918760618
Loss: 0.02506472002957886
Loss: 0.03591148655564282
Loss: 0.04595865420028297
Loss: 0.067220

In [75]:
evaluate(student_model, train_loader)

98.598

In [76]:
evaluate(student_model, test_loader)

97.1

In [77]:
def distillation_loss(teacher_output, student_output, target, temp=1, alpha=0.5, beta=0.5):
    student_output_log_prob = F.log_softmax(student_output / temp, dim=1)
    teacher_output_prob = F.softmax(teacher_output / temp, dim=1)
    return alpha * F.cross_entropy(student_output, target) + beta * -(teacher_output_prob*student_output_log_prob).sum(dim=1).mean()

In [78]:
eps = torch.finfo(torch.float64).eps

In [79]:
def distillation_loss_dirichlet(teacher_output_prob, student_output, target):
    """
        This is a basic version of the distillation loss.
        This does not scale the output using a temperature or incorporate a mixture of teacher loss and target loss.
        *** Negating the log-likelihood to follow the convention of minimizing the loss function. ***
    """
    return -((torch.sum((student_output - 1) * torch.log(teacher_output_prob + eps), dim=1) + torch.lgamma(torch.sum(student_output, dim=1) + eps) - torch.sum(torch.lgamma(student_output + eps), dim=1))).mean() 

In [80]:
student_model_distilled = StudentModel()
student_model_distilled.to(device)

StudentModel(
  (linear1): Linear(in_features=784, out_features=50, bias=True)
  (linear2): Linear(in_features=50, out_features=10, bias=True)
)

In [81]:
optimizer = Adam(student_model_distilled.parameters(), lr=lr)

In [82]:
student_model_distilled.train()
teacher_model.eval()

for epoch in range(1, NUM_EPOCHS):
    for i, (data, target) in enumerate(train_loader):
        data = data.double()
        data = data.view(BATCH_SIZE_TRAIN, -1)
        data = data.to(device)
        target = target.to(device)        
        student_output = student_model_distilled(data)
        teacher_output = teacher_model(data)
        loss = distillation_loss(teacher_output, student_output, target, temp=3, alpha=0.4, beta=0.6)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print("Loss:", loss.item())

Loss: 2.3580268431516727
Loss: 0.6107853313775403
Loss: 0.55463149456736
Loss: 0.6529141417873839
Loss: 0.4835884293264955
Loss: 0.49874910717174986
Loss: 0.5202839596862426
Loss: 0.4097256025562288
Loss: 0.3231338011833757
Loss: 0.3826480637080729
Loss: 0.3359389219432365
Loss: 0.27254007279638665
Loss: 0.3411947537710217
Loss: 0.5163640247243276
Loss: 0.33480293263263294
Loss: 0.40961410148506944
Loss: 0.41234721783908024
Loss: 0.29289605329798574
Loss: 0.41800605280718833
Loss: 0.4282348579891116
Loss: 0.38715913418606573
Loss: 0.3134492312038455
Loss: 0.3884770284082473
Loss: 0.30341899541394207
Loss: 0.2518020574232445
Loss: 0.3401087910735874
Loss: 0.29358301241776064
Loss: 0.27289855673361607
Loss: 0.3477010625883886
Loss: 0.3123522502314622
Loss: 0.3329393184233002
Loss: 0.32966112732820585
Loss: 0.36341430756847637
Loss: 0.2766805833454965
Loss: 0.27680178190757937
Loss: 0.3509351782688971
Loss: 0.2723616804139621
Loss: 0.31165877933740627
Loss: 0.2652171241584098
Loss: 0.3183

In [83]:
evaluate(student_model_distilled, train_loader)

98.342

In [84]:
evaluate(student_model_distilled, test_loader)

96.7

In [85]:
def il(x):
    return torch.where(x < 0, (1 / (1 - x)) + eps, x + 1 + eps)

In [86]:
class StudentModelDirichlet(nn.Module):
    def __init__(self):
        super(StudentModelDirichlet, self).__init__()
        self.linear1 = nn.Linear(784, 50, dtype=torch.float64)
        self.linear2 = nn.Linear(50, 10, dtype=torch.float64)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = il(self.linear2(x))
        return x

In [87]:
student_model_distilled_dirichlet = StudentModelDirichlet()
student_model_distilled_dirichlet.to(device)
optimizer = Adam(student_model_distilled_dirichlet.parameters(), lr=0.001)

In [88]:
student_model_distilled_dirichlet.train()
teacher_model.eval()

for epoch in range(1, NUM_EPOCHS):
    for i, (data, target) in enumerate(train_loader):
        data = data.double()
        data = data.view(BATCH_SIZE_TRAIN, -1)
        data = data.to(device)
        target = target.to(device)
        # target_one_hot_encoded = F.one_hot(target, num_classes=10)
        student_output = student_model_distilled_dirichlet(data)
        student_output.retain_grad()
        
        teacher_output = teacher_model(data)
        teacher_output = F.softmax(teacher_output, dim=1)
        loss = distillation_loss_dirichlet(teacher_output, student_output, target)
        optimizer.zero_grad()
        loss.backward()
        
        
        # print("teacher_output", teacher_output)
        # print("student_output", student_output)
        # print("student_output sum", student_output.sum())
        # print("loss", loss)
        # if i % 100 == 0:
        #     print("--------------")
        #     print(target)
        #     print(student_output)
        #     print("gradient", student_output.grad)
        # break
        optimizer.step()

        if i % 100 == 0:
            print("Loss:", loss.item())
            print("MSE:", mean_squared_error(teacher_output, student_output))

Loss: -17.717303986340067
MSE: 0.08658637135426829
Loss: -146.93780363835742
MSE: 0.0858176328888101
Loss: -160.1516467405612
MSE: 0.08502779704943841
Loss: -151.5732022085047
MSE: 0.08598026268752776
Loss: -145.09532692669617
MSE: 0.08080357296424963
Loss: -155.395347442335
MSE: 0.0786515033971438
Loss: -146.2075900220333
MSE: 0.0703130752894403
Loss: -145.61893155397954
MSE: 0.06273859935535837
Loss: -136.10336262863788
MSE: 0.051838923665585276
Loss: -144.29814220012756
MSE: 0.05042432139135245
Loss: -169.1971692530504
MSE: 0.046740488091019924
Loss: -166.5030239855897
MSE: 0.03755422379513372
Loss: -160.61559037124567
MSE: 0.04162129941551949
Loss: -147.13621136140145
MSE: 0.03671783803203076
Loss: -137.09141544414652
MSE: 0.038895164477791726
Loss: -154.72762177310403
MSE: 0.029404199037547384
Loss: -168.80343136838664
MSE: 0.02391109954544435
Loss: -162.06933735138256
MSE: 0.023925125080632815
Loss: -168.43892053034907
MSE: 0.02151434099695281
Loss: -142.8460532984871
MSE: 0.0242

In [89]:
evaluate(student_model_distilled_dirichlet, train_loader)

92.154

In [90]:
evaluate(student_model_distilled_dirichlet, test_loader, batch_size=100)

92.14