In [1]:
import argparse
import copy
import collections
import time
import sklearn.metrics as metrics
import tensorboardX as tb
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils as utils
import data
import es
import my
import lenet
import resnet
import rn

In [2]:
args = argparse.Namespace()
args.actor = 'linear'
# args.actor = 'lenet'
# args.actor = 'resnet'
args.alpha = 0.5
args.average = 'binary'
args.batch_size_actor = 100
args.batch_size_criticx = 1
args.batch_size_criticy = 1
args.beta = 2
args.ckpt_every = 0
# args.dataset = 'mnist'
args.dataset = 'cifar10'
args.gpu = 0
args.guided_es = True
args.iw = 'none'
# args.iw = 'sqrt'
# args.iw = 'linear'
# args.iw = 'quadratic'
args.k = 10
# args.labelling = ''
args.labelling = '91'
args.n_iterations = 50
args.n_iterations_actor = 25
args.n_iterations_critic = 25
args.n_perturbations = 50
args.report_every = 10
args.resume = 0
args.std = 1
args.sample_size_critic = 1
args.tau = 0.1
args.tb = False
args.verbose = -1

'''
parser = argparse.ArgumentParser()
parser.add_argument('--average', type=str, default=None)
parser.add_argument('--batch-size-actor', type=int, default=None)
parser.add_argument('--batch-size-critic', type=int, default=None)
parser.add_argument('--ckpt-every', type=int, default=None)
parser.add_argument('--dataset', type=str, default=None)
parser.add_argument('--gpu', type=int, default=None)
parser.add_argument('--iw', type=str, default=None)
parser.add_argument('--report-every', type=int, default=None)
parser.add_argument('--actor', type=str, default=None)
parser.add_argument('--labelling', type=str, default=None)
parser.add_argument('--n-iterations', type=int, default=None)
parser.add_argument('--n-iterations-critic', type=int, default=None)
parser.add_argument('--n-perturbations', type=int, default=None)
parser.add_argument('--resume', type=int, default=None)
parser.add_argument('--sample-size-critic', type=int, default=None)
parser.add_argument('--std', type=float, default=None)
parser.add_argument('--tau', type=float, default=None)
parser.add_argument('--verbose', type=int, default=-1)
args = parser.parse_args()
'''

keys = sorted(vars(args).keys())
excluded = ('ckpt_every', 'gpu', 'report_every', 'n_iterations', 'resume', 'tb', 'verbose')
experiment_id = 'parameter-' + '-'.join('%s-%s' % (key, str(getattr(args, key))) for key in keys if key not in excluded)
if args.tb:
    writer = tb.SummaryWriter('runs/' + experiment_id)

In [3]:
if args.gpu < 0:
    cuda = False
    new_tensor = th.FloatTensor
else:
    cuda = True
    new_tensor = th.cuda.FloatTensor
    th.cuda.set_device(args.gpu)

labelling = {} if args.labelling == '' else {(0, 9) : 0, (9, 10) : 1}
rbg = args.actor in ('lenet', 'resnet')
train_x, train_y, test_x, test_y = getattr(data, 'load_%s' % args.dataset)(labelling, rbg, torch=True)

train_set = utils.data.TensorDataset(train_x, train_y)
train_loader = utils.data.DataLoader(train_set, 4096, drop_last=False)
test_set = utils.data.TensorDataset(test_x, test_y)
test_loader = utils.data.DataLoader(test_set, 4096, drop_last=False)

loader = data.BalancedDataLoader(train_x, train_y, args.batch_size_actor, cuda)

n_classes = int(train_y.max() - train_y.min() + 1)

In [4]:
def batch(tensor, batch_sizex, batch_sizey):
    """
    Parameters
    ----------
    tensor : (x, y, z)
    """
    shapex, shapey, shapez = tensor.shape
    nx, ny = int(shapex / batch_sizex), int(shapey / batch_sizey)
    x_list = th.chunk(tensor, nx, 0)
    return sum([[y.view(-1, shapez) for y in th.chunk(x, ny, 1)] for x in x_list], [])

