In [None]:
from __future__ import print_function
import os
import time
import logging
import argparse
import numpy as np
from visdom import Visdom
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from utils import *
from metric.loss import RKdAngle, RkdDistance

# Teacher models:
# VGG11/VGG13/VGG16/VGG19, GoogLeNet, AlxNet, ResNet18, ResNet34, 
# ResNet50, ResNet101, ResNet152, ResNeXt29_2x64d, ResNeXt29_4x64d, 
# ResNeXt29_8x64d, ResNeXt29_32x64d, PreActResNet18, PreActResNet34, 
# PreActResNet50, PreActResNet101, PreActResNet152, 
# DenseNet121, DenseNet161, DenseNet169, DenseNet201, 
import models

# Student models:
# myNet, LeNet, FitNet

start_time = time.time()

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Distill Example')

parser.add_argument('--dataset',
                    choices=['CIFAR10',
                             'CIFAR100'
                            ],
                    default='CIFAR10')
parser.add_argument('--teachers',
                    default=['ResNet32', 'ResNet56', 'ResNet110'],
                    nargs='+')
parser.add_argument('--student',
                    choices=['ResNet8',
                             'ResNet15',
                             'ResNet20',
                             'myNet'
                            ],
                    default='ResNet8')
parser.add_argument('--kd_ratio', default=0.7, type=float)
parser.add_argument('--n_class', type=int, default=10, metavar='N', help='num of classes')
parser.add_argument('--T', type=float, default=20.0, metavar='Temputure', help='Temputure for distillation')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training')
parser.add_argument('--test_batch_size', type=int, default=128, metavar='N', help='input test batch size for training')
parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.5)')
parser.add_argument('--device', default='cuda:1', type=str, help='device: cuda or cpu')
parser.add_argument('--print_freq', type=int, default=10, metavar='N', help='how many batches to wait before logging training status')

config = ['--dataset', 'CIFAR10', '--epochs', '200', '--T', '5.0', '--n_class', '10', '--device', 'cuda:0']
args = parser.parse_args(config)

device = args.device if torch.cuda.is_available() else 'cpu'
load_dir = './checkpoint/' + args.dataset + '/'

# teachers model
teacher_models = []
for te in args.teachers:
    te_model = getattr(models, te)(num_classes=args.n_class)
#     print(te_model)
    te_model.load_state_dict(torch.load(load_dir + te_model.model_name + '.pth'))
    te_model.to(device)
    teacher_models.append(te_model)

st_model = getattr(models, args.student)(num_classes=args.n_class)  # args.student()
st_model.to(device)

# logging
logfile = load_dir + 'adapter_distill_' + st_model.model_name + '.log'
if os.path.exists(logfile):
    os.remove(logfile)
def log_out(info):
    f = open(logfile, mode='a')
    f.write(info)
    f.write('\n')
    f.close()
    print(info)
    
# visualizer
vis = Visdom(env='distill')
loss_win = vis.line(
    X=np.array([0]),
    Y=np.array([0]),
    opts=dict(
        title='multi ada. train loss',
        xlabel='epoch',
        xtickmin=0,
#         xtickmax=1,
        ylabel='loss',
        ytickmin=0,
#         ytickmax=1,
        ytickstep=0.5,
#         markers=True,
#         markersymbol='dot',
#         markersize=5,
    ),
    name="loss"
)

acc_win = vis.line(
    X=np.column_stack((0, 0)),
    Y=np.column_stack((0, 0)),
    opts=dict(
        title='multi-KD ada. ACC',
        xlabel='epoch',
        xtickmin=0,
        ylabel='accuracy',
        ytickmin=0,
        ytickmax=100,
#         markers=True,
#         markersymbol='dot',
#         markersize=5,
        legend=['train_acc', 'test_acc']
    ),
    name="acc"
)


