In [1]:
import os
os.chdir('/nfs/homedirs/ayle/guided-research/SNIP-it/')

In [2]:
import torch
from torchvision import datasets, transforms
import foolbox as fb
from experiments.main import load_checkpoint
from models import GeneralModel
from models.statistics.Metrics import Metrics
from utils.config_utils import *
from utils.model_utils import *
from utils.system_utils import *
from utils.attacks_utils import get_attack
from torch.utils.data.dataset import Dataset
from copy import deepcopy
from utils.metrics import calculate_aupr, calculate_auroc
from utils.attacks_utils import construct_adversarial_examples
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F
from torch.distributions import Categorical

In [3]:
arguments = dict({
'eval_freq': 1000,  # evaluate every n batches
    'save_freq': 1e6,  # save model every n epochs, besides before and after training
    'batch_size': 512,  # size of batches, for Imagenette 128
    'seed': 1234,  # random seed
    'max_training_minutes': 6120 , # one hour and a 45 minutes max, process killed after n minutes (after finish of epoch)
    'plot_weights_freq': 50, # plot pictures to tensorboard every n epochs
    'prune_freq': 1, # if pruning during training: how long to wait before starting
    'prune_delay': 0, # "if pruning during training: 't' from algorithm box, interval between pruning events, default=0
    'prune_to': 0,
    'epochs': 0,
    'rewind_to': 0, # rewind to this epoch if rewinding is done
    'snip_steps': 5, # 's' in algorithm box, number of pruning steps for 'rule of thumb', TODO
    'snip_iter': 1000,
    'pruning_rate': 0.0, # pruning rate passed to criterion at pruning event. however, most override this
    'growing_rate': 0.0000 , # grow back so much every epoch (for future criterions)
    'pruning_limit': 0.0,  # Prune until here, if structured in nodes, if unstructured in weights. most criterions use this instead of the pruning_rate
    'local_pruning': 0,
    'learning_rate': 2e-3,
    'grad_clip': 10,
    'grad_noise': 0 , # added gaussian noise to gradients
    'l2_reg': 5e-5 , # weight decay
    'l1_reg': 0 , # l1-norm regularisation
    'lp_reg': 0 , # lp regularisation with p < 1
    'l0_reg': 1.0 , # l0 reg lambda hyperparam
    'hoyer_reg': 0.001 , # hoyer reg lambda hyperparam
    'beta_ema': 0.999 , # l0 reg beta ema hyperparam

    'loss': 'CrossEntropy',
    'optimizer': 'ADAM',
    'model': 'ResNet18',  # ResNet not supported with structured
    'data_set': 'CIFAR10',
    'ood_data_set': 'SVHN',
    'ood_data_set_prune': 'SVHN',
    'prune_criterion': 'WeightImportance',  # options: SNIP, SNIPit, SNIPitDuring, UnstructuredRandom, GRASP, HoyerSquare, IMP, // SNAPit, StructuredRandom, GateDecorators, EfficientConvNets, GroupHoyerSquare
    'train_scheme': 'DefaultTrainer' , # default: DefaultTrainer
    'attack': 'FGSM',
    'epsilon': 6,
    'eval_ood_data_sets': ['SVHN', 'CIFAR100'],
    'eval_attacks': ['FGSM'],
    'eval_epsilons': [8, 48],

    'device': 'cuda',
    'results_dir': "tmp",

    'checkpoint_name': None,
    'checkpoint_model': None,

    'disable_cuda_benchmark': 1 , # speedup (disable) vs reproducibility (leave it)
    'eval': 0,
    'disable_autoconfig': 0 , # for the brave
    'preload_all_data': 0 , # load all data into ram memory for speedups
    'tuning': 0 , # splits trainset into train and validationset, omits test set

    'get_hooks': 0,
    'track_weights': 0 , # "keep statistics on the weights through training
    'disable_masking': 1 , # disable the ability to prune unstructured
    'enable_rewinding': 0, # enable the ability to rewind to previous weights
    'outer_layer_pruning': 1, # allow to prune outer layers (unstructured) or not (structured)
    'first_layer_dense': 0,
    'random_shuffle_labels': 0  ,# run with random-label experiment from zhang et al
    'l0': 0,  # run with l0 criterion, might overwrite some other arguments
    'hoyer_square': 0, # "run in unstructured DeephoyerSquare criterion, might overwrite some other arguments
    'group_hoyer_square': 0 ,# run in unstructured Group-DeephoyerSquare criterion, might overwrite some other arguments

    'disable_histograms': 0,
    'disable_saliency': 0,
    'disable_confusion': 0,
    'disable_weightplot': 0,
    'disable_netplot': 0,
    'skip_first_plot': 0,
    'disable_activations': 0,
    
#     'input_dim': [1, 28, 28],
#       'output_dim': 10,
#       'hidden_dim': [512],
#       'N': 60000,
    
    'input_dim': [3, 32, 32],
    'output_dim': 10,
    'hidden_dim': [512],
    'N': 60000,
    'mean': (0.4914, 0.4822, 0.4465),
    'std': (0.2471, 0.2435, 0.2616)
})

In [4]:
DATASET_PATH = '/nfs/students/ayle/guided-research/gitignored/data'

In [5]:
metrics = Metrics()
out = metrics.log_line
metrics._batch_size = arguments['batch_size']
metrics._eval_freq = arguments['eval_freq']
set_results_dir(arguments["results_dir"])

