In [1]:
from __future__ import print_function
import copy
from threading import Condition, Thread
import numpy as np
import numpy.random as npr
import torch as th
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.modules.loss import CrossEntropyLoss, MSELoss
import torch.nn.functional as F
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader, TensorDataset
import my

In [2]:
import matplotlib.pylab as pl
%matplotlib inline

In [3]:
N_TRAIN, N_TEST = 0, 0
train_data, train_labels, test_data, test_labels = my.unbalanced_cifar10(N_TRAIN, N_TEST, p=[])

train_data_np, train_labels_np, test_data_np, test_labels_np = \
    train_data, train_labels, test_data, test_labels
    
train_data = th.from_numpy(train_data).float()
train_labels = th.from_numpy(train_labels).long()
test_data = th.from_numpy(test_data).float()
test_labels = th.from_numpy(test_labels).long()

cuda = True
if cuda:
    th.cuda.set_device(3)

N_FEATURES = train_data.size()[1]
N_CLASSES = int(train_labels.max() - train_labels.min() + 1)

In [4]:
class CNN(nn.Module):
    def __init__(self, n_classes):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 2, 1)
        self.conv2 = nn.Conv2d(16, 8, 3, 2, 1)
        self.linear = nn.Linear(8, n_classes)
    
    def forward(self, x):
        if x.dim() != 4:
            x = x.view(-1, 3, 32, 32)
        x = F.tanh(self.conv1(x))
        x = F.tanh(self.conv2(x))
        x = F.avg_pool2d(x, 8)
        x = self.linear(x.view(-1, 8))
        return x

In [5]:
SAMPLE_SIZE = 64

th.random.manual_seed(1)
th.cuda.manual_seed_all(1)

for x, y in DataLoader(TensorDataset(train_data, train_labels), SAMPLE_SIZE, shuffle=True):
    if cuda:
        x, y = x.cuda(), y.cuda()
    x, y = Variable(x), Variable(y)
    break

L = lambda c: my.nd_f_beta(my.predict(c, x), y, N_CLASSES)

def forward(classifier):
    z = my.onehot(y, N_CLASSES)
    z_bar = F.softmax(classifier(x), 1)
    return th.cat((z, z_bar), 1).view(1, -1)

# def intrinsic_mse(critic):
#     z = my.onehot(y, N_CLASSES)
#     return (1 - critic(th.cat((z, z), 1).view(1, -1))) ** 2

state_dict_cpu2gpu = lambda state_dict: {key : value.cuda() for key, value in state_dict.items()}
state_dict_gpu2cpu = lambda state_dict: {key : value.cpu() for key, value in state_dict.items()}

In [10]:
th.random.manual_seed(1)
th.cuda.manual_seed_all(1)

# c = nn.Linear(N_FEATURES, N_CLASSES)
# c = my.MLP((N_FEATURES,) + (512,) * 1 + (N_CLASSES,), F.relu)
c = CNN(N_CLASSES)
critic = my.RN(SAMPLE_SIZE, 2 * N_CLASSES, (1024,) * 3 + (1,), F.relu)

if cuda:
    c.cuda()
    critic.cuda()

# c_optim = SGD(c.parameters(), 0.1, momentum=0.5)
# critic_optim = SGD(critic.parameters(), 0.1, momentum=0.5)
c_optim = Adam(c.parameters(), 1e-3)
critic_optim = Adam(critic.parameters(), 1e-3)

float(L(c))

0.019716529175639153

In [None]:
std = 1e-1
tau = 1e-2

N_ITERATIONS = 500
N_PERTURBATIONS = 50
CRITIC_ITERATIONS = 10
ACTOR_ITERATIONS = 5

hist = []
for i in range(N_ITERATIONS):
    hist.append({})
    hist[-1]['c_state_dict'] = copy.deepcopy(state_dict_gpu2cpu(c.state_dict()))
    
    c.eval()
    critic.train()
    L_c = L(c)
    c_bar_list, t_list = [], []
    f1_list = []
    for j in range(N_PERTURBATIONS):
        c_bar_list.append(my.perturb(c, std))
        L_bar = L(c_bar_list[-1])
        f1_list.append(float(L_bar))
        t = L_c - L_bar
        t_list.append(t[0])
    hist[-1]['f1_list'] = f1_list
    
#     print(min(map(float, t_list)), max(map(float, t_list)))
    
    w_list = [th.exp(-t ** 2 / tau) for t in t_list] # TODO proposal distribution
    z = sum(w_list)
    w_list = [(w / z).detach() for w in w_list]
    hist[-1]['w_list'] = w_list
    
    z = forward(c).detach()
    for j in range(CRITIC_ITERATIONS):
