In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from attack import attack, test_model, parse_param
import random
device = "cuda" if torch.cuda.is_available() else "cpu"


In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)


In [3]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar100_resnet50()


In [4]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)


In [5]:
all_param_names = all_param_names[:-2]


In [6]:
train_loaders, test_dataloaders, train_dataloader_all, test_dataloader_all = load_cifar100()
all_totals = list()
all_totals.append(attack(train_dataloader_all, all_param_names, load_cifar100_resnet50,
                      train_dataloader_all, alpha=0.00001, num_steps=4, op="add"))


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 196/196 [01:24<00:00,  2.33it/s]


0.0025499204021692274


100%|██████████| 196/196 [01:18<00:00,  2.49it/s]


0.0053379214388132095


100%|██████████| 196/196 [01:17<00:00,  2.53it/s]


0.013816522450447082


100%|██████████| 196/196 [01:17<00:00,  2.52it/s]


0.04495756593704223


  param_totals = np.array(param_totals)
  x = np.array(x)


In [7]:
pkl.dump(all_totals, open("weights/totals.pkl", "wb"))


In [26]:
thre = 0.9
net = load_cifar100_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight)
               for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate(
        [combine_.flatten() for combine_ in combine], axis=0)
    threshold = np.sort(combine_flatten)[
        ::-1][int(len(combine_flatten) * thre)]
    for idx, param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t


  combine = np.array(combine)


In [27]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())


conv1.weight 0.9988425925925926
layer1.0.conv1.weight 0.986083984375
layer1.0.conv2.weight 0.9622938368055556
layer1.0.conv3.weight 0.93389892578125
layer1.0.shortcut.0.weight 0.9384765625
layer1.1.conv1.weight 0.92041015625
layer1.1.conv2.weight 0.9716254340277778
layer1.1.conv3.weight 0.9515380859375
layer1.2.conv1.weight 0.94036865234375
layer1.2.conv2.weight 0.949462890625
layer1.2.conv3.weight 0.8822021484375
layer2.0.conv1.weight 0.986785888671875
layer2.0.conv2.weight 0.9630194769965278
layer2.0.conv3.weight 0.9636993408203125
layer2.0.shortcut.0.weight 0.9525222778320312
layer2.1.conv1.weight 0.907867431640625
layer2.1.conv2.weight 0.9338853624131944
layer2.1.conv3.weight 0.935699462890625
layer2.2.conv1.weight 0.9400482177734375
layer2.2.conv2.weight 0.9505954318576388
layer2.2.conv3.weight 0.9428253173828125
layer2.3.conv1.weight 0.9481658935546875
layer2.3.conv2.weight 0.9437391493055556
layer2.3.conv3.weight 0.92510986328125
layer3.0.conv1.weight 0.9844284057617188
layer3.0

In [28]:
temp / all_num


0.8999999658808341

In [29]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())


原始准确率 0.7929


In [30]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())


现在准确率 0.7893


In [31]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval(
            "net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(
            len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())


对比实验准确率 0.01