# adapter model
class Adapter():
    def __init__(self, in_models, pool_size):
        # representations of teachers
        pool_ch = pool_size[1]  # 64
        pool_w = pool_size[2]   # 8
        LR_list = []
        torch.manual_seed(1)
        self.theta = torch.randn(len(in_models), pool_ch).to(device)  # [3, 64]
        self.theta.requires_grad_(True)
   
        self.max_feat = nn.MaxPool2d(kernel_size=(pool_w, pool_w), stride=pool_w).to(device)
        self.W = torch.randn(pool_ch, 1).to(device)
        self.W.requires_grad_(True)
        self.val = False

    def loss(self, y, labels, weighted_logits, T=10.0, alpha=0.7):
        ls = nn.KLDivLoss()(F.log_softmax(y/T), weighted_logits) * (T*T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)
        if not self.val:
            ls += 0.1 * (torch.sum(self.W * self.W) + torch.sum(torch.sum(self.theta * self.theta, dim=1), dim=0))
        return ls
        
    def gradient(self, lr=0.01):
        self.W.data = self.W.data - lr * self.W.grad.data
        # Manually zero the gradients after updating weights
        self.W.grad.data.zero_()
        
    def eval(self):
        self.val = True
        self.theta.detach()
        self.W.detach()
    
    # input size: [64, 8, 8], [128, 3, 10]
    def forward(self, conv_map, te_logits_list):
        beta = self.max_feat(conv_map)
        beta = torch.squeeze(beta)  # [128, 64]
        
        latent_factor = []
        for t in self.theta:
            latent_factor.append(beta * t)
#         latent_factor = torch.stack(latent_factor, dim=0)  # [3, 128, 64]
        alpha = []
        for lf in latent_factor:  # lf.size:[128, 64]
            alpha.append(lf.mm(self.W))
        alpha = torch.stack(alpha, dim=0)  # [3, 128, 1]
        alpha = torch.squeeze(alpha).transpose(0, 1) # [128, 3]
        miu = F.softmax(alpha)  # [128, 3]
        miu = torch.unsqueeze(miu, dim=2)
        weighted_logits = miu * te_logits_list  # [128, 3, 10]
        weighted_logits = torch.sum(weighted_logits, dim=1)
#         print(weighted_logits)
        
        return weighted_logits

# adapter instance
_,_,_,pool_m,_ = st_model(torch.randn(1,3, 128, 128).to(device))  # get pool_size of student
# reate adapter instance
adapter = Adapter(teacher_models, pool_m.size())


# data
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, 4),
    transforms.ToTensor(),
    normalize,
])
test_transform = transforms.Compose([transforms.ToTensor(), normalize])
train_set = getattr(datasets, args.dataset)(root='../data', train=True, download=True, transform=train_transform)
test_set = getattr(datasets, args.dataset)(root='../data', train=False, download=False, transform=test_transform)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False)
# optim
optimizer_W = optim.SGD([adapter.W], lr=args.lr, momentum=0.9)
optimizer_theta = optim.SGD([adapter.theta], lr=args.lr, momentum=0.9)
optimizer_sgd = optim.SGD(st_model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer_sgd, gamma=0.1, milestones=[100, 150])
lr_scheduler2 = optim.lr_scheduler.MultiStepLR(optimizer_W, milestones=[40, 50])
lr_scheduler3 = optim.lr_scheduler.MultiStepLR(optimizer_theta, milestones=[40, 50])

# attention transfer loss
dist_criterion = RkdDistance().to(device)
angle_criterion = RKdAngle().to(device)


def train_adapter(n_epochs=70, model=st_model):
    print('Training adapter:')
    start_time = time.time()
    model.train()
    adapter.eval()
    for ep in range(n_epochs):
        lr_scheduler2.step()
        lr_scheduler3.step()
        for i, (input, target) in enumerate(train_loader):

            input, target = input.to(device), target.to(device)
            # compute outputs
            b1, b2, b3, pool, output = model(input) # out_feat: 16, 32, 64, 64, - 
