<a href="https://colab.research.google.com/github/REDEVL4/Adversarial-Attacks-on-Deep-Neural-Networks/blob/AA_v4/AA_v4_multi_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1

Collecting torch==2.5.1
  Downloading torch-2.5.1-cp312-cp312-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchvision==0.20.1
  Downloading torchvision-0.20.1-cp312-cp312-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio==2.5.1
  Downloading torchaudio-2.5.1-cp312-cp312-manylinux1_x86_64.whl.metadata (6.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.5.1)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.5.1)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.5.1)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.5.1)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12

In [2]:
# =======================================================================================
# MURA v1.1 — Multi-backbone training (train/valid only) + robust resume + FGSM/PGD/CW
# - LOCAL or Google-Drive-ZIP mode (choose in CONFIG)
# - Five models: resnet50, densenet121, googlenet, efficientnet_b0, convnext_tiny
# - Per-model hyperparams (LR/WD/dropout/batch) + per-model checkpoint folders
# - No leakage: train uses /train, val uses /valid
# - Attacks: FGSM, PGD, C&W (torchattacks) with safe defaults + visuals & confusion matrices
# =======================================================================================

# ----------------
# Core imports
# ----------------
import os, re, time, random, json, math, warnings, zipfile, shutil
import tqdm
from pathlib import Path
from collections import defaultdict, Counter

warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from sklearn.metrics import (
    roc_auc_score, accuracy_score, f1_score, balanced_accuracy_score,
    confusion_matrix, ConfusionMatrixDisplay
)

# ================================================================
# 0) CONFIG — choose data mode (LOCAL or COLAB_DRIVE_ZIP)
# ================================================================
MODE = "COLAB_DRIVE_ZIP"   # "LOCAL" or "COLAB_DRIVE_ZIP"

# -- LOCAL: point to folder that directly contains 'train' and 'valid'
LOCAL_DATA_ROOT = r"C:\Users\reddy\Downloads\Documents\Capstone\Datasets\MURA-v1.1\MURA-v1.1"

# -- COLAB_DRIVE_ZIP: mount & extract once
# (safe to leave even if running locally; only used when MODE == "COLAB_DRIVE_ZIP")
COLAB_ZIP_PATH = "/content/drive/MyDrive/Capstone - adversarial attack on DNN/Implementation/MURA-v1.1.zip"
COLAB_EXTRACT_DIR = "/content/mura_data"
COLAB_RUN_DIR     = "/content/drive/MyDrive/mura_runs_v4"

# Experiment/global defaults (can be overridden per-model via get_model_config)
GLOBAL_EPOCHS         = 50           # set larger for full runs
GLOBAL_NUM_CLASSES    = 2
GLOBAL_SEED           = 42
GLOBAL_NUM_WORKERS    = 0           # Windows/CUDA/DirectML -> keep 0 to avoid spawn/pickle issues
GLOBAL_PIN_MEMORY     = False

# Attacks (tune for runtime)
ENABLE_ATTACKS        = True
FGSM_EPS              = 2/255
PGD_EPS               = 4/255
PGD_STEP_ALPHA        = 1/255
PGD_STEPS             = 5           # increase if you can afford the time
CW_CONFIDENCE         = 0.0         # typical default
CW_STEPS              = 100         # fewer steps to keep runtime reasonable
CW_LR                 = 0.01

# Visualization control (to keep notebooks responsive)
VISUALIZE_REALTIME      = True      # draw a few examples during attacks
VIS_EVERY_N_BATCHES     = 40        # how often (in batches) to visualize a panel
MAX_VIS_IMAGES_PER_STEP = 4         # per visualization step
SAVE_FIGS               = True

# Models to run
MODEL_LIST = ["googlenet", "efficientnet_b0","convnext_tiny"] #"resnet50", "densenet121"

In [3]:
# ================================================================
# 1) Device selection: DirectML -> CUDA -> CPU
# ================================================================
def pick_device(prefer_gpu=True):
    if prefer_gpu:
        # 1) DirectML (AMD/Intel on Windows)
        try:
            import torch_directml
            dml_dev = torch_directml.device()
            _ = torch.randn(1, device=dml_dev)
            print("Using DirectML GPU:", dml_dev)
            return dml_dev
        except Exception:
            pass
        # 2) CUDA
        if torch.cuda.is_available():
            print("Using CUDA GPU:", torch.cuda.get_device_name(0))
            return torch.device("cuda")
    print("Using CPU")
    return torch.device("cpu")

DEVICE = pick_device(True)


Using CUDA GPU: NVIDIA A100-SXM4-80GB


In [4]:
# ================================================================
# 2) Data root setup (LOCAL vs COLAB_DRIVE_ZIP)
# ================================================================
def _find_mura_root(root: Path) -> Path:
    """Find the folder that directly contains 'train' and 'valid'."""
    root = Path(root)
    if (root/"train").exists() and (root/"valid").exists():
        return root
    for p in root.iterdir():
        if p.is_dir() and (p/"train").exists() and (p/"valid").exists():
            return p
    raise FileNotFoundError("Could not locate 'train/' and 'valid/' under " + str(root))