In [6]:
model: GeneralModel = find_right_model(
        NETWORKS_DIR, arguments['model'],
        device=arguments['device'],
        hidden_dim=arguments['hidden_dim'],
        input_dim=arguments['input_dim'],
        output_dim=arguments['output_dim'],
        is_maskable=arguments['disable_masking'],
        is_tracking_weights=arguments['track_weights'],
        is_rewindable=arguments['enable_rewinding'],
        is_growable=arguments['growing_rate'] > 0,
        outer_layer_pruning=arguments['outer_layer_pruning'],
        maintain_outer_mask_anyway=(
                                       not arguments['outer_layer_pruning']) and (
                                           "Structured" in arguments['prune_criterion']),
        l0=arguments['l0'],
        l0_reg=arguments['l0_reg'],
        N=arguments['N'],
        beta_ema=arguments['beta_ema'],
        l2_reg=arguments['l2_reg']
).to(arguments['device'])

10


In [7]:
load_checkpoint(arguments, model, out)

In [8]:
def load_checkpoint(path, model, out):
    with open(path, 'rb') as f:
        state = pickle.load(f)
    try:
        model.load_state_dict(state)
    except KeyError as e:
        print(list(state.keys()))
        raise e
    out(f"Loaded checkpoint {path}")
    
# def load_checkpoint(path, model, out):
#     state_dict = torch.load(path)
#     new_state_dict = {}
#     for key, val in state_dict.items():
#         if key == 'aug.width': continue
        
#         new_key = '.'.join(['m'] + key.split('.')[1:])
#         new_state_dict[new_key] = val
#     model.load_state_dict(new_state_dict)
#     out(f"Loaded checkpoint {path}")

In [9]:
# path = '/nfs/students/ayle/guided-research/results/Conv6/2021-07-12_03.45.39_model=Conv6_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/Conv6_finished.pickle'
# path = '/nfs/students/ayle/guided-research/results/Conv6/2021-07-12_04.48.17_model=Conv6_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=2345/models/Conv6_finished.pickle'
# path = '/nfs/students/ayle/guided-research/results/Conv6/2021-07-15_19.21.19_model=Conv6_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=3456/models/Conv6_finished.pickle'
# path = '/nfs/students/ayle/guided-research/results/Conv6/2021-07-15_19.25.40_model=Conv6_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=4567/models/Conv6_finished.pickle'
# path = '/nfs/students/ayle/guided-research/results/Conv6/2021-07-15_19.25.40_model=Conv6_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=4567/models/Conv6_finished.pickle'

# path = '/nfs/students/ayle/guided-research/results/LeNet5/2021-07-11_03.10.29_model=LeNet5_dataset=FASHION_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/LeNet5_finished.pickle'

# path = '/nfs/students/ayle/guided-research/results/ResNet18/2021-07-13_11.03.15_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle'

path = '/nfs/students/ayle/guided-research/results/ResNet18/2021-07-26_22.46.19_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle'
# path2 = '/nfs/students/ayle/guided-research/results/ResNet18/2021-07-26_23.33.17_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=2345/models/ResNet18_finished.pickle'
# path = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNet18/2021-07-26_23.35.18_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=3456/models/ResNet18_finished.pickle'
# path = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNet18/2021-07-26_23.35.46_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=4567/models/ResNet18_finished.pickle'
# path = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNet18/2021-07-26_23.36.23_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=5678/models/ResNet18_finished.pickle'

# path = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/VGG16/2021-08-22_11.02.10_model=VGG16_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/VGG16_finished.pickle'

# path = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNet18/2021-08-29_13.57.22_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle'

# path = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNet18/2021-08-29_18.19.04_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle'

# path = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNet18/2021-08-24_11.36.47_model=ResNet18_dataset=CIFAR10_prune-criterion=EarlyJohn_pruning-limit=0.94_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle'

# path = '/nfs/students/ayle/guided-research/gitignored/results/tests/2021-09-21_10.19.16_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle'

# model trained on augerino augmentations
# path = '/nfs/students/ayle/guided-research/gitignored/results/tests/2021-09-24_11.30.01_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle'

# augerino
# path = '/nfs/students/ayle/guided-research/gitignored/results/invariances/aug_no_trans_trained.pt'
# path = '/nfs/students/ayle/guided-research/gitignored/results/invariances/aug_fixed_trans_new_hyperparam_b512_trained.pt'

load_checkpoint(path, model, out)


Loaded checkpoint /nfs/students/ayle/guided-research/results/ResNet18/2021-07-26_22.46.19_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle


In [10]:
device = arguments['device']

In [11]:
# augmentation stuff
# width = torch.load("/nfs/students/ayle/guided-research/gitignored/results/invariances/aug_fixed_trans_trained_width.pt")
width = torch.load("/nfs/students/ayle/guided-research/gitignored/results/invariances/aug_fixed_trans_new_hyperparam_b512_trained_width.pt")
# width = torch.load("/nfs/students/ayle/guided-research/gitignored/results/invariances/aug_no_trans_new_hyperparam_b512_trained_width.pt")

width = width.cpu()
from augerino import models 
aug = models.UniformAug()
aug.set_width(width.data)

In [12]:
# load data
train_loader, test_loader = find_right_model(
    DATASETS, arguments['data_set'],
    arguments=arguments,
    mean=arguments['mean'],
    std=arguments['std']
)

# load OOD data
_, ood_loader = find_right_model(
    DATASETS, arguments['ood_data_set'],
    arguments=arguments,
    mean=arguments['mean'],
    std=arguments['std']
)

Using mean (0.4914, 0.4822, 0.4465)
Files already downloaded and verified
Files already downloaded and verified




Using mean (0.4914, 0.4822, 0.4465)
Using downloaded and verified file: /nfs/students/ayle/guided-research/gitignored/data/train_32x32.mat
Using downloaded and verified file: /nfs/students/ayle/guided-research/gitignored/data/test_32x32.mat


