In [70]:
import torch
import numpy as np
import re
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.notebook import tqdm
from datasets import load_cifar10_choosen,load_cifar10
from attack import attack, test_model
device = "cuda" if torch.cuda.is_available() else "cpu"


In [71]:
choosen_classes = [0, 1, 2, 3]

In [72]:
train_dataloader, test_dataloader, test_dataloader_all = load_cifar10_choosen(
    choosen_classes=choosen_classes)

Files already downloaded and verified
Files already downloaded and verified


In [73]:
def train_choosen_classes(train_dataloader, test_dataloader, epochs=10, lr=0.001, model=None,num_classes=4,masks=None,save_path="weights/best_model.pth"):
    if model is None:
        model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        model.fc = nn.Linear(2048, num_classes)
        model = model.to(device)
    loss_func = nn.CrossEntropyLoss()
    best_acc = -np.inf
    for epoch in range(epochs):
        model.train()
        all_preds = []
        all_labels = []
        for x, y in tqdm(train_dataloader):
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = loss_func(output, y)
            loss.backward()
            for name, param in model.named_parameters():
                if masks is not None:
                    if "weight" in name:
                        reg = re.compile("\.\d\.")
                        finded = reg.findall(name)
                        if len(finded) == 0:
                            name = name[:-7]
                        else:
                            for f in finded:
                                f = f[1:-1]
                                name = name.replace(f".{f}.", f"[{f}].")
                            name = name[:-7]
                    if name in masks.keys():
                        param.grad *= torch.BoolTensor(~masks[name]).to(device)
                # print(name, param.shape)
                param.data -= lr * param.grad
                param.grad.zero_()
        model.eval()
        train_correct, train_total = test_model(model, train_dataloader)
        test_correct, test_total = test_model(model, test_dataloader)
        print(f"Epoch {epoch + 1} train acc: {train_correct / train_total}")
        print(f"Epoch {epoch + 1} test acc: {test_correct / test_total}")
        if test_correct / test_total > best_acc:
            best_acc = test_correct / test_total
            torch.save(model.state_dict(), save_path)


In [74]:
train_choosen_classes(train_dataloader, test_dataloader, epochs=10, lr=0.01,masks=None)

  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 1 train acc: 0.71545
Epoch 1 test acc: 0.6815


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 2 train acc: 0.8351
Epoch 2 test acc: 0.7905


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 3 train acc: 0.8902
Epoch 3 test acc: 0.82


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 4 train acc: 0.9253
Epoch 4 test acc: 0.8385


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 5 train acc: 0.9502
Epoch 5 test acc: 0.85225


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 6 train acc: 0.9599
Epoch 6 test acc: 0.855


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 7 train acc: 0.97165
Epoch 7 test acc: 0.85975


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 8 train acc: 0.9755
Epoch 8 test acc: 0.86125


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 9 train acc: 0.96785
Epoch 9 test acc: 0.846


  0%|          | 0/79 [00:00<?, ?it/s]

Epoch 10 train acc: 0.9876
Epoch 10 test acc: 0.869


In [6]:
def load_model(num_classes):
    model = resnet50(num_classes=num_classes)
    model.load_state_dict(torch.load("weights/best_model.pth"))
    model = model.to(device)
    model.eval()
    return model


In [7]:
model = load_model(num_classes=4)
correct, all = test_model(model, test_dataloader)
print(f"test acc: {correct / all}")

test acc: 0.87275


In [8]:
all_layer_names = list()
for name,param in model.named_parameters():
    if "weight" in name:
        reg = re.compile("\.\d\.")
        finded = reg.findall(name)
        if len(finded) == 0:
            all_layer_names.append(name[:-7])
        else:
            for f in finded:
                f = f[1:-1]
                name = name.replace(f".{f}.", f"[{f}].")
            all_layer_names.append(name[:-7])

In [9]:
train_loaders, test_loaders, test_loaders_all = load_cifar10()

Files already downloaded and verified
Files already downloaded and verified


In [10]:
all_totals = list()
for clz in choosen_classes:
    all_totals.append(attack(train_loaders[clz], all_layer_names, load_model, alpha=0.000015,num_classes=4))

100%|██████████| 20/20 [00:01<00:00, 10.54it/s]


0.06343568496704101


100%|██████████| 20/20 [00:01<00:00, 11.77it/s]


0.15740043601989745


100%|██████████| 20/20 [00:01<00:00, 11.74it/s]


0.3945034729003906


100%|██████████| 20/20 [00:01<00:00, 11.46it/s]


0.8578034790039063


100%|██████████| 20/20 [00:01<00:00, 11.25it/s]


1.5460462219238282


  layer_totals = np.array(layer_totals)
  x = np.array(x)
100%|██████████| 20/20 [00:01<00:00, 11.48it/s]


0.0386423077583313


100%|██████████| 20/20 [00:01<00:00, 11.69it/s]


0.1051037015914917


100%|██████████| 20/20 [00:01<00:00, 11.67it/s]


0.3275358253479004


100%|██████████| 20/20 [00:01<00:00, 11.40it/s]


0.8173002777099609