if MODE.upper() == "COLAB_DRIVE_ZIP":
    # Mount drive only in Colab
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)

    EXTRACT_DIR = Path(COLAB_EXTRACT_DIR)
    EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
    if not any(EXTRACT_DIR.iterdir()):
        zipp = Path(COLAB_ZIP_PATH)
        assert zipp.exists(), f"ZIP not found: {zipp}"
        print(f"Extracting {zipp} -> {EXTRACT_DIR} …")
        with zipfile.ZipFile(zipp, 'r') as zf:
            zf.extractall(EXTRACT_DIR)
        print("Extraction complete.")

    DATA_ROOT = str(_find_mura_root(EXTRACT_DIR))
    RUN_ROOT  = Path(COLAB_RUN_DIR)
    RUN_ROOT.mkdir(parents=True, exist_ok=True)
else:
    DATA_ROOT = str(_find_mura_root(Path(LOCAL_DATA_ROOT)))
    RUN_ROOT  = Path(DATA_ROOT) / "AA_v4_multi_models_dml_runs_full"
    RUN_ROOT.mkdir(parents=True, exist_ok=True)

print("DATA_ROOT =", DATA_ROOT)
print("RUN_ROOT  =", RUN_ROOT)


Mounted at /content/drive
Extracting /content/drive/MyDrive/Capstone - adversarial attack on DNN/Implementation/MURA-v1.1.zip -> /content/mura_data …
Extraction complete.
DATA_ROOT = /content/mura_data/MURA-v1.1
RUN_ROOT  = /content/drive/MyDrive/mura_runs_v4


In [5]:
# ================================================================
# 3) Scan MURA (train/valid only) — NO mixing
# ================================================================
def scan_mura_split(split_root: Path, anatomy_subset=None):
    rows = []
    for part_dir in sorted(split_root.iterdir()):
        if not part_dir.is_dir(): continue
        if anatomy_subset and part_dir.name != anatomy_subset: continue
        for pat in part_dir.iterdir():
            if not pat.is_dir(): continue
            for study in pat.iterdir():
                if not study.is_dir(): continue
                label = 1 if "positive" in study.name.lower() else 0
                for f in study.iterdir():
                    if f.suffix.lower() in {".png", ".jpg", ".jpeg"}:
                        rows.append({"path": str(f), "label": label})
    return pd.DataFrame(rows)

ANATOMY_SUBSET = None  # or "XR_SHOULDER"
train_root = Path(DATA_ROOT)/"train"
valid_root = Path(DATA_ROOT)/"valid"
assert train_root.exists() and valid_root.exists(), "train/valid missing under DATA_ROOT"

train_df = scan_mura_split(train_root, anatomy_subset=ANATOMY_SUBSET)
val_df   = scan_mura_split(valid_root, anatomy_subset=ANATOMY_SUBSET)

print(f"Train images: {len(train_df)}  | Val images: {len(val_df)}")
print("Label dist (train):", Counter(train_df["label"].tolist()))
print("Label dist (val):  ", Counter(val_df["label"].tolist()))


Train images: 36812  | Val images: 3197
Label dist (train): Counter({0: 21939, 1: 14873})
Label dist (val):   Counter({0: 1667, 1: 1530})


In [6]:
# ================================================================
# 4) Cleaning helpers + study/patient IDs
# ================================================================
VALID_EXTS = {".png", ".jpg", ".jpeg"}

def clean_df(df):
    def ok(p):
        pth = Path(p)
        return (
            (pth.suffix.lower() in VALID_EXTS) and pth.is_file()
            and (not pth.name.startswith(".")) and (not pth.name.startswith("._"))
        )
    before = len(df)
    df2 = df[df["path"].map(ok)].copy()
    drop = before - len(df2)
    if drop:
        print(f"[clean_df] removed {drop} bad rows (hidden/missing/non-image).")
    return df2

def extract_patient_id(p):
    parts = Path(p).parts
    pid = next((q for q in parts if q.lower().startswith("patient")), None)
    return pid or "unknown"

def extract_study_id(p):
    path = Path(p)
    parts = path.parts
    patient = next((q for q in parts if q.lower().startswith("patient")), "unknown")
    # combine patient/study folder
    study = parts[parts.index(patient)+1] if patient in parts and (parts.index(patient)+1)<len(parts) else "study_unknown"
    return f"{patient}/{study}"

train_df = clean_df(train_df)
val_df   = clean_df(val_df)
for _df in (train_df, val_df):
    _df["patient_id"] = _df["path"].apply(extract_patient_id)
    _df["study_id"]   = _df["path"].apply(extract_study_id)
    _df.reset_index(drop=True, inplace=True)


[clean_df] removed 4 bad rows (hidden/missing/non-image).


In [7]:
# ## In case if wanna test out using some samples for a quick run or for random sampling
# train_df = train_df.sample(frac=0.001, random_state=GLOBAL_SEED).reset_index(drop=True)
# val_df   = val_df.sample(frac=0.001, random_state=GLOBAL_SEED).reset_index(drop=True)

In [8]:
# ================================================================
# 5) Preprocessing (Bone window → CLAHE → Edge mix → Auto-crop)
# ================================================================
class BoneWindowTransform:
    def __init__(self, lp=3, hp=97): self.lp=lp; self.hp=hp
    def __call__(self, img):
        f = img.astype(np.float32)
        nz = f[f>0]
        if nz.size==0: return img
        p_low, p_high = np.percentile(nz, self.lp), np.percentile(nz, self.hp)
        if p_high <= p_low: return img
        w = np.clip(f, p_low, p_high)
        w = ((w - p_low)/(p_high-p_low)*255.0).astype(np.uint8)
        return w