#             print('b1:{}, b2:{}, b3{}, pool:{}'.format(b1.size(), b2.size(), b3.size(), pool.size()))

            te_scores_list = []
            for j,te in enumerate(teacher_models):
                te.eval()
                with torch.no_grad():
                    t_b1, t_b2, t_b3, t_pool, t_output = te(input)
#                 print('t_b1:{}, t_b2:{}, t_b3{}, t_pool:{}'.format(t_b1.size(), t_b2.size(), t_b3.size(), t_pool.size()))
                t_output = F.softmax(t_output/args.T)
                te_scores_list.append(t_output)
            te_scores_Tensor = torch.stack(te_scores_list, dim=1)  # size: [128, 3, 10]
            
            optimizer_sgd.zero_grad()
            optimizer_W.zero_grad()
            optimizer_theta.zero_grad()
            
            weighted_logits = adapter.forward(pool, te_scores_Tensor)
            
            angle_loss = angle_criterion(output, weighted_logits)
            dist_loss = dist_criterion(output, weighted_logits)
            # compute gradient and do SGD step
            ada_loss = adapter.loss(output, target, weighted_logits, T=args.T, alpha=args.kd_ratio)
            loss = ada_loss + angle_loss + dist_loss
            
            loss.backward(retain_graph=True)
            optimizer_sgd.step()
            optimizer_W.step()
            optimizer_theta.step()
            
#          vis.line(np.array([loss.item()]), np.array([ep]), loss_win, update="append")
        log_out('epoch[{}/{}]adapter Loss: {:.4f}'.format(ep, n_epochs, loss.item()))
    end_time = time.time()
    log_out("--- adapter training cost {:.3f} mins ---".format((end_time - start_time)/60))


# train with multi-teacher
def train(epoch, model):
    print('Training:')
    # switch to train mode
    model.train()
    adapter.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    
    end = time.time()
    for i, (input, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.to(device), target.to(device)
        
        # compute outputs
        b1, b2, b3, pool, output = model(input)
        
        te_scores_list = []
        for j,te in enumerate(teacher_models):
            te.eval()
            with torch.no_grad():
            t_b1, t_b2, t_b3, t_pool, t_output = te(input)
                t_output = F.softmax(t_output/args.T)
            te_scores_list.append(t_output)
        te_scores_Tensor = torch.stack(te_scores_list, dim=1)  # size: [128, 3, 10]
        weighted_logits = adapter.forward(pool, te_scores_Tensor)
        
        optimizer_sgd.zero_grad()
        
        angle_loss = angle_criterion(output, weighted_logits)
        dist_loss = dist_criterion(output, weighted_logits)

        weighted_logits = adapter.forward(pool, te_scores_Tensor)
        # compute gradient and do SGD step
        ada_loss = adapter.loss(output, target, weighted_logits, T=args.T, alpha=args.kd_ratio)
        loss = ada_loss + angle_loss + dist_loss

        loss.backward(retain_graph=True)
        optimizer_sgd.step()

        output = output.float()
        loss = loss.float()
        # measure accuracy and record loss
        train_acc = accuracy(output.data, target.data)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(train_acc, input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            log_out('[{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      i, len(train_loader), batch_time=batch_time,
                      data_time=data_time, loss=losses, top1=top1))
    return losses.avg, train_acc.cpu().numpy()


def test(model):
    print('Testing:')
    # switch to evaluate mode
    model.eval()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(test_loader):
            input, target = input.to(device), target.to(device)

            # compute output
            _,_,_,_,output = model(input)
            loss = F.cross_entropy(output, target)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            test_acc = accuracy(output.data, target.data)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(test_acc, input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                log_out('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          i, len(test_loader), batch_time=batch_time, loss=losses,
                          top1=top1))

    log_out(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))

    return losses.avg, test_acc.cpu().numpy(), top1.avg.cpu().numpy()

