In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from attack import attack, test_model,parse_param
import random

In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)

In [3]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar100_resnet50()


In [4]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)
all_param_names = all_param_names[:-2]

In [5]:
train_loaders, test_dataloaders, test_dataloader_all = load_cifar100()
all_totals = list()
for i in range(100):
    all_totals.append(attack(train_loaders[i], all_param_names, load_cifar100_resnet50, alpha=0.0001,num_steps=5,op="add"))


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 20/20 [00:10<00:00,  1.93it/s]


0.0006625839769840241


100%|██████████| 20/20 [00:06<00:00,  2.86it/s]


13.35895234375


100%|██████████| 20/20 [00:06<00:00,  2.88it/s]


57.21819111328125


100%|██████████| 20/20 [00:06<00:00,  2.89it/s]


258.32373125


100%|██████████| 20/20 [00:07<00:00,  2.85it/s]


1193.916146875


  param_totals = np.array(param_totals)
100%|██████████| 20/20 [00:07<00:00,  2.77it/s]


0.0004628417499363422


100%|██████████| 20/20 [00:07<00:00,  2.76it/s]


11.755080200195312


100%|██████████| 20/20 [00:07<00:00,  2.75it/s]


45.515598828125


100%|██████████| 20/20 [00:07<00:00,  2.76it/s]


185.044208984375


100%|██████████| 20/20 [00:07<00:00,  2.76it/s]


835.832775


100%|██████████| 20/20 [00:07<00:00,  2.75it/s]


0.0005848916381597519


100%|██████████| 20/20 [00:07<00:00,  2.75it/s]


10.5305044921875


100%|██████████| 20/20 [00:07<00:00,  2.68it/s]


33.2222689453125


100%|██████████| 20/20 [00:07<00:00,  2.65it/s]


152.38730234375


100%|██████████| 20/20 [00:07<00:00,  2.63it/s]


658.150940625


100%|██████████| 20/20 [00:07<00:00,  2.64it/s]


0.0007870208606123925


100%|██████████| 20/20 [00:07<00:00,  2.66it/s]


15.081411669921875


100%|██████████| 20/20 [00:07<00:00,  2.62it/s]


77.0086833984375


100%|██████████| 20/20 [00:07<00:00,  2.64it/s]


323.35151328125


100%|██████████| 20/20 [00:07<00:00,  2.60it/s]


1453.376259375


100%|██████████| 20/20 [00:07<00:00,  2.60it/s]


0.0006073515847325325


100%|██████████| 20/20 [00:07<00:00,  2.58it/s]


11.401686108398437


100%|██████████| 20/20 [00:07<00:00,  2.58it/s]


44.97334072265625


100%|██████████| 20/20 [00:07<00:00,  2.59it/s]


197.41837109375


100%|██████████| 20/20 [00:07<00:00,  2.57it/s]


940.188328125


100%|██████████| 20/20 [00:07<00:00,  2.60it/s]


0.0005956038966774941


100%|██████████| 20/20 [00:07<00:00,  2.58it/s]


12.052161889648438


100%|██████████| 20/20 [00:07<00:00,  2.57it/s]


36.23437021484375


100%|██████████| 20/20 [00:07<00:00,  2.55it/s]


126.497278125


100%|██████████| 20/20 [00:07<00:00,  2.56it/s]


515.428834375


100%|██████████| 20/20 [00:07<00:00,  2.54it/s]


0.0006207873910665512


100%|██████████| 20/20 [00:07<00:00,  2.57it/s]


11.935619506835938


100%|██████████| 20/20 [00:07<00:00,  2.52it/s]


35.94758642578125


100%|██████████| 20/20 [00:07<00:00,  2.56it/s]


155.9597890625


100%|██████████| 20/20 [00:07<00:00,  2.58it/s]


860.8029546875


100%|██████████| 20/20 [00:07<00:00,  2.64it/s]


0.0004760837726294994


100%|██████████| 20/20 [00:07<00:00,  2.61it/s]


10.379177197265625


100%|██████████| 20/20 [00:08<00:00,  2.49it/s]


33.36720546875


100%|██████████| 20/20 [00:07<00:00,  2.58it/s]


118.5051861328125


100%|██████████| 20/20 [00:07<00:00,  2.67it/s]


489.364015625


100%|██████████| 20/20 [00:07<00:00,  2.56it/s]


0.000411579966545105


100%|██████████| 20/20 [00:08<00:00,  2.30it/s]


9.783391284179688


100%|██████████| 20/20 [00:07<00:00,  2.57it/s]


36.93264345703125


100%|██████████| 20/20 [00:07<00:00,  2.62it/s]


162.01926171875


100%|██████████| 20/20 [00:07<00:00,  2.65it/s]


762.30031875


100%|██████████| 20/20 [00:07<00:00,  2.57it/s]


0.0005323766872286797


100%|██████████| 20/20 [00:07<00:00,  2.63it/s]


11.345707299804687


