# üî• 05 ‚Äî Grad-CAM Visualizations

**Purpose:** Generate attention visualizations to understand what models focus on.

**Sections:**
1. Inline Setup
2. Copy Hybrid Crops to /content
3. Build Modality Bundles (load models + predictions)
4. Confidence Analysis
5. Grad-CAM Gallery (correct/wrong √ó high/low confidence)
6. Confusion-Pair Grad-CAM

**Prerequisites:**
- Trained checkpoints exist on Drive
- Predictions exist for modalities you want to visualize


In [None]:
# --- INLINE SETUP ---
import os, subprocess, sys

REPO_DIRNAME   = "CNNs-distracted-driving"
PROJECT_ROOT   = f"/content/{REPO_DIRNAME}"
DRIVE_PATH     = "/content/drive/MyDrive/TFM"
DRIVE_DATA_ROOT = f"{DRIVE_PATH}/data"
FAST_DATA      = "/content/data"
DATASET_ROOT   = DRIVE_DATA_ROOT
OUT_ROOT       = f"{DRIVE_PATH}/outputs"
CKPT_ROOT      = f"{DRIVE_PATH}/checkpoints"

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

if not os.path.isdir(PROJECT_ROOT):
    subprocess.call(f"git clone https://github.com/ClaudiaCPach/CNNs-distracted-driving {PROJECT_ROOT}", shell=True)
subprocess.call(f"pip install -q -e {PROJECT_ROOT}", shell=True)
!pip -q install timm grad-cam

os.environ["DRIVE_PATH"] = DRIVE_PATH
os.environ["DATASET_ROOT"] = DATASET_ROOT
os.environ["OUT_ROOT"] = OUT_ROOT
os.environ["CKPT_ROOT"] = CKPT_ROOT
os.environ["FAST_DATA"] = FAST_DATA

sys.path.insert(0, PROJECT_ROOT)
sys.path.insert(0, os.path.join(PROJECT_ROOT, "src"))

!nvidia-smi || echo "No GPU"
print("‚úÖ Setup complete")


## ‚ö° Section 2: Copy Hybrid Crops to /content


In [None]:
# Copy hybrid crops (needed for Grad-CAM visualization)
import os, shutil
from pathlib import Path
import importlib

# Copy BOTH variants for comparison
for HYBRID_VARIANT in ["face", "face_hands"]:
    LOCAL_ROOT = Path("/content/data/hybrid")
    DRIVE_ROOT = Path(OUT_ROOT) / "hybrid"
    LOCAL_VARIANT_DIR = LOCAL_ROOT / HYBRID_VARIANT
    DRIVE_VARIANT_DIR = DRIVE_ROOT / HYBRID_VARIANT
    
    def count_jpgs(p: Path) -> int:
        return sum(1 for _ in p.rglob("*.jpg")) if p.exists() else 0
    
    local_count = count_jpgs(LOCAL_VARIANT_DIR)
    drive_count = count_jpgs(DRIVE_VARIANT_DIR)
    
    if local_count > 0:
        print(f"‚úÖ {HYBRID_VARIANT}: Already in /content ({local_count} jpgs)")
    elif drive_count > 0:
        print(f"üì¶ {HYBRID_VARIANT}: Copying from Drive...")
        LOCAL_VARIANT_DIR.mkdir(parents=True, exist_ok=True)
        file_count = 0
        for src_dir, _, files in os.walk(DRIVE_VARIANT_DIR):
            rel_dir = Path(src_dir).relative_to(DRIVE_VARIANT_DIR)
            dst_dir = LOCAL_VARIANT_DIR / rel_dir
            dst_dir.mkdir(parents=True, exist_ok=True)
            for fname in files:
                if fname.lower().endswith(".jpg"):
                    shutil.copy2(Path(src_dir) / fname, dst_dir / fname)
                    file_count += 1
        print(f"   Copied {file_count} images")
        
        for fname in [f"manifest_{HYBRID_VARIANT}.csv", f"train_{HYBRID_VARIANT}.csv",
                      f"val_{HYBRID_VARIANT}.csv", f"test_{HYBRID_VARIANT}.csv"]:
            src = DRIVE_ROOT / fname
            if src.exists():
                shutil.copy2(src, LOCAL_ROOT / fname)
    else:
        print(f"‚ö†Ô∏è {HYBRID_VARIANT}: Not found on Drive")

