In [1]:
import argparse
import copy
import collections
import math
import pickle
import time
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
import my
import lenet
import resnet

In [2]:
args = argparse.Namespace()
args.n_iterations_critic = 100 # 25/50/100
args.iw = 'quadratic' # ''/linear/quadratic
args.gpu = 0
args.n_iterations = 5000
args.n_perturbations = 25 # 25/50/100
args.batch_size = 50
args.std = 1e-1 # 1/0.1/0.01
args.tau = 1e-1

'''
parser = argparse.ArgumentParser()
parser.add_argument('--n-iterations-critic', type=int)
parser.add_argument('--iw', type=str, default='')
parser.add_argument('--gpu', type=int)
parser.add_argument('--n-iterations', type=int)
parser.add_argument('--n-perturbations', type=int)
parser.add_argument('--batch-size', type=int)
parser.add_argument('--std', type=float)
parser.add_argument('--tau', type=float)
args = parser.parse_args()
'''

verbose = None

In [3]:
cuda = args.gpu >= 0
if cuda:
    th.cuda.set_device(args.gpu)

th.random.manual_seed(1)
if cuda:
    th.cuda.manual_seed_all(1)

train_x, train_y, test_x, test_y = my.load_cifar10(rbg=True, torch=True)
# train_x, train_y, test_x, test_y = my.load_cifar10(rbg=False, torch=True)

test_loader = DataLoader(TensorDataset(test_x, test_y), 4096, drop_last=False)

n_classes = int(train_y.max() - train_y.min() + 1)

In [20]:
def TrainLoader():
    train_loader = iter(DataLoader(TensorDataset(train_x, train_y), args.batch_size, shuffle=True))
    contextualize = lambda x, y: (x.cuda(), y.cuda()) if cuda else (x, y)
    while True:
        try:
            yield contextualize(*next(train_loader))
        except StopIteration:
            train_loader = iter(DataLoader(TensorDataset(train_x, train_y), args.batch_size, shuffle=True))
            yield contextualize(*next(train_loader))

train_loader = TrainLoader()

def forward(y_bar, y):
    y_bar = F.softmax(y_bar, 1)
    y = my.onehot(y, n_classes)
    return th.cat((y, y_bar), 1).view(1, -1)

def L_mini_batch(y_bar, y):
    return my.nd_f_beta(th.max(y_bar, 1)[1], y, n_classes).view(1,)

def global_stats(c, loader):
    curry = lambda stat: lambda y_bar, y: stat(y_bar, y, n_classes)
    stats = (my.accuracy,) + tuple(map(curry, (my.nd_precision, my.nd_recall, my.nd_f_beta)))
    keys = ('accuracy', 'precision', 'recall', 'f1')
    values = [value.item() for value in my.global_stats(c, loader, stats)]
    return collections.OrderedDict(zip(keys, values))

In [5]:
# c = my.MLP((3072, n_classes), F.relu)
# c = my.MLP((3072,) + (1024,) + (n_classes,), F.relu)
# c = my.MLP((3072,) + (1024,) * 2 + (n_classes,), F.relu)
# c = my.MLP((3072,) + (1024,) * 3 + (n_classes,), F.relu)
c = lenet.LeNet(3, n_classes)
# c = resnet.ResNet(depth=18, n_classes=n_classes)

critic = my.RN(args.batch_size, 2 * n_classes, tuple(), (4 * n_classes, 64, 256), (256, 64) + (1,), F.relu, triu=True)

if cuda:
    c.cuda()
    critic.cuda()

c_optim = Adam(c.parameters(), eps=1e-3)
critic_optim = Adam(critic.parameters())

for key, value in global_stats(c, test_loader).items():
    print(key, value)

accuracy 0.09609999507665634
precision 0.03016090951859951
recall 0.09610000252723694
f1 0.04590865224599838


In [27]:
hist = []
iw = {
    '' : lambda x: th.zeros(x.shape),
    'linear' : lambda x: x,
    'quadratic' : lambda x: x ** 2
}

for i in range(args.n_iterations):
    hist.append({})
#     hist[-1]['critic_state_dict'] = copy.deepcopy(my.state_dict_gpu2cpu(critic.state_dict()))        
#     hist[-1]['critic_optim_state_dict'] = my.optim_state_dict_gpu2cpu(critic_optim.state_dict())
#     hist[-1]['c_state_dict'] = copy.deepcopy(my.state_dict_gpu2cpu(c.state_dict()))
#     hist[-1]['c_optim_state_dict'] = my.optim_state_dict_gpu2cpu(c_optim.state_dict())

    if verbose == 0:
        t0 = time.time()

    x, y = next(train_loader)
    
    y_c = c(x)
    L_c = L_mini_batch(y_c, y)
    
    c_bar_list, y_bar_list, L_bar_list, t_list = [], [], [], []
    for j in range(args.n_perturbations):
        c_bar_list.append(my.perturb(c, args.std))
        y_bar_list.append(c_bar_list[-1](x))
        L_bar_list.append(L_mini_batch(y_bar_list[-1], y))
        t_list.append(L_c - L_bar_list[-1])
    t_tensor = th.cat(t_list, 0)
    w_tensor = th.exp(iw[args.iw](t_tensor)) / args.tau
    w_tensor /= th.sum(w_tensor)

    hist[-1]['L_bar_tensor'] = th.cat(L_bar_list, 0)
    hist[-1]['w_tensor'] = w_tensor

    if verbose == 0:
        t1 = time.time()
        print('[iteration %d]t1 - t0: %f' % (i + 1, t1 - t0))
            
    z_c = forward(y_c, y)
    z_detached = z_c.detach()
    z_bar_list = [forward(y_bar, y).detach() for y_bar in y_bar_list]
    for j in range(args.n_iterations_critic):
        for z_bar, t, w in zip(z_bar_list, t_list, w_tensor.tolist()):
            delta = critic(z_detached) - critic(z_bar)
            mse = th.sum(w * (t - delta) ** 2)
            critic_optim.zero_grad()
            mse.backward()
            critic_optim.step()
#     assert not my.module_isnan(critic)

    if verbose == 0:
        t2 = time.time()
        print('[iteration %d]t2 - t1: %f' % (i + 1, t2 - t1))

    objective = -th.mean(critic(z_c))
    c_optim.zero_grad()
    objective.backward()
    c_optim.step()
#     assert not my.module_isnan(c)

    if verbose == 0:
        t3 = time.time()
        print('[iteration %d]t3 - t2: %f' % (i + 1, t3 - t2))

    hist[-1]['stats'] = global_stats(c, test_loader)
    
    if (i + 1) % 1 == 0:
        print('[iteration %d]f1: %f' % (i + 1, hist[-1]['stats']['f1']))

KeyboardInterrupt: 

In [None]:
keys = sorted(vars(args).keys())
path = 'hist/' + '-'.join('%s-%s' % (key, str(getattr(args, key))) for key in keys)
pickle.dump(hist, open(path, 'wb'))