In [1]:
import os
os.chdir('/nfs/homedirs/ayle/guided-research/SNIP-it/')

In [2]:
import torch
from torchvision import datasets, transforms
import foolbox as fb
from experiments.main import load_checkpoint
from models import GeneralModel
from models.statistics.Metrics import Metrics
from utils.config_utils import *
from utils.model_utils import *
from utils.system_utils import *
from utils.attacks_utils import get_attack
from torch.utils.data.dataset import Dataset
from copy import deepcopy
import pickle
import time
import torch.nn.functional as F

In [3]:
arguments = dict({
'eval_freq': 1000,  # evaluate every n batches
    'save_freq': 1e6,  # save model every n epochs, besides before and after training
    'batch_size': 512,  # size of batches, for Imagenette 128
    'seed': 1234,  # random seed
    'max_training_minutes': 6120 , # one hour and a 45 minutes max, process killed after n minutes (after finish of epoch)
    'plot_weights_freq': 50, # plot pictures to tensorboard every n epochs
    'prune_freq': 1, # if pruning during training: how long to wait before starting
    'prune_delay': 0, # "if pruning during training: 't' from algorithm box, interval between pruning events, default=0
    'prune_to': 0,
    'epochs': 0,
    'rewind_to': 0, # rewind to this epoch if rewinding is done
    'snip_steps': 5, # 's' in algorithm box, number of pruning steps for 'rule of thumb', TODO
    'pruning_rate': 0.0, # pruning rate passed to criterion at pruning event. however, most override this
    'growing_rate': 0.0000 , # grow back so much every epoch (for future criterions)
    'pruning_limit': 0.5,  # Prune until here, if structured in nodes, if unstructured in weights. most criterions use this instead of the pruning_rate
    'local_pruning': 0,
    'learning_rate': 2e-3,
    'grad_clip': 10,
    'grad_noise': 0 , # added gaussian noise to gradients
    'l2_reg': 5e-5 , # weight decay
    'l1_reg': 0 , # l1-norm regularisation
    'lp_reg': 0 , # lp regularisation with p < 1
    'l0_reg': 1.0 , # l0 reg lambda hyperparam
    'hoyer_reg': 0.001 , # hoyer reg lambda hyperparam
    'beta_ema': 0.999 , # l0 reg beta ema hyperparam

    'loss': 'CrossEntropy',
    'optimizer': 'ADAM',
    'model': 'ResNet18',  # ResNet not supported with structured
    'data_set': 'CIFAR10',
    'ood_data_set': 'SVHN',
    'ood_data_set_prune': 'SVHN',
    'prune_criterion': 'EmptyCrit',  # options: SNIP, SNIPit, SNIPitDuring, UnstructuredRandom, GRASP, HoyerSquare, IMP, // SNAPit, StructuredRandom, GateDecorators, EfficientConvNets, GroupHoyerSquare
    'train_scheme': 'DefaultTrainer' , # default: DefaultTrainer
    'attack': 'FGSM',
    'epsilon': 12,
    'eval_ood_data_sets': ['SVHN', 'CIFAR100'],
    'eval_attacks': ['FGSM', 'L2FGSM'],
    'eval_epsilons': [6, 12, 48],

    'device': 'cuda',
    'results_dir': "tests",

    'checkpoint_name': None,
    'checkpoint_model': None,

    'disable_cuda_benchmark': 1 , # speedup (disable) vs reproducibility (leave it)
    'eval': 0,
    'disable_autoconfig': 0 , # for the brave
    'preload_all_data': 0 , # load all data into ram memory for speedups
    'tuning': 0 , # splits trainset into train and validationset, omits test set

    'get_hooks': 0,
    'track_weights': 0 , # "keep statistics on the weights through training
    'disable_masking': 1 , # disable the ability to prune unstructured
    'enable_rewinding': 0, # enable the ability to rewind to previous weights
    'outer_layer_pruning': 1, # allow to prune outer layers (unstructured) or not (structured)
    'first_layer_dense': 0,
    'random_shuffle_labels': 0  ,# run with random-label experiment from zhang et al
    'l0': 0,  # run with l0 criterion, might overwrite some other arguments
    'hoyer_square': 0, # "run in unstructured DeephoyerSquare criterion, might overwrite some other arguments
    'group_hoyer_square': 0 ,# run in unstructured Group-DeephoyerSquare criterion, might overwrite some other arguments

    'disable_histograms': 0,
    'disable_saliency': 0,
    'disable_confusion': 0,
    'disable_weightplot': 0,
    'disable_netplot': 0,
    'skip_first_plot': 0,
    'disable_activations': 0,
    
#     'input_dim': [1, 28, 28],
#       'output_dim': 10,
#       'hidden_dim': [512],
#       'N': 60000,
    
    'input_dim': [3, 32, 32],
      'output_dim': 10,
      'hidden_dim': [512],
      'N': 60000
})

