In [None]:
import os
os.chdir('/nfs/homedirs/ayle/guided-research/SNIP-it/')

In [None]:
import torch
from torchvision import datasets, transforms
import foolbox as fb
from experiments.main import load_checkpoint
from models import GeneralModel
from models.statistics.Metrics import Metrics
from utils.config_utils import *
from utils.model_utils import *
from utils.system_utils import *
from utils.attacks_utils import get_attack
from torch.utils.data.dataset import Dataset
from copy import deepcopy
import pickle
import time

In [None]:
arguments = dict({
'eval_freq': 1000,  # evaluate every n batches
    'save_freq': 1e6,  # save model every n epochs, besides before and after training
    'batch_size': 64,  # size of batches, for Imagenette 128
    'seed': 1234,  # random seed
    'max_training_minutes': 6120 , # one hour and a 45 minutes max, process killed after n minutes (after finish of epoch)
    'plot_weights_freq': 50, # plot pictures to tensorboard every n epochs
    'prune_freq': 1, # if pruning during training: how long to wait before starting
    'prune_delay': 0, # "if pruning during training: 't' from algorithm box, interval between pruning events, default=0
    'prune_to': 10,
    'epochs': 20,
    'rewind_to': 0, # rewind to this epoch if rewinding is done
    'snip_steps': 5, # 's' in algorithm box, number of pruning steps for 'rule of thumb', TODO
    'pruning_rate': 0.0 , # pruning rate passed to criterion at pruning event. however, most override this
    'growing_rate': 0.0000 , # grow back so much every epoch (for future criterions)
    'pruning_limit': 0.00,  # Prune until here, if structured in nodes, if unstructured in weights. most criterions use this instead of the pruning_rate
    'learning_rate': 2e-3,
    'grad_clip': 10,
    'grad_noise': 0 , # added gaussian noise to gradients
    'l2_reg': 5e-5 , # weight decay
    'l1_reg': 0 , # l1-norm regularisation
    'lp_reg': 0 , # lp regularisation with p < 1
    'l0_reg': 1.0 , # l0 reg lambda hyperparam
    'hoyer_reg': 0.001 , # hoyer reg lambda hyperparam
    'beta_ema': 0.999 , # l0 reg beta ema hyperparam

    'loss': 'CrossEntropy',
    'optimizer': 'ADAM',
    'model': 'ResNet18',  # ResNet not supported with structured
    'data_set': 'CIFAR10',
    'ood_data_set': 'SVHN',
    'prune_criterion': 'EmptyCrit',  # options: SNIP, SNIPit, SNIPitDuring, UnstructuredRandom, GRASP, HoyerSquare, IMP, // SNAPit, StructuredRandom, GateDecorators, EfficientConvNets, GroupHoyerSquare
    'train_scheme': 'DefaultTrainer' , # default: DefaultTrainer
    'attack': 'FGSM',
    'epsilon': 6,
    'eval_ood_data_sets': ['SVHN', 'CIFAR100', 'GAUSSIAN', 'OODOMAIN'],
    'eval_attacks': ['FGSM', 'L2FGSM'],
    'eval_epsilons': [4, 6, 12, 24, 48],

    'device': 'cuda',
    'results_dir': "tests",

    'checkpoint_name': '2021-05-30_19.59.39_model=ResNet18_dataset=CIFAR10_prune-criterion=SNIPitDuring_pruning-limit=0.98_prune-freq=4_prune-delay=8_outer-layer-pruning=1_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=1234',
    'checkpoint_model': 'ResNet18_finished',

    'disable_cuda_benchmark': 1 , # speedup (disable) vs reproducibility (leave it)
    'eval': 0,
    'disable_autoconfig': 0 , # for the brave
    'preload_all_data': 0 , # load all data into ram memory for speedups
    'tuning': 0 , # splits trainset into train and validationset, omits test set

    'track_weights': 0 , # "keep statistics on the weights through training
    'disable_masking': 1 , # disable the ability to prune unstructured
    'enable_rewinding': 0 , # enable the ability to rewind to previous weights
    'outer_layer_pruning': 1, # allow to prune outer layers (unstructured) or not (structured)
    'first_layer_dense': 0,
    'random_shuffle_labels': 0  ,# run with random-label experiment from zhang et al
    'l0': 0,  # run with l0 criterion, might overwrite some other arguments
    'hoyer_square': 0, # "run in unstructured DeephoyerSquare criterion, might overwrite some other arguments
    'group_hoyer_square': 0 ,# run in unstructured Group-DeephoyerSquare criterion, might overwrite some other arguments

    'disable_histograms': 0,
    'disable_saliency': 0,
    'disable_confusion': 0,
    'disable_weightplot': 0,
    'disable_netplot': 0,
    'skip_first_plot': 0,
    'disable_activations': 0,
    
#    'input_dim': [1, 28, 28],
#    'output_dim': 10,
#    'hidden_dim': [512],
#    'N': 60000,
    
    'input_dim': [3, 32, 32],
      'output_dim': 10,
      'hidden_dim': [512],
      'N': 60000
})

