In [1]:
import torch
import numpy as np
import re
from torchvision.models import resnet50, ResNet50_Weights
from sklearn.metrics import confusion_matrix
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.notebook import tqdm
from datasets import load_cifar10_choosen, load_cifar10
from attack import attack, test_model,parse_param
device = "cuda" if torch.cuda.is_available() else "cpu"
import random
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)


In [2]:
choosen_classes = [0, 1, 2, 3, 4, 5]


In [3]:
train_dataloader, test_dataloader, test_dataloader_all = load_cifar10_choosen(
    choosen_classes=choosen_classes)


Files already downloaded and verified
Files already downloaded and verified


In [4]:
def train_choosen_classes(train_dataloader, test_dataloader, epochs=10, lr=0.001, model=None, num_classes=6, masks=None, save_path="weights/best_model.pth"):
    if model is None:
        model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        model.fc = nn.Linear(2048, num_classes)
        model = model.to(device)
    loss_func = nn.CrossEntropyLoss(reduction="sum")
    best_acc = -np.inf
    for epoch in range(epochs):
        model.train()
        all_preds = []
        all_labels = []
        num = 0
        for x, y in tqdm(train_dataloader):
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = loss_func(output, y)
            regularization_loss = 0
            for param in model.parameters():
                regularization_loss += torch.norm(param)
            loss = loss + 0.1 * regularization_loss
            loss.backward()
            num += x.shape[0]
        for name, param in model.named_parameters():
            if masks is not None:
                name = parse_param(name)
                if name in masks.keys():
                    weights = torch.ones_like(param)
                    try:
                        weights[masks[name]] = weights[masks[name]]  * 0.01
                    except:
                        weights[masks[name], :] = weights[masks[name], :] * 0.01
                    param.grad = param.grad * weights
                    # print(name, param.grad)
            # print(name, param.shape)
            param.data -= lr * torch.sign(param.grad)
            param.grad.zero_()
        model.eval()
        train_pred, train_label = test_model(model, train_dataloader)
        test_pred, test_label = test_model(model, test_dataloader)
        print(f"Epoch {epoch + 1} train acc: {(train_pred.argmax(-1) == train_label).mean()}")
        print(f"Epoch {epoch + 1} test acc: {(test_pred.argmax(-1) == test_label).mean()}")
        if (test_pred.argmax(-1) == test_label).mean() > best_acc:
            best_acc = (test_pred.argmax(-1) == test_label).mean()
            torch.save(model.state_dict(), save_path)


In [None]:
train_choosen_classes(train_dataloader, test_dataloader,
                      epochs=3, lr=0.001, masks=None)


In [5]:
def load_model(num_classes):
    model = resnet50(num_classes=num_classes)
    model.load_state_dict(torch.load("weights/best_model.pth"))
    model = model.to(device)
    model.eval()
    return model


In [6]:
model = load_model(num_classes=6)
test_pred, test_label = test_model(model, test_dataloader)
print(f"test acc: {(test_pred.argmax(-1) == test_label).mean()}")


test acc: 0.8395


In [7]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)

all_param_names = all_param_names[:-2]

In [8]:
train_loaders, test_loaders, test_loaders_all = load_cifar10()


Files already downloaded and verified
Files already downloaded and verified


In [9]:
all_totals = list()
for clz in choosen_classes:
    all_totals.append(attack(
        train_loaders[clz], all_param_names, load_model, alpha=0.00001,num_steps=4, num_classes=6))


100%|██████████| 20/20 [00:02<00:00,  9.32it/s]


0.07726547727584838


100%|██████████| 20/20 [00:02<00:00,  9.77it/s]


0.14740314254760742


100%|██████████| 20/20 [00:01<00:00, 10.53it/s]


0.25791712951660156


100%|██████████| 20/20 [00:01<00:00, 10.33it/s]


0.4243775115966797


  param_totals = np.array(param_totals)
  x = np.array(x)
