In [None]:
from __future__ import print_function
import os
import time
import logging
import argparse
import numpy as np
from visdom import Visdom
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from utils import *

# Teacher models:
# VGG11/VGG13/VGG16/VGG19, GoogLeNet, AlxNet, ResNet18, ResNet34, 
# ResNet50, ResNet101, ResNet152, ResNeXt29_2x64d, ResNeXt29_4x64d, 
# ResNeXt29_8x64d, ResNeXt29_32x64d, PreActResNet18, PreActResNet34, 
# PreActResNet50, PreActResNet101, PreActResNet152, 
# DenseNet121, DenseNet161, DenseNet169, DenseNet201, 
import models

# Student models:
# myNet, LeNet, FitNet

start_time = time.time()
# os.makedirs('./checkpoint', exist_ok=True)

# Training settings
parser = argparse.ArgumentParser(description='PyTorch DML 3S')

parser.add_argument('--dataset',
                    choices=['CIFAR10',
                             'CIFAR100'
                            ],
                    default='CIFAR10')
parser.add_argument('--net1',
                    choices=['ResNet20',
                             'ResNet32',
                             'ResNet50',
                             'ResNet56',
                             'ResNet110'
                            ],
                    default='ResNet20')
parser.add_argument('--net2',
                    choices=['ResNet20',
                             'ResNet32',
                             'ResNet50',
                             'ResNet56',
                             'ResNet110'
                            ],
                    default='ResNet20')

parser.add_argument('--net3',
                    choices=['ResNet20',
                             'ResNet32',
                             'ResNet50',
                             'ResNet56',
                             'ResNet110'
                            ],
                    default='ResNet20')

parser.add_argument('--net4',
                    choices=['ResNet20',
                             'ResNet32',
                             'ResNet34',
                             'ResNet50',
                             'ResNet56',
                             'ResNet110',
                             'VGG19',
                             'GoogleNet',
                             'DenseNet121'
                            ],
                    default='ResNet20')

parser.add_argument('--net5',
                    choices=['ResNet20',
                             'ResNet32',
                             'ResNet34',
                             'ResNet50',
                             'ResNet56',
                             'ResNet110',
                             'VGG19',
                             'GoogleNet',
                             'DenseNet121'
                            ],
                    default='ResNet20')

parser.add_argument('--n_class', type=int, default=10, metavar='N', help='num of classes')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training')
parser.add_argument('--test_batch_size', type=int, default=128, metavar='N', help='input test batch size for training')
parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.5)')
parser.add_argument('--device', default='cuda:1', type=str, help='device: cuda or cpu')
parser.add_argument('--print_freq', type=int, default=40, metavar='N', help='how many batches to wait before logging training status')

config = ['--dataset', 'CIFAR10', '--epochs', '200', '--n_class', '10', '--device', 'cuda:0']
args = parser.parse_args(config)

device = args.device if torch.cuda.is_available() else 'cpu'
save_dir = './checkpoint/' + args.dataset + '/'

# models
net1 = getattr(models, args.net1)(num_classes=args.n_class)
net1.to(device)
net2 = getattr(models, args.net2)(num_classes=args.n_class)
net2.to(device)
net3 = getattr(models, args.net3)(num_classes=args.n_class)
net3.to(device)
net4 = getattr(models, args.net4)(num_classes=args.n_class)
net4.to(device)
net5 = getattr(models, args.net5)(num_classes=args.n_class)
net5.to(device)

# logging
logfile = save_dir + 'DML_3S_.log'
if os.path.exists(logfile):
    os.remove(logfile)
def log_out(info):
    f = open(logfile, mode='a')
    f.write(info)
    f.write('\n')
    f.close()
    print(info)
    
# visualizer
vis = Visdom(env='distill')
loss_win = vis.line(
    X=np.column_stack((0,0,0,0,0)),
    Y=np.column_stack((0,0,0,0,0)),
    opts=dict(
        title='DML 5S Loss',
        xlabel='epoch',
        xtickmin=0,
        ylabel='loss',
        ytickmin=0,
        ytickstep=0.5,
        legend=['net1_loss', 'net2_loss', 'net3_loss', 'net4_loss', 'net5_loss']
    ),
    name="loss"
)

acc_win = vis.line(
    X=np.column_stack((0,0,0,0,0)),
    Y=np.column_stack((0,0,0,0,0)),
    opts=dict(
        title='DML 5S Acc',
        xlabel='epoch',
        xtickmin=0,
        ylabel='accuracy',
        ytickmin=0,
        ytickmax=100,
        legend=['net1_acc', 'net2_acc', 'net3_acc', 'net4_acc', 'net5_acc']
    ),
    name="acc"
)


# data
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, 4),
    transforms.ToTensor(),
    normalize,
])
test_transform = transforms.Compose([transforms.ToTensor(), normalize])
train_set = getattr(datasets, args.dataset)(root='../data', train=True, download=True, transform=train_transform)
test_set = getattr(datasets, args.dataset)(root='../data', train=False, download=False, transform=test_transform)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False)

