In [135]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
from sklearn.metrics import roc_auc_score

from importlib import reload
import matplotlib.pyplot as plt

import utils.models as models
import utils.plotting as plotting
import utils.dataloaders as dl
import utils.traintest as tt
import utils.adversarial as adv
import utils.eval as ev
import model_params as params
import utils.resnet_orig as resnet
import utils.gmm_helpers as gmm_helpers
import utils.odin as odin

from tensorboardX import SummaryWriter
%load_ext line_profiler

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [3]:
device = torch.device('cuda:0')

dataset = 'MNIST'

model_params = params.params_dict[dataset](augm_flag=True, batch_size=128)

file = 'base_MNIST_lr0.001_augm_flagTrue_train_typeplain.pth'
base_model = torch.load('SavedModels/base/' + file).to(device)
ODIN_model = models.LeNetODIN(base_model, 10.)

In [47]:
base_model = torch.load('SavedModels/base/' + file).to(device)
ODIN_model = odin.LeNetODIN(base_model, 10.)

In [82]:
def FGSM(model, device, seed, epsilon=0.1):
    with torch.no_grad():
        data = seed.clone().to(device).requires_grad_()

    with torch.enable_grad():
        y = model(data)
        losses = y.max(1)[0]
        losses.sum().backward()

    with torch.no_grad():
        data += epsilon * data.grad
        data = torch.clamp(data, 0, 1).requires_grad_()
    return data.detach()

In [128]:
seed = torch.rand(10, 1, 28, 28, device=device)

In [131]:
adv_noise = FGSM(odin, device, seed)

In [116]:
def get_auroc(model_list, model_params, stats, device):
    auroc = []
    fp95 = []
    for i, model in enumerate(model_list):
        with torch.no_grad():
            conf = []
            for data, _ in model_params.test_loader:
                data = data.to(device)

                output = model(data).max(1)[0].exp()

                conf.append(output.cpu())

        conf = torch.cat(conf, 0)

        y_true = torch.cat([torch.ones_like(conf.cpu()), 
                            torch.zeros_like(stats[i])]).cpu().numpy()
        y_scores = torch.cat([conf.cpu(), 
                              stats[i]]).cpu().numpy()

        auroc.append(roc_auc_score(y_true, y_scores))
        fp95.append( ((stats[i] > 0.95).float().mean()).item() )
    return auroc, fp95

In [108]:
def aggregate_stats(model_list, device, shape, classes=10, 
                    batches=10, batch_size=100):
    stats = []

    for _ in range(batches):
        seed = torch.rand((batch_size,)+tuple(shape), device=device)
        
        batch_stats = []
        for i, model in enumerate(model_list):
            batch_stats.append(model(seed).max(1)[0].exp().detach().cpu().clone())
            
        batch_stats = torch.stack(batch_stats, 0)
        stats.append(batch_stats.clone())

    stats = torch.cat(stats, -1)
    
    return stats

In [117]:
a, b = get_auroc(model_list, model_params, stats, device)

In [None]:
temperatures = [1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]

epsilons = np.linspace(0, 0.004, 21)

grid = []

for T in temperatures:
    vec = []
    for eps in epsilons:
        model = odin.ModelODIN(odin.LeNetTemp(base_model, T), 
                                    eps, device=device)
        stats = aggregate_stats([model], device, seed[0].shape, 
                                classes=model_params.classes)
        auroc, fp95 = get_auroc([model], model_params, stats, device)
        vec.append(auroc + fp95)
    grid.append(vec)

In [None]:
auroc = torch.tensor(grid)[:,:,0]
fp95 = torch.tensor(grid)[:,:,1]

ind = auroc.view(-1).argmax().item()

xv, yv = np.meshgrid(epsilons, temperatures)
T = yv.reshape(-1)[ind]
eps = xv.reshape(-1)[ind]

In [None]:
odin_model = odin.ModelODIN(odin.LeNetTemp(base_model, T), 
                                           eps, device=device)
stats = aggregate_stats([model], device, seed[0].shape, 
                        classes=model_params.classes)
a, b = get_auroc([model], model_params, stats, device)

In [178]:
a

[0.9923061000000001]

In [179]:
b

[0.0]

In [182]:
auroc

tensor([[0.9712, 0.9723, 0.9711],
        [0.9824, 0.9820, 0.9818],
        [0.9907, 0.9909, 0.9909],
        [0.9925, 0.9924, 0.9925]])

In [180]:
T

10

In [181]:
eps

0.0