class ClaheTransform:
    def __init__(self, clip=2.5, tile=(8,8)):
        self.clahe = cv2.createCLAHE(clipLimit=clip, tileGridSize=tile)
    def __call__(self, img): return self.clahe.apply(img.astype(np.uint8))

class EdgeTransform:
    def __init__(self, alpha=0.5): self.alpha=float(alpha)
    def __call__(self, img):
        f = img.astype(np.float32)
        gx = cv2.Sobel(f, cv2.CV_32F, 1, 0, 3); gy = cv2.Sobel(f, cv2.CV_32F, 0, 1, 3)
        mag = np.sqrt(gx*gx + gy*gy); mag = (mag/(mag.max()+1e-6))*255.0
        out = np.clip(f + self.alpha*mag, 0, 255).astype(np.uint8)
        return out

class AutoCropTransform:
    def __init__(self, thresh=5): self.thresh=thresh
    def __call__(self, img):
        _, bw = cv2.threshold(img, self.thresh, 255, cv2.THRESH_BINARY)
        cnts,_ = cv2.findContours(bw, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if not cnts: return img
        x,y,w,h = cv2.boundingRect(max(cnts, key=cv2.contourArea))
        if w<16 or h<16: return img
        return img[y:y+h, x:x+w]

window   = BoneWindowTransform(3,97)
clahe    = ClaheTransform(2.5,(8,8))
edge     = EdgeTransform(0.5)
autocrop = AutoCropTransform(5)

def preprocess_chain(img):
    x = window(img); x = clahe(x); x = edge(x); x = autocrop(x)
    return x


In [9]:
### Transformation
IMAGENET_MEAN = [0.485,0.456,0.406]
IMAGENET_STD  = [0.229,0.224,0.225]

def make_transforms(img_size):
    tf_eval = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
    ])
    tf_train = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(img_size, scale=(0.9, 1.0)),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
    ])
    return tf_train, tf_eval

def safe_imread(p):
    img = cv2.imread(str(p), cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise FileNotFoundError(f"Could not read image: {p}")
    return img

# dataset returns (tensor, label, study_id) to support study-level eval
class MuraDataset(Dataset):
    def __init__(self, df, img_size=320, augment=False):
        self.paths  = df["path"].astype(str).tolist()
        self.labels = df["label"].astype(int).tolist()
        self.studies= df["study_id"].astype(str).tolist()
        self.img_size = img_size
        self.augment = augment
        self.tf_train, self.tf_eval = make_transforms(img_size)

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]; y = self.labels[idx]
        img = preprocess_chain(safe_imread(p))
        img = cv2.resize(img, (self.img_size, self.img_size), interpolation=cv2.INTER_AREA)
        img3 = np.stack([img,img,img], axis=-1)
        tfm = self.tf_train if self.augment else self.tf_eval
        x = tfm(img3)
        return x, y, self.studies[idx]


In [10]:
# ================================================================
# 6) Models + per-model hyperparams
# ================================================================
def get_model_config(name: str):
    name = name.lower()
    common = dict(num_classes=GLOBAL_NUM_CLASSES, label_smooth=0.05)
    if name == "resnet50":
        return dict(**common, img_size=320, batch_size=24, lr=3e-4, weight_decay=1e-4, dropout=0.40)
    if name == "densenet121":
        return dict(**common, img_size=320, batch_size=16, lr=2e-4, weight_decay=1e-4, dropout=0.20)
    if name == "googlenet":
        return dict(**common, img_size=320, batch_size=24, lr=3e-4, weight_decay=1e-4, dropout=0.30)
    if name == "efficientnet_b0":
        return dict(**common, img_size=320, batch_size=24, lr=1e-4, weight_decay=1e-5, dropout=0.30)
    if name == "convnext_tiny":
        return dict(**common, img_size=320, batch_size=24, lr=2e-4, weight_decay=5e-5, dropout=0.10)
    raise ValueError(f"Unknown model: {name}")

def build_model(name="resnet50", num_classes=2, dropout_p=0.4):
    name = name.lower()
    if name == "resnet50":
        m = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        in_f = m.fc.in_features
        m.fc = nn.Sequential(nn.Dropout(dropout_p), nn.Linear(in_f, num_classes))
    elif name == "densenet121":
        m = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        in_f = m.classifier.in_features
        m.classifier = nn.Sequential(nn.Dropout(dropout_p), nn.Linear(in_f, num_classes))
    # elif name == "googlenet":
    #     m = models.googlenet(weights=models.GoogLeNet_Weights.IMAGENET1K_V1, aux_logits=True)
    #     in_f = m.fc.in_features
    #     m.fc = nn.Sequential(nn.Dropout(dropout_p), nn.Linear(in_f, num_classes))
    elif name == "googlenet":
        m = models.googlenet(
            weights=models.GoogLeNet_Weights.IMAGENET1K_V1,
            aux_logits=True
        )
        # also force off in case someone toggles it later
        m.aux_logits = True
        in_f = m.fc.in_features
        m.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_f, num_classes))

    elif name == "efficientnet_b0":
        m = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        in_f = m.classifier[-1].in_features
        # keep its internal dropout and add ours before final linear
        m.classifier[-1] = nn.Identity()
        m.classifier = nn.Sequential(*list(m.classifier), nn.Dropout(dropout_p), nn.Linear(in_f, num_classes))
    elif name == "convnext_tiny":
        m = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1)
        in_f = m.classifier[-1].in_features
        m.classifier[-1] = nn.Identity()
        m.classifier = nn.Sequential(*list(m.classifier), nn.Dropout(dropout_p), nn.Linear(in_f, num_classes))
    else:
        raise ValueError("Unknown model: " + name)
    return m



