In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from attack import attack, test_model, parse_param
import random
device = "cuda" if torch.cuda.is_available() else "cpu"


In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)


In [3]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar100_resnet50()


In [4]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)


In [5]:
all_param_names = all_param_names[:-2]


In [6]:
train_loaders, test_dataloaders, train_dataloader_all, test_dataloader_all = load_cifar100()
all_totals = list()
for i in range(100):
    all_totals.append(attack(train_loaders[i], all_param_names, load_cifar100_resnet50,
                      train_dataloader_all, alpha=0.00001, num_steps=4, op="add"))


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 2/2 [00:03<00:00,  1.93s/it]


0.0031813300848007203


100%|██████████| 2/2 [00:00<00:00,  3.20it/s]


0.01804310417175293


100%|██████████| 2/2 [00:00<00:00,  3.26it/s]


0.2327025146484375


100%|██████████| 2/2 [00:00<00:00,  3.27it/s]


1.4608602905273438


  param_totals = np.array(param_totals)
  x = np.array(x)
100%|██████████| 2/2 [00:00<00:00,  3.27it/s]


0.0016323286294937134


100%|██████████| 2/2 [00:00<00:00,  3.29it/s]


0.015500137329101563


100%|██████████| 2/2 [00:00<00:00,  3.29it/s]


0.2118052978515625


100%|██████████| 2/2 [00:00<00:00,  3.23it/s]


1.4502610473632813


100%|██████████| 2/2 [00:00<00:00,  3.43it/s]


0.003915329337120056


100%|██████████| 2/2 [00:00<00:00,  3.45it/s]


0.050074665069580075


100%|██████████| 2/2 [00:00<00:00,  3.46it/s]


0.9679737243652343


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


4.147788818359375


100%|██████████| 2/2 [00:00<00:00,  3.43it/s]


0.002457998752593994


100%|██████████| 2/2 [00:00<00:00,  3.45it/s]


0.06849658393859863


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


0.9536296691894531


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


3.539983154296875


100%|██████████| 2/2 [00:00<00:00,  3.44it/s]


0.0028448475003242492


100%|██████████| 2/2 [00:00<00:00,  3.43it/s]


0.05297870826721191


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


0.8534326171875


100%|██████████| 2/2 [00:00<00:00,  3.47it/s]


3.51308203125


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.0020534344911575316


100%|██████████| 2/2 [00:00<00:00,  3.45it/s]


0.03537442111968994


100%|██████████| 2/2 [00:00<00:00,  3.44it/s]


0.5405798797607422


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


2.5403856201171875


100%|██████████| 2/2 [00:00<00:00,  3.43it/s]


0.001525224506855011


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.023830718040466308


100%|██████████| 2/2 [00:00<00:00,  3.43it/s]


0.4486836853027344


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


2.52093115234375


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


0.0024352328777313232


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.03335247039794922


100%|██████████| 2/2 [00:00<00:00,  3.43it/s]


0.27660484313964845


100%|██████████| 2/2 [00:00<00:00,  3.44it/s]


1.6127510986328124


100%|██████████| 2/2 [00:00<00:00,  3.43it/s]


0.001614146411418915


100%|██████████| 2/2 [00:00<00:00,  3.47it/s]


0.009982490539550781


100%|██████████| 2/2 [00:00<00:00,  3.45it/s]


0.15415506744384766


100%|██████████| 2/2 [00:00<00:00,  3.44it/s]


1.1547877197265626


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.0018572714924812318


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.016225821495056154


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.1505292510986328


100%|██████████| 2/2 [00:00<00:00,  3.43it/s]


0.7646273193359375


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.001700283944606781


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


0.03814146995544434


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.4685605010986328


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


2.4050953369140626


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.003962364077568054


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.07003918647766114


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


0.6729427185058594


100%|██████████| 2/2 [00:00<00:00,  3.46it/s]


3.2234678955078127


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


0.0021116113662719726


100%|██████████| 2/2 [00:00<00:00,  3.46it/s]


