In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from attack_copy import attack, test_model, parse_param
import random


In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)


In [3]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar10_resnet50()


In [4]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)


In [5]:
all_param_names = all_param_names[:-2]


In [6]:
train_loaders, test_dataloaders, train_dataloader_all, test_dataloader_all = load_cifar10()
all_totals = list()
for i in range(10):
#     all_totals.append(attack(train_loaders[i], all_param_names,
#                       load_cifar10_resnet50,train_dataloader_all, alpha=0.00001, num_steps=4, op="add",clz=i))
    all_totals.append(attack(model, train_loaders, all_param_names,train_dataloader_all, alpha=0.00003, num_steps=2, op="add",clz=i))


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 20/20 [00:12<00:00,  1.57it/s]
100%|██████████| 20/20 [00:08<00:00,  2.31it/s]
100%|██████████| 20/20 [00:08<00:00,  2.38it/s]
100%|██████████| 20/20 [00:08<00:00,  2.32it/s]
100%|██████████| 20/20 [00:08<00:00,  2.29it/s]
100%|██████████| 20/20 [00:08<00:00,  2.24it/s]
100%|██████████| 20/20 [00:08<00:00,  2.29it/s]
100%|██████████| 20/20 [00:08<00:00,  2.36it/s]
100%|██████████| 20/20 [00:08<00:00,  2.36it/s]
100%|██████████| 20/20 [00:08<00:00,  2.37it/s]


0.0011041762246191502


100%|██████████| 20/20 [00:08<00:00,  2.36it/s]
100%|██████████| 20/20 [00:08<00:00,  2.38it/s]
100%|██████████| 20/20 [00:08<00:00,  2.37it/s]
100%|██████████| 20/20 [00:08<00:00,  2.31it/s]
100%|██████████| 20/20 [00:08<00:00,  2.24it/s]
100%|██████████| 20/20 [00:09<00:00,  2.15it/s]
100%|██████████| 20/20 [00:09<00:00,  2.19it/s]
100%|██████████| 20/20 [00:09<00:00,  2.07it/s]
100%|██████████| 20/20 [00:09<00:00,  2.06it/s]
100%|██████████| 20/20 [00:09<00:00,  2.03it/s]


1.4110886314223334


  param_totals = np.array(param_totals)
  x = np.array(x)
100%|██████████| 20/20 [00:08<00:00,  2.22it/s]
100%|██████████| 20/20 [00:09<00:00,  2.19it/s]
100%|██████████| 20/20 [00:09<00:00,  2.15it/s]
100%|██████████| 20/20 [00:09<00:00,  2.14it/s]
100%|██████████| 20/20 [00:09<00:00,  2.13it/s]
100%|██████████| 20/20 [00:09<00:00,  2.14it/s]
100%|██████████| 20/20 [00:09<00:00,  2.15it/s]
100%|██████████| 20/20 [00:09<00:00,  2.15it/s]
100%|██████████| 20/20 [00:09<00:00,  2.14it/s]
100%|██████████| 20/20 [00:09<00:00,  2.15it/s]


1.1602450956060737


100%|██████████| 20/20 [00:09<00:00,  2.13it/s]
100%|██████████| 20/20 [00:09<00:00,  2.11it/s]
100%|██████████| 20/20 [00:09<00:00,  2.10it/s]
100%|██████████| 20/20 [00:09<00:00,  2.12it/s]
100%|██████████| 20/20 [00:09<00:00,  2.15it/s]
100%|██████████| 20/20 [00:09<00:00,  2.01it/s]
100%|██████████| 20/20 [00:09<00:00,  2.02it/s]
100%|██████████| 20/20 [00:10<00:00,  2.00it/s]
100%|██████████| 20/20 [00:10<00:00,  1.98it/s]
100%|██████████| 20/20 [00:09<00:00,  2.01it/s]


0.44447259319350124


100%|██████████| 20/20 [00:09<00:00,  2.03it/s]
100%|██████████| 20/20 [00:10<00:00,  1.99it/s]
100%|██████████| 20/20 [00:09<00:00,  2.02it/s]
100%|██████████| 20/20 [00:10<00:00,  1.99it/s]
100%|██████████| 20/20 [00:09<00:00,  2.08it/s]
100%|██████████| 20/20 [00:07<00:00,  2.56it/s]
100%|██████████| 20/20 [00:07<00:00,  2.61it/s]
100%|██████████| 20/20 [00:07<00:00,  2.68it/s]
100%|██████████| 20/20 [00:07<00:00,  2.65it/s]
100%|██████████| 20/20 [00:07<00:00,  2.65it/s]


