# 4.3 Cifar-100 Transfer Task (but with MNIST)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision

In [2]:
EPS = torch.finfo(torch.float64).eps
# EPS = 1e-8

In [3]:
def distillation_loss_beta_liouville(teacher_output_prob, student_output_alpha_i, student_output_alpha, student_output_beta, verbose=False):
    """
        This is a basic version of the distillation loss.
        This does not scale the output using a temperature or incorporate a mixture of teacher loss and target loss.
        *** Negating the log-likelihood to follow the convention of minimizing the loss function. ***
    """
    
    teacher_output_prob = torch.clamp(teacher_output_prob, min=EPS)
    student_output_alpha_i = torch.clamp(student_output_alpha_i, min=EPS)
    student_output_alpha = torch.clamp(student_output_alpha, min=EPS)
    student_output_beta = torch.clamp(student_output_beta, min=EPS)
    teacher_output_prob_sum = torch.sum(teacher_output_prob, dim=1)
    teacher_output_prob_sum = torch.clamp(teacher_output_prob_sum, max=1-EPS)
    return -(torch.lgamma(torch.sum(student_output_alpha_i, dim=1)) + torch.lgamma(student_output_alpha + student_output_beta)
            - torch.lgamma(student_output_alpha) - torch.lgamma(student_output_beta) + (student_output_alpha - torch.sum(student_output_alpha_i, dim=1)) * torch.log(torch.sum(teacher_output_prob, dim=1))
            + (student_output_beta - 1) * torch.log(1 - teacher_output_prob_sum) + torch.sum((student_output_alpha_i - 1) * torch.log(teacher_output_prob), dim=1)
            - torch.sum(torch.lgamma(student_output_alpha_i), dim=1)).mean()


In [4]:
P = torch.tensor([[0.3, 0.3, 0.1]])

alpha_d = torch.tensor([[2., 2., 3.]])

alpha = torch.tensor([2.])

beta = torch.tensor([2.])

distillation_loss_beta_liouville(P, alpha_d, alpha, beta)

tensor(-1.2441)

In [5]:
device = torch.device('cpu')

In [6]:
BATCH_SIZE_TRAIN = 100
BATCH_SIZE_VAL = 100
BATCH_SIZE_TEST = 100