0.04582867431640625


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.4266449737548828


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


2.242073486328125


100%|██████████| 2/2 [00:00<00:00,  3.47it/s]


0.0022741904854774475


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.03872584533691406


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.49928488159179685


100%|██████████| 2/2 [00:00<00:00,  3.47it/s]


2.5383388671875


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.0027605905532836913


100%|██████████| 2/2 [00:00<00:00,  3.45it/s]


0.052844209671020506


100%|██████████| 2/2 [00:00<00:00,  3.45it/s]


0.5538495635986328


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


2.6174725341796874


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.0018909668326377868


100%|██████████| 2/2 [00:00<00:00,  3.45it/s]


0.03164355182647705


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.5378310852050782


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


2.5711434326171876


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.0015539608597755431


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


0.02441586685180664


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.3834039764404297


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


1.9401480712890624


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.0018331934213638307


100%|██████████| 2/2 [00:00<00:00,  3.46it/s]


0.019736997604370116


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.36923162841796875


100%|██████████| 2/2 [00:00<00:00,  3.43it/s]


2.361777587890625


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.001789045751094818


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.02667636203765869


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.43800543212890625


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


2.208899658203125


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.002049513578414917


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.05065212821960449


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.8377701416015625


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


3.3496339111328126


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


0.0027038747072219847


100%|██████████| 2/2 [00:00<00:00,  3.43it/s]


0.03269764709472656


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.2854729461669922


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


1.6062218017578125


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.001924575686454773


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.028758222579956055


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.49864451599121096


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


2.6657978515625


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.00215793240070343


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.019010735511779786


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.2748946304321289


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


1.7247564697265625


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


0.002303774952888489


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.019509175300598144


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.24362542724609376


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


1.5202689208984375


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.002233712673187256


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


0.023307215690612792


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


0.3710262451171875


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


2.3234739990234377


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.0024583723545074463


100%|██████████| 2/2 [00:00<00:00,  3.43it/s]


0.045529108047485355


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.5745510559082031


100%|██████████| 2/2 [00:00<00:00,  3.30it/s]


2.571861083984375


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.004602209627628326


100%|██████████| 2/2 [00:00<00:00,  3.15it/s]


0.021571642875671388


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


0.3106897430419922


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


2.01523095703125


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


0.0018749197125434876


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.059639917373657224


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


1.1474668579101563


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


4.295735107421875


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.0022435731291770935


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.0236461820602417


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.23902468872070312


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


1.38641845703125


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.0017808344960212707


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.03402294254302979


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.38658029174804687


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


1.9916712036132813


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


0.0026136074066162108


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.04765525245666504


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


0.8380601196289063


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


3.6758359375


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.0023354610204696655


100%|██████████| 2/2 [00:00<00:00,  3.27it/s]


0.04827343559265137


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.8240353698730469


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


4.041028686523438


100%|██████████| 2/2 [00:00<00:00,  3.30it/s]


0.0020631833672523497


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


0.04974050712585449


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.6045473022460938


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


2.5600233154296874


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.002051503002643585


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.0429948787689209


100%|██████████| 2/2 [00:00<00:00,  3.43it/s]


0.6399642639160156


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


2.9363934326171877


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


0.0017985346913337707


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.016857285499572754


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.2530683059692383


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


1.6509220581054687


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


0.004099875450134277


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


0.06219531059265137


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.9730048522949218


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


4.548981689453125


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.002092049181461334


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


0.028981956481933593


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.4438250885009766


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


2.5961241455078126


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


0.0019017340540885925


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


0.03011072540283203


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


0.5494891357421875


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


2.6479361572265625


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]

0.0011850967407226562



100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.03267460346221924


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.6234982604980469


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


2.9363031005859375


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.00140402752161026


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.015214086055755615


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.17721123123168944


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


1.1030640869140624


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.0019434980154037476


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.03391250228881836


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.3731060485839844


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