In [None]:
DATASET_PATH = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/data'

In [None]:
metrics = Metrics()
out = metrics.log_line
metrics._batch_size = arguments['batch_size']
metrics._eval_freq = arguments['eval_freq']
set_results_dir(arguments["results_dir"])

In [None]:
model: GeneralModel = find_right_model(
        NETWORKS_DIR, arguments['model'],
        device=arguments['device'],
        hidden_dim=arguments['hidden_dim'],
        input_dim=arguments['input_dim'],
        output_dim=arguments['output_dim'],
        is_maskable=arguments['disable_masking'],
        is_tracking_weights=arguments['track_weights'],
        is_rewindable=arguments['enable_rewinding'],
        is_growable=arguments['growing_rate'] > 0,
        outer_layer_pruning=arguments['outer_layer_pruning'],
        maintain_outer_mask_anyway=(
                                       not arguments['outer_layer_pruning']) and (
                                           "Structured" in arguments['prune_criterion']),
        l0=arguments['l0'],
        l0_reg=arguments['l0_reg'],
        N=arguments['N'],
        beta_ema=arguments['beta_ema'],
        l2_reg=arguments['l2_reg']
).to(arguments['device'])

In [None]:
model.eval()

In [None]:
def load_checkpoint(path, model, out):
    state = DATA_MANAGER.load_python_obj(path)
    try:
        model.load_state_dict(state)
    except KeyError as e:
        print(list(state.keys()))
        raise e
    out(f"Loaded checkpoint {path}")

In [None]:
def load_model(path, out):
    with open(path, 'rb') as f:
        model = pickle.load(f)
    out(f"Loaded checkpoint {path}")
    return model

In [None]:
path1 = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/VGG16/2021-06-02_21.28.54_model=VGG16_dataset=CIFAR10_prune-criterion=StructuredEFGit_pruning-limit=0.9_prune-freq=1_prune-delay=0_outer-layer-pruning=0_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=3456/models/VGG16_mod_finished.pickle'
path2 = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/VGG16/2021-06-02_21.28.55_model=VGG16_dataset=CIFAR10_prune-criterion=StructuredEFGit_pruning-limit=0.9_prune-freq=1_prune-delay=0_outer-layer-pruning=0_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=2345/models/VGG16_mod_finished.pickle'
path3 = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/VGG16/2021-06-02_21.28.56_model=VGG16_dataset=CIFAR10_prune-criterion=StructuredEFGit_pruning-limit=0.9_prune-freq=1_prune-delay=0_outer-layer-pruning=0_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/VGG16_mod_finished.pickle'
path4 = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/VGG16/2021-06-02_21.48.55_model=VGG16_dataset=CIFAR10_prune-criterion=StructuredEFGit_pruning-limit=0.9_prune-freq=1_prune-delay=0_outer-layer-pruning=0_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=4567/models/VGG16_mod_finished.pickle'
path5 = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/VGG16/2021-06-02_22.32.42_model=VGG16_dataset=CIFAR10_prune-criterion=StructuredEFGit_pruning-limit=0.9_prune-freq=1_prune-delay=0_outer-layer-pruning=0_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=5678/models/VGG16_mod_finished.pickle'