In [13]:
# load OOD data
ood_prune_loader, _ = find_right_model(
    DATASETS, arguments['ood_data_set_prune'],
    arguments=arguments,
    mean=arguments['mean'],
    std=arguments['std']
)

# get loss function
loss = find_right_model(
    LOSS_DIR, arguments['loss'],
    device=device,
    l1_reg=arguments['l1_reg'],
    lp_reg=arguments['lp_reg'],
    l0_reg=arguments['l0_reg'],
    hoyer_reg=arguments['hoyer_reg']
)

# get optimizer
optimizer = find_right_model(
    OPTIMS, arguments['optimizer'],
    params=model.parameters(),
    lr=arguments['learning_rate'],
    weight_decay=arguments['l2_reg'] if not arguments['l0'] else 0
)

Using mean (0.4914, 0.4822, 0.4465)
Using downloaded and verified file: /nfs/students/ayle/guided-research/gitignored/data/train_32x32.mat
Using downloaded and verified file: /nfs/students/ayle/guided-research/gitignored/data/test_32x32.mat


In [14]:
backup_model = deepcopy(model)

In [15]:
run_name = f'_model={arguments["model"]}_dataset={arguments["data_set"]}_ood-dataset={arguments["ood_data_set"]}' + \
           f'_attack={arguments["attack"]}_epsilon={arguments["epsilon"]}_prune-criterion={arguments["prune_criterion"]}' + \
           f'_pruning-limit={arguments["pruning_limit"]}_prune-freq={arguments["prune_freq"]}_prune-delay={arguments["prune_delay"]}' + \
           f'_rewind-to={arguments["rewind_to"]}_train-scheme={arguments["train_scheme"]}_seed={arguments["seed"]}'


criterion = find_right_model(
        CRITERION_DIR, arguments['prune_criterion'],
        model=model,
        limit=arguments['pruning_limit'],
        start=0.5,
        orig_scores=True,
        steps=arguments['snip_steps'],
        device=arguments['device'],
        arguments=arguments
    )

# build trainer
trainer = find_right_model(
    TRAINERS_DIR, arguments['train_scheme'],
    model=model,
    loss=loss,
    optimizer=optimizer,
    device=device,
    arguments=arguments,
    train_loader=train_loader,
    test_loader=test_loader,
    ood_loader=ood_loader,
    ood_prune_loader=ood_prune_loader,
    metrics=metrics,
    criterion=criterion,
    run_name=run_name
)

trainer.train()

Made datestamp: 2021-10-10_19.52.31_model=ResNet18_dataset=CIFAR10_ood-dataset=SVHN_attack=FGSM_epsilon=6_prune-criterion=WeightImportance_pruning-limit=0.0_prune-freq=1_prune-delay=0_rewind-to=0_train-scheme=DefaultTrainer_seed=1234


  0%|          | 0/98 [00:00<?, ?it/s]