0.26487416280418635


100%|██████████| 20/20 [00:07<00:00,  2.58it/s]
100%|██████████| 20/20 [00:07<00:00,  2.60it/s]
100%|██████████| 20/20 [00:07<00:00,  2.59it/s]
100%|██████████| 20/20 [00:07<00:00,  2.60it/s]
100%|██████████| 20/20 [00:07<00:00,  2.62it/s]
100%|██████████| 20/20 [00:07<00:00,  2.64it/s]
100%|██████████| 20/20 [00:07<00:00,  2.65it/s]
100%|██████████| 20/20 [00:07<00:00,  2.62it/s]
100%|██████████| 20/20 [00:07<00:00,  2.60it/s]
100%|██████████| 20/20 [00:07<00:00,  2.65it/s]


0.013274485169053078


100%|██████████| 20/20 [00:07<00:00,  2.69it/s]
100%|██████████| 20/20 [00:07<00:00,  2.65it/s]
100%|██████████| 20/20 [00:07<00:00,  2.65it/s]
100%|██████████| 20/20 [00:07<00:00,  2.64it/s]
100%|██████████| 20/20 [00:07<00:00,  2.63it/s]
100%|██████████| 20/20 [00:07<00:00,  2.54it/s]
100%|██████████| 20/20 [00:07<00:00,  2.57it/s]
100%|██████████| 20/20 [00:07<00:00,  2.57it/s]
100%|██████████| 20/20 [00:07<00:00,  2.62it/s]
100%|██████████| 20/20 [00:07<00:00,  2.64it/s]


0.25106165497988464


100%|██████████| 20/20 [00:07<00:00,  2.65it/s]
100%|██████████| 20/20 [00:07<00:00,  2.65it/s]
100%|██████████| 20/20 [00:07<00:00,  2.65it/s]
100%|██████████| 20/20 [00:07<00:00,  2.62it/s]
100%|██████████| 20/20 [00:07<00:00,  2.62it/s]
100%|██████████| 20/20 [00:07<00:00,  2.63it/s]
100%|██████████| 20/20 [00:07<00:00,  2.65it/s]
100%|██████████| 20/20 [00:07<00:00,  2.65it/s]
100%|██████████| 20/20 [00:07<00:00,  2.65it/s]
100%|██████████| 20/20 [00:07<00:00,  2.65it/s]


0.11164860276013613


100%|██████████| 20/20 [00:07<00:00,  2.70it/s]
100%|██████████| 20/20 [00:07<00:00,  2.65it/s]
100%|██████████| 20/20 [00:07<00:00,  2.63it/s]
100%|██████████| 20/20 [00:07<00:00,  2.64it/s]
100%|██████████| 20/20 [00:07<00:00,  2.64it/s]
100%|██████████| 20/20 [00:07<00:00,  2.63it/s]
100%|██████████| 20/20 [00:07<00:00,  2.63it/s]
100%|██████████| 20/20 [00:07<00:00,  2.63it/s]
100%|██████████| 20/20 [00:07<00:00,  2.63it/s]
100%|██████████| 20/20 [00:07<00:00,  2.62it/s]


0.5508921235223859


100%|██████████| 20/20 [00:07<00:00,  2.67it/s]
100%|██████████| 20/20 [00:07<00:00,  2.65it/s]
100%|██████████| 20/20 [00:07<00:00,  2.67it/s]
100%|██████████| 20/20 [00:07<00:00,  2.67it/s]
100%|██████████| 20/20 [00:07<00:00,  2.69it/s]
100%|██████████| 20/20 [00:07<00:00,  2.66it/s]
100%|██████████| 20/20 [00:07<00:00,  2.65it/s]
100%|██████████| 20/20 [00:07<00:00,  2.74it/s]
100%|██████████| 20/20 [00:07<00:00,  2.77it/s]
100%|██████████| 20/20 [00:07<00:00,  2.71it/s]


0.04254670130491257