100%|██████████| 20/20 [00:01<00:00, 10.04it/s]


0.025977389174699782


100%|██████████| 20/20 [00:01<00:00, 10.25it/s]


0.056364512157440184


100%|██████████| 20/20 [00:01<00:00, 10.05it/s]


0.11205352020263672


100%|██████████| 20/20 [00:01<00:00, 10.57it/s]


0.20481037368774413


100%|██████████| 20/20 [00:01<00:00, 10.49it/s]


0.1707543472290039


100%|██████████| 20/20 [00:01<00:00, 10.30it/s]


0.26580281372070313


100%|██████████| 20/20 [00:01<00:00, 10.28it/s]


0.4068273056030273


100%|██████████| 20/20 [00:01<00:00, 10.24it/s]


0.6067653778076172


100%|██████████| 20/20 [00:01<00:00, 10.46it/s]


0.09997566471099853


100%|██████████| 20/20 [00:01<00:00, 10.28it/s]


0.150022855758667


100%|██████████| 20/20 [00:01<00:00, 10.27it/s]


0.22474251213073732


100%|██████████| 20/20 [00:01<00:00, 10.72it/s]


0.33471857681274414


100%|██████████| 20/20 [00:01<00:00, 10.83it/s]


0.07433102321624756


100%|██████████| 20/20 [00:01<00:00, 10.23it/s]


0.13142997093200684


100%|██████████| 20/20 [00:01<00:00, 10.34it/s]


0.22438429412841798


100%|██████████| 20/20 [00:01<00:00, 10.35it/s]


0.3676878646850586


100%|██████████| 20/20 [00:01<00:00, 10.12it/s]


0.2892280326843262


100%|██████████| 20/20 [00:01<00:00, 10.25it/s]


0.4544783889770508


100%|██████████| 20/20 [00:01<00:00, 10.36it/s]


0.6904158203125


100%|██████████| 20/20 [00:01<00:00, 10.22it/s]


1.008963104248047


In [10]:
thre = 0.8
net = load_model(6)
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)-1):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [total for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    for idx,param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t

  combine = np.array(combine)


In [11]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())


conv1.weight 1.0
layer1.0.conv1.weight 1.0
layer1.0.conv2.weight 0.9999457465277778
layer1.0.conv3.weight 0.97198486328125
layer1.0.downsample.0.weight 0.97265625
layer1.0.downsample.1.weight 0.97265625
layer1.0.downsample.1.bias 0.97265625
layer1.1.conv1.weight 0.95745849609375
layer1.1.conv2.weight 0.984375
layer1.1.conv3.weight 1.0
layer1.2.conv1.weight 1.0
layer1.2.conv2.weight 0.9998643663194444
layer1.2.conv3.weight 1.0
layer2.0.conv1.weight 1.0
layer2.0.conv2.weight 1.0
layer2.0.conv3.weight 0.998046875
layer2.0.downsample.0.weight 0.998046875
layer2.0.downsample.1.weight 0.998046875
layer2.0.downsample.1.bias 0.998046875
layer2.1.conv1.weight 0.9980316162109375
layer2.1.conv2.weight 0.9921875
layer2.1.conv3.weight 0.9921875
layer2.2.conv1.weight 1.0
layer2.2.conv2.weight 1.0
layer2.2.conv3.weight 1.0
layer2.3.conv1.weight 1.0
layer2.3.conv2.weight 0.9999796549479166
layer2.3.conv3.weight 1.0
layer3.0.conv1.weight 1.0
layer3.0.conv2.weight 0.9993048773871528
layer3.0.conv3.weigh

In [12]:
temp / all_num


0.7584785602545533

In [13]:
with torch.no_grad():
    net = load_model(6)
    preds, labels = test_model(net, test_dataloader)
    # print("原始准确率", (preds.argmax(-1) == labels).mean())
    print("原始准确率: " + str((preds.argmax(-1) == labels).mean()) + "\n")
    print(confusion_matrix(labels, preds.argmax(-1)))

