In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from attack import attack, test_model,parse_param
import random

In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)

In [3]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar100_resnet50()


In [4]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)

In [5]:
all_param_names = all_param_names[:-2]

In [6]:
train_loaders, test_dataloaders,train_dataloader_all, test_dataloader_all = load_cifar100()
all_totals = list()
all_totals.append(attack(train_dataloader_all, all_param_names, load_cifar100_resnet50,train_dataloader_all, alpha=0.00001,num_steps=2,op="minus"))


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 196/196 [01:22<00:00,  2.38it/s]


0.0025499204021692274


100%|██████████| 196/196 [01:18<00:00,  2.50it/s]


0.0014761182779073716


  param_totals = np.array(param_totals)
  x = np.array(x)


In [7]:
pkl.dump(all_totals, open("weights/minus_totals.pkl", "wb"))

In [18]:
thre = 0.5
net = load_cifar100_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight) for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    for idx,param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t

  combine = np.array(combine)


In [19]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())

conv1.weight 0.9924768518518519
layer1.0.conv1.weight 0.8310546875
layer1.0.conv2.weight 0.6549207899305556
layer1.0.conv3.weight 0.73388671875
layer1.0.shortcut.0.weight 0.74945068359375
layer1.1.conv1.weight 0.59808349609375
layer1.1.conv2.weight 0.6349555121527778
layer1.1.conv3.weight 0.675048828125
layer1.2.conv1.weight 0.59375
layer1.2.conv2.weight 0.6066080729166666
layer1.2.conv3.weight 0.50640869140625
layer2.0.conv1.weight 0.786895751953125
layer2.0.conv2.weight 0.5433553059895834
layer2.0.conv3.weight 0.6605987548828125
layer2.0.shortcut.0.weight 0.5523605346679688
layer2.1.conv1.weight 0.3526611328125
layer2.1.conv2.weight 0.3769599066840278
layer2.1.conv3.weight 0.507049560546875
layer2.2.conv1.weight 0.4006500244140625
layer2.2.conv2.weight 0.46923828125
layer2.2.conv3.weight 0.5167388916015625
layer2.3.conv1.weight 0.4222564697265625
layer2.3.conv2.weight 0.4430406358506944
layer2.3.conv3.weight 0.4775390625
layer3.0.conv1.weight 0.7414703369140625
layer3.0.conv2.weight 

In [20]:
temp / all_num

0.5

In [21]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())

原始准确率 0.7929


In [22]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())

现在准确率 0.2187


In [13]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval("net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("对比实验准确率", (preds.argmax(-1) == labels).mean())

对比实验准确率 0.01