100%|██████████| 20/20 [00:07<00:00,  2.85it/s]
100%|██████████| 20/20 [00:07<00:00,  2.80it/s]
100%|██████████| 20/20 [00:07<00:00,  2.83it/s]
100%|██████████| 20/20 [00:07<00:00,  2.85it/s]
100%|██████████| 20/20 [00:07<00:00,  2.84it/s]
100%|██████████| 20/20 [00:07<00:00,  2.84it/s]
100%|██████████| 20/20 [00:07<00:00,  2.83it/s]
100%|██████████| 20/20 [00:07<00:00,  2.83it/s]
100%|██████████| 20/20 [00:07<00:00,  2.82it/s]
100%|██████████| 20/20 [00:07<00:00,  2.82it/s]


0.02256250250056386


100%|██████████| 20/20 [00:07<00:00,  2.85it/s]
100%|██████████| 20/20 [00:07<00:00,  2.81it/s]
100%|██████████| 20/20 [00:07<00:00,  2.81it/s]
100%|██████████| 20/20 [00:07<00:00,  2.74it/s]
100%|██████████| 20/20 [00:07<00:00,  2.67it/s]
100%|██████████| 20/20 [00:07<00:00,  2.69it/s]
100%|██████████| 20/20 [00:07<00:00,  2.69it/s]
100%|██████████| 20/20 [00:07<00:00,  2.69it/s]
100%|██████████| 20/20 [00:07<00:00,  2.68it/s]
100%|██████████| 20/20 [00:07<00:00,  2.59it/s]


3.120060114093274


100%|██████████| 20/20 [00:07<00:00,  2.58it/s]
100%|██████████| 20/20 [00:08<00:00,  2.38it/s]
100%|██████████| 20/20 [00:08<00:00,  2.35it/s]
100%|██████████| 20/20 [00:08<00:00,  2.35it/s]
100%|██████████| 20/20 [00:08<00:00,  2.35it/s]
100%|██████████| 20/20 [00:08<00:00,  2.39it/s]
100%|██████████| 20/20 [00:08<00:00,  2.42it/s]
100%|██████████| 20/20 [00:07<00:00,  2.68it/s]
100%|██████████| 20/20 [00:07<00:00,  2.66it/s]
100%|██████████| 20/20 [00:08<00:00,  2.40it/s]


1.6427902479723842


100%|██████████| 20/20 [00:07<00:00,  2.55it/s]
100%|██████████| 20/20 [00:07<00:00,  2.64it/s]
100%|██████████| 20/20 [00:07<00:00,  2.53it/s]
100%|██████████| 20/20 [00:08<00:00,  2.42it/s]
100%|██████████| 20/20 [00:08<00:00,  2.42it/s]
100%|██████████| 20/20 [00:08<00:00,  2.41it/s]
100%|██████████| 20/20 [00:08<00:00,  2.41it/s]
100%|██████████| 20/20 [00:08<00:00,  2.42it/s]
100%|██████████| 20/20 [00:08<00:00,  2.42it/s]
100%|██████████| 20/20 [00:08<00:00,  2.42it/s]


5.5501550300758336


100%|██████████| 20/20 [00:08<00:00,  2.44it/s]
100%|██████████| 20/20 [00:08<00:00,  2.41it/s]
100%|██████████| 20/20 [00:08<00:00,  2.42it/s]
100%|██████████| 20/20 [00:08<00:00,  2.41it/s]
100%|██████████| 20/20 [00:08<00:00,  2.42it/s]
100%|██████████| 20/20 [00:08<00:00,  2.41it/s]
100%|██████████| 20/20 [00:08<00:00,  2.42it/s]
100%|██████████| 20/20 [00:08<00:00,  2.41it/s]
100%|██████████| 20/20 [00:08<00:00,  2.41it/s]
100%|██████████| 20/20 [00:08<00:00,  2.41it/s]


4.771516485924833


100%|██████████| 20/20 [00:08<00:00,  2.43it/s]
100%|██████████| 20/20 [00:08<00:00,  2.39it/s]
100%|██████████| 20/20 [00:08<00:00,  2.40it/s]
100%|██████████| 20/20 [00:08<00:00,  2.39it/s]
100%|██████████| 20/20 [00:08<00:00,  2.39it/s]
100%|██████████| 20/20 [00:08<00:00,  2.40it/s]
100%|██████████| 20/20 [00:08<00:00,  2.40it/s]
100%|██████████| 20/20 [00:08<00:00,  2.40it/s]
100%|██████████| 20/20 [00:08<00:00,  2.40it/s]
100%|██████████| 20/20 [00:08<00:00,  2.39it/s]


