In [1]:
from model import VisionTransformer
import re
import torch
from datasets import load_cifar10
from attack import attack,test_model
device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
def load_model():
    model = VisionTransformer(
                image_size=(384, 384),
                patch_size=(16, 16),
                emb_dim=768,
                mlp_dim=3072,
                num_heads=12,
                num_layers=12,
                num_classes=10,
                attn_dropout_rate=0.0,
                dropout_rate=0.1)
    state_dict = torch.load("weights/best.pth")["state_dict"]
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    return model

In [3]:
model = load_model()

In [4]:
for name, param in model.named_parameters():
    print(name, param.shape)

cls_token torch.Size([1, 1, 768])
embedding.weight torch.Size([768, 3, 16, 16])
embedding.bias torch.Size([768])
transformer.pos_embedding.pos_embedding torch.Size([1, 577, 768])
transformer.encoder_layers.0.norm1.weight torch.Size([768])
transformer.encoder_layers.0.norm1.bias torch.Size([768])
transformer.encoder_layers.0.attn.query.weight torch.Size([768, 12, 64])
transformer.encoder_layers.0.attn.query.bias torch.Size([12, 64])
transformer.encoder_layers.0.attn.key.weight torch.Size([768, 12, 64])
transformer.encoder_layers.0.attn.key.bias torch.Size([12, 64])
transformer.encoder_layers.0.attn.value.weight torch.Size([768, 12, 64])
transformer.encoder_layers.0.attn.value.bias torch.Size([12, 64])
transformer.encoder_layers.0.attn.out.weight torch.Size([12, 64, 768])
transformer.encoder_layers.0.attn.out.bias torch.Size([768])
transformer.encoder_layers.0.norm2.weight torch.Size([768])
transformer.encoder_layers.0.norm2.bias torch.Size([768])
transformer.encoder_layers.0.mlp.fc1.wei

In [5]:
all_layer_names = list()
for name,param in model.named_parameters():
    if "weight" in name:
        reg = re.compile("\.\d+\.")
        finded = reg.findall(name)
        if len(finded) == 0:
            all_layer_names.append(name[:-7])
        else:
            for f in finded:
                f = f[1:-1]
                name = name.replace(f".{f}.", f"[{f}].")
            all_layer_names.append(name[:-7])
all_layer_names

