In [223]:
import random

import torch
import pytorch_lightning as pl

In [224]:
torch.random.manual_seed(42)
random.seed(42)

In [225]:
teacher_logit = (torch.rand(3, 6) - 0.5) * random.randint(10, 500)
teacher_logit

tensor([[ 128.8247,  139.8563,  -39.4749,  154.7860,  -36.9190,   34.0017],
        [ -82.0351,   98.9571,  148.5400, -123.6163,  146.4595,   31.5363],
        [ 124.4893,   22.8201,   81.2487,  -23.7907,  129.8943,   24.9058]])

In [226]:
student_logit = (torch.rand(3, 6) - 0.5) * random.randint(10, 500)
student_logit

tensor([[-15.6391,   8.5391, -15.4347,  -3.9286, -13.6063,  22.2229],
        [-26.4439, -15.4438,  -9.4596, -20.1426,   3.1618, -33.0872],
        [ 30.2542, -28.4572,  25.8629,   5.5750, -10.8776,  20.7013]])

In [227]:
labels = [3, # ground truth 일때
         4, # ground truth는 아닌 데 topk안에 들 경우
         1, # topk안에도 없을 경우
        ]

labels = torch.tensor(labels, dtype=torch.int64).view(-1,1)
print(labels)

tensor([[3],
        [4],
        [1]])


In [228]:
# def DistillationLoss(student_logit, teacher_logit, T, threshold):
#     new_teacher_logit = teacher_logit + torch.abs(torch.min(teacher_logit, dim=1).values.reshape(-1, 1))
#     bar = torch.sort(new_teacher_logit, descending=True).values[:, threshold-1].reshape(-1, 1).repeat(1, teacher_logit.shape[1])
#     new_teacher_logit = torch.where(bar <= new_teacher_logit, new_teacher_logit, torch.zeros(1, device=torch.device('cuda')))
#     soft_label = F.softmax(new_teacher_logit / T, dim=1)
#     soft_prediction = F.log_softmax(student_logit / T, dim=1)
#     return F.kl_div(soft_prediction, soft_label)


# def FinalLoss(teacher_logit, student_logit, labels, T, alpha, threshold):
#     return (1. - alpha) * F.cross_entropy(student_logit, labels) \
#            + (alpha * T * T) * DistillationLoss(student_logit, teacher_logit, T, threshold)

In [229]:
import torch.nn.functional as F
import matplotlib.pyplot as plt


### option 1 - (student_logit, teacher_logit, T, T1, T2, K)
# topk 안에 드는 것은 suppress less
# topk 안에 안 드는 것은 suppress more
# 요약 - topk에 따라 Temperature를 다르게 하자 
def option1(student_logit, teacher_logit, T, T1, T2, K):
    new_teacher_logit = teacher_logit + torch.abs(torch.min(teacher_logit, dim=1).values.reshape(-1,1))
    new_teacher_logit = new_teacher_logit / 2
    bar = torch.sort(new_teacher_logit, descending=True).values[:, K-1].reshape(-1, 1).repeat(1, teacher_logit.shape[1])
    top = torch.where(bar <= new_teacher_logit, new_teacher_logit, torch.zeros(1))
    bot = torch.where(bar > new_teacher_logit, new_teacher_logit, torch.zeros(1))
    soft_label = F.softmax((top / T1) + (bot / T2), dim=1)
    soft_prediction = F.log_softmax(student_logit / T, dim=1)
    return F.kl_div(soft_prediction, soft_label)