In [None]:
model1 = load_model(path1, out)
model2 = load_model(path2, out)
model3 = load_model(path3, out)
model4 = load_model(path4, out)
model5 = load_model(path5, out)

In [None]:
path1 = '/nfs/students/ayle/guided-research/results/ResNet18/2021-05-30_16.14.35_model=ResNet18_dataset=CIFAR10_ood-dataset=SVHN_attack=FGSM_epsilon=4_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished'
#path1 = '/nfs/students/ayle/guided-research/results/ResNet18/2021-05-30_21.19.53_model=ResNet18_dataset=CIFAR10_prune-criterion=EarlyJohn_pruning-limit=0.98_prune-freq=4_prune-delay=8_outer-layer-pruning=1_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished'
# path1 = '/nfs/students/ayle/guided-research/results/ResNet18/2021-05-30_20.36.55_model=ResNet18_dataset=CIFAR10_prune-criterion=EarlyJohn_pruning-limit=0.8_prune-freq=4_prune-delay=8_outer-layer-pruning=1_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished'
# path1 = '/nfs/students/ayle/guided-research/results/VGG16/2021-05-31_00.45.30_model=VGG16_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.93_prune-freq=4_prune-delay=8_outer-layer-pruning=0_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/VGG16_finished'
model1 = deepcopy(model) 
load_checkpoint(path1, model1, out)

In [None]:
path2 = '/nfs/students/ayle/guided-research/results/ResNet18/2021-05-31_16.23.52_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=2345/models/ResNet18_finished'
#path2 = '/nfs/students/ayle/guided-research/results/ResNet18/2021-06-01_00.40.07_model=ResNet18_dataset=CIFAR10_prune-criterion=EarlyJohn_pruning-limit=0.98_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=2345/models/ResNet18_finished'
# path2 = '/nfs/students/ayle/guided-research/results/ResNet18/2021-06-01_00.40.07_model=ResNet18_dataset=CIFAR10_prune-criterion=EarlyJohn_pruning-limit=0.8_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=2345/models/ResNet18_finished'
# path2 = '/nfs/students/ayle/guided-research/results/VGG16/2021-05-31_01.03.02_model=VGG16_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.93_prune-freq=4_prune-delay=8_outer-layer-pruning=0_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=2345/models/VGG16_finished'
model2 = deepcopy(model) 
load_checkpoint(path2, model2, out)

In [None]:
path3 = '/nfs/students/ayle/guided-research/results/ResNet18/2021-05-31_16.23.51_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=3456/models/ResNet18_finished'
#path3 = '/nfs/students/ayle/guided-research/results/ResNet18/2021-06-01_02.14.47_model=ResNet18_dataset=CIFAR10_prune-criterion=EarlyJohn_pruning-limit=0.98_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=3456/models/ResNet18_finished'
# path3 = '/nfs/students/ayle/guided-research/results/ResNet18/2021-06-01_01.02.17_model=ResNet18_dataset=CIFAR10_prune-criterion=EarlyJohn_pruning-limit=0.8_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=3456/models/ResNet18_finished'
# path3 = '/nfs/students/ayle/guided-research/results/VGG16/2021-05-31_02.18.47_model=VGG16_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.93_prune-freq=4_prune-delay=8_outer-layer-pruning=0_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=3456/models/VGG16_finished'
model3 = deepcopy(model) 
load_checkpoint(path3, model3, out)

In [None]:
path4 = '/nfs/students/ayle/guided-research/results/ResNet18/2021-05-31_16.23.49_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=4567/models/ResNet18_finished'
#path4 = '/nfs/students/ayle/guided-research/results/ResNet18/2021-06-01_03.04.38_model=ResNet18_dataset=CIFAR10_prune-criterion=EarlyJohn_pruning-limit=0.98_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=4567/models/ResNet18_finished'
# path4 = '/nfs/students/ayle/guided-research/results/ResNet18/2021-06-01_02.23.39_model=ResNet18_dataset=CIFAR10_prune-criterion=EarlyJohn_pruning-limit=0.8_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=4567/models/ResNet18_finished'
# path4 = '/nfs/students/ayle/guided-research/results/VGG16/2021-05-31_02.56.03_model=VGG16_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.93_prune-freq=4_prune-delay=8_outer-layer-pruning=0_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=4567/models/VGG16_finished'
model4 = deepcopy(model) 
load_checkpoint(path4, model4, out)

