In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import pickle as pkl
import torch.nn.functional as F
import torchvision
from attack import attack, test_model,parse_param


In [2]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_cifar100_resnet50()


In [3]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)

In [4]:
all_param_names = all_param_names[:-2]

In [5]:
train_loaders, test_dataloaders,train_dataloader_all, test_dataloader_all = load_cifar100()
all_totals = list()
all_totals.append(attack(train_dataloader_all, all_param_names, load_cifar100_resnet50, alpha=0.0001))


Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 196/196 [01:20<00:00,  2.44it/s]


0.0025499327966570853


100%|██████████| 196/196 [01:23<00:00,  2.34it/s]


1.5512522412109375


100%|██████████| 196/196 [01:24<00:00,  2.31it/s]


4.551924611206054


100%|██████████| 196/196 [01:28<00:00,  2.22it/s]


18.10628897705078


100%|██████████| 196/196 [01:27<00:00,  2.23it/s]


72.1713884375


  param_totals = np.array(param_totals)
  x = np.array(x)


In [6]:
all_totals


[{'conv1.weight': array([[[[0.68770057, 1.9945182 , 1.1979407 ],
           [0.67375845, 1.2458116 , 0.78104633],
           [0.46793148, 0.74654156, 0.54974043]],
  
          [[0.5733383 , 2.0804217 , 1.3217131 ],
           [0.5539914 , 1.0907117 , 0.9137455 ],
           [0.3658267 , 0.7949592 , 0.3973766 ]],
  
          [[0.29734427, 1.9685125 , 1.21161   ],
           [0.32839137, 1.0168011 , 0.92834866],
           [0.16459915, 0.6974358 , 0.377849  ]]],
  
  
         [[[1.1297226 , 1.3676162 , 1.213601  ],
           [1.4521949 , 1.5203305 , 1.5650698 ],
           [1.3957955 , 1.6000452 , 1.4723983 ]],
  
          [[1.0539423 , 1.2908435 , 1.1287569 ],
           [1.3521117 , 1.4174347 , 1.4537528 ],
           [1.294312  , 1.4920663 , 1.359647  ]],
  
          [[1.1987495 , 1.4348295 , 1.2753388 ],
           [1.49794   , 1.5704635 , 1.6019444 ],
           [1.4360391 , 1.6309154 , 1.5075868 ]]],
  
  
         [[[0.71814317, 1.0399926 , 0.52506065],
           [0.7022153

In [7]:
pkl.dump(all_totals, open('weights/cifar100_resnet50.pkl', 'wb'))

In [14]:
thre = 0.6
net = load_cifar100_resnet50()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight) for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    for idx,param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t

  combine = np.array(combine)


In [15]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())

conv1.weight 0.9976851851851852
layer1.0.conv1.weight 0.92724609375
layer1.0.conv2.weight 0.8306206597222222
layer1.0.conv3.weight 0.871337890625
layer1.0.shortcut.0.weight 0.8856201171875
layer1.1.conv1.weight 0.80078125
layer1.1.conv2.weight 0.862548828125
layer1.1.conv3.weight 0.88433837890625
layer1.2.conv1.weight 0.85089111328125
layer1.2.conv2.weight 0.8541666666666666
layer1.2.conv3.weight 0.72662353515625
layer2.0.conv1.weight 0.940155029296875
layer2.0.conv2.weight 0.8363918728298612
layer2.0.conv3.weight 0.852294921875
layer2.0.shortcut.0.weight 0.8020095825195312
layer2.1.conv1.weight 0.62603759765625
layer2.1.conv2.weight 0.7291191948784722
layer2.1.conv3.weight 0.7801361083984375
layer2.2.conv1.weight 0.7144012451171875
layer2.2.conv2.weight 0.7576904296875
layer2.2.conv3.weight 0.7773590087890625
layer2.3.conv1.weight 0.7182769775390625
layer2.3.conv2.weight 0.7470364040798612
layer2.3.conv3.weight 0.703399658203125
layer3.0.conv1.weight 0.9081802368164062
layer3.0.conv2.

In [16]:
temp / all_num

0.5999999914702085

In [17]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    correct, all = test_model(net, test_dataloader_all)
    print("原始准确率", correct / all)


原始准确率 0.7931


In [18]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    correct, all = test_model(net, test_dataloader_all)
    print("现在准确率", correct / all)


现在准确率 0.725


In [19]:
with torch.no_grad():
    net = load_cifar100_resnet50()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval("net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    correct, all = test_model(net, test_dataloader_all)
    print("对比试验准确率", correct / all)


对比试验准确率 0.01