os.environ["HYBRID_ROOT_LOCAL"] = str(Path("/content/data/hybrid"))
os.environ["DATASET_ROOT"] = str(Path("/content/data/hybrid"))
print(f"\n‚úÖ DATASET_ROOT = {os.environ['DATASET_ROOT']}")


## üì¶ Section 3: Build Modality Bundles

Load models, predictions, and Grad-CAM objects for each modality.

**5-Run Experimental Plan:**
| Run | Name | Tag Example | Mode |
|-----|------|-------------|------|
| 1 | Full | `effb0_fullframe_v1` | full |
| 2 | Face | `effb0_face_v1` | hybrid |
| 3 | Face+Hands | `effb0_face_hands_v1` | hybrid |
| 4 | Ctrl-FaceSub | `effb0_fullframe_facesubset_v1` | full |
| 5 | Ctrl-FHSub | `effb0_fullframe_fhsubset_v1` | full |


In [None]:
# ============== CONFIGURE YOUR RUNS ==============
# Update tags to match your experiment naming from 02_training.ipynb
# Set any entry to comment-out or remove to skip

from pathlib import Path

RUNS = [
    # --- Natural Runs (different ID sets) ---
    {"name": "Full",       "tag": "effb0_fullframe_v1",             "mode": "full",   "roi_variant": None},
    {"name": "Face",       "tag": "effb0_face_v1",                  "mode": "hybrid", "roi_variant": "face"},
    {"name": "Face+Hands", "tag": "effb0_face_hands_v1",            "mode": "hybrid", "roi_variant": "face_hands"},
    
    # --- Control Runs (same IDs as ROI runs, but full-frame) ---
    {"name": "Ctrl-FaceSub", "tag": "effb0_fullframe_facesubset_v1", "mode": "full", "roi_variant": None},
    {"name": "Ctrl-FHSub",   "tag": "effb0_fullframe_fhsubset_v1",   "mode": "full", "roi_variant": None},
]

# Filter to only runs whose checkpoints/predictions exist (set to True to auto-filter)
AUTO_FILTER_AVAILABLE = True

MODEL_NAME = "efficientnet_b0"
SPLIT_TO_ANALYZE = "test"
IMAGE_SIZE = 224

FORCE_CPU = False
import torch
device = torch.device("cpu" if FORCE_CPU else ("cuda" if torch.cuda.is_available() else "cpu"))
print(f"üß† Using device: {device}")

if AUTO_FILTER_AVAILABLE:
    available_runs = []
    for run in RUNS:
        ckpt_path = Path(CKPT_ROOT) / run["tag"] / "best.pt"
        pred_path = Path(OUT_ROOT) / "preds" / SPLIT_TO_ANALYZE / f"{run['tag']}_{SPLIT_TO_ANALYZE}.csv"
        if ckpt_path.exists() and pred_path.exists():
            available_runs.append(run)
            print(f"‚úÖ {run['name']}: checkpoint + predictions found")
        else:
            print(f"‚ö†Ô∏è  {run['name']}: skipping (ckpt={ckpt_path.exists()}, preds={pred_path.exists()})")
    RUNS = available_runs
    print(f"\nüéØ Analyzing {len(RUNS)} runs")


In [None]:
# Build modality bundles
import pandas as pd
import numpy as np
import torch
from PIL import Image
from pathlib import Path
from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from ddriver.models.registry import build_model, register_timm_backbone

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def get_target_layers(model, model_name):
    base = getattr(model, "backbone", model)
    if "efficientnet" in model_name.lower():
        if hasattr(base, "conv_head"):
            return [base.conv_head]
        elif hasattr(base, "blocks"):
            return [base.blocks[-1]]
    for attr in ["features", "blocks", "stages", "layer4"]:
        if hasattr(base, attr):
            layer = getattr(base, attr)
            if hasattr(layer, "__getitem__"):
                return [layer[-1]]
    return [list(base.children())[-2]]

