In [1]:
import time
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import math
import torch.nn.functional as F
from scipy.stats import norm
import scipy
import torch.optim as optim

In [2]:
class AverageMeter(object):
    '''
    computes and stores the average and current value
    '''
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def accuracy(output, target, topk=(1, )):
    '''
    Computes the precision for the specified values of k
    '''
    maxk = max(topk)
    batch_size = target.size(0)
    
    return

In [2]:
def build_dataset(dataset, bs):
    print("==> Preparing data..")
    
    if dataset == 'cifar100':
        mean = [x/255 for x in [129.3, 124.1, 112.4]]
        std = [x/255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, f"Unknown dataset : {dataset}"
        
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    
    if dataset == 'cifar100':
        trainset = torchvision.datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train)
        
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=4)
    
    if dataset == 'cifar100':
        testset = torchvision.datasets.CIFAR100(root='../data', train=False, download=True, transform=transform_test)
        
    test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=4)
    
    return train_loader, test_loader

In [3]:
def test(net, device, data_loader, criterion):
    
    top1 = AverageMeter()
    net.eval()
    
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(data_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            outputs = outputs.float()
            prec1 = accuracy(outputs.data, targets)[0]
            top1.update(prec1.item(), inputs.size(0))
    return top1.avg

In [4]:
def conv3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion=1
    def __init__(self, in_planes, planes, stride=1, downsample=None, is_last=False):
        super(BasicBlock, self).__init__()
        self.is_last = is_last
        self.conv1 = conv3(in_planes, planes, stride) # 들어가는 차원의 크기 in_planes 에서 나가는 planes
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=False) # inplace=False 원본 데이터 보존하는 대신 추가메모리 사용,
        # 만약, inplace=True라면 들어가는 인수 또한 값이 output과 동일하게 바뀌는 현상 발생
        self.conv2 = conv3(planes, planes) # 들낙하는 차원은 planes로 fix
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
    
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        
        # downsample이 왜 필요한가? : 입력 텐서와 출력 텐서의 차원 불일치 문제 해소
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        # out은 잔차와 더해져서 다음 out은 relu를 거친것이고, preact는 relu 거치기전(pre-relu)
        preact = out.clone() # out.clone()은 무슨 의미인가? copy()와 같은 의미인가? Pytorch에서 텐서 복사할때 clone 메서드 사용
        out = self.relu(out)
        
        if self.is_last:
            # is_last가 각 블록마다의 distillation해야할 feature를 내보내야할 시기인가?
            return out, preact
        else:
            return out