In [4]:
import logging
from sacred import Experiment
import numpy as np
import seml

import sys
import warnings

sys.path.append('.')

from models import GeneralModel
from models.statistics.Metrics import Metrics
from utils.config_utils import *
from utils.model_utils import *
from utils.system_utils import *

import torch
from torch.utils.data.dataset import Dataset

from torchvision import transforms

from lipEstimation.lipschitz_utils import compute_module_input_sizes
from lipEstimation.lipschitz_approximations import lipschitz_spectral_ub


def main(
        arguments,
        metrics: Metrics
):

    global out
    out = metrics.log_line
    out(f"starting at {get_date_stamp()}")

    # hardware
    device = configure_device(arguments)

    if arguments['disable_cuda_benchmark']:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # for reproducibility
    configure_seeds(arguments, device)

    # filter for incompatible properties
    assert_compatibilities(arguments)

    # load pre-trained weights if specified
#     path1 = '/nfs/students/ayle/guided-research/results/AlexNet/2021-07-13_21.34.45_model=AlexNet_dataset=CIFAR10_prune-criterion=StructuredEFGit_pruning-limit=0.9_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/AlexNet_mod_finished.pickle'
#     path2 = '/nfs/students/ayle/guided-research/results/AlexNet/2021-07-13_22.22.00_model=AlexNet_dataset=CIFAR10_prune-criterion=StructuredEFGit_pruning-limit=0.9_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=2345/models/AlexNet_mod_finished.pickle'
#     path3 = '/nfs/students/ayle/guided-research/results/AlexNet/2021-07-14_05.03.33_model=AlexNet_dataset=CIFAR10_prune-criterion=StructuredEFGit_pruning-limit=0.9_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=3456/models/AlexNet_mod_finished.pickle'
#     path4 = '/nfs/students/ayle/guided-research/results/AlexNet/2021-07-14_05.15.43_model=AlexNet_dataset=CIFAR10_prune-criterion=StructuredEFGit_pruning-limit=0.9_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=4567/models/AlexNet_mod_finished.pickle'
#     path5 = '/nfs/students/ayle/guided-research/results/AlexNet/2021-07-14_05.49.55_model=AlexNet_dataset=CIFAR10_prune-criterion=StructuredEFGit_pruning-limit=0.9_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=5678/models/AlexNet_mod_finished.pickle'

