In [None]:
import numpy as np
import torch as th
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.modules.loss import CrossEntropyLoss, MSELoss
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader, TensorDataset
from torchvision.datasets import MNIST
from torchvision import transforms
import my

In [None]:
N_TRAIN, N_TEST = 0, 0
train_data, train_labels, test_data, test_labels = my.unbalanced_dataset(
    'MNIST', N_TRAIN, N_TEST, pca=True, p=[0, 1, 10])
train_data_np, train_labels_np, test_data_np, test_labels_np = \
    train_data, train_labels, test_data, test_labels

train_data = th.from_numpy(train_data).float()
train_labels = th.from_numpy(train_labels).long()
test_data = th.from_numpy(test_data).float()
test_labels = th.from_numpy(test_labels).long()

cuda = True
if cuda:
    th.cuda.set_device(3)
cudalize = lambda x: x.cuda() if cuda else x
    
BATCH_SIZE = 64
train_loader = DataLoader(TensorDataset(train_data, train_labels), BATCH_SIZE)
test_loader = DataLoader(TensorDataset(test_data, test_labels), BATCH_SIZE)

N_FEATURES = train_data.size()[1]
N_CLASSES = int(train_labels.max() - train_labels.min() + 1)

In [None]:
mlp = my.MLP((N_FEATURES, N_CLASSES), None)
if cuda:
    mlp.cuda()
optim = Adam(mlp.parameters(), lr=0.001)

N_EPOCHS = 5
for e in range(N_EPOCHS):
    for i, (X, y) in enumerate(train_loader):
        if cuda:
            X, y = X.cuda(), y.cuda()
        X, y = Variable(X), Variable(y)
        optim.zero_grad()
        loss = CrossEntropyLoss()(mlp(X), y)
        loss.backward()
        optim.step()
        if (i + 1) % 1000 == 0:
            accuracy = my.accuracy(my.predict(mlp, X), y)
            print('[iteration %d]nll loss: %f, accuracy %f' % (i + 1, float(loss), float(accuracy)))
        
    accuracy, precision, recall, f1 = my.global_stats(mlp, test_loader, (
        my.accuracy,
        my.nd_curry(my.nd_precision, N_CLASSES),
        my.nd_curry(my.nd_recall, N_CLASSES),
        my.nd_curry(my.nd_f_beta, N_CLASSES)))
    print('[epoch %d]accuracy: %f, precision: %f, recall: %f, f1: %f' % (
        e + 1, float(accuracy), float(precision), float(recall), float(f1)))

In [None]:
L = lambda c, loader: my.global_stats(c, loader, my.nd_curry(my.nd_f_beta, N_CLASSES))

def forward(classifier, X, y, std=0):
    if cuda:
        X, y = X.cuda(), y.cuda()
    X, y = Variable(X), Variable(y)
    y = my.onehot(y, N_CLASSES)
    y_bar = classifier(X)
    if std > 0:
        noise = th.randn(y_bar.size()) * std
        if y_bar.is_cuda:
            noise = noise.cuda()
        y_bar += Variable(noise)
    y_bar = F.softmax(y_bar, 1)
    return th.cat((y, y_bar), 1).view(1, -1)

create_sample_loader = lambda: iter(DataLoader(TensorDataset(train_data, train_labels),
                                               BATCH_SIZE, shuffle=True, drop_last=True))

def sample(loader, K):
    samples = []
    for k in range(K):
        try:
            samples.append(next(loader))
        except StopIteration:
            loader = create_sample_loader()
            samples.append(next(loader))
    return samples, loader

In [None]:
c = my.MLP((N_FEATURES, N_CLASSES), th.tanh)
if cuda:
    c.cuda()
c_optim = SGD(c.parameters(), 1e-1, momentum=0.9)
# c_optim = Adam(c.parameters(), 1e-3)

SAMPLE_SIZE = 64
D = (train_data.size()[1] + N_CLASSES + N_CLASSES) * SAMPLE_SIZE
critic = my.MLP(((N_CLASSES + N_CLASSES) * SAMPLE_SIZE,) + (1024,) * 3 + (1,), F.relu)
if cuda:
    critic.cuda()
critic_optim = SGD(critic.parameters(), 1e-3, momentum=0.9)
# critic_optim = Adam(critic.parameters(), 1e-3)

N = 4096 * 2
train_loader = DataLoader(TensorDataset(train_data[:N], train_labels[:N]), N / 2)
test_loader = DataLoader(TensorDataset(test_data[:N], test_labels[:N]), N / 2)

sample_loader = create_sample_loader()

float(L(c, train_loader))

In [None]:
OUTER = 2000
INNER = 10
STD = 0.1
K = 5
critic_aware = False
tau = 1

stats = [{} for _ in range(OUTER)]
for i in range(OUTER):
    stats[i]['mse'] = []
    for j in range(INNER):
        c_bar = my.perturb(c, STD)
        delta = L(c, train_loader) - L(c_bar, train_loader)
        
        samples, sample_loader = sample(sample_loader, K)
        c_action = th.cat([forward(c, X, y) for X, y in samples], 0)
        c_bar_action = th.cat([forward(c_bar, X, y) for X, y in samples], 0)
        delta_bar = th.mean(critic(c_action) - critic(c_bar_action), 0)
        
        mse = MSELoss()(delta_bar, delta)
        stats[i]['mse'].append(float(mse))
        critic_optim.zero_grad()
        mse.backward()
        critic_optim.step()
        
    samples, sample_loader = sample(sample_loader, K)
    y_list = [cudalize(y) for X, y in samples]
    std = 1 if critic_aware else 0
    c_action = [forward(c, X, y, std) for X, y in samples]
    if critic_aware:
        y_bar_list = [th.max(a.data.view(SAMPLE_SIZE, N_CLASSES * 2)[:, -N_CLASSES:], 1)[1] for a in c_action]
        f1 = [my.nd_f_beta(y_bar, y, N_CLASSES) for y_bar, y in zip(y_bar_list, y_list)]
        f1 = cudalize(th.from_numpy(np.array(f1)).float())
    c_action = th.cat(c_action, 0)
    c_action.register_hook(lambda g: stats[i].update({'mean': float(th.mean(g)), 'std': float(th.std(g))}))
    f1_bar = critic(c_action)
    weight = th.exp(-(Variable(f1) - f1_bar.view(K)) ** 2 / tau).detach() if critic_aware else 1
    objective = -th.mean(weight * f1_bar)
    c_optim.zero_grad()
    objective.backward()
    c_optim.step()

    if (i + 1) % 100 == 0:
        f1 = L(c, test_loader)
        print('[iteration %d]mse: %f, f1: %f' % ((i + 1), float(mse), float(f1)))

In [None]:
accuracy, precision, recall, f1 = my.global_stats(c, test_loader, (
    my.accuracy,
    my.nd_curry(my.nd_precision, N_CLASSES),
    my.nd_curry(my.nd_recall, N_CLASSES),
    my.nd_curry(my.nd_f_beta, N_CLASSES)))
float(accuracy), float(precision), float(recall), float(f1)

In [None]:
from matplotlib import pylab as pl
%matplotlib inline

In [None]:
mse_list = [sum(stat['mse']) / INNER for stat in stats]
pl.plot(range(len(mse_list)), mse_list)

In [None]:
means = [stat['mean'] for stat in stats]
pl.figure()
pl.xlabel('iteration')
pl.ylabel('mean')
pl.plot(range(len(means)), means)
stds = [stat['std'] for stat in stats]
pl.figure()
pl.xlabel('iteration')
pl.ylabel('std')
pl.plot(range(len(stds)), stds)