In [6]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
from tqdm.notebook import tqdm
import pickle as pkl
from attack import attack, test_model,parse_param
from utils import caculate_param_remove
import random

In [7]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)

In [8]:
from datasets import load_cifar10
from models.resnet import load_cifar10_resnet50
model = load_cifar10_resnet50()


In [9]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)
all_param_names = all_param_names[:-2]

In [10]:
train_loaders, test_dataloaders, train_dataloader_all, test_dataloader_all = load_cifar10()
all_totals = list()
for i in tqdm(range(10)):
    all_totals.append(attack(train_loaders[i], all_param_names,
                      load_cifar10_resnet50, norm=False, alpha=0.00001, num_steps=2, op="minus"))
    all_totals.append(attack(train_loaders[i], all_param_names,
                      load_cifar10_resnet50, norm=False, alpha=0.00001, num_steps=4, op="add"))


Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

0.01127861852645874


  0%|          | 0/20 [00:00<?, ?it/s]

0.01076155652999878


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.009953450536727906


  0%|          | 0/20 [00:00<?, ?it/s]

0.013396762561798095


  0%|          | 0/20 [00:00<?, ?it/s]

0.3823133995056152


  0%|          | 0/20 [00:00<?, ?it/s]

1.9999164947509767


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

0.011078875494003295


  0%|          | 0/20 [00:00<?, ?it/s]

0.010742098617553712


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.010153193473815918


  0%|          | 0/20 [00:00<?, ?it/s]

-0.0011883275508880616


  0%|          | 0/20 [00:00<?, ?it/s]

0.15676350021362304


  0%|          | 0/20 [00:00<?, ?it/s]

1.0295449340820313


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

0.011200893592834472


  0%|          | 0/20 [00:00<?, ?it/s]

0.010747618865966797


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.010031175756454468


  0%|          | 0/20 [00:00<?, ?it/s]

0.00521306574344635


  0%|          | 0/20 [00:00<?, ?it/s]

0.20862971954345702


  0%|          | 0/20 [00:00<?, ?it/s]

1.252725799560547


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

0.011403035116195679


  0%|          | 0/20 [00:00<?, ?it/s]

0.010727086162567139


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.009829042243957519


  0%|          | 0/20 [00:00<?, ?it/s]

0.01618811483383179


  0%|          | 0/20 [00:00<?, ?it/s]

0.39722023086547853


  0%|          | 0/20 [00:00<?, ?it/s]

2.036959246826172


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

0.011223393964767457


  0%|          | 0/20 [00:00<?, ?it/s]

0.010756812143325805


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.010008693647384643


  0%|          | 0/20 [00:00<?, ?it/s]

0.0010546480655670167


  0%|          | 0/20 [00:00<?, ?it/s]

0.17614929122924805


  0%|          | 0/20 [00:00<?, ?it/s]

1.1831684509277345


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

0.011211630344390868


  0%|          | 0/20 [00:00<?, ?it/s]

0.010716063213348389


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.010020458126068115


  0%|          | 0/20 [00:00<?, ?it/s]

0.007608432722091675


  0%|          | 0/20 [00:00<?, ?it/s]

0.2938665786743164


  0%|          | 0/20 [00:00<?, ?it/s]

1.7084033203125


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

0.01123682427406311


  0%|          | 0/20 [00:00<?, ?it/s]

0.010757255935668945


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.00999524381160736


  0%|          | 0/20 [00:00<?, ?it/s]

0.00045478434562683104


  0%|          | 0/20 [00:00<?, ?it/s]

0.19508163795471192


  0%|          | 0/20 [00:00<?, ?it/s]

1.4661389221191405


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

0.01109211769104004


  0%|          | 0/20 [00:00<?, ?it/s]

0.010739185810089112


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.010139948272705078


  0%|          | 0/20 [00:00<?, ?it/s]

0.0011067546129226685


  0%|          | 0/20 [00:00<?, ?it/s]

0.1700308521270752


  0%|          | 0/20 [00:00<?, ?it/s]

1.1503700958251952


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

0.01102761287689209


  0%|          | 0/20 [00:00<?, ?it/s]

0.010712595891952515


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.010204454612731933


  0%|          | 0/20 [00:00<?, ?it/s]

-7.338104248046875e-05


  0%|          | 0/20 [00:00<?, ?it/s]

0.20974421520233155


  0%|          | 0/20 [00:00<?, ?it/s]

1.4371024353027344


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

0.011148411130905152


  0%|          | 0/20 [00:00<?, ?it/s]

0.010732812929153443


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.010083656120300292


  0%|          | 0/20 [00:00<?, ?it/s]

0.0061279740095138546


  0%|          | 0/20 [00:00<?, ?it/s]

0.3244103172302246


  0%|          | 0/20 [00:00<?, ?it/s]

1.8952925048828124


In [None]:
all_totals_temp = list()
from utils import normalization
for i in range(0,len(all_totals),2):
    total_0 = all_totals[i]
    total_1 = all_totals[i+1]
    total = dict()
    total_values = list()
    for key in list(total_0.keys()):
        total_values.append(total_0[key] + total_1[key])
    total_values = np.array(total_values)
    total_values = normalization(abs(total_values))
    for key in list(total_0.keys()):
        total[key] = total_values[list(total_0.keys()).index(key)]
    all_totals_temp.append(total)


In [None]:
all_totals_clones = all_totals.copy()

In [None]:
all_totals = all_totals_temp

In [None]:
len(all_totals)

In [None]:
pkl.dump(all_totals, open("weights/op_totals.pkl", "wb"))

In [None]:
thre = 0.25
net = load_cifar10_resnet50()
param_remove = caculate_param_remove(all_param_names, all_totals, net, thre)

In [None]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())

In [None]:
temp / all_num

In [None]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())

In [None]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())

In [None]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval("net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())