In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from tqdm.notebook import tqdm
from attack import attack,parse_param,test_model
from utils import get_device
import random
device = get_device()

In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)


In [3]:
from datasets import load_cifar10
from models.resnet import load_cifar10_resnet50
model = load_cifar10_resnet50()


In [4]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)
all_param_names = all_param_names[:-2]

In [5]:
train_loaders, test_dataloaders, train_dataloader_all, test_dataloader_all = load_cifar10()
all_totals = list()
all_totals.append(attack(train_dataloader_all, all_param_names,
                         load_cifar10_resnet50, norm=True, alpha=0.00001, num_steps=4, op="add"))


Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

0.0005742178249359131


  0%|          | 0/196 [00:00<?, ?it/s]

0.001543781439512968


  0%|          | 0/196 [00:00<?, ?it/s]

0.0065519423300027845


  0%|          | 0/196 [00:00<?, ?it/s]

0.048425866241455076


  x = np.array(x)


In [6]:
pkl.dump(all_totals, open("weights/totals.pkl", "wb"))
# all_totals = pkl.load(open("weights/totals.pkl", "rb"))


In [13]:
thre = 0.5
net = load_cifar10_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight)
               for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate(
        [combine_.flatten() for combine_ in combine], axis=0)
    threshold = np.sort(combine_flatten)[
        ::-1][int(len(combine_flatten) * thre)]
    for idx, param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t


  combine = np.array(combine)


In [14]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())


conv1.weight 0.9982638888888888
layer1.0.conv1.weight 0.946044921875
layer1.0.conv2.weight 0.8968370225694444
layer1.0.conv3.weight 0.9039306640625
layer1.0.shortcut.0.weight 0.880859375
layer1.1.conv1.weight 0.83538818359375
layer1.1.conv2.weight 0.8942328559027778
layer1.1.conv3.weight 0.899169921875
layer1.2.conv1.weight 0.860595703125
layer1.2.conv2.weight 0.8793402777777778
layer1.2.conv3.weight 0.87884521484375
layer2.0.conv1.weight 0.9588623046875
layer2.0.conv2.weight 0.8979627821180556
layer2.0.conv3.weight 0.9182891845703125
layer2.0.shortcut.0.weight 0.8923721313476562
layer2.1.conv1.weight 0.801727294921875
layer2.1.conv2.weight 0.8716295030381944
layer2.1.conv3.weight 0.880767822265625
layer2.2.conv1.weight 0.87872314453125
layer2.2.conv2.weight 0.8824801974826388
layer2.2.conv3.weight 0.8565521240234375
layer2.3.conv1.weight 0.9043121337890625
layer2.3.conv2.weight 0.8874104817708334
layer2.3.conv3.weight 0.811004638671875
layer3.0.conv1.weight 0.9490814208984375
layer3.0

In [15]:
temp / all_num


0.5

In [16]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())


原始准确率 0.954


In [17]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())


现在准确率 0.9298


In [12]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval(
            "net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(
            len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())


对比实验准确率 0.1
