In [None]:
import os
os.chdir('/nfs/homedirs/ayle/guided-research/SNIP-it/')

In [None]:
import torch
from torchvision import datasets, transforms
import foolbox as fb
from experiments.main import load_checkpoint
from models import GeneralModel
from models.statistics.Metrics import Metrics
from utils.config_utils import *
from utils.model_utils import *
from utils.system_utils import *
from utils.attacks_utils import get_attack
from torch.utils.data.dataset import Dataset
import torch.nn as nn

from lipEstimation.lipschitz_utils import compute_module_input_sizes
from lipEstimation.lipschitz_approximations import lipschitz_spectral_ub

from copy import deepcopy

In [None]:
arguments = dict({
    'eval_freq': 1000,  # evaluate every n batches
    'save_freq': 1e6,  # save model every n epochs, besides before and after training
    'batch_size': 256,  # size of batches, for Imagenette 128  # 128
    'seed': 1234,  # random seed
    'max_training_minutes': 6120 , # one hour and a 45 minutes max, process killed after n minutes (after finish of epoch)
    'plot_weights_freq': 50, # plot pictures to tensorboard every n epochs
    'prune_freq': 1, # if pruning during training: how long to wait before starting
    'prune_delay': 0, # "if pruning during training: 't' from algorithm box, interval between pruning events, default=0
    'prune_to': 5,
    'epochs': 1,  # 200
    'rewind_to': 0, # rewind to this epoch if rewinding is done
    'snip_steps': 5, # 's' in algorithm box, number of pruning steps for 'rule of thumb', TODO
    'snip_iter': 5,
    'pruning_rate': 0.0, # pruning rate passed to criterion at pruning event. however, most override this
    'growing_rate': 0.0000 , # grow back so much every epoch (for future criterions)
    'pruning_limit': 0.0,  # Prune until here, if structured in nodes, if unstructured in weights. most criterions use this instead of the pruning_rate
    'local_pruning': 0,
    'learning_rate': 2e-3,  # 0.1
    'grad_clip': 10,
    'grad_noise': 0 , # added gaussian noise to gradients
    'l2_reg': 5e-5 , # weight decay  # 5e-4
    'l1_reg': 0 , # l1-norm regularisation
    'lp_reg': 0 , # lp regularisation with p < 1
    'l0_reg': 1.0 , # l0 reg lambda hyperparam
    'hoyer_reg': 0.001 , # hoyer reg lambda hyperparam
    'beta_ema': 0.999 , # l0 reg beta ema hyperparam
    'momentum': 0.9, 

    'loss': 'CrossEntropy',
    'optimizer': 'ADAM',  # SGD (+ scheduler)
    'model': 'ResNet18',  # ResNet not supported with structured
    'data_set': 'CIFAR10',
    'ood_data_set': 'SVHN',
    'ood_data_set_prune': 'SVHN',
    'prune_criterion': 'EmptyCrit',  # options: SNIP, SNIPit, SNIPitDuring, UnstructuredRandom, GRASP, HoyerSquare, IMP, // SNAPit, StructuredRandom, GateDecorators, EfficientConvNets, GroupHoyerSquare
    'train_scheme': 'DefaultTrainer' , # default: DefaultTrainer
    'attack': 'FGSM',
    'epsilon': 8,
    'eval_ood_data_sets': ['SVHN', 'CIFAR100', 'LSUN', 'OODOMAIN'],
    'eval_attacks': ['FGSM', 'L2FGSM'],
    'eval_epsilons': [8, 16],

    'device': 'cuda',
    'results_dir': "tests",

    'checkpoint_name': None,
    'checkpoint_model': None,

    'disable_cuda_benchmark': 1 , # speedup (disable) vs reproducibility (leave it)
    'eval': 0,
    'disable_autoconfig': 0 , # for the brave
    'preload_all_data': 0 , # load all data into ram memory for speedups
    'tuning': 0 , # splits trainset into train and validationset, omits test set

    'get_hooks': 0,
    'track_weights': 0 , # "keep statistics on the weights through training
    'disable_masking': 1 , # disable the ability to prune unstructured
    'enable_rewinding': 0, # enable the ability to rewind to previous weights
    'outer_layer_pruning': 1, # allow to prune outer layers (unstructured) or not (structured)
    'first_layer_dense': 0,
    'random_shuffle_labels': 0  ,# run with random-label experiment from zhang et al
    'l0': 0,  # run with l0 criterion, might overwrite some other arguments
    'hoyer_square': 0, # "run in unstructured DeephoyerSquare criterion, might overwrite some other arguments
    'group_hoyer_square': 0 ,# run in unstructured Group-DeephoyerSquare criterion, might overwrite some other arguments

    'disable_histograms': 0,
    'disable_saliency': 0,
    'disable_confusion': 0,
    'disable_weightplot': 0,
    'disable_netplot': 0,
    'skip_first_plot': 0,
    'disable_activations': 0,
    
#     'input_dim': [1, 28, 28],
#       'output_dim': 10,
#       'hidden_dim': [512],
#       'N': 60000,
    
    'input_dim': [3, 32, 32],
    'output_dim': 10,
    'hidden_dim': [512],
    'N': 60000,
    'mean': (0.4914, 0.4822, 0.4465),
    'std': (0.2471, 0.2435, 0.2616)
})

