In [None]:
from __future__ import print_function
import copy
from threading import Condition, Thread
import numpy as np
import numpy.random as npr
import torch as th
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.modules.loss import CrossEntropyLoss, MSELoss
import torch.nn.functional as F
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader, TensorDataset
import my

In [None]:
import matplotlib.pylab as pl
%matplotlib inline

In [None]:
N_TRAIN, N_TEST = 0, 0
BATCH_SIZE = 64
cuda = True

train_x, train_y, test_x, test_y = my.load_cifar10(partition=[], rbg=True)

train_x, test_x = th.from_numpy(train_x).float(), th.from_numpy(test_x).float()
train_y, test_y = th.from_numpy(train_y), th.from_numpy(test_y).long()

N_FEATURES = train_x.size()[1]
N_CLASSES = int(train_y.max() - train_y.min() + 1)

if cuda:
    th.cuda.set_device(2)

In [None]:
class CNN(nn.Module):
    def __init__(self, n_classes):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 2, 1)
        self.conv2 = nn.Conv2d(16, 8, 3, 2, 1)
        self.linear = nn.Linear(8, n_classes)
    
    def forward(self, x):
        if x.dim() != 4:
            x = x.view(-1, 3, 32, 32)
        x = F.tanh(self.conv1(x))
        x = F.tanh(self.conv2(x))
        x = F.avg_pool2d(x, 8)
        x = self.linear(x.view(-1, 8))
        return x

In [None]:
BATCH_SIZE = 64

th.random.manual_seed(1)
th.cuda.manual_seed_all(1)

for x, y in DataLoader(TensorDataset(train_x, train_y), BATCH_SIZE, shuffle=True):
    if cuda:
        x, y = x.cuda(), y.cuda()
    x, y = Variable(x), Variable(y)
    break

L = lambda y_bar: my.nd_f_beta(th.max(y_bar, 1)[1], y, N_CLASSES)

def perturb_y(y, std):
    return y + (lambda x: x.cuda() if cuda else x)(Variable(th.randn(y.size()) * std))

def critic_forward(y_bar, detach=False):
    z_bar = F.softmax(y_bar, 1)
    z = my.onehot(y, N_CLASSES)
    x = th.cat((z_bar, z), 1).view(1, -1)
    return critic((lambda x: x.detach() if detach else x)(x))

def critic_mse(y_bar, y_perturbed, target):
    l_bar = critic_forward(y_bar, True)
    l_perturbed = critic_forward(y_perturbed, True)
    return th.mean(((l_perturbed - l_bar) - target) ** 2)

In [None]:
th.random.manual_seed(1)
th.cuda.manual_seed_all(1)

c = CNN(N_CLASSES)
critic = my.RN(BATCH_SIZE, 2 * N_CLASSES, (512,) * 3 + (1,), F.relu)

if cuda:
    c.cuda()
    critic.cuda()

c_optim = Adam(c.parameters(), 1e-3)
critic_optim = Adam(critic.parameters(), 1e-3)

float(L(c(x)))

In [None]:
CLASSIFIER_N_ITERATIONS = 10
CRITIC_N_ITERATIONS = 10
N_ITERATIONS = 250
N_PERTURBATIONS = 25
std = 1e-1
tau = 1e-2

hist = []
for i in range(N_ITERATIONS):
    hist.append({})
#     hist[-1]['c_state_dict'] = copy.deepcopy(c.state_dict())

    c.eval()
    critic.train()
    y_bar = c(x)
    L_bar = L(y_bar)
    l_bar = critic_forward(y_bar, detach=True)

    y_list, t_list, w_list = [], [], []
    for j in range(N_PERTURBATIONS):
        y_ptrbd = perturb_y(y_bar, std) # TODO perturb in simplex
        L_ptrbd = L(y_ptrbd)
        t = L_ptrbd - L_bar
        y_list.append(y_ptrbd)
        t_list.append(t)
        w_list.append(th.exp(t ** 2 / tau))
        z = sum(w_list)
        w_list = [(w / z).detach() for w in w_list]

    for j in range(CRITIC_N_ITERATIONS):
        mse = 0
        for y_ptrbd, t, w in zip(y_list, t_list, w_list):
            mse += w * critic_mse(y_bar, y_ptrbd, t)
        mse /= N_PERTURBATIONS
        critic_optim.zero_grad()
        mse.backward()
        critic_optim.step()
#     hist[-1]['critic_state_dict'] = copy.deepcopy(critic.state_dict())
    
    c.train()
    critic.eval()
    c_param = copy.deepcopy(tuple(c.parameters()))
    for j in range(CLASSIFIER_N_ITERATIONS):
        y_bar = c(x)
        objective = -th.mean(critic_forward(y_bar))
        c_optim.zero_grad()
        objective.backward()
        c_optim.step()
        if any(float(th.max(th.abs(p - q))) > std for p, q in zip(c_param, c.parameters())):
            break
    
    if (i + 1) % 1 == 0:
        f1 = L(c(x))
        print('[iteration %d]mse: %f; f1: %f' % (i + 1, mse, f1))