#     path1 = '/nfs/students/ayle/guided-research/results/AlexNet/2021-07-13_21.22.12_model=AlexNet_dataset=CIFAR10_prune-criterion=StructuredEFGit_pruning-limit=0.75_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/AlexNet_mod_finished.pickle'
#     path2 = '/nfs/students/ayle/guided-research/results/AlexNet/2021-07-13_22.06.10_model=AlexNet_dataset=CIFAR10_prune-criterion=StructuredEFGit_pruning-limit=0.75_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=2345/models/AlexNet_mod_finished.pickle'
#     path3 = '/nfs/students/ayle/guided-research/results/AlexNet/2021-07-14_22.44.25_model=AlexNet_dataset=CIFAR10_prune-criterion=StructuredEFGit_pruning-limit=0.75_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=3456/models/AlexNet_mod_finished.pickle'
#     path4 = '/nfs/students/ayle/guided-research/results/AlexNet/2021-07-14_23.44.04_model=AlexNet_dataset=CIFAR10_prune-criterion=StructuredEFGit_pruning-limit=0.75_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=4567/models/AlexNet_mod_finished.pickle'
#     path5 = '/nfs/students/ayle/guided-research/results/AlexNet/2021-07-15_00.43.53_model=AlexNet_dataset=CIFAR10_prune-criterion=StructuredEFGit_pruning-limit=0.75_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=5678/models/AlexNet_mod_finished.pickle'

#     path1 = '/nfs/students/ayle/guided-research/results/AlexNet/2021-07-13_22.50.22_model=AlexNet_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/AlexNet_mod_finished.pickle'

#     path1 = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/AlexNet/2021-07-18_12.43.42_model=AlexNet_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=3456/models/AlexNet_mod_finished.pickle'
#     path2 = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/AlexNet/2021-07-18_12.43.42_model=AlexNet_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=4567/models/AlexNet_mod_finished.pickle'
#     path3 = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/AlexNet/2021-07-18_12.43.50_model=AlexNet_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=5678/models/AlexNet_mod_finished.pickle'
#     path4 = '/nfs/students/ayle/guided-research/results/AlexNet/2021-07-13_22.50.22_model=AlexNet_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/AlexNet_mod_finished.pickle'
#     path5 = '/nfs/students/ayle/guided-research/results/AlexNet/2021-07-13_22.52.11_model=AlexNet_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=2345/models/AlexNet_mod_finished.pickle'
    
    path1 = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNetAfter/2021-08-05_11.29.38_model=ResNet18_dataset=CIFAR10_ood-dataset=SVHN_attack=FGSM_epsilon=6_prune-criterion=SNIP_pruning-limit=0.94_prune-freq=1_prune-delay=0_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_mod_finished.pickle'
    path2 = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNetAfter/2021-08-05_11.23.29_model=ResNet18_dataset=CIFAR10_ood-dataset=SVHN_attack=FGSM_epsilon=6_prune-criterion=SNIP_pruning-limit=0.94_prune-freq=1_prune-delay=0_rewind-to=0_train-scheme=DefaultTrainer_seed=2345/models/ResNet18_mod_finished.pickle'
    path3 = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNetAfter/2021-08-05_11.36.02_model=ResNet18_dataset=CIFAR10_ood-dataset=SVHN_attack=FGSM_epsilon=6_prune-criterion=SNIP_pruning-limit=0.94_prune-freq=1_prune-delay=0_rewind-to=0_train-scheme=DefaultTrainer_seed=3456/models/ResNet18_mod_finished.pickle'
    path4 = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNetAfter/2021-08-05_11.40.55_model=ResNet18_dataset=CIFAR10_ood-dataset=SVHN_attack=FGSM_epsilon=6_prune-criterion=SNIP_pruning-limit=0.94_prune-freq=1_prune-delay=0_rewind-to=0_train-scheme=DefaultTrainer_seed=4567/models/ResNet18_mod_finished.pickle'
    path5 = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNetAfter/2021-08-05_11.45.37_model=ResNet18_dataset=CIFAR10_ood-dataset=SVHN_attack=FGSM_epsilon=6_prune-criterion=SNIP_pruning-limit=0.94_prune-freq=1_prune-delay=0_rewind-to=0_train-scheme=DefaultTrainer_seed=5678/models/ResNet18_mod_finished.pickle'
    
    model1 = load_checkpoint(path1).eval()
    model2 = load_checkpoint(path2).eval()
    model3 = load_checkpoint(path3).eval()
    model4 = load_checkpoint(path4).eval()
    model5 = load_checkpoint(path5).eval()
