In [8]:
from __future__ import print_function
import os
import time
import logging
import argparse
import numpy as np
from visdom import Visdom
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from utils import *

# Teacher models:
# VGG11/VGG13/VGG16/VGG19, GoogLeNet, AlxNet, ResNet18, ResNet34, 
# ResNet50, ResNet101, ResNet152, ResNeXt29_2x64d, ResNeXt29_4x64d, 
# ResNeXt29_8x64d, ResNeXt29_32x64d, PreActResNet18, PreActResNet34, 
# PreActResNet50, PreActResNet101, PreActResNet152, 
# DenseNet121, DenseNet161, DenseNet169, DenseNet201, 
import models

# Student models:
# myNet, LeNet, FitNet

start_time = time.time()
# os.makedirs('./checkpoint', exist_ok=True)

# Training settings
parser = argparse.ArgumentParser(description='PyTorch DML 2S')

parser.add_argument('--dataset',
                    choices=['CIFAR10',
                             'CIFAR100'
                            ],
                    default='CIFAR10')
parser.add_argument('--nets',
                    choices=['ResNet32',
                             'ResNet50',
                             'ResNet56',
                             'ResNet110'
                            ],
                    default=['ResNet20', 'ResNet56', 'ResNet110'],
                    nargs='+')

parser.add_argument('--n_class', type=int, default=10, metavar='N', help='num of classes')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training')
parser.add_argument('--test_batch_size', type=int, default=128, metavar='N', help='input test batch size for training')
parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.5)')
parser.add_argument('--device', default='cuda:1', type=str, help='device: cuda or cpu')
parser.add_argument('--print_freq', type=int, default=40, metavar='N', help='how many batches to wait before logging training status')

config = ['--epochs', '200', '--device', 'cuda:1']
args = parser.parse_args(config)

device = args.device if torch.cuda.is_available() else 'cpu'
save_dir = './checkpoint/' + args.dataset + '/'

# models
nets_list = []
for m in args.nets:
    net = getattr(models, m)()
    net.to(device)
    net.train()  # train mode
    nets_list.append(net)

K = len(nets_list)
    
# logging
logfile = save_dir + 'DML_3S_.log'
if os.path.exists(logfile):
    os.remove(logfile)
def log_out(info):
    f = open(logfile, mode='a')
    f.write(info)
    f.write('\n')
    f.close()
    print(info)
    
# visualizer
vis = Visdom(env='distill')
loss_win = vis.line(
    X=np.array([0]*K),
    Y=np.array([0]*K),
    opts=dict(
        title='DML_3S Loss',
        xlabel='epoch',
        xtickmin=0,
        ylabel='loss',
        ytickmin=0,
        ytickstep=0.5
    ),
    name="loss"
)

acc_win = vis.line(
    X=np.array([0]*K),
    Y=np.array([0]*K),
    opts=dict(
        title='DML_3S Acc',
        xlabel='epoch',
        xtickmin=0,
        ylabel='accuracy',
        ytickmin=0,
        ytickmax=100
    ),
    name="acc"
)


# data
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, 4),
    transforms.ToTensor(),
    normalize,
])
test_transform = transforms.Compose([transforms.ToTensor(), normalize])
train_set = datasets.CIFAR10(root='../data', train=True, download=True, transform=train_transform)
test_set = datasets.CIFAR10(root='../data', train=False, download=False, transform=test_transform)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False)