1.8347728881835939


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.002040235698223114


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.0274454402923584


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.37743043518066405


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


1.7749497680664061


100%|██████████| 2/2 [00:00<00:00,  3.29it/s]


0.0025093563795089723


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


0.04318592643737793


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.6131150512695313


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


2.839159423828125


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.002599960088729858


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.022429169178009033


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.25434517669677736


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


1.74479638671875


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


0.0033656004667282106


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


0.08319490432739257


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.8850337219238281


100%|██████████| 2/2 [00:00<00:00,  3.27it/s]


3.2345242919921877


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.0013990659713745118


100%|██████████| 2/2 [00:00<00:00,  3.28it/s]


0.032636924743652346


100%|██████████| 2/2 [00:00<00:00,  3.29it/s]


0.5951671142578125


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


2.79584521484375


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.0032793992757797243


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.06887889862060546


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.8339978637695312


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


3.54195849609375


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.003709802746772766


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.07838018798828125


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.7680115051269532


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


3.03258349609375


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.002534903407096863


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.04995415306091309


100%|██████████| 2/2 [00:00<00:00,  3.18it/s]


0.9297553405761719


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


3.9706959228515624


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.0016906291842460632


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.015552757263183593


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


0.28199166107177737


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


1.8576693115234375


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


0.0034003520607948303


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.031548357963562014


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.40592514038085936


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


2.3138944091796874


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.0012693257331848145


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.029391000747680664


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.4354347076416016


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


2.3601778564453126


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.0027725930213928223


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.03186031818389892


100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


0.37288125610351563


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


2.32049267578125


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.0016589387655258179


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.008312098026275635


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.1588567886352539


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


1.3100635986328124


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.0019475035667419433


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.016126205444335937


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


0.2812620391845703


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


1.8909940795898437


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.007237210512161255


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.06774565505981445


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.773346923828125


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


3.401802978515625


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


0.0023020942211151125


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


0.044201203346252445


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.3860505218505859


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


1.8929332275390625


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.002362324118614197


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


0.021412975311279298


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.28298975372314455


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


1.6154237670898437


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.0016837249398231507


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.010528790473937987


100%|██████████| 2/2 [00:00<00:00,  3.43it/s]


0.16061515426635742


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


1.1390914916992188


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.002461145281791687


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.04657129955291748


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.4880654754638672


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


2.421908203125


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.002895339488983154


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.026953490734100343


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


0.17686556243896484


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


1.0098727416992188


100%|██████████| 2/2 [00:00<00:00,  3.43it/s]


0.0025112645626068114


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


0.038903873443603514


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.7984542541503906


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


3.8964908447265625


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


0.0018911018967628479


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.03303088188171387


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


0.8649474182128907


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


3.854303955078125


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.004508473396301269


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.03596896362304688


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.3235175170898438


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


1.7533405151367187


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.0018520236015319824


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.034898126602172855


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


0.5759335327148437


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


2.6583804931640627


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.0019630788564682007


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.06315837860107422


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.7972942199707032


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


2.9906888427734377


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.0020534401535987853


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.029087573051452636


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.36177252197265625


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


1.8685194091796875


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


0.0016453341841697693


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.04199362945556641


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.6484411926269531


100%|██████████| 2/2 [00:00<00:00,  3.04it/s]


2.8219779052734375


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.0027457895278930663


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.021097450256347655


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.2699742736816406


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


1.4283143310546875


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.002637876272201538


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.033667774200439454


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.48246197509765626


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


2.1032646484375


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.003765004634857178


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.05015884685516357


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.5297950897216797


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


2.8442032470703125


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.002347975492477417


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.024272380828857423


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.24469078826904297


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


1.30290234375


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.007596997499465942


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.07619099807739257


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.9236793823242188


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


3.6669002685546874


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.0037457281351089477


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.032863707542419435


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


0.3879471893310547


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


1.9847998657226562


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.004996059656143188


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.04690264320373535


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.5805343627929688


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


2.96462939453125


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.0016973633766174316


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


