In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from attack import attack, test_model,parse_param
import random

In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)

In [3]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar100_resnet50()


In [4]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)

In [5]:
all_param_names = all_param_names[:-2]

In [6]:
train_loaders, test_dataloaders,train_dataloader_all, test_dataloader_all = load_cifar100()
all_totals = list()
for i in range(100):
    all_totals.append(attack(train_loaders[i], all_param_names, load_cifar100_resnet50,train_dataloader_all, alpha=0.00001,num_steps=2,op="minus"))


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 2/2 [00:07<00:00,  3.58s/it]


0.0031813300848007203


100%|██████████| 2/2 [00:00<00:00,  2.13it/s]


0.0007379670739173889


  param_totals = np.array(param_totals)
  x = np.array(x)
100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.0016323286294937134


100%|██████████| 2/2 [00:00<00:00,  3.62it/s]


0.00047551922500133515


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.0039153296947479245


100%|██████████| 2/2 [00:00<00:00,  3.52it/s]


0.000615882471203804


100%|██████████| 2/2 [00:00<00:00,  3.55it/s]


0.002457998633384705


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.00018176114559173583


100%|██████████| 2/2 [00:00<00:00,  3.55it/s]


0.0028448475599288942


100%|██████████| 2/2 [00:00<00:00,  3.48it/s]


0.00027353176474571227


100%|██████████| 2/2 [00:00<00:00,  3.56it/s]


0.002053434431552887


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.0003307967036962509


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.001525224506855011


100%|██████████| 2/2 [00:00<00:00,  3.55it/s]


0.00027876313030719756


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.0024352331161499024


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.0004993379414081574


100%|██████████| 2/2 [00:00<00:00,  3.56it/s]


0.00161414635181427


100%|██████████| 2/2 [00:00<00:00,  3.54it/s]


0.00040503956377506255


100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


0.0018572715520858765


100%|██████████| 2/2 [00:00<00:00,  3.52it/s]


0.00043104073405265806


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.0017002838253974914


100%|██████████| 2/2 [00:00<00:00,  3.56it/s]


0.00023673734813928603


100%|██████████| 2/2 [00:00<00:00,  3.48it/s]


0.003962363600730896


100%|██████████| 2/2 [00:00<00:00,  3.55it/s]


0.0005002691447734833


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.0021116116642951963


100%|██████████| 2/2 [00:00<00:00,  3.56it/s]


0.00034318630397319795


100%|██████████| 2/2 [00:00<00:00,  3.21it/s]


0.002274190425872803


100%|██████████| 2/2 [00:00<00:00,  3.63it/s]


0.00044476985931396484


100%|██████████| 2/2 [00:00<00:00,  3.47it/s]


0.002760591149330139


100%|██████████| 2/2 [00:00<00:00,  3.62it/s]


0.0004781426936388016


100%|██████████| 2/2 [00:00<00:00,  3.54it/s]


0.0018909668922424316


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.00035780085623264313


100%|██████████| 2/2 [00:00<00:00,  3.55it/s]


0.0015539610385894775


100%|██████████| 2/2 [00:00<00:00,  3.54it/s]


0.00032711589336395264


100%|██████████| 2/2 [00:00<00:00,  3.55it/s]


0.0018331933617591859


100%|██████████| 2/2 [00:00<00:00,  3.60it/s]


0.0005597283542156219


100%|██████████| 2/2 [00:00<00:00,  3.62it/s]


0.0017890458106994628


100%|██████████| 2/2 [00:00<00:00,  3.63it/s]


0.00037653426826000214


100%|██████████| 2/2 [00:00<00:00,  3.56it/s]


0.002049513578414917


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.0002936095595359802


100%|██████████| 2/2 [00:00<00:00,  3.60it/s]


0.0027038747072219847


100%|██████████| 2/2 [00:00<00:00,  3.55it/s]


0.0006775761246681214


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.001924575686454773


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.0003870377838611603


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.0021579325199127196


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.0005379759669303894


100%|██████████| 2/2 [00:00<00:00,  3.54it/s]


0.002303774952888489


100%|██████████| 2/2 [00:00<00:00,  3.52it/s]


0.0007180988192558289


100%|██████████| 2/2 [00:00<00:00,  3.69it/s]


0.002233712673187256


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.0006439407765865326


