# 4.3 Cifar-100 Transfer Task (but with MNIST)

In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision

In [2]:
EPS = torch.finfo(torch.float64).eps
# EPS = 1e-8

In [3]:
def distillation_loss_generalized_dirichlet(teacher_output_prob, student_output_alpha, student_output_beta, verbose=False):
    """
        This is a basic version of the distillation loss.
        This does not scale the output using a temperature or incorporate a mixture of teacher loss and target loss.
        *** Negating the log-likelihood to follow the convention of minimizing the loss function. ***
    """
    teacher_output_prob = torch.clamp(teacher_output_prob, min=EPS)
    student_output_alpha = torch.clamp(student_output_alpha, min=EPS)
    student_output_beta = torch.clamp(student_output_beta, min=EPS)
    d = student_output_alpha.shape[1]
    loss = 0.0
    for l in range(0, d - 1):
        scaler = student_output_beta[:, l] - student_output_alpha[:, l + 1] - student_output_beta[:, l + 1]
        sum_p = 0.0
        for j in range(0, l + 1):
            sum_p += teacher_output_prob[:, j]
        sum_p = torch.clamp(sum_p, max=1-EPS)
        if verbose:
            print("sum_p:", sum_p)
        loss += scaler * torch.log(1 - sum_p)
        
    if verbose:
        print("loss1:", loss)
    
    sum_teacher_output_prob = torch.sum(teacher_output_prob, dim=1)
    sum_teacher_output_prob = torch.clamp(sum_teacher_output_prob, max=1-EPS)
    loss += (student_output_beta[:, d - 1] - 1) * torch.log(1 - sum_teacher_output_prob)
    
    if verbose:
        print("loss2:", loss)
    
    loss += torch.sum(torch.lgamma(student_output_alpha + student_output_beta), dim=1) - torch.sum(torch.lgamma(student_output_alpha), dim=1) - \
            torch.sum(torch.lgamma(student_output_beta), dim=1) + torch.sum((student_output_alpha - 1) * torch.log(teacher_output_prob), dim=1)
    
    if verbose:
        print("loss3:", loss)
    
    return -loss.mean()

In [4]:
alpha = torch.tensor([[0.0876, 0.0360, 0.0444, 0.0248, 0.0153, 0.0206, 0.0174, 0.0242, 0.0112]])
beta = torch.tensor([[7.3526e+00, 2.2434e+01, 9.9446e+00, 9.6688e-02, 1.6363e-02, 4.5852e-02,
        1.2404e-01, 1.5865e-02, 5.2503e+00]])

P = torch.tensor([[2.5702e-07, 2.5177e-06, 2.2985e-05, 4.5485e-04, 2.8297e-07, 1.6700e-06,
        1.8722e-11, 9.9125e-01, 1.8576e-07]])

distillation_loss_generalized_dirichlet(P, alpha, beta)

tensor(-81.5249)

In [5]:
P = torch.tensor([[0.3, 0.3, 0.1]])

alpha = torch.tensor([[2., 2., 3.]])

beta = torch.tensor([[2., 3., 4.]])

distillation_loss_generalized_dirichlet(P, alpha, beta)

tensor(-2.4812)

In [6]:
teacher_output_prob = torch.tensor([[0.25, 0.25, 0.3, 0.1]])
student_output_alpha = torch.tensor([[1, 1, 1, 1]])
student_output_beta = torch.tensor([[1, 1, 1, 1]])

print(distillation_loss_generalized_dirichlet(teacher_output_prob, student_output_alpha, student_output_beta))
# torch.isclose(distillation_loss_generalized_dirichlet(teacher_output_prob, student_output_alpha, student_output_beta), torch.tensor(-3.2834))

tensor(-2.5903)


In [7]:
teacher_output_prob = torch.tensor([[0.25, 0.25, 0.25, 0.20]])
student_output_alpha = torch.tensor([[1, 1, 1, 1]])
student_output_beta = torch.tensor([[1, 1, 1, 1]])

print(distillation_loss_generalized_dirichlet(teacher_output_prob, student_output_alpha, student_output_beta))
# torch.isclose(distillation_loss_generalized_dirichlet(teacher_output_prob, student_output_alpha, student_output_beta), torch.tensor(-2.3671))

