In [None]:
from __future__ import print_function
import copy
from threading import Condition, Thread
import numpy as np
import numpy.random as npr
import torch as th
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.modules.loss import CrossEntropyLoss, MSELoss
import torch.nn.functional as F
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader, TensorDataset
import my

In [None]:
import matplotlib.pylab as pl
%matplotlib inline

In [None]:
N_TRAIN, N_TEST = 0, 0
train_x, train_y, test_x, test_y = my.load_cifar10(rbg=True)

train_x = th.from_numpy(train_x).float()
train_y = th.from_numpy(train_y)
test_x = th.from_numpy(test_x).float()
test_y = th.from_numpy(test_y)

cuda = True
if cuda:
    th.cuda.set_device(3)

N_FEATURES = train_x.size()[1]
N_CLASSES = int(train_y.max() - train_y.min() + 1)

In [None]:
class CNN(nn.Module):
    def __init__(self, n_classes):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 2, 1)
        self.conv2 = nn.Conv2d(16, 8, 3, 2, 1)
        self.linear = nn.Linear(8, n_classes)
    
    def forward(self, x):
        x = F.tanh(self.conv1(x))
        x = F.tanh(self.conv2(x))
        x = F.avg_pool2d(x, 8)
        x = self.linear(x.view(-1, 8))
        return x

In [None]:
SAMPLE_SIZE = 50

th.random.manual_seed(1)
th.cuda.manual_seed_all(1)

for x, y in DataLoader(TensorDataset(train_x, train_y), SAMPLE_SIZE, shuffle=True):
    if cuda:
        x, y = x.cuda(), y.cuda()
    x, y = Variable(x), Variable(y)
    break

L = lambda c: my.nd_f_beta(my.predict(c, x), y, N_CLASSES)

def forward(classifier):
    z = my.onehot(y, N_CLASSES)
    z_bar = F.softmax(classifier(x), 1)
    return th.cat((z, z_bar), 1).view(1, -1)

In [None]:
th.random.manual_seed(1)
th.cuda.manual_seed_all(1)

c = CNN(N_CLASSES)
critic = my.RN(SAMPLE_SIZE, 2 * N_CLASSES, (32, 64), (64,) * 3 + (1,), F.relu)

if cuda:
    c.cuda()
    critic.cuda()

c_optim = Adam(c.parameters(), 1e-3)
critic_optim = Adam(critic.parameters(), 1e-3)

float(L(c))

In [None]:
# TODO optimize!

std = 1e-1
tau = 1e-2

N_ITERATIONS = 500
N_PERTURBATIONS = 50
CRITIC_ITERATIONS = 10
ACTOR_ITERATIONS = 10

hist = []
for i in range(N_ITERATIONS):
    hist.append({})
    hist[-1]['c_state_dict'] = copy.deepcopy(my.state_dict_gpu2cpu(c.state_dict()))
    
    c.eval()
    critic.train()
    L_c = L(c)
    c_bar_list, t_list = [], []
    f1_list = []
    for j in range(N_PERTURBATIONS):
        c_bar_list.append(my.perturb(c, std))
        L_bar = L(c_bar_list[-1])
        f1_list.append(float(L_bar))
        t = L_c - L_bar
        t_list.append(t[0])
    hist[-1]['f1_list'] = f1_list
    
    w_list = [th.exp(t**2 / tau) for t in t_list]
    z = sum(w_list)
    w_list = [(w / z).detach() for w in w_list]
    hist[-1]['w_list'] = w_list
    
    z = forward(c).detach()
    z_bar_list = [forward(c_bar).detach() for c_bar in c_bar_list]
    for j in range(CRITIC_ITERATIONS):
        for z_bar, t, w in zip(z_bar_list, t_list, w_list):
            delta = th.mean(critic(z) - critic(z_bar), 0)
            mse = w * MSELoss()(delta, t)
            critic_optim.zero_grad()
            mse.backward()
            critic_optim.step()
    hist[-1]['critic_state_dict'] = copy.deepcopy(my.state_dict_gpu2cpu(critic.state_dict()))

    c.train()
    critic.eval()
    c_param = copy.deepcopy(tuple(c.parameters()))
    for j in range(ACTOR_ITERATIONS):
        z = forward(c)
        objective = -critic(z)
        c_optim.zero_grad()
        objective.backward()
        c_optim.step()
        if any(float(th.max(th.abs(p - q))) > std for p, q in zip(c_param, c.parameters())):
            break

    hist[-1]['f1'] = float(L(c))
    
    if (i + 1) % 1 == 0:
        print('[iteration %d]f1: %f' % (i + 1, hist[-1]['f1']))

In [None]:
# h = hist[-1]
# c, c_bar = CNN(N_CLASSES), CNN(N_CLASSES)
# c.load_state_dict(my.state_dict_cpu2gpu(h['c_state_dict']))
# c_bar.load_state_dict(h['c_state_dict'])
# critic = my.RN(SAMPLE_SIZE, 2 * N_CLASSES, (1024,) * 3 + (1,), F.relu)
# critic.load_state_dict(my.state_dict_cpu2gpu(h['critic_state_dict']))
# if cuda:
#     c.cuda()
#     c_bar.cuda()
#     critic.cuda()

# critic.eval()

# # TODO gradient in parameter space/simplex
# objective = -th.mean(critic(forward(c)))
# objective.backward()

# alpha = np.linspace(0, 10)
# f1_list = []
# critic_list = []
# for a in alpha:
#     for p, p_bar in zip(c.parameters(), c_bar.parameters()):
#         p_bar.data = p.data - float(a) * p.grad.data
#     f1_list.append(float(L(c_bar)))
#     critic_list.append(float(critic(forward(c_bar))))

# pl.figure()
# pl.plot(alpha, f1_list)
# pl.figure()
# pl.plot(alpha, critic_list)

In [None]:
f1_list = [np.array(list(map(float, h['f1_list']))) for h in hist]
min_list = tuple(map(np.min, f1_list))
max_list = tuple(map(np.max, f1_list))
std_list = tuple(map(np.std, f1_list))
pl.plot(range(len(f1_list)), min_list, label='min sample')
pl.plot(range(len(f1_list)), max_list, label='max sample')
pl.plot(range(len(hist)), [h['f1'] for h in hist], label='classifier')
pl.title('f1')
pl.legend(framealpha=0)
pl.figure()
pl.title('std')
pl.plot(range(len(f1_list)), std_list)

In [None]:
entropy_list = []
for h in hist:
    w_array = np.array(list(map(float, h['w_list'])))
    entropy_list.append(-np.sum(w_array * np.log(w_array)))
n = len(hist)
pl.plot(range(n), entropy_list, label='importance distribution')
pl.plot(range(n), -np.ones(n) * np.log(1 / N_PERTURBATIONS), label='uniform distribution')
pl.title('entropy')
pl.legend(framealpha=0)

In [None]:
w_list = [np.array(list(map(float, h['w_list']))) for h in hist]
min_list = tuple(map(np.min, w_list))
max_list = tuple(map(np.max, w_list))
std_list = tuple(map(np.std, w_list))
pl.plot(range(len(w_list)), min_list, label='min')
pl.plot(range(len(w_list)), max_list, label='max')
pl.title('importance weight')
pl.legend(framealpha=0)
pl.figure()
pl.title('importance weight')
pl.plot(range(len(w_list)), std_list, label='std')
pl.legend(framealpha=0)