In [None]:
### from __future__ import print_function
import os
import time
import logging
import argparse
import random
import numpy as np
from visdom import Visdom
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from utils import *

# Teacher models:
# VGG11/VGG13/VGG16/VGG19, GoogLeNet, AlxNet, ResNet18, ResNet34, 
# ResNet50, ResNet101, ResNet152, ResNeXt29_2x64d, ResNeXt29_4x64d, 
# ResNeXt29_8x64d, ResNeXt29_32x64d, PreActResNet18, PreActResNet34, 
# PreActResNet50, PreActResNet101, PreActResNet152, 
# DenseNet121, DenseNet161, DenseNet169, DenseNet201, 
import models

# Student models:
# myNet, LeNet, FitNet

start_time = time.time()
# os.makedirs('./checkpoint', exist_ok=True)

# Training settings
parser = argparse.ArgumentParser(description='PyTorch multi_teacher_avg_distill')

parser.add_argument('--dataset',
                    choices=['CIFAR10',
                             'CIFAR100'
                            ],
                    default='CIFAR10')
parser.add_argument('--teachers',
                    choices=['ResNet32',
                             'ResNet50',
                             'ResNet56',
                             'ResNet110'
                            ],
                    default=['ResNet32', 'ResNet56', 'ResNet110'],
                    nargs='+')
parser.add_argument('--student',
                    choices=['ResNet20',
                             'myNet'
                            ],
                    default='ResNet20')

parser.add_argument('--n_class', type=int, default=10, metavar='N', help='num of classes')
parser.add_argument('--T', type=float, default=20.0, metavar='Temputure', help='Temputure for distillation')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training')
parser.add_argument('--test_batch_size', type=int, default=128, metavar='N', help='input test batch size for training')
parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.5)')
parser.add_argument('--device', default='cuda:1', type=str, help='device: cuda or cpu')
parser.add_argument('--print_freq', type=int, default=10, metavar='N', help='how many batches to wait before logging training status')

config = ['--dataset', 'CIFAR100', '--n_class', '100', '--epochs', '200', '--T', '5.0', '--device', 'cuda:1']
args = parser.parse_args(config)

device = args.device if torch.cuda.is_available() else 'cpu'
load_dir = './checkpoint/' + args.dataset + '/'

# teachers model
teacher_models = []
for te in args.teachers:
    te_model = getattr(models, te)(num_classes=args.n_class)
#     print(te_model)
    te_model.load_state_dict(torch.load(load_dir + te_model.model_name + '.pth'))
    te_model.to(device)
    te_model.eval()  # eval mode
    teacher_models.append(te_model)

st_model = getattr(models, args.student)(num_classes=args.n_class)  # args.student()
st_model.to(device)

# logging
logfile = load_dir + 'avg_distill_' + st_model.model_name + '.log'
if os.path.exists(logfile):
    os.remove(logfile)
def log_out(info):
    f = open(logfile, mode='a')
    f.write(info)
    f.write('\n')
    f.close()
    print(info)
    
# visualizer
vis = Visdom(env='distill')
loss_win = vis.line(
    X=np.array([0]),
    Y=np.array([0]),
    opts=dict(
        title='multi avg. Loss',
        xlabel='epoch',
        xtickmin=0,
        ylabel='loss',
        ytickmin=0,
    ),
    name="loss"
)

acc_win = vis.line(
    X=np.column_stack((0, 0)),
    Y=np.column_stack((0, 0)),
    opts=dict(
        title='multi-KD avg. Acc',
        xlabel='epoch',
        xtickmin=0,
        ylabel='accuracy',
        ytickmin=0,
        ytickmax=100,
        legend=['train_acc', 'test_acc']
    ),
    name="acc"
)


# data
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, 4),
    transforms.ToTensor(),
    normalize,
])
test_transform = transforms.Compose([transforms.ToTensor(), normalize])
train_set = getattr(datasets, args.dataset)(root='../data', train=True, download=True, transform=train_transform)
test_set = getattr(datasets, args.dataset)(root='../data', train=False, download=False, transform=test_transform)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False)

# optimizer = optim.SGD(st_model.parameters(), lr=args.lr, momentum=args.momentum)
optimizer = optim.Adam(st_model.parameters(), lr=args.lr)
optimizer_sgd = optim.SGD(st_model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer_sgd, milestones=[100, 150])

# avg diatill
def distillation_loss(y, labels, logits, T, alpha=0.7):
    return nn.KLDivLoss()(F.log_softmax(y/T), logits) * (T*T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)

# triplet loss
triplet_loss = nn.TripletMarginLoss(margin=0.2, p=2).to(device)

# get max infoentropy scores
# input: Tensor[3, 128, 10]
def maxInfo_logits(te_scores_Tensor):
    used_score = torch.FloatTensor(te_scores_Tensor.size(1), te_scores_Tensor.size(2)).to(device)
    ents = torch.FloatTensor(te_scores_Tensor.size(0), te_scores_Tensor.size(1)).to(device)
    logp = torch.log2(te_scores_Tensor)
    plogp = -logp.mul(te_scores_Tensor)
    for i,te in enumerate(plogp):
        ents[i] = torch.sum(te, dim=1)
    max_ent_index = torch.max(ents, dim=0).indices   # 取每一列最大值index
#     print(max_ent_index)
    for i in range(max_ent_index.size(0)):
        used_score[i] = te_scores_Tensor[max_ent_index[i].item()][i]
#     print(used_score)

    return used_score
    
# avg logits
# input: Tensor[3, 128, 10]
def avg_logits(te_scores_Tensor):
#     print(te_scores_Tensor.size())
    mean_Tensor = torch.mean(te_scores_Tensor, dim=1)
#     print(mean_Tensor)
    return mean_Tensor
    