tensor(-2.3671)


In [8]:
teacher_output_prob = torch.tensor([[0.25, 0.25, 0.25, 0.20],
                                   [0.25, 0.25, 0.3, 0.1]])
student_output_alpha = torch.tensor([[1, 1, 1, 1],
                                    [1, 1, 1, 1]])
student_output_beta = torch.tensor([[1, 1, 1, 1],
                                   [1, 1, 1, 1]])

print(distillation_loss_generalized_dirichlet(teacher_output_prob, student_output_alpha, student_output_beta))
# torch.isclose(distillation_loss_generalized_dirichlet(teacher_output_prob, student_output_alpha, student_output_beta), torch.tensor(-2.3671))

tensor(-2.4787)


In [9]:
device = torch.device('cpu')

In [10]:
BATCH_SIZE_TRAIN = 100
BATCH_SIZE_VAL = 100
BATCH_SIZE_TEST = 100

In [11]:
trainset = torchvision.datasets.MNIST('./', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
testset = torchvision.datasets.MNIST('./', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

In [12]:
train_subset, val_subset = torch.utils.data.random_split(
        trainset, [50000, 10000], generator=torch.Generator().manual_seed(1))

In [13]:
train_loader = DataLoader(dataset=train_subset, shuffle=False, batch_size=BATCH_SIZE_TRAIN)
val_loader = DataLoader(dataset=val_subset, shuffle=False, batch_size=BATCH_SIZE_VAL)
test_loader = DataLoader(dataset=testset, shuffle=False, batch_size=BATCH_SIZE_TEST)

In [14]:
def evaluate(model, dataset, batch_size=100, max_ex=0):
    model.eval()
    acc = 0
    N = len(dataset) * batch_size
    for i, (features, labels) in enumerate(dataset):
        features = features.view(batch_size, -1)
        features = features.to(device)
        features = features.double()
        labels = labels.to(device)
        scores = model(features)
        # print(labels)
        # print(scores)
        pred = torch.argmax(scores, dim=1)
        acc += torch.sum(torch.eq(pred, labels)).item()
        if max_ex != 0 and i >= max_ex:
            break
    # print(i)
    return (acc * 100 / ((i+1) * batch_size) )

In [15]:
def evaluate_generalized_dirichlet(model, dataset, batch_size=100, max_ex=0):
    model.eval()
    acc = 0
    N = len(dataset) * batch_size
    for i, (features, labels) in enumerate(dataset):
        features = features.view(batch_size, -1)
        features = features.to(device)
        features = features.double()
        labels = labels.to(device)
        output_alpha, output_beta = model(features)
        
        output = output_alpha / (output_alpha + output_beta)
        d = output_alpha.shape[1]
        for l in range(0, d):
            prod = 1.
            for k in range(0, l):
                prod *= output_beta[:, k] / (output_alpha[:, k] + output_beta[:, k])
            output[:, l] *= prod
        
        # print(output.shape)
        # print((1 - torch.unsqueeze(torch.sum(output, dim=1), dim=1)).shape)
        torch.concat((output, 1 - torch.unsqueeze(torch.sum(output, dim=1), dim=1)), dim=1)
        pred = torch.argmax(output, dim=1)
        acc += torch.sum(torch.eq(pred, labels)).item()
        if max_ex != 0 and i >= max_ex:
            break
    # print(i)
    return (acc * 100 / ((i+1) * batch_size) )

In [16]:
def mean_squared_error(teacher_output_prob, student_output_alpha, student_output_beta):
    """
    teacher_output_prob: output of teacher model after applying softmax (N, d)
    student_output_alpha: output of student model (alpha) (N, d)
    student_output_beta: output of student model (beta) (N, d)
    labels are of size d+1
    """
    
    output = student_output_alpha / (student_output_alpha + student_output_beta)
    d = student_output_alpha.shape[1]
    for l in range(0, d):
        prod = 1.
        for k in range(0, l):
            prod *= student_output_beta[:, k] / (student_output_alpha[:, k] + student_output_beta[:, k])
        output[:, l] *= prod
    # torch.concat((output, 1 - torch.unsqueeze(torch.sum(output, dim=1), dim=1)), dim=1)
    return F.mse_loss(teacher_output_prob, output)

In [17]:
class TeacherModel(nn.Module):
    def __init__(self, dropout=0.5):
        super(TeacherModel, self).__init__()
        self.linear1 = nn.Linear(784, 400, dtype=torch.float64)
        self.dropout = nn.Dropout(p=dropout)
        self.linear2 = nn.Linear(400, 100, dtype=torch.float64)
        self.dropout = nn.Dropout(p=dropout)
        self.linear3 = nn.Linear(100, 10, dtype=torch.float64)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = F.relu(self.linear2(x))
        x = self.dropout(x)
        x = self.linear3(x)
        return x

In [18]:
NUM_EPOCHS = 10
lr = 0.001

In [19]:
teacher_model = TeacherModel()

In [20]:
criterion = nn.CrossEntropyLoss()
optimizer = Adam(teacher_model.parameters(), lr=lr)

In [21]:
teacher_model.to(device)
teacher_model.train()
for epoch in range(1, NUM_EPOCHS):
    for i, (data, target) in enumerate(train_loader):
        data = data.view(BATCH_SIZE_TRAIN, -1)
        data = data.to(device)
        data = data.double()
        target = target.to(device)
        
        output = teacher_model(data)
        loss = criterion(output, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            print("Loss:", loss.item())

Loss: 2.321927346979167
Loss: 0.5406929891843757
Loss: 0.2381379094473123
Loss: 0.3982020766564197
Loss: 0.30027667412171044
Loss: 0.2506617631915267
Loss: 0.25247357530419334
Loss: 0.26705201346340723
Loss: 0.2682749314215985
Loss: 0.24918368343072367
Loss: 0.1270754774954089
Loss: 0.14980813776985377
Loss: 0.20738678384495265
Loss: 0.25045025578714764
Loss: 0.17218106027782698
Loss: 0.14713575216159125
Loss: 0.21087226725638808
Loss: 0.1050023911637876
Loss: 0.15179182134546823
Loss: 0.09892041876898022
Loss: 0.13389585536363366
Loss: 0.08284455345968224
Loss: 0.08929186971024118
Loss: 0.24426680015342034
Loss: 0.12078413082162433
Loss: 0.10699673931127923
Loss: 0.1339071434258485
Loss: 0.04593023289810023
Loss: 0.1692386692115836
Loss: 0.08393953972107006
Loss: 0.17554591433020306
Loss: 0.24404950196258698
Loss: 0.06333964940669734
Loss: 0.11967236043496766
Loss: 0.1140977839833482
Loss: 0.15496087640197134
Loss: 0.14016759577865012
Loss: 0.08243407957765939
Loss: 0.0603129583703319

In [22]:
evaluate(teacher_model, train_loader)

98.69

In [23]:
evaluate(teacher_model, test_loader)

97.69

In [24]:
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.linear1 = nn.Linear(784, 50, dtype=torch.float64)
        self.linear2 = nn.Linear(50, 10, dtype=torch.float64)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [25]:
student_model = StudentModel()
student_model.to(device)

StudentModel(
  (linear1): Linear(in_features=784, out_features=50, bias=True)
  (linear2): Linear(in_features=50, out_features=10, bias=True)
)

In [26]:
criterion = nn.CrossEntropyLoss()
optimizer = Adam(student_model.parameters(), lr=lr)

In [27]:
student_model.train()
for epoch in range(1, NUM_EPOCHS):
    for i, (data, target) in enumerate(train_loader):
        data = data.view(BATCH_SIZE_TRAIN, -1)
        data = data.to(device)
        data = data.double()
        target = target.to(device)
        
        output = student_model(data)
        loss = criterion(output, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            print("Loss:", loss.item())

Loss: 2.270961151718689
Loss: 0.41900789647262604
Loss: 0.2354292438152886
Loss: 0.2867501562507435
Loss: 0.22119320720610222
Loss: 0.2604503886990442
Loss: 0.22001783687898618
Loss: 0.1398574933871171
Loss: 0.20669469511616356
Loss: 0.18165706807890777
Loss: 0.22949117268237146
Loss: 0.16986984730832538
Loss: 0.11333973193393738
Loss: 0.18944355834871293
Loss: 0.13694604067818372
Loss: 0.19334442112836392
Loss: 0.13648027180562702
Loss: 0.08832691675831245
Loss: 0.16430023327569676
Loss: 0.10590618842016253
Loss: 0.1659931083403088
Loss: 0.11721715770561181
Loss: 0.07155797104486906
Loss: 0.14652881334524653
Loss: 0.08275959090987998
Loss: 0.14321553095394457
Loss: 0.10039981932920196
Loss: 0.05246813914993325
Loss: 0.1482203384941135
Loss: 0.06945801219798503
Loss: 0.1248082182149465
Loss: 0.08646373454426357
Loss: 0.04665274424654235
Loss: 0.13533362445742536
Loss: 0.054295771366164286
Loss: 0.11142811551188653
Loss: 0.07861776679348634
Loss: 0.03950754803068454
Loss: 0.133166492854

In [28]:
evaluate(student_model, train_loader)

98.212

In [29]:
evaluate(student_model, test_loader)

96.62

In [30]:
def il(x):
    return torch.where(x < 0, (1 / (1 - x)), x + 1) + EPS

In [31]:
class StudentModelGeneralizedDirichlet(nn.Module):
    def __init__(self):
        super(StudentModelGeneralizedDirichlet, self).__init__()
        self.linear1 = nn.Linear(784, 50, dtype=torch.float64)
        self.output_alpha = nn.Linear(50, 9, dtype=torch.float64)
        self.output_beta = nn.Linear(50, 9, dtype=torch.float64) 

    def forward(self, x):
        x = F.relu(self.linear1(x))
        output_alpha = il(self.output_alpha(x))
        output_beta = il(self.output_beta(x))
        return output_alpha, output_beta

In [32]:
student_model_distilled_generalized_dirichlet = StudentModelGeneralizedDirichlet()
student_model_distilled_generalized_dirichlet.to(device)
optimizer = Adam(student_model_distilled_generalized_dirichlet.parameters(), lr=0.001)

In [33]:
student_model_distilled_generalized_dirichlet.train()
teacher_model.eval()

for epoch in range(1, NUM_EPOCHS):
    for i, (data, target) in enumerate(train_loader):
        data = data.view(BATCH_SIZE_TRAIN, -1)
        data = data.to(device)
        data = data.double()
        # print(data.dtype)
        # print(target.dtype)
        target = target.to(device)
        # target_one_hot_encoded = F.one_hot(target, num_classes=10)
        teacher_output = teacher_model(data)
        teacher_output = F.softmax(teacher_output, dim=1)
        teacher_output = teacher_output[:, :-1]
        student_output_alpha, student_output_beta = student_model_distilled_generalized_dirichlet(data)
        # print(student_output_alpha)
        # print(student_output_beta)
        # print(teacher_output[0])
        # student_output.retain_grad()
        
        
        loss = distillation_loss_generalized_dirichlet(teacher_output, student_output_alpha, student_output_beta, verbose=False)
        if i == 0:
            print("Loss:", loss.item())
            print("MSE:", mean_squared_error(teacher_output, student_output_alpha, student_output_beta).item())
            # print("alpha", student_output_alpha[0,])
            # print("beta", student_output_beta[0,])
            # print("P", teacher_output[0, ])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Loss: -52.17779732390217
MSE: 0.11166792284739137
Loss: -147.89911393376394
MSE: 0.04025670339046927
Loss: -149.11625976377948
MSE: 0.024466450215527343
Loss: -149.70077063133738
MSE: 0.01530498505556994
Loss: -150.07120507661134
MSE: 0.011611725505929073
Loss: -150.3110254001964
MSE: 0.009305463769942159
Loss: -150.48851102127549
MSE: 0.008052968771008931
Loss: -150.63948959655426
MSE: 0.006918096175341358
Loss: -150.76266200454933
MSE: 0.006408594751734392


In [34]:
evaluate_generalized_dirichlet(student_model_distilled_generalized_dirichlet, train_loader)

84.768

In [35]:
evaluate_generalized_dirichlet(student_model_distilled_generalized_dirichlet, test_loader)

84.47