## **Fast Gradient Signed Method** による敵対的ノイズ生成

In [None]:
# セットアップ
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import random

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# データセット準備
transform = transforms.Compose([
    transforms.ToTensor(),        # [0, 1] に正規化
])

train_set = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=64, shuffle=True, num_workers=2
)

test_set = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=1, shuffle=False, num_workers=2
)

print(f"Train size: {len(train_set)}, Test size: {len(test_set)}")

In [None]:
# モデルには単純なCNNを使用
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding = 1), # 28 * 28 -> 28 * 28
            nn.ReLU(),
            nn.MaxPool2d(2),                 # 28 * 28 -> 14 * 14
            nn.Conv2d(32, 64, 3, padding = 1), # 14 * 14 -> 14 * 14
            nn.ReLU(),
            nn.MaxPool2d(2),                   # 14 * 14 -> 7 * 7
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

model = SimpleCNN().to(device)

In [None]:
# 学習 & ベースライン精度の確認

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 1e-3)
epochs = 5 # デモ用に短縮

model.train()
for epoch in range(epochs):
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs}, loss: {running_loss/len(train_loader):.4f}")

print("Training finished.")

# ベースライン精度
model.eval()
correct = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()

base_acc = correct / len(test_set)
print(f"Baseline Accuracy: {base_acc*100:.2f}%")

In [None]:
# FGSM関数の定義
def fgsm_attack(image: torch.Tensor, epsilon: float, data_grad: torch.Tensor) -> torch.Tensor:
    """
    Fast Gradient Sign Method (FGSM) による敵対的な摂動を生成する

    Parameters
    ----------
    image: torch.Tensor
        元画像。形状(1, 1, 28, 28)を想定
    epsilon: float
        摂動の強度 0<=epsilon<=1
    data_grad: torch.Tensor
        損失関数の入力画像に対する勾配
    
    Returns
    -------
    torch.Tensor
        摂動を加えた画像 (同形状)
    """
    # 勾配の符号
    sign_data_grad = data_grad.sign()
    # 摂動を加算
    perturbed_image = image + epsilon * sign_data_grad
    # 画素値を[0, 1]にクリップ
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

In [None]:
# 可視化 (単一サンプル)
epsilons = [0.05 * i for i in range(7)]
examples = []

# 1枚ランダムに選ぶ
idx = random.randint(0, len(test_set)-1)
image, label = test_set[idx]
image = image.unsqueeze(0).to(device) # バッチ次元を付与
image, label = image.to(device), torch.tensor([label]).to(device)

# 勾配計算用にrequires_grad
image.requires_grad = True
output = model(image)
init_pred = output.max(1, keepdim=True)[1]

if init_pred.item() != label.item():
    print("Model misclassified the original sample, try another idx.")
else:
    loss = nn.CrossEntropyLoss()(output, label)
    model.zero_grad()
    loss.backward()
    data_grad = image.grad.data

    for eps in epsilons:
        perturbed = fgsm_attack(image, eps, data_grad)
        out = model(perturbed)
        pred = out.max(1, keepdim=True)[1].item()
        examples.append((eps, perturbed.squeeze().detach().cpu(), pred))

# 描画
def imshow(img, title=""):
    np_img = img.numpy()
    plt.imshow(np_img, cmap="gray")
    plt.title(title)
    plt.axis("off")

plt.figure(figsize=(10, 5))
for i, (eps, ex, pred) in enumerate(examples):
    plt.subplot(2, 4, i+1)
    imshow(ex, title=f"ε={eps}\nPred={pred}")
plt.tight_layout()
plt.show()

In [None]:
# 全テスト画像にFGSMを適用して精度低下をプロット
def test_fgsm(
        model: nn.Module,
        device: torch.device, 
        loader: torch.utils.data.DataLoader, 
        epsilon: float,
)-> float:
    """
    指定したepsilonでFGSM攻撃を施しテストデータ全体の精度を計算する

    Parameters:
    model: nn.Module
        評価対象の学習済みモデル
    device: torch.device
        実行デバイス (CPU / CUDA)
    loader: torch.DataLoader
        ミニバッチ単位で反復可能なテストデータローダー
    epsilon: float
        FGSMの摂動強度

    Returns
    -------
    float
        正解率
    """
    correct = 0
    criterion = nn.CrossEntropyLoss()
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        data.requires_grad = True

        output = model(data)
        init_pred = output.max(1, keepdim=True)[1]

        # 既に誤分類しているサンプルはスキップ
        if init_pred.item() != target.item():
            continue

        loss = criterion(output, target)
        model.zero_grad()
        loss.backward()
        data_grad = data.grad.data

        perturbed = fgsm_attack(data, epsilon, data_grad)
        out = model(perturbed)
        final_pred = out.max(1, keepdim=True)[1]

        if final_pred.item() == target.item():
            correct += 1

    return correct / float(len(loader))

accuracies = []
for eps in epsilons:
    acc = test_fgsm(model, device, test_loader, eps)
    accuracies.append(acc)
    print(f"Epsilon: {eps:.2f}  Test Accuracy: {acc*100:.2f}%")

# プロット
plt.figure()
plt.plot(epsilons, accuracies, marker="o")
plt.xlabel("Epsilon (ε)")
plt.ylabel("Accuracy")
plt.title("FGSM: Epsilon vs Accuracy")
plt.grid(True)
plt.show()