In [None]:
DATASET_PATH = '/nfs/students/ayle/guided-research/gitignored/data'

In [None]:
if arguments['data_set'] not in ['CIFAR10', 'MNIST', 'FASHION', 'custom_CIFAR10']:
    raise NotImplementedError(f'Unnormalized loading not implemented for dataset {arguments["data_set"]}')
metrics = Metrics()
out = metrics.log_line
metrics._batch_size = arguments['batch_size']
metrics._eval_freq = arguments['eval_freq']
set_results_dir(arguments["results_dir"])

In [None]:
model: GeneralModel = find_right_model(
        NETWORKS_DIR, arguments['model'],
        device=arguments['device'],
        hidden_dim=arguments['hidden_dim'],
        input_dim=arguments['input_dim'],
        output_dim=arguments['output_dim'],
        is_maskable=arguments['disable_masking'],
        is_tracking_weights=arguments['track_weights'],
        is_rewindable=arguments['enable_rewinding'],
        is_growable=arguments['growing_rate'] > 0,
        outer_layer_pruning=arguments['outer_layer_pruning'],
        maintain_outer_mask_anyway=(
                                       not arguments['outer_layer_pruning']) and (
                                           "Structured" in arguments['prune_criterion']),
        l0=arguments['l0'],
        l0_reg=arguments['l0_reg'],
        N=arguments['N'],
        beta_ema=arguments['beta_ema'],
        l2_reg=arguments['l2_reg']
).to(arguments['device'])

In [None]:
load_checkpoint(arguments, model, out)

In [None]:
# def load_checkpoint(path, model, out):
#     with open(path, 'rb') as f:
#         state = pickle.load(f)
#     try:
#         model.load_state_dict(state)
#     except KeyError as e:
#         print(list(state.keys()))
#         raise e
#     out(f"Loaded checkpoint {path}")
    
def load_checkpoint(path, model, out):
    state_dict = torch.load(path)
    new_state_dict = {}
    for key, val in state_dict.items():
        if key == 'aug.width': continue
        
        new_key = '.'.join(['m'] + key.split('.')[1:])
        new_state_dict[new_key] = val
    model.load_state_dict(new_state_dict)
    out(f"Loaded checkpoint {path}")

In [None]:
# path = '/nfs/students/ayle/guided-research/gitignored/results/tests/2021-09-21_10.19.16_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle'
# path = '/nfs/students/ayle/guided-research/results/ResNet18/2021-07-26_22.46.19_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle'
path = '/nfs/students/ayle/guided-research/gitignored/results/invariances/aug_fixed_trans_trained.pt'

load_checkpoint(path, model, out)

In [None]:
device = arguments['device']

In [None]:
# get criterion
criterion = find_right_model(
    CRITERION_DIR, arguments['prune_criterion'],
    model=model,
    limit=arguments['pruning_limit'],
    start=0.5,
    steps=arguments['snip_steps'],
    device=arguments['device'],
    arguments=arguments
)

# load data
train_loader, test_loader = find_right_model(
    DATASETS, arguments['data_set'],
    arguments=arguments,
    mean=arguments['mean'],
    std=arguments['std']
)