def find_ckpt_for_tag(run_tag):
    run_base = Path(CKPT_ROOT) / "runs" / run_tag
    all_runs = sorted(run_base.glob("*/"))
    if not all_runs:
        raise FileNotFoundError(f"No runs under {run_base}")
    latest = all_runs[-1]
    for name in ["best.pt", "last.pt"]:
        if (latest / name).exists():
            return latest / name
    raise FileNotFoundError(f"No checkpoint in {latest}")

def find_preds_csv(run_tag, split):
    for pattern in [f"{run_tag}_{split}.csv", f"{run_tag}.csv"]:
        p = Path(OUT_ROOT) / "preds" / split / pattern
        if p.exists():
            return p
    raise FileNotFoundError(f"Preds not found for {run_tag}")

def extract_class_from_path(p):
    for part in Path(p).parts:
        if part.startswith("c") and len(part) == 2 and part[1].isdigit():
            return part
    return ""

def class_to_int(class_id):
    if pd.isna(class_id):
        return -1
    if isinstance(class_id, str) and class_id.startswith("c"):
        return int(class_id[1:])
    return int(class_id)

def load_bundle(run):
    run_tag = run["tag"]
    mode = run["mode"]
    roi_variant = run["roi_variant"]
    
    print(f"\nüì¶ Loading: {run['name']} ({run_tag})")
    
    ckpt_path = find_ckpt_for_tag(run_tag)
    ckpt = torch.load(ckpt_path, map_location="cpu")
    
    try:
        register_timm_backbone(MODEL_NAME)
    except:
        pass
    
    msd = ckpt.get("model_state_dict", {})
    num_classes = msd.get("classifier.weight", torch.zeros(10, 1)).shape[0]
    
    model = build_model(MODEL_NAME, num_classes=num_classes, pretrained=False)
    model.load_state_dict(msd)
    model.eval()
    model = model.to(device)
    
    target_layers = get_target_layers(model, MODEL_NAME)
    cam = GradCAM(model=model, target_layers=target_layers)
    
    preds_csv = find_preds_csv(run_tag, SPLIT_TO_ANALYZE)
    preds_df = pd.read_csv(preds_csv)
    
    if mode == "hybrid":
        data_root = Path(os.environ.get("HYBRID_ROOT_LOCAL", Path(OUT_ROOT) / "hybrid"))
        manifest_path = data_root / f"manifest_{roi_variant}.csv"
    else:
        data_root = Path(DATASET_ROOT)
        manifest_path = Path(OUT_ROOT) / "manifests" / "manifest.csv"
    
    manifest_df = pd.read_csv(manifest_path)
    
    preds_df["_class"] = preds_df["path"].apply(extract_class_from_path)
    preds_df["_filename"] = preds_df["path"].apply(lambda p: Path(p).name)
    manifest_df["_class"] = manifest_df["class_id"]
    manifest_df["_filename"] = manifest_df["path"].apply(lambda p: Path(p).name)
    
    manifest_for_merge = manifest_df[["_class", "_filename", "path", "class_id"]].drop_duplicates(
        subset=["_class", "_filename"], keep="first"
    ).rename(columns={"path": "crop_path"})
    
    preds_df = preds_df.merge(manifest_for_merge, on=["_class", "_filename"], how="left")
    preds_df = preds_df.dropna(subset=["class_id"]).copy()
    preds_df["vis_path"] = preds_df["crop_path"]
    preds_df["label"] = preds_df["class_id"].apply(class_to_int)
    preds_df["pred"] = preds_df["pred_class_id"].apply(class_to_int)
    if "confidence" not in preds_df.columns:
        preds_df["confidence"] = 1.0
    
    def generate_gradcam_fn(img_path):
        img = Image.open(img_path).convert("RGB")
        img_resized = img.resize((IMAGE_SIZE, IMAGE_SIZE))
        img_np = np.array(img_resized) / 255.0
        img_tensor = transform(img).unsqueeze(0).to(device)
        grayscale_cam = cam(input_tensor=img_tensor, targets=None)[0, :]
        visualization = show_cam_on_image(img_np.astype(np.float32), grayscale_cam, use_rgb=True)
        return visualization, grayscale_cam
    
    print(f"   ‚úÖ Loaded {len(preds_df)} predictions, checkpoint from {ckpt_path.name}")
    
    return {
        "run_name": run["name"],
        "tag": run_tag,
        "mode": mode,
        "roi_variant": roi_variant,
        "preds_df": preds_df,
        "model": model,
        "cam": cam,
        "device": device,
        "data_root": data_root,
        "generate_gradcam": generate_gradcam_fn,
    }

