In [1]:
import os
os.chdir('/nfs/homedirs/ayle/guided-research/SNIP-it/')

In [2]:
import torch
from torchvision import datasets, transforms
import foolbox as fb
from experiments.main import load_checkpoint
from models import GeneralModel
from models.statistics.Metrics import Metrics
from utils.config_utils import *
from utils.model_utils import *
from utils.system_utils import *
from utils.attacks_utils import get_attack
from torch.utils.data.dataset import Dataset
from copy import deepcopy
from utils.metrics import calculate_aupr, calculate_auroc
from utils.attacks_utils import construct_adversarial_examples
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F
from torch.distributions import Categorical

In [3]:
arguments = dict({
'eval_freq': 1000,  # evaluate every n batches
    'save_freq': 1e6,  # save model every n epochs, besides before and after training
    'batch_size': 512,  # size of batches, for Imagenette 128
    'seed': 1234,  # random seed
    'max_training_minutes': 6120 , # one hour and a 45 minutes max, process killed after n minutes (after finish of epoch)
    'plot_weights_freq': 50, # plot pictures to tensorboard every n epochs
    'prune_freq': 1, # if pruning during training: how long to wait before starting
    'prune_delay': 0, # "if pruning during training: 't' from algorithm box, interval between pruning events, default=0
    'prune_to': 0,
    'epochs': 0,
    'rewind_to': 0, # rewind to this epoch if rewinding is done
    'snip_steps': 5, # 's' in algorithm box, number of pruning steps for 'rule of thumb', TODO
    'snip_iter': 1000,
    'pruning_rate': 0.0, # pruning rate passed to criterion at pruning event. however, most override this
    'growing_rate': 0.0000 , # grow back so much every epoch (for future criterions)
    'pruning_limit': 0.0,  # Prune until here, if structured in nodes, if unstructured in weights. most criterions use this instead of the pruning_rate
    'local_pruning': 0,
    'learning_rate': 2e-3,
    'grad_clip': 10,
    'grad_noise': 0 , # added gaussian noise to gradients
    'l2_reg': 5e-5 , # weight decay
    'l1_reg': 0 , # l1-norm regularisation
    'lp_reg': 0 , # lp regularisation with p < 1
    'l0_reg': 1.0 , # l0 reg lambda hyperparam
    'hoyer_reg': 0.001 , # hoyer reg lambda hyperparam
    'beta_ema': 0.999 , # l0 reg beta ema hyperparam

    'loss': 'CrossEntropy',
    'optimizer': 'ADAM',
    'model': 'ResNet18',  # ResNet not supported with structured
    'data_set': 'CIFAR10',
    'ood_data_set': 'SVHN',
    'ood_data_set_prune': 'SVHN',
    'prune_criterion': 'WeightImportance',  # options: SNIP, SNIPit, SNIPitDuring, UnstructuredRandom, GRASP, HoyerSquare, IMP, // SNAPit, StructuredRandom, GateDecorators, EfficientConvNets, GroupHoyerSquare
    'train_scheme': 'DefaultTrainer' , # default: DefaultTrainer
    'attack': 'FGSM',
    'epsilon': 6,
    'eval_ood_data_sets': ['SVHN', 'CIFAR100'],
    'eval_attacks': ['FGSM'],
    'eval_epsilons': [8, 48],

    'device': 'cuda',
    'results_dir': "tmp",

    'checkpoint_name': None,
    'checkpoint_model': None,

    'disable_cuda_benchmark': 1 , # speedup (disable) vs reproducibility (leave it)
    'eval': 0,
    'disable_autoconfig': 0 , # for the brave
    'preload_all_data': 0 , # load all data into ram memory for speedups
    'tuning': 0 , # splits trainset into train and validationset, omits test set

    'get_hooks': 0,
    'track_weights': 0 , # "keep statistics on the weights through training
    'disable_masking': 1 , # disable the ability to prune unstructured
    'enable_rewinding': 0, # enable the ability to rewind to previous weights
    'outer_layer_pruning': 1, # allow to prune outer layers (unstructured) or not (structured)
    'first_layer_dense': 0,
    'random_shuffle_labels': 0  ,# run with random-label experiment from zhang et al
    'l0': 0,  # run with l0 criterion, might overwrite some other arguments
    'hoyer_square': 0, # "run in unstructured DeephoyerSquare criterion, might overwrite some other arguments
    'group_hoyer_square': 0 ,# run in unstructured Group-DeephoyerSquare criterion, might overwrite some other arguments

    'disable_histograms': 0,
    'disable_saliency': 0,
    'disable_confusion': 0,
    'disable_weightplot': 0,
    'disable_netplot': 0,
    'skip_first_plot': 0,
    'disable_activations': 0,
    
#     'input_dim': [1, 28, 28],
#       'output_dim': 10,
#       'hidden_dim': [512],
#       'N': 60000,
    
    'input_dim': [3, 32, 32],
      'output_dim': 10,
      'hidden_dim': [512],
      'N': 60000
})

In [4]:
DATASET_PATH = '/nfs/students/ayle/guided-research/gitignored/data'

In [5]:
metrics = Metrics()
out = metrics.log_line
metrics._batch_size = arguments['batch_size']
metrics._eval_freq = arguments['eval_freq']
set_results_dir(arguments["results_dir"])

In [6]:
model: GeneralModel = find_right_model(
        NETWORKS_DIR, arguments['model'],
        device=arguments['device'],
        hidden_dim=arguments['hidden_dim'],
        input_dim=arguments['input_dim'],
        output_dim=arguments['output_dim'],
        is_maskable=arguments['disable_masking'],
        is_tracking_weights=arguments['track_weights'],
        is_rewindable=arguments['enable_rewinding'],
        is_growable=arguments['growing_rate'] > 0,
        outer_layer_pruning=arguments['outer_layer_pruning'],
        maintain_outer_mask_anyway=(
                                       not arguments['outer_layer_pruning']) and (
                                           "Structured" in arguments['prune_criterion']),
        l0=arguments['l0'],
        l0_reg=arguments['l0_reg'],
        N=arguments['N'],
        beta_ema=arguments['beta_ema'],
        l2_reg=arguments['l2_reg']
).to(arguments['device'])

10


In [7]:
load_checkpoint(arguments, model, out)

In [8]:
def load_checkpoint(path, model, out):
    with open(path, 'rb') as f:
        state = pickle.load(f)
    try:
        model.load_state_dict(state)
    except KeyError as e:
        print(list(state.keys()))
        raise e
    out(f"Loaded checkpoint {path}")

In [9]:
# path = '/nfs/students/ayle/guided-research/results/Conv6/2021-07-12_03.45.39_model=Conv6_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/Conv6_finished.pickle'
# path = '/nfs/students/ayle/guided-research/results/Conv6/2021-07-12_04.48.17_model=Conv6_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=2345/models/Conv6_finished.pickle'
# path = '/nfs/students/ayle/guided-research/results/Conv6/2021-07-15_19.21.19_model=Conv6_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=3456/models/Conv6_finished.pickle'
# path = '/nfs/students/ayle/guided-research/results/Conv6/2021-07-15_19.25.40_model=Conv6_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=4567/models/Conv6_finished.pickle'
# path = '/nfs/students/ayle/guided-research/results/Conv6/2021-07-15_19.25.40_model=Conv6_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=4567/models/Conv6_finished.pickle'

# path = '/nfs/students/ayle/guided-research/results/LeNet5/2021-07-11_03.10.29_model=LeNet5_dataset=FASHION_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/LeNet5_finished.pickle'

# path = '/nfs/students/ayle/guided-research/results/ResNet18/2021-07-13_11.03.15_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle'

# path = '/nfs/students/ayle/guided-research/results/ResNet18/2021-07-26_22.46.19_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle'
# path2 = '/nfs/students/ayle/guided-research/results/ResNet18/2021-07-26_23.33.17_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=2345/models/ResNet18_finished.pickle'
# path = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNet18/2021-07-26_23.35.18_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=3456/models/ResNet18_finished.pickle'
# path = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNet18/2021-07-26_23.35.46_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=4567/models/ResNet18_finished.pickle'
# path = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNet18/2021-07-26_23.36.23_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=5678/models/ResNet18_finished.pickle'

# path = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/VGG16/2021-08-22_11.02.10_model=VGG16_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/VGG16_finished.pickle'

# path = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNet18/2021-08-29_13.57.22_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle'

# path = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNet18/2021-08-29_18.19.04_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle'

# path = '/nfs/homedirs/ayle/guided-research/SNIP-it/gitignored/results/ResNet18/2021-08-24_11.36.47_model=ResNet18_dataset=CIFAR10_prune-criterion=EarlyJohn_pruning-limit=0.94_prune-freq=1_prune-delay=0_outer-layer-pruning=1_prune-to=10_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle'

