In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from learn2learn.vision.datasets import MiniImagenet
from tqdm.notebook import tqdm
from attack import attack, parse_param, test_model
from utils import get_device, caculate_param_remove
import random
device = get_device()


In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)


In [3]:
from datasets import load_cifar10
from models.resnet import load_cifar10_resnet50
model = load_cifar10_resnet50()


In [4]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)
all_param_names = all_param_names[:-2]


In [5]:
train_loaders, test_dataloaders, train_dataloader_all, test_dataloader_all = load_cifar10()
all_totals = list()
for i in range(10):
    all_totals.append(attack(train_loaders[i], all_param_names,
                      load_cifar10_resnet50, norm=True, alpha=0.00001, num_steps=4, op="add"))


Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.00995345048904419


  0%|          | 0/20 [00:00<?, ?it/s]

0.013395670342445374


  0%|          | 0/20 [00:00<?, ?it/s]

0.38232105712890624


  0%|          | 0/20 [00:00<?, ?it/s]

1.9999372589111328


  x = np.array(x)


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.010153193998336791


  0%|          | 0/20 [00:00<?, ?it/s]

-0.0011885146617889404


  0%|          | 0/20 [00:00<?, ?it/s]

0.15676712760925293


  0%|          | 0/20 [00:00<?, ?it/s]

1.0295591491699219


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.010031173849105835


  0%|          | 0/20 [00:00<?, ?it/s]

0.005215082144737243


  0%|          | 0/20 [00:00<?, ?it/s]

0.2086234100341797


  0%|          | 0/20 [00:00<?, ?it/s]

1.252728094482422


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.009829036808013916


  0%|          | 0/20 [00:00<?, ?it/s]

0.016188286733627318


  0%|          | 0/20 [00:00<?, ?it/s]

0.39721205825805667


  0%|          | 0/20 [00:00<?, ?it/s]

2.036865411376953


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.010008672952651978


  0%|          | 0/20 [00:00<?, ?it/s]

0.0010555079221725465


  0%|          | 0/20 [00:00<?, ?it/s]

0.17615248546600343


  0%|          | 0/20 [00:00<?, ?it/s]

1.183141372680664


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.010020434951782226


  0%|          | 0/20 [00:00<?, ?it/s]

0.007610933089256287


  0%|          | 0/20 [00:00<?, ?it/s]

0.2939016357421875


  0%|          | 0/20 [00:00<?, ?it/s]

1.7085674224853515


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.009995253705978393


  0%|          | 0/20 [00:00<?, ?it/s]

0.0004541816234588623


  0%|          | 0/20 [00:00<?, ?it/s]

0.19505502586364745


  0%|          | 0/20 [00:00<?, ?it/s]

1.4660293243408202


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.010139955663681031


  0%|          | 0/20 [00:00<?, ?it/s]

0.0011085834503173828


  0%|          | 0/20 [00:00<?, ?it/s]

0.17003772735595704


  0%|          | 0/20 [00:00<?, ?it/s]

1.150430776977539


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.010204456233978271


  0%|          | 0/20 [00:00<?, ?it/s]

-7.367324829101562e-05


  0%|          | 0/20 [00:00<?, ?it/s]

0.20973687477111816


  0%|          | 0/20 [00:00<?, ?it/s]

1.4370737152099609


  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

-0.010083657836914062


  0%|          | 0/20 [00:00<?, ?it/s]

0.00612880482673645


  0%|          | 0/20 [00:00<?, ?it/s]

0.3244254638671875


  0%|          | 0/20 [00:00<?, ?it/s]

1.8953181213378907


In [6]:
pkl.dump(all_totals, open("weights/totals.pkl", "wb"))
# all_totals = pkl.load(open("weights/totals.pkl", "rb"))


In [13]:
thre = 0.7
net = load_cifar10_resnet50()
param_remove = caculate_param_remove(all_param_names, all_totals, net, thre)

In [14]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())


conv1.weight 1.0
layer1.0.conv1.weight 0.996337890625
layer1.0.conv2.weight 0.992919921875
layer1.0.conv3.weight 0.9595947265625
layer1.0.shortcut.0.weight 0.95208740234375
layer1.1.conv1.weight 0.9461669921875
layer1.1.conv2.weight 0.9931911892361112
layer1.1.conv3.weight 0.97125244140625
layer1.2.conv1.weight 0.9637451171875
layer1.2.conv2.weight 0.9746365017361112
layer1.2.conv3.weight 0.98480224609375
layer2.0.conv1.weight 0.998016357421875
layer2.0.conv2.weight 0.99462890625
layer2.0.conv3.weight 0.9896697998046875
layer2.0.shortcut.0.weight 0.9908065795898438
layer2.1.conv1.weight 0.9828338623046875
layer2.1.conv2.weight 0.9921061197916666
layer2.1.conv3.weight 0.98687744140625
layer2.2.conv1.weight 0.9900665283203125
layer2.2.conv2.weight 0.9926079644097222
layer2.2.conv3.weight 0.98370361328125
layer2.3.conv1.weight 0.994110107421875
layer2.3.conv2.weight 0.9930487738715278
layer2.3.conv3.weight 0.974884033203125
layer3.0.conv1.weight 0.9975128173828125
layer3.0.conv2.weight 0.

In [15]:
temp / all_num


0.9050795846605689

In [16]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())


原始准确率 0.954


In [None]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())


In [None]:
with torch.no_grad():
    net = load_cifar10_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval(
            "net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(
            len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ +
                 "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())
