In [1]:
import argparse
import copy
import collections
import time
import numpy as np
import sklearn.metrics as metrics
import torch as th
import torch.nn.functional as F
import torch.optim as optim
import torch.utils as utils
import tensorboardX as tb
import data
import my
import lenet
import resnet
import rn

In [5]:
args = argparse.Namespace()
args.batch_size = 50
args.gpu = 0
args.n_iterations = 1000
args.critic_n_iterations = 100
args.n_perturbations = 25
args.std = 1e-1
args.tau = 1e-1

keys = sorted(vars(args).keys())
run_id = 'standard-' + '-'.join('%s-%s' % (key, str(getattr(args, key))) for key in keys)
writer = tb.SummaryWriter('runs/' + run_id)

verbose = None

In [7]:
train_x, train_y, test_x, test_y = data.load_cifar10(rbg=True, torch=True)
# train_x, train_y, test_x, test_y = my.load_cifar10(rbg=False, torch=True)

train_set = utils.data.TensorDataset(train_x, train_y)
test_set = utils.data.TensorDataset(test_x, test_y)
test_loader = utils.data.DataLoader(test_set, 4096, drop_last=False)

n_classes = int(train_y.max() - train_y.min() + 1)

In [8]:
def TrainLoader():
    train_loader = iter(utils.data.DataLoader(train_set, args.batch_size, shuffle=True))
    ctx = lambda x, y: (x.cuda(), y.cuda()) if cuda else (x, y)
    while True:
        try:
            yield ctx(*next(train_loader))
        except StopIteration:
            train_loader = iter(utils.data.DataLoader(train_set, args.batch_size, shuffle=True))
            yield ctx(*next(train_loader))

train_loader = TrainLoader()

def forward(c, batch):
    x, y = batch
    y = my.onehot(y, n_classes)
    y_bar = F.softmax(c(x), 1)
    return th.cat((y, y_bar), 1).view(1, -1)

def global_scores(c, loader):
    keys = ('accuracy', 'precision', 'recall', 'f1')
    scores = (
        metrics.accuracy_score,
        lambda y, y_bar: metrics.precision_score(y, y_bar, average='micro'),
        lambda y, y_bar: metrics.recall_score(y, y_bar, average='micro'),
        lambda y, y_bar: metrics.f1_score(y, y_bar, average='micro'),
    )
    values = [value.item() for value in my.global_scores(c, loader, scores)]
    return collections.OrderedDict(zip(keys, values))

def L_batch(c, batch):
    x, y = batch
    return metrics.f1_score(th.max(c(x), 1)[1], y, average='micro').item()

In [10]:
if args.gpu < 0:
    cuda = False
    new_tensor = th.FloatTensor
else:
    cuda = True
    new_tensor = th.cuda.FloatTensor
    th.cuda.set_device(args.gpu)
    th.cuda.manual_seed_all(1)

th.random.manual_seed(1)

# c = mlp.MLP((3072, n_classes), F.relu)
# c = mlp.MLP((3072,) + (1024,) + (n_classes,), F.relu)
# c = mlp.MLP((3072,) + (1024,) * 2 + (n_classes,), F.relu)
# c = mlp.MLP((3072,) + (1024,) * 3 + (n_classes,), F.relu)
c = lenet.LeNet(3, n_classes)
# c = resnet.ResNet(depth=18, n_classes=n_classes)

critic = rn.RN(args.batch_size, 2 * n_classes, tuple(), (4 * n_classes, 64, 256), (256, 64) + (1,), F.relu, triu=True)

if cuda:
    c.cuda()
    critic.cuda()

c_optim = optim.Adam(c.parameters(), eps=1e-3)
critic_optim = optim.Adam(critic.parameters())

for key, value in global_scores(c, test_loader).items():
    print(key, value)

TypeError: argument 0 is not a Variable

In [None]:
hist = []

for i in range(args.n_iterations):
    hist.append({})

    if verbose == 0:
        t0 = time.time()

    batch = next(train_loader)
    my.set_requires_grad(c, False)
    L_c = L_batch(c, batch)
    c_bar_list = []
    L_bar_list = []
    t_list = []
    for j in range(args.n_perturbations):
        c_bar = copy.deepcopy(c)
        my.set_requires_grad(c_bar, False)
        c_bar_list.append(my.perturb(c_bar, args.std))
        L_bar_list.append(L_batch(c_bar_list[-1], batch))
        t_list.append(L_c - L_bar_list[-1])
    t_tensor = new_tensor(t_list)
    w_tensor = th.exp(t_tensor ** 2 / args.tau)
    w_tensor /= th.sum(w_tensor)
    w_list = w_tensor.tolist()

    writer.add_scalar('th.min(L_bar_tensorz)', th.min(L_bar_tensorz), i)
    writer.add_scalar('th.max(L_bar_tensorz)', th.max(L_bar_tensorz), i)
    writer.add_scalar('entropy', -th.sum(w_tensorz * th.log(w_tensorz)), i)

    if verbose == 0:
        t1 = time.time()
        print('[iteration %d]t1 - t0: %f' % (i + 1, t1 - t0))
    
    y = forward(c, batch).detach()
    y_bar_list = [forward(c_bar, batch) for c_bar in c_bar_list]
    for j in range(args.n_iterations_critic):
        for y_bar, t, w in zip(y_bar_list, t_list, w_list):
            delta = critic(y) - critic(y_bar)
            mse = th.sum(w * (t - delta) ** 2)
            critic_optim.zero_grad()
            mse.backward()
            critic_optim.step()
        writer.add_scalar('mse', mse, i * args.n_iterations_critic + j)

    if verbose == 0:
        t2 = time.time()
        print('[iteration %d]t2 - t1: %f' % (i + 1, t2 - t1))

    my.set_requires_grad(c, True)
    y_bar = forward(c, batch)
    objective = -th.mean(critic(y_bar))
    c_optim.zero_grad()
    objective.backward()
    c_optim.step()

    if verbose == 0:
        t3 = time.time()
        print('[iteration %d]t3 - t2: %f' % (i + 1, t3 - t2))

    hist[-1]['stats'] = global_scores(c, test_loader)
    for key, value in hist[-1]['stats'].items():
        writer.add_scalar(key, value, i)
    
    if (i + 1) % 1 == 0:
        print('[iteration %d]%f' % (i + 1, hist[-1]['stats']['f1']))