0.019658483505249024


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.25254929351806643


100%|██████████| 2/2 [00:00<00:00,  3.30it/s]


1.5919912109375


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.0020919685363769533


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.023053266525268554


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.363121337890625


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


2.036203857421875


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


0.0019912768602371217


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.04170324325561523


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.68587451171875


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


2.9452088623046877


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.006525643587112427


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.038679574966430666


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.5097637023925782


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


2.610959228515625


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.00360040819644928


100%|██████████| 2/2 [00:00<00:00,  3.30it/s]


0.04740168762207031


100%|██████████| 2/2 [00:00<00:00,  3.29it/s]


0.4798617401123047


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


2.1964271240234376


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.0013743156790733337


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.029702077865600585


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.5129892425537109


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


2.650689453125


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.0022922998666763308


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.05508687591552734


100%|██████████| 2/2 [00:00<00:00,  3.30it/s]


0.986726806640625


100%|██████████| 2/2 [00:00<00:00,  3.30it/s]


4.029283569335938


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.002233386516571045


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.01771187400817871


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


0.2537506942749023


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


1.3500543212890626


100%|██████████| 2/2 [00:00<00:00,  3.27it/s]


0.0014083506464958192


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


0.02121392822265625


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.4856187133789063


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


2.7812598876953123


100%|██████████| 2/2 [00:00<00:00,  3.27it/s]


0.0014320658445358277


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


0.02999569320678711


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.5506112060546875


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


2.402754150390625


100%|██████████| 2/2 [00:00<00:00,  3.30it/s]


0.002093624770641327


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.02265994930267334


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.36457510375976565


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


2.1890963134765626


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


0.0019848989844322203


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


0.018363498687744142


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.252481689453125


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


1.4506549072265624


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.0018527382016181946


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


0.02383841323852539


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


0.3363791809082031


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


1.8140671997070312


100%|██████████| 2/2 [00:00<00:00,  3.39it/s]


0.0019447194933891296


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


0.015094626903533935


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.2378989028930664


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


1.8432951049804687


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


0.001900745451450348


100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


0.02556918239593506


100%|██████████| 2/2 [00:00<00:00,  3.34it/s]


0.4314019927978516


100%|██████████| 2/2 [00:00<00:00,  3.28it/s]


2.3914493408203126


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


0.0022753387689590453


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


0.040166807174682614


100%|██████████| 2/2 [00:00<00:00,  3.28it/s]


0.5037739715576172


100%|██████████| 2/2 [00:00<00:00,  3.30it/s]


2.458405029296875


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.002490786552429199


100%|██████████| 2/2 [00:00<00:00,  3.20it/s]


0.024261000633239747


100%|██████████| 2/2 [00:00<00:00,  3.14it/s]


0.2861810760498047


100%|██████████| 2/2 [00:00<00:00,  3.06it/s]


1.5949205932617188


100%|██████████| 2/2 [00:00<00:00,  3.11it/s]


0.0020627627968788148


100%|██████████| 2/2 [00:00<00:00,  3.10it/s]


0.03333019351959229


100%|██████████| 2/2 [00:00<00:00,  3.07it/s]


0.5030301818847657


100%|██████████| 2/2 [00:00<00:00,  3.11it/s]


2.7011004638671876


100%|██████████| 2/2 [00:00<00:00,  3.22it/s]


0.0022285321950912475


100%|██████████| 2/2 [00:00<00:00,  3.20it/s]


0.04771722030639648


100%|██████████| 2/2 [00:00<00:00,  3.14it/s]


0.7083925476074219


100%|██████████| 2/2 [00:00<00:00,  3.16it/s]


3.0778175048828125


100%|██████████| 2/2 [00:00<00:00,  3.17it/s]


0.002759297013282776


100%|██████████| 2/2 [00:00<00:00,  3.15it/s]


0.03204127311706543


100%|██████████| 2/2 [00:00<00:00,  3.14it/s]


0.3730016326904297