# """
print('StudentNet:\n')
print(st_model)
st_model.apply(weights_init_normal)
train_adapter(n_epochs=80)
# st_model.apply(weights_init_normal)
best_acc = 0
for epoch in range(1, args.epochs + 1):
    log_out("\n===> epoch: {}/{}".format(epoch, args.epochs))
    log_out('current lr {:.5e}'.format(optimizer_sgd.param_groups[0]['lr']))
    lr_scheduler.step(epoch)
    train_loss, train_acc = train(epoch, st_model)
    # visaulize loss
    vis.line(np.array([train_loss]), np.array([epoch]), loss_win, update="append")
    _, test_acc, top1 = test(st_model)
    vis.line(np.column_stack((train_acc, top1)), np.column_stack((epoch, epoch)), acc_win, update="append")
    if top1 > best_acc:
        best_acc = top1
        
# release GPU memory
torch.cuda.empty_cache()
log_out("BEST ACC: {:.3f}".format(best_acc))
log_out("--- {:.3f} mins ---".format((time.time() - start_time)/60))
# """

  init.kaiming_normal(m.weight)
  init.kaiming_normal(m.weight)


Files already downloaded and verified
StudentNet:

ResNet(
  (conv1): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2



epoch[0/70]adapter Loss: 1.9912
epoch[1/70]adapter Loss: 1.7895
epoch[2/70]adapter Loss: 1.4977
epoch[3/70]adapter Loss: 1.3384
epoch[4/70]adapter Loss: 1.2228
epoch[5/70]adapter Loss: 1.3414
epoch[6/70]adapter Loss: 1.2330
epoch[7/70]adapter Loss: 0.8899
epoch[8/70]adapter Loss: 1.0103
epoch[9/70]adapter Loss: 0.8968
epoch[10/70]adapter Loss: 0.8606
epoch[11/70]adapter Loss: 1.1236
epoch[12/70]adapter Loss: 0.8837
epoch[13/70]adapter Loss: 0.7829
epoch[14/70]adapter Loss: 0.8977
epoch[15/70]adapter Loss: 0.6793
epoch[16/70]adapter Loss: 0.8539
epoch[17/70]adapter Loss: 0.9417
epoch[18/70]adapter Loss: 0.8477
epoch[19/70]adapter Loss: 0.7535
epoch[20/70]adapter Loss: 0.7869
epoch[21/70]adapter Loss: 0.9029
epoch[22/70]adapter Loss: 0.7928
epoch[23/70]adapter Loss: 0.6864
epoch[24/70]adapter Loss: 0.9132
epoch[25/70]adapter Loss: 0.9764
epoch[26/70]adapter Loss: 0.8576
epoch[27/70]adapter Loss: 0.9804
epoch[28/70]adapter Loss: 0.6362
epoch[29/70]adapter Loss: 0.8282
epoch[30/70]adapter 



[10/391]	Time 0.066 (0.067)	Data 0.019 (0.019)	Loss 0.8254 (0.7442)	Prec@1 78.906 (81.179)
[20/391]	Time 0.068 (0.067)	Data 0.019 (0.019)	Loss 0.7429 (0.7140)	Prec@1 78.125 (81.845)
[30/391]	Time 0.069 (0.067)	Data 0.019 (0.019)	Loss 0.7722 (0.7420)	Prec@1 77.344 (81.048)
[40/391]	Time 0.067 (0.068)	Data 0.018 (0.019)	Loss 0.8041 (0.7452)	Prec@1 76.562 (81.136)
[50/391]	Time 0.072 (0.068)	Data 0.019 (0.019)	Loss 0.6607 (0.7487)	Prec@1 81.250 (80.974)
[60/391]	Time 0.066 (0.068)	Data 0.019 (0.019)	Loss 0.6713 (0.7427)	Prec@1 84.375 (81.301)
[70/391]	Time 0.066 (0.068)	Data 0.019 (0.019)	Loss 0.5587 (0.7462)	Prec@1 86.719 (81.173)
[80/391]	Time 0.067 (0.068)	Data 0.019 (0.019)	Loss 0.7752 (0.7487)	Prec@1 78.906 (81.096)
[90/391]	Time 0.071 (0.068)	Data 0.018 (0.019)	Loss 0.7360 (0.7443)	Prec@1 82.812 (81.216)
[100/391]	Time 0.066 (0.068)	Data 0.018 (0.019)	Loss 0.6633 (0.7422)	Prec@1 82.812 (81.250)
[110/391]	Time 0.066 (0.068)	Data 0.019 (0.019)	Loss 0.6944 (0.7402)	Prec@1 82.812 (81.28

Test: [40/79]	Time 0.017 (0.018)	Loss 0.8105 (0.9277)	Prec@1 81.250 (77.115)
Test: [50/79]	Time 0.017 (0.017)	Loss 0.8751 (0.9163)	Prec@1 77.344 (77.206)
Test: [60/79]	Time 0.017 (0.017)	Loss 0.8806 (0.9182)	Prec@1 78.906 (77.267)
Test: [70/79]	Time 0.017 (0.017)	Loss 0.6500 (0.9187)	Prec@1 81.250 (77.113)
 * Prec@1 77.170

===> epoch: 3/200
current lr 1.00000e-01
Training:
[0/391]	Time 0.077 (0.077)	Data 0.023 (0.023)	Loss 0.8637 (0.8637)	Prec@1 78.125 (78.125)
[10/391]	Time 0.069 (0.070)	Data 0.019 (0.019)	Loss 0.6988 (0.7661)	Prec@1 78.906 (80.611)
[20/391]	Time 0.065 (0.068)	Data 0.018 (0.019)	Loss 0.7433 (0.7608)	Prec@1 81.250 (80.729)
[30/391]	Time 0.067 (0.068)	Data 0.019 (0.019)	Loss 0.8730 (0.7586)	Prec@1 78.906 (80.922)
[40/391]	Time 0.066 (0.067)	Data 0.019 (0.019)	Loss 0.8082 (0.7353)	Prec@1 80.469 (81.669)
[50/391]	Time 0.066 (0.067)	Data 0.019 (0.019)	Loss 0.7477 (0.7365)	Prec@1 83.594 (81.464)
[60/391]	Time 0.066 (0.067)	Data 0.018 (0.019)	Loss 0.7732 (0.7512)	Prec@1 78.

[380/391]	Time 0.066 (0.067)	Data 0.018 (0.019)	Loss 0.6819 (0.7465)	Prec@1 81.250 (81.312)
[390/391]	Time 0.048 (0.067)	Data 0.012 (0.019)	Loss 0.5979 (0.7468)	Prec@1 87.500 (81.314)
Testing:
Test: [0/79]	Time 0.018 (0.018)	Loss 0.8818 (0.8818)	Prec@1 75.781 (75.781)
Test: [10/79]	Time 0.018 (0.017)	Loss 1.1540 (1.1219)	Prec@1 74.219 (74.006)
Test: [20/79]	Time 0.017 (0.017)	Loss 0.9263 (1.1869)	Prec@1 80.469 (73.251)
Test: [30/79]	Time 0.017 (0.017)	Loss 0.9475 (1.1852)	Prec@1 81.250 (73.564)
Test: [40/79]	Time 0.018 (0.017)	Loss 0.9422 (1.1819)	Prec@1 71.875 (73.819)
Test: [50/79]	Time 0.018 (0.017)	Loss 1.1334 (1.1700)	Prec@1 72.656 (74.127)
Test: [60/79]	Time 0.017 (0.017)	Loss 1.2141 (1.1657)	Prec@1 70.312 (74.219)
Test: [70/79]	Time 0.017 (0.017)	Loss 0.9716 (1.1686)	Prec@1 76.562 (74.120)
 * Prec@1 74.130

===> epoch: 5/200
current lr 1.00000e-01
Training:
[0/391]	Time 0.072 (0.072)	Data 0.020 (0.020)	Loss 0.7315 (0.7315)	Prec@1 81.250 (81.250)
[10/391]	Time 0.072 (0.071)	Data 