In [None]:
# 使用 FGSM 和 BIM 攻击 ComRecCNN 防御的 ResNet18，使用CIFAR-10数据集。
import torch
from utils.model.resnet import ResNet18
from utils.utils.dataloader import *
from utils.attacks import FGSM, BIM
from utils.model.comreccnn_3 import ComRecCNN

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
PATH_PARAMETERS = "models/cifar10/resnet.pth"
_, test_loader, _ = cifar10(100)

net = ResNet18().to(device).eval()
net.load_state_dict(torch.load(PATH_PARAMETERS))

cnn = ComRecCNN(net, device)
cnn.set_models_path("models/cifar10/comcnn.pth", "models/cifar10/reccnn.pth")
cnn.load_models_parameters()

atks = [
    FGSM,
    BIM,
]
for atk in atks:
    for eps in [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]:
        print("--" * 30)
        attack_succ = 0
        defend_succ = 0
        total_num = 0
        print(atk(net, eps))
        for i, (imgs, lbls) in enumerate(test_loader):
            imgs, lbls = imgs.to(device), lbls.to(device)

            adv_imgs = atk(net, eps)(imgs, lbls)
            dfd_imgs = cnn.defend(adv_imgs)

            outputs = net(adv_imgs)
            pred_indice = outputs.argmax(1)
            attack_succ += torch.eq(pred_indice, lbls).sum().item()

            outputs = net(dfd_imgs)
            pred_indice = outputs.argmax(1)
            defend_succ += torch.eq(pred_indice, lbls).sum().item()

            total_num += len(lbls)
        print(
            "{},    {}, ".format(
                attack_succ / total_num,
                defend_succ / total_num,
            )
        )


In [None]:
# 使用 DeepFool 攻击 ComRecCNN 防御的 ResNet18，使用CIFAR-10数据集。
import torch
from utils.model.resnet import ResNet18
from utils.utils.dataloader import *
from utils.attacks import DeepFool
from utils.model.comreccnn_3 import ComRecCNN

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
PATH_PARAMETERS = "models/cifar10/resnet_def.pth"
_, test_loader, _ = cifar10(100)

net = ResNet18().to(device).eval()
net.load_state_dict(torch.load(PATH_PARAMETERS))

cnn = ComRecCNN(net, device)
cnn.set_models_path("models/cifar10/comcnn.pth", "models/cifar10/reccnn.pth")
cnn.load_models_parameters()

for os in [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]:
    atks = [
        DeepFool(net, overshoot=os),
    ]
    print("--" * 30)
    for atk in atks:
        attack_succ = 0
        defend_succ = 0
        total_num = 0
        for i, (imgs, lbls) in enumerate(test_loader):
            print(i)
            imgs, lbls = imgs.to(device), lbls.to(device)

            adv_imgs = atk(imgs, lbls)
            dfd_imgs = cnn.defend(adv_imgs)

            outputs = net(adv_imgs)
            pred_indice = outputs.argmax(1)
            attack_succ += torch.eq(pred_indice, lbls).sum().item()

            outputs = net(dfd_imgs)
            pred_indice = outputs.argmax(1)
            defend_succ += torch.eq(pred_indice, lbls).sum().item()

            total_num += len(lbls)
        print(
            "{},    {}, ".format(
                attack_succ / total_num,
                defend_succ / total_num,
            )
        )