MODALITY_BUNDLES = {}
for run in RUNS:
    try:
        key = run["name"].lower().replace("+", "_").replace(" ", "_")
        MODALITY_BUNDLES[key] = load_bundle(run)
    except Exception as e:
        print(f"   ‚ö†Ô∏è Skipped {run['name']}: {e}")

print(f"\n‚úÖ Loaded {len(MODALITY_BUNDLES)} bundles: {list(MODALITY_BUNDLES.keys())}")


## üìä Section 4: Confidence Analysis


In [None]:
# Confidence summary per modality
import matplotlib.pyplot as plt

HIGH_CONF_THRESHOLD = 0.8
results = []

for bundle_key, bundle in MODALITY_BUNDLES.items():
    preds_df = bundle["preds_df"].copy()
    preds_df["correct"] = preds_df["pred"] == preds_df["label"]
    
    correct_df = preds_df[preds_df["correct"]]
    wrong_df = preds_df[~preds_df["correct"]]
    
    mean_conf_correct = correct_df["confidence"].mean() if len(correct_df) > 0 else np.nan
    mean_conf_wrong = wrong_df["confidence"].mean() if len(wrong_df) > 0 else np.nan
    
    high_conf = preds_df[preds_df["confidence"] >= HIGH_CONF_THRESHOLD]
    overconf_rate = (len(high_conf[~high_conf["correct"]]) / len(high_conf) * 100) if len(high_conf) > 0 else 0
    
    results.append({
        "Modality": bundle["run_name"],
        "N": len(preds_df),
        "Accuracy": (preds_df["correct"].mean() * 100),
        "Conf (Correct)": mean_conf_correct,
        "Conf (Wrong)": mean_conf_wrong,
        "Conf Gap": mean_conf_correct - mean_conf_wrong if not np.isnan(mean_conf_wrong) else np.nan,
        "Overconf Error %": overconf_rate,
    })

results_df = pd.DataFrame(results)
print("=" * 80)
print("üìä CONFIDENCE SUMMARY")
print("=" * 80)
print(results_df.to_string(index=False))


## üî• Section 5: Grad-CAM Gallery

Generate example visualizations for correct/wrong predictions.


In [None]:
# Grad-CAM gallery: sample from each category
from PIL import Image

CLASS_NAMES = {
    0: "Safe", 1: "Txt-R", 2: "Ph-R", 3: "Txt-L",
    4: "Ph-L", 5: "Radio", 6: "Drink", 7: "Reach",
    8: "Hair", 9: "Pass"
}

BUNDLE_KEY = "face"  # face | face_hands | full
N_SAMPLES = 3

if BUNDLE_KEY not in MODALITY_BUNDLES:
    print(f"‚ö†Ô∏è Bundle '{BUNDLE_KEY}' not loaded")