100%|██████████| 2/2 [00:00<00:00,  3.69it/s]


0.0024583725929260255


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.0004214601069688797


100%|██████████| 2/2 [00:00<00:00,  3.65it/s]


0.00460221004486084


100%|██████████| 2/2 [00:00<00:00,  3.65it/s]


0.0008426664471626282


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.0018749197721481324


100%|██████████| 2/2 [00:00<00:00,  3.66it/s]


0.0002621779888868332


100%|██████████| 2/2 [00:00<00:00,  3.60it/s]


0.002243573069572449


100%|██████████| 2/2 [00:00<00:00,  3.56it/s]


0.00045618076622486116


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.0017808343768119812


100%|██████████| 2/2 [00:00<00:00,  3.65it/s]


0.0003079681545495987


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.0026136074066162108


100%|██████████| 2/2 [00:00<00:00,  3.61it/s]


0.0005500466823577881


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.0023354610204696655


100%|██████████| 2/2 [00:00<00:00,  3.60it/s]


0.00044055356085300444


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.0020631835460662843


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.0001897636875510216


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.002051502823829651


100%|██████████| 2/2 [00:00<00:00,  3.63it/s]


0.00032470887899398805


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.0017985348105430603


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.00036199951171875


100%|██████████| 2/2 [00:00<00:00,  3.62it/s]


0.004099875926971436


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.0006523131132125854


100%|██████████| 2/2 [00:00<00:00,  3.65it/s]


0.0020920491218566896


100%|██████████| 2/2 [00:00<00:00,  3.52it/s]


0.0004467261880636215


100%|██████████| 2/2 [00:00<00:00,  3.56it/s]


0.0019017342329025269


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.0004107163995504379


100%|██████████| 2/2 [00:00<00:00,  3.56it/s]


0.0011850967407226562


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.00015262065082788468


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.0014040275812149047


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.0003283279836177826


100%|██████████| 2/2 [00:00<00:00,  3.55it/s]


0.0019434978365898132


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.0002686207145452499


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.0020402355194091798


100%|██████████| 2/2 [00:00<00:00,  3.52it/s]


0.0004267439693212509


100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


0.0025093562602996828


100%|██████████| 2/2 [00:00<00:00,  3.55it/s]


0.0004038912057876587


100%|██████████| 2/2 [00:00<00:00,  3.49it/s]


0.002599960207939148


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.00043743957579135895


100%|██████████| 2/2 [00:00<00:00,  3.48it/s]


0.0033656002283096315


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.0003246729224920273


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.0013990658521652222


100%|██████████| 2/2 [00:00<00:00,  3.50it/s]


0.00014462468028068542


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.0032793991565704347


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.0005161431729793548


100%|██████████| 2/2 [00:00<00:00,  3.49it/s]


0.0037098028659820555


100%|██████████| 2/2 [00:00<00:00,  3.55it/s]


0.0006431311368942261


100%|██████████| 2/2 [00:00<00:00,  3.52it/s]


0.002534903883934021


100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


0.0006388308703899384


100%|██████████| 2/2 [00:00<00:00,  3.52it/s]


0.0016906294822692871


100%|██████████| 2/2 [00:00<00:00,  3.48it/s]


0.0004556894451379776


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.0034003517031669615


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.000348372608423233


100%|██████████| 2/2 [00:00<00:00,  3.49it/s]


0.0012693256735801697


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.00020931584388017655


100%|██████████| 2/2 [00:00<00:00,  3.54it/s]


0.0027725929021835327


100%|██████████| 2/2 [00:00<00:00,  3.52it/s]


0.0005476654917001725


100%|██████████| 2/2 [00:00<00:00,  3.50it/s]


0.001658939003944397


100%|██████████| 2/2 [00:00<00:00,  3.54it/s]


0.0005671007633209228


100%|██████████| 2/2 [00:00<00:00,  3.18it/s]


0.0019475035071372987


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.0005845302045345307


100%|██████████| 2/2 [00:00<00:00,  3.54it/s]


0.007237210750579834


100%|██████████| 2/2 [00:00<00:00,  3.51it/s]


0.0012638156414031983


100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


0.002302093803882599


100%|██████████| 2/2 [00:00<00:00,  3.50it/s]


0.0005577147006988525