In [None]:
path5 = '/nfs/students/ayle/guided-research/results/ResNet18/2021-05-31_16.23.50_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=5678/models/ResNet18_finished'
#path5 = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNet18/2021-06-01_16.37.04_model=ResNet18_dataset=CIFAR10_prune-criterion=EarlyJohn_pruning-limit=0.98_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=5678/models/ResNet18_finished'
# path5 = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNet18/2021-06-01_15.27.48_model=ResNet18_dataset=CIFAR10_prune-criterion=EarlyJohn_pruning-limit=0.8_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=5678/models/ResNet18_finished'
# path5 = '/nfs/students/ayle/guided-research/results/VGG16/2021-05-31_03.46.55_model=VGG16_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.93_prune-freq=4_prune-delay=8_outer-layer-pruning=0_prune-to=5_rewind-to=0_train-scheme=DefaultTrainer_seed=5678/models/VGG16_finished'
model5 = deepcopy(model) 
load_checkpoint(path5, model5, out)

In [None]:
device = arguments['device']

In [None]:
# load data
train_loader, test_loader = find_right_model(
    DATASETS, arguments['data_set'],
    arguments=arguments
)

In [None]:
# load OOD data
_, ood_loader = find_right_model(
    DATASETS, arguments['ood_data_set'],
    arguments=arguments
)

In [None]:
model1.eval()
model2.eval()
model3.eval()
model4.eval()
model5.eval()

In [None]:
from sklearn import metrics
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt


def calculate_auroc(correct, predictions):
    fpr, tpr, thresholds = metrics.roc_curve(correct, predictions)
    auroc = metrics.auc(fpr, tpr)
    plt.plot(fpr, tpr)
    return auroc


def calculate_aupr(correct, predictions):
    aupr = metrics.average_precision_score(correct, predictions)
    return aupr

In [None]:
ood_labels = []
ood_scores = []

mean_var = 0
count = 0

accuracy = 0
total_disagreement = 0
total_time = 0

kl_loss = torch.nn.KLDivLoss(reduction='none')

with torch.no_grad():
    for data, labels in test_loader:
        preds = torch.zeros((5, len(data), 10))
        disagreement = torch.zeros((5, len(data)))
        count += 1
        t = 0
        
        data = data.to(device)
        
        start = time.time()
        out = model1(data)
        t += time.time() - start
        probs = F.softmax(out, 1)
        preds[0] = probs.cpu()
        
#         breakpoint()
        
        start = time.time()
        out = model2(data)
        t += time.time() - start
        probs = F.softmax(out, 1)
        preds[1] = probs.cpu()
        
        start = time.time()
        out = model3(data)
        t += time.time() - start
        probs = F.softmax(out, 1)
        preds[2] = probs.cpu()
        
        start = time.time()
        out = model4(data)
        t += time.time() - start
        probs = F.softmax(out, 1)
        preds[3] = probs.cpu()
        
        start = time.time()
        out = model5(data)
        t += time.time() - start
        probs = F.softmax(out, 1)
        preds[4] = probs.cpu()
        
        all_probs = preds.mean(0)
                                        
        disagreement[0] = kl_loss(torch.log(preds[0]), all_probs.cpu()).sum(-1)
        disagreement[1] = kl_loss(torch.log(preds[1]), all_probs.cpu()).sum(-1)
        disagreement[2] = kl_loss(torch.log(preds[2]), all_probs.cpu()).sum(-1)
        disagreement[3] = kl_loss(torch.log(preds[3]), all_probs.cpu()).sum(-1)
        disagreement[4] = kl_loss(torch.log(preds[4]), all_probs.cpu()).sum(-1)

        max_probs, max_indices  = all_probs.max(1)
        max_probs = max_probs.detach().cpu().numpy()
                
        ood_labels.append(np.ones_like(max_probs))
        ood_scores.append(max_probs)
        