path = '/nfs/students/ayle/guided-research/gitignored/results/tests/2021-09-21_10.19.16_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle'

load_checkpoint(path, model, out)


Loaded checkpoint /nfs/students/ayle/guided-research/gitignored/results/tests/2021-09-21_10.19.16_model=ResNet18_dataset=CIFAR10_prune-criterion=EmptyCrit_pruning-limit=0.0_train-scheme=DefaultTrainer_seed=1234/models/ResNet18_finished.pickle


In [10]:
device = arguments['device']

In [11]:
# load data
train_loader, test_loader = find_right_model(
    DATASETS, arguments['data_set'],
    arguments=arguments
)

# load OOD data
_, ood_loader = find_right_model(
    DATASETS, arguments['ood_data_set'],
    arguments=arguments
)

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: /nfs/students/ayle/guided-research/gitignored/data/train_32x32.mat
Using downloaded and verified file: /nfs/students/ayle/guided-research/gitignored/data/test_32x32.mat


In [12]:
# load OOD data
ood_prune_loader, _ = find_right_model(
    DATASETS, arguments['ood_data_set_prune'],
    arguments=arguments
)

# get loss function
loss = find_right_model(
    LOSS_DIR, arguments['loss'],
    device=device,
    l1_reg=arguments['l1_reg'],
    lp_reg=arguments['lp_reg'],
    l0_reg=arguments['l0_reg'],
    hoyer_reg=arguments['hoyer_reg']
)

# get optimizer
optimizer = find_right_model(
    OPTIMS, arguments['optimizer'],
    params=model.parameters(),
    lr=arguments['learning_rate'],
    weight_decay=arguments['l2_reg'] if not arguments['l0'] else 0
)

Using downloaded and verified file: /nfs/students/ayle/guided-research/gitignored/data/train_32x32.mat
Using downloaded and verified file: /nfs/students/ayle/guided-research/gitignored/data/test_32x32.mat


In [13]:
backup_model = deepcopy(model)

In [14]:
run_name = f'_model={arguments["model"]}_dataset={arguments["data_set"]}_ood-dataset={arguments["ood_data_set"]}' + \
           f'_attack={arguments["attack"]}_epsilon={arguments["epsilon"]}_prune-criterion={arguments["prune_criterion"]}' + \
           f'_pruning-limit={arguments["pruning_limit"]}_prune-freq={arguments["prune_freq"]}_prune-delay={arguments["prune_delay"]}' + \
           f'_rewind-to={arguments["rewind_to"]}_train-scheme={arguments["train_scheme"]}_seed={arguments["seed"]}'


criterion = find_right_model(
        CRITERION_DIR, arguments['prune_criterion'],
        model=model,
        limit=arguments['pruning_limit'],
        start=0.5,
        orig_scores=True,
        steps=arguments['snip_steps'],
        device=arguments['device'],
        arguments=arguments
    )

# build trainer
trainer = find_right_model(
    TRAINERS_DIR, arguments['train_scheme'],
    model=model,
    loss=loss,
    optimizer=optimizer,
    device=device,
    arguments=arguments,
    train_loader=train_loader,
    test_loader=test_loader,
    ood_loader=ood_loader,
    ood_prune_loader=ood_prune_loader,
    metrics=metrics,
    criterion=criterion,
    run_name=run_name
)

trainer.train()

Made datestamp: 2021-09-22_10.50.19_model=ResNet18_dataset=CIFAR10_ood-dataset=SVHN_attack=FGSM_epsilon=6_prune-criterion=WeightImportance_pruning-limit=0.0_prune-freq=1_prune-delay=0_rewind-to=0_train-scheme=DefaultTrainer_seed=1234


  0%|          | 0/98 [00:00<?, ?it/s]

