In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import norm
import scipy
import math

> ### Distillation Loss

In [None]:
def distillation_loss(source, target, margin):
    loss = ((source - margin)**2 * ((source > margin) & (target <= margin)).float()
            + (source - target) ** 2 * ((source > target) & (target > margin) & (target <= 0)).float() + 
            (source - target) ** 2 * (target > 0).float())
    # loss function을 어떻게 이해하면 될까...
    return torch.abs(loss).sum()

> ### Teacher / Student 간 Distill을 위한 Connecter

In [None]:
def build_feature_connector(t_channel, s_channel):
    # Teacher와 Student 간의 Feature Distillation을 위한 connector 함수
    C = [nn.Conv2d(s_channel, t_channel, kernel_size=1, stride=1, padding=0, bias=False),
         nn.BatchNorm2d(t_channel)
         ]

    for m in C:
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
            # He가중치를 쓰는 상황이므로 아래와 같이도 쓸 수 있음
            # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
    return nn.Sequential(*C)

In [None]:
def get_margin_from_BN(bn):
    margin = []
    std = bn.weight.data
    mean = bn.bias.data
    for (s, m) in zip(std, mean):
        s = abs(s.item())
        m = m.item()
        if norm.cdf(-m / s) > 0.001:
            margin.append(-s * math.exp(-(m/s) ** 2 / 2) / math.sqrt(2 * math.pi) / norm.cdf(-m / s) + m)
        else:
            margin.append(-3 * s)
    
    return torch.FloatTensor(margin).to(std.device)

In [None]:
class Distiller(nn.Module):
    def __init__(self, t_net, s_net):
        super(Distiller, self).__init__()
        
        t_channels = t_net.get_channel_num()
        s_channels = s_net.get_channel_num()
        
        self.Connectors = nn.ModuleList([build_feature_connector(t, s) for t, s in zip(t_channels, s_channels)])
        
        teacher_bns = t_net.get_bn_before_relu()
        margins = [get_margin_from_BN(bn) for bn in teacher_bns]
        for i, margin in enumerate(margins):
            self.register_buffer('margin%d' % (i + 1), margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach())
            
        self.t_net = t_net
        self.s_net = s_net
    
    def forward(self, x):
        t_feats, t_out = self.t_net.extract_feature(x)
        s_feats, s_out = self.s_net.extract_feature(x)
        feat_num = len(t_feats)
        
        loss_distill = 0
        for i in range(feat_num):
            s_feats[i] = self.Connectors[i](s_feats[i])