In [19]:
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn

transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.2860,), std=(0.3530,))
    ])
# 加载测试集
test_data = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True)

# 定义攻击函数
def ifgsm_attack(image, epsilon, data_grad, max_iter):
    # 初始化扰动值
    perturbed_image = image
    # 迭代更新扰动值
    for i in range(max_iter):
        # 计算梯度并更新扰动值
        perturbed_image.requires_grad = True
        output = model(perturbed_image)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target)
        model.zero_grad()
        loss.backward()
        data_grad = perturbed_image.grad.data
        perturbed_image = perturbed_image + epsilon * data_grad.sign()
        # 限制像素值范围在[0,1]
        perturbed_image = torch.clamp(perturbed_image, 0, 1)
        # 如果扰动图像已经被误分类，停止攻击
        if model(perturbed_image).max(1, keepdim=True)[1] != target:
            break
    return perturbed_image

# 定义测试函数
def test(model, device, test_loader, epsilon):
    model.eval()
    correct = 0
    adv_examples = []
    for data, target in test_loader:
        # 找到被分类正确的图像
        data, target = data.to(device), target.to(device)
        output = model(data)
        init_pred = output.max(1, keepdim=True)[1]
        if init_pred.item() != target.item():
            continue
        correct += 1
        # 对该图像进行攻击
        data.requires_grad = True
        output = model(data)
        attack_target = (target + 1) % 10
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, attack_target)
        model.zero_grad()
        loss.backward()
        assert data.requires_grad == True
        data_grad = data.grad.data
        perturbed_data = ifgsm_attack(image=data, epsilon=epsilon, data_grad=data_grad, max_iter=1000)
        # 将生成的对抗样本加入列表
        adv_examples.append((data, perturbed_data, target))
        # 打印进度信息
        if len(adv_examples) >= 1000:
            break
        if len(adv_examples) % 100 == 0:
            print(f"Attack progress: {len(adv_examples)}/{1000}")
    # 计算攻击成功率
    final_acc = correct / float(len(test_loader))
    print(f"Correctly classified examples: {correct}/{len(test_loader)}")
    print(f"Attack success rate: {(1000-correct)/1000:.4f}")
    return adv_examples

# 设置攻击参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
epsilon = 0.1
model = torchvision.models.resnet34(weights=None, num_classes=10)
model.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.fc = nn.Linear(model.fc.in_features, 10)  # 更改最后一层全连接层
model.load_state_dict(torch.load('./checkpoints/checkpoint-60-93.71.pt'))
model.to(device)
# 对模型进行测试和攻击
adv_examples = test(model, device, test_loader, epsilon)


cpu


Traceback (most recent call last):
  File "_pydevd_bundle/pydevd_cython.pyx", line 1078, in _pydevd_bundle.pydevd_cython.PyDBFrame.trace_dispatch
  File "_pydevd_bundle/pydevd_cython.pyx", line 297, in _pydevd_bundle.pydevd_cython.PyDBFrame.do_wait_suspend
  File "/Users/alex/anaconda3/lib/python3.10/site-packages/debugpy/_vendored/pydevd/pydevd.py", line 1976, in do_wait_suspend
    keep_suspended = self._do_wait_suspend(thread, frame, event, arg, suspend_type, from_this_thread, frames_tracker)
  File "/Users/alex/anaconda3/lib/python3.10/site-packages/debugpy/_vendored/pydevd/pydevd.py", line 2011, in _do_wait_suspend
    time.sleep(0.01)
KeyboardInterrupt


KeyboardInterrupt: 