In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from attack import attack, test_model, parse_param
from utils import caculate_param_remove
import random


In [None]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)

In [None]:
from datasets import load_cifar10
from models.resnet import load_cifar10_resnet50
model = load_cifar10_resnet50()


In [None]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)
all_param_names = all_param_names[:-2]

In [None]:
train_loaders, test_dataloaders, train_dataloader_all, test_dataloader_all = load_cifar10()
all_totals = list()
all_totals.append(attack(train_dataloader_all, all_param_names,
                      load_cifar10_resnet50, norm=False, alpha=0.00001, num_steps=2, op="minus"))
all_totals.append(attack(train_dataloader_all, all_param_names,
                      load_cifar10_resnet50, norm=False, alpha=0.00001, num_steps=4, op="add"))


In [None]:
all_totals_temp = list()
from utils import normalization
for i in range(0,len(all_totals),2):
    total_0 = all_totals[i]
    total_1 = all_totals[i+1]
    total = dict()
    total_values = list()
    for key in list(total_0.keys()):
        total_values.append(total_0[key] + total_1[key])
    total_values = np.array(total_values)
    total_values = normalization(abs(total_values))
    for key in list(total_0.keys()):
        total[key] = total_values[list(total_0.keys()).index(key)]
    all_totals_temp.append(total)


In [None]:
all_totals_clones = all_totals.copy()

In [None]:
all_totals = all_totals_temp

In [None]:
len(all_totals)

In [None]:
pkl.dump(all_totals, open("weights/op_totals.pkl", "wb"))

In [None]:
thre = 0.25
net = load_cifar10_resnet50()
param_remove = caculate_param_remove(all_param_names, all_totals, net, thre)

In [None]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())

In [None]:
temp / all_num

In [None]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())

In [None]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())

In [None]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval("net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())