11.895325623950809


100%|██████████| 20/20 [00:08<00:00,  2.45it/s]
100%|██████████| 20/20 [00:07<00:00,  2.51it/s]
100%|██████████| 20/20 [00:08<00:00,  2.45it/s]
100%|██████████| 20/20 [00:08<00:00,  2.42it/s]
100%|██████████| 20/20 [00:08<00:00,  2.42it/s]
100%|██████████| 20/20 [00:08<00:00,  2.43it/s]
100%|██████████| 20/20 [00:08<00:00,  2.42it/s]
100%|██████████| 20/20 [00:08<00:00,  2.38it/s]
100%|██████████| 20/20 [00:07<00:00,  2.55it/s]
100%|██████████| 20/20 [00:08<00:00,  2.50it/s]


7.7451128073335065


100%|██████████| 20/20 [00:08<00:00,  2.43it/s]
100%|██████████| 20/20 [00:08<00:00,  2.40it/s]
100%|██████████| 20/20 [00:08<00:00,  2.40it/s]
100%|██████████| 20/20 [00:08<00:00,  2.40it/s]
100%|██████████| 20/20 [00:08<00:00,  2.40it/s]
100%|██████████| 20/20 [00:08<00:00,  2.40it/s]
100%|██████████| 20/20 [00:08<00:00,  2.40it/s]
100%|██████████| 20/20 [00:08<00:00,  2.40it/s]
100%|██████████| 20/20 [00:08<00:00,  2.40it/s]
100%|██████████| 20/20 [00:08<00:00,  2.40it/s]


14.64521257217817


100%|██████████| 20/20 [00:08<00:00,  2.41it/s]
100%|██████████| 20/20 [00:08<00:00,  2.38it/s]
100%|██████████| 20/20 [00:08<00:00,  2.41it/s]
100%|██████████| 20/20 [00:08<00:00,  2.43it/s]
100%|██████████| 20/20 [00:08<00:00,  2.41it/s]
100%|██████████| 20/20 [00:08<00:00,  2.38it/s]
100%|██████████| 20/20 [00:07<00:00,  2.51it/s]
100%|██████████| 20/20 [00:07<00:00,  2.55it/s]
100%|██████████| 20/20 [00:07<00:00,  2.56it/s]
100%|██████████| 20/20 [00:07<00:00,  2.56it/s]


12.792322454189081


100%|██████████| 20/20 [00:07<00:00,  2.60it/s]
100%|██████████| 20/20 [00:07<00:00,  2.57it/s]
100%|██████████| 20/20 [00:07<00:00,  2.56it/s]
100%|██████████| 20/20 [00:07<00:00,  2.57it/s]
100%|██████████| 20/20 [00:07<00:00,  2.57it/s]
100%|██████████| 20/20 [00:07<00:00,  2.57it/s]
100%|██████████| 20/20 [00:07<00:00,  2.61it/s]
100%|██████████| 20/20 [00:07<00:00,  2.62it/s]
100%|██████████| 20/20 [00:07<00:00,  2.60it/s]
100%|██████████| 20/20 [00:07<00:00,  2.59it/s]


23.23955159180466


In [7]:
# pkl.dump(all_totals, open("weights/totals_copy.pkl", "wb"))
all_totals = pkl.load(open("weights/totals_copy.pkl", "rb"))

In [38]:
thre = 0.1
net = load_cifar10_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight)
               for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate(
        [combine_.flatten() for combine_ in combine], axis=0)
    threshold = np.sort(combine_flatten)[
        ::-1][int(len(combine_flatten) * thre)]
    for idx, param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t


  combine = np.array(combine)


In [39]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())