In [11]:
# ================================================================
# 7) Checkpoints (per-model directories) + resume helpers
# ================================================================
def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)
    return p

def ckpt_paths_for_model(run_root: Path, model_name: str):
    base = ensure_dir(run_root / "checkpoints" / model_name)
    return {
        "latest": base / "latest.pt",
        "best":   run_root / f"{model_name}_best_by_auc.pt",
        "history_json": base / "history.json",
        "snapshots": ensure_dir(base / "snapshots")
    }

def get_rng_state():
    return {
        "py_random": random.getstate(),
        "np_random": np.random.get_state(),
        "torch_cpu": torch.get_rng_state().tolist(),
        "torch_cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
    }

def set_rng_state(st):
    try: random.setstate(st["py_random"])
    except: pass
    try: np.random.set_state(st["np_random"])
    except: pass
    try: torch.set_rng_state(torch.tensor(st["torch_cpu"], dtype=torch.uint8))
    except: pass
    if torch.cuda.is_available() and st.get("torch_cuda") is not None:
        try: torch.cuda.set_rng_state_all(st["torch_cuda"])
        except: pass

def atomic_save(state, path: Path):
    tmp = str(path) + ".tmp"
    torch.save(state, tmp)
    os.replace(tmp, path)

def save_history_json(history_dict, path: Path):
    try:
        with open(path, "w", encoding="utf-8") as f:
            json.dump(history_dict, f, indent=2)
    except Exception as e:
        print(f"[warn] could not write {path}: {e}")

def save_checkpoint_for_model(ck, epoch_idx, best_auc, history, model, optimizer=None, scheduler=None):
    state = {
        "epoch": epoch_idx,
        "model_name": model.__class__.__name__.lower(),
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict() if optimizer else None,
        "scheduler_state": scheduler.state_dict() if scheduler else None,
        "best_auc": best_auc,
        "history": history,
        "rng_state": get_rng_state(),
        "device_info": str(DEVICE),
    }
    atomic_save(state, ck["latest"])
    save_history_json(history, ck["history_json"])

def try_load_checkpoint_for_model(ck, model, optimizer=None, scheduler=None, source="latest",
                                  load_optimizer=True, load_scheduler=True, load_rng=True):
    if source == "latest":
        ckpt_path = ck["latest"]
    elif isinstance(source, int):
        ckpt_path = ck["snapshots"]/f"epoch_{source:03d}.pt"
    else:
        ckpt_path = Path(source)

    if not ckpt_path.exists():
        print(f"[warn] Checkpoint not found for {ckpt_path}")
        return None

    print(f"Found checkpoint for {ckpt_path}, loading …")
    ckpt = torch.load(ckpt_path, map_location=DEVICE)
    model.load_state_dict(ckpt["model_state"])

    if load_optimizer and optimizer and ckpt.get("optimizer_state") is not None:
        try: optimizer.load_state_dict(ckpt["optimizer_state"])
        except Exception as e: print("[warn] optimizer not restored:", e)
    if load_scheduler and scheduler and ckpt.get("scheduler_state") is not None:
        try: scheduler.load_state_dict(ckpt["scheduler_state"])
        except Exception as e: print("[warn] scheduler not restored:", e)
    if load_rng and ckpt.get("rng_state") is not None:
        set_rng_state(ckpt["rng_state"])

    start_epoch = int(ckpt.get("epoch",0)) + 1
    best_auc = float(ckpt.get("best_auc",-1.0))
    hist = ckpt.get("history", defaultdict(list))
    if isinstance(hist, defaultdict): hist=dict(hist)
    print(f"Resume at epoch {start_epoch} (best_auc={best_auc:.4f})")
    return start_epoch, best_auc, hist



In [12]:
# ================================================================
# 8) Metrics helpers (image-level + study-level)
# ================================================================
@torch.no_grad()
def predict_probs(model, loader):
    model.eval()
    all_probs, all_targets, all_studies = [], [], []
    for xb, yb, sb in loader:
        xb = xb.to(DEVICE, non_blocking=True)

        logits = model(xb)
        if isinstance(logits, (tuple, list)):
        # common case: first element is the main logits
          logits = logits[0]
        elif hasattr(logits, "logits"):  # namedtuple with a .logits field
          logits = logits.logits

        probs = torch.softmax(logits, dim=1)[:,1].detach().cpu().numpy()
        all_probs.append(probs); all_targets.append(yb.numpy()); all_studies.extend(list(sb))
    return np.concatenate(all_probs), np.concatenate(all_targets), np.array(all_studies)