In [7]:
trainset = torchvision.datasets.MNIST('./', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
testset = torchvision.datasets.MNIST('./', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

In [8]:
train_subset, val_subset = torch.utils.data.random_split(
        trainset, [50000, 10000], generator=torch.Generator().manual_seed(1))

In [9]:
train_loader = DataLoader(dataset=train_subset, shuffle=False, batch_size=BATCH_SIZE_TRAIN)
val_loader = DataLoader(dataset=val_subset, shuffle=False, batch_size=BATCH_SIZE_VAL)
test_loader = DataLoader(dataset=testset, shuffle=False, batch_size=BATCH_SIZE_TEST)

In [10]:
def evaluate(model, dataset, batch_size=100, max_ex=0):
    model.eval()
    acc = 0
    N = len(dataset) * batch_size
    for i, (features, labels) in enumerate(dataset):
        features = features.view(batch_size, -1)
        features = features.to(device)
        features = features.double()
        labels = labels.to(device)
        scores = model(features)
        # print(labels)
        # print(scores)
        pred = torch.argmax(scores, dim=1)
        acc += torch.sum(torch.eq(pred, labels)).item()
        if max_ex != 0 and i >= max_ex:
            break
    # print(i)
    return (acc * 100 / ((i+1) * batch_size) )

In [11]:
def evaluate_beta_liouville(model, dataset, batch_size=100, max_ex=0):
    """
        output_alpha_i : (N, D)
        output_alpha   : (N, 1)
        output_beta    : (N, 1)
    """
    model.eval()
    acc = 0
    N = len(dataset) * batch_size
    for i, (features, labels) in enumerate(dataset):
        features = features.view(batch_size, -1)
        features = features.to(device)
        features = features.double()
        labels = labels.to(device)
        output_alpha_i, output_alpha, output_beta = model(features)
        output = output_alpha_i / torch.sum(output_alpha_i, dim=1, keepdim=True)
        output *= output_alpha / (output_alpha + output_beta)   
        torch.concat((output, 1 - torch.unsqueeze(torch.sum(output, dim=1), dim=1)), dim=1)
        pred = torch.argmax(output, dim=1)
        acc += torch.sum(torch.eq(pred, labels)).item()
        if max_ex != 0 and i >= max_ex:
            break
    # print(i)
    return (acc * 100 / ((i+1) * batch_size) )

In [12]:
def mean_squared_error(teacher_output_prob, output_alpha_i, output_alpha, output_beta):
    """
    teacher_output_prob: output of teacher model after applying softmax (N, d)
    student_output_alpha_i: output of student model (alpha) (N, d)
    student_output_alpha: output of student model (beta) (N, 1)
    student_output_beta: output of student model (beta) (N, 1)
    labels are of size d+1
    """
    output = output_alpha_i / torch.sum(output_alpha_i, dim=1, keepdim=True)
    output *= output_alpha / (output_alpha + output_beta)   
    
    return F.mse_loss(teacher_output_prob, output).item()

In [13]:
class TeacherModel(nn.Module):
    def __init__(self, dropout=0.5):
        super(TeacherModel, self).__init__()
        self.linear1 = nn.Linear(784, 400, dtype=torch.float64)
        self.dropout = nn.Dropout(p=dropout)
        self.linear2 = nn.Linear(400, 100, dtype=torch.float64)
        self.dropout = nn.Dropout(p=dropout)
        self.linear3 = nn.Linear(100, 10, dtype=torch.float64)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = F.relu(self.linear2(x))
        x = self.dropout(x)
        x = self.linear3(x)
        return x

In [14]:
NUM_EPOCHS = 10
lr = 0.001

In [15]:
teacher_model = TeacherModel()

In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = Adam(teacher_model.parameters(), lr=lr)

In [17]:
teacher_model.to(device)
teacher_model.train()
for epoch in range(1, NUM_EPOCHS):
    for i, (data, target) in enumerate(train_loader):
        data = data.view(BATCH_SIZE_TRAIN, -1)
        data = data.to(device)
        data = data.double()
        target = target.to(device)
        
        output = teacher_model(data)
        loss = criterion(output, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            print("Loss:", loss.item())

Loss: 2.3166536372104405
Loss: 0.6096831454824725
Loss: 0.2794842294434579
Loss: 0.3598486027407903
Loss: 0.23256641982973916
Loss: 0.2474551808080626
Loss: 0.26510286677265277
Loss: 0.20394707081174768
Loss: 0.2083503831977444
Loss: 0.257821930062876
Loss: 0.13670311741724295
Loss: 0.14037771925195064
Loss: 0.12013769575218064
Loss: 0.24089642965531033
Loss: 0.141830523427644
Loss: 0.1423595666342025
Loss: 0.17602033481884738
Loss: 0.07475682925303895
Loss: 0.2294210890483884
Loss: 0.20855018842157072
Loss: 0.14121689585184424
Loss: 0.21926326944517435
Loss: 0.1070157050545755
Loss: 0.13027088615724092
Loss: 0.16279880097316637
Loss: 0.08233808401686264
Loss: 0.09893154585331951
Loss: 0.09998629630541542
Loss: 0.07275472173066447
Loss: 0.14144501940478652
Loss: 0.12667445742054684
Loss: 0.12189295337556347
Loss: 0.04849306983834963
Loss: 0.15693854422807696
Loss: 0.1288591851243214
Loss: 0.08829534126344674
Loss: 0.15612173284394987
Loss: 0.05003507554613689
Loss: 0.11261784429642384


In [18]:
evaluate(teacher_model, train_loader)

98.63

In [19]:
evaluate(teacher_model, test_loader)

97.81

In [20]:
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.linear1 = nn.Linear(784, 50, dtype=torch.float64)
        self.linear2 = nn.Linear(50, 10, dtype=torch.float64)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [21]:
student_model = StudentModel()
student_model.to(device)

StudentModel(
  (linear1): Linear(in_features=784, out_features=50, bias=True)
  (linear2): Linear(in_features=50, out_features=10, bias=True)
)

In [22]:
criterion = nn.CrossEntropyLoss()
optimizer = Adam(student_model.parameters(), lr=lr)

In [23]:
student_model.train()
for epoch in range(1, NUM_EPOCHS):
    for i, (data, target) in enumerate(train_loader):
        data = data.view(BATCH_SIZE_TRAIN, -1)
        data = data.to(device)
        data = data.double()
        target = target.to(device)
        
        output = student_model(data)
        loss = criterion(output, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            print("Loss:", loss.item())

Loss: 2.3334183619368853
Loss: 0.3843412755930081
Loss: 0.22942254939002354
Loss: 0.2952095209913666
Loss: 0.1964623295572778
Loss: 0.2851572507584291
Loss: 0.20263482257477122
Loss: 0.12397038462742362
Loss: 0.21246059051192503
Loss: 0.13559758253005763
Loss: 0.23261682621633292
Loss: 0.14247202790361588
Loss: 0.08625265980390091
Loss: 0.1793765864365005
Loss: 0.09336033104324734
Loss: 0.19079324671582285
Loss: 0.11654718262345724
Loss: 0.06642726074608657
Loss: 0.1580057123665614
Loss: 0.06611914462633756
Loss: 0.15769846118915365
Loss: 0.1102585485431388
Loss: 0.048222031133984676
Loss: 0.13162942893612198
Loss: 0.05111616437781475
Loss: 0.1387509206425659
Loss: 0.1073214775073215
Loss: 0.03540645063175957
Loss: 0.11192705310844085
Loss: 0.043931814169010146
Loss: 0.117218811226339
Loss: 0.10298231925746376
Loss: 0.028331323381798716
Loss: 0.09268116019086586
Loss: 0.031684248987343906
Loss: 0.09299906258216174
Loss: 0.09128393122569368
Loss: 0.02327276877655826
Loss: 0.088599807981

In [24]:
evaluate(student_model, train_loader)

98.394

In [25]:
evaluate(student_model, test_loader)

96.81

In [26]:
def il(x):
    return torch.where(x < 0, (1 / (1 - x)), x + 1) + EPS

In [27]:
class StudentModelBetaLiouville(nn.Module):
    def __init__(self):
        super(StudentModelBetaLiouville, self).__init__()
        self.linear1 = nn.Linear(784, 50, dtype=torch.float64)
        self.output_alpha_i = nn.Linear(50, 9, dtype=torch.float64)
        self.output_alpha = nn.Linear(50, 1, dtype=torch.float64)
        self.output_beta = nn.Linear(50, 1, dtype=torch.float64) 

    def forward(self, x):
        x = F.relu(self.linear1(x))
        output_alpha_i = il(self.output_alpha_i(x))
        output_alpha = il(self.output_alpha(x))
        output_beta = il(self.output_beta(x))
        return output_alpha_i, output_alpha, output_beta

In [28]:
student_model_distilled_beta_liouville = StudentModelBetaLiouville()
student_model_distilled_beta_liouville.to(device)
optimizer = Adam(student_model_distilled_beta_liouville.parameters(), lr=0.001)

In [29]:
student_model_distilled_beta_liouville.train()
teacher_model.eval()

for epoch in range(1, NUM_EPOCHS):
    for i, (data, target) in enumerate(train_loader):
        data = data.view(BATCH_SIZE_TRAIN, -1)
        data = data.to(device)
        data = data.double()
        target = target.to(device)
        teacher_output = teacher_model(data)
        teacher_output = F.softmax(teacher_output, dim=1)
        teacher_output = teacher_output[:, :-1]
        student_output_alpha_i, student_output_alpha, student_output_beta = student_model_distilled_beta_liouville(data)
        
        loss = distillation_loss_beta_liouville(teacher_output, student_output_alpha_i, student_output_alpha, student_output_beta, verbose=False)
        if i == 0:
            print("Loss:", loss.item())
            print("MSE:", mean_squared_error(teacher_output, student_output_alpha_i, student_output_alpha, student_output_beta))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Loss: -14.64621420965183
MSE: 0.09200795108182831
Loss: -133.65138811080232
MSE: 0.08281800974795
Loss: -134.71965126388503
MSE: 0.05381885780249965
Loss: -135.7686422031431
MSE: 0.025876374316457115
Loss: -136.1747335824253
MSE: 0.020288284638487804
Loss: -136.4463399402976
MSE: 0.01703018694016209
Loss: -136.59333410413174
MSE: 0.015050421932914766
Loss: -136.7562120804196
MSE: 0.013528319383203199
Loss: -136.88799376632855
MSE: 0.012678205576272715


In [30]:
evaluate_beta_liouville(student_model_distilled_beta_liouville, train_loader)

84.88

In [31]:
evaluate_beta_liouville(student_model_distilled_beta_liouville, test_loader)

84.73