In [None]:
#| default_exp distill.losses

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict

In [None]:
#| export
def SoftTarget(pred,          # Student predictions
               teacher_pred,  # Teacher predictions
               T=5,           # Temperature for softening
               **kwargs
):
    "Knowledge distillation with softened distributions (Hinton et al.)"
    student_soft = F.log_softmax(pred / T, dim=1)
    teacher_soft = F.softmax(teacher_pred / T, dim=1)
    return nn.KLDivLoss(reduction='batchmean')(student_soft, teacher_soft) * (T * T)

In [None]:
#| export
def Logits(pred,          # Student predictions
           teacher_pred,  # Teacher predictions
           **kwargs
):
    "Direct logit matching between student and teacher"
    return F.mse_loss(pred, teacher_pred)

In [None]:
#| export
def Mutual(pred,          # Student predictions
           teacher_pred,  # Teacher predictions
           **kwargs
):
    "KL divergence between student and teacher"
    student_log_prob = F.log_softmax(pred, dim=1)
    teacher_prob = F.softmax(teacher_pred, dim=1)
    return nn.KLDivLoss(reduction='batchmean')(student_log_prob, teacher_prob)

In [None]:
#| export
def Attention(fm_s,  # Student feature maps {name: tensor}
              fm_t,  # Teacher feature maps {name: tensor}
              p=2,   # Power for attention computation
              **kwargs
):
    "Attention transfer loss (Zagoruyko & Komodakis)"
    total_loss = 0.0
    for name_st, name_t in zip(fm_s, fm_t):
        student_attention = fm_s[name_st].pow(p).mean(1)
        teacher_attention = fm_t[name_t].pow(p).mean(1)
        student_norm = F.normalize(student_attention, dim=(1, 2))
        teacher_norm = F.normalize(teacher_attention, dim=(1, 2))
        total_loss += F.mse_loss(student_norm, teacher_norm)
    return total_loss

In [None]:
#| export
def ActivationBoundaries(fm_s,  # Student feature maps
                         fm_t,  # Teacher feature maps
                         m=2,   # Boundary margin
                         **kwargs
):
    "Boundary-based knowledge distillation (Heo et al.)"
    total_loss = 0.0
    for name_st, name_t in zip(fm_s, fm_t):
        student_act = fm_s[name_st]
        teacher_act = fm_t[name_t]
        positive_boundary = (student_act + m).pow(2) * ((student_act > -m) & (teacher_act <= 0)).float()
        negative_boundary = (student_act - m).pow(2) * ((student_act <= m) & (teacher_act > 0)).float()
        total_loss += (positive_boundary + negative_boundary).mean()
    return total_loss

In [None]:
#| export
def FitNet(fm_s,  # Student feature maps
           fm_t,  # Teacher feature maps
           **kwargs
):
    "FitNets: direct feature map matching (Romero et al.)"
    total_loss = 0.0
    for name_st, name_t in zip(fm_s, fm_t):
        total_loss += F.mse_loss(fm_s[name_st], fm_t[name_t])
    return total_loss

In [None]:
#| export
def Similarity(fm_s,  # Student feature maps
               fm_t,  # Teacher feature maps
               pred,  # Student predictions (unused, for API consistency)
               p=2,   # Normalization power
               **kwargs
):
    "Similarity-preserving knowledge distillation (Tung & Mori)"
    total_loss = 0.0
    for name_st, name_t in zip(fm_s, fm_t):
        student_flat = fm_s[name_st].view(fm_s[name_st].size(0), -1)
        teacher_flat = fm_t[name_t].view(fm_t[name_t].size(0), -1)
        student_sim = F.normalize(student_flat @ student_flat.t(), p=p, dim=1)
        teacher_sim = F.normalize(teacher_flat @ teacher_flat.t(), p=p, dim=1)
        total_loss += F.mse_loss(student_sim, teacher_sim)
    return total_loss