In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from attack import attack, test_model, parse_param
import random


In [4]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)


In [5]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar10_resnet50()


In [6]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)


In [7]:
all_param_names = all_param_names[:-2]


In [8]:
train_loaders, test_dataloaders, train_dataloader_all, test_dataloader_all = load_cifar10()
# all_totals = list()
# for i in range(10):
#     all_totals.append(attack(train_loaders[i], all_param_names,
#                       load_cifar10_resnet50,train_dataloader_all, alpha=0.00001, num_steps=4, op="add"))


Files already downloaded and verified
Files already downloaded and verified


In [9]:
# pkl.dump(all_totals, open("weights/totals.pkl", "wb"))
all_totals = pkl.load(open("weights/totals.pkl", "rb"))

In [10]:
thre = 0.25
net = load_cifar10_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight)
               for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate(
        [combine_.flatten() for combine_ in combine], axis=0)
    threshold = np.sort(combine_flatten)[
        ::-1][int(len(combine_flatten) * thre)]
    for idx, param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t


  combine = np.array(combine)


In [11]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())


conv1.weight 0.9976851851851852
layer1.0.conv1.weight 0.943359375
layer1.0.conv2.weight 0.8700900607638888
layer1.0.conv3.weight 0.89874267578125
layer1.0.shortcut.0.weight 0.8687744140625
layer1.1.conv1.weight 0.8094482421875
layer1.1.conv2.weight 0.8846842447916666
layer1.1.conv3.weight 0.8861083984375
layer1.2.conv1.weight 0.84033203125
layer1.2.conv2.weight 0.870361328125
layer1.2.conv3.weight 0.84906005859375
layer2.0.conv1.weight 0.963531494140625
layer2.0.conv2.weight 0.8958875868055556
layer2.0.conv3.weight 0.9161529541015625
layer2.0.shortcut.0.weight 0.8838653564453125
layer2.1.conv1.weight 0.7527923583984375
layer2.1.conv2.weight 0.8577677408854166
layer2.1.conv3.weight 0.86981201171875
layer2.2.conv1.weight 0.8435211181640625
layer2.2.conv2.weight 0.8680826822916666
layer2.2.conv3.weight 0.8392791748046875
layer2.3.conv1.weight 0.88409423828125
layer2.3.conv2.weight 0.8778483072916666
layer2.3.conv3.weight 0.791015625
layer3.0.conv1.weight 0.9491806030273438
layer3.0.conv2.

In [12]:
temp / all_num


0.4675693489107797

In [13]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())


原始准确率 0.954


In [14]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())


现在准确率 0.9446


In [15]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval(
            "net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(
            len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())


对比实验准确率 0.1