100%|██████████| 2/2 [00:00<00:00,  3.56it/s]


0.002362324118614197


100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


0.0005249511301517487


100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


0.001683724582195282


100%|██████████| 2/2 [00:00<00:00,  3.50it/s]


0.00046737182140350344


100%|██████████| 2/2 [00:00<00:00,  3.51it/s]


0.0024611446857452394


100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


0.00042090250551700593


100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


0.0028953393697738646


100%|██████████| 2/2 [00:00<00:00,  3.51it/s]


0.000994887888431549


100%|██████████| 2/2 [00:00<00:00,  3.46it/s]


0.0025112648010253906


100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


0.0006520422399044037


100%|██████████| 2/2 [00:00<00:00,  3.45it/s]


0.0018911019563674926


100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


0.00043745553493499756


100%|██████████| 2/2 [00:00<00:00,  3.50it/s]


0.004508472740650177


100%|██████████| 2/2 [00:00<00:00,  3.51it/s]


0.0007812778949737549


100%|██████████| 2/2 [00:00<00:00,  3.51it/s]


0.0018520235419273376


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.0002586245760321617


100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


0.0019630789756774902


100%|██████████| 2/2 [00:00<00:00,  3.49it/s]


0.00016604235023260117


100%|██████████| 2/2 [00:00<00:00,  3.55it/s]


0.002053440272808075


100%|██████████| 2/2 [00:00<00:00,  3.50it/s]


0.000419806569814682


100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


0.0016453340649604797


100%|██████████| 2/2 [00:00<00:00,  3.46it/s]


0.00020455296337604521


100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


0.002745789289474487


100%|██████████| 2/2 [00:00<00:00,  3.52it/s]


0.000684714525938034


100%|██████████| 2/2 [00:00<00:00,  3.49it/s]


0.002637876272201538


100%|██████████| 2/2 [00:00<00:00,  3.56it/s]


0.0005753113925457001


100%|██████████| 2/2 [00:00<00:00,  3.54it/s]


0.003765004277229309


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.000883078783750534


100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


0.002347975492477417


100%|██████████| 2/2 [00:00<00:00,  3.55it/s]


0.0006305276453495025


100%|██████████| 2/2 [00:00<00:00,  3.55it/s]


0.0075969982147216794


100%|██████████| 2/2 [00:00<00:00,  3.51it/s]


0.0011544549837708474


100%|██████████| 2/2 [00:00<00:00,  3.47it/s]


0.003745728850364685


100%|██████████| 2/2 [00:00<00:00,  3.50it/s]


0.0007458879500627518


100%|██████████| 2/2 [00:00<00:00,  3.56it/s]


0.004996059417724609


100%|██████████| 2/2 [00:00<00:00,  3.54it/s]


0.0008259548544883728


100%|██████████| 2/2 [00:00<00:00,  3.60it/s]


0.0016973634958267212


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.0003401660621166229


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.002091968297958374


100%|██████████| 2/2 [00:00<00:00,  3.62it/s]


0.0004942563772201538


100%|██████████| 2/2 [00:00<00:00,  3.60it/s]


0.0019912765622138976


100%|██████████| 2/2 [00:00<00:00,  3.57it/s]


0.0003322077244520187


100%|██████████| 2/2 [00:00<00:00,  3.61it/s]


0.0065256433486938475


100%|██████████| 2/2 [00:00<00:00,  3.61it/s]


0.0014070499390363693


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.00360040819644928


100%|██████████| 2/2 [00:00<00:00,  3.66it/s]


0.0007562032341957093


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.0013743156194686889


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.00019718803465366364


100%|██████████| 2/2 [00:00<00:00,  3.66it/s]


0.00229229998588562


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.0004102607071399689


100%|██████████| 2/2 [00:00<00:00,  3.63it/s]


0.002233386516571045


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.0006007132530212403


100%|██████████| 2/2 [00:00<00:00,  3.61it/s]


0.0014083505868911744


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.00031627567112445834


100%|██████████| 2/2 [00:00<00:00,  3.64it/s]


0.001432065784931183


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.0001882425993680954


100%|██████████| 2/2 [00:00<00:00,  3.60it/s]


0.0020936247110366823


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.0005493069887161255


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.0019848989248275758


