In [12]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from attack import attack, test_model,parse_param
import random

In [13]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)

In [14]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar100_resnet50()


In [15]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)

In [None]:
all_param_names = all_param_names[:-2]

In [16]:
train_loaders, test_dataloaders,train_dataloader_all, test_dataloader_all = load_cifar100()
all_totals = list()
for i in range(100):
    all_totals.append(attack(train_loaders[i], all_param_names, load_cifar100_resnet50,train_dataloader_all, alpha=0.0001,num_steps=1,op="add"))


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 20/20 [00:09<00:00,  2.14it/s]


0.0006625839769840241


100%|██████████| 20/20 [00:08<00:00,  2.23it/s]


0.0004628416374325752


100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


0.0005848595090210438


100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


0.0007870129093527794


100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


0.0006073626540601254


100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


0.000595596294850111


100%|██████████| 20/20 [00:09<00:00,  2.22it/s]


0.0006207933589816093


100%|██████████| 20/20 [00:09<00:00,  2.20it/s]


0.0004760825902223587


100%|██████████| 20/20 [00:09<00:00,  2.21it/s]


0.0004115789517760277


100%|██████████| 20/20 [00:08<00:00,  2.23it/s]


0.0005323646783828735


In [None]:
pkl.dump(all_totals, open("weights/sm_totals.pkl", "wb"))

In [17]:
thre = 0.25
net = load_cifar100_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight) for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    for idx,param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t

  combine = np.array(combine)


In [18]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())

conv1.weight 0.9976851851851852
layer1.0.conv1.weight 0.943359375
layer1.0.conv2.weight 0.8621148003472222
layer1.0.conv3.weight 0.883544921875
layer1.0.shortcut.0.weight 0.849365234375
layer1.1.conv1.weight 0.80364990234375
layer1.1.conv2.weight 0.8681098090277778
layer1.1.conv3.weight 0.8525390625
layer1.2.conv1.weight 0.80810546875
layer1.2.conv2.weight 0.8195258246527778
layer1.2.conv3.weight 0.80548095703125
layer2.0.conv1.weight 0.938873291015625
layer2.0.conv2.weight 0.8439805772569444
layer2.0.conv3.weight 0.8879241943359375
layer2.0.shortcut.0.weight 0.8387908935546875
layer2.1.conv1.weight 0.73004150390625
layer2.1.conv2.weight 0.8270195855034722
layer2.1.conv3.weight 0.8274383544921875
layer2.2.conv1.weight 0.791595458984375
layer2.2.conv2.weight 0.8248765733506944
layer2.2.conv3.weight 0.789764404296875
layer2.3.conv1.weight 0.843353271484375
layer2.3.conv2.weight 0.8288642035590278
layer2.3.conv3.weight 0.7334442138671875
layer3.0.conv1.weight 0.9242935180664062
layer3.0.c

In [19]:
temp / all_num

0.4442216419642264

In [20]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())

原始准确率 0.954


In [21]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())

现在准确率 0.9306


In [22]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval("net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())

对比实验准确率 0.1
