In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from attack import attack, test_model,parse_param
import random

In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)

In [3]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar10_resnet50()


In [4]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)
all_param_names = all_param_names[:-2]

In [5]:
train_loaders, test_dataloaders,train_dataloader_all, test_dataloader_all = load_cifar10()
all_totals = list()
for i in range(10):
    all_totals.append(attack(train_loaders[i], all_param_names, load_cifar10_resnet50,train_dataloader_all, alpha=0.00001,num_steps=4,op="add"))


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 20/20 [00:10<00:00,  1.93it/s]


0.0006625839769840241


100%|██████████| 20/20 [00:06<00:00,  2.90it/s]


0.02401313738822937


100%|██████████| 20/20 [00:06<00:00,  2.91it/s]


0.3929096221923828


100%|██████████| 20/20 [00:06<00:00,  2.92it/s]


2.0104552856445315


  param_totals = np.array(param_totals)
  x = np.array(x)
100%|██████████| 20/20 [00:06<00:00,  2.91it/s]


0.00046284052580595017


100%|██████████| 20/20 [00:06<00:00,  2.86it/s]


0.009428091919422149


100%|██████████| 20/20 [00:06<00:00,  2.92it/s]


0.16736998748779297


100%|██████████| 20/20 [00:06<00:00,  2.92it/s]


1.0400620269775391


100%|██████████| 20/20 [00:06<00:00,  2.92it/s]


0.0005848605498671532


100%|██████████| 20/20 [00:06<00:00,  2.89it/s]


0.015827607131004333


100%|██████████| 20/20 [00:06<00:00,  2.91it/s]


0.2192234992980957


100%|██████████| 20/20 [00:06<00:00,  2.91it/s]


1.2632999237060547


100%|██████████| 20/20 [00:06<00:00,  2.91it/s]


0.0007869976699352264


100%|██████████| 20/20 [00:06<00:00,  2.91it/s]


0.02680395531654358


100%|██████████| 20/20 [00:06<00:00,  2.91it/s]


0.40782768859863283


100%|██████████| 20/20 [00:06<00:00,  2.90it/s]


2.0474692810058595


100%|██████████| 20/20 [00:06<00:00,  2.91it/s]


0.0006073615729808808


100%|██████████| 20/20 [00:06<00:00,  2.91it/s]


0.011672097051143646


100%|██████████| 20/20 [00:06<00:00,  2.91it/s]


0.18673147296905518


100%|██████████| 20/20 [00:06<00:00,  2.91it/s]


1.1935657745361328


100%|██████████| 20/20 [00:06<00:00,  2.89it/s]


0.0005955995425581932


100%|██████████| 20/20 [00:06<00:00,  2.88it/s]


0.018224787282943727


100%|██████████| 20/20 [00:06<00:00,  2.86it/s]


0.304492170715332


100%|██████████| 20/20 [00:06<00:00,  2.86it/s]


1.718993899536133


100%|██████████| 20/20 [00:07<00:00,  2.84it/s]


0.0006207808420062065


100%|██████████| 20/20 [00:07<00:00,  2.85it/s]


0.011070780754089356


100%|██████████| 20/20 [00:07<00:00,  2.83it/s]


0.20569087142944337


100%|██████████| 20/20 [00:07<00:00,  2.83it/s]


1.476642462158203


100%|██████████| 20/20 [00:07<00:00,  2.83it/s]


0.00047607889324426654


100%|██████████| 20/20 [00:07<00:00,  2.83it/s]


0.011722956812381745


100%|██████████| 20/20 [00:07<00:00,  2.83it/s]


0.1806321975708008


100%|██████████| 20/20 [00:07<00:00,  2.77it/s]


1.160956201171875


100%|██████████| 20/20 [00:07<00:00,  2.82it/s]


0.0004115782298147678


100%|██████████| 20/20 [00:07<00:00,  2.79it/s]


0.010544170099496842


100%|██████████| 20/20 [00:07<00:00,  2.81it/s]


0.22035634117126465


100%|██████████| 20/20 [00:07<00:00,  2.80it/s]


1.4476825134277345


100%|██████████| 20/20 [00:07<00:00,  2.78it/s]


0.0005323766142129898


100%|██████████| 20/20 [00:07<00:00,  2.72it/s]


0.016743292498588563


100%|██████████| 20/20 [00:07<00:00,  2.67it/s]


0.3350339324951172


100%|██████████| 20/20 [00:07<00:00,  2.66it/s]


1.9058428802490235


In [6]:
pkl.dump(all_totals, open("weights/high_totals.pkl", "wb"))

In [37]:
# def choose_one(all_totals,thre,other_thre,choose_class=0):
thre = 0.2
other_thre = 0.2
choose_class = 0
net = load_cifar10_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
all_classes = list(range(10))
all_classes.remove(choose_class)
all_classes.append(choose_class)
print(all_classes)
# for i in range(len(all_totals)):
for i,class_ in enumerate(all_classes):
    totals = all_totals[class_]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                    for param in all_param_names]
    combine = [np.abs(total * weight) for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    if i == 9:
        threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    else:
        threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * other_thre)]
    for idx,param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            if i == 9:
                param_remove[param] = ~param_remove[param] & t
            else:
                param_remove[param] = param_remove[param] | t