def forward(actor, batch_list, yz=True, L=True):
    x_tuple, y_tuple = zip(*batch_list)
    x_tensor, y_tensor = th.cat(x_tuple), th.cat(y_tuple)
    z_tensor = actor(x_tensor)
    ret = []
    if yz:
        ret.append(th.cat([my.onehot(y_tensor, n_classes), F.softmax(z_tensor, 1)], 1).view(len(batch_list), -1))
    if L:
        z_list = th.chunk(z_tensor, len(batch_list))
        ret.append(new_tensor([L_batch(y, z) for y, z in zip(y_tuple, z_list)]).unsqueeze(1))
    return ret
    
def L_batch(y, y_bar):
    y_bar = th.max(y_bar, 1)[1]
    return metrics.f1_score(y, y_bar, average=args.average)

iw = {
    'none' : lambda x: th.zeros_like(x),
    'quadratic' : lambda x: x * x,
}[args.iw]

In [5]:
def ckpt(actor, critic, actor_optim, critic_optim, i):
    th.save(actor.state_dict(), 'ckpt/%s-actor-%d' % (experiment_id, i + 1))
    th.save(critic.state_dict(), 'ckpt/%s-critic-%d' % (experiment_id, i + 1))
    th.save(actor_optim.state_dict(), 'ckpt/%s-actor_optim-%d' % (experiment_id, i + 1))
    th.save(critic_optim.state_dict(), 'ckpt/%s-critic_optim-%d' % (experiment_id, i + 1))

def global_scores(c, loader):
    key_list = ['accuracy', 'precision', 'recall', 'f1']
    score_list = [
        metrics.accuracy_score,
        lambda y, y_bar: metrics.precision_recall_fscore_support(y, y_bar, average=args.average)
    ]
    accuracy, (precision, recall, f1, _) = my.global_scores(c, loader, score_list)
    return collections.OrderedDict({
        'accuracy'  : accuracy,
        'precision' : precision,
        'recall'    : recall,
        'f1'        : f1,
    })

def log_stats(tensor, tag, i):
    writer.add_scalar('th.min(%s)' % tag, th.min(tensor), i + 1)
    writer.add_scalar('th.max(%s)' % tag, th.max(tensor), i + 1)
    writer.add_scalar('th.mean(%s)' % tag, th.mean(tensor), i + 1)
    
def report(actor, i):
    train_scores = global_scores(actor, train_loader)
    test_scores = global_scores(actor, test_loader)

    prefix = '0' * (len(str(args.n_iterations)) - len(str(i + 1)))
    print('[iteration %s%d]' % (prefix, i + 1) + \
          ' | '.join('%s %0.3f/%0.3f' % (key, value, test_scores[key]) for key, value in train_scores.items()))

    if args.tb:
        for key, value in train_scores.items():
            writer.add_scalar('train-' + key, value, i + 1)
        for key, value in test_scores.items():
            writer.add_scalar('test-' + key, value, i + 1)

In [6]:
import imp
es = imp.reload(es)
data = imp.reload(data)

In [None]:
th.random.manual_seed(1)
if cuda:
    th.cuda.manual_seed_all(1)

n_channels = 1 if args.dataset == 'mnist' else 3
size = 28 if args.dataset == 'mnist' else 32
actor = {
    'linear' : nn.Linear(n_channels * size ** 2, n_classes),
    'lenet'  : lenet.LeNet(3, n_classes, size),
    'resnet' : resnet.ResNet(depth=18, n_classes=n_classes),
}[args.actor]
unary = [2 * n_classes, 64]
binary = [2 * unary[-1], 64]
terminal = [64, 64] + [1]
critic = rn.RN(args.batch_size_actor, 2 * n_classes, unary, binary, terminal, F.relu, triu=True)