100%|██████████| 20/20 [00:07<00:00,  2.65it/s]


50.0259662109375


100%|██████████| 20/20 [00:07<00:00,  2.68it/s]


244.3629484375


100%|██████████| 20/20 [00:07<00:00,  2.68it/s]


1226.9912875


In [6]:
pkl.dump(all_totals, open("weights/high_totals.pkl", "wb"))

In [7]:
# def choose_one(all_totals,thre,other_thre,choose_class=0):
thre = 0.2
other_thre = 0.2
choose_class = 0
net = load_cifar100_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
all_classes = list(range(100))
all_classes.remove(choose_class)
all_classes.append(choose_class)
print(all_classes)
# for i in range(len(all_totals)):
for i,class_ in enumerate(all_classes):
    totals = all_totals[class_]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                    for param in all_param_names]
    combine = [np.abs(total * weight) for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    if i == 9:
        threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    else:
        threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * other_thre)]
    for idx,param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            if i == 9:
                param_remove[param] = ~param_remove[param] & t
            else:
                param_remove[param] = param_remove[param] | t

[1, 2, 3, 4, 5, 6, 7, 8, 9, 0]


  combine = np.array(combine)


In [9]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())

conv1.weight 0.0
layer1.0.conv1.weight 0.00244140625
layer1.0.conv2.weight 0.003228081597222222
layer1.0.conv3.weight 0.00225830078125
layer1.0.shortcut.0.weight 0.00396728515625
layer1.1.conv1.weight 0.008056640625
layer1.1.conv2.weight 0.008870442708333334
layer1.1.conv3.weight 0.00341796875
layer1.2.conv1.weight 0.0086669921875
layer1.2.conv2.weight 0.009412977430555556
layer1.2.conv3.weight 0.0018310546875
layer2.0.conv1.weight 0.001922607421875
layer2.0.conv2.weight 0.005438910590277778
layer2.0.conv3.weight 0.0023345947265625
layer2.0.shortcut.0.weight 0.00354766845703125
layer2.1.conv1.weight 0.0128936767578125
layer2.1.conv2.weight 0.005093044704861111
layer2.1.conv3.weight 0.005706787109375
layer2.2.conv1.weight 0.008941650390625
layer2.2.conv2.weight 0.004564073350694444
layer2.2.conv3.weight 0.00360107421875
layer2.3.conv1.weight 0.00469970703125
layer2.3.conv2.weight 0.002312554253472222
layer2.3.conv3.weight 0.004974365234375
layer3.0.conv1.weight 0.0018768310546875
layer3

In [10]:
0.2 - temp / all_num

0.18667514357345039

In [11]:
from sklearn.metrics import confusion_matrix
with torch.no_grad():
    net = load_cifar100_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())
    print(confusion_matrix(labels, preds.argmax(-1)))
    # 输出每个类别的准确率
    for i in range(100):
        print(f"类别{i}准确率", (preds[labels == i].argmax(-1) == i).mean())


原始准确率 0.954
[[958   4   7   3   1   0   2   1  19   5]
 [  1 977   1   1   0   0   0   1   1  18]
 [  8   0 948  10  10   9  10   2   3   0]
 [  5   1  17 889  11  53  17   3   4   0]
 [  1   0   8  11 962   5   2  11   0   0]
 [  3   0   6  50   8 922   4   7   0   0]
 [  1   0   5   7   3   1 982   0   0   1]
 [  3   0   3   6   8   8   0 971   1   0]
 [ 14   6   3   1   0   1   0   0 969   6]
 [  5  23   1   2   0   1   0   0   6 962]]
类别0准确率 0.958
类别1准确率 0.977
类别2准确率 0.948
类别3准确率 0.889
类别4准确率 0.962
类别5准确率 0.922
类别6准确率 0.982
类别7准确率 0.971
类别8准确率 0.969
类别9准确率 0.962


In [12]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())
    print(confusion_matrix(labels, preds.argmax(-1)))
    for i in range(100):
        print(f"类别{i}准确率", (preds[labels == i].argmax(-1) == i).mean())

现在准确率 0.9535
[[939   3  13   6   3   0   1   1  27   7]
 [  1 977   1   1   0   0   0   1   1  18]
 [  5   0 953  13   8   7   8   2   4   0]
 [  3   1  18 901  10  45  11   5   6   0]
 [  1   0   8  11 962   5   2  11   0   0]
 [  3   0   8  56   5 917   3   8   0   0]
 [  1   0   8  12   3   1 974   0   0   1]
 [  1   0   3   6   9   7   0 973   1   0]
 [  7   5   3   2   0   1   0   0 978   4]
 [  3  21   2   2   0   0   0   0  11 961]]
类别0准确率 0.939
类别1准确率 0.977
类别2准确率 0.953
类别3准确率 0.901
类别4准确率 0.962
类别5准确率 0.917
类别6准确率 0.974
类别7准确率 0.973
类别8准确率 0.978
类别9准确率 0.961
