In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pickle as pkl
import torchvision
from attack import attack, test_model,parse_param


In [2]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar100_resnet50()

In [3]:
all_param_names = list()
for name, param in model.named_parameters():
    all_param_names.append(name)

In [4]:
train_loaders, test_dataloaders, test_dataloader_all = load_cifar100()
all_totals = list()
for i in range(100):
    all_totals.append(attack(train_loaders[i], all_param_names, load_cifar100_resnet50, alpha=0.00002))


Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
pkl.dump(all_totals, open("weights/cifar100_resnet50_totals.pkl", "wb"))

In [None]:
thre = 0.1
net = load_cifar100_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight) for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    for idx,param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t

  combine = np.array(combine)


In [None]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())

conv1.weight 0.9965277777777778
layer1.0.conv1.weight 0.912841796875
layer1.0.conv2.weight 0.7533094618055556
layer1.0.conv3.weight 0.81182861328125
layer1.0.shortcut.0.weight 0.837890625
layer1.1.conv1.weight 0.71002197265625
layer1.1.conv2.weight 0.7708333333333334
layer1.1.conv3.weight 0.77099609375
layer1.2.conv1.weight 0.71893310546875
layer1.2.conv2.weight 0.7283799913194444
layer1.2.conv3.weight 0.57537841796875
layer2.0.conv1.weight 0.901153564453125
layer2.0.conv2.weight 0.7180650499131944
layer2.0.conv3.weight 0.8049774169921875
layer2.0.shortcut.0.weight 0.709259033203125
layer2.1.conv1.weight 0.4111480712890625
layer2.1.conv2.weight 0.5914849175347222
layer2.1.conv3.weight 0.677459716796875
layer2.2.conv1.weight 0.5162811279296875
layer2.2.conv2.weight 0.6452297634548612
layer2.2.conv3.weight 0.68939208984375
layer2.3.conv1.weight 0.6061248779296875
layer2.3.conv2.weight 0.66510009765625
layer2.3.conv3.weight 0.634124755859375
layer3.0.conv1.weight 0.8852996826171875
layer3

In [None]:
temp / all_num

0.6133051014294566

In [None]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    correct, all = test_model(net, test_dataloader_all)
    print("原始准确率", correct / all)


原始准确率 0.7931


In [None]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    correct, all = test_model(net, test_dataloader_all)
    print("现在准确率", correct / all)

现在准确率 0.7012


In [None]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval("net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    correct, all = test_model(net, test_dataloader_all)
    print("对比实验准确率", correct / all)

对比实验准确率 0.01