conv1.weight 0.9901620370370371
layer1.0.conv1.weight 0.863037109375
layer1.0.conv2.weight 0.6993815104166666
layer1.0.conv3.weight 0.79901123046875
layer1.0.shortcut.0.weight 0.77581787109375
layer1.1.conv1.weight 0.61553955078125
layer1.1.conv2.weight 0.7137586805555556
layer1.1.conv3.weight 0.76641845703125
layer1.2.conv1.weight 0.6905517578125
layer1.2.conv2.weight 0.7470160590277778
layer1.2.conv3.weight 0.697265625
layer2.0.conv1.weight 0.884246826171875
layer2.0.conv2.weight 0.6812811957465278
layer2.0.conv3.weight 0.758636474609375
layer2.0.shortcut.0.weight 0.66748046875
layer2.1.conv1.weight 0.4094390869140625
layer2.1.conv2.weight 0.5862630208333334
layer2.1.conv3.weight 0.677215576171875
layer2.2.conv1.weight 0.5778656005859375
layer2.2.conv2.weight 0.6144476996527778
layer2.2.conv3.weight 0.628021240234375
layer2.3.conv1.weight 0.64520263671875
layer2.3.conv2.weight 0.6556939019097222
layer2.3.conv3.weight 0.5778045654296875
layer3.0.conv1.weight 0.8221282958984375
layer3.

In [40]:
temp / all_num


0.25380953282673197

In [41]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())


现在准确率 0.1036


In [48]:
thre = 0.2
other_thre = 0.2
choose_class = 3
net = load_cifar10_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
all_classes = list(range(10))
all_classes.remove(choose_class)
all_classes.append(choose_class)
print(all_classes)
# for i in range(len(all_totals)):
for i,class_ in enumerate(all_classes):
    totals = all_totals[class_]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                    for param in all_param_names]
    combine = [np.abs(total * weight) for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    if i == 9:
        threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    else:
        threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * other_thre)]
    for idx,param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            if i == 9:
                param_remove[param] = ~param_remove[param] & t
            else:
                param_remove[param] = param_remove[param] | t


[0, 1, 2, 4, 5, 6, 7, 8, 9, 3]


  combine = np.array(combine)


In [49]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())


conv1.weight 0.0
layer1.0.conv1.weight 0.000244140625
layer1.0.conv2.weight 0.00035264756944444444
layer1.0.conv3.weight 0.00018310546875
layer1.0.shortcut.0.weight 0.00018310546875
layer1.1.conv1.weight 0.00054931640625
layer1.1.conv2.weight 0.0006510416666666666
layer1.1.conv3.weight 0.000244140625
layer1.2.conv1.weight 6.103515625e-05
layer1.2.conv2.weight 0.00010850694444444444
layer1.2.conv3.weight 0.00042724609375
layer2.0.conv1.weight 0.0
layer2.0.conv2.weight 0.00035264756944444444
layer2.0.conv3.weight 0.000213623046875
layer2.0.shortcut.0.weight 0.0004119873046875
layer2.1.conv1.weight 0.0006103515625
layer2.1.conv2.weight 0.00035942925347222225
layer2.1.conv3.weight 0.0006256103515625
layer2.2.conv1.weight 0.000152587890625
layer2.2.conv2.weight 0.00023057725694444444
layer2.2.conv3.weight 0.0003204345703125
layer2.3.conv1.weight 0.0001373291015625
layer2.3.conv2.weight 0.00031195746527777775
layer2.3.conv3.weight 0.00067138671875
layer3.0.conv1.weight 0.00040435791015625
la

In [50]:
temp / all_num


0.0018371038423639942

In [33]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())


原始准确率 0.954


In [33]:
from sklearn.metrics import confusion_matrix
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())
    print(confusion_matrix(labels,preds.argmax(-1)))


现在准确率 0.8106
[[724  11   4  38   1   0 192   0  28   2]
 [  0 983   0   3   0   0  10   0   1   3]
 [  5   1 810  43   6   0 132   0   3   0]
 [  1   3   3 892   7   4  87   0   3   0]
 [  0   0   9  24 838   0 124   1   4   0]
 [  2   8  11 324   9 466 177   0   3   0]
 [  0   0   1   9   0   0 990   0   0   0]
 [  1   3   8  48   8   0 208 723   1   0]
 [  3   9   1   5   0   0  51   0 931   0]
 [  1 118   0  11   1   0 108   0  12 749]]


In [35]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval(
            "net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(
            len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())


对比实验准确率 0.1
