# PatchAttack sweep
Run PatchAttack for multiple models and patch sizes, save outputs, and plot adversarial accuracy.


In [1]:
import os
import re
import subprocess
import sys
from pathlib import Path

import matplotlib.pyplot as plt

PATCH_SIZES = [0.05, 0.1]
MODELS = ["resnet18", "resnet34"]
DATASET = "cifar10"
MAX_IMAGES = 50
PATCH_SCRIPT = str(Path("patchattack.py").resolve())
OUTPUT_ROOT = Path("output") / "patch"


ModuleNotFoundError: No module named 'matplotlib'

In [None]:
def run_patchattack(model: str, patch_size: float) -> float:
    size_str = str(patch_size)
    save_dir = OUTPUT_ROOT / f"{DATASET}_{model}_p{size_str}"
    cmd = [
        sys.executable,
        PATCH_SCRIPT,
        "--model",
        model,
        "--dataset",
        DATASET,
        "--patch-size",
        size_str,
        "--max-images",
        str(MAX_IMAGES),
        "--save-dir",
        str(save_dir),
    ]
    result = subprocess.run(cmd, check=True, capture_output=True, text=True)
    match = re.search(r"Adversarial acc: ([0-9.]+)", result.stdout)
    if not match:
        raise RuntimeError(f"Could not parse Adversarial acc for {model} patch_size={size_str}.")
    return float(match.group(1))


In [None]:
# PatchAttack sweep
accuracies = {model: [] for model in MODELS}

for model in MODELS:
    for patch_size in PATCH_SIZES:
        acc = run_patchattack(model, patch_size)
        accuracies[model].append(acc)

accuracies


In [None]:
# Save PatchAttack CSV
import csv

patch_csv = OUTPUT_ROOT / "patch_accuracy.csv"
patch_csv.parent.mkdir(parents=True, exist_ok=True)

with patch_csv.open("w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["model", "patch_size", "accuracy"])
    for model in MODELS:
        for patch_size, acc in zip(PATCH_SIZES, accuracies[model]):
            writer.writerow([model, patch_size, acc])

patch_csv


In [None]:
# Plot PatchAttack
plt.figure(figsize=(8, 5))
for model in MODELS:
    plt.plot(PATCH_SIZES, accuracies[model], marker="o", label=model)

plt.xlabel("patch_size")
plt.ylabel("accuracy")
plt.title("PatchAttack adversarial accuracy vs. patch size")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()