#     ensembles = [model1, model2, model3, model4, model5]
#     ensembles = [model1, model2, model3]
    ensembles = [model1]

    # load data
    train_loader, test_loader = find_right_model(
        DATASETS, arguments['data_set'],
        arguments=arguments
    )

    # load OOD data
    _, ood_loader = find_right_model(
        DATASETS, arguments['ood_data_set'],
        arguments=arguments
    )

    results = {}
    
    import time
    acc = []
    inf_time = []
    with torch.no_grad():
        for batch_num, batch in enumerate(test_loader):
            x, y = batch
            x, y = x.to(device), y.to(device)

            output = 0
            start = time.time()
            for model in ensembles:
                output += model(x)
            end = time.time()
            inf_time.append(end-start)
            output /= len(ensembles)
            probs = F.softmax(output, dim=-1)

            predictions = probs.argmax(dim=-1, keepdim=True).view_as(y)
            correct = y.eq(predictions).sum().item()
            acc.append(correct / output.shape[0])
        
    results['accuracy'] = np.mean(acc)
    results['inference_time'] = np.mean(inf_time)
    
    breakpoint()

    out("EVALUATING...")

    with torch.no_grad():
        for ood_data_set in arguments['eval_ood_data_sets']:
            out("OOD Dataset: {}".format(ood_data_set))
            # load data
            _, test_loader = find_right_model(
                DATASETS, arguments['data_set'],
                arguments=arguments
            )

            # load OOD data
            _, ood_loader = find_right_model(
                DATASETS, ood_data_set,
                arguments=arguments
            )
            # build tester
            tester = find_right_model(
                TESTERS_DIR, 'OODEvaluation',
                model=ensembles,
                device=device,
                arguments=None,
                test_loader=test_loader,
                ood_loader=ood_loader,
                ood_dataset=ood_data_set,
                ensemble=True
            )
            res = tester.evaluate()

            for key, value in res.items():
                results[key] = value

    class DS(Dataset):

        def __init__(self, images, labels):
            self.images = images
            self.labels = labels
            self.mean = [0.4914, 0.4822, 0.4465]
            self.std = [0.2471, 0.2435, 0.2616]
            self.transforms = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(mean=self.mean, std=self.std)
                ]
            )

        def __getitem__(self, item):
            image = self.images[item] / 255
            image = self.transforms(image.transpose((1, 2, 0)))
            return image.to(torch.float32), torch.tensor(self.labels[item], dtype=torch.float32)

        def __len__(self):
            return len(self.images)

    with torch.no_grad():
        if arguments["data_set"] == "CIFAR10":
            avg_acc = np.zeros(5)
            avg_entropy = np.zeros(5)
            avg_auroc = np.zeros(5)
            avg_aupr = np.zeros(5)
            avg_auroc_ent = np.zeros(5)
            avg_aupr_ent = np.zeros(5)
            ds_path = os.path.join(DATASET_PATH, "cifar10_corrupted")
            for ds_dataset_name in os.listdir(ds_path):
                npz_dataset = np.load(os.path.join(ds_path, ds_dataset_name))

                ds_dataset = DS(npz_dataset["images"], npz_dataset["labels"])
                ds_loader = torch.utils.data.DataLoader(
                    ds_dataset,
                    batch_size=arguments['batch_size'],
                    shuffle=False,
                    pin_memory=True,
                    num_workers=4
                )

                # build tester
                tester = find_right_model(
                    TESTERS_DIR, 'DSEvaluation',
                    model=ensembles,
                    device=device,
                    arguments=None,
                    test_loader=test_loader,
                    ds_loader=ds_loader,
                    ds_dataset=ds_dataset_name.split('.')[0],
                    ensemble=True
                )
                res = tester.evaluate()

                severity = int(ds_dataset_name.split('.')[0].split('_')[-1]) - 1
                for key, value in res.items():
                    if key.startswith('acc'):
                        avg_acc[severity] += value
                    elif key.startswith('auroc_entropy'):
                        avg_auroc_ent[severity] += value
                    elif key.startswith('aupr_entropy'):
                        avg_aupr_ent[severity] += value
                    elif key.startswith('auroc'):
                        avg_auroc[severity] += value
                    elif key.startswith('aupr'):
                        avg_aupr[severity] += value
                    elif key.startswith('entropy_'):
                        avg_entropy[severity] += value

                    results[key] = value
            avg_acc = avg_acc / 15
            avg_auroc_ent = avg_auroc_ent / 15
            avg_aupr_ent = avg_aupr_ent / 15
            avg_auroc = avg_auroc / 15
            avg_aupr = avg_aupr / 15
            avg_entropy = avg_entropy / 15
            for i in range(len(avg_acc)):
                name = 'avg_acc_' + str(i + 1)
                results[name] = avg_acc[i]
            for i in range(len(avg_acc)):
                name = 'avg_auroc_ent_' + str(i + 1)
                results[name] = avg_auroc_ent[i]
            for i in range(len(avg_acc)):
                name = 'avg_aupr_ent_' + str(i + 1)
                results[name] = avg_aupr_ent[i]
            for i in range(len(avg_acc)):
                name = 'avg_auroc_' + str(i + 1)
                results[name] = avg_auroc[i]
            for i in range(len(avg_acc)):
                name = 'avg_aupr_' + str(i + 1)
                results[name] = avg_aupr[i]
            for i in range(len(avg_acc)):
                name = 'avg_entropy_' + str(i + 1)
                results[name] = avg_entropy[i]

    return results