# load OOD data
_, ood_loader = find_right_model(
    DATASETS, arguments['ood_data_set'],
    arguments=arguments,
    mean=arguments['mean'],
    std=arguments['std']
)

In [None]:
# load OOD data
ood_prune_loader, _ = find_right_model(
    DATASETS, arguments['ood_data_set_prune'],
    arguments=arguments,
    mean=arguments['mean'],
    std=arguments['std']
)

# get loss function
loss = find_right_model(
    LOSS_DIR, arguments['loss'],
    device=device,
    l1_reg=arguments['l1_reg'],
    lp_reg=arguments['lp_reg'],
    l0_reg=arguments['l0_reg'],
    hoyer_reg=arguments['hoyer_reg']
)

# get optimizer
optimizer = find_right_model(
    OPTIMS, arguments['optimizer'],
    params=model.parameters(),
    lr=arguments['learning_rate'],
#     momentum=arguments['momentum'],
    weight_decay=arguments['l2_reg'] if not arguments['l0'] else 0
)

run_name = f'_model={arguments["model"]}_dataset={arguments["data_set"]}_prune-criterion={arguments["prune_criterion"]}' + \
            f'_pruning-limit={arguments["pruning_limit"]}_train-scheme={arguments["train_scheme"]}_seed={arguments["seed"]}'

# build trainer
trainer = find_right_model(
    TRAINERS_DIR, arguments['train_scheme'],
    model=model,
    loss=loss,
    optimizer=optimizer,
    device=device,
    arguments=arguments,
    train_loader=train_loader,
    test_loader=test_loader,
    ood_loader=ood_loader,
    ood_prune_loader=ood_prune_loader,
    metrics=metrics,
    criterion=criterion,
    run_name=run_name
)

trainer.train()

In [None]:
results = {}

In [None]:
results = {'train_acc': trainer.train_acc, 'sparsity': trainer.sparsity, 'filename': DATA_MANAGER.stamp}
if arguments['get_hooks']: 
    results['cka'] = trainer.cka_mean

In [None]:
model = model.eval()

In [None]:
# In-distribution evaluation
in_tester = find_right_model(
    TESTERS_DIR, 'InEvaluation',
    test_loader=test_loader,
    device=device,
    model=model
)
in_res, true_labels, all_preds, entropies = in_tester.evaluate()
for key, value in in_res.items():
    results[key] = value

In [None]:
results

In [None]:
# Adversarial evaluation
for attack in arguments['eval_attacks']:
    for epsilon in arguments['eval_epsilons']:
        out("Attack {}".format(attack))
        # load data
        (_, un_test_loader), mean, std = find_right_model(
            DATASETS, arguments['data_set'] + '_unnormalized',
            arguments=arguments,
            mean=arguments['mean'],
            std=arguments['std']
        )
        # build tester
        tester = find_right_model(
            TESTERS_DIR, 'AdversarialEvaluation',
            attack=attack,
            model=model,
            device=device,
            test_loader=un_test_loader,
            mean=mean,
            std=std
        )
        res = tester.evaluate(epsilon=epsilon, true_labels=deepcopy(true_labels), all_preds=deepcopy(all_preds),
                                  entropies=deepcopy(entropies))
        for key, value in res.items():
            results[key] = value

In [None]:
# OOD Evaluation
with torch.no_grad():
    for ood_data_set in arguments['eval_ood_data_sets']:
        out("OOD Dataset: {}".format(ood_data_set))

        # load OOD data
        _, ood_loader = find_right_model(
            DATASETS, ood_data_set,
            arguments=arguments,
            mean=arguments['mean'],
            std=arguments['std']
        )
        # build tester
        tester = find_right_model(
            TESTERS_DIR, 'OODEvaluation',
            model=model,
            device=device,
            ood_loader=ood_loader,
            ood_dataset=ood_data_set
        )
        res = tester.evaluate(true_labels=deepcopy(true_labels), all_preds=deepcopy(all_preds),
                              entropies=deepcopy(entropies))

        for key, value in res.items():
            results[key] = value