else:
    bundle = MODALITY_BUNDLES[BUNDLE_KEY]
    preds_df = bundle["preds_df"].copy()
    preds_df["correct"] = preds_df["pred"] == preds_df["label"]
    
    def find_image_path(vis_path, data_root, roi_variant):
        p = Path(vis_path)
        if p.is_absolute() and p.exists():
            return p
        candidate = data_root / vis_path
        if candidate.exists():
            return candidate
        if roi_variant:
            candidate = data_root / roi_variant / vis_path
            if candidate.exists():
                return candidate
        return None
    
    # Sample categories
    categories = [
        ("Correct + High Conf", preds_df[(preds_df["correct"]) & (preds_df["confidence"] >= 0.9)]),
        ("Wrong + High Conf", preds_df[(~preds_df["correct"]) & (preds_df["confidence"] >= 0.8)]),
        ("Wrong + Low Conf", preds_df[(~preds_df["correct"]) & (preds_df["confidence"] < 0.5)]),
    ]
    
    for cat_name, cat_df in categories:
        if len(cat_df) == 0:
            print(f"\n‚ö†Ô∏è {cat_name}: No samples")
            continue
        
        samples = cat_df.sample(min(N_SAMPLES, len(cat_df)))
        print(f"\nüì∏ {cat_name} ({len(cat_df)} total, showing {len(samples)})")
        
        fig, axes = plt.subplots(1, len(samples), figsize=(4*len(samples), 4))
        if len(samples) == 1:
            axes = [axes]
        
        for ax, (_, row) in zip(axes, samples.iterrows()):
            img_path = find_image_path(row["vis_path"], bundle["data_root"], bundle["roi_variant"])
            if img_path and img_path.exists():
                viz, _ = bundle["generate_gradcam"](img_path)
                ax.imshow(viz)
                true_label = CLASS_NAMES.get(row["label"], f"c{row['label']}")
                pred_label = CLASS_NAMES.get(row["pred"], f"c{row['pred']}")
                icon = "‚úÖ" if row["correct"] else "‚ùå"
                ax.set_title(f"True: {true_label}\nPred: {pred_label} ({row['confidence']:.2f}) {icon}", fontsize=10)
            else:
                ax.text(0.5, 0.5, "Not found", ha="center", va="center")
            ax.axis("off")
        
        plt.suptitle(f"{bundle['run_name']}: {cat_name}", fontweight="bold")
        plt.tight_layout()
        
        out_dir = Path(OUT_ROOT) / "gradcam" / bundle["tag"]
        out_dir.mkdir(parents=True, exist_ok=True)
        out_path = out_dir / f"{cat_name.replace(' ', '_').lower()}.png"
        plt.savefig(out_path, dpi=150, bbox_inches="tight")
        plt.show()
        print(f"   üíæ Saved to {out_path}")


## üîç Section 6: Confusion-Pair Grad-CAM

Investigate specific confusion pairs to understand why the model fails.


In [None]:
# Confusion-pair Grad-CAM
CONFUSION_PAIRS = [
    (5, 8, "Radio ‚Üí Hair/Makeup"),
    (9, 0, "Passenger ‚Üí Safe"),
    (1, 2, "Texting(R) ‚Üí Phone(R)"),
]

N_EXAMPLES = 3
BUNDLE_KEY = "face_hands"  # Which modality to analyze

if BUNDLE_KEY not in MODALITY_BUNDLES:
    print(f"‚ö†Ô∏è Bundle '{BUNDLE_KEY}' not loaded")
else:
    bundle = MODALITY_BUNDLES[BUNDLE_KEY]
    preds_df = bundle["preds_df"]
    
    for true_c, pred_c, desc in CONFUSION_PAIRS:
        confusion_df = preds_df[(preds_df["label"] == true_c) & (preds_df["pred"] == pred_c)]
        
        if len(confusion_df) == 0:
            print(f"\n‚¨ú {desc}: No examples")
            continue
        
        samples = confusion_df.sample(min(N_EXAMPLES, len(confusion_df)))
        print(f"\nüîç {desc} ({len(confusion_df)} total, showing {len(samples)})")
        
        fig, axes = plt.subplots(1, len(samples), figsize=(4*len(samples), 4))
        if len(samples) == 1:
            axes = [axes]
        
        for ax, (_, row) in zip(axes, samples.iterrows()):
            img_path = find_image_path(row["vis_path"], bundle["data_root"], bundle["roi_variant"])
            if img_path and img_path.exists():
                viz, _ = bundle["generate_gradcam"](img_path)
                ax.imshow(viz)
                ax.set_title(f"True: {CLASS_NAMES.get(true_c)}\nPred: {CLASS_NAMES.get(pred_c)} ({row['confidence']:.2f})", fontsize=10)
            else:
                ax.text(0.5, 0.5, "Not found", ha="center", va="center")
            ax.axis("off")
        
        plt.suptitle(f"{bundle['run_name']}: {desc}", fontweight="bold")
        plt.tight_layout()
        
        out_dir = Path(OUT_ROOT) / "gradcam" / "confusions"
        out_dir.mkdir(parents=True, exist_ok=True)
        safe_name = desc.replace(" ", "_").replace("‚Üí", "to").replace("/", "_")
        out_path = out_dir / f"{safe_name}__{BUNDLE_KEY}.png"
        plt.savefig(out_path, dpi=150, bbox_inches="tight")
        plt.show()
        print(f"   üíæ Saved to {out_path}")


