In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from attack_op import attack, test_model,parse_param
import random

In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)

In [3]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar10_resnet50()


In [4]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)

In [5]:
all_param_names = all_param_names[:-2]

In [6]:
train_loaders, test_dataloaders,train_dataloader_all, test_dataloader_all = load_cifar10()
all_totals = list()
for i in range(10):
    all_totals.append(attack(train_loaders[i], all_param_names, load_cifar10_resnet50, alpha=0.00001,num_steps=2,op="minus"))
    all_totals.append(attack(train_loaders[i], all_param_names, load_cifar10_resnet50, alpha=0.00001,num_steps=4,op="add"))


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 20/20 [00:12<00:00,  1.60it/s]


0.0006625839769840241


100%|██████████| 20/20 [00:07<00:00,  2.83it/s]


0.00014571376331150532


  param_totals = np.array(param_totals)
100%|██████████| 20/20 [00:06<00:00,  2.91it/s]


0.0006625839993357658


100%|██████████| 20/20 [00:06<00:00,  2.90it/s]


0.02401344358921051


100%|██████████| 20/20 [00:06<00:00,  2.90it/s]


0.3929175445556641


100%|██████████| 20/20 [00:06<00:00,  2.90it/s]


2.0105186614990234


100%|██████████| 20/20 [00:06<00:00,  2.90it/s]


0.0004628409862518311


100%|██████████| 20/20 [00:06<00:00,  2.90it/s]


0.00012629709225147964


100%|██████████| 20/20 [00:06<00:00,  2.90it/s]


0.00046284108981490136


100%|██████████| 20/20 [00:07<00:00,  2.86it/s]


0.009427620553970338


100%|██████████| 20/20 [00:06<00:00,  2.90it/s]


0.16737778759002686


100%|██████████| 20/20 [00:06<00:00,  2.88it/s]


1.0400880523681642


100%|██████████| 20/20 [00:06<00:00,  2.86it/s]


0.000584859086573124


100%|██████████| 20/20 [00:06<00:00,  2.86it/s]


0.00013180603235960007


100%|██████████| 20/20 [00:07<00:00,  2.76it/s]


0.000584858762472868


100%|██████████| 20/20 [00:07<00:00,  2.76it/s]


0.01582834048271179


100%|██████████| 20/20 [00:07<00:00,  2.75it/s]


0.21922850971221924


100%|██████████| 20/20 [00:07<00:00,  2.59it/s]


1.2632976104736329


100%|██████████| 20/20 [00:07<00:00,  2.76it/s]


0.0007870006538927555


100%|██████████| 20/20 [00:07<00:00,  2.75it/s]


0.00011125453654676676


100%|██████████| 20/20 [00:07<00:00,  2.77it/s]


0.0007869922697544098


100%|██████████| 20/20 [00:07<00:00,  2.75it/s]


0.02680502429008484


100%|██████████| 20/20 [00:07<00:00,  2.69it/s]


0.40782339935302736


100%|██████████| 20/20 [00:07<00:00,  2.62it/s]


2.047495520019531


100%|██████████| 20/20 [00:07<00:00,  2.71it/s]


0.0006073593929409981


100%|██████████| 20/20 [00:07<00:00,  2.77it/s]


0.0001410262878984213


100%|██████████| 20/20 [00:07<00:00,  2.68it/s]


0.0006073409274220467


100%|██████████| 20/20 [00:07<00:00,  2.72it/s]


0.01167050416469574


100%|██████████| 20/20 [00:07<00:00,  2.75it/s]


0.18673283653259276


100%|██████████| 20/20 [00:07<00:00,  2.73it/s]


1.1936228332519532


100%|██████████| 20/20 [00:07<00:00,  2.69it/s]


0.0005955957502126694


100%|██████████| 20/20 [00:07<00:00,  2.72it/s]


0.00010021753143519164


100%|██████████| 20/20 [00:07<00:00,  2.76it/s]


0.0005955764353275299


100%|██████████| 20/20 [00:07<00:00,  2.74it/s]


0.018224940705299376


100%|██████████| 20/20 [00:07<00:00,  2.77it/s]


0.3044577560424805


100%|██████████| 20/20 [00:07<00:00,  2.72it/s]


1.718875112915039


100%|██████████| 20/20 [00:07<00:00,  2.74it/s]


0.0006207897856831551


100%|██████████| 20/20 [00:07<00:00,  2.76it/s]


0.00014143160320818424


100%|██████████| 20/20 [00:07<00:00,  2.75it/s]


0.0006207907885313034


