In [1]:
from __future__ import print_function
import copy
from threading import Condition, Thread
import numpy as np
import numpy.random as npr
import torch as th
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.modules.loss import CrossEntropyLoss, MSELoss
import torch.nn.functional as F
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader, TensorDataset
import my

In [2]:
import matplotlib.pylab as pl
%matplotlib inline

In [3]:
N_TRAIN, N_TEST = 0, 0
BATCH_SIZE = 64
cuda = True

# train_data, train_labels, test_data, test_labels = my.unbalanced_cifar10(N_TRAIN, N_TEST, p=[0, 1, 10])
train_data, train_labels, test_data, test_labels = my.unbalanced_cifar10(N_TRAIN, N_TEST, p=[])

train_data_np, train_labels_np, test_data_np, test_labels_np = \
    train_data, train_labels, test_data, test_labels
    
train_data = th.from_numpy(train_data).float()
train_labels = th.from_numpy(train_labels).long()
test_data = th.from_numpy(test_data).float()
test_labels = th.from_numpy(test_labels).long()


if cuda:
    th.cuda.set_device(3)

train_loader = DataLoader(TensorDataset(train_data, train_labels), BATCH_SIZE, shuffle=True)
test_loader = DataLoader(TensorDataset(test_data, test_labels), BATCH_SIZE)

N_FEATURES = train_data.size()[1]
N_CLASSES = int(train_labels.max() - train_labels.min() + 1)

In [4]:
class CNN(nn.Module):
    def __init__(self, n_classes):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 2, 1)
        self.conv2 = nn.Conv2d(16, 8, 3, 2, 1)
        self.linear = nn.Linear(8, n_classes)
    
    def forward(self, x):
        if x.dim() != 4:
            x = x.view(-1, 3, 32, 32)
        x = F.tanh(self.conv1(x))
        x = F.tanh(self.conv2(x))
        x = F.avg_pool2d(x, 8)
        x = self.linear(x.view(-1, 8))
        return x

In [None]:
# c = my.MLP((N_FEATURES,) + (64,) * 3 + (N_CLASSES,), F.relu)
# c_pretrained = CNN(N_CLASSES)
# if cuda:
#     c_pretrained.cuda()
# optim = Adam(c_pretrained.parameters(), lr=0.001)
# EPOCHS = 1
# for i in range(EPOCHS):
#     for x, y in train_loader:
#         if cuda:
#             x, y = x.cuda(), y.cuda()
#         x, y = Variable(x), Variable(y)
#         loss = CrossEntropyLoss()(c_pretrained(x), y)
#         optim.zero_grad()
#         loss.backward()
#         optim.step()
#     accuracy = my.global_stats(c_pretrained, test_loader, my.accuracy)
#     print('[epoch %d]cross-entropy loss: %f, accuracy: %f' % ((i + 1), float(loss), float(accuracy)))

In [None]:
# nd_stats = [my.accuracy] + [my.nd_curry(stat, N_CLASSES) for stat in (my.nd_precision, my.nd_recall, my.nd_f_beta)]
# accuracy, precision, recall, f1 = my.global_stats(c_pretrained, test_loader, nd_stats)
# 'accuracy: %f, precision: %f, recall: %f, f1: %f' % tuple(map(float, (accuracy, precision, recall, f1)))

In [5]:
def sample(sample_size, batch_size):
    samples = [my.sample_subset(train_data_np, train_labels_np, sample_size) for k in range(batch_size)]
    if cuda:
        samples = [(x.cuda(), y.cuda()) for (x, y) in samples]
    return [(Variable(x), Variable(y)) for (x, y) in samples]

def perturb_y(y, std):
    return y + (lambda x: x.cuda() if cuda else x)(Variable(th.randn(y.size()) * std))

def L(y_bar, y):
    return th.cat(tuple(my.nd_f_beta(th.max(z_bar, 1)[1], z, N_CLASSES).view(1, 1)
                        for z_bar, z in zip(y_bar, y)), 1)

def critic_forward(y_bar, y, detach=False):
    y_bar = tuple(F.softmax(z_bar, 1) for z_bar in y_bar)
    y = tuple(my.onehot(z, N_CLASSES) for z in y)
    z = tuple(th.cat((z_bar, z), 1).view(1, -1) for z_bar, z in zip(y_bar, y))
    return critic((lambda x: x.detach() if detach else x)(th.cat(z, 0)))

def MSE(y_bar, y_perturbed, y, target):
    l_bar = critic_forward(y_bar, y, True)
    l_perturbed = critic_forward(y_perturbed, y, True)
#     return th.mean(th.exp(target ** 2 / 0.1) * ((l_perturbed - l_bar) - target) ** 2)
    return th.mean(((l_perturbed - l_bar) - target) ** 2)