[1mStarted training[0m


100%|██████████| 98/98 [00:23<00:00,  4.18it/s]
100%|██████████| 98/98 [00:20<00:00,  4.69it/s]


Saved results/tmp/2021-09-22_10.50.19_model=ResNet18_dataset=CIFAR10_ood-dataset=SVHN_attack=FGSM_epsilon=6_prune-criterion=WeightImportance_pruning-limit=0.0_prune-freq=1_prune-delay=0_rewind-to=0_train-scheme=DefaultTrainer_seed=1234/output/scores


In [15]:
orig_grads = criterion.grads_abs

In [16]:
orig_mean = criterion.scores_mean
orig_std = criterion.scores_std
layer_names = list(orig_grads.keys())

In [17]:
for name, val in orig_grads.items():
    orig_grads[name] = (val - orig_mean[name]) / (1e-8 + orig_std[name])

In [18]:
# backup_model = deepcopy(model)

In [19]:
arguments['batch_size'] = 1

# load data
train_loader, test_loader = find_right_model(
    DATASETS, arguments['data_set'],
    arguments=arguments
)

# load OOD data
_, ood_loader = find_right_model(
    DATASETS, arguments['ood_data_set'],
    arguments=arguments
)

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: /nfs/students/ayle/guided-research/gitignored/data/train_32x32.mat
Using downloaded and verified file: /nfs/students/ayle/guided-research/gitignored/data/test_32x32.mat


In [20]:
norms = []
per_layer_norms = []
for i, (x, y) in enumerate(tqdm(test_loader)):
    if i == int(len(test_loader)*0.1): break
        
    train_x, train_y = next(iter(train_loader))
    
    model = deepcopy(backup_model)
    model.eval()
        
    batch_x = torch.cat((train_x, x))
    batch_y = torch.cat((train_y, y))
    out = model(batch_x.cuda())
    preds = out.argmax(dim=-1, keepdim=True).view_as(batch_y)
        
    # get criterion
    criterion = find_right_model(
        CRITERION_DIR, arguments['prune_criterion'],
        model=model,
        limit=arguments['pruning_limit'],
        start=0.5,
        steps=arguments['snip_steps'],
        device=arguments['device'],
        arguments=arguments
    )
    
    criterion.prune(arguments['pruning_limit'],
                        train_loader=[(batch_x, preds)],
                      ood_loader=None,
                      local=arguments['local_pruning'],
                      manager=None)
    
    layer_norms = []
    for j, (grad1, grad2) in enumerate(zip(orig_grads.values(), criterion.grads_abs.values())):
        grad3 = (grad2 - orig_mean[layer_names[j]]) /  (1e-8 + orig_std[layer_names[j]])
        layer_norms.append(torch.norm(grad1 - grad3, p=2).cpu().detach().numpy())

    train_x2, train_y2 = next(iter(train_loader))
    
    batch_x = torch.cat((train_x, train_x2))
    batch_y = torch.cat((train_y, train_y2))
    out = model(batch_x.cuda())
    preds = out.argmax(dim=-1, keepdim=True).view_as(batch_y)
        
    # get criterion
    criterion = find_right_model(
        CRITERION_DIR, arguments['prune_criterion'],
        model=model,
        limit=arguments['pruning_limit'],
        start=0.5,
        steps=arguments['snip_steps'],
        device=arguments['device'],
        arguments=arguments
    )
    
    criterion.prune(arguments['pruning_limit'],
                        train_loader=[(batch_x, preds)],
                      ood_loader=None,
                      local=arguments['local_pruning'],
                      manager=None)
    
    layer_norms_train = []
    for j, (grad1, grad2) in enumerate(zip(orig_grads.values(), criterion.grads_abs.values())):
        grad3 = (grad2 - orig_mean[layer_names[j]]) /  (1e-8 + orig_std[layer_names[j]])
        layer_norms_train.append(torch.norm(grad1 - grad3, p=2).cpu().detach().numpy())
    
    norms.append(np.mean(layer_norms) - np.mean(layer_norms_train))

  0%|          | 0/10000 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 66.84it/s]

100%|██████████| 1/1 [00:00<00:00, 112.64it/s]
  0%|          | 1/10000 [00:01<4:09:37,  1.50s/it]
100%|██████████| 1/1 [00:00<00:00, 118.46it/s]

100%|██████████| 1/1 [00:00<00:00, 114.34it/s]
  0%|          | 2/10000 [00:02<3:39:43,  1.32s/it]
100%|██████████| 1/1 [00:00<00:00, 124.79it/s]

100%|██████████| 1/1 [00:00<00:00, 116.44it/s]
  0%|          | 3/10000 [00:03<3:29:47,  1.26s/it]
100%|██████████| 1/1 [00:00<00:00, 125.90it/s]

100%|██████████| 1/1 [00:00<00:00, 114.30it/s]
  0%|          | 4/10000 [00:05<3:25:26,  1.23s/it]
100%|██████████| 1/1 [00:00<00:00, 125.70it/s]

100%|██████████| 1/1 [00:00<00:00, 111.15it/s]
  0%|          | 5/10000 [00:06<3:22:09,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.78it/s]

100%|██████████| 1/1 [00:00<00:00, 116.19it/s]
  0%|          | 6/10000 [00:07<3:21:17,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 126.43it/s]

100%|██████████| 1/1 [00:00<0

100%|██████████| 1/1 [00:00<00:00, 125.47it/s]

100%|██████████| 1/1 [00:00<00:00, 117.55it/s]
  1%|          | 56/10000 [01:07<3:19:51,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.70it/s]

100%|██████████| 1/1 [00:00<00:00, 117.12it/s]
  1%|          | 57/10000 [01:08<3:19:42,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 121.42it/s]

100%|██████████| 1/1 [00:00<00:00, 117.97it/s]
  1%|          | 58/10000 [01:09<3:19:33,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.39it/s]

100%|██████████| 1/1 [00:00<00:00, 117.82it/s]
  1%|          | 59/10000 [01:11<3:19:55,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 126.36it/s]

100%|██████████| 1/1 [00:00<00:00, 115.39it/s]
  1%|          | 60/10000 [01:12<3:19:51,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.15it/s]

100%|██████████| 1/1 [00:00<00:00, 117.16it/s]
  1%|          | 61/10000 [01:13<3:18:51,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.79it/s]

100%|██████████| 1/1 [00:00<00:00, 117.75it/s]
  1%|          | 

100%|██████████| 1/1 [00:00<00:00, 125.71it/s]

100%|██████████| 1/1 [00:00<00:00, 118.37it/s]
  1%|          | 111/10000 [02:13<3:17:50,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 123.26it/s]

100%|██████████| 1/1 [00:00<00:00, 117.93it/s]
  1%|          | 112/10000 [02:14<3:17:23,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.99it/s]

100%|██████████| 1/1 [00:00<00:00, 115.82it/s]
  1%|          | 113/10000 [02:16<3:17:18,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 123.21it/s]

100%|██████████| 1/1 [00:00<00:00, 117.58it/s]
  1%|          | 114/10000 [02:17<3:17:06,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.59it/s]

100%|██████████| 1/1 [00:00<00:00, 117.67it/s]
  1%|          | 115/10000 [02:18<3:17:24,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 126.18it/s]

100%|██████████| 1/1 [00:00<00:00, 114.28it/s]
  1%|          | 116/10000 [02:19<3:17:28,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.12it/s]

100%|██████████| 1/1 [00:00<00:00, 118.35it/s]
  1%|      

  2%|▏         | 165/10000 [03:18<3:17:26,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.35it/s]

100%|██████████| 1/1 [00:00<00:00, 116.49it/s]
  2%|▏         | 166/10000 [03:19<3:20:23,  1.22s/it]
100%|██████████| 1/1 [00:00<00:00, 125.05it/s]

100%|██████████| 1/1 [00:00<00:00, 115.52it/s]
  2%|▏         | 167/10000 [03:21<3:19:34,  1.22s/it]
100%|██████████| 1/1 [00:00<00:00, 124.62it/s]

100%|██████████| 1/1 [00:00<00:00, 115.77it/s]
  2%|▏         | 168/10000 [03:22<3:18:44,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.93it/s]

100%|██████████| 1/1 [00:00<00:00, 118.23it/s]
  2%|▏         | 169/10000 [03:23<3:18:21,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.18it/s]

100%|██████████| 1/1 [00:00<00:00, 115.43it/s]
  2%|▏         | 170/10000 [03:24<3:18:04,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.78it/s]

100%|██████████| 1/1 [00:00<00:00, 113.26it/s]
  2%|▏         | 171/10000 [03:25<3:17:49,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.73it/s]

100%

100%|██████████| 1/1 [00:00<00:00, 117.61it/s]
  2%|▏         | 220/10000 [04:24<3:15:28,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.46it/s]

100%|██████████| 1/1 [00:00<00:00, 117.71it/s]
  2%|▏         | 221/10000 [04:25<3:14:37,  1.19s/it]
100%|██████████| 1/1 [00:00<00:00, 125.12it/s]

100%|██████████| 1/1 [00:00<00:00, 118.81it/s]
  2%|▏         | 222/10000 [04:27<3:14:40,  1.19s/it]
100%|██████████| 1/1 [00:00<00:00, 124.60it/s]

100%|██████████| 1/1 [00:00<00:00, 118.17it/s]
  2%|▏         | 223/10000 [04:28<3:15:12,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 121.04it/s]

100%|██████████| 1/1 [00:00<00:00, 114.11it/s]
  2%|▏         | 224/10000 [04:29<3:15:39,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 123.24it/s]

100%|██████████| 1/1 [00:00<00:00, 115.51it/s]
  2%|▏         | 225/10000 [04:30<3:15:41,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 123.67it/s]

100%|██████████| 1/1 [00:00<00:00, 116.98it/s]
  2%|▏         | 226/10000 [04:31<3:15:23,  1.20s/it]
100%|

100%|██████████| 1/1 [00:00<00:00, 122.88it/s]

100%|██████████| 1/1 [00:00<00:00, 118.51it/s]
  3%|▎         | 275/10000 [05:31<3:14:56,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 123.80it/s]

100%|██████████| 1/1 [00:00<00:00, 116.17it/s]
  3%|▎         | 276/10000 [05:32<3:14:41,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.16it/s]

100%|██████████| 1/1 [00:00<00:00, 115.02it/s]
  3%|▎         | 277/10000 [05:33<3:14:56,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.32it/s]

100%|██████████| 1/1 [00:00<00:00, 115.42it/s]
  3%|▎         | 278/10000 [05:34<3:14:53,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.09it/s]

100%|██████████| 1/1 [00:00<00:00, 117.53it/s]
  3%|▎         | 279/10000 [05:35<3:14:42,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 119.63it/s]

100%|██████████| 1/1 [00:00<00:00, 117.51it/s]
  3%|▎         | 280/10000 [05:37<3:14:27,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.49it/s]

100%|██████████| 1/1 [00:00<00:00, 112.64it/s]
  3%|▎     

  3%|▎         | 329/10000 [06:37<3:12:01,  1.19s/it]
100%|██████████| 1/1 [00:00<00:00, 125.02it/s]

100%|██████████| 1/1 [00:00<00:00, 115.56it/s]
  3%|▎         | 330/10000 [06:38<3:12:53,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 123.76it/s]

100%|██████████| 1/1 [00:00<00:00, 118.61it/s]
  3%|▎         | 331/10000 [06:39<3:13:02,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.52it/s]

100%|██████████| 1/1 [00:00<00:00, 117.73it/s]
  3%|▎         | 332/10000 [06:41<3:12:39,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.32it/s]

100%|██████████| 1/1 [00:00<00:00, 116.11it/s]
  3%|▎         | 333/10000 [06:42<3:12:52,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.41it/s]

100%|██████████| 1/1 [00:00<00:00, 117.38it/s]
  3%|▎         | 334/10000 [06:43<3:13:14,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 123.22it/s]

100%|██████████| 1/1 [00:00<00:00, 117.85it/s]
  3%|▎         | 335/10000 [06:44<3:13:09,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.24it/s]

100%

100%|██████████| 1/1 [00:00<00:00, 117.83it/s]
  4%|▍         | 384/10000 [07:43<3:12:36,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.35it/s]

100%|██████████| 1/1 [00:00<00:00, 117.04it/s]
  4%|▍         | 385/10000 [07:44<3:13:24,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 121.38it/s]

100%|██████████| 1/1 [00:00<00:00, 116.22it/s]
  4%|▍         | 386/10000 [07:46<3:13:38,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.39it/s]

100%|██████████| 1/1 [00:00<00:00, 111.21it/s]
  4%|▍         | 387/10000 [07:47<3:13:53,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 117.56it/s]

100%|██████████| 1/1 [00:00<00:00, 115.25it/s]
  4%|▍         | 388/10000 [07:48<3:12:55,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.11it/s]

100%|██████████| 1/1 [00:00<00:00, 114.07it/s]
  4%|▍         | 389/10000 [07:49<3:13:16,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.60it/s]

100%|██████████| 1/1 [00:00<00:00, 118.28it/s]
  4%|▍         | 390/10000 [07:50<3:12:14,  1.20s/it]
100%|

100%|██████████| 1/1 [00:00<00:00, 118.03it/s]
  4%|▍         | 439/10000 [08:49<3:11:44,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.26it/s]

100%|██████████| 1/1 [00:00<00:00, 117.61it/s]
  4%|▍         | 440/10000 [08:51<3:11:33,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.61it/s]

100%|██████████| 1/1 [00:00<00:00, 115.66it/s]
  4%|▍         | 441/10000 [08:52<3:11:16,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.98it/s]

100%|██████████| 1/1 [00:00<00:00, 114.96it/s]
  4%|▍         | 442/10000 [08:53<3:11:44,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 118.61it/s]

100%|██████████| 1/1 [00:00<00:00, 116.32it/s]
  4%|▍         | 443/10000 [08:54<3:12:06,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.18it/s]

100%|██████████| 1/1 [00:00<00:00, 115.69it/s]
  4%|▍         | 444/10000 [08:56<3:12:05,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.18it/s]

100%|██████████| 1/1 [00:00<00:00, 115.42it/s]
  4%|▍         | 445/10000 [08:57<3:11:54,  1.21s/it]
100%|

100%|██████████| 1/1 [00:00<00:00, 123.86it/s]

100%|██████████| 1/1 [00:00<00:00, 115.55it/s]
  5%|▍         | 494/10000 [09:56<3:10:19,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 119.86it/s]

100%|██████████| 1/1 [00:00<00:00, 116.49it/s]
  5%|▍         | 495/10000 [09:57<3:10:14,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 123.15it/s]

100%|██████████| 1/1 [00:00<00:00, 115.74it/s]
  5%|▍         | 496/10000 [09:58<3:10:28,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.66it/s]

100%|██████████| 1/1 [00:00<00:00, 117.88it/s]
  5%|▍         | 497/10000 [09:59<3:10:50,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.52it/s]

100%|██████████| 1/1 [00:00<00:00, 115.60it/s]
  5%|▍         | 498/10000 [10:01<3:11:13,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.09it/s]

100%|██████████| 1/1 [00:00<00:00, 117.88it/s]
  5%|▍         | 499/10000 [10:02<3:11:21,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.71it/s]

100%|██████████| 1/1 [00:00<00:00, 117.87it/s]
  5%|▌     

  5%|▌         | 548/10000 [11:01<3:10:05,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.93it/s]

100%|██████████| 1/1 [00:00<00:00, 114.83it/s]
  5%|▌         | 549/10000 [11:02<3:10:52,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 126.03it/s]

100%|██████████| 1/1 [00:00<00:00, 115.29it/s]
  6%|▌         | 550/10000 [11:03<3:09:54,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.95it/s]

100%|██████████| 1/1 [00:00<00:00, 115.99it/s]
  6%|▌         | 551/10000 [11:05<3:10:03,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.83it/s]

100%|██████████| 1/1 [00:00<00:00, 113.45it/s]
  6%|▌         | 552/10000 [11:06<3:10:01,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 119.86it/s]

100%|██████████| 1/1 [00:00<00:00, 117.37it/s]
  6%|▌         | 553/10000 [11:07<3:10:07,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 119.81it/s]

100%|██████████| 1/1 [00:00<00:00, 113.61it/s]
  6%|▌         | 554/10000 [11:08<3:10:01,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.57it/s]

100%

100%|██████████| 1/1 [00:00<00:00, 122.72it/s]

100%|██████████| 1/1 [00:00<00:00, 117.16it/s]
  6%|▌         | 599/10000 [12:03<3:09:19,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 121.72it/s]

100%|██████████| 1/1 [00:00<00:00, 114.45it/s]
  6%|▌         | 600/10000 [12:04<3:09:22,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 117.97it/s]

100%|██████████| 1/1 [00:00<00:00, 114.68it/s]
  6%|▌         | 601/10000 [12:05<3:09:50,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.62it/s]

100%|██████████| 1/1 [00:00<00:00, 114.66it/s]
  6%|▌         | 602/10000 [12:06<3:09:00,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.88it/s]

100%|██████████| 1/1 [00:00<00:00, 116.11it/s]
  6%|▌         | 603/10000 [12:07<3:08:31,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 121.16it/s]

100%|██████████| 1/1 [00:00<00:00, 116.40it/s]
  6%|▌         | 604/10000 [12:09<3:08:13,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.41it/s]

100%|██████████| 1/1 [00:00<00:00, 115.19it/s]
  6%|▌     

  7%|▋         | 653/10000 [13:08<3:07:24,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.69it/s]

100%|██████████| 1/1 [00:00<00:00, 117.48it/s]
  7%|▋         | 654/10000 [13:09<3:06:48,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.43it/s]

100%|██████████| 1/1 [00:00<00:00, 118.68it/s]
  7%|▋         | 655/10000 [13:10<3:06:41,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 122.44it/s]

100%|██████████| 1/1 [00:00<00:00, 118.23it/s]
  7%|▋         | 656/10000 [13:11<3:07:17,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 122.95it/s]

100%|██████████| 1/1 [00:00<00:00, 115.65it/s]
  7%|▋         | 657/10000 [13:13<3:08:14,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 121.21it/s]

100%|██████████| 1/1 [00:00<00:00, 115.20it/s]
  7%|▋         | 658/10000 [13:14<3:07:46,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.64it/s]

100%|██████████| 1/1 [00:00<00:00, 116.47it/s]
  7%|▋         | 659/10000 [13:15<3:08:25,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 120.87it/s]

100%

100%|██████████| 1/1 [00:00<00:00, 116.67it/s]
  7%|▋         | 708/10000 [14:14<3:06:24,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.37it/s]

100%|██████████| 1/1 [00:00<00:00, 118.37it/s]
  7%|▋         | 709/10000 [14:15<3:05:59,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.80it/s]

100%|██████████| 1/1 [00:00<00:00, 116.16it/s]
  7%|▋         | 710/10000 [14:17<3:06:16,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.17it/s]

100%|██████████| 1/1 [00:00<00:00, 117.13it/s]
  7%|▋         | 711/10000 [14:18<3:06:41,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 126.15it/s]

100%|██████████| 1/1 [00:00<00:00, 115.61it/s]
  7%|▋         | 712/10000 [14:19<3:06:19,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.24it/s]

100%|██████████| 1/1 [00:00<00:00, 116.60it/s]
  7%|▋         | 713/10000 [14:20<3:08:06,  1.22s/it]
100%|██████████| 1/1 [00:00<00:00, 123.92it/s]

100%|██████████| 1/1 [00:00<00:00, 117.82it/s]
  7%|▋         | 714/10000 [14:21<3:07:55,  1.21s/it]
100%|


100%|██████████| 1/1 [00:00<00:00, 115.41it/s]
  8%|▊         | 763/10000 [15:21<3:08:46,  1.23s/it]
100%|██████████| 1/1 [00:00<00:00, 125.52it/s]

100%|██████████| 1/1 [00:00<00:00, 117.66it/s]
  8%|▊         | 764/10000 [15:22<3:07:14,  1.22s/it]
100%|██████████| 1/1 [00:00<00:00, 123.07it/s]

100%|██████████| 1/1 [00:00<00:00, 117.89it/s]
  8%|▊         | 765/10000 [15:23<3:06:29,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 121.68it/s]

100%|██████████| 1/1 [00:00<00:00, 116.30it/s]
  8%|▊         | 766/10000 [15:24<3:06:47,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.06it/s]

100%|██████████| 1/1 [00:00<00:00, 117.59it/s]
  8%|▊         | 767/10000 [15:25<3:05:57,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.54it/s]

100%|██████████| 1/1 [00:00<00:00, 118.79it/s]
  8%|▊         | 768/10000 [15:27<3:29:39,  1.36s/it]
100%|██████████| 1/1 [00:00<00:00, 126.34it/s]

100%|██████████| 1/1 [00:00<00:00, 115.09it/s]
  8%|▊         | 769/10000 [15:30<4:17:17,  1.67s/it]
100%

100%|██████████| 1/1 [00:00<00:00, 122.83it/s]

100%|██████████| 1/1 [00:00<00:00, 116.25it/s]
  8%|▊         | 818/10000 [16:29<3:04:01,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 122.84it/s]

100%|██████████| 1/1 [00:00<00:00, 114.96it/s]
  8%|▊         | 819/10000 [16:30<3:04:11,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.02it/s]

100%|██████████| 1/1 [00:00<00:00, 118.50it/s]
  8%|▊         | 820/10000 [16:31<3:03:39,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 122.23it/s]

100%|██████████| 1/1 [00:00<00:00, 116.67it/s]
  8%|▊         | 821/10000 [16:32<3:03:33,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.01it/s]

100%|██████████| 1/1 [00:00<00:00, 116.03it/s]
  8%|▊         | 822/10000 [16:33<3:03:40,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.90it/s]

100%|██████████| 1/1 [00:00<00:00, 116.74it/s]
  8%|▊         | 823/10000 [16:35<3:03:38,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.27it/s]

100%|██████████| 1/1 [00:00<00:00, 117.62it/s]
  8%|▊     

  9%|▊         | 872/10000 [17:34<3:04:34,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.10it/s]

100%|██████████| 1/1 [00:00<00:00, 116.26it/s]
  9%|▊         | 873/10000 [17:35<3:04:30,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.88it/s]

100%|██████████| 1/1 [00:00<00:00, 115.46it/s]
  9%|▊         | 874/10000 [17:36<3:04:10,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.89it/s]

100%|██████████| 1/1 [00:00<00:00, 117.44it/s]
  9%|▉         | 875/10000 [17:37<3:04:02,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.26it/s]

100%|██████████| 1/1 [00:00<00:00, 117.28it/s]
  9%|▉         | 876/10000 [17:39<3:03:46,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 121.91it/s]

100%|██████████| 1/1 [00:00<00:00, 114.46it/s]
  9%|▉         | 877/10000 [17:40<3:03:47,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 121.12it/s]

100%|██████████| 1/1 [00:00<00:00, 117.69it/s]
  9%|▉         | 878/10000 [17:41<3:03:13,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.39it/s]

100%

100%|██████████| 1/1 [00:00<00:00, 118.10it/s]
  9%|▉         | 927/10000 [18:40<3:01:52,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.69it/s]

100%|██████████| 1/1 [00:00<00:00, 116.52it/s]
  9%|▉         | 928/10000 [18:41<3:01:47,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 126.40it/s]

100%|██████████| 1/1 [00:00<00:00, 117.27it/s]
  9%|▉         | 929/10000 [18:42<3:01:53,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.80it/s]

100%|██████████| 1/1 [00:00<00:00, 116.62it/s]
  9%|▉         | 930/10000 [18:44<3:01:47,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.57it/s]

100%|██████████| 1/1 [00:00<00:00, 116.25it/s]
  9%|▉         | 931/10000 [18:45<3:01:56,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 117.56it/s]

100%|██████████| 1/1 [00:00<00:00, 116.88it/s]
  9%|▉         | 932/10000 [18:46<3:02:05,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 120.92it/s]

100%|██████████| 1/1 [00:00<00:00, 117.58it/s]
  9%|▉         | 933/10000 [18:47<3:02:53,  1.21s/it]
100%|

100%|██████████| 1/1 [00:00<00:00, 118.38it/s]
 10%|▉         | 982/10000 [19:46<3:01:44,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 120.66it/s]

100%|██████████| 1/1 [00:00<00:00, 116.40it/s]
 10%|▉         | 983/10000 [19:48<3:02:00,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 126.22it/s]

100%|██████████| 1/1 [00:00<00:00, 117.76it/s]
 10%|▉         | 984/10000 [19:49<3:01:39,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 126.72it/s]

100%|██████████| 1/1 [00:00<00:00, 116.73it/s]
 10%|▉         | 985/10000 [19:50<3:01:08,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.38it/s]

100%|██████████| 1/1 [00:00<00:00, 117.22it/s]
 10%|▉         | 986/10000 [19:51<3:01:35,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.46it/s]

100%|██████████| 1/1 [00:00<00:00, 116.63it/s]
 10%|▉         | 987/10000 [19:53<3:00:58,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.92it/s]

100%|██████████| 1/1 [00:00<00:00, 116.75it/s]
 10%|▉         | 988/10000 [19:54<3:00:56,  1.20s/it]
100%|

In [21]:
np.mean(norms)

-32948.875

In [None]:
# Load CIFAR100 OOD data
_, ood_loader = find_right_model(
    DATASETS, "CIFAR100",
    arguments=arguments
)

In [22]:
# OOD 

ood_norms = []
for i, (x, y) in enumerate(tqdm(ood_loader)):
    if i == int(len(test_loader)*0.1): break
        
    train_x, train_y = next(iter(train_loader))
    
    model = deepcopy(backup_model)
    model.eval()
        
    batch_x = torch.cat((train_x, x))
    batch_y = torch.cat((train_y, y))
    out = model(batch_x.cuda())
    preds = out.argmax(dim=-1, keepdim=True).view_as(batch_y)
        
    # get criterion
    criterion = find_right_model(
        CRITERION_DIR, arguments['prune_criterion'],
        model=model,
        limit=arguments['pruning_limit'],
        start=0.5,
        steps=arguments['snip_steps'],
        device=arguments['device'],
        arguments=arguments
    )
    
    criterion.prune(arguments['pruning_limit'],
                        train_loader=[(batch_x, preds)],
                      ood_loader=None,
                      local=arguments['local_pruning'],
                      manager=None)
    
    layer_norms = []
    for j, (grad1, grad2) in enumerate(zip(orig_grads.values(), criterion.grads_abs.values())):
        grad3 = (grad2 - orig_mean[layer_names[j]]) /  (1e-8 + orig_std[layer_names[j]])
        layer_norms.append(torch.norm(grad1 - grad3, p=2).cpu().detach().numpy())

    train_x2, train_y2 = next(iter(train_loader))
    
    batch_x = torch.cat((train_x, train_x2))
    batch_y = torch.cat((train_y, train_y2))
    out = model(batch_x.cuda())
    preds = out.argmax(dim=-1, keepdim=True).view_as(batch_y)
        
    # get criterion
    criterion = find_right_model(
        CRITERION_DIR, arguments['prune_criterion'],
        model=model,
        limit=arguments['pruning_limit'],
        start=0.5,
        steps=arguments['snip_steps'],
        device=arguments['device'],
        arguments=arguments
    )
    
    criterion.prune(arguments['pruning_limit'],
                        train_loader=[(batch_x, preds)],
                      ood_loader=None,
                      local=arguments['local_pruning'],
                      manager=None)
    
    layer_norms_train = []
    for j, (grad1, grad2) in enumerate(zip(orig_grads.values(), criterion.grads_abs.values())):
        grad3 = (grad2 - orig_mean[layer_names[j]]) /  (1e-8 + orig_std[layer_names[j]])
        layer_norms_train.append(torch.norm(grad1 - grad3, p=2).cpu().detach().numpy())
    
    ood_norms.append(np.mean(layer_norms) - np.mean(layer_norms_train))

  0%|          | 0/26032 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 117.61it/s]

100%|██████████| 1/1 [00:00<00:00, 110.57it/s]
  0%|          | 1/26032 [00:01<10:48:06,  1.49s/it]
100%|██████████| 1/1 [00:00<00:00, 115.35it/s]

100%|██████████| 1/1 [00:00<00:00, 106.63it/s]
  0%|          | 2/26032 [00:02<9:45:29,  1.35s/it] 
100%|██████████| 1/1 [00:00<00:00, 125.99it/s]

100%|██████████| 1/1 [00:00<00:00, 109.35it/s]
  0%|          | 3/26032 [00:03<9:23:51,  1.30s/it]
100%|██████████| 1/1 [00:00<00:00, 121.39it/s]

100%|██████████| 1/1 [00:00<00:00, 111.44it/s]
  0%|          | 4/26032 [00:05<9:12:43,  1.27s/it]
100%|██████████| 1/1 [00:00<00:00, 121.40it/s]

100%|██████████| 1/1 [00:00<00:00, 117.06it/s]
  0%|          | 5/26032 [00:06<9:01:35,  1.25s/it]
100%|██████████| 1/1 [00:00<00:00, 125.10it/s]

100%|██████████| 1/1 [00:00<00:00, 118.32it/s]
  0%|          | 6/26032 [00:07<8:53:20,  1.23s/it]
100%|██████████| 1/1 [00:00<00:00, 123.61it/s]

100%|██████████| 1/1 [00:0

100%|██████████| 1/1 [00:00<00:00, 123.79it/s]

100%|██████████| 1/1 [00:00<00:00, 115.21it/s]
  0%|          | 56/26032 [01:08<8:42:54,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 121.46it/s]

100%|██████████| 1/1 [00:00<00:00, 114.16it/s]
  0%|          | 57/26032 [01:09<8:42:25,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.87it/s]

100%|██████████| 1/1 [00:00<00:00, 109.23it/s]
  0%|          | 58/26032 [01:10<8:44:31,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.03it/s]

100%|██████████| 1/1 [00:00<00:00, 117.67it/s]
  0%|          | 59/26032 [01:12<8:52:07,  1.23s/it]
100%|██████████| 1/1 [00:00<00:00, 125.41it/s]

100%|██████████| 1/1 [00:00<00:00, 114.45it/s]
  0%|          | 60/26032 [01:13<8:49:31,  1.22s/it]
100%|██████████| 1/1 [00:00<00:00, 126.02it/s]

100%|██████████| 1/1 [00:00<00:00, 116.31it/s]
  0%|          | 61/26032 [01:14<8:46:48,  1.22s/it]
100%|██████████| 1/1 [00:00<00:00, 124.29it/s]

100%|██████████| 1/1 [00:00<00:00, 114.27it/s]
  0%|          | 

100%|██████████| 1/1 [00:00<00:00, 113.77it/s]

100%|██████████| 1/1 [00:00<00:00, 115.40it/s]
  0%|          | 111/26032 [02:15<8:42:04,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 126.67it/s]

100%|██████████| 1/1 [00:00<00:00, 115.42it/s]
  0%|          | 112/26032 [02:16<8:41:47,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.87it/s]

100%|██████████| 1/1 [00:00<00:00, 116.76it/s]
  0%|          | 113/26032 [02:17<8:41:41,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 115.12it/s]

100%|██████████| 1/1 [00:00<00:00, 115.40it/s]
  0%|          | 114/26032 [02:18<8:40:42,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.87it/s]

100%|██████████| 1/1 [00:00<00:00, 112.22it/s]
  0%|          | 115/26032 [02:19<8:41:32,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.47it/s]

100%|██████████| 1/1 [00:00<00:00, 114.39it/s]
  0%|          | 116/26032 [02:21<8:40:13,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.20it/s]

100%|██████████| 1/1 [00:00<00:00, 115.13it/s]
  0%|      

  1%|          | 165/26032 [03:20<8:42:07,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.58it/s]

100%|██████████| 1/1 [00:00<00:00, 115.71it/s]
  1%|          | 166/26032 [03:21<8:42:20,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.06it/s]

100%|██████████| 1/1 [00:00<00:00, 112.79it/s]
  1%|          | 167/26032 [03:23<8:41:33,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 118.06it/s]

100%|██████████| 1/1 [00:00<00:00, 115.04it/s]
  1%|          | 168/26032 [03:24<8:42:21,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.57it/s]

100%|██████████| 1/1 [00:00<00:00, 115.22it/s]
  1%|          | 169/26032 [03:25<8:43:40,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.66it/s]

100%|██████████| 1/1 [00:00<00:00, 116.35it/s]
  1%|          | 170/26032 [03:26<8:42:43,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 117.52it/s]

100%|██████████| 1/1 [00:00<00:00, 112.23it/s]
  1%|          | 171/26032 [03:27<8:42:39,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.29it/s]

100%

100%|██████████| 1/1 [00:00<00:00, 113.50it/s]
  1%|          | 220/26032 [04:29<12:43:34,  1.77s/it]
100%|██████████| 1/1 [00:00<00:00, 125.38it/s]

100%|██████████| 1/1 [00:00<00:00, 117.05it/s]
  1%|          | 221/26032 [04:30<11:30:04,  1.60s/it]
100%|██████████| 1/1 [00:00<00:00, 121.37it/s]

100%|██████████| 1/1 [00:00<00:00, 116.36it/s]
  1%|          | 222/26032 [04:31<10:38:27,  1.48s/it]
100%|██████████| 1/1 [00:00<00:00, 124.63it/s]

100%|██████████| 1/1 [00:00<00:00, 116.74it/s]
  1%|          | 223/26032 [04:33<10:02:37,  1.40s/it]
100%|██████████| 1/1 [00:00<00:00, 125.48it/s]

100%|██████████| 1/1 [00:00<00:00, 115.86it/s]
  1%|          | 224/26032 [04:34<9:34:24,  1.34s/it] 
100%|██████████| 1/1 [00:00<00:00, 122.08it/s]

100%|██████████| 1/1 [00:00<00:00, 114.92it/s]
  1%|          | 225/26032 [04:35<9:18:09,  1.30s/it]
100%|██████████| 1/1 [00:00<00:00, 121.00it/s]

100%|██████████| 1/1 [00:00<00:00, 115.60it/s]
  1%|          | 226/26032 [04:36<9:07:22,  1.27s/it]


100%|██████████| 1/1 [00:00<00:00, 125.65it/s]

100%|██████████| 1/1 [00:00<00:00, 115.74it/s]
  1%|          | 275/26032 [05:35<8:35:05,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.33it/s]

100%|██████████| 1/1 [00:00<00:00, 117.23it/s]
  1%|          | 276/26032 [05:36<8:34:58,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.98it/s]

100%|██████████| 1/1 [00:00<00:00, 117.55it/s]
  1%|          | 277/26032 [05:38<8:37:25,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.31it/s]

100%|██████████| 1/1 [00:00<00:00, 114.44it/s]
  1%|          | 278/26032 [05:39<8:38:38,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.76it/s]

100%|██████████| 1/1 [00:00<00:00, 115.73it/s]
  1%|          | 279/26032 [05:40<8:37:35,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.72it/s]

100%|██████████| 1/1 [00:00<00:00, 115.19it/s]
  1%|          | 280/26032 [05:41<8:37:31,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.63it/s]

100%|██████████| 1/1 [00:00<00:00, 115.72it/s]
  1%|      

  1%|▏         | 329/26032 [06:40<8:36:23,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.02it/s]

100%|██████████| 1/1 [00:00<00:00, 116.28it/s]
  1%|▏         | 330/26032 [06:41<8:38:19,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.12it/s]

100%|██████████| 1/1 [00:00<00:00, 113.68it/s]
  1%|▏         | 331/26032 [06:43<8:37:53,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.39it/s]

100%|██████████| 1/1 [00:00<00:00, 114.52it/s]
  1%|▏         | 332/26032 [06:44<8:36:44,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.42it/s]

100%|██████████| 1/1 [00:00<00:00, 116.80it/s]
  1%|▏         | 333/26032 [06:45<8:37:02,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.75it/s]

100%|██████████| 1/1 [00:00<00:00, 117.69it/s]
  1%|▏         | 334/26032 [06:46<8:37:41,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.33it/s]

100%|██████████| 1/1 [00:00<00:00, 114.27it/s]
  1%|▏         | 335/26032 [06:47<8:36:50,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.87it/s]

100%

100%|██████████| 1/1 [00:00<00:00, 116.00it/s]
  1%|▏         | 384/26032 [07:47<8:34:38,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.28it/s]

100%|██████████| 1/1 [00:00<00:00, 115.68it/s]
  1%|▏         | 385/26032 [07:48<8:35:33,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.47it/s]

100%|██████████| 1/1 [00:00<00:00, 117.95it/s]
  1%|▏         | 386/26032 [07:49<8:35:59,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.29it/s]

100%|██████████| 1/1 [00:00<00:00, 117.04it/s]
  1%|▏         | 387/26032 [07:50<8:35:07,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.51it/s]

100%|██████████| 1/1 [00:00<00:00, 114.07it/s]
  1%|▏         | 388/26032 [07:52<8:34:54,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.94it/s]

100%|██████████| 1/1 [00:00<00:00, 116.72it/s]
  1%|▏         | 389/26032 [07:53<8:34:40,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 119.76it/s]

100%|██████████| 1/1 [00:00<00:00, 115.47it/s]
  1%|▏         | 390/26032 [07:54<8:37:10,  1.21s/it]
100%|

100%|██████████| 1/1 [00:00<00:00, 125.10it/s]

100%|██████████| 1/1 [00:00<00:00, 118.15it/s]
  2%|▏         | 439/26032 [08:53<8:34:58,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.16it/s]

100%|██████████| 1/1 [00:00<00:00, 117.50it/s]
  2%|▏         | 440/26032 [08:54<8:35:05,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.67it/s]

100%|██████████| 1/1 [00:00<00:00, 117.30it/s]
  2%|▏         | 441/26032 [08:56<8:34:35,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.76it/s]

100%|██████████| 1/1 [00:00<00:00, 115.14it/s]
  2%|▏         | 442/26032 [08:57<8:34:37,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.77it/s]

100%|██████████| 1/1 [00:00<00:00, 116.90it/s]
  2%|▏         | 443/26032 [08:58<8:36:13,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.58it/s]

100%|██████████| 1/1 [00:00<00:00, 116.80it/s]
  2%|▏         | 444/26032 [08:59<8:36:37,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.86it/s]

100%|██████████| 1/1 [00:00<00:00, 116.25it/s]
  2%|▏     

  2%|▏         | 493/26032 [09:59<8:35:05,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.67it/s]

100%|██████████| 1/1 [00:00<00:00, 114.47it/s]
  2%|▏         | 494/26032 [10:00<8:35:20,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.10it/s]

100%|██████████| 1/1 [00:00<00:00, 111.34it/s]
  2%|▏         | 495/26032 [10:01<8:34:34,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.48it/s]

100%|██████████| 1/1 [00:00<00:00, 114.68it/s]
  2%|▏         | 496/26032 [10:02<8:32:26,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 120.86it/s]

100%|██████████| 1/1 [00:00<00:00, 116.06it/s]
  2%|▏         | 497/26032 [10:03<8:33:38,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.12it/s]

100%|██████████| 1/1 [00:00<00:00, 117.70it/s]
  2%|▏         | 498/26032 [10:05<8:34:30,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.73it/s]

100%|██████████| 1/1 [00:00<00:00, 117.20it/s]
  2%|▏         | 499/26032 [10:06<8:37:03,  1.22s/it]
100%|██████████| 1/1 [00:00<00:00, 124.28it/s]

100%

100%|██████████| 1/1 [00:00<00:00, 117.33it/s]
  2%|▏         | 548/26032 [11:05<8:32:42,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.63it/s]

100%|██████████| 1/1 [00:00<00:00, 115.58it/s]
  2%|▏         | 549/26032 [11:06<8:32:44,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.88it/s]

100%|██████████| 1/1 [00:00<00:00, 116.72it/s]
  2%|▏         | 550/26032 [11:08<8:32:10,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.27it/s]

100%|██████████| 1/1 [00:00<00:00, 116.55it/s]
  2%|▏         | 551/26032 [11:09<8:33:17,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.79it/s]

100%|██████████| 1/1 [00:00<00:00, 116.96it/s]
  2%|▏         | 552/26032 [11:10<8:32:08,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.55it/s]

100%|██████████| 1/1 [00:00<00:00, 92.07it/s]
  2%|▏         | 553/26032 [11:11<8:35:50,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.27it/s]

100%|██████████| 1/1 [00:00<00:00, 117.24it/s]
  2%|▏         | 554/26032 [11:13<8:41:59,  1.23s/it]
100%|█


100%|██████████| 1/1 [00:00<00:00, 117.04it/s]
  2%|▏         | 603/26032 [12:12<8:34:13,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 119.23it/s]

100%|██████████| 1/1 [00:00<00:00, 115.90it/s]
  2%|▏         | 604/26032 [12:13<8:34:36,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.11it/s]

100%|██████████| 1/1 [00:00<00:00, 116.18it/s]
  2%|▏         | 605/26032 [12:14<8:34:30,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.33it/s]

100%|██████████| 1/1 [00:00<00:00, 116.07it/s]
  2%|▏         | 606/26032 [12:16<8:34:15,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.22it/s]

100%|██████████| 1/1 [00:00<00:00, 117.17it/s]
  2%|▏         | 607/26032 [12:17<8:32:23,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.13it/s]

100%|██████████| 1/1 [00:00<00:00, 117.43it/s]
  2%|▏         | 608/26032 [12:18<8:33:49,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.20it/s]

100%|██████████| 1/1 [00:00<00:00, 117.62it/s]
  2%|▏         | 609/26032 [12:19<8:35:47,  1.22s/it]
100%

100%|██████████| 1/1 [00:00<00:00, 125.36it/s]

100%|██████████| 1/1 [00:00<00:00, 116.71it/s]
  3%|▎         | 658/26032 [13:19<8:33:42,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.72it/s]

100%|██████████| 1/1 [00:00<00:00, 114.61it/s]
  3%|▎         | 659/26032 [13:20<8:34:55,  1.22s/it]
100%|██████████| 1/1 [00:00<00:00, 125.04it/s]

100%|██████████| 1/1 [00:00<00:00, 117.76it/s]
  3%|▎         | 660/26032 [13:21<8:32:59,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.68it/s]

100%|██████████| 1/1 [00:00<00:00, 115.96it/s]
  3%|▎         | 661/26032 [13:22<8:33:01,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.28it/s]

100%|██████████| 1/1 [00:00<00:00, 116.48it/s]
  3%|▎         | 662/26032 [13:24<8:33:06,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.70it/s]

100%|██████████| 1/1 [00:00<00:00, 114.86it/s]
  3%|▎         | 663/26032 [13:25<8:33:25,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.18it/s]

100%|██████████| 1/1 [00:00<00:00, 117.28it/s]
  3%|▎     

  3%|▎         | 712/26032 [14:26<8:29:02,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 119.99it/s]

100%|██████████| 1/1 [00:00<00:00, 114.03it/s]
  3%|▎         | 713/26032 [14:27<8:29:56,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.01it/s]

100%|██████████| 1/1 [00:00<00:00, 116.97it/s]
  3%|▎         | 714/26032 [14:29<8:27:45,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 123.90it/s]

100%|██████████| 1/1 [00:00<00:00, 117.34it/s]
  3%|▎         | 715/26032 [14:30<8:28:02,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 123.41it/s]

100%|██████████| 1/1 [00:00<00:00, 116.70it/s]
  3%|▎         | 716/26032 [14:31<8:27:53,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.48it/s]

100%|██████████| 1/1 [00:00<00:00, 116.55it/s]
  3%|▎         | 717/26032 [14:32<8:26:39,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 126.06it/s]

100%|██████████| 1/1 [00:00<00:00, 114.93it/s]
  3%|▎         | 718/26032 [14:33<8:26:59,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 122.13it/s]

100%

  3%|▎         | 767/26032 [15:33<8:29:11,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.59it/s]

100%|██████████| 1/1 [00:00<00:00, 117.45it/s]
  3%|▎         | 768/26032 [15:34<8:30:48,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.24it/s]

100%|██████████| 1/1 [00:00<00:00, 117.80it/s]
  3%|▎         | 769/26032 [15:35<8:30:52,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.54it/s]

100%|██████████| 1/1 [00:00<00:00, 96.66it/s]
  3%|▎         | 770/26032 [15:36<8:31:47,  1.22s/it]
100%|██████████| 1/1 [00:00<00:00, 124.52it/s]

100%|██████████| 1/1 [00:00<00:00, 116.48it/s]
  3%|▎         | 771/26032 [15:38<8:30:00,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.90it/s]

100%|██████████| 1/1 [00:00<00:00, 115.90it/s]
  3%|▎         | 772/26032 [15:39<8:29:14,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.66it/s]

100%|██████████| 1/1 [00:00<00:00, 113.33it/s]
  3%|▎         | 773/26032 [15:40<8:29:54,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.78it/s]

100%|

100%|██████████| 1/1 [00:00<00:00, 115.68it/s]
  3%|▎         | 822/26032 [16:39<8:26:32,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.98it/s]

100%|██████████| 1/1 [00:00<00:00, 116.88it/s]
  3%|▎         | 823/26032 [16:41<8:26:15,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.30it/s]

100%|██████████| 1/1 [00:00<00:00, 114.23it/s]
  3%|▎         | 824/26032 [16:42<8:27:00,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.46it/s]

100%|██████████| 1/1 [00:00<00:00, 116.93it/s]
  3%|▎         | 825/26032 [16:43<8:28:59,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.03it/s]

100%|██████████| 1/1 [00:00<00:00, 116.21it/s]
  3%|▎         | 826/26032 [16:44<8:30:18,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.46it/s]

100%|██████████| 1/1 [00:00<00:00, 117.18it/s]
  3%|▎         | 827/26032 [16:45<8:30:52,  1.22s/it]
100%|██████████| 1/1 [00:00<00:00, 119.48it/s]

100%|██████████| 1/1 [00:00<00:00, 115.61it/s]
  3%|▎         | 828/26032 [16:47<8:31:47,  1.22s/it]
100%|


100%|██████████| 1/1 [00:00<00:00, 117.47it/s]
  3%|▎         | 877/26032 [17:46<8:28:03,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.85it/s]

100%|██████████| 1/1 [00:00<00:00, 116.82it/s]
  3%|▎         | 878/26032 [17:47<8:28:31,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.22it/s]

100%|██████████| 1/1 [00:00<00:00, 116.05it/s]
  3%|▎         | 879/26032 [17:49<8:29:08,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.39it/s]

100%|██████████| 1/1 [00:00<00:00, 115.85it/s]
  3%|▎         | 880/26032 [17:50<8:26:12,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.00it/s]

100%|██████████| 1/1 [00:00<00:00, 116.40it/s]
  3%|▎         | 881/26032 [17:51<8:28:36,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.10it/s]

100%|██████████| 1/1 [00:00<00:00, 116.54it/s]
  3%|▎         | 882/26032 [17:52<8:25:55,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.09it/s]

100%|██████████| 1/1 [00:00<00:00, 117.04it/s]
  3%|▎         | 883/26032 [17:53<8:25:33,  1.21s/it]
100%

100%|██████████| 1/1 [00:00<00:00, 125.97it/s]

100%|██████████| 1/1 [00:00<00:00, 116.28it/s]
  4%|▎         | 932/26032 [18:53<8:24:04,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 124.23it/s]

100%|██████████| 1/1 [00:00<00:00, 117.63it/s]
  4%|▎         | 933/26032 [18:54<8:26:34,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.08it/s]

100%|██████████| 1/1 [00:00<00:00, 116.81it/s]
  4%|▎         | 934/26032 [18:55<8:26:56,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 123.46it/s]

100%|██████████| 1/1 [00:00<00:00, 117.06it/s]
  4%|▎         | 935/26032 [18:56<8:28:22,  1.22s/it]
100%|██████████| 1/1 [00:00<00:00, 123.95it/s]

100%|██████████| 1/1 [00:00<00:00, 117.10it/s]
  4%|▎         | 936/26032 [18:58<8:29:44,  1.22s/it]
100%|██████████| 1/1 [00:00<00:00, 124.77it/s]

100%|██████████| 1/1 [00:00<00:00, 117.64it/s]
  4%|▎         | 937/26032 [18:59<8:28:10,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 122.33it/s]

100%|██████████| 1/1 [00:00<00:00, 116.40it/s]
  4%|▎     

  4%|▍         | 986/26032 [19:58<8:24:07,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.25it/s]

100%|██████████| 1/1 [00:00<00:00, 116.19it/s]
  4%|▍         | 987/26032 [20:00<8:24:17,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.15it/s]

100%|██████████| 1/1 [00:00<00:00, 116.75it/s]
  4%|▍         | 988/26032 [20:01<8:23:31,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 124.62it/s]

100%|██████████| 1/1 [00:00<00:00, 116.69it/s]
  4%|▍         | 989/26032 [20:02<8:23:28,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.20it/s]

100%|██████████| 1/1 [00:00<00:00, 118.82it/s]
  4%|▍         | 990/26032 [20:03<8:22:37,  1.20s/it]
100%|██████████| 1/1 [00:00<00:00, 125.07it/s]

100%|██████████| 1/1 [00:00<00:00, 116.87it/s]
  4%|▍         | 991/26032 [20:04<8:23:50,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 126.31it/s]

100%|██████████| 1/1 [00:00<00:00, 117.04it/s]
  4%|▍         | 992/26032 [20:06<8:24:02,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00, 125.41it/s]

100%

In [None]:
# Attacks
from torchvision import datasets
from utils.constants import DATASET_PATH
mean = (0.4914, 0.4822, 0.4465)
std = (0.2471, 0.2435, 0.2616)
test_set = datasets.CIFAR10(root=DATASET_PATH, train=False, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=arguments['batch_size'],
        shuffle=False,
        pin_memory=True,
        num_workers=4
    )

trans = transforms.Normalize(mean, std)

# Attacks
ood_norms = []
attack_per_layer_norms = []

for i, (x, y) in enumerate(tqdm(test_loader)):
    
    model = deepcopy(backup_model)
    model.eval()
        
    adv_results, predictions = construct_adversarial_examples(x, y, 'FGSM', model, model.device, 8, False, False)
    _, advs, success = adv_results
    x = advs.cpu()
        
    new_x = []
    for image in x:
        image = trans(image.squeeze()).unsqueeze(0)
        new_x.append(image)
    x = torch.cat(new_x)
        
    model.eval()
    
    out = model(x.cuda())
    preds = out.argmax(dim=-1, keepdim=True).view_as(y)
    
    # get criterion
    criterion = find_right_model(
        CRITERION_DIR, arguments['prune_criterion'],
        model=model,
        limit=arguments['pruning_limit'],
        start=0.5,
        steps=arguments['snip_steps'],
        device=arguments['device'],
        arguments=arguments
    )
    
    criterion.prune(arguments['pruning_limit'],
                    train_loader=[(x, preds)],
                      ood_loader=None,
                      local=arguments['local_pruning'],
                      manager=None)
    
    layer_norms = []
    for j, (grad1, grad2) in enumerate(zip(orig_grads.values(), criterion.grads_abs.values())):
        if True:
#         if j not in [0, 7, 11, 20]:
            layer_norms.append(torch.norm(grad1 - grad2, p=2).cpu().numpy())
    attack_per_layer_norms.append(layer_norms)
    ood_norms.append(np.mean(layer_norms))

In [None]:
# DS
class DS(Dataset):

    def __init__(self, images, labels):
        self.images = images
        self.labels = labels
        self.mean = [0.4914, 0.4822, 0.4465]
        self.std = [0.2471, 0.2435, 0.2616]
        self.transforms = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(mean=self.mean, std=self.std)
            ]
        )

    def __getitem__(self, item):
        image = self.images[item] / 255
        image = self.transforms(image.transpose((1, 2, 0)))
        return image.to(torch.float32), torch.tensor(self.labels[item], dtype=torch.float32)

    def __len__(self):
        return len(self.images)