### option 2 - (student_logit, teacher_logit, labels, T, T1, T2, K)
# option 1 + GT가 topk에 들든 안 들든 그만히 놔두자 (Temperature 적용 X)
def option2(student_logit, teacher_logit, labels, T, T1, T2, K):
    new_teacher_logit = teacher_logit + torch.abs(torch.min(teacher_logit, dim=1).values.reshape(-1,1))
    new_teacher_logit = new_teacher_logit / 2
    bar = torch.sort(new_teacher_logit, descending=True).values[:, K-1].reshape(-1, 1).repeat(1, teacher_logit.shape[1])
    new_teacher_logit_wo_gt = new_teacher_logit.scatter(1, labels, 0)
    top = torch.where(bar <= new_teacher_logit_wo_gt, new_teacher_logit_wo_gt, torch.zeros(1))
    bot = torch.where(bar > new_teacher_logit_wo_gt, new_teacher_logit_wo_gt, torch.zeros(1))
    gt = torch.where(torch.zeros_like(new_teacher_logit).scatter(1, labels, 1) == 1., new_teacher_logit, torch.zeros(1))
    top = top / T1
    bot = bot / T2
    soft_label = F.softmax(top + bot + gt, dim=1)
    soft_prediction = F.log_softmax(student_logit / T, dim=1)
    return F.kl_div(soft_prediction, soft_label)



### option 3 - (student_logit, teacher_logit, T, K)
# topk에 안 드는 것은 다 0으로 만들기 (GT 포함)
def option3(student_logit, teacher_logit, T, K):
    new_teacher_logit = teacher_logit + torch.abs(torch.min(teacher_logit, dim=1).values.reshape(-1, 1))
    new_teacher_logit = new_teacher_logit / 2
    bar = torch.sort(new_teacher_logit, descending=True).values[:, K-1].reshape(-1, 1).repeat(1, teacher_logit.shape[1])
    new_teacher_logit = torch.where(bar <= new_teacher_logit, new_teacher_logit, torch.zeros(1))
    soft_label = F.softmax(new_teacher_logit / T, dim=1)
    soft_prediction = F.log_softmax(student_logit / T, dim=1)
    return F.kl_div(soft_prediction, soft_label)


def FinalLoss_option1(teacher_logit, student_logit, labels, T, T1, T2, alpha, K):
    return (1. - alpha) * F.cross_entropy(student_logit, labels) \
           + (alpha * T1 * T2) * option1(student_logit, teacher_logit, T, T1, T2, K)
    
    
def FinalLoss_option2(teacher_logit, student_logit, labels, T, T1, T2, alpha, K):
    return (1. - alpha) * F.cross_entropy(student_logit, labels) \
           + (alpha * T1 * T2) * option2(student_logit, teacher_logit, labels, T, T1, T2, K)
    
    
def FinalLoss_option3(teacher_logit, student_logit, labels, T, alpha, K):
    return (1. - alpha) * F.cross_entropy(student_logit, labels) \
           + (alpha * T * T) * option3(student_logit, teacher_logit, T, K)

In [230]:
# option1(student_logit, teacher_logit, 2, 10, 20, 3)
option2(student_logit, teacher_logit, labels, 2, 10, 20, 3)
# option3(student_logit, teacher_logit, 2, 3)

tensor(2.3638)

In [None]:
def option1_plt(student_logit, teacher_logit, T1, T2, K):
    print(f'teacher_logit:\n{teacher_logit}')
    new_teacher_logit = teacher_logit + torch.abs(torch.min(teacher_logit, dim=1).values.reshape(-1,1))
    print(f'new_teacher_logit:\n{new_teacher_logit}')
    bar = torch.sort(new_teacher_logit, descending=True).values[:, K-1].reshape(-1, 1).repeat(1, teacher_logit.shape[1])
    print(f'bar:\n{bar}')
    top = torch.where(bar <= new_teacher_logit, new_teacher_logit, torch.zeros(1))
    bot = torch.where(bar > new_teacher_logit, new_teacher_logit, torch.zeros(1))
    print(f'top:\n{top}')
    print(f'bot:\n{bot}')
    print(f'top / T1:\n{top / T1}')
    print(f'bot / T2:\n{bot / T2}')
    print(f'final:\n{(top / T1) + (bot / T2)}')
    original = F.softmax(teacher_logit / T2, dim=1)
    modified = F.softmax((top / T1) + (bot / T2), dim=1)
    
    for i in range(3):
        plt.bar(range(6), original[i] * 100)
        plt.ylim(ymax = 100)
        plt.show()

        plt.bar(range(6), modified[i] * 100)
        plt.ylim(ymax = 100)
        plt.show()

    pass