# optimizer = optim.SGD(st_model.parameters(), lr=args.lr, momentum=args.momentum)
optimizer_1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
optimizer_2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
optimizer_3 = optim.SGD(net3.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
optimizer_4 = optim.SGD(net4.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
optimizer_5 = optim.SGD(net5.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

lr_scheduler_1 = optim.lr_scheduler.MultiStepLR(optimizer_1, milestones=[100, 150])
lr_scheduler_2 = optim.lr_scheduler.MultiStepLR(optimizer_2, milestones=[100, 150])
lr_scheduler_3 = optim.lr_scheduler.MultiStepLR(optimizer_3, milestones=[100, 150])
lr_scheduler_4 = optim.lr_scheduler.MultiStepLR(optimizer_4, milestones=[100, 150])
lr_scheduler_5 = optim.lr_scheduler.MultiStepLR(optimizer_5, milestones=[100, 150])


# train with multi-teacher
def train(epoch, net1, net2, net3, net4, net5):
    print('Training:')
    # switch to train mode
    net1.train()
    net2.train()
    net3.train()
    net4.train()
    net5.train()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses_1 = AverageMeter()
    losses_2 = AverageMeter()
    losses_3 = AverageMeter()
    losses_4 = AverageMeter()
    losses_5 = AverageMeter()
    top1_1 = AverageMeter()
    top1_2 = AverageMeter()
    top1_3 = AverageMeter()
    top1_4 = AverageMeter()
    top1_5 = AverageMeter()
    
    end = time.time()
    for i, (input, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.to(device), target.to(device)
        
        # compute outputs
        _,_,_,_, output_1 = net1(input)
        _,_,_,_, output_2 = net2(input)
        _,_,_,_, output_3 = net3(input)
        _,_,_,_, output_4 = net4(input)
        _,_,_,_, output_5 = net5(input)
        logits_1 = F.softmax(output_1)
        logits_2 = F.softmax(output_2)
        logits_3 = F.softmax(output_3)
        logits_4 = F.softmax(output_4)
        logits_5 = F.softmax(output_5)
        
        optimizer_1.zero_grad()
        loss_1 = 0.5 * (nn.KLDivLoss()(logits_2, logits_1) + nn.KLDivLoss()(logits_3, logits_1) 
                 + nn.KLDivLoss()(logits_4, logits_1)+ nn.KLDivLoss()(logits_5, logits_1)) + F.cross_entropy(output_1, target)
        loss_1.backward(retain_graph=True)
        optimizer_1.step()
        
        optimizer_2.zero_grad()
        loss_2 = 0.5 * (nn.KLDivLoss()(logits_1, logits_2) + nn.KLDivLoss()(logits_3, logits_2)
                 + nn.KLDivLoss()(logits_4, logits_2)+ nn.KLDivLoss()(logits_5, logits_2))+ F.cross_entropy(output_2, target)
        loss_2.backward(retain_graph=True)
        optimizer_2.step()
        
        optimizer_3.zero_grad()
        loss_3 = 0.5 * (nn.KLDivLoss()(logits_1, logits_3) + nn.KLDivLoss()(logits_2, logits_3) 
                 + nn.KLDivLoss()(logits_4, logits_3)+ nn.KLDivLoss()(logits_5, logits_3))+ F.cross_entropy(output_3, target)
        loss_3.backward(retain_graph=True)
        optimizer_3.step()
        
        optimizer_4.zero_grad()
        loss_4 = 0.5 * (nn.KLDivLoss()(logits_1, logits_4) + nn.KLDivLoss()(logits_2, logits_4)
                 + nn.KLDivLoss()(logits_3, logits_4)+ nn.KLDivLoss()(logits_5, logits_4))+ F.cross_entropy(output_4, target)
        loss_4.backward(retain_graph=True)
        optimizer_4.step()
        
        optimizer_5.zero_grad()
        loss_5 = 0.5 * (nn.KLDivLoss()(logits_1, logits_5) + nn.KLDivLoss()(logits_2, logits_5)
                 + nn.KLDivLoss()(logits_3, logits_5)+ nn.KLDivLoss()(logits_4, logits_5))+ F.cross_entropy(output_5, target)
        loss_5.backward(retain_graph=True)
        optimizer_5.step()


        # measure accuracy and record loss
        net1_acc = accuracy(output_1, target)[0]
        net2_acc = accuracy(output_2, target)[0]
        net3_acc = accuracy(output_3, target)[0]
        net4_acc = accuracy(output_4, target)[0]
        net5_acc = accuracy(output_5, target)[0]
        losses_1.update(loss_1.item(), input.size(0))
        top1_1.update(net1_acc, input.size(0))
        losses_2.update(loss_2.item(), input.size(0))
        top1_2.update(net2_acc, input.size(0))
        losses_3.update(loss_3.item(), input.size(0))
        top1_3.update(net3_acc, input.size(0))
        losses_4.update(loss_4.item(), input.size(0))
        top1_4.update(net4_acc, input.size(0))
        losses_5.update(loss_5.item(), input.size(0))
        top1_5.update(net5_acc, input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            log_out('[{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 ({top1_1.avg:.3f})'.format(
                      i, len(train_loader), batch_time=batch_time,
                      data_time=data_time, loss=losses_1, top1_1=top1_1))
    return losses_1.avg, losses_2.avg, losses_3.avg, losses_4.avg, losses_5.avg, net1_acc.cpu().numpy(), net2_acc.cpu().numpy(), net3_acc.cpu().numpy(), net4_acc.cpu().numpy(), net5_acc.cpu().numpy()


def test(model):
    print('Testing:')
    # switch to evaluate mode
    model.eval()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(test_loader):
            input, target = input.to(device), target.to(device)

            # compute output
            _,_,_,_,output = model(input)
            loss = F.cross_entropy(output, target)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            test_acc = accuracy(output.data, target.data)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(test_acc, input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                log_out('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          i, len(test_loader), batch_time=batch_time, loss=losses,
                          top1=top1))

    log_out(' * {0} Prec@1 {top1.avg:.3f}'.format(model.model_name, top1=top1))

    return losses.avg, test_acc.cpu().numpy(), top1.avg.cpu().numpy()


print('*-----------------DML----------------*')
best_acc = 0
for epoch in range(1, args.epochs + 1):
    log_out("\n===> epoch: {}/{}".format(epoch, args.epochs))
#     log_out('current lr {:.5e}'.format(optimizer_1.param_groups[0]['lr']))
    lr_scheduler_1.step()
    lr_scheduler_2.step()
    lr_scheduler_3.step()
    lr_scheduler_4.step()
    lr_scheduler_5.step()
    train_loss_1, train_loss_2, train_loss_3, train_loss_4, train_loss_5, net1_acc, net2_acc, net3_acc, net4_acc, net5_acc = train(epoch, net1, net2, net3, net4, net5)
    # visaulize loss
    vis.line(np.column_stack((train_loss_1, train_loss_2, train_loss_3, train_loss_4, train_loss_5)), np.column_stack((epoch, epoch, epoch, epoch, epoch)), loss_win, update="append")
    _, test_acc_1, top1_1 = test(net1)
    _, test_acc_2, top1_2 = test(net2)
    _, test_acc_3, top1_3 = test(net3)
    _, test_acc_4, top1_4 = test(net4)
    _, test_acc_5, top1_5 = test(net5)
    vis.line(np.column_stack((top1_1, top1_2, top1_3, top1_4, top1_5)), np.column_stack((epoch, epoch, epoch, epoch, epoch)), acc_win, update="append")
    best_acc = max(top1_1, best_acc)

log_out("@ BEST Prec: {:.4f}".format(best_acc))
log_out("--- {:.3f} mins ---".format((time.time() - start_time)/60))


  init.kaiming_normal(m.weight)


Files already downloaded and verified
*-----------------DML----------------*

===> epoch: 1/200
Training:




[0/391]	Time 0.473 (0.473)	Data 0.039 (0.039)	Loss 2.3204 (2.3204)	Prec@1 (7.812)
[40/391]	Time 0.275 (0.281)	Data 0.019 (0.020)	Loss 1.5083 (1.6540)	Prec@1 (22.847)
[80/391]	Time 0.275 (0.280)	Data 0.019 (0.019)	Loss 1.2872 (1.5041)	Prec@1 (28.241)
[120/391]	Time 0.284 (0.280)	Data 0.019 (0.019)	Loss 1.2726 (1.4197)	Prec@1 (31.431)
[160/391]	Time 0.289 (0.280)	Data 0.019 (0.019)	Loss 1.1916 (1.3726)	Prec@1 (33.438)
[200/391]	Time 0.289 (0.282)	Data 0.020 (0.020)	Loss 1.0543 (1.3281)	Prec@1 (35.393)
[240/391]	Time 0.284 (0.283)	Data 0.019 (0.019)	Loss 1.0713 (1.2909)	Prec@1 (37.017)
[280/391]	Time 0.293 (0.285)	Data 0.019 (0.020)	Loss 1.0685 (1.2523)	Prec@1 (38.723)
[320/391]	Time 0.306 (0.286)	Data 0.021 (0.020)	Loss 0.8693 (1.2183)	Prec@1 (40.301)
[360/391]	Time 0.284 (0.287)	Data 0.019 (0.020)	Loss 0.8677 (1.1847)	Prec@1 (41.915)
Testing:
Test: [0/79]	Time 0.020 (0.020)	Loss 1.2756 (1.2756)	Prec@1 53.125 (53.125)
Test: [40/79]	Time 0.020 (0.020)	Loss 1.4258 (1.3573)	Prec@1 49.219 (5