In [2]:
from __future__ import print_function
import os
import time
import logging
import argparse
from visdom import Visdom
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Teacher models
from models.teacher import *

# Student models
from models.student import *


start_time = time.time()
# os.makedirs('./checkpoint', exist_ok=True)

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Distill Example')
parser.add_argument('--teacher', type=str, default='VGG19', help='teacher net: AlexNet, VGG11/13/16/19, GoogLeNet')
parser.add_argument('--student', type=str, default='FitNet11', help='student net: LeNet5, ')
parser.add_argument('--T', type=float, default=20.0, metavar='Temputure', help='Temputure for distillation')
parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training')
parser.add_argument('--test_batch_size', type=int, default=128, metavar='N', help='input test batch size for training')
parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.5)')
parser.add_argument('--cuda', action='store_true', default=torch.cuda.is_available(), help='use CUDA training')
parser.add_argument('--log_interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status')

config = ['--epochs', '50', '--teacher', 'Hint7', '--student', 'FitNet11', '--T', '10', '--cuda']
args = parser.parse_args(config)

device = 'cuda:0' if args.cuda and torch.cuda.is_available() else 'cpu'

# logging
logfile = './checkpoint/distill_' + args.teacher  + '_' + args.student +  '.log'
if os.path.exists(logfile):
    os.remove(logfile)

def log_out(info):
    f = open(logfile, mode='a')
    f.write(info)
    f.write('\n')
    f.close()
    print(info)
    
# visualizer
vis = Visdom(env='distill')
loss_win = vis.line(
    X=np.array([0]),
    Y=np.array([0]),
    opts=dict(
        title='train loss',
        xtickmin=0,
#             xtickmax=1,
        xtickstep=5,
        ytickmin=0,
#             ytickmax=1,
        ytickstep=0.5,
        markers=True,
        markersymbol='dot',
        markersize=5,
    ),
    name="loss"
)
    
acc_win = vis.line(
    X=np.column_stack((0, 0)),
    Y=np.column_stack((0, 0)),
    opts=dict(
        title='ACC',
        xtickmin=0,
        xtickstep=5,
        ytickmin=0,
        ytickmax=100,
        markers=True,
        markersymbol='dot',
        markersize=5,
        legend=['train_acc', 'test_acc']
    ),
    name="acc"
)

# weights init
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)
    elif classname.find('linear') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)




In [None]:
# class ConvRegressor(nn.Module):
#     def __init__(self, teacher, hint_layer, student, guided_layer):
#         self.hint_layer = teacher.
        

In [4]:
teacher_model = eval(args.teacher)().to(device)
teacher_model.load_state_dict(torch.load('./checkpoint/' + args.teacher + '_cifar10.pth'))
st_model = eval(args.student)().to(device)
st_model.apply(weights_init_normal)  # init student

st_features = None
te_features = None

def st_hook(module, input, output):
    '''把这层的输出到features中'''
    global st_features
    st_features = output.data
    
def te_hook(module, input, output):
    '''把这层的输出拷贝到features中'''
    global te_features
    te_features = output.data

class Regressor(nn.Module):
    def __init__(self):
        super(Regressor,self).__init__()
        # torch.Size([128, 512, 16, 16]) -> torch.Size([128, 80, 4, 4]) 
        self.features = nn.Sequential(
            nn.Conv2d(2048, 512, 3, 1, 1),  # ch: 512 -> 256  
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2), # size: 16 -> 8
            
            nn.Conv2d(512, 128, 3, 1, 1),  # ch: 256 -> 128
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2), # size: 8 -> 4
            
            nn.Conv2d(128, 80, 3, 1, 1),   # ch: 128 -> 80
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.features(x)

regressor = Regressor().to(device)
st_model.features[15].register_forward_hook(st_hook)
teacher_model.features[11].register_forward_hook(te_hook)

# data
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])
test_transform = transforms.Compose([transforms.ToTensor()])
train_set = datasets.CIFAR10(root='../data', train=True, download=True, transform=train_transform)
test_set = datasets.CIFAR10(root='../data', train=False, download=False, transform=test_transform)
train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False)

optimizer = optim.SGD(st_model.parameters(), lr=args.lr, momentum=args.momentum)
optimizer_r = optim.SGD(regressor.parameters(), lr=args.lr, momentum=args.momentum)

def distillation(y, labels, teacher_scores, T, alpha):
    return nn.KLDivLoss()(F.log_softmax(y/T), F.softmax(teacher_scores/T)) * (T*T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)


# guided train
def guided_train(model, loss_fn_guided=nn.MSELoss()):
    model.train()
    teacher_model.eval()
    guided_loss = None
    for epoch in range(10):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            optimizer_r.zero_grad()
            model.forward(data)
            teacher_model(data).detach()
            te_output = regressor(te_features)
#             print(st_features.size(), te_hint.size())
        
            guided_loss = loss_fn_guided(te_output, st_features)
            guided_loss.backward()
            optimizer_r.step()
            optimizer.step()
        print('guided_epoch:[{}]\tLoss:{:.4f}'.format(epoch, guided_loss.item()))


def train(epoch, model, loss_fn):
    model.train()
    teacher_model.eval()
    loss = None
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        teacher_output = teacher_model(data)
        teacher_output = teacher_output.detach()
#         print(st_features.size())
        # teacher_output = Variable(teacher_output.data, requires_grad=False) #alternative approach to load teacher_output
        loss = loss_fn(output, target, teacher_output, T=args.T, alpha=0.6)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            log_out('Train Epoch: {} [{}/{} ({:.4f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    return loss.item()

def train_evaluate(model):
    model.eval()
    train_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            train_loss += F.cross_entropy(output, target).item() # sum up batch loss
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()

        log_out('\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format(
            train_loss, correct, len(train_loader.dataset), 
            100. * correct / len(train_loader.dataset)))
    return 100. * correct / len(train_loader.dataset)

def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # test_loss += F.cross_entropy(output, target).item() # sum up batch loss
            pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()

        test_loss /= len(test_loader.dataset)
        log_out('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
    return 100. * correct / len(test_loader.dataset)

print('StudentNet:\n')
print(st_model)
guided_train(st_model)
for epoch in range(1, args.epochs + 1):
    train_loss = train(epoch, st_model, loss_fn=distillation)
    # visaulize loss
    vis.line(np.array([train_loss]), np.array([epoch]), loss_win, update="append")
    train_acc = train_evaluate(st_model)
    test_acc = test(st_model)
    vis.line(np.column_stack((train_acc, test_acc)), np.column_stack((epoch, epoch)), acc_win, update="append")


torch.save(st_model.state_dict(), './checkpoint/' + args.teacher + '_distill_' + args.student + '.pth')
# the_model = Net()
# the_model.load_state_dict(torch.load('student.pth.tar'))

# test(the_model)
# for data, target in test_loader:
#     data, target = Variable(data, volatile=True), Variable(target)
#     teacher_out = the_model(data)
# print(teacher_out)
log_out("--- {:.3f} seconds ---".format(time.time() - start_time))


FileNotFoundError: [Errno 2] No such file or directory: './checkpoint/Hint7_cifar10.pth'