In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from attack import attack, test_model,parse_param
import random

In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)

In [3]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar10_resnet50()


In [4]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)

In [5]:
all_param_names = all_param_names[:-2]

In [6]:
train_loaders, test_dataloaders, test_dataloader_all = load_cifar10()
all_totals = list()
for i in range(10):
    all_totals.append(attack(train_loaders[i], all_param_names, load_cifar10_resnet50, alpha=0.00001,num_steps=3,op="add"))


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 20/20 [00:20<00:00,  1.02s/it]


0.0006625839769840241


100%|██████████| 20/20 [00:09<00:00,  2.09it/s]


0.02401313738822937


100%|██████████| 20/20 [00:09<00:00,  2.17it/s]


0.3929096221923828


  param_totals = np.array(param_totals)
  x = np.array(x)
100%|██████████| 20/20 [00:09<00:00,  2.14it/s]


0.00046284100264310836


100%|██████████| 20/20 [00:09<00:00,  2.13it/s]


0.00942666541337967


100%|██████████| 20/20 [00:09<00:00,  2.08it/s]


0.16735419273376465


100%|██████████| 20/20 [00:09<00:00,  2.07it/s]


0.0005848659977316856


100%|██████████| 20/20 [00:09<00:00,  2.04it/s]


0.0158255264043808


100%|██████████| 20/20 [00:10<00:00,  1.94it/s]


0.21922230644226073


100%|██████████| 20/20 [00:09<00:00,  2.01it/s]


0.000786997240781784


100%|██████████| 20/20 [00:09<00:00,  2.09it/s]


0.026806430077552797


100%|██████████| 20/20 [00:09<00:00,  2.12it/s]


0.40782703857421876


100%|██████████| 20/20 [00:09<00:00,  2.08it/s]


0.0006073593527078628


100%|██████████| 20/20 [00:09<00:00,  2.03it/s]


0.011670439040660858


100%|██████████| 20/20 [00:09<00:00,  2.05it/s]


0.18673768348693848


100%|██████████| 20/20 [00:09<00:00,  2.16it/s]


0.0005955943539738655


100%|██████████| 20/20 [00:09<00:00,  2.20it/s]


0.018226849114894868


100%|██████████| 20/20 [00:09<00:00,  2.18it/s]


0.3045005058288574


100%|██████████| 20/20 [00:09<00:00,  2.14it/s]


0.0006207883909344674


100%|██████████| 20/20 [00:09<00:00,  2.02it/s]


0.011072832882404328


100%|██████████| 20/20 [00:09<00:00,  2.09it/s]


0.20568238143920897


100%|██████████| 20/20 [00:09<00:00,  2.15it/s]


0.0004760829672217369


100%|██████████| 20/20 [00:09<00:00,  2.13it/s]


0.011725876915454864


100%|██████████| 20/20 [00:09<00:00,  2.12it/s]


0.18064193229675293


100%|██████████| 20/20 [00:09<00:00,  2.11it/s]


0.0004115740329027176


100%|██████████| 20/20 [00:09<00:00,  2.14it/s]


0.010544984757900238


100%|██████████| 20/20 [00:09<00:00,  2.14it/s]


0.22035617218017578


100%|██████████| 20/20 [00:09<00:00,  2.09it/s]


0.0005323756381869317


100%|██████████| 20/20 [00:09<00:00,  2.08it/s]


0.01674596917629242


100%|██████████| 20/20 [00:09<00:00,  2.12it/s]


0.335018994140625


In [7]:
# pkl.dump(all_totals, open("weights/totals.pkl", "wb"))
# all_totals = pkl.load(open("weights/totals.pkl", "rb"))

In [12]:
thre = 0.1
net = load_cifar10_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [total for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    for idx,param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t

  combine = np.array(combine)


In [13]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())

conv1.weight 0.9936342592592593
layer1.0.conv1.weight 0.783203125
layer1.0.conv2.weight 0.5836046006944444
layer1.0.conv3.weight 0.75164794921875
layer1.0.shortcut.0.weight 0.7431640625
layer1.1.conv1.weight 0.41241455078125
layer1.1.conv2.weight 0.5375434027777778
layer1.1.conv3.weight 0.74359130859375
layer1.2.conv1.weight 0.5316162109375
layer1.2.conv2.weight 0.6970757378472222
layer1.2.conv3.weight 0.686279296875
layer2.0.conv1.weight 0.881134033203125
layer2.0.conv2.weight 0.5897962782118056
layer2.0.conv3.weight 0.7638092041015625
layer2.0.shortcut.0.weight 0.6974945068359375
layer2.1.conv1.weight 0.2981109619140625
layer2.1.conv2.weight 0.5779351128472222
layer2.1.conv3.weight 0.6712799072265625
layer2.2.conv1.weight 0.443450927734375
layer2.2.conv2.weight 0.5477769639756944
layer2.2.conv3.weight 0.5927581787109375
layer2.3.conv1.weight 0.549774169921875
layer2.3.conv2.weight 0.6331515842013888
layer2.3.conv3.weight 0.5400238037109375
layer3.0.conv1.weight 0.8659820556640625
lay

In [14]:
temp / all_num

0.5183197743767793

In [18]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())

原始准确率 0.954


In [19]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())

现在准确率 0.1


In [20]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval("net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())

对比实验准确率 0.1
