In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from tqdm.notebook import tqdm
from attack import attack, parse_param, test_model
from utils import get_device, caculate_param_remove
import random
device = get_device()


In [None]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)


In [None]:
from datasets import load_imagenet
from torchvision.models import efficientnet_b0,EfficientNet_B0_Weights

def load_efficientnet_b0():
    net = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
    net = net.to(device)
    return net
model = load_efficientnet_b0()
parameters = list(model.named_parameters())
all_param_names = list()
i = 0
while i < len(parameters):
    if len(parameters[i][1].shape) == 1 and "weight" in parameters[i][0]:
        i += 2
        continue
    else:
        # print(parameters[i][0], parameters[i][1].shape)
        all_param_names.append(parameters[i][0])
        i += 1
all_param_names = all_param_names[:-2]

In [None]:
train_loaders, test_dataloaders, train_dataloader_all, test_dataloader_all = load_imagenet()


In [None]:
all_totals = list()
for i in tqdm(range(100)):
    all_totals.append(attack(train_loaders[i], all_param_names,
                      load_efficientnet_b0, norm=True, alpha=0.00001, num_steps=4, op="add"))

In [6]:
pkl.dump(all_totals, open("weights/totals.pkl", "wb"))
# all_totals = pkl.load(open("weights/totals.pkl", "rb"))


In [17]:
thre = 0.3
net = load_efficientnet_b0()
param_remove = caculate_param_remove(all_param_names, all_totals, net, thre)

In [18]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())


features.0.0.weight 0.9733796296296297
features.1.0.block.0.0.weight 0.9756944444444444
features.1.0.block.1.fc1.weight 0.88671875
features.1.0.block.1.fc1.bias 1.0
features.1.0.block.1.fc2.weight 0.5078125
features.1.0.block.1.fc2.bias 0.96875
features.1.0.block.2.0.weight 0.982421875
features.2.0.block.0.0.weight 0.9563802083333334
features.2.0.block.1.0.weight 0.9849537037037037
features.2.0.block.2.fc1.weight 0.75
features.2.0.block.2.fc1.bias 0.75
features.2.0.block.2.fc2.weight 0.5390625
features.2.0.block.2.fc2.bias 1.0
features.2.0.block.3.0.weight 0.96484375
features.2.1.block.0.0.weight 0.9421296296296297
features.2.1.block.1.0.weight 0.9537037037037037
features.2.1.block.2.fc1.weight 0.6736111111111112
features.2.1.block.2.fc1.bias 1.0
features.2.1.block.2.fc2.weight 0.6574074074074074
features.2.1.block.2.fc2.bias 0.9861111111111112
features.2.1.block.3.0.weight 0.9337384259259259
features.3.0.block.0.0.weight 0.9508101851851852
features.3.0.block.1.0.weight 0.8719444444444

In [19]:
temp / all_num


0.6825984508509829

In [20]:
with torch.no_grad():
    net = load_efficientnet_b0()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())


  0%|          | 0/94 [00:00<?, ?it/s]

原始准确率 0.8881666666666667


In [21]:
with torch.no_grad():
    net = load_efficientnet_b0()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())


  0%|          | 0/94 [00:00<?, ?it/s]

现在准确率 0.827


In [None]:
with torch.no_grad():
    net = load_efficientnet_b0()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval(
            "net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(
            len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())
