In [1]:
import argparse
import copy
import collections
import pickle
import time
import sklearn.metrics as metrics
import tensorboardX as tb
import torch as th
import torch.nn.functional as F
import torch.optim as optim
import torch.utils as utils
import data
import my
import lenet
import resnet
import rn

In [11]:
args = argparse.Namespace()
args.batch_size_c = 100
args.batch_size_critic = 1
args.ckpt_every = 1000
args.ckpt_id = ''
args.gpu = 0
args.iw = 'none'
# args.iw = 'sqrt'
# args.iw = 'linear'
# args.iw = 'quadratic'
args.log_every = 1
args.n_iterations = 50
args.n_iterations_critic = 25
args.n_perturbations = 50
args.resume = 0
args.std = 0.1
args.tau = 0.1
args.topk = 0

'''
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size-c', type=int, default=None)
parser.add_argument('--batch-size-critic', type=int, default=None)
parser.add_argument('--ckpt-every', type=int, default=None)
parser.add_argument('--ckpt-id', type=str, default='')
parser.add_argument('--gpu', type=int, default=None)
parser.add_argument('--iw', type=str, default=None)
parser.add_argument('--log-every', type=int, default=None)
parser.add_argument('--n-iterations', type=int, default=None)
parser.add_argument('--n-iterations-critic', type=int, default=None)
parser.add_argument('--n-perturbations', type=int, default=None)
parser.add_argument('--resume', type=int, default=None)
parser.add_argument('--std', type=float, default=None)
parser.add_argument('--tau', type=float, default=None)
parser.add_argument('--topk', type=int, default=0)
args = parser.parse_args()
'''

verbose = None

keys = sorted(vars(args).keys())
excluded = ('ckpt_every', 'ckpt_id', 'gpu', 'log_every', 'n_iterations', 'resume')
experiment_id = 'cifar10-9-1-parameter-' + '-'.join('%s-%s' % (key, str(getattr(args, key))) for key in keys if key not in excluded)
writer = tb.SummaryWriter('runs/' + experiment_id)

In [3]:
if args.gpu < 0:
    cuda = False
    new_tensor = th.FloatTensor
else:
    cuda = True
    new_tensor = th.cuda.FloatTensor
    th.cuda.set_device(args.gpu)

labelling = {(0, 9) : 0, (9, 10) : 1}
train_x, train_y, test_x, test_y = data.load_cifar10(labelling, rbg=True, torch=True)
# train_x, train_y, test_x, test_y = data.load_cifar10(labelling, rbg=False, torch=True)

train_set = utils.data.TensorDataset(train_x, train_y)
train_loader = utils.data.DataLoader(train_set, 4096, drop_last=False)
test_set = utils.data.TensorDataset(test_x, test_y)
test_loader = utils.data.DataLoader(test_set, 4096, drop_last=False)

loader = data.BalancedDataLoader(train_x, train_y, args.batch_size_c, cuda)

n_classes = int(train_y.max() - train_y.min() + 1)

In [4]:
def forward(y, y_bar):
    # TODO batchify
    y = my.onehot(y, n_classes)
    y_bar = F.softmax(y_bar, 1)
    return th.cat((y, y_bar), 1).view(1, -1)

def L_batch(y, y_bar, average='binary'):
    y_bar = th.max(y_bar, 1)[1].detach()
    return metrics.f1_score(y, y_bar, average=average)

def global_scores(c, loader, average='binary'):
    keys = ('accuracy', 'precision', 'recall', 'f1')
    scores = (
        metrics.accuracy_score,
        lambda y, y_bar: metrics.precision_score(y, y_bar, average=average),
        lambda y, y_bar: metrics.recall_score(y, y_bar, average=average),
        lambda y, y_bar: metrics.f1_score(y, y_bar, average=average),
    )
    values = [value.item() for value in my.global_scores(c, loader, scores)]
    return collections.OrderedDict(zip(keys, values))

iw = {
    'none' : lambda x: th.zeros_like(x),
    'sqrt' : lambda x: th.sqrt(th.abs(x)),
    'linear' : lambda x: x,
    'quadratic' : lambda x: x * x,
}

