In [None]:
from model import VisionTransformer
import torch
from datasets import load_cifar10
from attack import attack,test_model,parse_param
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def load_model():
    model = VisionTransformer(
                image_size=(384, 384),
                patch_size=(16, 16),
                emb_dim=768,
                mlp_dim=3072,
                num_heads=12,
                num_layers=12,
                num_classes=10,
                attn_dropout_rate=0.0,
                dropout_rate=0.1)
    state_dict = torch.load("weights/best.pth")["state_dict"]
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    return model

In [None]:
model = load_model()

In [None]:
for name, param in model.named_parameters():
    print(name, param.shape)

In [None]:
all_param_names = list()
for name, param in model.named_parameters():
    all_param_names.append(name)

In [None]:
train_loaders, test_dataloaders, test_dataloader_all = load_cifar10()

In [None]:
# test_model(model,test_dataloader_all)

In [None]:

all_totals = list()
for i in range(10):
    all_totals.append(attack(train_loaders[i], all_param_names, load_model, alpha=0.0001))


In [None]:
import pickle as pkl
pkl.dump(all_totals, open("weights/cifar10_vit_all.pkl", "wb"))

In [None]:
import numpy as np

In [None]:
thre = 0.4
net = load_model()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight) for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    for idx,param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t

In [None]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())

In [None]:
temp / all_num

In [None]:
with torch.no_grad():
    net = load_model()
    correct, all = test_model(net, test_dataloader_all)
    print("原始准确率", correct / all)

In [None]:
with torch.no_grad():
    net = load_model()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    correct, all = test_model(net, test_dataloader_all)
    print("现在准确率", correct / all)


In [None]:
for key,value in param.items():
    print((eval("net." + key + ".weight")[np.where(value == False)]).sum())
    # break

In [None]:
with torch.no_grad():
    net = load_model()
    for param in all_param_names:
        param_ = parse_param(param)
        keep_rate = param_remove[param].sum() / param_remove[param].size
        weight_flatten = eval("net." + param_ + ".cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + param_ + "[eval('net.' + param_ + '.cpu().detach().numpy()') < threshold,:] = 0")
    correct, all = test_model(net, test_dataloader_all)
    print("对比试验准确率", correct / all)
