In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
import torchvision.datasets
from torchvision import transforms
import tqdm

import scipy.stats
import numpy as np

In [24]:
from models import instantiate_MLP_model
from training import Trainer, AdversarialTrainer
from data import SplitMNIST

In [27]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set()

In [28]:
# Load the data
class ReshapeTransform:
    def __init__(self, new_shape):
        self.new_shape = new_shape

    def __call__(self, x):
        return x.view(self.new_shape)

mnist_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)),
                                      ReshapeTransform([28*28])])
mnist_train = torchvision.datasets.MNIST(root='./data', train=True,
                                         transform=mnist_transform,
                                         download=True)
mnist_test = torchvision.datasets.MNIST(root='./data', train=False, transform=mnist_transform,
                                         target_transform=None, download=True)

In [29]:
# Split mnist
mnist_id_train = SplitMNIST(mnist_train, classes=list(range(5)), transform=mnist_transform)
mnist_id_test = SplitMNIST(mnist_train, classes=list(range(5)), transform=mnist_transform)
mnist_ood_train = SplitMNIST(mnist_train, classes=list(range(5, 10)), transform=mnist_transform)
mnist_ood_test = SplitMNIST(mnist_train, classes=list(range(5, 10)), transform=mnist_transform)

In [30]:
num_models = 10
batch_size = 100
n_epochs=100

#### Train vanilla ensemble

In [31]:
vanilla_models = []
for i in range(num_models):
    model = instantiate_MLP_model()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    criterion = torch.nn.CrossEntropyLoss()
    trainer = Trainer(model, criterion, mnist_id_train, mnist_id_test, optimizer, scheduler=None,
                      batch_size=batch_size, num_workers=0)
    trainer.train(n_epochs)
    vanilla_models.append(model)
    print('Finished training model {i+1}/{num_models}')

Epoch 1:	Test Loss: 0.169425;	Train Loss: 0.206985;	Test Acc.: 0.952; 	Train Acc.: 0.9042; 	Time per epoch: 18.7s


KeyboardInterrupt: 

In [None]:
at_models = []
for i in range(num_models):
    model = instantiate_MLP_model()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
    criterion = torch.nn.CrossEntropyLoss()
    trainer = Trainer(model, criterion, mnist_id_train, mnist_id_test, optimizer, scheduler=None,
                      batch_size=batch_size, num_workers=0)
    trainer.train(n_epochs)
    vanilla_models.append(model)
    print('Finished training model {i+1}/{num_models}')

In [9]:
trainer.train(40)

KeyboardInterrupt: 

In [20]:
# Compute the data range
def compute_dataset_range(dataset):
    data_max = -np.ones(dataset[0][0].shape, dtype=np.float64)*np.inf
    data_min = np.ones_like(data_max)*np.inf
    for i in tqdm.tqdm(range(len(dataset))):
        x = mnist_train[i][0].numpy()
        data_max = np.maximum(data_max, x)
        data_min = np.minimum(data_min, x)
    return data_max - data_min

In [21]:
mnist_range = compute_dataset_range(mnist_train)


  0%|          | 0/60000 [00:00<?, ?it/s][A
  1%|          | 513/60000 [00:00<00:11, 5126.11it/s][A
  2%|▏         | 1194/60000 [00:00<00:10, 5535.95it/s][A
  3%|▎         | 1868/60000 [00:00<00:09, 5848.42it/s][A
  4%|▍         | 2520/60000 [00:00<00:09, 6033.30it/s][A
  5%|▌         | 3209/60000 [00:00<00:09, 6266.71it/s][A
  6%|▋         | 3896/60000 [00:00<00:08, 6436.03it/s][A
  8%|▊         | 4592/60000 [00:00<00:08, 6582.79it/s][A
  9%|▉         | 5257/60000 [00:00<00:08, 6602.34it/s][A
 10%|▉         | 5929/60000 [00:00<00:08, 6635.28it/s][A
 11%|█         | 6595/60000 [00:01<00:08, 6641.74it/s][A
 12%|█▏        | 7247/60000 [00:01<00:08, 6574.99it/s][A
 13%|█▎        | 7914/60000 [00:01<00:07, 6602.05it/s][A
 14%|█▍        | 8569/60000 [00:01<00:07, 6517.07it/s][A
 15%|█▌        | 9232/60000 [00:01<00:07, 6549.96it/s][A
 17%|█▋        | 9916/60000 [00:01<00:07, 6633.70it/s][A
 18%|█▊        | 10578/60000 [00:01<00:07, 6532.88it/s][A
 19%|█▊        | 11231/600

In [16]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
criterion = torch.nn.CrossEntropyLoss()
batch_size = 100

trainer = AdversarialTrainer(model, criterion, mnist_train, mnist_test, optimizer, scheduler=None,
                             batch_size=batch_size, num_workers=0, adv_example_epsilon=0.01,
                             data_range=torch.Tensor(mnist_range).view([1, -1]))

In [17]:
trainer.train(40)

Epoch 1:	Test Loss: 0.151879;	Train Loss: 0.322535;	Test Acc.: 0.9558; 	Train Acc.: 0.9238; 	Time per epoch: 26.0s
Epoch 2:	Test Loss: 0.132886;	Train Loss: 0.119982;	Test Acc.: 0.9626; 	Train Acc.: 0.9534; 	Time per epoch: 28.9s
Epoch 3:	Test Loss: 0.099735;	Train Loss: 0.115155;	Test Acc.: 0.9737; 	Train Acc.: 0.9646; 	Time per epoch: 29.1s
Epoch 4:	Test Loss: 0.103664;	Train Loss: 0.107569;	Test Acc.: 0.9735; 	Train Acc.: 0.97; 	Time per epoch: 28.6s
Epoch 5:	Test Loss: 0.102947;	Train Loss: 0.102566;	Test Acc.: 0.9739; 	Train Acc.: 0.9722; 	Time per epoch: 28.5s
Epoch 6:	Test Loss: 0.090544;	Train Loss: 0.107967;	Test Acc.: 0.9783; 	Train Acc.: 0.9754; 	Time per epoch: 28.2s
Epoch 7:	Test Loss: 0.100607;	Train Loss: 0.096192;	Test Acc.: 0.9761; 	Train Acc.: 0.9772; 	Time per epoch: 28.4s


KeyboardInterrupt: 