100%|██████████| 2/2 [00:00<00:00,  3.12it/s]


1.6673712158203124


100%|██████████| 2/2 [00:00<00:00,  3.14it/s]


0.004253222703933716


100%|██████████| 2/2 [00:00<00:00,  3.17it/s]


0.02120113182067871


100%|██████████| 2/2 [00:00<00:00,  3.07it/s]


0.2525477066040039


100%|██████████| 2/2 [00:00<00:00,  3.12it/s]


1.691107666015625


100%|██████████| 2/2 [00:00<00:00,  3.22it/s]


0.0023945645093917845


100%|██████████| 2/2 [00:00<00:00,  3.19it/s]


0.05380408477783203


100%|██████████| 2/2 [00:00<00:00,  3.14it/s]


0.7839984130859375


100%|██████████| 2/2 [00:00<00:00,  3.13it/s]


3.24824072265625


100%|██████████| 2/2 [00:00<00:00,  3.17it/s]


0.0014396038055419923


100%|██████████| 2/2 [00:00<00:00,  3.27it/s]


0.01606134557723999


100%|██████████| 2/2 [00:00<00:00,  3.25it/s]


0.2913730773925781


100%|██████████| 2/2 [00:00<00:00,  3.35it/s]


1.7850098266601562


100%|██████████| 2/2 [00:00<00:00,  3.33it/s]


0.001446643352508545


100%|██████████| 2/2 [00:00<00:00,  3.31it/s]


0.03825680351257324


100%|██████████| 2/2 [00:00<00:00,  3.29it/s]


0.7454845581054688


100%|██████████| 2/2 [00:00<00:00,  3.30it/s]

3.905684814453125



100%|██████████| 2/2 [00:00<00:00,  3.29it/s]


0.008622812032699585


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.03505203628540039


100%|██████████| 2/2 [00:00<00:00,  3.32it/s]


0.39028611755371095


100%|██████████| 2/2 [00:00<00:00,  3.28it/s]


1.9754710693359374


In [7]:
pkl.dump(all_totals, open("weights/totals.pkl", "wb"))


In [20]:
thre = 0.07
net = load_cifar100_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight)
               for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate(
        [combine_.flatten() for combine_ in combine], axis=0)
    threshold = np.sort(combine_flatten)[
        ::-1][int(len(combine_flatten) * thre)]
    for idx, param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t


  combine = np.array(combine)


In [21]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())


conv1.weight 0.9936342592592593
layer1.0.conv1.weight 0.883544921875
layer1.0.conv2.weight 0.6957736545138888
layer1.0.conv3.weight 0.76568603515625
layer1.0.shortcut.0.weight 0.80279541015625
layer1.1.conv1.weight 0.65869140625
layer1.1.conv2.weight 0.7159016927083334
layer1.1.conv3.weight 0.73040771484375
layer1.2.conv1.weight 0.66961669921875
layer1.2.conv2.weight 0.6825900607638888
layer1.2.conv3.weight 0.53460693359375
layer2.0.conv1.weight 0.8699951171875
layer2.0.conv2.weight 0.6604682074652778
layer2.0.conv3.weight 0.7593841552734375
layer2.0.shortcut.0.weight 0.655487060546875
layer2.1.conv1.weight 0.345794677734375
layer2.1.conv2.weight 0.5167914496527778
layer2.1.conv3.weight 0.6204986572265625
layer2.2.conv1.weight 0.4469451904296875
layer2.2.conv2.weight 0.5791219075520834
layer2.2.conv3.weight 0.6329803466796875
layer2.3.conv1.weight 0.540374755859375
layer2.3.conv2.weight 0.6033528645833334
layer2.3.conv3.weight 0.5842437744140625
layer3.0.conv1.weight 0.8491668701171875

In [22]:
temp / all_num


0.4880513827815582

In [23]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())


原始准确率 0.7929


In [24]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())


现在准确率 0.4417


In [25]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval(
            "net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(
            len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())


对比实验准确率 0.01
