In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from attack import attack, test_model,parse_param
import random

In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)

In [3]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar10_resnet50()


In [4]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)

In [5]:
all_param_names = all_param_names[:-2]

In [6]:
train_loaders, test_dataloaders, test_dataloader_all = load_cifar10()
all_totals = list()
for i in range(10):
    all_totals.append(attack(train_loaders[i], all_param_names, load_cifar10_resnet50, alpha=0.00005,num_steps=3,op="minus"))
    all_totals.append(attack(train_loaders[i], all_param_names, load_cifar10_resnet50, alpha=0.0001,num_steps=3,op="add"))


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 20/20 [00:13<00:00,  1.45it/s]


0.0006625839769840241


100%|██████████| 20/20 [00:10<00:00,  1.94it/s]


3.283968003233895e-06


100%|██████████| 20/20 [00:10<00:00,  1.96it/s]


0.0


  param_totals = np.array(param_totals)
100%|██████████| 20/20 [00:10<00:00,  1.95it/s]


0.000662586422264576


100%|██████████| 20/20 [00:10<00:00,  1.94it/s]


13.406303759765626


100%|██████████| 20/20 [00:10<00:00,  1.94it/s]


57.52458779296875


100%|██████████| 20/20 [00:10<00:00,  1.94it/s]


0.0004628409862518311


100%|██████████| 20/20 [00:10<00:00,  1.94it/s]


1.8679110246011987e-06


100%|██████████| 20/20 [00:10<00:00,  1.93it/s]


0.0


100%|██████████| 20/20 [00:10<00:00,  1.93it/s]


0.0004628450289368629


100%|██████████| 20/20 [00:10<00:00,  1.92it/s]


11.796388305664063


100%|██████████| 20/20 [00:09<00:00,  2.01it/s]


45.74559716796875


100%|██████████| 20/20 [00:10<00:00,  1.96it/s]


0.000584859086573124


100%|██████████| 20/20 [00:11<00:00,  1.81it/s]


3.024666366400197e-06


100%|██████████| 20/20 [00:10<00:00,  1.85it/s]


0.0


100%|██████████| 20/20 [00:10<00:00,  1.86it/s]


0.0005848580315709114


100%|██████████| 20/20 [00:10<00:00,  1.83it/s]


10.575895336914062


100%|██████████| 20/20 [00:11<00:00,  1.81it/s]


33.47259736328125


100%|██████████| 20/20 [00:11<00:00,  1.79it/s]


0.0007870006538927555


100%|██████████| 20/20 [00:11<00:00,  1.77it/s]


1.998564408859238e-06


100%|██████████| 20/20 [00:11<00:00,  1.79it/s]


0.0


100%|██████████| 20/20 [00:11<00:00,  1.74it/s]


0.000786986194550991


100%|██████████| 20/20 [00:11<00:00,  1.77it/s]


15.139920263671875


100%|██████████| 20/20 [00:11<00:00,  1.73it/s]


77.4741080078125


100%|██████████| 20/20 [00:11<00:00,  1.78it/s]


0.0006073593929409981


100%|██████████| 20/20 [00:11<00:00,  1.68it/s]


6.083607708569616e-06


100%|██████████| 20/20 [00:10<00:00,  1.83it/s]


0.0


100%|██████████| 20/20 [00:10<00:00,  1.84it/s]


0.0006073602437973023


100%|██████████| 20/20 [00:10<00:00,  1.83it/s]


11.448587841796876


100%|██████████| 20/20 [00:10<00:00,  1.82it/s]


45.2471705078125


100%|██████████| 20/20 [00:11<00:00,  1.82it/s]


0.0005955957502126694


100%|██████████| 20/20 [00:10<00:00,  1.82it/s]


9.24586207838729e-07


100%|██████████| 20/20 [00:10<00:00,  1.83it/s]


0.0


100%|██████████| 20/20 [00:11<00:00,  1.81it/s]


0.0005955934397876263


100%|██████████| 20/20 [00:11<00:00,  1.78it/s]


12.1032484375


100%|██████████| 20/20 [00:11<00:00,  1.80it/s]


36.4782953125


100%|██████████| 20/20 [00:10<00:00,  1.82it/s]


0.0006207897856831551


100%|██████████| 20/20 [00:11<00:00,  1.80it/s]


3.8041478488594293e-06


100%|██████████| 20/20 [00:11<00:00,  1.81it/s]


0.0


100%|██████████| 20/20 [00:11<00:00,  1.79it/s]


