In [None]:
from __future__ import print_function
import numpy as np
import numpy.random as npr
import torch as th
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.modules.loss import CrossEntropyLoss, MSELoss
import torch.nn.functional as F
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader, TensorDataset
import my

In [None]:
N_TRAIN, N_TEST = 0, 0
BATCH_SIZE = 64
cuda = True

train_data, train_labels, test_data, test_labels = my.unbalanced_dataset(
    'MNIST', N_TRAIN, N_TEST, pca=False, p=[0, 1, 10])

train_data_np, train_labels_np, test_data_np, test_labels_np = \
    train_data, train_labels, test_data, test_labels
    
train_data = th.from_numpy(train_data).float()
train_labels = th.from_numpy(train_labels).long()
test_data = th.from_numpy(test_data).float()
test_labels = th.from_numpy(test_labels).long()

if cuda:
    th.cuda.set_device(2)
    train_data, train_labels = train_data.cuda(), train_labels.cuda()
    test_data, test_labels = test_data.cuda(), test_labels.cuda()

train_loader = DataLoader(TensorDataset(train_data, train_labels), BATCH_SIZE, shuffle=True)
test_loader = DataLoader(TensorDataset(test_data, test_labels), BATCH_SIZE)

N_FEATURES = train_data.size()[1]
N_CLASSES = int(train_labels.max() - train_labels.min() + 1)

In [None]:
class CNN(nn.Module):
    def __init__(self, n_classes):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, 2, 1)
        self.conv2 = nn.Conv2d(16, 8, 3, 2, 1)
        self.linear = nn.Linear(8, n_classes)
    
    def forward(self, x):
        if x.dim() != 4:
            x = x.view(-1, 1, 28, 28)
        x = F.tanh(self.conv1(x))
        x = F.tanh(self.conv2(x))
        x = F.avg_pool2d(x, 7)
        x = self.linear(x.view(-1, 8))
        return x

In [None]:
# c = my.MLP((N_FEATURES,) + (64,) * 3 + (N_CLASSES,), F.relu)
c = CNN(N_CLASSES)
if cuda:
    c.cuda()
optim = Adam(c.parameters(), lr=0.001)
EPOCHS = 10
for i in range(EPOCHS):
    for X, y in train_loader:
        if cuda:
            X, y = X.cuda(), y.cuda()
        X, y = Variable(X), Variable(y)
        loss = CrossEntropyLoss()(c(X), y)
        optim.zero_grad()
        loss.backward()
        optim.step()
    accuracy = my.global_stats(c, test_loader, my.accuracy)
    print('[epoch %d]cross-entropy loss: %f, accuracy: %f' % ((i + 1), float(loss), float(accuracy)))

In [None]:
y_bar = my.predict(c, test_data)
accuracy = my.accuracy(y_bar, test_labels)
precision = my.nd_precision(y_bar, test_labels, N_CLASSES)
recall = my.nd_recall(y_bar, test_labels, N_CLASSES)
f1 = my.nd_f_beta(y_bar, test_labels, N_CLASSES)
print('accuracy: %f, precision: %f, recall: %f, f1: %f' % tuple(map(float, (accuracy, precision, recall, f1))))

In [None]:
SAMPLE_SIZE = 64
BATCH_SIZE = 16

def L(classifier, X, y):
    y_bar = my.predict(classifier, X)
    return my.nd_f_beta(y_bar, y, N_CLASSES)

def forward(classifier, pair):
    X, y = pair
    y = my.onehot(y, N_CLASSES)
    y_bar = F.softmax(classifier(X), 1)
    return th.cat((y, y_bar), 1).view(1, -1)
    
def sample():
    samples = [my.sample_subset(train_data_np, train_labels_np, SAMPLE_SIZE) for k in range(BATCH_SIZE)]
    if cuda:
        samples = [(X.cuda(), y.cuda()) for (X, y) in samples]
    return [(Variable(X), Variable(y)) for (X, y) in samples]

In [None]:
c = nn.Linear(N_FEATURES, N_CLASSES)
# c = CNN(N_CLASSES)
# critic = my.MLP(((N_CLASSES +  N_CLASSES) * SAMPLE_SIZE,) + (1024,) * 3 +(1,), F.relu)
critic = my.RN(SAMPLE_SIZE, 2 * N_CLASSES, (1024,) * 3 + (1,), F.relu)

if cuda:
    c.cuda()
    critic.cuda()

# c_optim = SGD(c.parameters(), 0.1, momentum=0.9)
# critic_optim = SGD(critic.parameters(), 0.1, momentum=0.9)
c_optim = Adam(c.parameters(), 1e-3)
critic_optim = Adam(critic.parameters(), 1e-3)

float(my.nd_f_beta(my.predict(c, test_data), test_labels, N_CLASSES))

In [None]:
STD = 0.1
OUTER = 500
INNER = 10

stats = []
for i in range(OUTER):
    f1 = L(c, train_data, train_labels)
    f1_list = []
    for j in range(INNER):
        c_bar = my.perturb(c, STD)
        f1_bar = L(c_bar, train_data, train_labels)
        delta = f1 - f1_bar
        f1_list.append(float(f1))

        samples = sample()
        y = th.cat(tuple(map(lambda x: forward(c, x), samples)), 0)
        y_bar = th.cat(tuple(map(lambda x: forward(c_bar, x), samples)), 0)
        delta_ = th.mean(critic(y) - critic(y_bar), 0)
        
        mse = MSELoss()(delta_, delta)
        critic_optim.zero_grad()
        mse.backward()
        critic_optim.step()
    
    samples = sample()
    y = th.cat(tuple(map(lambda x: forward(c, x), samples)), 0)
    objective = -th.mean(critic(y))
    c_optim.zero_grad()
    objective.backward()
    c_optim.step()
    stats.append(f1_list)
    
    if (i + 1) % 100 == 0:
        y_bar = my.predict(c, test_data)
        f1 = my.nd_f_beta(y_bar, test_labels, N_CLASSES)
        print('[iteration %d]mse: %f, objective: %f, f1: %f' % ((i + 1), float(mse), float(objective), float(f1)))

In [None]:
y_bar = my.predict(c, test_data)
accuracy = my.accuracy(y_bar, test_labels)
precision = my.nd_precision(y_bar, test_labels, N_CLASSES)
recall = my.nd_recall(y_bar, test_labels, N_CLASSES)
f1 = my.nd_f_beta(y_bar, test_labels, N_CLASSES)
print('accuracy: %f, precision: %f, recall: %f, f1: %f' % tuple(map(float, (accuracy, precision, recall, f1))))

In [None]:
import matplotlib.pylab as pl
%matplotlib inline

In [None]:
pl.plot(range(sum(map(len, stats))), sum(stats, []))

In [None]:
pl.plot(range(len(stats)), tuple(map(max, stats)))

In [None]:
# th.save(critic.state_dict(), 'mnist_critic')