In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import pickle as pkl
from tqdm.notebook import tqdm
from model.resnet_cifar10 import resnet56
from attack import attack, test_model, parse_param
import random
from fgsm import FGSMGrad
device = "cuda" if torch.cuda.is_available() else "cpu"


In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


setup_seed(3407)


In [3]:
def load_resnet56():
    model = resnet56()
    state_dict = torch.load("weights/checkpoint_best.pth")["state_dict"]
    for key in list(state_dict.keys()):
        if key.startswith('module.'):
            state_dict[key[7:]] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    return model


In [4]:
data_min = np.min((0 - np.array([0.485, 0.456, 0.406])) /
                  np.array([0.229, 0.224, 0.225]))
data_max = np.max((1 - np.array([0.485, 0.456, 0.406])) /
                  np.array([0.229, 0.224, 0.225]))


In [5]:
fgsm = FGSMGrad(0.3 * 255, data_min, data_max)


In [6]:
from datasets import load_cifar10, load_cifar100
from models.resnet import load_cifar10_resnet50, load_cifar100_resnet50
model = load_resnet56()


In [7]:
all_param_names = list()
for name, param in model.named_parameters():
    if not "bn" in name:
        all_param_names.append(name)


In [8]:
all_param_names = all_param_names[:-2]


In [9]:
train_loaders, test_dataloaders, train_dataloader_all, test_dataloader_all = load_cifar10()


Files already downloaded and verified
Files already downloaded and verified


In [10]:
adv_data_1k = list()
adv_true_label_1k = list()
from tqdm.notebook import tqdm
pbar = tqdm(total=1000)
for data, target in train_dataloader_all:
    data = data.to(device)
    target = target.to(device)
    correct_index = torch.argmax(model(data), dim=-1) == target
    data = data[correct_index]
    target = target[correct_index]
    adv_data, success, _, _, _ = fgsm(
        model, data, target, num_steps=20, alpha=0.001, early_stop=False, use_sign=True, use_softmax=True)
    adv_data_1k.append(adv_data[success])
    adv_true_label_1k.append(target[success])
    pbar.update(adv_data[success].shape[0])
    if pbar.n >= 1000:
        break


  0%|          | 0/1000 [00:00<?, ?it/s]

In [11]:
adv_dataset = TensorDataset(torch.cat(adv_data_1k, dim=0), torch.cat(adv_true_label_1k, dim=0))
adv_dataloader = DataLoader(adv_dataset, batch_size=128, shuffle=True)

In [12]:
# all_totals = list()
# for i in range(10):
#     all_totals.append(attack(train_loaders[i], all_param_names,
#                       load_resnet56, train_dataloader_all, alpha=0.0001, num_steps=2, op="add"))


In [13]:
# pkl.dump(all_totals, open("weights/resnet56.pkl", "wb"))
all_totals = pkl.load(open("weights/resnet56.pkl", "rb"))


In [23]:
thre = 0.045
net = load_resnet56()
param_remove = dict()
for param in all_param_names:
    param_remove[param] = None
for i in range(len(all_totals)):
    totals = all_totals[i]
    totals = [totals[param] for param in all_param_names]
    param_weights = [eval("net." + parse_param(param) + ".cpu().detach().numpy()")
                     for param in all_param_names]
    combine = [np.abs(total * weight)
               for total, weight in zip(totals, param_weights)]
    combine = np.array(combine)
    combine_flatten = np.concatenate(
        [combine_.flatten() for combine_ in combine], axis=0)
    threshold = np.sort(combine_flatten)[
        ::-1][int(len(combine_flatten) * thre)]
    for idx, param in enumerate(all_param_names):
        if param_remove[param] is None:
            param_remove[param] = combine[idx] > threshold
        else:
            t = combine[idx] > threshold
            param_remove[param] = param_remove[param] | t


  combine = np.array(combine)


In [24]:
temp = 0
all_num = 0
for param in param_remove:
    temp += param_remove[param].sum()
    all_num += param_remove[param].size
    print(param, param_remove[param].mean())


conv1.weight 0.5648148148148148
layer1.0.conv1.weight 0.1762152777777778
layer1.0.conv2.weight 0.1918402777777778
layer1.1.conv1.weight 0.1015625
layer1.1.conv2.weight 0.07942708333333333
layer1.2.conv1.weight 0.16796875
layer1.2.conv2.weight 0.11067708333333333
layer1.3.conv1.weight 0.015190972222222222
layer1.3.conv2.weight 0.0013020833333333333
layer1.4.conv1.weight 0.08463541666666667
layer1.4.conv2.weight 0.016927083333333332
layer1.5.conv1.weight 0.0008680555555555555
layer1.5.conv2.weight 0.0
layer1.6.conv1.weight 0.19270833333333334
layer1.6.conv2.weight 0.140625
layer1.7.conv1.weight 0.18663194444444445
layer1.7.conv2.weight 0.07378472222222222
layer1.8.conv1.weight 0.2881944444444444
layer1.8.conv2.weight 0.15364583333333334
layer2.0.conv1.weight 0.2827690972222222
layer2.0.conv2.weight 0.059895833333333336
layer2.1.conv1.weight 0.0
layer2.1.conv2.weight 0.0
layer2.2.conv1.weight 0.0
layer2.2.conv2.weight 0.0
layer2.3.conv1.weight 0.0
layer2.3.conv2.weight 0.0
layer2.4.conv1.