In [12]:
th.random.manual_seed(1)
if cuda:
    th.cuda.manual_seed_all(1)

# c = my.MLP((3072, n_classes), F.relu)
# c = my.MLP((3072,) + (1024,) + (n_classes,), F.relu)
# c = my.MLP((3072,) + (1024,) * 2 + (n_classes,), F.relu)
# c = my.MLP((3072,) + (1024,) * 3 + (n_classes,), F.relu)
# c = lenet.LeNet(3, n_classes)
c = resnet.ResNet(depth=18, n_classes=n_classes)

critic = rn.RN(args.batch_size_c, 2 * n_classes, tuple(), (4 * n_classes, 64, 64), (64,) * 3 + (1,), F.relu, triu=True)

if cuda:
    c.cuda()
    critic.cuda()

c_optim = optim.Adam(c.parameters(), amsgrad=True)
critic_optim = optim.Adam(critic.parameters(), amsgrad=True)

if args.resume > 0:
    ckpt_id = args.ckpt_id if args.ckpt_id else experiment_id
    c.load_state_dict(th.load('ckpt/%s-c-%d' % (ckpt_id, args.resume)))
    critic.load_state_dict(th.load('ckpt/%s-critic-%d' % (ckpt_id, args.resume)))
    c_optim.load_state_dict(th.load('ckpt/%s-c_optim-%d' % (ckpt_id, args.resume)))
    critic_optim.load_state_dict(th.load('ckpt/%s-critic_optim-%d' % (ckpt_id, args.resume)))

for key, value in global_scores(c, test_loader).items():
    print(key, value)

accuracy 0.8992
precision 0.0
recall 0.0
f1 0.0


In [None]:
for i in range(args.resume, args.resume + args.n_iterations):
    if verbose == 0:
        t0 = time.time()

    x, y = next(loader)
    
    y_c = c(x)
    L_c = L_batch(y, y_c)
    
    y_bar_listz, L_bar_listz, t_listz = [], [], []
    for j in range(args.n_perturbations):
        c_bar = copy.deepcopy(c)
        my.set_requires_grad(c_bar, False)
        my.perturb(c_bar, args.std)
        y_bar_listz.append(c_bar(x))
        L_bar_listz.append(L_batch(y, y_bar_listz[-1]))
        t_listz.append(L_c - L_bar_listz[-1])

    y_bar_tensorz = th.cat([y_bar.unsqueeze(0) for y_bar in y_bar_listz], 0)
    L_bar_tensorz = new_tensor(L_bar_listz)
    t_tensorz = new_tensor(t_listz)
    w_tensorz = th.exp(iw[args.iw](t_tensorz) / args.tau)
    w_tensorz /= th.sum(w_tensorz)

    writer.add_scalar('th.min(L_bar_tensorz)', th.min(L_bar_tensorz), i + 1)
    writer.add_scalar('th.max(L_bar_tensorz)', th.max(L_bar_tensorz), i + 1)
    writer.add_scalar('entropy', -th.sum(w_tensorz * th.log(w_tensorz)), i + 1)

    if args.topk > 0:
        w_tensor, topk = th.topk(w_tensorz, args.topk)
        y_bar_tensor, t_tensor = y_bar_tensorz[topk], t_tensorz[topk]
        y_bar_list = [y_bar.squeeze(0) for y_bar in th.chunk(y_bar_tensor, args.topk, 0)]
    else:
        w_tensor, y_bar_tensor, t_tensor = w_tensorz, y_bar_tensorz, t_tensorz
        y_bar_list = y_bar_listz
    
    if verbose == 0:
        t1 = time.time()
        print('[iteration %d]t1 - t0: %f' % (i + 1, t1 - t0))
    
    z_c = forward(y, y_c)
    z_detached = z_c.detach()
    z_bar_list = [forward(y, y_bar).detach() for y_bar in y_bar_list] # TODO batchify
    z_bar_tensor = th.cat(z_bar_list, 0)
    
    if args.topk > 0:
        n_batches = int(args.topk / args.batch_size_critic)
    else:
        n_batches = int(args.n_perturbations / args.batch_size_critic)
    chunk = lambda x: th.chunk(x, n_batches, 0)
    z_bar_list, t_list, w_list = tuple(map(chunk, (z_bar_tensor, t_tensor, w_tensor)))
    for j in range(args.n_iterations_critic):
        for z_bar, t, w in zip(z_bar_list, t_list, w_list):
            delta = critic(z_detached) - critic(z_bar)
            mse = th.sum(w * (t - delta) ** 2)
