In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
from attack import attack, test_model


In [2]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar10_resnet50()


In [3]:
all_param_names = list()
for name, param in model.named_parameters():
    
    name = name.replace(".0.", "[0].")
    name = name.replace(".1.", "[1].")
    name = name.replace(".2.", "[2].")
    name = name.replace(".3.", "[3].")
    name = name.replace(".4.", "[4].")
    name = name.replace(".5.", "[5].")
    name = name.replace(".6.", "[6].")
    name = name.replace(".7.", "[7].")
    name = name.replace(".8.", "[8].")
    name = name.replace(".9.", "[9].") 
    if name.endswith("weight"):
        all_param_names.append(name[:-7])
all_param_names

['conv1',
 'bn1',
 'layer1[0].conv1',
 'layer1[0].bn1',
 'layer1[0].conv2',
 'layer1[0].bn2',
 'layer1[0].conv3',
 'layer1[0].bn3',
 'layer1[0].shortcut[0]',
 'layer1[0].shortcut[1]',
 'layer1[1].conv1',
 'layer1[1].bn1',
 'layer1[1].conv2',
 'layer1[1].bn2',
 'layer1[1].conv3',
 'layer1[1].bn3',
 'layer1[2].conv1',
 'layer1[2].bn1',
 'layer1[2].conv2',
 'layer1[2].bn2',
 'layer1[2].conv3',
 'layer1[2].bn3',
 'layer2[0].conv1',
 'layer2[0].bn1',
 'layer2[0].conv2',
 'layer2[0].bn2',
 'layer2[0].conv3',
 'layer2[0].bn3',
 'layer2[0].shortcut[0]',
 'layer2[0].shortcut[1]',
 'layer2[1].conv1',
 'layer2[1].bn1',
 'layer2[1].conv2',
 'layer2[1].bn2',
 'layer2[1].conv3',
 'layer2[1].bn3',
 'layer2[2].conv1',
 'layer2[2].bn1',
 'layer2[2].conv2',
 'layer2[2].bn2',
 'layer2[2].conv3',
 'layer2[2].bn3',
 'layer2[3].conv1',
 'layer2[3].bn1',
 'layer2[3].conv2',
 'layer2[3].bn2',
 'layer2[3].conv3',
 'layer2[3].bn3',
 'layer3[0].conv1',
 'layer3[0].bn1',
 'layer3[0].conv2',
 'layer3[0].bn2',
 'la

In [4]:
for name in all_param_names:
    print(eval("model." + name + ".weight.shape"))

torch.Size([64, 3, 3, 3])
torch.Size([64])
torch.Size([64, 64, 1, 1])
torch.Size([64])
torch.Size([64, 64, 3, 3])
torch.Size([64])
torch.Size([256, 64, 1, 1])
torch.Size([256])
torch.Size([256, 64, 1, 1])
torch.Size([256])
torch.Size([64, 256, 1, 1])
torch.Size([64])
torch.Size([64, 64, 3, 3])
torch.Size([64])
torch.Size([256, 64, 1, 1])
torch.Size([256])
torch.Size([64, 256, 1, 1])
torch.Size([64])
torch.Size([64, 64, 3, 3])
torch.Size([64])
torch.Size([256, 64, 1, 1])
torch.Size([256])
torch.Size([128, 256, 1, 1])
torch.Size([128])
torch.Size([128, 128, 3, 3])
torch.Size([128])
torch.Size([512, 128, 1, 1])
torch.Size([512])
torch.Size([512, 256, 1, 1])
torch.Size([512])
torch.Size([128, 512, 1, 1])
torch.Size([128])
torch.Size([128, 128, 3, 3])
torch.Size([128])
torch.Size([512, 128, 1, 1])
torch.Size([512])
torch.Size([128, 512, 1, 1])
torch.Size([128])
torch.Size([128, 128, 3, 3])
torch.Size([128])
torch.Size([512, 128, 1, 1])
torch.Size([512])
torch.Size([128, 512, 1, 1])
torch.Si

In [5]:
train_loaders, test_dataloaders,train_dataloader_all, test_dataloader_all = load_cifar10()
all_totals = list()
# for i in range(10):
all_totals.append(attack(train_dataloader_all, ["layer3[0].conv2", "layer3[1].conv2", "layer3[5].conv2"], load_cifar10_resnet50, alpha=0.0001))


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 196/196 [01:21<00:00,  2.39it/s]


0.0005741859122738243


100%|██████████| 196/196 [01:20<00:00,  2.44it/s]


0.0016722585937380791


100%|██████████| 196/196 [01:20<00:00,  2.43it/s]


0.013022370172739028


100%|██████████| 196/196 [01:20<00:00,  2.43it/s]


0.1361942886543274


100%|██████████| 196/196 [01:20<00:00,  2.42it/s]

0.6002613494110107





In [6]:
all_totals


[{'layer3[0].conv2': array([[[[0.04758077, 0.06342641, 0.06326441],
           [0.06420432, 0.07797959, 0.07982474],
           [0.05798601, 0.07569326, 0.08194256]],
  
          [[0.06142694, 0.0597029 , 0.07775103],
           [0.06908931, 0.07178441, 0.10876542],
           [0.06681299, 0.07226577, 0.11651919]],
  
          [[0.03806551, 0.03091723, 0.04360677],
           [0.04734807, 0.04001633, 0.05158048],
           [0.05294424, 0.05357284, 0.06962128]],
  
          ...,
  
          [[0.02719614, 0.01592337, 0.0172644 ],
           [0.02862031, 0.0262914 , 0.02412566],
           [0.02905377, 0.03771427, 0.03285838]],
  
          [[0.10237086, 0.11564587, 0.11009045],
           [0.12811995, 0.14845905, 0.12963414],
           [0.13323158, 0.15327935, 0.12508966]],
  
          [[0.04020063, 0.04502914, 0.05525846],
           [0.06773664, 0.06980892, 0.07862065],
           [0.0705502 , 0.07311662, 0.08450375]]],
  
  
         [[[0.02865841, 0.03064635, 0.0124202 ],
    

In [7]:
thre = 0.4
net = load_cifar10_resnet50()
layer_remove = dict()
for layer in ["layer3[0].conv2", "layer3[1].conv2", "layer3[5].conv2"]:
    layer_remove[layer] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[layer] for layer in ["layer3[0].conv2", "layer3[1].conv2", "layer3[5].conv2"]]
    layer_weights = [eval("net." + layer + ".weight.cpu().detach().numpy()")
                     for layer in ["layer3[0].conv2", "layer3[1].conv2", "layer3[5].conv2"]]
    combine = [np.abs(total * weight) for total, weight in zip(totals, layer_weights)]
    combine = np.array(combine)
    combine_flatten = combine.flatten()
    threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    for idx,layer in enumerate(["layer3[0].conv2", "layer3[1].conv2", "layer3[5].conv2"]):
        if layer_remove[layer] is None:
            layer_remove[layer] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            layer_remove[layer] = layer_remove[layer] | t
    
# for i in range(len(totals)):
#     combine = np.abs(totals[i][layer] * eval("net." + layer + ".weight.cpu().detach().numpy()"))
#     threshold = np.sort(combine.flatten())[::-1][int(len(combine.flatten()) * thre)]
#     if remove is None:
#         remove = combine > threshold
#     else:
#         t = combine > threshold
#         remove = remove | t
# print(remove.sum())
# keep_rate = remove.sum() / remove.size
# print("保留的", keep_rate)


In [8]:
temp = 0
all_num = 0
for layer in layer_remove:
    temp += layer_remove[layer].sum()
    all_num += layer_remove[layer].size
    print(layer, layer_remove[layer].mean())

layer3[0].conv2 0.4185468885633681
layer3[1].conv2 0.2676408555772569
layer3[5].conv2 0.5138108995225694


In [9]:
temp / all_num

0.39999954788773145

In [10]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    correct, all = test_model(net, test_dataloader_all)
    print("原始准确率", correct / all)


原始准确率 0.954


In [11]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for layer in ["layer3[0].conv2", "layer3[1].conv2", "layer3[5].conv2"]:
        exec("net." + layer + ".weight[~layer_remove[layer]] = 0")
    # for layer in ["layer3[0].conv2"]:
    #     exec("net." + layer + ".weight[~remove] = 0")
    correct, all = test_model(net, test_dataloader_all)
    print("现在准确率", correct / all)


现在准确率 0.9229


In [12]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for layer in ["layer3[0].conv2", "layer3[1].conv2", "layer3[5].conv2"]:
        keep_rate = layer_remove[layer].sum() / layer_remove[layer].size
        weight_flatten = eval("net." + layer + ".weight.cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(len(weight_flatten) * (1 - keep_rate))]
        exec("net." + layer + ".weight[eval('net.' + layer + '.weight.cpu().detach().numpy()') < threshold] = 0")
    correct, all = test_model(net, test_dataloader_all)
    print("去掉最大准确率", correct / all)
    # for layer in ["layer3[0].conv2"]:
    #     threshold = np.sort(eval("net." + layer + ".weight.cpu().detach().numpy()").flatten())[
    #         int(len(eval("net." + layer +
    #             ".weight.cpu().detach().numpy()").flatten()) * (1 - keep_rate))
    #     ]
    #     exec("net." + layer +
    #          ".weight[eval('net.' + layer + '.weight.cpu().detach().numpy()') < threshold] = 0")
    # correct, all = test_model(net, test_dataloader_all)
    # print("现在准确率", correct / all)


去掉最大准确率 0.1156