def metrics_from_probs(probs, targets):
    probs = np.asarray(probs)
    targets = np.asarray(targets)
    preds = (probs >= 0.5).astype(int)
    out = {}
    try: out["AUROC"] = float(roc_auc_score(targets, probs))
    except Exception: out["AUROC"] = float("nan")
    out["ACC"]  = float(accuracy_score(targets, preds))
    out["BACC"] = float(balanced_accuracy_score(targets, preds))
    out["F1"]   = float(f1_score(targets, preds))
    return out

def study_level_metrics(studies, probs, targets):
    df = pd.DataFrame({"study": studies, "prob": probs, "y": targets})
    g = df.groupby("study")
    probs_s = g["prob"].mean().values
    targs_s = g["y"].max().astype(int).values
    return metrics_from_probs(probs_s, targs_s)

def plot_confusion(y_true, probs, title, out_path=None):
    preds = (np.asarray(probs) >= 0.5).astype(int)
    cm = confusion_matrix(y_true, preds)
    disp = ConfusionMatrixDisplay(cm, display_labels=["negative", "positive"])
    fig, ax = plt.subplots(figsize=(4,4))
    disp.plot(ax=ax, cmap="Blues", colorbar=False)
    ax.set_title(title)
    plt.tight_layout()
    if out_path: plt.savefig(out_path, dpi=140)
    plt.show()


In [13]:
# ================================================================
# 9) Attacks (FGSM/PGD/C&W) — robust wrappers
# ================================================================
# Try torchattacks; if missing, we implement FGSM/PGD manually
USE_TORCHATTACKS = False
try:
    import torchattacks as ta
    USE_TORCHATTACKS = True
except Exception:
    pass

def fgsm_attack(model, x, y, eps=2/255):
    x = x.clone().detach().to(DEVICE); x.requires_grad_(True)
    y = y.to(DEVICE)
    logits = model(x)
    loss = nn.CrossEntropyLoss()(logits, y)
    loss.backward()
    x_adv = x + eps * x.grad.sign()
    return torch.clamp(x_adv, -10, 10).detach()

def pgd_attack(model, x, y, eps=4/255, alpha=1/255, iters=7):
    x = x.clone().detach().to(DEVICE)
    y = y.to(DEVICE)
    x_adv = x.clone().detach()
    for _ in range(iters):
        x_adv.requires_grad_(True)
        logits = model(x_adv)
        loss = nn.CrossEntropyLoss()(logits, y)
        model.zero_grad(set_to_none=True)
        loss.backward()
        with torch.no_grad():
            x_adv = x_adv + alpha * x_adv.grad.sign()
            # project back to eps-ball around original x
            delta = torch.clamp(x_adv - x, min=-eps, max=eps)
            x_adv = torch.clamp(x + delta, -10, 10)
    return x_adv.detach()

def cw_attack_torchattacks(model, x, y, c=1.0, steps=100, lr=0.01, kappa=0.0):
    # Build only once per call with explicit kwargs (no None)
    attack = ta.CW(model, c=c, kappa=kappa, steps=steps, lr=lr)
    return attack(x, y)

def attack_and_eval(model, loader, attack_name, out_dir, params, vis_every=40, max_vis=4):
    """
    Runs attack across loader, logs metrics, and saves occasional visual panels.
    attack_name in {"clean","fgsm","pgd","cw"}.
    params: dict of params for the chosen attack (ignored for "clean").
    """
    ensure_dir(out_dir)
    model.eval()
    all_probs, all_targets = [], []
    shown = 0

    for b_idx, (xb, yb, sb) in enumerate(loader):
        xb = xb.to(DEVICE); yb = yb.to(DEVICE)

        if attack_name == "clean":
            x_use = xb
        elif attack_name == "fgsm":
            x_use = fgsm_attack(model, xb, yb, eps=float(params.get("eps", FGSM_EPS)))
        elif attack_name == "pgd":
            x_use = pgd_attack(
                model, xb, yb,
                eps=float(params.get("eps", PGD_EPS)),
                alpha=float(params.get("alpha", PGD_STEP_ALPHA)),
                iters=int(params.get("iters", PGD_STEPS)),
            )
        elif attack_name == "cw":
            if not USE_TORCHATTACKS:
                # Fallback: use a strong PGD if CW not available
                x_use = pgd_attack(model, xb, yb, eps=PGD_EPS, alpha=PGD_STEP_ALPHA, iters=PGD_STEPS)
            else:
                x_use = cw_attack_torchattacks(
                    model, xb, yb,
                    c=float(params.get("c", 1.0)),
                    steps=int(params.get("steps", CW_STEPS)),
                    lr=float(params.get("lr", CW_LR)),
                    kappa=float(params.get("kappa", CW_CONFIDENCE)),
                )
        else:
            raise ValueError("Unknown attack: " + attack_name)

        with torch.no_grad():
            logits = model(x_use)
            probs = torch.softmax(logits, dim=1)[:,1].detach().cpu().numpy()
        all_probs.append(probs)
        all_targets.append(yb.detach().cpu().numpy())

        # real-time visualization
        if VISUALIZE_REALTIME and (b_idx % vis_every == 0):
            # show up to max_vis images from this batch
            k = min(max_vis, x_use.size(0))
            x0 = xb[:k].detach().cpu()
            xa = x_use[:k].detach().cpu()
            p0 = torch.softmax(model(x0.to(DEVICE)), dim=1)[:,1].detach().cpu().numpy()
            pa = torch.softmax(model(xa.to(DEVICE)), dim=1)[:,1].detach().cpu().numpy()
            y0 = yb[:k].detach().cpu().numpy()

            def denorm(t):
                # tensor BCHW -> BHWC in [0,1]
                arr = t.numpy().transpose(0,2,3,1)
                arr = arr * np.array(IMAGENET_STD) + np.array(IMAGENET_MEAN)
                return np.clip(arr, 0, 1)

            clean = denorm(x0)
            adv   = denorm(xa)
            fig, axes = plt.subplots(2, k, figsize=(3.2*k, 6))
            for i in range(k):
                axes[0,i].imshow(clean[i]); axes[0,i].axis("off")
                axes[0,i].set_title(f"clean y={y0[i]} p1={p0[i]:.2f}")
                axes[1,i].imshow(adv[i]);   axes[1,i].axis("off")
                axes[1,i].set_title(f"{attack_name} p1={pa[i]:.2f}")
            plt.suptitle(f"{attack_name.upper()} — batch {b_idx}")
            plt.tight_layout()
            if SAVE_FIGS:
                plt.savefig(out_dir/f"{attack_name}_vis_batch{b_idx:04d}.png", dpi=130)
            plt.show()
            shown += k

    probs = np.concatenate(all_probs); targs = np.concatenate(all_targets)
    met   = metrics_from_probs(probs, targs)
    # confusion matrices
    plot_confusion(targs, probs, f"{attack_name.upper()} — image-level", out_path=(out_dir/f"{attack_name}_cm.png" if SAVE_FIGS else None))
    return probs, targs, met