In [6]:
SAMPLE_SIZE = 16

th.random.manual_seed(1)
th.cuda.manual_seed_all(1)

# c = my.MLP((N_FEATURES,) + (512,) * 1 + (N_CLASSES,), F.relu)
c = CNN(N_CLASSES)
# c.load_state_dict(c_pretrained.state_dict())
critic = my.RN(SAMPLE_SIZE, 2 * N_CLASSES, (512,) * 3 + (1,), F.relu)

if cuda:
    c.cuda()
    critic.cuda()

# c_optim = SGD(c.parameters(), 0.1, momentum=0.5)
# critic_optim = SGD(critic.parameters(), 0.1, momentum=0.5)
c_optim = Adam(c.parameters(), 1e-3)
critic_optim = Adam(critic.parameters(), 1e-3)

nd_stats = [my.accuracy] + [my.nd_curry(stat, N_CLASSES) for stat in (my.nd_precision, my.nd_recall, my.nd_f_beta)]
accuracy, precision, recall, f1 = my.global_stats(c, test_loader, nd_stats)
'accuracy: %f, precision: %f, recall: %f, f1: %f' % tuple(map(float, (accuracy, precision, recall, f1)))

'accuracy: 0.100000, precision: 0.010000, recall: 0.100000, f1: 0.018180'

In [None]:
N_ITERATIONS = 5
N_PERTURBATIONS = 25
BATCH_SIZE = 4
STD = 1e-1
CRITIC_N_ITERATIONS = 10
CLASSIFIER_N_ITERATIONS = 25

# TODO replay buffer
hist = []
for i in range(N_ITERATIONS):
    hist.append({})
    hist[-1]['c_state_dict'] = copy.deepcopy(c.state_dict())

    c.eval()
    critic.train()
    x, y = zip(*sample(SAMPLE_SIZE, BATCH_SIZE))
    hist[-1]['x'], hist[-1]['y'] = x, y
    y_bar = tuple(map(c, x))
    L_bar = L(y_bar, y)
    l_bar = critic_forward(y_bar, y, True)

    p_list = []
    for j in range(N_PERTURBATIONS):
        y_perturbed = tuple(perturb_y(z_bar, STD) for z_bar in y_bar) # TODO perturb in simplex
        L_perturbed = L(y_perturbed, y)
        target = L_perturbed - L_bar
        p_list.append((y_perturbed, target))

    for j in range(CRITIC_N_ITERATIONS):
        mse = 0
        for y_perturbed, target in p_list:
            mse += MSE(y_bar, y_perturbed, y, target)
        mse /= N_PERTURBATIONS
        critic_optim.zero_grad()
        mse.backward()
        critic_optim.step()
    hist[-1]['critic_state_dict'] = copy.deepcopy(critic.state_dict())
    
    c.train()
    critic.eval()
    c_params = copy.deepcopy(tuple(c.parameters()))
    for j in range(CLASSIFIER_N_ITERATIONS):
        y_bar = tuple(map(c, x))
        objective = -th.mean(critic_forward(y_bar, y))
        c_optim.zero_grad()
        objective.backward()
        c_optim.step()
        if any(float(th.max(th.abs(p - q))) > 0.1 for p, q in zip(c_params, c.parameters())):
            break
    
    if (i + 1) % 1 == 0:
        f1 = my.global_stats(c, test_loader, my.nd_curry(my.nd_f_beta, N_CLASSES))
        print('[iteration %d]mse: %f; f1: %f' % (i + 1, mse, f1))

In [None]:
for h in hist:
    c, c_bar = CNN(N_CLASSES), CNN(N_CLASSES)
    c.load_state_dict(h['c_state_dict'])
    c_bar.load_state_dict(h['c_state_dict'])
    critic = my.RN(SAMPLE_SIZE, 2 * N_CLASSES, (512,) * 3 + (1,), F.relu)
    critic.load_state_dict(h['critic_state_dict'])
    if cuda:
        c.cuda()
        c_bar.cuda()
        critic.cuda()

    # TODO gradient in parameter space/simplex
    y_bar = tuple(map(c, h['x']))
    objective = -th.mean(critic_forward(y_bar, h['y']))
    objective.backward()

    alpha = np.linspace(0, 1)
    f1_list = []
    for a in alpha:
        for p, p_bar in zip(c.parameters(), c_bar.parameters()):
            p_bar.data = p.data + float(a) * p.grad.data
        # TODO global f1/batch f1
        f1 = my.global_stats(c_bar, test_loader, my.nd_curry(my.nd_f_beta, N_CLASSES))
        f1_list.append(float(f1))
    
    pl.figure()
    pl.plot(alpha, f1_list)