原始准确率: 0.8395

[[902  15  27  34  18   4]
 [ 18 959   5  13   2   3]
 [ 39   4 791  53  90  23]
 [  8   7  33 802  71  79]
 [ 16   3  20  54 890  17]
 [  6   3  32 207  59 693]]


In [14]:
with torch.no_grad():
    net = load_model(6)
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader)
    # print("现在准确率", (preds.argmax(-1) == labels).mean())
    print("现在准确率: " + str((preds.argmax(-1) == labels).mean()) + "\n")
    print(confusion_matrix(labels, preds.argmax(-1)))

现在准确率: 0.8391666666666666

[[902  15  27  35  17   4]
 [ 17 959   5  14   2   3]
 [ 40   4 790  52  91  23]
 [  8   7  33 802  71  79]
 [ 16   3  20  54 890  17]
 [  6   3  32 207  60 692]]


In [15]:
param_remove

{'conv1.weight': array([[[[ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          ...,
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True]],
 
         [[ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          ...,
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True]],
 
         [[ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          ...,
          [ True,  True,  True, ...,  True,  Tru

In [16]:
for name, param in net.named_parameters():
    print(name, param.shape)

conv1.weight torch.Size([64, 3, 7, 7])
bn1.weight torch.Size([64])
bn1.bias torch.Size([64])
layer1.0.conv1.weight torch.Size([64, 64, 1, 1])
layer1.0.bn1.weight torch.Size([64])
layer1.0.bn1.bias torch.Size([64])
layer1.0.conv2.weight torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight torch.Size([64])
layer1.0.bn2.bias torch.Size([64])
layer1.0.conv3.weight torch.Size([256, 64, 1, 1])
layer1.0.bn3.weight torch.Size([256])
layer1.0.bn3.bias torch.Size([256])
layer1.0.downsample.0.weight torch.Size([256, 64, 1, 1])
layer1.0.downsample.1.weight torch.Size([256])
layer1.0.downsample.1.bias torch.Size([256])
layer1.1.conv1.weight torch.Size([64, 256, 1, 1])
layer1.1.bn1.weight torch.Size([64])
layer1.1.bn1.bias torch.Size([64])
layer1.1.conv2.weight torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight torch.Size([64])
layer1.1.bn2.bias torch.Size([64])
layer1.1.conv3.weight torch.Size([256, 64, 1, 1])
layer1.1.bn3.weight torch.Size([256])
layer1.1.bn3.bias torch.Size([256])
layer1.2.conv1.weight tor

In [17]:
param_remove["fc.weight"] = np.ones((6, 2048),dtype=bool)
param_remove["fc.weight"][-1,:] = False
param_remove["fc.bias"] = np.ones((6,),dtype=bool)
param_remove["fc.bias"][-1] = False

In [18]:
param_remove

{'conv1.weight': array([[[[ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          ...,
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True]],
 
         [[ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          ...,
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True]],
 
         [[ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          ...,
          [ True,  True,  True, ...,  True,  Tru

In [19]:
# with torch.no_grad():
#     net = load_model(6)
#     for name, param in net.named_parameters():
#         if name in param_remove:
#             try:
#                 exec("net." + parse_param(name) + "[~param_remove[name]] = 0.00001")
#             except:
#                 exec("net." + parse_param(name) + "[~param_remove[name],:] = 0.00001")
#     preds, labels = test_model(net, test_dataloader)
#     print(preds)
#     # print("现在准确率", (preds.argmax(-1) == labels).mean())
#     print("现在准确率: " + str((preds.argmax(-1) == labels).mean()) + "\n")
#     print(confusion_matrix(labels, preds.argmax(-1)))

In [20]:
train_dataloader_one_class, test_dataloader_one_class, test_dataloader_all = load_cifar10_choosen(
    choosen_classes=6)

Files already downloaded and verified
Files already downloaded and verified


In [21]:
train_dataloader_one_class.dataset.targets = (np.array(train_dataloader_one_class.dataset.targets) - 1).tolist()
test_dataloader_one_class.dataset.targets = (np.array(test_dataloader_one_class.dataset.targets) - 1).tolist()

In [22]:
from torch.utils.data import TensorDataset, DataLoader
test_x = list()
test_y = list()
for x,y in test_dataloader_one_class:
    test_x.append(x)
    test_y.append(y)
for x,y in test_dataloader:
    test_x.append(x[y != 5])
    test_y.append(y[y != 5])
test_x = torch.cat(test_x,dim=0)
test_y = torch.cat(test_y,dim=0)
test_dataloader_new = DataLoader(TensorDataset(test_x,test_y),batch_size=256,shuffle=False)

In [23]:
train_choosen_classes(train_dataloader_one_class, test_dataloader_one_class,epochs=4, lr=0.0003, masks=param_remove,save_path="weights/best_model_one_class.pth",model=load_model(6))

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1 train acc: 0.4892
Epoch 1 test acc: 0.5


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 2 train acc: 0.335
Epoch 2 test acc: 0.22


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 3 train acc: 0.5578
Epoch 3 test acc: 0.401


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 4 train acc: 0.7034
Epoch 4 test acc: 0.441


In [24]:
with torch.no_grad():
    net = resnet50(num_classes=6)
    net.load_state_dict(torch.load("weights/best_model_one_class.pth"))
    net = net.to(device)
    net.eval()
    # for param in all_param_names:
    #     param_ = parse_param(param)
    #     try:
    #         exec("net." + param_ + "[~param_remove[param]] = 0")
    #     except:
    #         exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_new)
    # print("现在准确率", (preds.argmax(-1) == labels).mean())
    print("现在准确率: " + str((preds.argmax(-1) == labels).mean()) + "\n")
    print(confusion_matrix(labels, preds.argmax(-1)))
    for i in range(6):
        print(i, (preds.argmax(-1) == labels)[labels==i].mean())

现在准确率: 0.5243333333333333

[[964  10   8   3   0  15]
 [155 804   3   9   0  29]
 [270  13 471  42   8 196]
 [140  10  27 278   2 543]
 [288  13  64  91 129 415]
 [212  19  97 167   5 500]]
0 0.964
1 0.804
2 0.471
3 0.278
4 0.129
5 0.5


In [25]:
train_choosen_classes(train_dataloader_one_class, test_dataloader_one_class,epochs=4, lr=0.0003, masks=None,save_path="weights/best_model_one_class2.pth",model=load_model(6))

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1 train acc: 0.4794
Epoch 1 test acc: 0.499


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 2 train acc: 0.3692
Epoch 2 test acc: 0.265


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 3 train acc: 0.558
Epoch 3 test acc: 0.379


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 4 train acc: 0.7086
Epoch 4 test acc: 0.437


In [26]:
with torch.no_grad():
    net = resnet50(num_classes=6)
    net.load_state_dict(torch.load("weights/best_model_one_class2.pth"))
    net = net.to(device)
    net.eval()
    # for param in all_param_names:
    #     param_ = parse_param(param)
    #     try:
    #         exec("net." + param_ + "[~param_remove[param]] = 0")
    #     except:
    #         exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_new)
    # print("现在准确率", (preds.argmax(-1) == labels).mean())
    print("现在准确率: " + str((preds.argmax(-1) == labels).mean()) + "\n")
    print(confusion_matrix(labels, preds.argmax(-1)))
    for i in range(6):
        print(i, (preds.argmax(-1) == labels)[labels==i].mean())

现在准确率: 0.5235

[[961  12   7   4   0  16]
 [142 822   3  11   0  22]
 [275  13 458  48   7 199]
 [141  10  24 275   2 548]
 [296  13  63  96 126 406]
 [212  19  94 171   5 499]]
0 0.961
1 0.822
2 0.458
3 0.275
4 0.126
5 0.499