[1, 2, 3, 4, 5, 6, 7, 8, 9, 0]


  combine = np.array(combine)


In [38]:
# # thre = 0.1
# thres = [0.15,0.15,0.15,0.15,0.15,0.15,0.15,0.15,0.15,0.4]
# net = load_cifar10_resnet50()
# param_remove = dict()
# for param in all_param_names:
#     param_remove[param] = None
# for i in range(len(all_totals)):
#     thre = thres[i]
#     totals = all_totals[i]
#     totals = [totals[param] for param in all_param_names]
#     param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
#                      for param in all_param_names]
#     combine = [np.abs(total * weight) for total, weight in zip(totals, param_weights)]
#     combine = np.array(combine)
#     combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
#     threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
#     for idx,param in enumerate(all_param_names):
#         if param_remove[param] is None:
#             param_remove[param] = combine[idx] > threshold
#         else:
#             t = combine[idx] > threshold
#             if i == 9:
#                 param_remove[param] = ~param_remove[param] & t
#             else:
#                 param_remove[param] = param_remove[param] | t

In [39]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())

conv1.weight 0.0023148148148148147
layer1.0.conv1.weight 0.08056640625
layer1.0.conv2.weight 0.1787651909722222
layer1.0.conv3.weight 0.09033203125
layer1.0.shortcut.0.weight 0.1180419921875
layer1.1.conv1.weight 0.20196533203125
layer1.1.conv2.weight 0.15833875868055555
layer1.1.conv3.weight 0.126953125
layer1.2.conv1.weight 0.18035888671875
layer1.2.conv2.weight 0.15733506944444445
layer1.2.conv3.weight 0.1873779296875
layer2.0.conv1.weight 0.053558349609375
layer2.0.conv2.weight 0.15193006727430555
layer2.0.conv3.weight 0.1154937744140625
layer2.0.shortcut.0.weight 0.16058349609375
layer2.1.conv1.weight 0.33868408203125
layer2.1.conv2.weight 0.2028469509548611
layer2.1.conv3.weight 0.1745452880859375
layer2.2.conv1.weight 0.2266693115234375
layer2.2.conv2.weight 0.18778483072916666
layer2.2.conv3.weight 0.2118377685546875
layer2.3.conv1.weight 0.167572021484375
layer2.3.conv2.weight 0.1712646484375
layer2.3.conv3.weight 0.2668304443359375
layer3.0.conv1.weight 0.0747222900390625
lay

In [40]:
0.2 - temp / all_num

-0.3980436837917584

In [41]:
from sklearn.metrics import confusion_matrix
with torch.no_grad():
    net = load_cifar10_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())
    print(confusion_matrix(labels, preds.argmax(-1)))
    # 输出每个类别的准确率
    for i in range(10):
        print(f"类别{i}准确率", (preds[labels == i].argmax(-1) == i).mean())


原始准确率 0.954
[[958   4   7   3   1   0   2   1  19   5]
 [  1 977   1   1   0   0   0   1   1  18]
 [  8   0 948  10  10   9  10   2   3   0]
 [  5   1  17 889  11  53  17   3   4   0]
 [  1   0   8  11 962   5   2  11   0   0]
 [  3   0   6  50   8 922   4   7   0   0]
 [  1   0   5   7   3   1 982   0   0   1]
 [  3   0   3   6   8   8   0 971   1   0]
 [ 14   6   3   1   0   1   0   0 969   6]
 [  5  23   1   2   0   1   0   0   6 962]]
类别0准确率 0.958
类别1准确率 0.977
类别2准确率 0.948
类别3准确率 0.889
类别4准确率 0.962
类别5准确率 0.922
类别6准确率 0.982
类别7准确率 0.971
类别8准确率 0.969
类别9准确率 0.962


In [42]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())
    print(confusion_matrix(labels, preds.argmax(-1)))
    for i in range(10):
        print(f"类别{i}准确率", (preds[labels == i].argmax(-1) == i).mean())

现在准确率 0.9201
[[765  16   6  34  23   9  44  35  30  38]
 [  0 982   0   0   2   0   1   2   1  12]
 [  4   2 863  11  36  16  53  10   3   2]
 [  0   5   4 809  27  73  53  14   4  11]
 [  0   0   0   2 973   3  11  11   0   0]
 [  0   1   4  30  21 919  11  12   0   2]
 [  0   0   0   1   4   0 994   0   0   1]
 [  0   1   1   2  10   7   1 975   0   3]
 [  2  11   0   2   2   1  10   2 960  10]
 [  1  27   0   2   1   0   3   1   4 961]]
类别0准确率 0.765
类别1准确率 0.982
类别2准确率 0.863
类别3准确率 0.809
类别4准确率 0.973
类别5准确率 0.919
类别6准确率 0.994
类别7准确率 0.975
类别8准确率 0.96
类别9准确率 0.961