def assert_compatibilities(arguments):
    check_incompatible_props([arguments['loss'] != "L0CrossEntropy", arguments['l0']], "l0", arguments['loss'])
    check_incompatible_props([arguments['train_scheme'] != "L0Trainer", arguments['l0']], "l0", arguments['train_scheme'])
    check_incompatible_props([arguments['l0'], arguments['group_hoyer_square'], arguments['hoyer_square']],
                             "Choose one mode, not multiple")
    check_incompatible_props(
        ["Structured" in arguments['prune_criterion'], "Group" in arguments['prune_criterion'], "ResNet" in arguments['model']],
        "structured", "residual connections")
    # todo: add more


def load_checkpoint(path):
    with open(path, 'rb') as f:
        model = pickle.load(f)
    return model


def log_start_run(arguments, out):
    arguments.PyTorch_version = torch.__version__
    arguments.PyThon_version = sys.version
    arguments.pwd = os.getcwd()
    out("PyTorch version:", torch.__version__, "Python version:", sys.version)
    out("Working directory: ", os.getcwd())
    out("CUDA avalability:", torch.cuda.is_available(), "CUDA version:", torch.version.cuda)
    out(arguments)

def run(arguments):
    metrics = Metrics()
    out = metrics.log_line
    metrics._batch_size = arguments['batch_size']
    metrics._eval_freq = arguments['eval_freq']
    set_results_dir(arguments["results_dir"])
    return main(arguments, metrics)

In [None]:
results = run(arguments)

starting at 2021-08-05_12.27.33
Using downloaded and verified file: gitignored/data/train_32x32.mat
Using downloaded and verified file: gitignored/data/test_32x32.mat
> [0;32m<ipython-input-4-2f31f856c020>[0m(124)[0;36mmain[0;34m()[0m
[0;32m    122 [0;31m    [0mbreakpoint[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    123 [0;31m[0;34m[0m[0m
[0m[0;32m--> 124 [0;31m    [0mout[0m[0;34m([0m[0;34m"EVALUATING..."[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    125 [0;31m[0;34m[0m[0m
[0m[0;32m    126 [0;31m    [0;32mwith[0m [0mtorch[0m[0;34m.[0m[0mno_grad[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> np.mean(acc)
0.5842543658088235


In [None]:
results