In [None]:
from __future__ import print_function
import os
import time
import logging
import argparse
import numpy as np
from visdom import Visdom
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from utils import *
from metric.loss import FitNet, AttentionTransfer, RKdAngle, RkdDistance

# Teacher models:
# VGG11/VGG13/VGG16/VGG19, GoogLeNet, AlxNet, ResNet18, ResNet34, 
# ResNet50, ResNet101, ResNet152, ResNeXt29_2x64d, ResNeXt29_4x64d, 
# ResNeXt29_8x64d, ResNeXt29_32x64d, PreActResNet18, PreActResNet34, 
# PreActResNet50, PreActResNet101, PreActResNet152, 
# DenseNet121, DenseNet161, DenseNet169, DenseNet201, 
import models

# Student models:
# myNet, LeNet, FitNet

start_time = time.time()

# Training settings
parser = argparse.ArgumentParser(description='PyTorch LR_adaptive_AT')

parser.add_argument('--dataset',
                    choices=['CIFAR10',
                             'CIFAR100'
                            ],
                    default='CIFAR10')
parser.add_argument('--teacher',
                    choices=['ResNet32',
                             'ResNet50',
                             'ResNet56',
                             'ResNet110'
                            ],
                    default='ResNet110')
parser.add_argument('--student',
                    choices=['ResNet20',
                             'myNet'
                            ],
                    default='ResNet20')
parser.add_argument('--dist_ratio', default=1, type=float)
parser.add_argument('--angle_ratio', default=2, type=float)
parser.add_argument('--at_ratio', default=1, type=float)

parser.add_argument('--n_class', type=int, default=100, metavar='N', help='num of classes')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training')
parser.add_argument('--test_batch_size', type=int, default=128, metavar='N', help='input test batch size for training')
parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.5)')
parser.add_argument('--device', default='cuda:1', type=str, help='device: cuda or cpu')
parser.add_argument('--print_freq', type=int, default=40, metavar='N', help='how many batches to wait before logging training status')

config = ['--dataset', 'CIFAR100', '--epochs', '200', '--at_ratio', '1', '--device', 'cuda:0']
args = parser.parse_args(config)

device = args.device if torch.cuda.is_available() else 'cpu'
load_dir = './checkpoint/' + args.dataset + '/'

# teacher model
te_model = getattr(models, args.teacher)(num_classes=args.n_class)
te_model.load_state_dict(torch.load(load_dir + te_model.model_name + '.pth'))
te_model.to(device)
te_model.eval()  # eval mode

st_model = getattr(models, args.student)(num_classes=args.n_class)  # args.student()
st_model.to(device)

# logging
logfile = load_dir + 'RKD_' + st_model.model_name + '.log'
if os.path.exists(logfile):
    os.remove(logfile)
def log_out(info):
    f = open(logfile, mode='a')
    f.write(info)
    f.write('\n')
    f.close()
    print(info)
    
# visualizer
vis = Visdom(env='distill')
loss_win = vis.line(
    X=np.array([0]),
    Y=np.array([0]),
    opts=dict(
        title='RKD Loss',
        xlabel='epoch',
        xtickmin=0,
        ylabel='loss',
        ytickmin=0,
        ytickstep=0.5,
    ),
    name="loss"
)

acc_win = vis.line(
    X=np.column_stack((0, 0)),
    Y=np.column_stack((0, 0)),
    opts=dict(
        title='RKD Acc',
        xlabel='epoch',
        xtickmin=0,
        ylabel='accuracy',
        ytickmin=0,
        ytickmax=100,
        legend=['train_acc', 'test_acc']
    ),
    name="acc"
)


# data
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, 4),
    transforms.ToTensor(),
    normalize,
])
test_transform = transforms.Compose([transforms.ToTensor(), normalize])
train_set = getattr(datasets, args.dataset)(root='../data', train=True, download=True, transform=train_transform)
test_set = getattr(datasets, args.dataset)(root='../data', train=False, download=False, transform=test_transform)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False)
# optim
optimizer_sgd = optim.SGD(st_model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer_sgd, gamma=0.1, milestones=[100, 150])


# attention transfer loss, distance loss, angular loss
at_criterion = AttentionTransfer().to(device)
dist_criterion = RkdDistance().to(device)
angle_criterion = RKdAngle().to(device)


# train with teacher
def train(epoch, model):
    print('Training:')
    # switch to train mode
    model.train()
    te_model.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    
    end = time.time()
    for i, (input, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.to(device), target.to(device)
        
        # compute outputs
        b1, b2, b3, pool, output = model(input)
        with torch.no_grad():
            t_b1, t_b2, t_b3, t_pool, t_output = te_model(input)
        
        optimizer_sgd.zero_grad()
        
        angle_loss = args.angle_ratio * angle_criterion(output, t_output)
        dist_loss = args.dist_ratio * dist_criterion(output, t_output)
        # attention loss
        at_loss = args.at_ratio * (at_criterion(b1, t_b1) + at_criterion(b2, t_b2) + at_criterion(b3, t_b3))
        entropy_loss = F.cross_entropy(output, target)
        loss = at_loss + angle_loss + dist_loss + entropy_loss

        loss.backward(retain_graph=True)
        optimizer_sgd.step()

        output = output.float()
        loss = loss.float()
        # measure accuracy and record loss
        train_acc = accuracy(output.data, target.data)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(train_acc, input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            log_out('[{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      i, len(train_loader), batch_time=batch_time,
                      data_time=data_time, loss=losses, top1=top1))
    return losses.avg, train_acc.cpu().numpy()


def test(model):
    print('Testing:')
    # switch to evaluate mode
    model.eval()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(test_loader):
            input, target = input.to(device), target.to(device)

            # compute output
            _,_,_,_,output = model(input)
            loss = F.cross_entropy(output, target)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            test_acc = accuracy(output.data, target.data)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(test_acc, input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                log_out('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          i, len(test_loader), batch_time=batch_time, loss=losses,
                          top1=top1))

    log_out(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))

    return losses.avg, test_acc.cpu().numpy(), top1.avg.cpu().numpy()


print('StudentNet:\n')
print(st_model)
st_model.apply(weights_init_normal)
best_acc = 0
for epoch in range(1, args.epochs + 1):
    log_out("\n===> epoch: {}/{}".format(epoch, args.epochs))
    log_out('current lr {:.5e}'.format(optimizer_sgd.param_groups[0]['lr']))
    lr_scheduler.step(epoch)
    train_loss, train_acc = train(epoch, st_model)
    # visaulize loss
    vis.line(np.array([train_loss]), np.array([epoch]), loss_win, update="append")
    _, test_acc, top1 = test(st_model)
    vis.line(np.column_stack((train_acc, top1)), np.column_stack((epoch, epoch)), acc_win, update="append")
    if top1 > best_acc:
        best_acc = top1
        
# release GPU memory
torch.cuda.empty_cache()
log_out("BEST ACC: {:.3f}".format(best_acc))
log_out("--- {:.3f} mins ---".format((time.time() - start_time)/60))


  init.kaiming_normal(m.weight)


Files already downloaded and verified
StudentNet:

ResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2



[40/391]	Time 0.055 (0.056)	Data 0.018 (0.019)	Loss 4.5492 (4.9636)	Prec@1 3.906 (3.030)