In [25]:
temp / all_num


0.15257619909843642

In [26]:
with torch.no_grad():
    net = load_resnet56()
    preds, labels = test_model(net, test_dataloader_all)
    print("原始准确率", (preds.argmax(-1) == labels).mean())


原始准确率 0.9325


In [33]:
with torch.no_grad():
    net = load_resnet56()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())


现在准确率 0.1


In [97]:
with torch.no_grad():
    net = load_resnet56()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, adv_dataloader)
    print("抗进攻准确率", (preds.argmax(-1) == labels).mean())

抗进攻准确率 0.19288389513108614


In [28]:
def train_model(train_dataloader, test_dataloader, model, epochs=10, lr=0.001, masks=None, save_path="weights/best_model.pth"):
    loss_func = nn.CrossEntropyLoss(reduction="sum")
    best_acc = -np.inf
    for epoch in range(epochs):
        model.train()
        all_preds = []
        all_labels = []
        num = 0
        for x, y in tqdm(train_dataloader):
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = loss_func(output, y)
            # regularization_loss = 0
            # for param in model.parameters():
            #     regularization_loss += torch.norm(param)
            # loss = loss + 0.1 * regularization_loss
            loss.backward()
            for name, param in model.named_parameters():
                if masks is not None:
                    name = parse_param(name)
                    if name in masks.keys():
                        param.grad = param.grad * \
                            torch.BoolTensor(masks[name]).to(device)
                        # print(name)
                        # print(name, param.grad)
                # print(name, param.shape)
                param.data -= lr * param.grad
                param.grad.zero_()
        model.eval()
        train_pred, train_label = test_model(model, train_dataloader)
        test_pred, test_label = test_model(model, test_dataloader)
        print(
            f"Epoch {epoch + 1} train acc: {(train_pred.argmax(-1) == train_label).mean()}")
        print(
            f"Epoch {epoch + 1} test acc: {(test_pred.argmax(-1) == test_label).mean()}")
        if (test_pred.argmax(-1) == test_label).mean() > best_acc:
            best_acc = (test_pred.argmax(-1) == test_label).mean()
            torch.save(model.state_dict(), save_path)


In [34]:
train_model(train_dataloader_all, test_dataloader_all,
            net, epochs=100, lr=0.0005, masks=param_remove)


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 1 train acc: 0.95434
Epoch 1 test acc: 0.8883


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 2 train acc: 0.9689
Epoch 2 test acc: 0.8901


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 3 train acc: 0.98122
Epoch 3 test acc: 0.9013


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 4 train acc: 0.99044
Epoch 4 test acc: 0.9009


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 5 train acc: 0.99546
Epoch 5 test acc: 0.9063


  0%|          | 0/196 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [35]:
with torch.no_grad():
    net = resnet56()
    net.to(device)
    net.eval()
    for param in all_param_names:
        param_ = parse_param(param)
        try:
            exec("net." + param_ + "[~param_remove[param]] = 0")
        except:
            exec("net." + param_ + "[~param_remove[param],:] = 0")
    preds, labels = test_model(net, test_dataloader_all)
    print("现在准确率", (preds.argmax(-1) == labels).mean())


现在准确率 0.0982


In [36]:
train_model(train_dataloader_all, test_dataloader_all,
            net, epochs=100, lr=0.0005, masks=param_remove)


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 1 train acc: 0.48392
Epoch 1 test acc: 0.4837


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 2 train acc: 0.52886
Epoch 2 test acc: 0.5156


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 3 train acc: 0.62882
Epoch 3 test acc: 0.6106


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 4 train acc: 0.64616
Epoch 4 test acc: 0.6276


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 5 train acc: 0.64282
Epoch 5 test acc: 0.6115


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 6 train acc: 0.73112
Epoch 6 test acc: 0.6881


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 7 train acc: 0.79846
Epoch 7 test acc: 0.7367


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 8 train acc: 0.77864
Epoch 8 test acc: 0.72


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 9 train acc: 0.8278
Epoch 9 test acc: 0.7524


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 10 train acc: 0.85272
Epoch 10 test acc: 0.7557


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 11 train acc: 0.83884
Epoch 11 test acc: 0.7401


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 12 train acc: 0.8513
Epoch 12 test acc: 0.7444


  0%|          | 0/196 [00:00<?, ?it/s]

Epoch 13 train acc: 0.89014
Epoch 13 test acc: 0.7551


  0%|          | 0/196 [00:00<?, ?it/s]

KeyboardInterrupt: 