In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from tqdm.notebook import tqdm
from attack import attack,parse_param,test_model
from utils import get_device
import random
device = get_device()

In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)


In [3]:
from datasets import load_cifar100
from models.resnet import load_cifar100_resnet50
model = load_cifar100_resnet50()


In [4]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)
all_param_names = all_param_names[:-2]

In [6]:
train_loaders, test_dataloaders, train_dataloader_all, test_dataloader_all = load_cifar100()
all_totals = list()
for i in range(100):
    all_totals.append(attack(train_loaders[i], all_param_names,
                      load_cifar100_resnet50, norm=True, alpha=0.00001, num_steps=4, op="add"))


Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

0.0006626494631171226


  0%|          | 0/20 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [16]:
pkl.dump(all_totals, open("weights/totals.pkl", "wb"))
# all_totals = pkl.load(open("weights/totals.pkl", "rb"))


In [17]:
thre = 0.25
net = load_cifar100_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight)
               for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate(
        [combine_.flatten() for combine_ in combine], axis=0)
    threshold = np.sort(combine_flatten)[
        ::-1][int(len(combine_flatten) * thre)]
    for idx, param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t


  combine = np.array(combine)


In [18]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())


conv1.weight 0.9797453703703703
layer1.0.conv1.weight 0.761962890625
layer1.0.conv2.weight 0.5234375
layer1.0.conv3.weight 0.68951416015625
layer1.0.shortcut.0.weight 0.67645263671875
layer1.1.conv1.weight 0.46636962890625
layer1.1.conv2.weight 0.5396864149305556
layer1.1.conv3.weight 0.64544677734375
layer1.2.conv1.weight 0.5447998046875
layer1.2.conv2.weight 0.634765625
layer1.2.conv3.weight 0.5819091796875
layer2.0.conv1.weight 0.777557373046875
layer2.0.conv2.weight 0.4875691731770833
layer2.0.conv3.weight 0.618011474609375
layer2.0.shortcut.0.weight 0.5057907104492188
layer2.1.conv1.weight 0.221160888671875
layer2.1.conv2.weight 0.3875935872395833
layer2.1.conv3.weight 0.521728515625
layer2.2.conv1.weight 0.363433837890625
layer2.2.conv2.weight 0.4190673828125
layer2.2.conv3.weight 0.4657745361328125
layer2.3.conv1.weight 0.432037353515625
layer2.3.conv2.weight 0.4587741427951389
layer2.3.conv3.weight 0.413116455078125
layer3.0.conv1.weight 0.6927871704101562
layer3.0.conv2.weight

In [19]:
temp / all_num


0.14116536229095186

In [13]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())


原始准确率 0.954


In [16]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())


现在准确率 0.1


In [15]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval(
            "net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(
            len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())


对比实验准确率 0.1
