In [1]:
import argparse
import collections
import copy
import itertools
import time
import numpy as np
import sklearn.metrics as metrics
import torch as th
import torch.nn.functional as F
import torch.optim as optim
import torch.utils as utils
import data
import my
import lenet
import resnet
import rn

In [2]:
args = argparse.Namespace()
args.actor_iterations = 25
args.c_batch_size = 8
args.critic_batch_size = 8
args.critic_iterations = 25
args.gpu = 3
args.n_iterations = 150
args.n_perturbations = 100
args.sample_size = 50
args.std = 1e-1
args.tau = 1e-1

verbose = None

In [3]:
if args.gpu < 0:
    cuda = False
    new_tensor = th.FloatTensor
else:
    cuda = True
    new_tensor = th.cuda.FloatTensor
    th.cuda.set_device(args.gpu)

train_x, train_y, test_x, test_y = data.load_cifar10(rbg=True, torch=True)
# train_x, train_y, test_x, test_y = my.load_cifar10(rbg=False, torch=True)

train_set = utils.data.TensorDataset(train_x, train_y)
test_set = utils.data.TensorDataset(test_x, test_y)
test_loader = utils.data.DataLoader(test_set, 4096, drop_last=False)

n_classes = int(train_y.max() - train_y.min() + 1)

In [4]:
def forward(c, batch):
    x, y = batch
    y = my.onehot(y, n_classes)
    y_bar = F.softmax(c(x), 1)
    return th.cat((y, y_bar), 1).view(1, -1)

def global_scores(c, loader):
    keys = ('accuracy', 'precision', 'recall', 'f1')
    scores = (
        metrics.accuracy_score,
        lambda y, y_bar: metrics.precision_score(y, y_bar, average='micro'),
        lambda y, y_bar: metrics.recall_score(y, y_bar, average='micro'),
        lambda y, y_bar: metrics.f1_score(y, y_bar, average='micro'),
    )
    values = [value.item() for value in my.global_scores(c, loader, scores)]
    return collections.OrderedDict(zip(keys, values))

def L_batches(c, batches):
    L = [[metrics.f1_score(th.max(c(x), 1)[1], y, average='micro').item()] for x, y in batches]
    return new_tensor(L)

def sample(dataset, batch_size, n_batches, cuda):
    loader = utils.data.DataLoader(dataset, batch_size, shuffle=True)
    batches = itertools.takewhile(lambda x: x[0] < n_batches, enumerate(loader))
    if cuda:
        batches = [(x.cuda(), y.cuda()) for _, (x, y) in batches]
    return batches

In [5]:
th.random.manual_seed(1)
if cuda:
    th.cuda.manual_seed_all(1)

# c = mlp.MLP((3072, n_classes), F.relu)
# c = mlp.MLP((3072,) + (1024,) + (n_classes,), F.relu)
# c = mlp.MLP((3072,) + (1024,) * 2 + (n_classes,), F.relu)
# c = mlp.MLP((3072,) + (1024,) * 3 + (n_classes,), F.relu)
c = lenet.LeNet(3, n_classes)
# c = resnet.ResNet(depth=18, n_classes=n_classes)

critic = rn.RN(args.sample_size, 2 * n_classes, tuple(), (4 * n_classes, 64, 256), (256, 64) + (1,), F.relu, triu=True)

if cuda:
    c.cuda()
    critic.cuda()

c_optim = optim.Adam(c.parameters(), eps=1e-3)
critic_optim = optim.Adam(critic.parameters())

for key, value in global_scores(c, test_loader).items():
    print(key, value)

accuracy 0.0961
precision 0.0961
recall 0.0961
f1 0.0961


In [6]:
hist = []
critic_batches = sample(train_set, args.sample_size, args.critic_batch_size, cuda)
for i in range(args.n_iterations):
    hist.append({})

    if verbose == 0:
        t0 = time.time()

    L_c = L_batches(c, critic_batches)
    c_bar_list = []
    L_bar_list = []
    t_list = []
    for j in range(args.n_perturbations):
        c_bar = copy.deepcopy(c)
        my.set_requires_grad(c_bar, False)
        c_bar_list.append(my.perturb(c_bar, args.std))
        L_bar_list.append(L_batches(c_bar_list[-1], critic_batches))
        t_list.append(L_c - L_bar_list[-1])
    w_tensor = th.cat([th.exp(t ** 2 / args.tau) for t in t_list], 1)
    w_list = th.chunk((w_tensor / th.sum(w_tensor, 1, keepdim=True)).detach(), args.n_perturbations, 1)

    hist[-1]['L_bar_list'] = L_bar_list
    hist[-1]['w_list'] = w_list

    if verbose == 0:
        t1 = time.time()
        print('[iteration %d]t1 - t0: %f' % (i + 1, t1 - t0))
            
    y = th.cat([forward(c, batch) for batch in critic_batches], 0).detach()
    y_bar_list = [th.cat([forward(c_bar, batch) for batch in critic_batches], 0).detach() for c_bar in c_bar_list]
    for j in range(args.critic_iterations):
        for y_bar, t, w in zip(y_bar_list, t_list, w_list):
            delta = critic(y) - critic(y_bar)
            mse = th.sum(w * (t - delta)**2)
            critic_optim.zero_grad()
            mse.backward()
            critic_optim.step()

    if verbose == 0:
        t2 = time.time()
        print('[iteration %d]t2 - t1: %f' % (i + 1, t2 - t1))

    c_params = copy.deepcopy(tuple(c.parameters()))
    for j in range(args.actor_iterations):
        batches = critic_batches + sample(train_set, args.sample_size, args.c_batch_size, cuda)
        y_bar = th.cat([forward(c, batch) for batch in batches], 0)
        objective = -th.mean(critic(y_bar))
        c_optim.zero_grad()
        objective.backward()
        c_optim.step()
        if any(float(th.max(th.abs(p - q))) > args.std for p, q in zip(c_params, c.parameters())):
            break

    if verbose == 0:
        t3 = time.time()
        print('[iteration %d]t3 - t2: %f' % (i + 1, t3 - t2))

#     f1 = th.mean(L_batches(c, critic_batches))

    hist[-1]['stats'] = global_scores(c, test_loader)
    print('[iteration %d]%f' % (i, hist[-1]['stats']['f1']))

[iteration 0]0.100500
[iteration 1]0.123600
[iteration 2]0.195000
[iteration 3]0.225700
[iteration 4]0.217200
[iteration 5]0.211500
[iteration 6]0.235500
[iteration 7]0.276500
[iteration 8]0.300300
[iteration 9]0.315100
[iteration 10]0.323700
[iteration 11]0.338200
[iteration 12]0.354500
[iteration 13]0.365500
[iteration 14]0.376800
[iteration 15]0.385200
[iteration 16]0.391300
[iteration 17]0.403000
[iteration 18]0.417200
[iteration 19]0.418500
[iteration 20]0.412300
[iteration 21]0.431400
[iteration 22]0.436200
[iteration 23]0.440800
[iteration 24]0.455700
[iteration 25]0.460100
[iteration 26]0.468000
[iteration 27]0.465000
[iteration 28]0.478400
[iteration 29]0.478900
[iteration 30]0.485800
[iteration 31]0.487000
[iteration 32]0.489400
[iteration 33]0.494400
[iteration 34]0.498500
[iteration 35]0.507500
[iteration 36]0.504300
[iteration 37]0.506400
[iteration 38]0.501800
[iteration 39]0.499700
[iteration 40]0.503800
[iteration 41]0.512700
[iteration 42]0.504000
[iteration 43]0.50490