ds_path = os.path.join(DATASET_PATH, "cifar10_corrupted")
aurocs = []
ds_per_layer_norms = []
        
for ds_dataset_name in os.listdir(ds_path):
    if ds_dataset_name.endswith('5.npz'):
        _, test_loader = find_right_model(
            DATASETS, arguments['data_set'],
            arguments=arguments
        )
        npz_dataset = np.load(os.path.join(ds_path, ds_dataset_name))

        ds_dataset = DS(npz_dataset["images"], npz_dataset["labels"])
        ds_loader = torch.utils.data.DataLoader(
            ds_dataset,
            batch_size=arguments['batch_size'],
            shuffle=False,
            pin_memory=True,
            num_workers=4
        )

        ood_norms = []
        for i, (x, y) in enumerate(ds_loader):
            
            model = deepcopy(backup_model)
            model.eval()
            
            out = model(x.cuda())
            preds = out.argmax(dim=-1, keepdim=True).view_as(y)

            # get criterion
            criterion = find_right_model(
                CRITERION_DIR, arguments['prune_criterion'],
                model=model,
                limit=arguments['pruning_limit'],
                start=0.5,
                steps=arguments['snip_steps'],
                device=arguments['device'],
                arguments=arguments
            )

            criterion.prune(arguments['pruning_limit'],
                              train_loader=[(x, preds)],
                              ood_loader=None,
                              local=arguments['local_pruning'],
                              manager=None)

            layer_norms = []
            for grad1, grad2 in zip(orig_grads.values(), criterion.grads_abs.values()):
                layer_norms.append(torch.norm(grad1 - grad2, p=2).cpu().numpy())
            ds_per_layer_norms.append(layer_norms)
            ood_norms.append(np.mean(layer_norms))

        auroc = calculate_auroc(np.concatenate((np.zeros_like(norms), np.ones_like(ood_norms))), np.concatenate((norms, ood_norms)))
        print('AUROC', auroc)
        aurocs.append(auroc)
        
