In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from attack import attack, test_model, parse_param
from utils import caculate_param_remove
import random


In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)

In [3]:
from datasets import load_cifar100
from models.resnet import load_cifar100_resnet50
model = load_cifar100_resnet50()


In [4]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)
all_param_names = all_param_names[:-2]

In [6]:
train_loaders, test_dataloaders,train_dataloader_all, test_dataloader_all = load_cifar100()
all_totals = list()
for i in range(100):
    all_totals.append(attack(train_loaders[i], all_param_names, load_cifar100_resnet50,norm=True, alpha=0.00001,num_steps=2,op="minus"))


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 20/20 [00:13<00:00,  1.51it/s]


0.0006625839769840241


100%|██████████| 20/20 [00:09<00:00,  2.19it/s]


0.00014571376331150532


  param_totals = np.array(param_totals)
  x = np.array(x)
100%|██████████| 20/20 [00:08<00:00,  2.24it/s]


0.000462839786708355


100%|██████████| 20/20 [00:08<00:00,  2.26it/s]


0.0001262958489358425


100%|██████████| 20/20 [00:08<00:00,  2.25it/s]


0.0005848589763045311


100%|██████████| 20/20 [00:08<00:00,  2.25it/s]


0.0001318057008087635


100%|██████████| 20/20 [00:08<00:00,  2.25it/s]


0.0007869988471269607


100%|██████████| 20/20 [00:08<00:00,  2.27it/s]


0.00011125698871910573


100%|██████████| 20/20 [00:08<00:00,  2.25it/s]


0.0006073642581701278


100%|██████████| 20/20 [00:08<00:00,  2.27it/s]


0.00014102745279669763


100%|██████████| 20/20 [00:08<00:00,  2.26it/s]


0.0005955979764461517


100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


0.00010021717865020037


100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


0.0006207811504602433


100%|██████████| 20/20 [00:09<00:00,  2.20it/s]


0.00014143167398869992


100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


0.0004760828569531441


100%|██████████| 20/20 [00:09<00:00,  2.12it/s]


0.00012338287569582462


100%|██████████| 20/20 [00:09<00:00,  2.12it/s]


0.0004115798108279705


100%|██████████| 20/20 [00:09<00:00,  2.11it/s]


9.683243203908205e-05


100%|██████████| 20/20 [00:09<00:00,  2.11it/s]


0.0005323725596070289


100%|██████████| 20/20 [00:09<00:00,  2.12it/s]


0.00011698988024145365


In [7]:
pkl.dump(all_totals, open("weights/minus_totals.pkl", "wb"))

In [8]:
thre = 0.25
net = load_cifar100_resnet50()
param_remove = caculate_param_remove(all_param_names, all_totals, net, thre)

  combine = np.array(combine)


In [9]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())

conv1.weight 0.9976851851851852
layer1.0.conv1.weight 0.957275390625
layer1.0.conv2.weight 0.894775390625
layer1.0.conv3.weight 0.9014892578125
layer1.0.shortcut.0.weight 0.86907958984375
layer1.1.conv1.weight 0.83428955078125
layer1.1.conv2.weight 0.8971082899305556
layer1.1.conv3.weight 0.87689208984375
layer1.2.conv1.weight 0.83740234375
layer1.2.conv2.weight 0.8482801649305556
layer1.2.conv3.weight 0.837158203125
layer2.0.conv1.weight 0.95440673828125
layer2.0.conv2.weight 0.8832600911458334
layer2.0.conv3.weight 0.91387939453125
layer2.0.shortcut.0.weight 0.8760604858398438
layer2.1.conv1.weight 0.789337158203125
layer2.1.conv2.weight 0.8686591254340278
layer2.1.conv3.weight 0.8619384765625
layer2.2.conv1.weight 0.841156005859375
layer2.2.conv2.weight 0.8667399088541666
layer2.2.conv3.weight 0.8291015625
layer2.3.conv1.weight 0.88299560546875
layer2.3.conv2.weight 0.8697577582465278
layer2.3.conv3.weight 0.7801666259765625
layer3.0.conv1.weight 0.9436111450195312
layer3.0.conv2.we

In [10]:
temp / all_num

0.5226386210534361

In [11]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())

原始准确率 0.954


In [12]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())

现在准确率 0.9429


In [13]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval("net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())

对比实验准确率 0.1