['embedding',
 'transformer.encoder_layers[0].norm1',
 'transformer.encoder_layers[0].attn.query',
 'transformer.encoder_layers[0].attn.key',
 'transformer.encoder_layers[0].attn.value',
 'transformer.encoder_layers[0].attn.out',
 'transformer.encoder_layers[0].norm2',
 'transformer.encoder_layers[0].mlp.fc1',
 'transformer.encoder_layers[0].mlp.fc2',
 'transformer.encoder_layers[1].norm1',
 'transformer.encoder_layers[1].attn.query',
 'transformer.encoder_layers[1].attn.key',
 'transformer.encoder_layers[1].attn.value',
 'transformer.encoder_layers[1].attn.out',
 'transformer.encoder_layers[1].norm2',
 'transformer.encoder_layers[1].mlp.fc1',
 'transformer.encoder_layers[1].mlp.fc2',
 'transformer.encoder_layers[2].norm1',
 'transformer.encoder_layers[2].attn.query',
 'transformer.encoder_layers[2].attn.key',
 'transformer.encoder_layers[2].attn.value',
 'transformer.encoder_layers[2].attn.out',
 'transformer.encoder_layers[2].norm2',
 'transformer.encoder_layers[2].mlp.fc1',
 'transf

In [6]:
train_loaders, test_dataloaders, test_dataloader_all = load_cifar10()

Files already downloaded and verified
Files already downloaded and verified


In [7]:
test_model(model,test_dataloader_all)

  0%|          | 0/313 [00:00<?, ?it/s]

(9849, 10000)

In [8]:

all_totals = list()
for i in range(10):
    all_totals.append(attack(train_loaders[i], all_layer_names, load_model, alpha=0.0001))


  0%|          | 0/157 [00:00<?, ?it/s]

9.17611303339072e-05


  0%|          | 0/157 [00:00<?, ?it/s]

0.02952886368483305


  0%|          | 0/157 [00:00<?, ?it/s]

1.6651802265167237


  0%|          | 0/157 [00:00<?, ?it/s]

11.395269366455079


  0%|          | 0/157 [00:00<?, ?it/s]

25.353899447631836


  layer_totals = np.array(layer_totals)
  x = np.array(x)


  0%|          | 0/157 [00:00<?, ?it/s]

0.00017273987938806385


  0%|          | 0/157 [00:00<?, ?it/s]

0.29743441443443297


  0%|          | 0/157 [00:00<?, ?it/s]

7.620735089874268


  0%|          | 0/157 [00:00<?, ?it/s]

20.800572982788086


  0%|          | 0/157 [00:00<?, ?it/s]

28.997444354248046


  0%|          | 0/157 [00:00<?, ?it/s]

0.00026605912309896664


  0%|          | 0/157 [00:00<?, ?it/s]

0.01807895293090296


  0%|          | 0/157 [00:00<?, ?it/s]

0.15765635643079876


  0%|          | 0/157 [00:00<?, ?it/s]

1.9378524852752685


  0%|          | 0/157 [00:00<?, ?it/s]

9.774940203857422


  0%|          | 0/157 [00:00<?, ?it/s]

0.0023813774497655686


  0%|          | 0/157 [00:00<?, ?it/s]

0.2881590758293867


  0%|          | 0/157 [00:00<?, ?it/s]

4.5767861892700195


  0%|          | 0/157 [00:00<?, ?it/s]

14.31907311553955


  0%|          | 0/157 [00:00<?, ?it/s]

24.74333949584961


  0%|          | 0/157 [00:00<?, ?it/s]

0.0008837171252442204


  0%|          | 0/157 [00:00<?, ?it/s]

0.06307871801811853


  0%|          | 0/157 [00:00<?, ?it/s]

2.1311324887394907


  0%|          | 0/157 [00:00<?, ?it/s]

11.453127532958984


  0%|          | 0/157 [00:00<?, ?it/s]

21.267425317382813


  0%|          | 0/157 [00:00<?, ?it/s]

0.0009060151876648888


  0%|          | 0/157 [00:00<?, ?it/s]

0.17408443751093


  0%|          | 0/157 [00:00<?, ?it/s]

3.9981314842224123


  0%|          | 0/157 [00:00<?, ?it/s]

14.684711595153809


  0%|          | 0/157 [00:00<?, ?it/s]

24.0784275970459


  0%|          | 0/157 [00:00<?, ?it/s]

0.00035041836898817564


  0%|          | 0/157 [00:00<?, ?it/s]

2.3683866453170777


  0%|          | 0/157 [00:00<?, ?it/s]

13.203306437683105


  0%|          | 0/157 [00:00<?, ?it/s]

23.150868743896485


  0%|          | 0/157 [00:00<?, ?it/s]

0.00012405874358228174


  0%|          | 0/157 [00:00<?, ?it/s]

0.02220252730140346


  0%|          | 0/157 [00:00<?, ?it/s]

0.8334930953726173


  0%|          | 0/157 [00:00<?, ?it/s]

9.376436051940917


  0%|          | 0/157 [00:00<?, ?it/s]

20.379634716796875


  0%|          | 0/157 [00:00<?, ?it/s]

0.0006660923528901435


  0%|          | 0/157 [00:00<?, ?it/s]

0.030550511024927254


  0%|          | 0/157 [00:00<?, ?it/s]

1.5538501966476441


  0%|          | 0/157 [00:00<?, ?it/s]

10.933225189208985


  0%|          | 0/157 [00:00<?, ?it/s]

22.915690103149416


  0%|          | 0/157 [00:00<?, ?it/s]

0.0007163493095542435


  0%|          | 0/157 [00:00<?, ?it/s]

0.2244058738410473


  0%|          | 0/157 [00:00<?, ?it/s]

7.2472744552612305


  0%|          | 0/157 [00:00<?, ?it/s]

20.877379736328123


  0%|          | 0/157 [00:00<?, ?it/s]

29.40443511047363


In [43]:
import pickle as pkl
pkl.dump(all_totals, open("weights/cifar10_vit_weight.pkl", "wb"))

In [9]:
import numpy as np

In [22]:
thre = 0.2
net = load_model()
layer_remove = dict()
for layer in all_layer_names:
    layer_remove[layer] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[layer] for layer in all_layer_names]
    layer_weights = [eval("net." + layer + ".weight.cpu().detach().numpy()")
                     for layer in all_layer_names]
    combine = [np.abs(total * weight) for total, weight in zip(totals, layer_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    for idx,layer in enumerate(all_layer_names):
        if layer_remove[layer] is None:
            layer_remove[layer] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            layer_remove[layer] = layer_remove[layer] | t

  combine = np.array(combine)


In [23]:
temp = 0
all_num = 0
for layer in layer_remove:
    temp += layer_remove[layer].sum()
    all_num += layer_remove[layer].size
    print(layer, layer_remove[layer].mean())

embedding 0.6393110487196181
transformer.encoder_layers[0].norm1 1.0
transformer.encoder_layers[0].attn.query 0.3231455485026042
transformer.encoder_layers[0].attn.key 0.3710208468967014
transformer.encoder_layers[0].attn.value 0.2099728054470486
transformer.encoder_layers[0].attn.out 0.20081583658854166
transformer.encoder_layers[0].norm2 1.0
transformer.encoder_layers[0].mlp.fc1 0.24737591213650173
transformer.encoder_layers[0].mlp.fc2 0.16192330254448783
transformer.encoder_layers[1].norm1 0.9986979166666666
transformer.encoder_layers[1].attn.query 0.3634745279947917
transformer.encoder_layers[1].attn.key 0.3591766357421875
transformer.encoder_layers[1].attn.value 0.3278893364800347
transformer.encoder_layers[1].attn.out 0.4239976671006944
transformer.encoder_layers[1].norm2 1.0
transformer.encoder_layers[1].mlp.fc1 0.30961354573567706
transformer.encoder_layers[1].mlp.fc2 0.2048611111111111
transformer.encoder_layers[2].norm1 1.0
transformer.encoder_layers[2].attn.query 0.325600518

In [24]:
temp / all_num

0.3861793079619073

In [13]:
with torch.no_grad():
    net = load_model()
    correct, all = test_model(net, test_dataloader_all)
    print("原始准确率", correct / all)

  0%|          | 0/313 [00:00<?, ?it/s]

原始准确率 0.9849


In [28]:
with torch.no_grad():
    net = load_model()
    for layer in all_layer_names:
        # if len(eval("net." + layer + ".weight.shape")) == 2:
        try:
            exec("net." + layer + ".weight[~layer_remove[layer]] = 0")
        except:
            exec("net." + layer + ".weight[~layer_remove[layer],:] = 0")
        # exec("net." + layer + ".weight[~layer_remove[layer]] = 0")
    correct, all = test_model(net, test_dataloader_all)
    print("现在准确率", correct / all)

  0%|          | 0/313 [00:00<?, ?it/s]

现在准确率 0.9248


In [42]:
for key,value in layer_remove.items():
    print((eval("net." + key + ".weight")[np.where(value == False)]).sum())
    # break

tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
tensor(0., d

In [37]:
eval("net." + list(layer_remove.keys())[0] + ".weight").shape

torch.Size([768, 3, 16, 16])

In [27]:
with torch.no_grad():
    net = load_model()
    for layer in all_layer_names:
        keep_rate = layer_remove[layer].sum() / layer_remove[layer].size
        weight_flatten = eval("net." + layer + ".weight.cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(len(weight_flatten) * (1 - keep_rate))]
        try:
            exec("net." + layer + ".weight[eval('net.' + layer + '.weight.cpu().detach().numpy()') < threshold] = 0")
        except:
            exec("net." + layer + ".weight[eval('net.' + layer + '.weight.cpu().detach().numpy()') < threshold,:] = 0")
    correct, all = test_model(net, test_dataloader_all)
    print("去掉最大准确率", correct / all)

  0%|          | 0/313 [00:00<?, ?it/s]

去掉最大准确率 0.0962