100%|██████████| 20/20 [00:01<00:00, 11.66it/s]


1.5427132446289062


100%|██████████| 20/20 [00:01<00:00, 11.44it/s]


0.10755170097351074


100%|██████████| 20/20 [00:01<00:00, 11.79it/s]


0.2414998825073242


100%|██████████| 20/20 [00:01<00:00, 11.58it/s]


0.5466656471252441


100%|██████████| 20/20 [00:01<00:00, 11.38it/s]


1.1162116973876952


100%|██████████| 20/20 [00:01<00:00, 11.71it/s]


2.0155708953857423


100%|██████████| 20/20 [00:01<00:00, 11.74it/s]


0.07487216796875


100%|██████████| 20/20 [00:01<00:00, 11.80it/s]


0.16055814151763917


100%|██████████| 20/20 [00:01<00:00, 11.57it/s]


0.38959078521728513


100%|██████████| 20/20 [00:01<00:00, 11.64it/s]


0.8893507110595703


100%|██████████| 20/20 [00:01<00:00, 11.15it/s]


1.6807810272216797


In [39]:
thre = 0.25
net = load_model(num_classes=4)
layer_remove = dict()
for layer in all_layer_names:
    layer_remove[layer] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[layer] for layer in all_layer_names]
    layer_weights = [eval("net." + layer + ".weight.cpu().detach().numpy()")
                     for layer in all_layer_names]
    combine = [np.abs(total * weight) for total, weight in zip(totals, layer_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    for idx,layer in enumerate(all_layer_names):
        if layer_remove[layer] is None:
            layer_remove[layer] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            layer_remove[layer] = layer_remove[layer] | t

  combine = np.array(combine)


In [40]:
temp = 0
all_num = 0
for layer in layer_remove:
    temp += layer_remove[layer].sum()
    all_num += layer_remove[layer].size
    print(layer, layer_remove[layer].mean())

conv1 0.9663052721088435
bn1 1.0
layer1[0].conv1 0.96728515625
layer1[0].bn1 1.0
layer1[0].conv2 0.7099066840277778
layer1[0].bn2 1.0
layer1[0].conv3 0.8153076171875
layer1[0].bn3 1.0
layer1[0].downsample[0] 0.92315673828125
layer1[0].downsample[1] 1.0
layer1[1].conv1 0.79901123046875
layer1[1].bn1 1.0
layer1[1].conv2 0.7809516059027778
layer1[1].bn2 1.0
layer1[1].conv3 0.858642578125
layer1[1].bn3 0.98046875
layer1[2].conv1 0.83544921875
layer1[2].bn1 1.0
layer1[2].conv2 0.8398980034722222
layer1[2].bn2 1.0
layer1[2].conv3 0.8336181640625
layer1[2].bn3 0.94921875
layer2[0].conv1 0.94921875
layer2[0].bn1 1.0
layer2[0].conv2 0.7417534722222222
layer2[0].bn2 1.0
layer2[0].conv3 0.77362060546875
layer2[0].bn3 0.974609375
layer2[0].downsample[0] 0.6619415283203125
layer2[0].downsample[1] 1.0
layer2[1].conv1 0.545867919921875
layer2[1].bn1 1.0
layer2[1].conv2 0.5102335611979166
layer2[1].bn2 1.0
layer2[1].conv3 0.7311248779296875
layer2[1].bn3 0.99609375
layer2[2].conv1 0.710296630859375
la

In [41]:
temp / all_num

0.35120302274225806

In [42]:
with torch.no_grad():
    net = load_model(num_classes=4)
    correct, all = test_model(net, test_dataloader)
    print("原始准确率", correct / all)


原始准确率 0.87275


In [43]:
with torch.no_grad():
    net = load_model(num_classes=4)
    for layer in all_layer_names:
        if len(eval("net." + layer + ".weight.shape")) == 2:
            exec("net." + layer + ".weight[~layer_remove[layer],:] = 0")
        else:
            exec("net." + layer + ".weight[~layer_remove[layer]] = 0")
    # for layer in ["layer3[0].conv2"]:
    #     exec("net." + layer + ".weight[~remove] = 0")
    correct, all = test_model(net, test_dataloader)
    print("现在准确率", correct / all)


现在准确率 0.8295


In [45]:
with torch.no_grad():
    net = load_model(num_classes=4)
    for layer in all_layer_names:
        keep_rate = layer_remove[layer].sum() / layer_remove[layer].size
        weight_flatten = eval("net." + layer + ".weight.cpu().detach().numpy()").flatten()
        threshold = np.sort(weight_flatten)[int(len(weight_flatten) * (1 - keep_rate))]
        if len(eval("net." + layer + ".weight.shape")) == 2:
            exec("net." + layer + ".weight[eval ('net.' + layer + '.weight.cpu().detach().numpy()') < threshold,:] = 0")
        else:
            exec("net." + layer + ".weight[eval('net.' + layer + '.weight.cpu().detach().numpy()') < threshold] = 0")
    correct, all = test_model(net, test_dataloader)
    print("对比准确率", correct / all)

对比准确率 0.25