print('Mean AUROC', np.mean(aurocs))

In [23]:
np.mean(ood_norms)

4990.221

In [24]:
norms = np.array(norms)
ood_norms = np.array(ood_norms)

In [25]:
calculate_auroc(np.concatenate((np.zeros_like(norms), np.ones_like(ood_norms))), np.concatenate((norms, ood_norms)))

0.853601

In [None]:
# 0.9999251152073733 batch size 10 dropout eval model train, not layers 0 3 7
# 0.9999873271889401 batch size 10 model train, not layers 0 3 7
# 0.9999592933947773 batch size 10 model train, all layers

In [None]:
plt.plot(np.array(per_layer_norms).mean(0), label='In-distribution')
plt.plot(np.array(ood_per_layer_norms).mean(0), label='Out-Of-Distribution')
# plt.plot(np.array(attack_per_layer_norms).mean(0), label='Attack (FGSM \u03B5 = 8)')
# plt.plot(np.array(ds_per_layer_norms).mean(0), label='Distribution Shifts')
plt.legend()
plt.xlabel('ResNet18 layers')
plt.ylabel('L2 distance to original scores')

In [None]:
# np.array(ood_per_layer_norms).mean(0)

In [None]:
thresholds, _ = torch.topk(torch.tensor(norms), int(len(norms)*0.05), sorted=True)

In [None]:
threshold = thresholds[-1]

In [None]:
num_detected = 0
for ood_norm in ood_norms:
    if ood_norm >= threshold:
        num_detected += 1

In [None]:
num_detected / len(ood_norms)