#         ood_labels.append(np.zeros_like(disagreement.mean(0)))
#         ood_scores.append(disagreement.mean(0))
        
        total_time += t
        
        accuracy += (labels == max_indices.cpu()).float().mean()        
        mean_var += torch.var(preds, dim=0).mean()
        
print(mean_var / count)
print(accuracy / count)
print(total_time / count)

In [None]:
mean_var = 0
count = 0

kl_loss = torch.nn.KLDivLoss(reduction='none')
total_time = 0

with torch.no_grad():
    for data, labels in ood_loader:
        preds = torch.zeros((5, len(data), 10))
        disagreement = torch.zeros((5, len(data)))
        count += 1
        t = 0
        
        data = data.to(device)
        
        start = time.time()
        out = model1(data)
        t += time.time() - start
        probs = F.softmax(out, 1)
        preds[0] = probs.cpu()
        
        start = time.time()
        out = model2(data)
        t += time.time() - start
        probs = F.softmax(out, 1)
        preds[1] = probs.cpu()
        
        start = time.time()
        out = model3(data)
        t += time.time() - start
        probs = F.softmax(out, 1)
        preds[2] = probs.cpu()
        
        start = time.time()
        out = model4(data)
        t += time.time() - start
        probs = F.softmax(out, 1)
        preds[3] = probs.cpu()
        
        start = time.time()
        out = model5(data)
        t += time.time() - start
        probs = F.softmax(out, 1)
        preds[4] = probs.cpu()
        
        all_probs = preds.mean(0)
                                        
        disagreement[0] = kl_loss(torch.log(preds[0]), all_probs.cpu()).sum(-1)
        disagreement[1] = kl_loss(torch.log(preds[1]), all_probs.cpu()).sum(-1)
        disagreement[2] = kl_loss(torch.log(preds[2]), all_probs.cpu()).sum(-1)
        disagreement[3] = kl_loss(torch.log(preds[3]), all_probs.cpu()).sum(-1)
        disagreement[4] = kl_loss(torch.log(preds[4]), all_probs.cpu()).sum(-1)

        max_probs, max_indices  = all_probs.max(1)
        max_probs = max_probs.detach().cpu().numpy()
                
        ood_labels.append(np.zeros_like(max_probs))
        ood_scores.append(max_probs)
        
#         ood_labels.append(np.ones_like(disagreement.mean(0)))
#         ood_scores.append(disagreement.mean(0))
        
        total_time += t
        
        mean_var += torch.var(preds, dim=0).mean()
        
print(mean_var / count)
print(total_time / count)

In [None]:
np.concatenate(ood_scores)[:10000].mean()

In [None]:
np.concatenate(ood_scores)[10000:].mean()

In [None]:
print(calculate_auroc(np.concatenate(ood_labels), np.concatenate(ood_scores)))
print(calculate_aupr(np.concatenate(ood_labels), np.concatenate(ood_scores)))

In [None]:
class DS(Dataset):

    def __init__(self, images, labels):
        self.images = images
        self.labels = labels
        self.mean = [0.485, 0.456, 0.406]  # avg 0.449
        self.std = [0.229, 0.224, 0.225]  # avg 0.226
        self.transforms = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(mean=self.mean, std=self.std)
            ]
        )

    def __getitem__(self, item):
        image = self.images[item] / 255
        image = self.transforms(image.transpose((1, 2, 0)))
        return image.to(torch.float32), torch.tensor(self.labels[item], dtype=torch.float32)

    def __len__(self):
        return len(self.images)

