In [None]:
import torch
import numpy as np
import re
from torchvision.models import resnet50, ResNet50_Weights
from sklearn.metrics import confusion_matrix
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm.notebook import tqdm
from datasets import load_cifar10_choosen, load_cifar10
from attack import attack, test_model,parse_param
device = "cuda" if torch.cuda.is_available() else "cpu"
import random
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)


In [None]:
choosen_classes = [0, 1, 2, 3, 4, 5]


In [None]:
train_dataloader, test_dataloader, test_dataloader_all = load_cifar10_choosen(
    choosen_classes=choosen_classes)


In [107]:
def train_choosen_classes(train_dataloader, test_dataloader, epochs=10, lr=0.001, model=None, num_classes=6, masks=None, save_path="weights/best_model.pth"):
    if model is None:
        model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        model.fc = nn.Linear(2048, num_classes)
        model = model.to(device)
    loss_func = nn.CrossEntropyLoss(reduction="sum")
    best_acc = -np.inf
    for epoch in range(epochs):
        model.train()
        all_preds = []
        all_labels = []
        num = 0
        for x, y in tqdm(train_dataloader):
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = loss_func(output, y)
            regularization_loss = 0
            for param in model.parameters():
                regularization_loss += torch.norm(param)
            loss = loss + 0.1 * regularization_loss
            loss.backward()
            num += x.shape[0]
        for name, param in model.named_parameters():
            if masks is not None:
                name = parse_param(name)
                if name in masks.keys():
                    weights = torch.ones_like(param)
                    try:
                        weights[masks[name]] = weights[masks[name]]  * 0.01
                    except:
                        weights[masks[name], :] = weights[masks[name], :] * 0.01
                    param.grad = param.grad * weights
                    # print(name, param.grad)
            # print(name, param.shape)
            param.data -= lr * torch.sign(param.grad)
            param.grad.zero_()
        model.eval()
        train_pred, train_label = test_model(model, train_dataloader)
        test_pred, test_label = test_model(model, test_dataloader)
        print(f"Epoch {epoch + 1} train acc: {(train_pred.argmax(-1) == train_label).mean()}")
        print(f"Epoch {epoch + 1} test acc: {(test_pred.argmax(-1) == test_label).mean()}")
        if (test_pred.argmax(-1) == test_label).mean() > best_acc:
            best_acc = (test_pred.argmax(-1) == test_label).mean()
            torch.save(model.state_dict(), save_path)


In [None]:
train_choosen_classes(train_dataloader, test_dataloader,
                      epochs=3, lr=0.001, masks=None)


In [None]:
def load_model(num_classes):
    model = resnet50(num_classes=num_classes)
    model.load_state_dict(torch.load("weights/best_model.pth"))
    model = model.to(device)
    model.eval()
    return model


In [None]:
model = load_model(num_classes=6)
test_pred, test_label = test_model(model, test_dataloader)
print(f"test acc: {(test_pred.argmax(-1) == test_label).mean()}")


In [None]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name and not "shortcut.1" in name:
        all_param_names.append(name)

all_param_names = all_param_names[:-2]

In [None]:
train_loaders, test_loaders, test_loaders_all = load_cifar10()


In [None]:
all_totals = list()
for clz in choosen_classes:
    all_totals.append(attack(
        train_loaders[clz], all_param_names, load_model, alpha=0.00001,num_steps=4, num_classes=6))


In [41]:
thre = 0.8
net = load_model(6)
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)-1):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight) for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate([combine_.flatten() for combine_ in combine],axis=0)
    threshold = np.sort(combine_flatten)[::-1][int(len(combine_flatten) * thre)]
    for idx,param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t

  combine = np.array(combine)


In [42]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())


conv1.weight 1.0
layer1.0.conv1.weight 1.0
layer1.0.conv2.weight 0.9999457465277778
layer1.0.conv3.weight 0.97198486328125
layer1.0.downsample.0.weight 0.97265625
layer1.0.downsample.1.weight 0.97265625
layer1.0.downsample.1.bias 0.97265625
layer1.1.conv1.weight 0.95745849609375
layer1.1.conv2.weight 0.984375
layer1.1.conv3.weight 1.0
layer1.2.conv1.weight 1.0
layer1.2.conv2.weight 0.9998643663194444
layer1.2.conv3.weight 1.0
layer2.0.conv1.weight 1.0
layer2.0.conv2.weight 1.0
layer2.0.conv3.weight 0.998046875
layer2.0.downsample.0.weight 0.998046875
layer2.0.downsample.1.weight 0.998046875
layer2.0.downsample.1.bias 0.998046875
layer2.1.conv1.weight 0.9980316162109375
layer2.1.conv2.weight 0.9921875
layer2.1.conv3.weight 0.9921875
layer2.2.conv1.weight 1.0
layer2.2.conv2.weight 1.0
layer2.2.conv3.weight 1.0
layer2.3.conv1.weight 1.0
layer2.3.conv2.weight 0.9999796549479166
layer2.3.conv3.weight 1.0
layer3.0.conv1.weight 1.0
layer3.0.conv2.weight 0.9993048773871528
layer3.0.conv3.weigh

In [43]:
temp / all_num


0.7584785602545533

In [None]:
with torch.no_grad():
    net = load_model(6)
    preds, labels = test_model(net, test_dataloader)
    # print("原始准确率", (preds.argmax(-1) == labels).mean())
    print("原始准确率: " + str((preds.argmax(-1) == labels).mean()) + "\n")
    print(confusion_matrix(labels, preds.argmax(-1)))