if cuda:
    actor.cuda()
    critic.cuda()

if args.guided_es:
    guided_es = es.GuidedES(args.std, args.alpha, my.n_parameters(actor), args.k)
    
actor_optim = optim.Adam(actor.parameters(), amsgrad=True)
critic_optim = optim.Adam(critic.parameters(), amsgrad=True)

if args.resume > 0:
    c.load_state_dict(th.load('ckpt/%s-actor-%d' % (experiment_id, args.resume)))
    critic.load_state_dict(th.load('ckpt/%s-critic-%d' % (experiment_id, args.resume)))
    c_optim.load_state_dict(th.load('ckpt/%s-actor_optim-%d' % (experiment_id, args.resume)))
    critic_optim.load_state_dict(th.load('ckpt/%s-critic_optim-%d' % (experiment_id, args.resume)))

report(actor, -1)

[iteration 00]accuracy 0.454/0.457 | precision 0.267/0.291 | recall 0.054/0.058 | f1 0.089/0.097


In [None]:
for i in range(args.resume, args.resume + args.n_iterations):
    batch_list = [next(loader) for j in range(args.sample_size_critic)]
    
    my.set_requires_grad(actor, False)
    yz, L = forward(actor, batch_list)

    actor_bar = copy.deepcopy(actor)
    my.set_requires_grad(actor_bar, False)
    yzbar_list, Lbar_list, delta_list = [], [], []
    for j in range(args.n_perturbations):
        my.copy(actor, actor_bar)
        if args.guided_es:
            guided_es.perturb(actor_bar)
        else:
            my.perturb(actor_bar, args.std)
        yz_bar, L_bar = forward(actor_bar, batch_list)
        yzbar_list.append(yz_bar)
        Lbar_list.append(L_bar)
        delta_list.append(L - L_bar)
    if args.tb:
        log_stats(th.cat(Lbar_list, 1), 'L_bar', i)
        
    yzbar_tensor = th.cat([yz_bar.unsqueeze(1) for yz_bar in yzbar_list], 1)
    delta_tensor = th.cat(delta_list, 1)
    weight_tensor = F.softmax(iw(delta_tensor), 1)
    entropy = th.sum(weight_tensor * th.log(weight_tensor)) / args.sample_size_critic
    if args.tb:
        writer.add_scalar('entropy', entropy, i + 1)
    
    delta_tensor, weight_tensor = delta_tensor.unsqueeze(2), weight_tensor.unsqueeze(2)
    lambda_batch = lambda tensor: batch(tensor, args.batch_size_criticx, args.batch_size_criticy)
    yzbar_list, delta_list, weight_list = list(map(lambda_batch, [yzbar_tensor, delta_tensor, weight_tensor]))
    
    my.set_requires_grad(critic, True)
    for j in range(args.n_iterations_critic):
        for yz_bar, delta, weight in zip(yzbar_list, delta_list, weight_list):
            mse = th.sum(weight * (delta - (critic(yz) - critic(yz_bar))) ** 2)
            critic_optim.zero_grad()
            mse.backward()
            critic_optim.step()
        if args.tb:
            writer.add_scalar('mse', mse, i * args.n_iterations_critic + j + 1)

    my.set_requires_grad(actor, True)
    my.set_requires_grad(critic, False)
    for j in range(args.n_iterations_actor):
        yz, L = forward(actor, batch_list)
        if args.tb:
            log_stats(L, 'L', i * args.n_iterations_actor + j)
        
        objective = -critic(yz)
        actor_optim.zero_grad()
        objective.backward()
        actor_optim.step()
        
    if args.report_every > 0 and (i + 1) % args.report_every == 0:
        report(actor, i)

    if args.ckpt_every > 0 and (i + 1) % args.ckpt_every == 0:
        ckpt(actor, critic, actor_optim, critic_optim, i)