In [None]:
# ================================================================
# 10) Orchestration: train/resume + validate + attacks (per model)
# ================================================================
random.seed(GLOBAL_SEED); np.random.seed(GLOBAL_SEED); torch.manual_seed(GLOBAL_SEED)

all_runs_summary = []

for mname in MODEL_LIST:
    print("\n" + "="*80)
    print("Model:", mname)

    # per-model config
    cfg = get_model_config(mname)
    IMG_SIZE   = cfg["img_size"]
    BATCH_SIZE = cfg["batch_size"]
    LR         = cfg["lr"]
    WD         = cfg["weight_decay"]
    DROPOUT    = cfg["dropout"]
    LABEL_SMOOTH = cfg["label_smooth"]

    # per-model run dirs and checkpoints
    model_run_dir = ensure_dir(RUN_ROOT / f"run_{mname}")
    ck = ckpt_paths_for_model(RUN_ROOT, mname)

    # data loaders (train/val only; NO mixing)
    train_ds = MuraDataset(train_df, img_size=IMG_SIZE, augment=True)
    val_ds   = MuraDataset(val_df,   img_size=IMG_SIZE, augment=False)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=GLOBAL_NUM_WORKERS, pin_memory=GLOBAL_PIN_MEMORY)
    val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=GLOBAL_NUM_WORKERS, pin_memory=GLOBAL_PIN_MEMORY)

    # build model + optim
    model = build_model(mname, GLOBAL_NUM_CLASSES, dropout_p=DROPOUT).to(DEVICE)
    criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH)
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2) #verbose=True

    # resume if possible
    history = defaultdict(list)
    best_auc = -1.0
    start_epoch = 1
    resume = try_load_checkpoint_for_model(ck, model, optimizer=optimizer, scheduler=scheduler, source="latest",
                                           load_optimizer=True, load_scheduler=True, load_rng=True)
    if resume is not None:
        start_epoch, best_auc, loaded_hist = resume
        for k,v in loaded_hist.items():
            history[k] = v if isinstance(v,list) else list(v)
        print(f"Resumed {mname} from epoch {start_epoch}")

    # ------------- TRAIN LOOP -------------
    for epoch in range(start_epoch, GLOBAL_EPOCHS+1):
        t0 = time.time()
        # TRAIN
        model.train()
        tr_loss_sum = 0.0; tr_seen=0; tr_correct=0
        pbar = tqdm.tqdm(train_loader, desc=f"{mname} Epoch {epoch}/{GLOBAL_EPOCHS}", leave=False)
        for xb, yb, _ in pbar:
            xb = xb.to(DEVICE); yb = yb.to(DEVICE)
            optimizer.zero_grad(set_to_none=True)
            logits = model(xb)
            if isinstance(logits, (tuple, list)):
                # common case: first element is the main logits
                logits = logits[0]
            elif hasattr(logits, "logits"):  # namedtuple with a .logits field
                logits = logits.logits
            loss = criterion(logits, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            bs = xb.size(0)
            tr_loss_sum += loss.item()*bs
            tr_correct  += (logits.argmax(1)==yb).sum().item()
            tr_seen     += bs
            pbar.set_postfix({"loss": f"{tr_loss_sum/max(1,tr_seen):.4f}", "acc": f"{tr_correct/max(1,tr_seen):.4f}"}, refresh=False)
        pbar.close()

        # VALID (loss/acc)
        model.eval()
        va_loss_sum=0.0; va_seen=0; va_correct=0
        with torch.no_grad():
            for xb, yb, _ in val_loader:
                xb=xb.to(DEVICE); yb=yb.to(DEVICE)
                logits = model(xb)
                if isinstance(logits, (tuple, list)):
                  # common case: first element is the main logits
                  logits = logits[0]
                elif hasattr(logits, "logits"):  # namedtuple with a .logits field
                  logits = logits.logits
                loss = criterion(logits, yb)
                bs = xb.size(0)
                va_loss_sum += loss.item()*bs
                va_correct  += (logits.argmax(1)==yb).sum().item()
                va_seen     += bs

        tr_loss = tr_loss_sum/max(1,tr_seen); tr_acc=tr_correct/max(1,tr_seen)
        va_loss = va_loss_sum/max(1,va_seen); va_acc=va_correct/max(1,va_seen)

        # VALID AUROC image-level
        probs_val, targs_val, studies_val = predict_probs(model, val_loader)
        val_img_metrics = metrics_from_probs(probs_val, targs_val)
        val_auc = val_img_metrics["AUROC"]

        history["train_loss"].append(tr_loss); history["train_acc"].append(tr_acc)
        history["val_loss"].append(va_loss);   history["val_acc"].append(va_acc)
        history["val_auc"].append(val_auc)

        # Save latest & snapshot
        save_checkpoint_for_model(ck, epoch, best_auc, history, model, optimizer, scheduler)
        if epoch % 10 == 0:
            snap = ck["snapshots"]/f"epoch_{epoch:03d}.pt"
            torch.save({
                "epoch": epoch, "model_name": mname, "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(),
                "best_auc": best_auc, "history": history, "rng_state": get_rng_state()
            }, snap)
            print("Saved snapshot:", snap)

        # Save best-by-AUROC
        saved_msg = ""
        if not np.isnan(val_auc) and val_auc > best_auc + 1e-5:
            best_auc = float(val_auc)
            torch.save(model.state_dict(), ck["best"])
            saved_msg = f"  saved_best(AUROC={best_auc:.4f})"

        # schedule on val loss
        try: scheduler.step(va_loss)
        except: pass

        print(f"Epoch {epoch} {mname}: train_loss={tr_loss:.4f} acc={tr_acc:.4f} | val_loss={va_loss:.4f} acc={va_acc:.4f} | val_AUROC={val_auc:.4f}{saved_msg}  time={time.time()-t0:.1f}s")

    # After training: ensure best weights loaded for attacks
    if ck["best"].exists():
        print("Loading best weights for attacks:", ck["best"])
        model.load_state_dict(torch.load(ck["best"], map_location=DEVICE))

    # Plot training curves
    try:
        fig, ax = plt.subplots(1,2, figsize=(10,4))
        ax[0].plot(history["train_loss"], label="train_loss")
        ax[0].plot(history["val_loss"],   label="val_loss")
        ax[0].legend(); ax[0].set_title(f"{mname} — Loss")
        ax[1].plot(history["train_acc"], label="train_acc")
        ax[1].plot(history["val_acc"],   label="val_acc")
        ax[1].plot(history["val_auc"],   label="val_AUROC")
        ax[1].legend(); ax[1].set_title(f"{mname} — Acc/AUROC")
        plt.tight_layout()
        if SAVE_FIGS: plt.savefig(model_run_dir/f"{mname}_curves.png", dpi=140)
        plt.show()
    except Exception as e:
        print("[warn] curve plotting failed:", e)

    # ----------------- Attacks -----------------
    attack_summary = {"model": mname}

    if ENABLE_ATTACKS:
        out_adv_dir = ensure_dir(RUN_ROOT / "adv_examples" / mname)
        # CLEAN
        probs_c, targs_c, met_c = attack_and_eval(model, val_loader, "clean", out_adv_dir/"clean", {},
                                                  vis_every=VIS_EVERY_N_BATCHES, max_vis=MAX_VIS_IMAGES_PER_STEP)
        attack_summary["clean"] = met_c

        # FGSM
        probs_f, targs_f, met_f = attack_and_eval(model, val_loader, "fgsm", out_adv_dir/"fgsm",
                                                  {"eps": FGSM_EPS},
                                                  vis_every=VIS_EVERY_N_BATCHES, max_vis=MAX_VIS_IMAGES_PER_STEP)
        attack_summary["fgsm"] = met_f

        # PGD
        probs_p, targs_p, met_p = attack_and_eval(model, val_loader, "pgd", out_adv_dir/"pgd",
                                                  {"eps": PGD_EPS, "alpha": PGD_STEP_ALPHA, "iters": PGD_STEPS},
                                                  vis_every=VIS_EVERY_N_BATCHES, max_vis=MAX_VIS_IMAGES_PER_STEP)
        attack_summary["pgd"] = met_p

        # C&W (torchattacks if available, else falls back to PGD)
        params_cw = {"c":1.0, "steps":CW_STEPS, "lr":CW_LR, "kappa":CW_CONFIDENCE}
        probs_w, targs_w, met_w = attack_and_eval(model, val_loader, "cw", out_adv_dir/"cw",
                                                  params_cw, vis_every=VIS_EVERY_N_BATCHES, max_vis=MAX_VIS_IMAGES_PER_STEP)
        attack_summary["cw"] = met_w

        # Overlay comparison bar for this model (ACC/F1/AUROC)
        try:
            labels = ["clean","fgsm","pgd","cw"]
            aurocs = [attack_summary[k]["AUROC"] for k in labels]
            accs   = [attack_summary[k]["ACC"]   for k in labels]
            f1s    = [attack_summary[k]["F1"]    for k in labels]

            fig, ax = plt.subplots(1,3, figsize=(13,3.8))
            ax[0].bar(labels, aurocs); ax[0].set_title(f"{mname} AUROC")
            ax[1].bar(labels, accs);   ax[1].set_title(f"{mname} ACC")
            ax[2].bar(labels, f1s);    ax[2].set_title(f"{mname} F1")
            for a in ax:
                for label in a.get_xticklabels(): label.set_rotation(25)
            plt.tight_layout()
            if SAVE_FIGS: plt.savefig(model_run_dir/f"{mname}_attack_bar.png", dpi=140)
            plt.show()
        except Exception as e:
            print("[warn] attack comparison plotting failed:", e)

    # store per-model summary
    all_runs_summary.append({"model": mname, "history": {k:list(v) for k,v in history.items()}, "attacks": attack_summary})

# Save the global summary
with open(RUN_ROOT/"results_summary.json", "w") as f:
    json.dump(all_runs_summary, f, indent=2)
print("Saved:", RUN_ROOT/"results_summary.json")

# Optional: combined comparison across models (val AUROC last epoch)
try:
    fig, ax = plt.subplots(figsize=(7,4))
    names = []
    last_aurocs = []
    for rec in all_runs_summary:
        names.append(rec["model"])
        a = rec["history"].get("val_auc", [])
        last_aurocs.append(a[-1] if len(a)>0 else np.nan)
    ax.bar(names, last_aurocs)
    ax.set_title("Validation AUROC (last epoch) per model")
    for lbl in ax.get_xticklabels(): lbl.set_rotation(20)
    plt.tight_layout()
    if SAVE_FIGS: plt.savefig(RUN_ROOT/"models_val_auc_bar.png", dpi=140)
    plt.show()
except Exception as e:
    print("[warn] summary plotting failed:", e)


Model: googlenet


Downloading: "https://download.pytorch.org/models/googlenet-1378be20.pth" to /root/.cache/torch/hub/checkpoints/googlenet-1378be20.pth
100%|██████████| 49.7M/49.7M [00:00<00:00, 229MB/s]


Found checkpoint for /content/drive/MyDrive/mura_runs_v4/checkpoints/googlenet/latest.pt, loading …
Resume at epoch 5 (best_auc=0.8565)
Resumed googlenet from epoch 5




Epoch 5 googlenet: train_loss=0.4450 acc=0.8245 | val_loss=0.4696 acc=0.8104 | val_AUROC=0.8742  saved_best(AUROC=0.8742)  time=630.4s




Epoch 6 googlenet: train_loss=0.4335 acc=0.8303 | val_loss=0.5017 acc=0.8023 | val_AUROC=0.8781  saved_best(AUROC=0.8781)  time=630.4s




Epoch 7 googlenet: train_loss=0.4239 acc=0.8381 | val_loss=0.4691 acc=0.8039 | val_AUROC=0.8759  time=626.4s




Epoch 8 googlenet: train_loss=0.4120 acc=0.8448 | val_loss=0.5150 acc=0.7954 | val_AUROC=0.8631  time=627.3s




Epoch 9 googlenet: train_loss=0.3998 acc=0.8517 | val_loss=0.5623 acc=0.7998 | val_AUROC=0.8682  time=628.3s




Saved snapshot: /content/drive/MyDrive/mura_runs_v4/checkpoints/googlenet/snapshots/epoch_010.pt
Epoch 10 googlenet: train_loss=0.3906 acc=0.8570 | val_loss=0.4940 acc=0.7957 | val_AUROC=0.8678  time=634.1s




Epoch 11 googlenet: train_loss=0.3483 acc=0.8818 | val_loss=0.4784 acc=0.8264 | val_AUROC=0.8889  saved_best(AUROC=0.8889)  time=631.0s




Epoch 12 googlenet: train_loss=0.3256 acc=0.8949 | val_loss=0.5107 acc=0.8176 | val_AUROC=0.8846  time=629.2s




Epoch 13 googlenet: train_loss=0.3114 acc=0.9027 | val_loss=0.5192 acc=0.8114 | val_AUROC=0.8808  time=626.0s




Epoch 14 googlenet: train_loss=0.2765 acc=0.9208 | val_loss=0.5243 acc=0.8167 | val_AUROC=0.8831  time=629.4s




Epoch 15 googlenet: train_loss=0.2583 acc=0.9304 | val_loss=0.5397 acc=0.8170 | val_AUROC=0.8801  time=639.1s




Epoch 16 googlenet: train_loss=0.2481 acc=0.9364 | val_loss=0.5509 acc=0.8183 | val_AUROC=0.8815  time=625.9s




Epoch 17 googlenet: train_loss=0.2264 acc=0.9490 | val_loss=0.5599 acc=0.8136 | val_AUROC=0.8750  time=624.5s


googlenet Epoch 18/50:  51%|█████▏    | 787/1534 [04:43<04:27,  2.79it/s, loss=0.2171, acc=0.9545]