#         critic_optim.zero_grad()
#         intrinsic_mse(critic).backward()
#         critic_optim.step()
        for c_bar, t, w in zip(c_bar_list, t_list, w_list):
            z_bar = forward(c_bar).detach()
            delta = th.mean(critic(z) - critic(z_bar), 0)
            mse = w * MSELoss()(delta, t)
            critic_optim.zero_grad()
            mse.backward()
            critic_optim.step()
    hist[-1]['critic_state_dict'] = copy.deepcopy(state_dict_gpu2cpu(critic.state_dict()))

    c.train()
    critic.eval()
    c_parameters = copy.deepcopy(tuple(c.parameters()))
    for j in range(ACTOR_ITERATIONS):
        z = forward(c)
        objective = -th.mean(critic(z)) # TODO remove th.mean
        c_optim.zero_grad()
        objective.backward()
        c_optim.step()
        if any(float(th.max(th.abs(p - q))) > std for p, q in zip(c_parameters, c.parameters())):
            break

    hist[-1]['f1'] = float(L(c))
    
    if (i + 1) % 1 == 0:
        print('[iteration %d]f1: %f' % (i + 1, hist[-1]['f1']))

[iteration 1]f1: 0.064086
[iteration 2]f1: 0.040906
[iteration 3]f1: 0.053655
[iteration 4]f1: 0.057710
[iteration 5]f1: 0.056277
[iteration 6]f1: 0.055719
[iteration 7]f1: 0.061244
[iteration 8]f1: 0.105319
[iteration 9]f1: 0.080248
[iteration 10]f1: 0.085937
[iteration 11]f1: 0.104718
[iteration 12]f1: 0.118358
[iteration 13]f1: 0.258465
[iteration 14]f1: 0.263717
[iteration 15]f1: 0.161076
[iteration 16]f1: 0.143018
[iteration 17]f1: 0.143018
[iteration 18]f1: 0.156580


In [None]:
# h = hist[-1]
# c, c_bar = CNN(N_CLASSES), CNN(N_CLASSES)
# c.load_state_dict(state_dict_cpu2gpu(h['c_state_dict']))
# c_bar.load_state_dict(h['c_state_dict'])
# critic = my.RN(SAMPLE_SIZE, 2 * N_CLASSES, (1024,) * 3 + (1,), F.relu)
# critic.load_state_dict(state_dict_cpu2gpu(h['critic_state_dict']))
# if cuda:
#     c.cuda()
#     c_bar.cuda()
#     critic.cuda()

# critic.eval()

# # TODO gradient in parameter space/simplex
# objective = -th.mean(critic(forward(c)))
# objective.backward()

# alpha = np.linspace(0, 10)
# f1_list = []
# critic_list = []
# for a in alpha:
#     for p, p_bar in zip(c.parameters(), c_bar.parameters()):
#         p_bar.data = p.data - float(a) * p.grad.data
#     f1_list.append(float(L(c_bar)))
#     critic_list.append(float(critic(forward(c_bar))))

# pl.figure()
# pl.plot(alpha, f1_list)
# pl.figure()
# pl.plot(alpha, critic_list)

In [None]:
f1_list = [np.array(list(map(float, h['f1_list']))) for h in hist]
min_list = tuple(map(np.min, f1_list))
max_list = tuple(map(np.max, f1_list))
std_list = tuple(map(np.std, f1_list))
pl.plot(range(len(f1_list)), min_list, label='min sample')
pl.plot(range(len(f1_list)), max_list, label='max sample')
pl.plot(range(len(hist)), [h['f1'] for h in hist], label='classifier')
pl.title('f1')
pl.legend(framealpha=0)
pl.figure()
pl.title('std')
pl.plot(range(len(f1_list)), std_list)

In [None]:
entropy_list = []
for h in hist:
    w_array = np.array(list(map(float, h['w_list'])))
    entropy_list.append(-np.sum(w_array * np.log(w_array)))
n = len(hist)
pl.plot(range(n), entropy_list, label='importance distribution')
pl.plot(range(n), -np.ones(n) * np.log(1 / N_PERTURBATIONS), label='uniform distribution')
pl.title('entropy')
pl.legend(framealpha=0)

In [None]:
w_list = [np.array(list(map(float, h['w_list']))) for h in hist]
min_list = tuple(map(np.min, w_list))
max_list = tuple(map(np.max, w_list))
std_list = tuple(map(np.std, w_list))
pl.plot(range(len(w_list)), min_list, label='min')
pl.plot(range(len(w_list)), max_list, label='max')
pl.title('importance weight')
pl.legend(framealpha=0)
pl.figure()
pl.title('importance weight')
pl.plot(range(len(w_list)), std_list, label='std')
pl.legend(framealpha=0)