## üéØ Section 7: ROI vs Control Side-by-Side Comparison

Compare Grad-CAM attention between ROI models and their matched full-frame controls **on the same image IDs**. This shows whether the ROI crop helps the model focus on semantically relevant features.


In [None]:
# ROI vs Control Side-by-Side Comparison
# Compare attention on SAME images (matched by original image ID)

# Comparisons to make (ROI model, Control model)
COMPARISONS = [
    ("Face+Hands", "Ctrl-FHSub"),   # Face+Hands ROI vs Full-frame on same IDs
    ("Face", "Ctrl-FaceSub"),       # Face ROI vs Full-frame on same IDs
]

N_COMPARISON_SAMPLES = 6  # Per comparison
SAMPLE_CATEGORIES = ["correct", "wrong"]  # Sample from correct and wrong predictions

for roi_name, ctrl_name in COMPARISONS:
    if roi_name not in MODALITY_BUNDLES or ctrl_name not in MODALITY_BUNDLES:
        print(f"‚ö†Ô∏è Skipping {roi_name} vs {ctrl_name}: one or both not loaded")
        continue
    
    roi_bundle = MODALITY_BUNDLES[roi_name]
    ctrl_bundle = MODALITY_BUNDLES[ctrl_name]
    
    print(f"\n{'='*80}")
    print(f"üéØ COMPARING: {roi_name} (ROI) vs {ctrl_name} (Full-frame control)")
    print(f"{'='*80}")
    
    # Find common image IDs (by extracting original image ID from paths)
    roi_df = roi_bundle["preds_df"].copy()
    ctrl_df = ctrl_bundle["preds_df"].copy()
    
    # For ROI predictions, extract original image ID from crop path
    # Expected path format: .../face_hands/c0/img_123_driver_uuid.jpg -> img_123_driver
    def extract_original_id(path):
        fname = Path(path).stem  # e.g., "img_123_driver_uuid"
        # Remove hybrid-specific suffixes like _uuid
        parts = fname.split("_")
        if len(parts) >= 3:
            return "_".join(parts[:3])  # img_123_driver
        return fname
    
    roi_df["orig_id"] = roi_df["vis_path"].apply(extract_original_id)
    ctrl_df["orig_id"] = ctrl_df["vis_path"].apply(extract_original_id)
    
    # Find common IDs
    common_ids = set(roi_df["orig_id"]) & set(ctrl_df["orig_id"])
    print(f"üìä Common image IDs: {len(common_ids)}")
    
    if len(common_ids) < N_COMPARISON_SAMPLES:
        print(f"‚ö†Ô∏è Not enough common IDs for comparison")
        continue
    
    # Filter to common IDs
    roi_common = roi_df[roi_df["orig_id"].isin(common_ids)]
    ctrl_common = ctrl_df[ctrl_df["orig_id"].isin(common_ids)]
    
    # Add correctness
    roi_common["correct"] = roi_common["pred"] == roi_common["label"]
    ctrl_common["correct"] = ctrl_common["pred"] == ctrl_common["label"]
    
    # Sample images for comparison
    for cat in SAMPLE_CATEGORIES:
        is_correct = (cat == "correct")
        
        # Get IDs where ROI model was correct/wrong
        roi_filtered = roi_common[roi_common["correct"] == is_correct]
        if len(roi_filtered) < N_COMPARISON_SAMPLES:
            print(f"‚ö†Ô∏è Not enough {cat} samples for {roi_name}")
            continue
        
        sample_ids = roi_filtered.sample(min(N_COMPARISON_SAMPLES, len(roi_filtered)))["orig_id"].tolist()
        
        # Create side-by-side figure
        fig, axes = plt.subplots(2, len(sample_ids), figsize=(4*len(sample_ids), 8))
        
        for col_idx, orig_id in enumerate(sample_ids):
            # Get ROI row
            roi_row = roi_common[roi_common["orig_id"] == orig_id].iloc[0]
            # Get Control row with same orig_id
            ctrl_rows = ctrl_common[ctrl_common["orig_id"] == orig_id]
            if len(ctrl_rows) == 0:
                continue
            ctrl_row = ctrl_rows.iloc[0]
            
            # ROI Grad-CAM (top row)
            ax_roi = axes[0, col_idx] if len(sample_ids) > 1 else axes[0]
            roi_img_path = Path(roi_bundle["data_root"]) / roi_bundle["roi_variant"] / roi_row["vis_path"]
            if not roi_img_path.exists():
                roi_img_path = Path(roi_row["vis_path"])
            
            if roi_img_path.exists():
                viz_roi, _ = roi_bundle["generate_gradcam"](roi_img_path)
                ax_roi.imshow(viz_roi)
            ax_roi.set_title(f"{roi_name}\nP:{CLASS_NAMES.get(roi_row['pred'], '?')[:6]} ({roi_row['confidence']:.2f})", fontsize=9)
            ax_roi.axis("off")
            
            # Control Grad-CAM (bottom row) - uses full-frame image
            ax_ctrl = axes[1, col_idx] if len(sample_ids) > 1 else axes[1]
            ctrl_img_path = Path(ctrl_bundle["data_root"]) / ctrl_row["vis_path"]
            if not ctrl_img_path.exists():
                ctrl_img_path = Path(ctrl_row["vis_path"])
            
            if ctrl_img_path.exists():
                viz_ctrl, _ = ctrl_bundle["generate_gradcam"](ctrl_img_path)
                ax_ctrl.imshow(viz_ctrl)
            ax_ctrl.set_title(f"{ctrl_name}\nP:{CLASS_NAMES.get(ctrl_row['pred'], '?')[:6]} ({ctrl_row['confidence']:.2f})", fontsize=9)
            ax_ctrl.axis("off")
        
        # Super title with category
        true_class = CLASS_NAMES.get(roi_row["label"], f"c{roi_row['label']}")
        status = "‚úÖ ROI Correct" if is_correct else "‚ùå ROI Wrong"
        fig.suptitle(f"{roi_name} vs {ctrl_name} | {status} | True: {true_class}", fontweight="bold", fontsize=12)
        plt.tight_layout()
        
        # Save
        out_dir = Path(OUT_ROOT) / "gradcam" / "control_comparison"
        out_dir.mkdir(parents=True, exist_ok=True)
        out_path = out_dir / f"{roi_name}_vs_{ctrl_name}_{cat}.png"
        plt.savefig(out_path, dpi=150, bbox_inches="tight")
        plt.show()
        print(f"üíæ Saved to {out_path}")

print("\n‚úÖ Control comparison complete!")


## ‚úÖ Grad-CAM Complete!

**Outputs saved to Drive:**
- `OUT_ROOT/gradcam/{model_tag}/` ‚Äî Per-category Grad-CAM galleries
- `OUT_ROOT/gradcam/confusions/` ‚Äî Confusion-pair visualizations
- `OUT_ROOT/gradcam/control_comparison/` ‚Äî ROI vs Control side-by-side comparisons

**Use these figures in your thesis to:**
- Show what the model focuses on when correct (face, hands, posture)
- Identify shortcuts (looking at background, identity features)
- Explain why certain confusions happen
- **Demonstrate ROI effect:** Control comparison figures show attention on the same image IDs

**5-Run Control Analysis Key Findings:**
- Compare `Face+Hands` vs `Ctrl-FHSub` on same IDs ‚Üí isolates ROI cropping effect
- Full-frame controls may attend to background/distractors
- ROI crops should attend more to face/hands regions
- Use these comparisons to justify (or question) ROI preprocessing