[1mStarted training[0m


100%|██████████| 98/98 [02:44<00:00,  1.67s/it]


Saved results/tmp/2021-10-10_19.52.31_model=ResNet18_dataset=CIFAR10_ood-dataset=SVHN_attack=FGSM_epsilon=6_prune-criterion=WeightImportance_pruning-limit=0.0_prune-freq=1_prune-delay=0_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/output/scores


In [16]:
orig_grads = criterion.grads_abs

In [17]:
orig_mean = criterion.scores_mean
orig_std = criterion.scores_std
layer_names = list(orig_grads.keys())

In [18]:
# for name, val in orig_grads.items():
#     orig_grads[name] = (val.cpu() - orig_mean[name]) / (1e-8 + orig_std[name])

In [19]:
# backup_model = deepcopy(model)

In [20]:
arguments['batch_size'] = 2

# load data
train_loader, test_loader = find_right_model(
    DATASETS, arguments['data_set'],
    arguments=arguments,
    mean=arguments['mean'],
    std=arguments['std']
)

Using mean (0.4914, 0.4822, 0.4465)
Files already downloaded and verified
Files already downloaded and verified


In [21]:
from torchvision import datasets
from utils.constants import DATASET_PATH
mean=arguments['mean']
std=arguments['std']
test_set = datasets.CIFAR10(root=DATASET_PATH, train=False, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=arguments['batch_size'],
        shuffle=False,
        pin_memory=True,
        num_workers=4
    )

trans = transforms.ToPILImage()
trans1 = transforms.RandomHorizontalFlip(p=1.0)
trans2 = transforms.RandomCrop(32, padding=4)
trans3 = transforms.ToTensor()
trans4 = transforms.Normalize(mean, std)
trans5 = transforms.RandomRotation(30)
trans6 = transforms.RandomPerspective(p=1.0, distortion_scale=0.3)
trans7 = transforms.RandomVerticalFlip(p=1.0)

norms = []
per_layer_norms = []
for i, (x, y) in enumerate(tqdm(test_loader)):
    if i == int(len(test_loader)*0.1): break
    
    model = deepcopy(backup_model)
    model.eval()
        
    all_x = []
    for im in x:
        x0 = trans4(deepcopy(im).squeeze()).unsqueeze(0)
        all_x.append(x0)
        for _ in range(5):
            image = aug(deepcopy(x0))
            all_x.append(image.cpu())
        x1 = trans4(trans3(trans1(trans(deepcopy(im).squeeze())))).unsqueeze(0)
        x2 = trans4(trans3(trans7(trans(deepcopy(im).squeeze())))).unsqueeze(0)
        x3 = torch.rot90(trans4(deepcopy(im).squeeze()).unsqueeze(0), 1, [2, 3])
        x4 = torch.rot90(deepcopy(x3), 1, [2, 3])
        x5 = torch.rot90(deepcopy(x4), 1, [2, 3])
        all_x.extend([x1, x2, x3, x4, x5])
    x = torch.cat(all_x)
    
    x = x.cuda()

    out = model(x)
    preds = out.argmax(dim=-1, keepdim=True).flatten()

    train_loader = [(x, preds)]

#     train_loader = []
#     for im in x:
#         batch = []
        
#         batch.append(im.unsqueeze(0))
#         for _ in range(50):
#             new_x = aug(deepcopy(im.unsqueeze(0)))
#             batch.append(new_x)
#         batch = torch.cat(batch)
#         out = model(batch)
#         pred = out.argmax(dim=-1, keepdim=True).flatten()
        
#         train_loader.append((batch, pred))
        
    # get criterion
    criterion = find_right_model(
        CRITERION_DIR, arguments['prune_criterion'],
        model=model,
        limit=arguments['pruning_limit'],
        start=0.5,
        steps=arguments['snip_steps'],
        device=arguments['device'],
        arguments=arguments
    )
    
    criterion.prune(arguments['pruning_limit'],
                        train_loader=train_loader,
                      ood_loader=None,
                      local=arguments['local_pruning'],
                      manager=None)
    
    layer_norms = []
    for j, (grad1, grad2) in enumerate(zip(orig_grads.values(), criterion.grads_abs.values())):
        grad3 = (grad2.cpu() - orig_mean[layer_names[j]]) /  (1e-8 + orig_std[layer_names[j]])
#         grad3 = grad2
        layer_norms.append(torch.norm(grad1.cpu() - grad3, p=5).cpu().detach().numpy())
    per_layer_norms.append(layer_norms)
    norms.append(np.mean(layer_norms))

  0%|          | 0/5000 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 12.07it/s]
  0%|          | 1/5000 [00:05<7:49:30,  5.64s/it]
100%|██████████| 1/1 [00:00<00:00, 12.77it/s]
  0%|          | 2/5000 [00:06<3:43:10,  2.68s/it]
100%|██████████| 1/1 [00:00<00:00, 12.90it/s]
  0%|          | 3/5000 [00:06<2:24:09,  1.73s/it]
100%|██████████| 1/1 [00:00<00:00, 12.99it/s]
  0%|          | 4/5000 [00:07<1:46:46,  1.28s/it]
100%|██████████| 1/1 [00:00<00:00, 12.71it/s]
  0%|          | 5/5000 [00:08<1:26:14,  1.04s/it]
100%|██████████| 1/1 [00:00<00:00, 13.02it/s]
  0%|          | 6/5000 [00:08<1:13:35,  1.13it/s]
100%|██████████| 1/1 [00:00<00:00, 12.80it/s]
  0%|          | 7/5000 [00:09<1:05:59,  1.26it/s]
100%|██████████| 1/1 [00:00<00:00, 12.78it/s]
  0%|          | 8/5000 [00:09<1:00:42,  1.37it/s]
100%|██████████| 1/1 [00:00<00:00, 12.99it/s]
  0%|          | 9/5000 [00:10<57:02,  1.46it/s]  
100%|██████████| 1/1 [00:00<00:00, 12.98it/s]
  0%|          | 10/5000 [00:11<54:52,  

  2%|▏         | 80/5000 [00:52<48:29,  1.69it/s]
100%|██████████| 1/1 [00:00<00:00, 12.94it/s]
  2%|▏         | 81/5000 [00:53<48:40,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.88it/s]
  2%|▏         | 82/5000 [00:53<48:35,  1.69it/s]
100%|██████████| 1/1 [00:00<00:00, 12.81it/s]
  2%|▏         | 83/5000 [00:54<48:36,  1.69it/s]
100%|██████████| 1/1 [00:00<00:00, 12.69it/s]
  2%|▏         | 84/5000 [00:55<48:31,  1.69it/s]
100%|██████████| 1/1 [00:00<00:00, 12.89it/s]
  2%|▏         | 85/5000 [00:55<48:31,  1.69it/s]
100%|██████████| 1/1 [00:00<00:00, 12.92it/s]
  2%|▏         | 86/5000 [00:56<48:29,  1.69it/s]
100%|██████████| 1/1 [00:00<00:00, 12.83it/s]
  2%|▏         | 87/5000 [00:56<48:48,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.60it/s]
  2%|▏         | 88/5000 [00:57<48:39,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.81it/s]
  2%|▏         | 89/5000 [00:58<48:30,  1.69it/s]
100%|██████████| 1/1 [00:00<00:00, 12.70it/s]
  2%|▏         | 90/5000 [00:58<48:39,  

100%|██████████| 1/1 [00:00<00:00, 12.71it/s]
  3%|▎         | 165/5000 [01:43<47:49,  1.69it/s]
100%|██████████| 1/1 [00:00<00:00, 12.66it/s]
  3%|▎         | 166/5000 [01:43<47:58,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.82it/s]
  3%|▎         | 167/5000 [01:44<47:50,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.78it/s]
  3%|▎         | 168/5000 [01:45<47:47,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.80it/s]
  3%|▎         | 169/5000 [01:45<48:00,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.82it/s]
  3%|▎         | 170/5000 [01:46<48:09,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.86it/s]
  3%|▎         | 171/5000 [01:46<47:54,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.85it/s]
  3%|▎         | 172/5000 [01:47<47:45,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.70it/s]
  3%|▎         | 173/5000 [01:48<47:40,  1.69it/s]
100%|██████████| 1/1 [00:00<00:00, 12.75it/s]
  3%|▎         | 174/5000 [01:48<47:42,  1.69it/s]
100%|██████████| 1/1 [00:00<00

100%|██████████| 1/1 [00:00<00:00, 12.71it/s]
  5%|▍         | 249/5000 [02:33<47:14,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.64it/s]
  5%|▌         | 250/5000 [02:34<47:09,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.71it/s]
  5%|▌         | 251/5000 [02:34<47:00,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.72it/s]
  5%|▌         | 252/5000 [02:35<46:58,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.39it/s]
  5%|▌         | 253/5000 [02:35<47:00,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.66it/s]
  5%|▌         | 254/5000 [02:36<47:12,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.79it/s]
  5%|▌         | 255/5000 [02:37<47:05,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.69it/s]
  5%|▌         | 256/5000 [02:37<47:00,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.81it/s]
  5%|▌         | 257/5000 [02:38<46:52,  1.69it/s]
100%|██████████| 1/1 [00:00<00:00, 12.86it/s]
  5%|▌         | 258/5000 [02:38<46:51,  1.69it/s]
100%|██████████| 1/1 [00:00<00

100%|██████████| 1/1 [00:00<00:00, 12.74it/s]
  7%|▋         | 333/5000 [03:23<46:08,  1.69it/s]
100%|██████████| 1/1 [00:00<00:00, 12.66it/s]
  7%|▋         | 334/5000 [03:24<46:04,  1.69it/s]
100%|██████████| 1/1 [00:00<00:00, 12.43it/s]
  7%|▋         | 335/5000 [03:24<46:12,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.84it/s]
  7%|▋         | 336/5000 [03:25<46:11,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.66it/s]
  7%|▋         | 337/5000 [03:25<46:12,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.71it/s]
  7%|▋         | 338/5000 [03:26<46:11,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.70it/s]
  7%|▋         | 339/5000 [03:27<46:33,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.75it/s]
  7%|▋         | 340/5000 [03:27<46:22,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.82it/s]
  7%|▋         | 341/5000 [03:28<46:14,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.86it/s]
  7%|▋         | 342/5000 [03:28<46:54,  1.66it/s]
100%|██████████| 1/1 [00:00<00

100%|██████████| 1/1 [00:00<00:00, 12.71it/s]
  8%|▊         | 417/5000 [04:13<45:36,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.80it/s]
  8%|▊         | 418/5000 [04:14<46:02,  1.66it/s]
100%|██████████| 1/1 [00:00<00:00, 12.82it/s]
  8%|▊         | 419/5000 [04:14<45:49,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.74it/s]
  8%|▊         | 420/5000 [04:15<45:39,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.54it/s]
  8%|▊         | 421/5000 [04:16<46:03,  1.66it/s]
100%|██████████| 1/1 [00:00<00:00, 12.77it/s]
  8%|▊         | 422/5000 [04:16<45:47,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.74it/s]
  8%|▊         | 423/5000 [04:17<45:39,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.81it/s]
  8%|▊         | 424/5000 [04:17<45:53,  1.66it/s]
100%|██████████| 1/1 [00:00<00:00, 12.82it/s]
  8%|▊         | 425/5000 [04:18<45:39,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.64it/s]
  9%|▊         | 426/5000 [04:19<45:34,  1.67it/s]
100%|██████████| 1/1 [00:00<00

In [22]:
np.mean(norms)

2069.6052

In [34]:
arguments['batch_size'] = 1
test_set = datasets.CIFAR10(root=DATASET_PATH, train=False, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=arguments['batch_size'],
        shuffle=False,
        pin_memory=True,
        num_workers=4
    )
test_iter = iter(test_loader)

In [39]:
# OOD data

from torchvision import datasets
from utils.constants import DATASET_PATH
# SVHN
# ood_set = datasets.SVHN(root=DATASET_PATH, split='test', transform=transforms.ToTensor())

# CIFAR100
ood_set = datasets.CIFAR100(root=DATASET_PATH, train=False, transform=transforms.ToTensor())

# LSUN
# ood_set = datasets.LSUN(root=DATASET_PATH, classes='test', 
#                         transform=transforms.Compose([
#                             transforms.Resize(32),
#                             transforms.CenterCrop(32),
#                             transforms.ToTensor()]))

# Common
ood_loader = torch.utils.data.DataLoader(
        ood_set,
        batch_size=arguments['batch_size'],
        shuffle=False,
        pin_memory=True,
        num_workers=0
    )

trans = transforms.ToPILImage()
trans1 = transforms.RandomHorizontalFlip(p=1.0)
trans2 = transforms.RandomCrop(32, padding=4)
trans3 = transforms.ToTensor()
trans4 = transforms.Normalize(mean, std)
trans5 = transforms.RandomRotation(30)
trans6 = transforms.RandomPerspective(p=1.0, distortion_scale=0.3)
trans7 = transforms.RandomVerticalFlip(p=1.0)

ood_norms = []
ood_per_layer_norms = []

for i, (x, y) in enumerate(tqdm(ood_loader)):
    if i == int(len(test_loader)*0.1 / 2): break
    
    model = deepcopy(backup_model)
    model.eval()
    
    x = [x, next(test_iter)[0]]

    all_x = []
    for im in x:
        x0 = trans4(deepcopy(im).squeeze()).unsqueeze(0)
        all_x.append(x0)
        for _ in range(5):
            image = aug(deepcopy(x0))
            all_x.append(image.cpu())
        x1 = trans4(trans3(trans1(trans(deepcopy(im).squeeze())))).unsqueeze(0)
        x2 = trans4(trans3(trans7(trans(deepcopy(im).squeeze())))).unsqueeze(0)
        x3 = torch.rot90(trans4(deepcopy(im).squeeze()).unsqueeze(0), 1, [2, 3])
        x4 = torch.rot90(deepcopy(x3), 1, [2, 3])
        x5 = torch.rot90(deepcopy(x4), 1, [2, 3])
        all_x.extend([x1, x2, x3, x4, x5])
    x = torch.cat(all_x)
    
    x = x.cuda()

    out = model(x)
    preds = out.argmax(dim=-1, keepdim=True).flatten()

    train_loader = [(x, preds)]

#     train_loader = []
#     for im in x:
#         batch = []
        
#         batch.append(im.unsqueeze(0))
#         for _ in range(50):
#             new_x = aug(deepcopy(im.unsqueeze(0)))
#             batch.append(new_x)
#         batch = torch.cat(batch)
#         out = model(batch)
#         pred = out.argmax(dim=-1, keepdim=True).flatten()
        
#         train_loader.append((batch, pred))
    
    # get criterion
    criterion = find_right_model(
        CRITERION_DIR, arguments['prune_criterion'],
        model=model,
        limit=arguments['pruning_limit'],
        start=0.5,
        steps=arguments['snip_steps'],
        device=arguments['device'],
        arguments=arguments
    )
    
    criterion.prune(arguments['pruning_limit'],
                        train_loader=train_loader,
                      ood_loader=None,
                      local=arguments['local_pruning'],
                      manager=None)
    
    layer_norms = []
    for j, (grad1, grad2) in enumerate(zip(orig_grads.values(), criterion.grads_abs.values())):
        grad3 = (grad2.cpu() - orig_mean[layer_names[j]]) /  (1e-8 + orig_std[layer_names[j]])
#         grad3 = grad2
        layer_norms.append(torch.norm(grad1.cpu() - grad3, p=5).cpu().detach().numpy())
    per_layer_norms.append(layer_norms)
    ood_norms.append(np.mean(layer_norms))

  0%|          | 0/10000 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 12.31it/s]
  0%|          | 1/10000 [00:00<1:44:34,  1.59it/s]
100%|██████████| 1/1 [00:00<00:00, 12.36it/s]
  0%|          | 2/10000 [00:01<2:01:15,  1.37it/s]
100%|██████████| 1/1 [00:00<00:00, 12.73it/s]
  0%|          | 3/10000 [00:02<1:50:56,  1.50it/s]
100%|██████████| 1/1 [00:00<00:00, 12.76it/s]
  0%|          | 4/10000 [00:02<1:46:06,  1.57it/s]
100%|██████████| 1/1 [00:00<00:00, 12.86it/s]
  0%|          | 5/10000 [00:03<1:45:01,  1.59it/s]
100%|██████████| 1/1 [00:00<00:00, 12.65it/s]
  0%|          | 6/10000 [00:03<1:43:11,  1.61it/s]
100%|██████████| 1/1 [00:00<00:00, 12.59it/s]
  0%|          | 7/10000 [00:04<1:42:05,  1.63it/s]
100%|██████████| 1/1 [00:00<00:00, 12.16it/s]
  0%|          | 8/10000 [00:05<1:42:26,  1.63it/s]
100%|██████████| 1/1 [00:00<00:00, 12.76it/s]
  0%|          | 9/10000 [00:05<1:41:43,  1.64it/s]
100%|██████████| 1/1 [00:00<00:00, 12.72it/s]
  0%|          | 10/10000 [00:

100%|██████████| 1/1 [00:00<00:00, 12.61it/s]
  1%|          | 83/10000 [00:49<1:38:28,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.70it/s]
  1%|          | 84/10000 [00:50<1:38:31,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.82it/s]
  1%|          | 85/10000 [00:51<1:38:23,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.74it/s]
  1%|          | 86/10000 [00:51<1:38:31,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.75it/s]
  1%|          | 87/10000 [00:52<1:38:29,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.46it/s]
  1%|          | 88/10000 [00:52<1:38:44,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.68it/s]
  1%|          | 89/10000 [00:53<1:38:40,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.69it/s]
  1%|          | 90/10000 [00:54<1:38:25,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.72it/s]
  1%|          | 91/10000 [00:54<1:39:08,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.78it/s]
  1%|          | 92/10000 [00:55<1:38:56,  1.67it/s]
100%|█████

100%|██████████| 1/1 [00:00<00:00, 12.75it/s]
  2%|▏         | 165/10000 [01:38<1:37:50,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.69it/s]
  2%|▏         | 166/10000 [01:39<1:37:58,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.68it/s]
  2%|▏         | 167/10000 [01:40<1:37:39,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.71it/s]
  2%|▏         | 168/10000 [01:40<1:37:31,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.75it/s]
  2%|▏         | 169/10000 [01:41<1:37:39,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.71it/s]
  2%|▏         | 170/10000 [01:41<1:37:40,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.71it/s]
  2%|▏         | 171/10000 [01:42<1:37:40,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.79it/s]
  2%|▏         | 172/10000 [01:43<1:37:36,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.76it/s]
  2%|▏         | 173/10000 [01:43<1:37:23,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.61it/s]
  2%|▏         | 174/10000 [01:44<1:37:22,  1.68it/s]


  2%|▏         | 246/10000 [02:27<1:37:25,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.79it/s]
  2%|▏         | 247/10000 [02:28<1:37:21,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.55it/s]
  2%|▏         | 248/10000 [02:28<1:37:24,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.64it/s]
  2%|▏         | 249/10000 [02:29<1:37:34,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.81it/s]
  2%|▎         | 250/10000 [02:30<1:37:22,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.81it/s]
  3%|▎         | 251/10000 [02:30<1:37:06,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.70it/s]
  3%|▎         | 252/10000 [02:31<1:37:09,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.71it/s]
  3%|▎         | 253/10000 [02:31<1:36:55,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.79it/s]
  3%|▎         | 254/10000 [02:32<1:36:56,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.84it/s]
  3%|▎         | 255/10000 [02:33<1:37:02,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.86it/s]


100%|██████████| 1/1 [00:00<00:00, 12.79it/s]
  3%|▎         | 328/10000 [03:16<1:37:25,  1.65it/s]
100%|██████████| 1/1 [00:00<00:00, 12.71it/s]
  3%|▎         | 329/10000 [03:17<1:36:49,  1.66it/s]
100%|██████████| 1/1 [00:00<00:00, 12.70it/s]
  3%|▎         | 330/10000 [03:18<1:36:34,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.78it/s]
  3%|▎         | 331/10000 [03:18<1:36:09,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.66it/s]
  3%|▎         | 332/10000 [03:19<1:36:06,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.67it/s]
  3%|▎         | 333/10000 [03:19<1:36:09,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.69it/s]
  3%|▎         | 334/10000 [03:20<1:36:10,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.84it/s]
  3%|▎         | 335/10000 [03:20<1:36:07,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.70it/s]
  3%|▎         | 336/10000 [03:21<1:36:01,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.64it/s]
  3%|▎         | 337/10000 [03:22<1:36:06,  1.68it/s]


  4%|▍         | 409/10000 [04:05<1:37:36,  1.64it/s]
100%|██████████| 1/1 [00:00<00:00, 12.69it/s]
  4%|▍         | 410/10000 [04:06<1:37:03,  1.65it/s]
100%|██████████| 1/1 [00:00<00:00, 12.53it/s]
  4%|▍         | 411/10000 [04:06<1:36:45,  1.65it/s]
100%|██████████| 1/1 [00:00<00:00, 12.68it/s]
  4%|▍         | 412/10000 [04:07<1:35:54,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.88it/s]
  4%|▍         | 413/10000 [04:07<1:35:42,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.86it/s]
  4%|▍         | 414/10000 [04:08<1:35:53,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.72it/s]
  4%|▍         | 415/10000 [04:09<1:36:01,  1.66it/s]
100%|██████████| 1/1 [00:00<00:00, 12.69it/s]
  4%|▍         | 416/10000 [04:09<1:35:55,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.74it/s]
  4%|▍         | 417/10000 [04:10<1:35:50,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.72it/s]
  4%|▍         | 418/10000 [04:10<1:35:38,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.74it/s]


100%|██████████| 1/1 [00:00<00:00, 12.75it/s]
  5%|▍         | 491/10000 [04:54<1:34:52,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.90it/s]
  5%|▍         | 492/10000 [04:55<1:35:18,  1.66it/s]
100%|██████████| 1/1 [00:00<00:00, 12.68it/s]
  5%|▍         | 493/10000 [04:55<1:35:01,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.65it/s]
  5%|▍         | 494/10000 [04:56<1:34:35,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.48it/s]
  5%|▍         | 495/10000 [04:56<1:34:32,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.70it/s]
  5%|▍         | 496/10000 [04:57<1:34:34,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.80it/s]
  5%|▍         | 497/10000 [04:58<1:34:20,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.82it/s]
  5%|▍         | 498/10000 [04:58<1:34:30,  1.68it/s]
100%|██████████| 1/1 [00:00<00:00, 12.50it/s]
  5%|▍         | 499/10000 [04:59<1:34:45,  1.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.79it/s]
  5%|▌         | 500/10000 [04:59<1:34:56,  1.67it/s]


In [None]:
# Attacks
from torchvision import datasets
from utils.constants import DATASET_PATH
mean=arguments['mean']
std=arguments['std']
test_set = datasets.CIFAR10(root=DATASET_PATH, train=False, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=arguments['batch_size'],
        shuffle=False,
        pin_memory=True,
        num_workers=4
    )

trans = transforms.ToPILImage()
trans1 = transforms.RandomHorizontalFlip(p=1.0)
trans2 = transforms.RandomCrop(32, padding=4)
trans3 = transforms.ToTensor()
trans4 = transforms.Normalize(mean, std)
trans5 = transforms.RandomRotation(30)
trans6 = transforms.RandomPerspective(p=1.0, distortion_scale=0.3)
trans7 = transforms.RandomVerticalFlip(p=1.0)

# Attacks
ood_norms = []
attack_per_layer_norms = []

for i, (x, y) in enumerate(tqdm(test_loader)):
    if i == int(len(test_loader)*0.1): break
    
    model = deepcopy(backup_model)
    model.eval()
        
    adv_results, predictions = construct_adversarial_examples(x, y, 'FGSM', model, model.device, 8, False, False)
    _, advs, success = adv_results
    x = advs.cpu()
    
    all_x = []
    for im in x:
        x0 = trans4(deepcopy(im).squeeze()).unsqueeze(0)
        all_x.append(x0)
        for _ in range(5):
            image = aug(deepcopy(x0))
            all_x.append(image.cpu())
        x1 = trans4(trans3(trans1(trans(deepcopy(im).squeeze())))).unsqueeze(0)
        x2 = trans4(trans3(trans7(trans(deepcopy(im).squeeze())))).unsqueeze(0)
        x3 = torch.rot90(trans4(deepcopy(im).squeeze()).unsqueeze(0), 1, [2, 3])
        x4 = torch.rot90(deepcopy(x3), 1, [2, 3])
        x5 = torch.rot90(deepcopy(x4), 1, [2, 3])
        all_x.extend([x1, x2, x3, x4, x5])
    x = torch.cat(all_x)
    
    x = x.cuda()

    out = model(x)
    preds = out.argmax(dim=-1, keepdim=True).flatten()

    train_loader = [(x, preds)]

#     train_loader = []
#     for im in x:
#         batch = []
        
#         batch.append(im.unsqueeze(0))
#         for _ in range(5):
#             new_x = aug(deepcopy(im.unsqueeze(0)))
#             batch.append(new_x)
#         batch = torch.cat(batch)
#         out = model(batch)
#         pred = out.argmax(dim=-1, keepdim=True).flatten()
        
#         train_loader.append((batch, pred))
    
    # get criterion
    criterion = find_right_model(
        CRITERION_DIR, arguments['prune_criterion'],
        model=model,
        limit=arguments['pruning_limit'],
        start=0.5,
        steps=arguments['snip_steps'],
        device=arguments['device'],
        arguments=arguments
    )
    
    criterion.prune(arguments['pruning_limit'],
                    train_loader=train_loader,
                      ood_loader=None,
                      local=arguments['local_pruning'],
                      manager=None)
    
    layer_norms = []
    for j, (grad1, grad2) in enumerate(zip(orig_grads.values(), criterion.grads_abs.values())):
        grad3 = (grad2.cpu() - orig_mean[layer_names[j]]) /  (1e-8 + orig_std[layer_names[j]])
#         grad3 = grad2
        layer_norms.append(torch.norm(grad1.cpu() - grad3, p=5).cpu().detach().numpy())
    attack_per_layer_norms.append(layer_norms)
    ood_norms.append(np.mean(layer_norms))

In [None]:
# DS

ds_path = os.path.join(DATASET_PATH, "cifar10_corrupted")
aurocs = []
ds_per_layer_norms = []
        
for ds_dataset_name in os.listdir(ds_path):
    if ds_dataset_name.endswith('5.npz'):
        npz_dataset = np.load(os.path.join(ds_path, ds_dataset_name))

        ds_dataset = CIFAR10C(npz_dataset["images"], npz_dataset["labels"], arguments["mean"], arguments["std"])
        ds_loader = torch.utils.data.DataLoader(
            ds_dataset,
            batch_size=arguments['batch_size'],
            shuffle=False,
            pin_memory=True,
            num_workers=4
        )

        ood_norms = []
        for i, (x, y) in enumerate(tqdm(ds_loader)):
            if i == int(len(test_loader)*0.1): break
            
            model = deepcopy(backup_model)
            model.eval()
            
            x = x.cuda()
            
            new_x = [x]
            for im in x:
                for _ in range(5):
                    image = aug(im.unsqueeze(0))
                    new_x.append(image)
            x = torch.cat(new_x)
            
            out = model(x)
            preds = out.argmax(dim=-1, keepdim=True).flatten()

            train_loader = [(x, preds)]

            # get criterion
            criterion = find_right_model(
                CRITERION_DIR, arguments['prune_criterion'],
                model=model,
                limit=arguments['pruning_limit'],
                start=0.5,
                steps=arguments['snip_steps'],
                device=arguments['device'],
                arguments=arguments
            )

            criterion.prune(arguments['pruning_limit'],
                              train_loader=train_loader,
                              ood_loader=None,
                              local=arguments['local_pruning'],
                              manager=None)

            layer_norms = []
            for j, (grad1, grad2) in enumerate(zip(orig_grads.values(), criterion.grads_abs.values())):
                grad3 = (grad2.cpu() - orig_mean[layer_names[j]]) /  (1e-8 + orig_std[layer_names[j]])
    #                 grad3 = grad2
                layer_norms.append(torch.norm(grad1 - grad3, p=5).cpu().detach().numpy())
            ds_per_layer_norms.append(layer_norms)
            ood_norms.append(np.mean(layer_norms))

        auroc = calculate_auroc(np.concatenate((np.zeros_like(norms), np.ones_like(ood_norms))), np.concatenate((norms, ood_norms)))
        print('AUROC', auroc)
        aurocs.append(auroc)
        
print('Mean AUROC', np.mean(aurocs))

In [40]:
np.mean(ood_norms)

2910.608

In [41]:
norms = np.array(norms)
ood_norms = np.array(ood_norms)

In [42]:
calculate_auroc(np.concatenate((np.zeros_like(norms), np.ones_like(ood_norms))), np.concatenate((norms, ood_norms)))

0.7484600000000001

In [None]:
0.932336

In [None]:
# 0.9999251152073733 batch size 10 dropout eval model train, not layers 0 3 7
# 0.9999873271889401 batch size 10 model train, not layers 0 3 7
# 0.9999592933947773 batch size 10 model train, all layers

In [None]:
plt.plot(np.array(per_layer_norms).mean(0), label='In-distribution')
plt.plot(np.array(ood_per_layer_norms).mean(0), label='Out-Of-Distribution')
# plt.plot(np.array(attack_per_layer_norms).mean(0), label='Attack (FGSM \u03B5 = 8)')
# plt.plot(np.array(ds_per_layer_norms).mean(0), label='Distribution Shifts')
plt.legend()
plt.xlabel('ResNet18 layers')
plt.ylabel('L2 distance to original scores')

In [None]:
# np.array(ood_per_layer_norms).mean(0)

In [None]:
thresholds, _ = torch.topk(torch.tensor(norms), int(len(norms)*0.05), sorted=True)

In [None]:
threshold = thresholds[-1]

In [None]:
num_detected = 0
for ood_norm in ood_norms:
    if ood_norm >= threshold:
        num_detected += 1

In [None]:
num_detected / len(ood_norms)