In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np

In [2]:
'''
This ResNet code from https://github.com/megvii-research/mdistiller/blob/master/mdistiller/models/cifar/resnet.py
'''

__all__ = ["resnet"]

def conv3x3(in_planes, out_planes, stride=1):
    '''
    3 x 3 convolution with padding
    '''
    return nn.Conv2d(
        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
    )
    
class BasicBlock(nn.Module):
    '''
    BasicBlock : Conv층 2개로 이루어지며, 잔차가 포함된 block
    '''
    expansion = 1
    
    def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False):
        super(BasicBlock, self).__init__()
        self.is_last = is_last
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True) #inplace=True로 하면 들어가는 인수 값이 output과 동일하게 변동, 메모리 절약 효과
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
    
    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        
        # Downsample이 필요하다면?
        if self.downsample is not None:
            residual = self.downsample(x)
        
        out += residual
        preact = out
        out = F.relu(out) #init에 쓰는 nn.ReLU와 다르게 forward에서는 F.relu 쓰인다고함
        if self.is_last:
            return out, preact
        else:
            return out

class ResNet(nn.Module):
    def __init__(self, depth, num_filters, block_name="BasicBlock", num_classes=10):
        super(ResNet, self).__init__()
    
        if block_name.lower() == "basicblock":
            assert(
                depth - 2
            ) % 6 == 0, "Basic block depth should be 6n+2, 20, 32, 44, 56, 110 등"
            n = (depth - 2) // 6
            block = BasicBlock
        else:
            raise ValueError("block_name should be Basicblock")
        
        self.inplanes = num_filters[0]
        self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(num_filters[0])
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, num_filters[1], n)
        self.layer2 = self._make_layer(block, num_filters[2], n, stride=2)
        self.layer3 = self._make_layer(block, num_filters[3], n, stride=2)
        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes)
        self.stage_channels = num_filters

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.inplanes,
                    planes * block.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layers = list([])
        layers.append(
            block(self.inplanes, planes, stride, downsample, is_last=(blocks == 1))
        )
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, is_last=(i == blocks - 1)))
        
        return nn.Sequential(*layers)
    
    def get_feat_modules(self):
        feat_m = nn.ModuleList([])
        feat_m.append(self.conv1)
        feat_m.append(self.bn1)
        feat_m.append(self.relu)
        feat_m.append(self.layer1)
        feat_m.append(self.layer2)
        feat_m.append(self.layer3)
        return feat_m
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        f0 = x

        x, f1_pre = self.layer1(x)
        f1 = x
        x, f2_pre = self.layer2(x)
        f2 = x
        x, f3_pre = self.layer3(x)
        f3 = x

        x = self.avgpool(x)
        avg = x.reshape(x.size(0), -1)
        out = self.fc(avg)

        feats = {}
        feats["feats"] = [f0, f1, f2, f3]
        feats["preact_feats"] = [f0, f1_pre, f2_pre, f3_pre]
        feats["pooled_feat"] = avg

        return out

In [3]:
def resnet8(**kwargs):
    return ResNet(8, [16, 16, 32, 64], "basicblock", **kwargs), "resnet8"

def resnet20(**kwargs):
    return ResNet(20, [16, 16, 32, 64], "basicblock", **kwargs), "resnet20"


In [4]:
def load_dataset(bz=64):
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ]
    )
    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ]
    )
    trainset = torchvision.datasets.CIFAR100(root='./../../data', train=True, download=True, transform=train_transform)
    testset = torchvision.datasets.CIFAR100(root='./../../data', train=False, download=True, transform=test_transform)

    train_loader = torch.utils.data.DataLoader(trainset, batch_size=bz, shuffle=True, num_workers=0)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=bz, shuffle=True, num_workers=0)

    return train_loader, test_loader

In [8]:
seed = 2021
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device : {device}")
train_loader, test_loader = load_dataset()
criterion = nn.CrossEntropyLoss()
student, student_name = resnet8(num_classes=100)
teacher, teacher_name = resnet20(num_classes=100)

student.to(device)
teacher.to(device)

optimizer = optim.SGD(student.parameters(), lr=0.05, momentum=0.9, weight_decay=0.0001)

device : cuda
Files already downloaded and verified
Files already downloaded and verified


In [9]:
student.train()
for batch_idx, (inputs, labels) in enumerate(train_loader):
    inputs, labels = inputs.to(device), labels.to(device)
    batch_size = inputs.size(0)
    optimizer.zero_grad()
    
    output_stu = student(inputs)
    output_tea = teacher(inputs)
    break

In [72]:
# instance-wise Distillation loss

def instance_distill_loss(student_logits, teacher_logits):
    # L2 norm이라 p=2라고 설정, L1 norm이면 p=1 
    student_norm = F.normalize(student_logits, p=2, dim=1)
    teacher_norm = F.normalize(teacher_logits, p=2, dim=1)
    
    instance_loss = F.mse_loss(student_norm, teacher_norm)
    
    return instance_loss
    
# class-wise distillation loss
def class_distill_loss(t_student_logits, t_teacher_logits):
    t_student_norm = F.normalize(t_student_logits, p=2, dim=1)
    t_teacher_norm = F.normalize(t_teacher_logits, p=2, dim=1)
    
    class_loss = F.mse_loss(t_student_norm, t_teacher_norm)
    
    return class_loss

# class correlation loss 
def class_correlation_loss(student_logits, teacher_logits):
    '''
    1. Class Correlation Matrix
    2. Frobenius Norm(L2 norm)
    '''
    # 1. Class Correlation Matrix
    N, C = student_logits.shape
    
    student_mean = torch.mean(student_logits, dim=0)
    teacher_mean = torch.mean(teacher_logits, dim=0)
    
    B_s, B_t = torch.zeros((N, N)).to(device), torch.zeros((N, N)).to(device)
    for j in range(C):
        student_j = student_logits[:, j]
        diff_s = student_j - student_mean[j]
        B_s += torch.outer(torch.t(diff_s), diff_s)
        
        teacher_j = teacher_logits[:, j]
        diff_t = teacher_j - teacher_mean[j]
        B_t += torch.outer(torch.t(diff_t), diff_t)
    
    B_s /= (C-1)
    B_t /= (C-1)
    
    # 2. Frobenius Norm(L2 norm)
    diff = B_s - B_t
    diff_norm = torch.norm(diff, 'fro') # Frobenius Norm
    class_corr_loss = (1 / (C**2)) * diff_norm ** 2
    
    return class_corr_loss
    

In [73]:
lamb, mu, nu = 0.1, 0.5, 0.4
t_output_stu = torch.t(output_stu)
t_output_tea = torch.t(output_tea)
ins = instance_distill_loss(output_stu, output_tea)
cla = class_distill_loss(t_output_stu, t_output_tea)
clcor = class_correlation_loss(output_stu, output_tea)
total_loss = ins * lamb + cla * mu + clcor * nu