In [5]:
class ResNet(nn.Module):
    def __init__(self, depth, num_classes=10):
        super(ResNet, self).__init__()
        
        # assert 가 경고한다는건데, ResNet-20, 32, 44 이런식으로 맞추는게 필요한가?
        assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110'
        block_num = (depth - 2) // 6 # block_num??
        self.in_planes = 16
        self.conv1 = conv3(in_planes=3, out_planes=16)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=False)
        self.layer1 = self._make_layer(planes=16, block_num=block_num)
        self.layer2 = self._make_layer(planes=32, block_num=block_num, stride=2)
        self.layer3 = self._make_layer(planes=64, block_num=block_num, stride=2)
        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, num_classes)
        
        for m in self.modules(): # ResNet 모델의 모든 Conv층, BN층을 접근해서 가중치 초기화 수행
            if isinstance(m, nn.Conv2d):
                # Conv층 같은 경우, fan_out(출력 유닛에 초점) <-> fan_in(입력 유닛에 초점) 을 사용
                # 그리고 ReLU 활성화함수를 거침에 따라 nonlinearity를 'relu'로 설정
                # kaiming_normal_ == He초기화
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            
    
    def _make_layer(self, planes, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != planes:
            # stride가 1이 아니거나 in_planes와 planes가 같지 않다면 downsample 변동
            downsample = nn.Sequential(
                nn.Conv2d(self.in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes),
            )
        layers = []
        layers.append(BasicBlock(self.in_planes, planes, stride, downsample))
        self.in_planes = planes
        for i in range(1, block_num):
            layers.append(BasicBlock(self.in_planes, planes))
        return nn.Sequential(*layers) 
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = F.relu(x) ## nn.ReLU() 은 모듈 객체를 생성함에 따라 __init__에서 주로 쓰이고, F.relu는 모듈 객체를 생성하지 않으므로 forward에서 자주 쓰임
        x = self.avgpool(x)
        x = x.view(x.size(0), -1) # == x.flatten(start_dim=1), 1차원으로 평탄화, view 쓰는 경우에는 복잡한 차원 재구성이 필요할때
        x = self.fc(x)
        return x
    
    def get_bn_before_relu(self):
        
        if isinstance(self.layer1[0], BasicBlock):
            bn1 = self.layer1[-1].bn2
            bn2 = self.layer2[-1].bn2
            bn3 = self.layer3[-1].bn2
        else:
            print('ResNet unknown block error')
        
        return [bn1, bn2, bn3]
    
    def get_channel_num(self):
        
        return [16, 32, 64]
    
    def extract_feature(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        
        feat1 = self.layer1(x)
        feat2 = self.layer2(feat1)
        feat3 = self.layer3(feat2)
        
        x = nn.ReLU(inplace=False)(feat3)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        out = self.fc(x)
        
        return [feat1, feat2, feat3], out

In [6]:
def resnet20(class_num=10):
    return ResNet(20, class_num)
    
def resnet32(class_num=10):
    return ResNet(32, class_num)
    
def resnet44(class_num=10):
    return ResNet(44, class_num)

def resnet56(class_num=10):
    return ResNet(56, class_num)

# 110

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.benchmark = True

In [8]:
train_loader, test_loader = build_dataset('cifar100', 128)

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [9]:
def distillation_loss(source, target, margin):
    loss = ((source - margin)**2 * ((source > margin) & (target <= margin)).float()
            + (source - target) ** 2 * ((source > target) & (target > margin) & (target <= 0)).float() + 
            (source - target) ** 2 * (target > 0).float())
    # loss function을 어떻게 이해하면 될까...
    return torch.abs(loss).sum()

In [10]:
def build_feature_connector(t_channel, s_channel):
    # Teacher와 Student 간의 Feature Distillation을 위한 connector 함수
    C = [nn.Conv2d(s_channel, t_channel, kernel_size=1, stride=1, padding=0, bias=False),
         nn.BatchNorm2d(t_channel)
         ]

    for m in C:
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
            # He가중치를 쓰는 상황이므로 아래와 같이도 쓸 수 있음
            # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
    return nn.Sequential(*C)

In [11]:
def get_margin_from_BN(bn):
    margin = []
    std = bn.weight.data
    mean = bn.bias.data
    for (s, m) in zip(std, mean):
        s = abs(s.item())
        m = m.item()
        if norm.cdf(-m / s) > 0.001:
            margin.append(-s * math.exp(-(m/s) ** 2 / 2) / math.sqrt(2 * math.pi) / norm.cdf(-m / s) + m)
        else:
            margin.append(-3 * s)
    
    return torch.FloatTensor(margin).to(std.device)

In [12]:
class Distiller(nn.Module):
    def __init__(self, t_net, s_net):
        super(Distiller, self).__init__()
        
        t_channels = t_net.get_channel_num()
        s_channels = s_net.get_channel_num()
        
        # build_feature_connector에서 리스트로 묶은 레이어가 있어서 ModuleList로 정의해줘야함.
        # zip을 쓰는 이유는 동시순회를 위해서임
        self.Connectors = nn.ModuleList([build_feature_connector(t, s) for t, s in zip(t_channels, s_channels)])
        
        teacher_bns = t_net.get_bn_before_relu()
        margins = [get_margin_from_BN(bn) for bn in teacher_bns]
        for i, margin in enumerate(margins):
            self.register_buffer('margin%d' % (i + 1), margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach())
            
        self.t_net = t_net
        self.s_net = s_net
    
    def forward(self, x):
        t_feats, t_out = self.t_net.extract_feature(x)
        s_feats, s_out = self.s_net.extract_feature(x)
        feat_num = len(t_feats)
        
        loss_distill = 0
        for i in range(feat_num):
            s_feats[i] = self.Connectors[i](s_feats[i])
            loss_distill += distillation_loss(s_feats[i], t_feats[i].detach(), getattr(self, 'margin%d' %(i+1))) / 2 ** (feat_num - i - 1)
        
        return s_out, loss_distill

In [16]:
train_accuracies=[]
test_accuracies=[]
class_num = 100

# Teacher / Student Network 정의 
t_net = resnet56(100).to(device)
s_net = resnet20(100).to(device)

criterion = nn.CrossEntropyLoss()

# Teacher / Student Network 끼리의 Distillation을 위한 Distiller 정의
d_net = Distiller(t_net, s_net)

optimizer = optim.Adam([{'params' : s_net.parameters()}, {'params' : d_net.Connectors.parameters()}], lr=0.05) # Momentum, weight_decay는 생략

In [None]:
def train_with_distill(d_net, optimizer, device, train_loader, criterion):
    d_net.to(device)
    d_net.s_net.to(device)
    d_net.t_net.to(device)
    d_net.train()
    d_net.s_net.train()
    d_net.t_net.train()
    top1 = AverageMeter()
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        
        inputs, targets = inputs.to(device), targets.to(device)
        
        batch_size = inputs.shape[0]
        outputs, loss_distill = d_net(inputs)
        loss_CE = criterion(outputs, targets)
        loss = loss_CE + 1e-4 * loss_distill.sum() / batch_size
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        outputs = outputs.float()
        prec1 = accuracy(outputs.data, targets)[0]
        top1.update(prec1.item(), inputs.size(0))
    return top1.avg

In [17]:
def train_teacher(t_net, optimizer, device, train_loader, criterion):
    t_net.to(device)
    t_net.train()
    
    print("training...")
    for epoch in range(150):
        running_loss = 0.0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            outputs = t_net(inputs)
            loss = criterion(outputs, targets)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        print(f"Epoch {epoch+1}/{150}, loss: {running_loss / len(train_loader)}")
            
    return t_net

In [18]:
trained_t_net = train_teacher(t_net, optimizer, device, train_loader, criterion)

training...
Epoch 1/150, loss: 5.350280089756412
Epoch 2/150, loss: 5.350695907612286
Epoch 3/150, loss: 5.351594646873377
Epoch 4/150, loss: 5.350470336806744
Epoch 5/150, loss: 5.349426378069631
Epoch 6/150, loss: 5.351610732505389
Epoch 7/150, loss: 5.349905806734129
Epoch 8/150, loss: 5.352293780392698
Epoch 9/150, loss: 5.351639112243262
Epoch 10/150, loss: 5.349537050632565
Epoch 11/150, loss: 5.35118992920117
Epoch 12/150, loss: 5.350918945449088
Epoch 13/150, loss: 5.348647307861796
Epoch 14/150, loss: 5.350719018970304
Epoch 15/150, loss: 5.349038360673753
Epoch 16/150, loss: 5.351611809352475
Epoch 17/150, loss: 5.349895082166433
Epoch 18/150, loss: 5.350416955435672
Epoch 19/150, loss: 5.351166100758116
Epoch 20/150, loss: 5.350278601926916
Epoch 21/150, loss: 5.351544472872448
Epoch 22/150, loss: 5.351359635667728
Epoch 23/150, loss: 5.35009429643831
Epoch 24/150, loss: 5.3509192893572175
Epoch 25/150, loss: 5.3514827255092925
Epoch 26/150, loss: 5.351053752557701
Epoch 27/

In [32]:
torch.save(trained_t_net.state_dict(), "../models/model")

In [None]:
start_epoch=0
start = time.time()
for epoch in range(start_epoch, 150):
    if epoch in [80, 120]:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.1 
    
    