In [None]:
import torch as th
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
import my

In [None]:
N_TRAIN, N_TEST = 0, 0
BATCH_SIZE = 64
cuda = True

# train_data, train_labels, test_data, test_labels = my.unbalanced_cifar10(N_TRAIN, N_TEST, p=[0, 1, 10])
train_data, train_labels, test_data, test_labels = my.unbalanced_cifar10(N_TRAIN, N_TEST, p=[])

train_data_np, train_labels_np, test_data_np, test_labels_np = \
    train_data, train_labels, test_data, test_labels
    
train_data = th.from_numpy(train_data).float()
train_labels = th.from_numpy(train_labels).long()
test_data = th.from_numpy(test_data).float()
test_labels = th.from_numpy(test_labels).long()

if cuda:
    th.cuda.set_device(3)

train_loader = DataLoader(TensorDataset(train_data, train_labels), BATCH_SIZE, shuffle=True)
test_loader = DataLoader(TensorDataset(test_data, test_labels), BATCH_SIZE)

N_FEATURES = train_data.size()[1]
N_CLASSES = int(train_labels.max() - train_labels.min() + 1)

In [None]:
class CNN(nn.Module):
    def __init__(self, n_classes):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 2, 1)
        self.conv2 = nn.Conv2d(16, 8, 3, 2, 1)
        self.linear = nn.Linear(8, n_classes)
    
    def forward(self, x):
        if x.dim() != 4:
            x = x.view(-1, 3, 32, 32)
        x = F.tanh(self.conv1(x))
        x = F.tanh(self.conv2(x))
        x = F.avg_pool2d(x, 8)
        x = self.linear(x.view(-1, 8))
        return x

In [None]:
c_pretrained = CNN(N_CLASSES)
if cuda:
    c_pretrained.cuda()
optim = Adam(c_pretrained.parameters(), lr=0.001)
EPOCHS = 50
for i in range(EPOCHS):
    for x, y in train_loader:
        if cuda:
            x, y = x.cuda(), y.cuda()
        x, y = Variable(x), Variable(y)
        loss = CrossEntropyLoss()(c_pretrained(x), y)
        optim.zero_grad()
        loss.backward()
        optim.step()
    accuracy = my.global_stats(c_pretrained, test_loader, my.accuracy)
    if (i + 1) % 5 == 0:
        print('[epoch %d]cross-entropy loss: %f, accuracy: %f' % ((i + 1), float(loss), float(accuracy)))

In [None]:
f1 = my.global_stats(c_pretrained, test_loader, my.nd_curry(my.nd_f_beta, N_CLASSES))
float(f1)

In [None]:
c = CNN(N_CLASSES)
if cuda:
    c.cuda()
c.load_state_dict(c_pretrained.state_dict())
c_optim = Adam(c.parameters(), 0.01)
sample_loader = DataLoader(TensorDataset(train_data, train_labels), 4096, shuffle=True)

In [None]:
def perturb_y(y, std):
    randn = Variable(th.randn(y_bar.size()) * std)
    randn = randn.cuda() if cuda else randn
    return th.clamp(y + randn, 0, 1)

def fst(loader):
    for x, y in loader:
        break
    return x, y

In [None]:
std = 1e-1

x, y = fst(sample_loader)
if cuda:
    x, y = x.cuda(), y.cuda()
x, y = Variable(x), Variable(y)
y_bar = F.softmax(c(x), 1)
f1_bar = my.nd_f_beta(th.max(y_bar, 1)[1], y, N_CLASSES)

stats = []
for i in range(5000):
    y_per = perturb_y(y_bar, std).detach()
    f1_per = my.nd_f_beta(th.max(y_per, 1)[1], y, N_CLASSES)
    stats.append((float(f1_per), y_per))

In [None]:
s = sorted(stats, key=lambda x: x[0], reverse=True)
s[0][0]

In [None]:
N_ITERATIONS = 1000
for i in range(N_ITERATIONS):
    mse = th.mean((c(x) - s[0][1]) ** 2)
    c_optim.zero_grad()
    mse.backward()
    c_optim.step()
    if (i + 1) % 1000 == 0:
        print('[iteration %d]mse: %f' % (i + 1, float(mse)))

In [None]:
f1_c_pre_x = my.nd_f_beta(th.max(c_pretrained(x), 1)[1], y, N_CLASSES)
f1_c_x = my.nd_f_beta(th.max(c(x), 1)[1], y, N_CLASSES)
f1_c = my.global_stats(c, test_loader, my.nd_curry(my.nd_f_beta, N_CLASSES))
float(f1_c_pre_x), float(f1_c_x), float(f1_c)

In [None]:
stats = []
for x, y in sample_loader:
    if cuda:
        x, y = x.cuda(), y.cuda()
    x, y = Variable(x), Variable(y)
    f1 = my.nd_f_beta(th.max(c(x), 1)[1], y, N_CLASSES)
    stats.append(float(f1))

In [None]:
import matplotlib.pylab as pl
%matplotlib inline
pl.hist(stats)