0.0006207873709499835


100%|██████████| 20/20 [00:11<00:00,  1.79it/s]


11.983464624023437


100%|██████████| 20/20 [00:11<00:00,  1.79it/s]


36.16382587890625


100%|██████████| 20/20 [00:11<00:00,  1.81it/s]


0.00047608321607112885


100%|██████████| 20/20 [00:11<00:00,  1.79it/s]


2.68241839366965e-06


100%|██████████| 20/20 [00:11<00:00,  1.79it/s]


0.0


100%|██████████| 20/20 [00:11<00:00,  1.80it/s]


0.00047608350813388824


100%|██████████| 20/20 [00:11<00:00,  1.80it/s]


10.423259057617187


100%|██████████| 20/20 [00:11<00:00,  1.78it/s]


33.5777029296875


100%|██████████| 20/20 [00:11<00:00,  1.81it/s]


0.0004115783102810383


100%|██████████| 20/20 [00:11<00:00,  1.79it/s]


1.2100443593226373e-06


100%|██████████| 20/20 [00:11<00:00,  1.79it/s]


0.0


100%|██████████| 20/20 [00:11<00:00,  1.79it/s]


0.0004115825138986111


100%|██████████| 20/20 [00:11<00:00,  1.79it/s]


9.824226342773438


100%|██████████| 20/20 [00:11<00:00,  1.78it/s]


37.14941015625


100%|██████████| 20/20 [00:11<00:00,  1.76it/s]


0.0005323765903711319


100%|██████████| 20/20 [00:11<00:00,  1.80it/s]


1.3311370828887448e-06


100%|██████████| 20/20 [00:11<00:00,  1.78it/s]


0.0


100%|██████████| 20/20 [00:11<00:00,  1.79it/s]


0.000532409642636776


100%|██████████| 20/20 [00:11<00:00,  1.80it/s]


11.393523974609375


100%|██████████| 20/20 [00:11<00:00,  1.75it/s]


50.31952021484375


In [17]:
all_totals_temp = list()
from utils import normalization
for i in range(0,len(all_totals),2):
    total_0 = all_totals[i]
    total_1 = all_totals[i+1]
    total = dict()
    total_values = list()
    for key in list(total_0.keys()):
        total_values.append(total_0[key] + total_1[key])
    total_values = np.array(total_values)
    total_values = normalization(abs(total_values))
    for key in list(total_0.keys()):
        total[key] = total_values[list(total_0.keys()).index(key)]
    all_totals_temp.append(total)


  total_values = np.array(total_values)
  x = np.array(x)


In [20]:
all_totals_clones = all_totals.copy()

In [21]:
all_totals = all_totals_temp

In [22]:
len(all_totals)

10

In [23]:
pkl.dump(all_totals, open("weights/op_totals.pkl", "wb"))

In [137]:
thre = 0.25
net = load_cifar10_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight) for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    for idx,param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t

  combine = np.array(combine)


In [138]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())

conv1.weight 0.0
layer1.0.conv1.weight 0.00244140625
layer1.0.conv2.weight 0.007649739583333333
layer1.0.conv3.weight 0.00262451171875
layer1.0.shortcut.0.weight 0.00225830078125
layer1.1.conv1.weight 0.0074462890625
layer1.1.conv2.weight 0.005506727430555556
layer1.1.conv3.weight 0.00177001953125
layer1.2.conv1.weight 0.00152587890625
layer1.2.conv2.weight 0.001193576388888889
layer1.2.conv3.weight 0.00250244140625
layer2.0.conv1.weight 0.001678466796875
layer2.0.conv2.weight 0.004489474826388889
layer2.0.conv3.weight 0.003509521484375
layer2.0.shortcut.0.weight 0.00431060791015625
layer2.1.conv1.weight 0.0086212158203125
layer2.1.conv2.weight 0.006266276041666667
layer2.1.conv3.weight 0.0044403076171875
layer2.2.conv1.weight 0.0070343017578125
layer2.2.conv2.weight 0.003167046440972222
layer2.2.conv3.weight 0.0043487548828125
layer2.3.conv1.weight 0.005859375
layer2.3.conv2.weight 0.002509223090277778
layer2.3.conv3.weight 0.0048370361328125
layer3.0.conv1.weight 0.003509521484375
la

In [140]:
temp / all_num

0.015241999202138154

In [130]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())

原始准确率 0.954


In [143]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())

现在准确率 0.9539


In [119]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval("net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())

对比实验准确率 0.1
