In [67]:
import torch as th
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.modules.loss import MSELoss
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import my

In [46]:
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.131,), (0.308,))])
mnist_train = MNIST('MNIST/', True, transform=trans)
train_loader = DataLoader(mnist_train, 64, True)
mnist_test = MNIST('MNIST/', False, transform=trans)
test_loader = DataLoader(mnist_test, 1024)

In [42]:
mnist_train.train_data.max()

255

In [60]:
class CNN(nn.Module):
    def __init__(self, out_channels):
        super(CNN, self).__init__()
        self.conv2d = nn.Conv2d(1, out_channels, 5, 2)
        self.in_features = out_channels * 12 * 12
        self.linear = nn.Linear(self.in_features, 10)
        self.expose = False
    
    def forward(self, X):
        if X.dim() == 3:
            N, H, W = X.size()
            X = X.view(N, 1, H, W)
        conv2d = self.conv2d(X)
        tanh = th.tanh(conv2d)
        linear = self.linear(tanh.view(-1, self.in_features))
        if self.expose:
            return tanh.view(tanh.size()[0], -1), F.log_softmax(linear, 1)
        else:
            return F.log_softmax(linear, 1)

In [54]:
cnn = CNN(16)
optim = Adam(cnn.parameters(), lr=0.001)

In [6]:
N_EPOCHS = 1

for e in range(N_EPOCHS):
    for i, (X, y) in enumerate(train_loader):
        X, y = Variable(X), Variable(y)
        optim.zero_grad()
        log_softmax = cnn(X)
        loss = F.nll_loss(log_softmax, y)
        loss.backward()
        optim.step()
        if (i + 1) % 100 == 0:
            accuracy = my.accuracy(my.predict(cnn, X), y)
            print('[iteration %d]nll loss: %f, accuracy %f' % (i + 1, float(loss), float(accuracy)))
        
    accuracy, precision, recall, f1 = my.global_stats(cnn, test_loader, (
        my.accuracy,
        my.nd_curry(my.nd_precision, 10),
        my.nd_curry(my.nd_recall, 10),
        my.nd_curry(my.nd_f_beta, 10)))
    print('[epoch %d]accuracy: %f, precision: %f, recall: %f, f1: %f' % (
        e + 1, float(accuracy), float(precision), float(recall), float(f1)))

[iteration 100]nll loss: 0.045671, accuracy 1.000000
[iteration 200]nll loss: 0.226835, accuracy 0.953125
[iteration 300]nll loss: 0.068923, accuracy 0.984375
[iteration 400]nll loss: 0.151858, accuracy 0.921875
[iteration 500]nll loss: 0.170338, accuracy 0.953125
[iteration 600]nll loss: 0.069826, accuracy 1.000000
[iteration 700]nll loss: 0.039108, accuracy 1.000000
[iteration 800]nll loss: 0.063316, accuracy 0.968750
[iteration 900]nll loss: 0.133404, accuracy 0.953125
[epoch 1]accuracy: 0.971800, precision: 0.971106, recall: 0.970572, f1: 0.640810


In [21]:
train_loader = DataLoader(mnist_train, 4096)
test_loader = DataLoader(mnist_test, 4096)

In [64]:
out_channels = 16
SAMPLE_SIZE = 16
D = (out_channels * 12 * 12 + 10) * SAMPLE_SIZE

c = CNN(out_channels) # the classifier
approx = nn.Sequential(
    nn.Linear(D, 256),
    nn.Tanh(),
    nn.Linear(256, 1)
) # L_\theta
c_optim = SGD(c.parameters(), 0.01)
approx_optim = SGD(approx.parameters(), 0.01)

In [68]:
OUTER = 50000
INNER = 10
STD = 1
K = 10

L = lambda c, loader: my.global_stats(c, loader, my.nd_curry(my.nd_f_beta, 10))
def data(c, X):
    return th.cat(c(X), 1).view(1, -1)
train_data, train_labels = my.th_normalize(mnist_train.train_data), mnist_train.train_labels

for i in range(OUTER):
    total_mse = 0
    total_delta, total_delta_ = 0, 0
    for j in range(INNER):
        c_bar = my.perturb(c, STD)
        c.expose = False
        c_bar.expose = False
        delta = L(c, train_loader) - L(c_bar, train_loader) # \delta = L(c, D)-L(\bar{c},D)
        c.expose = True
        c_bar.expose = True
        
        samples = [my.sample(train_data, train_labels, SAMPLE_SIZE) for k in range(K)] # D_1, ..., D_K
        c_d = th.cat(map(lambda X: data(c, X), zip(*samples)[0]), 0) # (c, D_1), ..., (c, D_K)
        c_bar_d = th.cat(map(lambda X: data(c_bar, X), zip(*samples)[0]), 0) # (c_bar, D_1), ..., (c_bar, D_K)
        # \frac1K \sum_{i = 1}^K \delta_i, where \delta_i = L_\theta(c, D_i) - L_\theta(\bar{c}, D_i)
        delta_ = th.mean(approx(c_d) - approx(c_bar_d), 0)
        
        total_delta += abs(float(delta))
        total_delta_ += abs(float(delta_))

        # \arg \min_\theta (\delta - \frac1K \sum_{i = 1}^K \delta_i)^2
        mse = MSELoss()(delta_, delta)
        approx_optim.zero_grad()
        mse.backward()
        approx_optim.step()
        total_mse += float(mse)
    
#     if (i + 1) % 100 == 0:
#         print('[iteration %d]mse: %f, delta: %f, delta_: %f' % (
#             (i + 1), total_mse / (j + 1), total_delta / (j + 1), total_delta_ / (j + 1)))
        
    samples = [sample(train_data, train_labels) for k in range(K)] # D_1, ..., D_K
    c_d = th.cat(map(lambda X: data(c, X), zip(*samples)[0]), 0) # (c, D_1), ..., (c, D_K)
    # \arg \min_c \frac1K \sum_{i = 1}^K L_\theta (c, D_i)
    objective = -th.mean(approx(c_d))
    c_optim.zero_grad()
    objective.backward()
    c_optim.step()
    
    if (i + 1) % 1000 == 0:
        y_bar = predict(c, test_data)
        f1 = f_beta(y_bar, test_labels)
        print('[iteration %d]objective: %f, f1: %f' % ((i + 1), float(objective), float(f1)))

KeyboardInterrupt: 

In [27]:
X, y = my.sample(mnist_train.train_data, mnist_train.train_labels, SAMPLE_SIZE)

In [28]:
X.size(), y.size()

(torch.Size([16, 28, 28]), torch.Size([16]))

In [39]:
type(mnist_train.train_data)

torch.ByteTensor

In [51]:
my = reload(my)