In [None]:
import torch as th
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.modules.loss import MSELoss
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import my

In [None]:
class MLP(nn.Module):
    def __init__(self, D, nonlinear):
        super(MLP, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(D[i], D[i + 1]) for i in range(len(D) - 1)])
        self.nonlinear = nonlinear
        self.expose = False
    
    def forward(self, x):
        if x.dim != 2:
            x = x.view(x.size()[0], -1)
        for i, linear in enumerate(self.linears):
            x = linear(x)
            if i < len(self.linears) - 1:
                x = self.nonlinear(x)
        return F.log_softmax(x, 1)

In [None]:
N_TRAIN, N_TEST = 1000, 1000

# train_data, train_labels, test_data, test_labels = my.unbalanced_mnist(N_TRAIN, N_TEST, pca=True)
# train_loader = DataLoader(th.utils.data.TensorDataset(train_data.data, train_labels.data), train_data.size()[0])

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.131,), (0.308,))])
mnist_train = MNIST('MNIST/', True, transform=trans)
mnist_test = MNIST('MNIST/', False, transform=trans)

train_data = my.th_normalize(mnist_train.train_data).view(mnist_train.train_data.size()[0], -1)
train_labels = mnist_train.train_labels
train_labels = train_labels.numpy()
train_labels[train_labels > 3] = 4
train_labels = th.from_numpy(train_labels)
train_loader = DataLoader(
    th.utils.data.TensorDataset(*my.sample_subset(train_data, train_labels, N_TRAIN, False)), N_TRAIN)

mlp = MLP((28 * 28, 10), th.tanh)
optim = Adam(mlp.parameters(), lr=0.001)

N_EPOCHS = 1
for e in range(N_EPOCHS):
    for i, (X, y) in enumerate(train_loader):
        X, y = Variable(X), Variable(y)
        optim.zero_grad()
        log_softmax = mlp(X)
        loss = F.nll_loss(log_softmax, y)
        loss.backward()
        optim.step()
        if (i + 1) % 100 == 0:
            accuracy = my.accuracy(my.predict(mlp, X), y)
            print('[iteration %d]nll loss: %f, accuracy %f' % (i + 1, float(loss), float(accuracy)))
        
    accuracy, precision, recall, f1 = my.global_stats(mlp, test_loader, (
        my.accuracy,
        my.nd_curry(my.nd_precision, 10),
        my.nd_curry(my.nd_recall, 10),
        my.nd_curry(my.nd_f_beta, 10)))
    print('[epoch %d]accuracy: %f, precision: %f, recall: %f, f1: %f' % (
        e + 1, float(accuracy), float(precision), float(recall), float(f1)))

In [None]:
d = int(train_labels.max() - train_labels.min() + 1)
c = MLP((train_data.size()[1], d), th.tanh) # the classifier
c_optim = SGD(c.parameters(), 0.01)

SAMPLE_SIZE = 16
D = (train_data.size()[1] + d + d) * SAMPLE_SIZE + sum(p.numel() for p in c.parameters())
approx = nn.Sequential(
    nn.Linear(D, 256),
    nn.Tanh(),
    nn.Linear(256, 1)
) # L_\theta
approx_optim = SGD(approx.parameters(), 0.01)
float(my.global_stats(c, train_loader, my.nd_curry(my.nd_f_beta, d)))

In [None]:
OUTER = 50000
INNER = 10
STD = 5
K = 10

L = lambda c, loader: my.global_stats(c, loader, my.nd_curry(my.nd_f_beta, d))
def data(classifier, pair):
    X, y = pair
    y_bar = classifier(X)
    exposure = th.cat((X, my.onehot(y.view(y.size()[0], 1), d), y_bar), 1).view(1, -1)
    parameters = list(map(lambda p: p.view(1, -1), classifier.parameters()))
    return th.cat([exposure] + parameters, 1)

for i in range(OUTER):
    total_mse = 0
    total_delta, total_delta_ = 0, 0
    for j in range(INNER):
        c_bar = my.perturb(c, STD)
        delta = L(c, train_loader) - L(c_bar, train_loader)
        
        samples = [my.sample_subset(train_data, train_labels, SAMPLE_SIZE) for k in range(K)]
        c_d = th.cat(map(lambda X: data(c, X), samples), 0)
        c_bar_d = th.cat(map(lambda X: data(c_bar, X), samples), 0)
        delta_bar = th.mean(approx(c_d) - approx(c_bar_d), 0)
        
        mse = MSELoss()(delta_bar, delta)
        approx_optim.zero_grad()
        mse.backward()
        approx_optim.step()
        
    samples = [my.sample_subset(train_data, train_labels, SAMPLE_SIZE) for k in range(K)]
    c_d = th.cat(map(lambda X: data(c, X), samples), 0)
    f1_bar = -th.mean(approx(c_d))
    c_optim.zero_grad()
    f1_bar.backward()
    c_optim.step()

    if (i + 1) % 10 == 0:
        f1 = L(c, train_loader)
        print('[iteration %d]mse: %f, f1: %f' % ((i + 1), float(mse), float(f1)))