# optimizer = optim.SGD(st_model.parameters(), lr=args.lr, momentum=args.momentum)
optimizers_list = []
lr_scheduler_list = []
for m in nets_list:
    optimizer_m = optim.SGD(m.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    lr_scheduler_m = optim.lr_scheduler.MultiStepLR(optimizer_m, milestones=[100, 150])
    optimizers_list.append(optimizer_m)
    lr_scheduler_list.append(lr_scheduler_m)

    
# train with multi-teacher
def train(epoch, nets_list):
    print('Training:')
    K = len(nets_list)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses_list = [AverageMeter()] * K
    top1_list = [AverageMeter()] * K
    
    end = time.time()
    for i, (input, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.to(device), target.to(device)
        
        # compute outputs
        output_list = []
        logits_list = []
        for net in nets_list:
            _,_,_,_, output_m = net(input)
            logits_m = F.softmax(output_m)
            output_list.append(output_m)
            logits_list.append(logits_m)
        
        for j in range(K):
            loss_j = 0
            
            optimizers_list[j].zero_grad()
            for h in range(K):
                if h != j:
                    loss_j += nn.KLDivLoss()(logits_list[h], logits_list[j]) 
            loss_j /= K - 1
            loss_j += F.cross_entropy(output_list[j], target)
            loss_j.backward()  # retain_graph=True
            optimizers_list[j].step()
            
            # measure accuracy and record loss
            netj_acc = accuracy(output_list[j], target)[0]
            losses_list[j].update(loss_j.item(), input.size(0))
            top1_list[j].update(netj_acc, input.size(0))
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            for j in range(K):
                log_out('[{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 ({top1_1.avg:.3f})'.format(
                          i, len(train_loader), batch_time=batch_time,
                          data_time=data_time, loss=losses_list[j], top1_1=top1_list[j]))
    
    losses_list = [losses_list[j].avg for j in range(K)]
    top1_list = [top1_list[j].avg.cpu().numpy() for j in range(K)]
    
    return losses_list, top1_list


def test(model):
    print('Testing:')
    # switch to evaluate mode
    model.eval()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(test_loader):
            input, target = input.to(device), target.to(device)

            # compute output
            _,_,_,_,output = model(input)
            loss = F.cross_entropy(output, target)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            test_acc = accuracy(output.data, target.data)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(test_acc, input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                log_out('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          i, len(test_loader), batch_time=batch_time, loss=losses,
                          top1=top1))

    log_out(' * {0} Prec@1 {top1.avg:.3f}'.format(model.model_name, top1=top1))

    return losses.avg, test_acc.cpu().numpy(), top1.avg.cpu().numpy()


print('*-----------------DML----------------*')
best_acc_list = [0] * K
for epoch in range(1, args.epochs + 1):
    log_out("\n===> epoch: {}/{}".format(epoch, args.epochs))
#     log_out('current lr {:.5e}'.format(optimizer_1.param_groups[0]['lr']))
    for j in range(K):
        lr_scheduler_list[j].step()
    train_loss_list, top1_list = train(epoch, nets_list)
    # visaulize loss
    vis.line(np.column_stack(np.array(train_loss_list)), np.column_stack((epoch) * K), loss_win, update="append")
    top1_list = []
    for j in range(K):
        _, _, top1 = test(nets_list[j])
        best_acc_list[j] = max(top1, best_acc_list[j])
        top1_list.append(top1)
    
    vis.line(np.column_stack(np.array(top1_list))),  np.column_stack((epoch) * K), acc_win, update="append")
    
for j in range(K):
    log_out("@ [{}] BEST Prec: {:.4f}".format(nets_list[j].model_name, best_acc_list[j]))
log_out("--- {:.3f} mins ---".format((time.time() - start_time)/60))




Files already downloaded and verified
*-----------------DML----------------*

===> epoch: 1/200
Training:




[0/391]	Time 0.343 (0.343)	Data 0.019 (0.019)	Loss 6.4467 (4.5754)	Prec@1 (7.031)
[0/391]	Time 0.343 (0.343)	Data 0.019 (0.019)	Loss 6.4467 (4.5754)	Prec@1 (7.031)
[0/391]	Time 0.343 (0.343)	Data 0.019 (0.019)	Loss 6.4467 (4.5754)	Prec@1 (7.031)
[40/391]	Time 0.317 (0.319)	Data 0.018 (0.018)	Loss 2.3183 (2.7125)	Prec@1 (13.993)
[40/391]	Time 0.317 (0.319)	Data 0.018 (0.018)	Loss 2.3183 (2.7125)	Prec@1 (13.993)
[40/391]	Time 0.317 (0.319)	Data 0.018 (0.018)	Loss 2.3183 (2.7125)	Prec@1 (13.993)
[80/391]	Time 0.328 (0.319)	Data 0.018 (0.018)	Loss 1.9819 (2.3065)	Prec@1 (16.917)
[80/391]	Time 0.328 (0.319)	Data 0.018 (0.018)	Loss 1.9819 (2.3065)	Prec@1 (16.917)
[80/391]	Time 0.328 (0.319)	Data 0.018 (0.018)	Loss 1.9819 (2.3065)	Prec@1 (16.917)
[120/391]	Time 0.328 (0.320)	Data 0.018 (0.018)	Loss 1.9261 (2.1333)	Prec@1 (19.551)
[120/391]	Time 0.328 (0.320)	Data 0.018 (0.018)	Loss 1.9261 (2.1333)	Prec@1 (19.551)
[120/391]	Time 0.328 (0.320)	Data 0.018 (0.018)	Loss 1.9261 (2.1333)	Prec@1 (19.

KeyboardInterrupt: 