In [None]:
# DS Evaluation
with torch.no_grad():
    if arguments["data_set"] == "CIFAR10":
        avg_acc = [[] for _ in range(5)]
        avg_entropy = [[] for _ in range(5)]
        avg_auroc = [[] for _ in range(5)]
        avg_aupr = [[] for _ in range(5)]
        avg_auroc_ent = [[] for _ in range(5)]
        avg_aupr_ent = [[] for _ in range(5)]

        ds_path = os.path.join(DATASET_PATH, "cifar10_corrupted")

        for ds_dataset_name in os.listdir(ds_path):
            # Get corruption loader
            npz_dataset = np.load(os.path.join(ds_path, ds_dataset_name))
            ds_dataset = CIFAR10C(npz_dataset["images"], npz_dataset["labels"], arguments["mean"], arguments["std"])
            ds_loader = torch.utils.data.DataLoader(
                ds_dataset,
                batch_size=arguments['batch_size'],
                shuffle=False,
                pin_memory=True,
                num_workers=4
            )

            # build tester
            tester = find_right_model(
                TESTERS_DIR, 'DSEvaluation',
                model=model,
                device=device,
                ds_loader=ds_loader,
                ds_dataset=ds_dataset_name.split('.')[0]
            )
            res = tester.evaluate(true_labels=deepcopy(true_labels), all_preds=deepcopy(all_preds),
                                  entropies=deepcopy(entropies))

            severity = int(ds_dataset_name.split('.')[0].split('_')[-1]) - 1
            for key, value in res.items():
                if key.startswith('acc'):
                    avg_acc[severity].append(value)
                elif key.startswith('auroc_entropy'):
                    avg_auroc_ent[severity].append(value)
                elif key.startswith('aupr_entropy'):
                    avg_aupr_ent[severity].append(value)
                elif key.startswith('auroc'):
                    avg_auroc[severity].append(value)
                elif key.startswith('aupr'):
                    avg_aupr[severity].append(value)
                elif key.startswith('entropy_'):
                    avg_entropy[severity].append(value)

                results[key] = value

        avg_acc = [np.mean(acc) for acc in avg_acc]
        avg_auroc_ent = [np.mean(auroc_ent) for auroc_ent in avg_auroc_ent]
        avg_aupr_ent = [np.mean(aupr_ent) for aupr_ent in avg_aupr_ent]
        avg_auroc = [np.mean(auroc) for auroc in avg_auroc]
        avg_aupr = [np.mean(aupr) for aupr in avg_aupr]
        avg_entropy = [np.mean(entropy) for entropy in avg_entropy]

        for i in range(len(avg_acc)):
            name = 'avg_acc_' + str(i + 1)
            results[name] = avg_acc[i]
        for i in range(len(avg_acc)):
            name = 'avg_auroc_ent_' + str(i + 1)
            results[name] = avg_auroc_ent[i]
        for i in range(len(avg_acc)):
            name = 'avg_aupr_ent_' + str(i + 1)
            results[name] = avg_aupr_ent[i]
        for i in range(len(avg_acc)):
            name = 'avg_auroc_' + str(i + 1)
            results[name] = avg_auroc[i]
        for i in range(len(avg_acc)):
            name = 'avg_aupr_' + str(i + 1)
            results[name] = avg_aupr[i]
        for i in range(len(avg_acc)):
            name = 'avg_entropy_' + str(i + 1)
            results[name] = avg_entropy[i]
            
        results['avg_acc_cifar10c'] = np.mean(avg_acc)
        results['avg_auroc_ent_cifar10c'] = np.mean(avg_auroc_ent)
        results['avg_aupr_ent_cifar10c'] = np.mean(avg_aupr_ent)
        results['avg_auroc_cifar10c'] = np.mean(avg_auroc)
        results['avg_aupr_cifar10c'] = np.mean(avg_aupr)
        results['avg_entropy_cifar10c'] = np.mean(avg_entropy)

In [None]:
# Compute Lipschitz constant
# Don't compute gradient for the projector: speedup computations
for p in model.parameters():
    p.requires_grad = False

# Compute input sizes for all modules of the model
for img, target in train_loader:
    input_size = torch.unsqueeze(img[0], 0).size()
    break
compute_module_input_sizes(model, input_size)
lip_spec = lipschitz_spectral_ub(model.cpu()).data[0]
results['lip_spec'] = lip_spec

In [None]:
results