# random logits
def random_logits(te_scores_Tensor):
    return te_scores_Tensor[np.random.randint(0, 1, 1)]

# input: t1, t2 - triplet pair
def triplet_distance(t1, t2):
    return (t1 - t2).pow(2).sum()
    
# get triplets
def random_triplets(st_maps, te_maps):
    conflict = 0
    st_triplet_list = []
    triplet_set_size = st_maps.size(0)
    batch_list = [x for x in range(triplet_set_size)]
    for i in range(triplet_set_size):
        triplet_index = random.sample(batch_list, 3)
        anchor_index = triplet_index[0]  # denote the 1st triplet item as anchor
        st_triplet = st_maps[triplet_index]
        te_triplet = te_maps[triplet_index]
        distance_01 = triplet_distance(te_triplet[0], te_triplet[1])
        distance_02 = triplet_distance(te_triplet[0], te_triplet[2])
        if distance_01 > distance_02:
            conflict += 1
            # swap postive and negative
            st_triplet[1], st_triplet[2] = st_triplet[2], st_triplet[1]
        st_triplet_list.append(st_triplet)
    
    st_triplet_batch = torch.stack(st_triplet_list, dim=1)
    return st_triplet_batch
    
# get the smallest conflicts index
def smallest_conflict_teacher(st_maps, te_maps_list):
    
    index = 0
    triplet_set_size = st_maps.size(0)
    min_conflict = 1
    batch_list = [x for x in range(triplet_set_size)]
    triplet_index = random.sample(batch_list, 3)
    anchor_index = triplet_index[0]  # denote the 1st triplet item as anchor
    for idx, te_maps in enumerate(te_maps_list):
        conflict = 0
        for i in range(triplet_set_size):
            st_triplet = st_maps[triplet_index]
            te_triplet = te_maps[triplet_index]
            distance_01 = triplet_distance(te_triplet[0], te_triplet[1])
            distance_02 = triplet_distance(te_triplet[0], te_triplet[2])
            if distance_01 > distance_02:
                conflict += 1
        conflict /= triplet_set_size
        conflict = min(conflict, (1-conflict))
        if conflict < min_conflict:
            index = idx
    return index

# train with multi-teacher
def train(epoch, st_model):
    print('Training:')
    # switch to train mode
    st_model.train()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    
    end = time.time()
    for i, (input, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.to(device), target.to(device)
        
        # compute outputs
        b1, b2, b3, pool, output = st_model(input)
        st_maps = [b1, b2, b3, pool]
        
        te_scores_list = []
        hint_maps = []
        for j,te in enumerate(teacher_models):
            te.eval()
            with torch.no_grad():
                t_b1, t_b2, t_b3, t_pool, t_output = te(input)
            
            hint_maps.append(t_b2)
            t_output = F.softmax(t_output/args.T)
            te_scores_list.append(t_output)
        te_scores_Tensor = torch.stack(te_scores_list, dim=1)  # size: [128, 3, 10]
        mean_logits = avg_logits(te_scores_Tensor)
        
        
        te_index = smallest_conflict_teacher(b2, hint_maps)
        st_tripets = random_triplets(b2, hint_maps[te_index])
        
        optimizer_sgd.zero_grad()

        # compute gradient and do SGD step
        kd_loss = distillation_loss(output, target, mean_logits, T=args.T, alpha=0.7)
        relation_loss = triplet_loss(st_tripets[0], st_tripets[1], st_tripets[2])
        
        loss = kd_loss + relation_loss

        loss.backward(retain_graph=True)
        optimizer_sgd.step()

        output = output.float()
        loss = loss.float()
        # measure accuracy and record loss
        train_acc = accuracy(output.data, target.data)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(train_acc, input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            log_out('[{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      i, len(train_loader), batch_time=batch_time,
                      data_time=data_time, loss=losses, top1=top1))
    return losses.avg, train_acc.cpu().numpy()


def test(model):
    print('Testing:')
    # switch to evaluate mode
    model.eval()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(test_loader):
            input, target = input.to(device), target.to(device)

            # compute output
            _,_,_,_,output = model(input)
            loss = F.cross_entropy(output, target)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            test_acc = accuracy(output.data, target.data)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(test_acc, input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                log_out('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          i, len(test_loader), batch_time=batch_time, loss=losses,
                          top1=top1))

    log_out(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))

    return losses.avg, test_acc.cpu().numpy(), top1.avg.cpu().numpy()


print('StudentNet:\n')
print(st_model)
best_acc = 0
for epoch in range(1, args.epochs + 1):
    log_out("\n===> epoch: {}/{}".format(epoch, args.epochs))
    log_out('current lr {:.5e}'.format(optimizer_sgd.param_groups[0]['lr']))
    lr_scheduler.step()
    train_loss, train_acc = train(epoch, st_model)
    # visaulize loss
    vis.line(np.array([train_loss]), np.array([epoch]), loss_win, update="append")
    _, test_acc, top1 = test(st_model)
    vis.line(np.column_stack((train_acc, top1)), np.column_stack((epoch, epoch)), acc_win, update="append")
    if top1 > best_acc:
        best_acc = top1
        if epoch > 150:
            torch.save(st_model.state_dict(), load_dir + st_model.model_name + '_avg.pth')

# release GPU memory
torch.cuda.empty_cache()
log_out("@ BEST ACC = {:.4f}%".format(best_acc))
log_out("--- {:.3f} mins ---".format((time.time() - start_time)/60))


  init.kaiming_normal(m.weight)


Files already downloaded and verified
StudentNet:

ResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2



[0/391]	Time 0.889 (0.889)	Data 0.041 (0.041)	Loss 2.2553 (2.2553)	Prec@1 0.781 (0.781)
