In [1]:
import torch as th
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.modules.loss import MSELoss
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import my

In [2]:
class MLP(nn.Module):
    def __init__(self, D, nonlinear):
        super(MLP, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(D[i], D[i + 1]) for i in range(len(D) - 1)])
        self.nonlinear = nonlinear
        self.expose = False
    
    def forward(self, x):
        if x.dim != 2:
            x = x.view(x.size()[0], -1)
        for i, linear in enumerate(self.linears):
            x = linear(x)
            if i < len(self.linears) - 1:
                x = self.nonlinear(x)
        return F.log_softmax(x, 1)

In [3]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.131,), (0.308,))])
mnist_train = MNIST('MNIST/', True, transform=trans)
train_loader = DataLoader(mnist_train, 64, True)
mnist_test = MNIST('MNIST/', False, transform=trans)
test_loader = DataLoader(mnist_test, 1024)

mlp = MLP((28 * 28, 10), th.tanh)
optim = Adam(mlp.parameters(), lr=0.001)

N_EPOCHS = 1
for e in range(N_EPOCHS):
    for i, (X, y) in enumerate(train_loader):
        X, y = Variable(X), Variable(y)
        optim.zero_grad()
        log_softmax = mlp(X)
        loss = F.nll_loss(log_softmax, y)
        loss.backward()
        optim.step()
        if (i + 1) % 100 == 0:
            accuracy = my.accuracy(my.predict(mlp, X), y)
            print('[iteration %d]nll loss: %f, accuracy %f' % (i + 1, float(loss), float(accuracy)))
        
    accuracy, precision, recall, f1 = my.global_stats(mlp, test_loader, (
        my.accuracy,
        my.nd_curry(my.nd_precision, 10),
        my.nd_curry(my.nd_recall, 10),
        my.nd_curry(my.nd_f_beta, 10)))
    print('[epoch %d]accuracy: %f, precision: %f, recall: %f, f1: %f' % (
        e + 1, float(accuracy), float(precision), float(recall), float(f1)))

[iteration 100]nll loss: 0.475699, accuracy 0.859375
[iteration 200]nll loss: 0.216485, accuracy 0.953125
[iteration 300]nll loss: 0.240146, accuracy 0.890625
[iteration 400]nll loss: 0.467797, accuracy 0.859375
[iteration 500]nll loss: 0.473764, accuracy 0.875000
[iteration 600]nll loss: 0.388439, accuracy 0.890625
[iteration 700]nll loss: 0.447029, accuracy 0.906250
[iteration 800]nll loss: 0.273371, accuracy 0.937500
[iteration 900]nll loss: 0.167352, accuracy 0.968750
[epoch 1]accuracy: 0.915000, precision: 0.914516, recall: 0.913974, f1: 0.913725


In [4]:
N_TRAIN, N_TEST = 1000, 1000
# train_data, train_labels, test_data, test_labels = my.unbalanced_mnist(N_TRAIN, N_TEST, pca=True)
# train_loader = DataLoader(th.utils.data.TensorDataset(train_data.data, train_labels.data), train_data.size()[0])
train_data = my.th_normalize(mnist_train.train_data).view(mnist_train.train_data.size()[0], -1)
train_labels = mnist_train.train_labels
# train_labels = train_labels.numpy()
# train_labels[train_labels != 0] = 1
# train_labels = th.from_numpy(train_labels)
train_loader = DataLoader(
    th.utils.data.TensorDataset(*my.sample_subset(train_data, train_labels, N_TRAIN, False)), N_TRAIN)

d = int(train_labels.max() - train_labels.min() + 1)
c = MLP((train_data.size()[1], d), th.tanh) # the classifier
c_optim = SGD(c.parameters(), 0.01)

SAMPLE_SIZE = 16
D = (train_data.size()[1] + d + d) * SAMPLE_SIZE + sum(p.numel() for p in c.parameters())
approx = nn.Sequential(
    nn.Linear(D, 256),
    nn.Tanh(),
    nn.Linear(256, 1)
) # L_\theta
approx_optim = SGD(approx.parameters(), 0.01)
float(my.global_stats(c, train_loader, my.nd_curry(my.nd_f_beta, d)))

0.08324562013149261

In [None]:
OUTER = 50000
INNER = 10
STD = 5
K = 10

L = lambda c, loader: my.global_stats(c, loader, my.nd_curry(my.nd_f_beta, d))
def data(classifier, pair):
    X, y = pair
    y_bar = classifier(X)
    exposure = th.cat((X, my.onehot(y.view(y.size()[0], 1), d), y_bar), 1).view(1, -1)
    parameters = list(map(lambda p: p.view(1, -1), classifier.parameters()))
    return th.cat([exposure] + parameters, 1)

for i in range(OUTER):
    total_mse = 0
    total_delta, total_delta_ = 0, 0
    for j in range(INNER):
        c_bar = my.perturb(c, STD)
        delta = L(c, train_loader) - L(c_bar, train_loader)
        
        samples = [my.sample_subset(train_data, train_labels, SAMPLE_SIZE) for k in range(K)]
        c_d = th.cat(map(lambda X: data(c, X), samples), 0)
        c_bar_d = th.cat(map(lambda X: data(c_bar, X), samples), 0)
        delta_bar = th.mean(approx(c_d) - approx(c_bar_d), 0)
        
        mse = MSELoss()(delta_bar, delta)
        approx_optim.zero_grad()
        mse.backward()
        approx_optim.step()
        
    samples = [my.sample_subset(train_data, train_labels, SAMPLE_SIZE) for k in range(K)]
    c_d = th.cat(map(lambda X: data(c, X), samples), 0)
    f1_bar = -th.mean(approx(c_d))
    c_optim.zero_grad()
    f1_bar.backward()
    c_optim.step()

    if (i + 1) % 10 == 0:
        f1 = L(c, train_loader)
        print('[iteration %d]mse: %f, f1: %f' % ((i + 1), float(mse), float(f1)))

[iteration 10]mse: 0.005802, f1: 0.084248
[iteration 20]mse: 0.002567, f1: 0.083122
[iteration 30]mse: 0.000086, f1: 0.082517
[iteration 40]mse: 0.005454, f1: 0.083266
[iteration 50]mse: 0.000002, f1: 0.082339
[iteration 60]mse: 0.000501, f1: 0.082291
[iteration 70]mse: 0.004007, f1: 0.086820
[iteration 80]mse: 0.000006, f1: 0.088180
[iteration 90]mse: 0.000291, f1: 0.089575
[iteration 100]mse: 0.000137, f1: 0.091936
[iteration 110]mse: 0.000001, f1: 0.089893
[iteration 120]mse: 0.000015, f1: 0.089094
[iteration 130]mse: 0.009502, f1: 0.090232
[iteration 140]mse: 0.000210, f1: 0.091078
[iteration 150]mse: 0.000024, f1: 0.092304
[iteration 160]mse: 0.000030, f1: 0.094909
[iteration 170]mse: 0.001438, f1: 0.094989
[iteration 180]mse: 0.000172, f1: 0.095354
[iteration 190]mse: 0.001220, f1: 0.096929
[iteration 200]mse: 0.000019, f1: 0.096896
[iteration 210]mse: 0.000203, f1: 0.097210
[iteration 220]mse: 0.000034, f1: 0.097107
[iteration 230]mse: 0.000372, f1: 0.097958
[iteration 240]mse: 