In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from tqdm.notebook import tqdm
from attack import attack,parse_param,test_model
from utils import get_device,caculate_param_remove
import random
device = get_device()

In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)


In [3]:
from datasets import load_cifar10
from models.resnet import load_cifar10_resnet50
model = load_cifar10_resnet50()


FileNotFoundError: [Errno 2] No such file or directory: 'weights/CIFAR10_Resnet50_weights.pth'

In [None]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)
all_param_names = all_param_names[:-2]

In [None]:
train_loaders, test_dataloaders, train_dataloader_all, test_dataloader_all = load_cifar10()
all_totals = list()
for i in range(10):
    all_totals.append(attack(train_loaders[i], all_param_names,
                      load_cifar10_resnet50, norm=True, alpha=0.00001, num_steps=4, op="add"))


In [None]:
pkl.dump(all_totals, open("weights/totals.pkl", "wb"))
# all_totals = pkl.load(open("weights/totals.pkl", "rb"))


In [None]:
thre = 0.25
net = load_cifar10_resnet50()
param_remove = caculate_param_remove(all_param_names, all_totals, net, thre)

In [None]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())


In [None]:
temp / all_num


In [None]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())


In [None]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())


In [None]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval(
            "net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(
            len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())