100%|██████████| 20/20 [00:07<00:00,  2.69it/s]


0.011071184360980988


100%|██████████| 20/20 [00:07<00:00,  2.74it/s]


0.2057021308898926


100%|██████████| 20/20 [00:07<00:00,  2.75it/s]


1.4767296478271483


100%|██████████| 20/20 [00:07<00:00,  2.62it/s]


0.00047608321607112885


100%|██████████| 20/20 [00:08<00:00,  2.49it/s]


0.00012338142544031144


100%|██████████| 20/20 [00:08<00:00,  2.50it/s]


0.0004760863184928894


100%|██████████| 20/20 [00:07<00:00,  2.59it/s]


0.011724818956851958


100%|██████████| 20/20 [00:07<00:00,  2.52it/s]


0.18064493103027343


100%|██████████| 20/20 [00:09<00:00,  2.14it/s]


1.160967333984375


100%|██████████| 20/20 [00:09<00:00,  2.06it/s]


0.0004115783102810383


100%|██████████| 20/20 [00:09<00:00,  2.10it/s]


9.683336485177279e-05


100%|██████████| 20/20 [00:09<00:00,  2.11it/s]


0.00041157990023493767


100%|██████████| 20/20 [00:09<00:00,  2.09it/s]


0.010544743859767914


100%|██████████| 20/20 [00:09<00:00,  2.06it/s]


0.22034988822937013


100%|██████████| 20/20 [00:10<00:00,  1.94it/s]


1.4476864288330078


100%|██████████| 20/20 [00:09<00:00,  2.06it/s]


0.0005323765903711319


100%|██████████| 20/20 [00:09<00:00,  2.07it/s]


0.00011699054837226867


100%|██████████| 20/20 [00:09<00:00,  2.07it/s]


0.0005323783457279205


100%|██████████| 20/20 [00:09<00:00,  2.09it/s]


0.016746693491935728


100%|██████████| 20/20 [00:09<00:00,  2.06it/s]


0.3350173500061035


100%|██████████| 20/20 [00:09<00:00,  2.13it/s]


1.9058331298828124


In [18]:
all_totals_temp = list()
from utils import normalization
for i in range(0,len(all_totals),2):
    total_0 = all_totals[i]
    total_1 = all_totals[i+1]
    total = dict()
    total_values = list()
    for key in list(total_0.keys()):
        total_values.append(total_0[key] + total_1[key])
    total_values = np.array(total_values)
    total_values = normalization(abs(total_values))
    for key in list(total_0.keys()):
        total[key] = total_values[list(total_0.keys()).index(key)]
    all_totals_temp.append(total)


  total_values = np.array(total_values)


In [19]:
all_totals_clones = all_totals.copy()

In [20]:
all_totals = all_totals_temp

In [21]:
len(all_totals)

5

In [22]:
pkl.dump(all_totals, open("weights/op_totals.pkl", "wb"))

In [23]:
thre = 0.25
net = load_cifar10_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight) for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    for idx,param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t

  combine = np.array(combine)


In [24]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())

conv1.weight 0.9976851851851852
layer1.0.conv1.weight 0.911865234375
layer1.0.conv2.weight 0.801025390625
layer1.0.conv3.weight 0.8663330078125
layer1.0.shortcut.0.weight 0.83758544921875
layer1.1.conv1.weight 0.748779296875
layer1.1.conv2.weight 0.8286404079861112
layer1.1.conv3.weight 0.84515380859375
layer1.2.conv1.weight 0.7855224609375
layer1.2.conv2.weight 0.8327907986111112
layer1.2.conv3.weight 0.79705810546875
layer2.0.conv1.weight 0.9393310546875
layer2.0.conv2.weight 0.8355780707465278
layer2.0.conv3.weight 0.8726043701171875
layer2.0.shortcut.0.weight 0.8238296508789062
layer2.1.conv1.weight 0.6333160400390625
layer2.1.conv2.weight 0.7771063910590278
layer2.1.conv3.weight 0.8152923583984375
layer2.2.conv1.weight 0.75860595703125
layer2.2.conv2.weight 0.7984551323784722
layer2.2.conv3.weight 0.7757110595703125
layer2.3.conv1.weight 0.810272216796875
layer2.3.conv2.weight 0.8111165364583334
layer2.3.conv3.weight 0.7202606201171875
layer3.0.conv1.weight 0.9180068969726562
laye

In [25]:
temp / all_num

0.36837990087699907

In [26]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())

原始准确率 0.954


In [27]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())

现在准确率 0.921


In [28]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval("net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())

对比实验准确率 0.1