#             mse = th.sum(w * delta * (th.sign(delta) - th.sign(t)))
            critic_optim.zero_grad()
            mse.backward()
            critic_optim.step()
        delta = critic(z_detached) - critic(z_bar_tensor)
        mse = th.sum(w_tensor * (t_tensor - delta) ** 2)
        writer.add_scalar('mse', mse, i * args.n_iterations_critic + j + 1)

    if verbose == 0:
        t2 = time.time()
        print('[iteration %d]t2 - t1: %f' % (i + 1, t2 - t1))

    objective = -critic(z_c)
    c_optim.zero_grad()
    objective.backward()
    c_optim.step()

    if verbose == 0:
        t3 = time.time()
        print('[iteration %d]t3 - t2: %f' % (i + 1, t3 - t2))
    
    L_c = L_batch(y, c(x))
    writer.add_scalar('L_c', L_c, i + 1)
    
    if args.log_every > 0 and (i + 1) % args.log_every == 0:
        train_scores = global_scores(c, train_loader)
        test_scores = global_scores(c, test_loader)

        prefix = '0' * (len(str(args.n_iterations)) - len(str(i + 1)))
        print('[iteration %s%d]' % (prefix, i + 1) + \
              ' | '.join('%s %0.3f/%0.3f' % (key, value, test_scores[key]) for key, value in train_scores.items()))

        for key, value in train_scores.items():
            writer.add_scalar('train-' + key, value, i + 1)

        for key, value in test_scores.items():
            writer.add_scalar('test-' + key, value, i + 1)

    if args.ckpt_every > 0 and (i + 1) % args.ckpt_every == 0:
        th.save(c.state_dict(), 'ckpt/%s-c-%d' % (experiment_id, i + 1))
        th.save(critic.state_dict(), 'ckpt/%s-critic-%d' % (experiment_id, i + 1))
        th.save(c_optim.state_dict(), 'ckpt/%s-c_optim-%d' % (experiment_id, i + 1))
        th.save(critic_optim.state_dict(), 'ckpt/%s-critic_optim-%d' % (experiment_id, i + 1))

  'precision', 'predicted', average, warn_for)


[iteration 01]accuracy 0.721/0.719 | precision 0.789/0.783 | recall 0.234/0.232 | f1 0.361/0.358
[iteration 02]accuracy 0.715/0.713 | precision 0.785/0.775 | recall 0.229/0.227 | f1 0.355/0.351
[iteration 03]accuracy 0.715/0.716 | precision 0.775/0.772 | recall 0.228/0.228 | f1 0.353/0.352
[iteration 04]accuracy 0.723/0.725 | precision 0.760/0.755 | recall 0.231/0.231 | f1 0.354/0.354
[iteration 05]accuracy 0.730/0.733 | precision 0.730/0.723 | recall 0.231/0.232 | f1 0.351/0.351
[iteration 06]accuracy 0.746/0.750 | precision 0.715/0.709 | recall 0.240/0.243 | f1 0.360/0.362
[iteration 07]accuracy 0.764/0.767 | precision 0.708/0.700 | recall 0.255/0.257 | f1 0.375/0.376
[iteration 08]accuracy 0.786/0.788 | precision 0.697/0.685 | recall 0.275/0.275 | f1 0.394/0.393
[iteration 09]accuracy 0.794/0.795 | precision 0.698/0.687 | recall 0.284/0.284 | f1 0.404/0.402
[iteration 10]accuracy 0.796/0.798 | precision 0.667/0.658 | recall 0.281/0.282 | f1 0.395/0.395
[iteration 11]accuracy 0.800/0