In [None]:
with torch.no_grad():
    if arguments["data_set"] == "CIFAR10":
        ds_path = os.path.join(DATASET_PATH, "cifar10_corrupted")
        for ds_dataset_name in os.listdir(ds_path):
            print(ds_dataset_name)
            npz_dataset = np.load(os.path.join(ds_path, ds_dataset_name))

            ds_dataset = DS(npz_dataset["images"], npz_dataset["labels"])
            ds_loader = torch.utils.data.DataLoader(
                ds_dataset,
                batch_size=arguments['batch_size'],
                shuffle=False,
                pin_memory=True,
                num_workers=4
            )
            
            mean_var = 0
            count = 0
            accuracy = 0
            
            for data, labels in ds_loader:
                preds = torch.zeros((5, len(data), 10))
                count += 1

                data = data.to(device)
                all_probs = 0

                out = model1(data)
                probs = F.softmax(out, 1)
                all_probs += probs
                preds[0] = probs.cpu()

                out = model2(data)
                probs = F.softmax(out, 1)
                all_probs += probs
                preds[1] = probs.cpu()

                out = model3(data)
                probs = F.softmax(out, 1)
                all_probs += probs
                preds[2] = probs.cpu()

                out = model4(data)
                probs = F.softmax(out, 1)
                all_probs += probs
                preds[3] = probs.cpu()

                out = model5(data)
                probs = F.softmax(out, 1)
                all_probs += probs
                preds[4] = probs.cpu()

                all_probs = all_probs / 5

                max_probs, max_indices  = all_probs.max(1)
                max_probs = max_probs.detach().cpu().numpy()
                ood_labels.append(np.zeros_like(max_probs))
                ood_scores.append(max_probs)

                mean_var += torch.var(preds, dim=0).mean()
                accuracy += (labels == max_indices.cpu()).float().mean()   

            print(mean_var / count)
            print(accuracy / count)
            print(calculate_auroc(np.concatenate(ood_labels), np.concatenate(ood_scores)))
            print(calculate_aupr(np.concatenate(ood_labels), np.concatenate(ood_scores)))
            
            breakpoint()

In [None]:
results = {}

In [None]:
for attack in arguments['eval_attacks']:
    for epsilon in arguments['eval_epsilons']:
        out("Attack {}".format(attack))
        # build tester
        tester = find_right_model(
            TESTERS_DIR, 'AdversarialEvaluation',
            attack=attack,
#             model=trainer._model,
            model = model,
            device=device,
            arguments=None,
            test_loader=test_loader,
        )

        out("Epsilon {}".format(str(epsilon)))
        res = tester.evaluate(epsilon=epsilon)

        for key, value in res.items():
            results[key] = value

In [None]:
with torch.no_grad():
    for ood_data_set in arguments['eval_ood_data_sets']:
        out("OOD Dataset: {}".format(ood_data_set))
        # load data
        _, test_loader = find_right_model(
            DATASETS, arguments['data_set'],
            arguments=arguments
        )

        # load OOD data
        _, ood_loader = find_right_model(
            DATASETS, ood_data_set,
            arguments=arguments
        )
        # build tester
        tester = find_right_model(
            TESTERS_DIR, 'OODEvaluation',
#             model=trainer._model,
            model = model,
            device=device,
            arguments=None,
            test_loader=test_loader,
            ood_loader=ood_loader,
            ood_dataset=ood_data_set
        )
        res = tester.evaluate()

        for key, value in res.items():
            results[key] = value

In [None]:


with torch.no_grad():
    if arguments["data_set"] == "CIFAR10":
        ds_path = os.path.join(DATASET_PATH, "cifar10_corrupted")
        for ds_dataset_name in os.listdir(ds_path):
            npz_dataset = np.load(os.path.join(ds_path, ds_dataset_name))

            ds_dataset = DS(npz_dataset["images"], npz_dataset["labels"])
            ds_loader = torch.utils.data.DataLoader(
                ds_dataset,
                batch_size=arguments['batch_size'],
                shuffle=False,
                pin_memory=True,
                num_workers=4
            )

            # build tester
            tester = find_right_model(
                TESTERS_DIR, 'DSEvaluation',
                model=trainer._model,
                device=device,
                arguments=None,
                test_loader=test_loader,
                ds_loader=ds_loader,
                ds_dataset=ds_dataset_name.split('.')[0]
            )
            res = tester.evaluate()

            for key, value in res.items():
                results[key] = value

            break

In [None]:
results