In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import pickle as pkl
import torch.nn.functional as F
import torchvision
import random
from attack import attack, test_model,parse_param


In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)

In [3]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar10_resnet50()


In [4]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)

In [5]:
all_param_names = all_param_names[:-2]

In [6]:
train_loaders, test_dataloaders,train_dataloader_all, test_dataloader_all = load_cifar10()
all_totals = list()
all_totals.append(attack(train_dataloader_all, all_param_names, load_cifar10_resnet50,train_dataloader_all, alpha=0.00001,num_steps=4,op="add"))


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 196/196 [01:18<00:00,  2.48it/s]


0.0005741067615151406


100%|██████████| 196/196 [01:13<00:00,  2.66it/s]


0.4299012036895752


100%|██████████| 196/196 [01:14<00:00,  2.61it/s]


10.475829932861329


  param_totals = np.array(param_totals)
  x = np.array(x)


In [7]:
all_totals


[{'conv1.weight': array([[[[0.0360588 , 0.09237533, 0.06430576],
           [0.02092605, 0.0336416 , 0.02365122],
           [0.03448137, 0.05658536, 0.00850736]],
  
          [[0.04909461, 0.10951702, 0.08440155],
           [0.00963671, 0.05140426, 0.00981765],
           [0.0490398 , 0.07459333, 0.02150341]],
  
          [[0.06820728, 0.12422253, 0.09693772],
           [0.03306965, 0.07456466, 0.0291067 ],
           [0.07984389, 0.1023901 , 0.04436048]]],
  
  
         [[[0.11709377, 0.06977938, 0.07355948],
           [0.20905936, 0.05876227, 0.0875394 ],
           [0.23460387, 0.03738863, 0.12143967]],
  
          [[0.14199853, 0.04384269, 0.10328057],
           [0.22218966, 0.04106204, 0.11287365],
           [0.24189201, 0.04715927, 0.13803318]],
  
          [[0.14973895, 0.04099445, 0.10030424],
           [0.23234071, 0.03244056, 0.11842395],
           [0.26025745, 0.06926417, 0.15088947]]],
  
  
         [[[0.26321858, 0.25459212, 0.22776978],
           [0.2232554

In [8]:
pkl.dump(all_totals, open("weights/totals.pkl", "wb"))

In [15]:
thre = 0.25
net = load_cifar10_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight) for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    for idx,param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t

  combine = np.array(combine)


In [16]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())

conv1.weight 0.9971064814814815
layer1.0.conv1.weight 0.938720703125
layer1.0.conv2.weight 0.8323838975694444
layer1.0.conv3.weight 0.88739013671875
layer1.0.shortcut.0.weight 0.87091064453125
layer1.1.conv1.weight 0.766357421875
layer1.1.conv2.weight 0.8091634114583334
layer1.1.conv3.weight 0.87457275390625
layer1.2.conv1.weight 0.8057861328125
layer1.2.conv2.weight 0.8533257378472222
layer1.2.conv3.weight 0.84320068359375
layer2.0.conv1.weight 0.919921875
layer2.0.conv2.weight 0.7741970486111112
layer2.0.conv3.weight 0.83294677734375
layer2.0.shortcut.0.weight 0.7821884155273438
layer2.1.conv1.weight 0.66607666015625
layer2.1.conv2.weight 0.79119873046875
layer2.1.conv3.weight 0.8011016845703125
layer2.2.conv1.weight 0.734466552734375
layer2.2.conv2.weight 0.716064453125
layer2.2.conv3.weight 0.740509033203125
layer2.3.conv1.weight 0.796905517578125
layer2.3.conv2.weight 0.7399766710069444
layer2.3.conv3.weight 0.7042999267578125
layer3.0.conv1.weight 0.856658935546875
layer3.0.conv2

In [17]:
temp / all_num

0.5

In [18]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())


原始准确率 0.954


In [19]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())


现在准确率 0.8911


In [20]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval(
            "net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(
            len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())


对比试验准确率 0.1