100%|██████████| 2/2 [00:00<00:00,  3.60it/s]


0.00036268402636051175


100%|██████████| 2/2 [00:00<00:00,  3.61it/s]


0.0018527384996414186


100%|██████████| 2/2 [00:00<00:00,  3.56it/s]


0.00044000692665576937


100%|██████████| 2/2 [00:00<00:00,  3.53it/s]


0.0019447197318077088


100%|██████████| 2/2 [00:00<00:00,  3.63it/s]


0.000470348060131073


100%|██████████| 2/2 [00:00<00:00,  3.52it/s]


0.0019007456302642823


100%|██████████| 2/2 [00:00<00:00,  3.56it/s]


0.000404536172747612


100%|██████████| 2/2 [00:00<00:00,  3.59it/s]


0.002275338888168335


100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


0.00043727871775627134


100%|██████████| 2/2 [00:00<00:00,  3.60it/s]


0.002490786552429199


100%|██████████| 2/2 [00:00<00:00,  3.60it/s]


0.0006170926988124847


100%|██████████| 2/2 [00:00<00:00,  3.54it/s]


0.0020627627968788148


100%|██████████| 2/2 [00:00<00:00,  3.52it/s]


0.00044557179510593413


100%|██████████| 2/2 [00:00<00:00,  3.48it/s]


0.0022285325527191162


100%|██████████| 2/2 [00:00<00:00,  3.52it/s]


0.0003089916855096817


100%|██████████| 2/2 [00:00<00:00,  3.45it/s]


0.0027592967748641967


100%|██████████| 2/2 [00:00<00:00,  3.21it/s]


0.0009354405105113983


100%|██████████| 2/2 [00:00<00:00,  3.05it/s]


0.0042532234191894535


100%|██████████| 2/2 [00:00<00:00,  3.37it/s]


0.00102148100733757


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.002394564628601074


100%|██████████| 2/2 [00:00<00:00,  3.48it/s]


0.0003036593347787857


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


0.0014396040439605712


100%|██████████| 2/2 [00:00<00:00,  3.36it/s]


0.00027472166717052457


100%|██████████| 2/2 [00:00<00:00,  3.45it/s]


0.0014466434121131898


100%|██████████| 2/2 [00:00<00:00,  3.51it/s]


0.00021125011891126632


100%|██████████| 2/2 [00:00<00:00,  3.40it/s]


0.008622812509536743


100%|██████████| 2/2 [00:00<00:00,  3.44it/s]


0.0023981708288192747


In [7]:
pkl.dump(all_totals, open("weights/minus_totals.pkl", "wb"))

In [8]:
thre = 0.25
net = load_cifar100_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight) for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    for idx,param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t

  combine = np.array(combine)


In [9]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())

conv1.weight 0.9976851851851852
layer1.0.conv1.weight 0.961181640625
layer1.0.conv2.weight 0.8777126736111112
layer1.0.conv3.weight 0.88641357421875
layer1.0.shortcut.0.weight 0.8897705078125
layer1.1.conv1.weight 0.846923828125
layer1.1.conv2.weight 0.8989529079861112
layer1.1.conv3.weight 0.851806640625
layer1.2.conv1.weight 0.84613037109375
layer1.2.conv2.weight 0.8301866319444444
layer1.2.conv3.weight 0.67303466796875
layer2.0.conv1.weight 0.9505615234375
layer2.0.conv2.weight 0.8661092122395834
layer2.0.conv3.weight 0.8997344970703125
layer2.0.shortcut.0.weight 0.8480072021484375
layer2.1.conv1.weight 0.70001220703125
layer2.1.conv2.weight 0.80914306640625
layer2.1.conv3.weight 0.8157196044921875
layer2.2.conv1.weight 0.7675628662109375
layer2.2.conv2.weight 0.8313259548611112
layer2.2.conv3.weight 0.8278045654296875
layer2.3.conv1.weight 0.82562255859375
layer2.3.conv2.weight 0.8433702256944444
layer2.3.conv3.weight 0.778533935546875
layer3.0.conv1.weight 0.9455184936523438
layer

In [10]:
temp / all_num

0.8576080963416065

In [11]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())

原始准确率 0.7929


In [12]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())

现在准确率 0.7918


In [13]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval("net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())

对比实验准确率 0.01
