# FGSM epsilon sweep

Run FGSM for multiple models and epsilons, save outputs, and plot adversarial accuracy.

In [None]:
import os
import re
import subprocess
import sys
from pathlib import Path

import matplotlib.pyplot as plt

EPSILONS = [0, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.25, 0.3]
MODELS = ["resnet18", "resnet34", "densenet121"]
DATASET = "cifar10"
MAX_IMAGES = 50
FGSM_SCRIPT = str(Path("fgsm.py").resolve())
OUTPUT_ROOT = Path("output") / "fgsm"


In [None]:
def run_fgsm(model: str, epsilon: float) -> float:
    eps_str = str(epsilon)
    save_dir = OUTPUT_ROOT / f"{DATASET}_{model}_{eps_str}"
    cmd = [
        sys.executable,
        FGSM_SCRIPT,
        "--model",
        model,
        "--dataset",
        DATASET,
        "--epsilon",
        eps_str,
        "--max-images",
        str(MAX_IMAGES),
        "--save-dir",
        str(save_dir),
    ]
    result = subprocess.run(cmd, check=True, capture_output=True, text=True)
    match = re.search(r"FGSM acc: ([0-9.]+)", result.stdout)
    if not match:
        raise RuntimeError(f"Could not parse FGSM acc for {model} eps={epsilon}.")
    return float(match.group(1))


In [None]:
accuracies = {model: [] for model in MODELS}

for model in MODELS:
    for epsilon in EPSILONS:
        acc = run_fgsm(model, epsilon)
        accuracies[model].append(acc)

accuracies


In [None]:
import csv

csv_path = Path("output") / "fgsm" / "fgsm_accuracy.csv"
csv_path.parent.mkdir(parents=True, exist_ok=True)

with csv_path.open("w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["model", "epsilon", "accuracy"])
    for model in MODELS:
        for epsilon, acc in zip(EPSILONS, accuracies[model]):
            writer.writerow([model, epsilon, acc])

csv_path


In [None]:
plt.figure(figsize=(8, 5))
for model in MODELS:
    plt.plot(EPSILONS, accuracies[model], marker="o", label=model)

plt.xlabel("epsilon")
plt.ylabel("accuracy")
plt.title("FGSM adversarial accuracy vs. epsilon")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()