In [None]:
with torch.no_grad():
    net = load_model(6)
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader)
    # print("现在准确率", (preds.argmax(-1) == labels).mean())
    print("现在准确率: " + str((preds.argmax(-1) == labels).mean()) + "\n")
    print(confusion_matrix(labels, preds.argmax(-1)))

In [None]:
param_remove

In [None]:
for name, param in net.named_parameters():
    print(name, param.shape)

In [44]:
param_remove["fc.weight"] = np.ones((6, 2048),dtype=bool)
param_remove["fc.weight"][-1,:] = False
param_remove["fc.bias"] = np.ones((6,),dtype=bool)
param_remove["fc.bias"][-1] = False

In [45]:
param_remove

{'conv1.weight': array([[[[ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          ...,
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True]],
 
         [[ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          ...,
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True]],
 
         [[ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          [ True,  True,  True, ...,  True,  True,  True],
          ...,
          [ True,  True,  True, ...,  True,  Tru

In [46]:
# with torch.no_grad():
#     net = load_model(6)
#     for name, param in net.named_parameters():
#         if name in param_remove:
#             try:
#                 exec("net." + parse_param(name) + "[~param_remove[name]] = 0.00001")
#             except:
#                 exec("net." + parse_param(name) + "[~param_remove[name],:] = 0.00001")
#     preds, labels = test_model(net, test_dataloader)
#     print(preds)
#     # print("现在准确率", (preds.argmax(-1) == labels).mean())
#     print("现在准确率: " + str((preds.argmax(-1) == labels).mean()) + "\n")
#     print(confusion_matrix(labels, preds.argmax(-1)))

In [81]:
train_dataloader_one_class, test_dataloader_one_class, test_dataloader_all = load_cifar10_choosen(
    choosen_classes=6)

Files already downloaded and verified
Files already downloaded and verified


In [82]:
train_dataloader_one_class.dataset.targets = (np.array(train_dataloader_one_class.dataset.targets) - 1).tolist()
test_dataloader_one_class.dataset.targets = (np.array(test_dataloader_one_class.dataset.targets) - 1).tolist()

In [83]:
from torch.utils.data import TensorDataset, DataLoader
test_x = list()
test_y = list()
for x,y in test_dataloader_one_class:
    test_x.append(x)
    test_y.append(y)
for x,y in test_dataloader:
    test_x.append(x[y != 5])
    test_y.append(y[y != 5])
test_x = torch.cat(test_x,dim=0)
test_y = torch.cat(test_y,dim=0)
test_dataloader_new = DataLoader(TensorDataset(test_x,test_y),batch_size=256,shuffle=False)

In [110]:
train_choosen_classes(train_dataloader_one_class, test_dataloader_one_class,epochs=4, lr=0.0003, masks=param_remove,save_path="weights/best_model_one_class.pth",model=load_model(6))

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1 train acc: 0.4998
Epoch 1 test acc: 0.5


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 2 train acc: 0.3758
Epoch 2 test acc: 0.269


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 3 train acc: 0.5448
Epoch 3 test acc: 0.349


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 4 train acc: 0.7174
Epoch 4 test acc: 0.452


In [111]:
with torch.no_grad():
    net = resnet50(num_classes=6)
    net.load_state_dict(torch.load("weights/best_model_one_class.pth"))
    net = net.to(device)
    net.eval()
    # for param in all_param_names:
    #     param_ = parse_param(param)
    #     try:
    #         exec("net." + param_ + "[~param_remove[param]] = 0")
    #     except:
    #         exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_new)
    # print("现在准确率", (preds.argmax(-1) == labels).mean())
    print("现在准确率: " + str((preds.argmax(-1) == labels).mean()) + "\n")
    print(confusion_matrix(labels, preds.argmax(-1)))
    for i in range(6):
        print(i, (preds.argmax(-1) == labels)[labels==i].mean())

现在准确率: 0.5245

[[965  10   6   4   0  15]
 [150 803   3  15   0  29]
 [278  11 462  41   7 201]
 [127   9  27 282   1 554]
 [272  11  65  98 135 419]
 [214  18  88 173   7 500]]
0 0.965
1 0.803
2 0.462
3 0.282
4 0.135
5 0.5


In [112]:
train_choosen_classes(train_dataloader_one_class, test_dataloader_one_class,epochs=4, lr=0.0003, masks=None,save_path="weights/best_model_one_class2.pth",model=load_model(6))

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1 train acc: 0.4934
Epoch 1 test acc: 0.492


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 2 train acc: 0.3252
Epoch 2 test acc: 0.214


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 3 train acc: 0.5544
Epoch 3 test acc: 0.371


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch 4 train acc: 0.6952
Epoch 4 test acc: 0.519


In [113]:
with torch.no_grad():
    net = resnet50(num_classes=6)
    net.load_state_dict(torch.load("weights/best_model_one_class2.pth"))
    net = net.to(device)
    net.eval()
    # for param in all_param_names:
    #     param_ = parse_param(param)
    #     try:
    #         exec("net." + param_ + "[~param_remove[param]] = 0")
    #     except:
    #         exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_new)
    # print("现在准确率", (preds.argmax(-1) == labels).mean())
    print("现在准确率: " + str((preds.argmax(-1) == labels).mean()) + "\n")
    print(confusion_matrix(labels, preds.argmax(-1)))
    for i in range(6):
        print(i, (preds.argmax(-1) == labels)[labels==i].mean())

现在准确率: 0.4865

[[779  12  12  11  10 176]
 [272 404   3   9  91 221]
 [118  12 520  34  29 287]
 [ 49   9  25 251  14 652]
 [ 88  11  62  51 446 342]
 [ 81  43 124 144  89 519]]
0 0.779
1 0.404
2 0.52
3 0.251
4 0.446
5 0.519
