
# üöÄ Colab Setup ‚Äî **CNNs-distracted-driving** (hardcoded + config-aware)

This version is **simplified and hardcoded** for your repo and URL, and it **respects your `src/ddriver/config.py`**.
- Repo name fixed to **`CNNs-distracted-driving`**
- Repo URL fixed to **`https://github.com/ClaudiaCPach/CNNs-distracted-driving`**
- Uses your `config.py` convention: when running in Colab, we **set env vars** (`DRIVE_PATH`, `DATASET_ROOT`, `OUT_ROOT`, `CKPT_ROOT`, `FAST_DATA`) so your code reads correct paths via `ddriver.config`.
- Optional `FAST_DATA` at `/content/data` for faster I/O (if you later copy data there).

> Run cells **top ‚Üí bottom** the first time. Re-run **Update repo** to pull new commits after you push.


In [None]:

# üîß 0) (Optional) quick GPU check
!nvidia-smi || echo "No GPU detected ‚Äî CPU runtime is okay for setup steps."


In [None]:

# üîß 1) Fixed config for your repo + Drive layout
import os

REPO_URL       = "https://github.com/ClaudiaCPach/CNNs-distracted-driving"
REPO_DIRNAME   = "CNNs-distracted-driving"   # hardcoded
BRANCH         = "main"
PROJECT_ROOT   = f"/content/{REPO_DIRNAME}"  # where the repo will live in Colab

# Your persistent Google Drive base folder (matches your project docs):
DRIVE_PATH       = "/content/drive/MyDrive/TFM"
DRIVE_DATA_ROOT  = f"{DRIVE_PATH}/data"      # contains auc.distracted.driver.dataset_v2

# Optional: a fast, ephemeral workspace inside the VM
FAST_DATA        = "/content/data"           # rsync target for faster I/O (lives on the VM SSD)

# Start with Drive as the canonical dataset root; later cells can switch to FAST_DATA
DATASET_ROOT     = DRIVE_DATA_ROOT
OUT_ROOT         = f"{DRIVE_PATH}/outputs"
CKPT_ROOT        = f"{DRIVE_PATH}/checkpoints"


In [None]:

# üîå 2) Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=False)
print("‚úÖ Drive mounted.")


In [None]:

# üìÅ 3) Clone or update the repo (no name inference ‚Äî all hardcoded)
import os, subprocess

def sh(cmd):
    print(f"\n$ {cmd}")
    rc = subprocess.call(cmd, shell=True, executable="/bin/bash")
    if rc != 0:
        raise RuntimeError(f"Command failed with exit code {rc}: {cmd}")

if os.path.isdir(PROJECT_ROOT):
    print(f"üìÅ Repo already present at {PROJECT_ROOT}. Pulling latest on branch {BRANCH}...")
    sh(f"cd {PROJECT_ROOT} && git fetch origin {BRANCH} && git checkout {BRANCH} && git pull --rebase origin {BRANCH}")
else:
    print(f"‚¨áÔ∏è Cloning {REPO_URL} ‚Üí {PROJECT_ROOT}")
    sh(f"git clone --branch {BRANCH} {REPO_URL} {PROJECT_ROOT}")

print("PROJECT_ROOT =", PROJECT_ROOT)


In [None]:

# üì¶ 4) Install the repo (editable) + requirements (uses pyproject.toml if present)
import os, subprocess

def sh(cmd):
    print(f"\n$ {cmd}")
    rc = subprocess.call(cmd, shell=True, executable="/bin/bash")
    if rc != 0:
        raise RuntimeError(f"Command failed with exit code {rc}: {cmd}")

print("üîÑ Upgrading pip/setuptools/wheel...")
sh("python -m pip install --upgrade pip setuptools wheel")

has_pyproject = os.path.exists(os.path.join(PROJECT_ROOT, "pyproject.toml"))
if has_pyproject:
    print("üì¶ Editable install from pyproject.toml ...")
    sh(f"cd {PROJECT_ROOT} && pip install -e .")
else:
    print("‚ö†Ô∏è No pyproject.toml found. Skipping editable install.")

req_path = os.path.join(PROJECT_ROOT, "requirements.txt")
if os.path.exists(req_path):
    print("üìù Installing requirements.txt...")
    sh(f"pip install -r {req_path}")
else:
    print("‚ÑπÔ∏è No requirements.txt found ‚Äî continuing.")


In [None]:

# üå≥ 5) Configure environment for your ddriver.config (Colab branch)
# Your config.py reads env vars and falls back to sensible defaults when in Colab.
import os

os.environ["DRIVE_PATH"]   = DRIVE_PATH
os.environ["DATASET_ROOT"] = DATASET_ROOT
os.environ["OUT_ROOT"]     = OUT_ROOT
os.environ["CKPT_ROOT"]    = CKPT_ROOT
os.environ["FAST_DATA"]    = FAST_DATA

# Also write a .env (harmless in Colab; helpful if code calls load_dotenv())
env_text = f"""DRIVE_PATH={DRIVE_PATH}
DATASET_ROOT={DATASET_ROOT}
OUT_ROOT={OUT_ROOT}
CKPT_ROOT={CKPT_ROOT}
FAST_DATA={FAST_DATA}
"""
with open(os.path.join(PROJECT_ROOT, ".env"), "w") as f:
    f.write(env_text)

print("‚úÖ Environment variables set for ddriver.config")
print("\nSummary:")
for k in ["DRIVE_PATH","DATASET_ROOT","OUT_ROOT","CKPT_ROOT","FAST_DATA"]:
    print(f"{k} = {os.environ[k]}")


In [None]:

# ‚úÖ 8) Import smoke test (uses your package + config.py)
import sys, os
sys.path.append(PROJECT_ROOT)
sys.path.append(os.path.join(PROJECT_ROOT, "src"))  # <‚Äî lets Python find src/ddriver

try:
    import ddriver
    print("ddriver imported OK from:", ddriver.__file__)
    # Confirm config picks up Colab env:
    try:
        from ddriver import config
        print("Loaded ddriver.config successfully.")
        # Echo the resolved paths from config (they are pathlib.Path objects)
        print("config.DATASET_ROOT =", config.DATASET_ROOT)
        print("config.OUT_ROOT     =", config.OUT_ROOT)
        print("config.CKPT_ROOT    =", config.CKPT_ROOT)
        print("config.FAST_DATA    =", config.FAST_DATA)
    except Exception as e:
        print("Note: ddriver.config not imported:", e)
except Exception as e:
    print("‚ö†Ô∏è Import failed ‚Äî check package name/setup.")
    print(e)


# üìã 9) Generate Manifest and Split CSVs

This step creates the CSV files that tell your code where all the images are and which ones go to train/val/test.

**What this does:**
- Scans all your images in the dataset folder
- Creates a big list (manifest.csv) with info about every image
- Creates three smaller lists (train.csv, val.csv, test.csv) that say which images belong where
- Saves everything to your Google Drive so it's permanent

**Why we need this:**
- Your training code needs to know which images to use
- The manifest remembers which driver each image belongs to (for VAL split)
- The split CSVs organize images into train/val/test groups


In [None]:
# Run the manifest generator
# This is like asking a librarian to catalog all your books and create reading lists

import subprocess
import sys

# Make sure we can import ddriver
sys.path.insert(0, PROJECT_ROOT)

# Run the manifest script
# --write-split-lists means "also create train.csv, val.csv, test.csv files"
manifest_cmd = f"cd {PROJECT_ROOT} && python -m ddriver.data.manifest --write-split-lists"

print("üî® Generating manifest and split CSVs...")
print(f"Running: {manifest_cmd}\n")

result = subprocess.run(
    manifest_cmd,
    shell=True,
    capture_output=True,
    text=True
)

# Show what happened
print(result.stdout)
if result.stderr:
    print("Warnings/Errors:")
    print(result.stderr)

if result.returncode == 0:
    print("\n‚úÖ Manifest and split CSVs generated successfully!")
    print(f"   Manifest: {os.environ['OUT_ROOT']}/manifests/manifest.csv")
    print(f"   Train split: {os.environ['OUT_ROOT']}/splits/train.csv")
    print(f"   Val split: {os.environ['OUT_ROOT']}/splits/val.csv")
    print(f"   Test split: {os.environ['OUT_ROOT']}/splits/test.csv")
else:
    print(f"\n‚ùå Error generating manifest (exit code {result.returncode})")
    raise RuntimeError("Manifest generation failed")


In [None]:
# Quick check: Did the CSVs get created?
# This is like checking that the librarian actually wrote down all the book lists

import pandas as pd
from pathlib import Path

manifest_path = Path(os.environ['OUT_ROOT']) / "manifests" / "manifest.csv"
train_path = Path(os.environ['OUT_ROOT']) / "splits" / "train.csv"
val_path = Path(os.environ['OUT_ROOT']) / "splits" / "val.csv"
test_path = Path(os.environ['OUT_ROOT']) / "splits" / "test.csv"

print("üìä Checking CSV files...\n")

for name, path in [("Manifest", manifest_path), ("Train", train_path), ("Val", val_path), ("Test", test_path)]:
    if path.exists():
        df = pd.read_csv(path)
        print(f"‚úÖ {name}: {len(df)} rows, columns: {list(df.columns)}")
    else:
        print(f"‚ùå {name}: File not found at {path}")

# Show a sample from the manifest
if manifest_path.exists():
    print("\nüìÑ Sample from manifest (first 3 rows):")
    sample = pd.read_csv(manifest_path).head(3)
    print(sample[['path', 'class_id', 'driver_id', 'camera', 'split']].to_string())


In [None]:
# Create a tiny balanced subset for quick testing
# Run this cell ONCE to create train_small.csv, then use it for fast experiments

import pandas as pd
from pathlib import Path
from ddriver import config

train_csv = Path(config.OUT_ROOT) / "splits" / "train.csv"
train_small_csv = Path(config.OUT_ROOT) / "splits" / "train_small.csv"

print(f"Reading {train_csv}...")
df = pd.read_csv(train_csv)

# Get 20 images per class (balanced)
small = df.groupby("class_id").head(20)

print(f"Original train.csv: {len(df)} images")
print(f"Small subset: {len(small)} images ({len(small) // 10} per class)")
print(f"\nClass distribution in small subset:")
print(small["class_id"].value_counts().sort_index())

small.to_csv(train_small_csv, index=False)
print(f"\n‚úÖ Saved to {train_small_csv}")

### ‚ö°Ô∏è Tiny-train option

Set `USE_TINY_SPLIT = True` in the training cell below to replace the heavy
`train.csv` with the quick `train_small.csv` (20 images per class). Validation
and test splits stay full so you still see realistic metrics.

Run the "Create a tiny balanced subset" cell once per Drive setup before
enabling this flag.


# üß™ 10) Test dataset.py and datamod.py

Now let's make sure the code that loads images actually works!

**What we're testing:**
1. **dataset.py** - Can it load a single image and give us the right info?
2. **datamod.py** - Can it create data loaders that give us batches of images?

**Why test this:**
- If these don't work, training will fail
- Better to catch problems now than later
- We want to see that images load correctly and labels are right


## üîç MediaPipe Crop Quality Audit

**How to run a FAST audit (recommended):**
1. Run the "Copy crops to /content" cell 32 below (copies from Drive to fast local SSD)
2. Run the audit cell (it auto-detects the local copy and uses it)

**Two modes:**
- **Full mode**: Uses `detection_metadata_{variant}.csv` (has face/hand detection info)
- **Lite mode**: Uses `manifest_{variant}.csv` (infers fallback from crop dimensions)

**What you get:**
- Numeric stats: fallback rates, ROI area/aspect distributions
- Breakdowns by class/camera/split
- Visual grids: "worst suspects" (tiny crops, fallbacks, extreme aspects)
- Per-class sample grids

**Path conventions:** All CSVs store relative paths. At runtime, paths are resolved using `config.OUT_ROOT` or `config.FAST_DATA` depending on where the crops are.


In [None]:
# üîç Run MediaPipe Crop Quality Audit
# Auto-detects whether crops are in /content (fast) or Drive, and which mode to use.

import matplotlib.pyplot as plt
from ddriver import config
from ddriver.data.mediapipe_audit import generate_audit_report, get_crop_root

VARIANT = "face_hands"  # must match the variant you extracted

# Auto-detect crop root (prefers FAST_DATA if available)
crop_root = get_crop_root(prefer_fast=True)
print(f"üìÅ Using crop root: {crop_root}")

# Look for metadata/manifest CSVs in the same location
metadata_csv = crop_root.parent / f"detection_metadata_{VARIANT}.csv"
manifest_csv = crop_root.parent / f"manifest_{VARIANT}.csv"

# Fall back to OUT_ROOT if not found in FAST_DATA
if not metadata_csv.exists() and not manifest_csv.exists():
    metadata_csv = config.OUT_ROOT / "mediapipe" / f"detection_metadata_{VARIANT}.csv"
    manifest_csv = config.OUT_ROOT / "mediapipe" / f"manifest_{VARIANT}.csv"

# Output directory (always on Drive for persistence)
audit_output = config.OUT_ROOT / "mediapipe" / "audit" / VARIANT

# Run the audit
if metadata_csv.exists():
    print(f"‚úÖ Found detection metadata: {metadata_csv}")
    audit_result = generate_audit_report(
        metadata_csv=metadata_csv,
        crop_root=crop_root,
        output_dir=audit_output,
        variant=VARIANT,
        n_samples=25,
        save_figures=True,
        show_figures=True,
    )
elif manifest_csv.exists():
    print(f"‚ö†Ô∏è Using manifest (lite mode): {manifest_csv}")
    audit_result = generate_audit_report(
        manifest_csv=manifest_csv,
        crop_root=crop_root,
        output_dir=audit_output,
        variant=VARIANT,
        n_samples=25,
        save_figures=True,
        show_figures=True,
    )
else:
    raise FileNotFoundError(
        f"Neither metadata nor manifest found. Run extraction first.\n"
        f"  Checked: {metadata_csv}\n"
        f"  Checked: {manifest_csv}"
    )

# Store results for later cells
stats = audit_result["stats"]
breakdowns = audit_result["breakdowns"]
lite_mode = audit_result["lite_mode"]
crop_root = audit_result["crop_root"]

print(f"\n‚úÖ Audit complete! Outputs saved to: {audit_output}")


In [None]:
# üìä Display Audit Summary Stats (uses results from previous cell)

print("=" * 60)
print("üìä DETECTION SUMMARY STATS")
if lite_mode:
    print("   [LITE MODE - face/hand detection info not available]")
print("=" * 60)
print(f"Total images processed: {stats['total_images']}")

# Face/hand detection (full mode only)
if not lite_mode and "face_detected_pct" in stats:
    print(f"\nüéØ Detection rates:")
    print(f"   Face detected: {stats['face_detected_pct']:.1f}%")
    print(f"   Hands: 0={stats['hands_0_pct']:.1f}%, 1={stats['hands_1_pct']:.1f}%, 2={stats['hands_2_pct']:.1f}%")

print(f"\n‚ö†Ô∏è  Fallback to full frame: {stats['fallback_count']} ({stats['fallback_pct']:.1f}%)")
print(f"   Fallback reasons: {stats['fallback_reasons']}")

print(f"\nüìê ROI statistics:")
print(f"   Area fraction: mean={stats['roi_area_frac_mean']:.3f}, std={stats['roi_area_frac_std']:.3f}")
print(f"   Area percentiles: 5%={stats['roi_area_frac_p5']:.3f}, 25%={stats['roi_area_frac_p25']:.3f}, 50%={stats['roi_area_frac_median']:.3f}")
print(f"   Aspect ratio: mean={stats['roi_aspect_mean']:.3f}, min={stats['roi_aspect_min']:.3f}, max={stats['roi_aspect_max']:.3f}")

# Detection types (full mode only)
if not lite_mode and "detection_used_distribution" in stats:
    print(f"\nüè∑Ô∏è  Detection types used:")
    for dtype, count in stats['detection_used_distribution'].items():
        pct = 100 * count / stats['total_images']
        print(f"   {dtype}: {count} ({pct:.1f}%)")


In [None]:
# üìã Display Breakdown by Class and Camera (uses results from audit cell)
import pandas as pd

# Class breakdown
if "class_id" in breakdowns:
    print("üìã BREAKDOWN BY CLASS")
    print("-" * 100)
    print(breakdowns["class_id"].to_string(index=False))
    print()
else:
    print("‚ö†Ô∏è Class breakdown not available")

# Camera breakdown
if "camera" in breakdowns:
    print("üìã BREAKDOWN BY CAMERA")
    print("-" * 100)
    print(breakdowns["camera"].to_string(index=False))
    print()
else:
    print("‚ö†Ô∏è Camera breakdown not available")

# Split breakdown (train/val/test)
if "split" in breakdowns:
    print("üìã BREAKDOWN BY SPLIT")
    print("-" * 100)
    print(breakdowns["split"].to_string(index=False))
else:
    print("‚ö†Ô∏è Split breakdown not available")


In [None]:
# üñºÔ∏è Re-display saved grids from disk (if you need to see them again)
# Note: Grids were already shown inline when you ran the audit cell above!

from IPython.display import display, Image as IPImage
from ddriver import config

VARIANT = "face_hands"
audit_output = config.OUT_ROOT / "mediapipe" / "audit" / VARIANT

print("üìÅ Saved grids location:", audit_output)
print("   (These were already displayed inline during the audit)\n")

# List what's available
grids = [
    ("grid_area_small.png", "üî¨ Smallest ROI crops"),
    ("grid_fallback.png", "‚ö†Ô∏è Fallback to full frame"),
    ("grid_aspect_extreme.png", "üìê Extreme aspect ratios"),
]
# Add full-mode only grids
if not lite_mode:
    grids.extend([
        ("grid_no_hands.png", "üë§ Face detected but no hands"),
        ("grid_one_hand.png", "‚úã Only one hand detected"),
    ])

for filename, title in grids:
    grid_path = audit_output / filename
    if grid_path.exists():
        print(f"‚úÖ {title}: {grid_path.name}")
    else:
        print(f"‚ö†Ô∏è {title}: not found")

# Uncomment below to re-display a specific grid:
# display(IPImage(filename=str(audit_output / "grid_area_small.png"), width=900))


In [None]:
# üè∑Ô∏è Per-Class Sample Grids info
# Note: These were already displayed inline during the audit!

from IPython.display import display, Image as IPImage
from ddriver import config

VARIANT = "face_hands"
audit_output = config.OUT_ROOT / "mediapipe" / "audit" / VARIANT

# Class labels for reference
CLASS_LABELS = {
    0: "c0 - Safe driving",
    1: "c1 - Texting (right)",
    2: "c2 - Phone (right)",
    3: "c3 - Texting (left)",
    4: "c4 - Phone (left)",
    5: "c5 - Radio",
    6: "c6 - Drinking",
    7: "c7 - Reaching behind",
    8: "c8 - Hair/makeup",
    9: "c9 - Talking to passenger",
}

print("üìÅ Per-class grids saved at:", audit_output)
print("   (Already displayed inline during audit)\n")

# Check what's available
for class_id in range(10):
    grid_path = audit_output / f"grid_class_{class_id}.png"
    label = CLASS_LABELS.get(class_id, f"Class {class_id}")
    if grid_path.exists():
        print(f"‚úÖ {label}")
    else:
        print(f"‚ö†Ô∏è {label}: not found")

# Uncomment to re-display specific class grids:
# for class_id in [0, 1, 7]:  # Adjust class IDs as needed
#     grid_path = audit_output / f"grid_class_{class_id}.png"
#     if grid_path.exists():
#         print(f"\nüè∑Ô∏è {CLASS_LABELS.get(class_id, f'Class {class_id}')}")
#         display(IPImage(filename=str(grid_path), width=900))


In [None]:
# Test 1: Can dataset.py load a single image?
# This is like testing if a worker can fetch one book from the library

from ddriver.data.dataset import AucDriverDataset
from torchvision import transforms as T
from pathlib import Path

# Get paths from config
manifest_csv = Path(os.environ['OUT_ROOT']) / "manifests" / "manifest.csv"
val_split_csv = Path(os.environ['OUT_ROOT']) / "splits" / "val.csv"

print("üß™ Test 1: Testing AucDriverDataset (dataset.py)")
print(f"   Manifest: {manifest_csv}")
print(f"   Using Val split: {val_split_csv}\n")

try:
    # Create a simple dataset (no fancy transforms, just load the image)
    simple_transforms = T.ToTensor()  # Just convert to tensor, no augmentation
    
    val_dataset = AucDriverDataset(
        manifest_csv=manifest_csv,
        split_csv=val_split_csv,
        transforms=simple_transforms
    )
    
    print(f"‚úÖ Dataset created! It has {len(val_dataset)} images in VAL split")
    
    # Try to load the first image
    print("\nüìñ Loading first image from VAL split...")
    sample = val_dataset[0]
    
    print(f"‚úÖ Image loaded successfully!")
    print(f"   Image shape: {sample['image'].shape} (should be [3, height, width])")
    print(f"   Label: {sample['label']} (should be 0-9)")
    print(f"   Driver ID: {sample['driver_id']} (VAL should have driver IDs)")
    print(f"   Camera: {sample['camera']} (should be 'cam1' or 'cam2')")
    print(f"   Path: {sample['path'][:80]}...")  # Show first 80 chars
    
    # Check that label is valid (0-9)
    if 0 <= sample['label'] <= 9:
        print(f"   ‚úÖ Label is valid (0-9)")
    else:
        print(f"   ‚ùå Label {sample['label']} is NOT in range 0-9!")
    
    # Check that VAL has driver IDs
    if sample['driver_id'] is not None:
        print(f"   ‚úÖ VAL split has driver ID (as expected)")
    else:
        print(f"   ‚ö†Ô∏è  VAL split missing driver ID (might be okay if this image wasn't in your DRIVER_RANGES)")
    
    print("\n‚úÖ Test 1 PASSED: dataset.py works!")
    
except Exception as e:
    print(f"\n‚ùå Test 1 FAILED: {e}")
    import traceback
    traceback.print_exc()
    raise


# üßµ 11) Full pipeline (train ‚Üí predict ‚Üí metrics)

Now that data loading is working, these next cells show how to:
1. Register the model you want (e.g., `resnet18` from timm)
2. Run training from the command line helper
3. Generate predictions from a checkpoint
4. Evaluate metrics and save all logs to Drive

> You can change the `RUN_TAG`, model name, epochs, etc. in the code below.


In [None]:
# Register models you want to use (run once per runtime)
# This example uses timm's convnext_tiny.

!pip -q install timm

from ddriver.models import registry

registry.register_timm_backbone("efficientnet_b0")
print("Available models:", registry.available_models()[:10])


In [None]:
!pip install "mediapipe==0.10.14" "protobuf<5" "opencv-python-headless<4.11"

In [None]:
# üß≠ Generate MediaPipe ROI crops (face, hands, face+hands)
# Run once per runtime/variant. Produces new manifest/splits under OUT_ROOT/mediapipe.
!pip -q install mediapipe opencv-python-headless

import subprocess
from pathlib import Path

VARIANT = "face_hands"  # choose: face | hands | face_hands
OUTPUT_ROOT = Path(OUT_ROOT) / "mediapipe"
manifest_csv = Path(OUT_ROOT) / "manifests" / "manifest.csv"
splits_root = Path(OUT_ROOT) / "splits"

extract_cmd = f"""
cd {PROJECT_ROOT}
python -m src.ddriver.data.mediapipe_extract \
  --manifest {manifest_csv} \
  --splits-root {splits_root} \
  --dataset-root {DATASET_ROOT} \
  --output-root {OUTPUT_ROOT} \
  --variant {VARIANT} \
  --max-side 720 \
  --model-complexity 2 \
  --min-detection-area-frac 0.05 \
  --min-area-frac 0.10 \
  --min-aspect 0.20 \
  --pad-frac 0.20 \
  --face-extra-down-frac 0.35 \
  --overwrite
"""

print("Running MediaPipe extraction for variant:", VARIANT)
print(extract_cmd)
proc = subprocess.Popen(
    extract_cmd,
    shell=True,
    text=True,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
)
if proc.stdout is None:
    raise RuntimeError("Extraction process has no stdout pipe.")
for line in proc.stdout:
    print(line, end="")
proc.wait()
if proc.returncode != 0:
    raise RuntimeError("MediaPipe extraction failed. Check logs above.")



## üéØ YOLO-World ROI Extraction (Alternative to MediaPipe)

YOLO-World uses open-vocabulary detection to find faces and hands without custom training.
This is an alternative to the MediaPipe pipeline above.

**Advantages over MediaPipe:**
- Better detection accuracy for occluded/partial views
- Confidence scores for filtering
- Faster inference on GPU

**Choose ONE pipeline:** Either run MediaPipe extraction OR YOLO extraction, not both.
The training cell below lets you pick which pipeline's crops to use (`USE_MEDIAPIPE` vs `USE_YOLO`).

In [None]:
# üéØ Generate YOLO-World ROI crops (face, hands, face+hands)
# Alternative to MediaPipe - uses open-vocabulary detection.
# Run once per runtime/variant. Produces new manifest/splits under OUT_ROOT/yolo.

!pip -q install ultralytics

import subprocess
from pathlib import Path

# ===== CONFIGURATION =====
VARIANT = "face_hands"  # choose: face | hands | face_hands
YOLO_OUTPUT_ROOT = Path(OUT_ROOT) / "yolo"
manifest_csv = Path(OUT_ROOT) / "manifests" / "manifest.csv"
splits_root = Path(OUT_ROOT) / "splits"

# ===== TEST MODE OPTIONS (toggle these!) =====
# Option 1: Use train_small.csv for quick testing (~200 images)
TEST_MODE = True  # Set False for full extraction
SAMPLE_CSV = Path(OUT_ROOT) / "splits" / "train_small.csv"  # Small balanced subset

# Option 2: Limit to first N images (even faster for debugging)
LIMIT = None  # Set to e.g. 50 for super quick test, None for no limit

# Build command
sample_flag = f"--sample-csv {SAMPLE_CSV}" if TEST_MODE and SAMPLE_CSV.exists() else ""
limit_flag = f"--limit {LIMIT}" if LIMIT else ""

extract_cmd = f"""
cd {PROJECT_ROOT}
python -m src.ddriver.data.yolo_extract \
  --manifest {manifest_csv} \
  --splits-root {splits_root} \
  --dataset-root {DATASET_ROOT} \
  --output-root {YOLO_OUTPUT_ROOT} \
  --variant {VARIANT} \
  --model-size m \
  --confidence 0.15 \
  --min-detection-area-frac 0.03 \
  --min-area-frac 0.08 \
  --min-aspect 0.20 \
  --pad-frac 0.20 \
  {sample_flag} \
  {limit_flag} \
  --overwrite
"""

if TEST_MODE:
    print("‚ö° TEST MODE: Using small sample for quick testing")
    print(f"   Sample CSV: {SAMPLE_CSV}")
    if LIMIT:
        print(f"   Limit: {LIMIT} images")
else:
    print("ü™µ FULL MODE: Processing all images")

print(f"\nRunning YOLO-World extraction for variant: {VARIANT}")
print(extract_cmd)
proc = subprocess.Popen(
    extract_cmd,
    shell=True,
    text=True,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
)
if proc.stdout is None:
    raise RuntimeError("Extraction process has no stdout pipe.")
for line in proc.stdout:
    print(line, end="")
proc.wait()
if proc.returncode != 0:
    raise RuntimeError("YOLO extraction failed. Check logs above.")


## üöö Copy YOLO crops to /content (optional, faster I/O)

Use this if training/audit from Drive is slow. It copies the generated YOLO crops/CSVs into `/content/data/yolo/<variant>` and updates paths for the current runtime.


In [None]:
# üöö Copy YOLO crops to /content for faster I/O
import os, shutil
from pathlib import Path

# Set your variant to the one you extracted already
YOLO_VARIANT = "face_hands"  # face | hands | face_hands
SRC_ROOT = Path(OUT_ROOT) / "yolo"
SRC_VARIANT_DIR = SRC_ROOT / YOLO_VARIANT
DST_ROOT = Path("/content/data/yolo") / YOLO_VARIANT

if not SRC_VARIANT_DIR.exists():
    raise FileNotFoundError(f"Source YOLO folder not found: {SRC_VARIANT_DIR}\nRun the YOLO extraction cell first.")

print(f"Copying YOLO crops from {SRC_VARIANT_DIR} -> {DST_ROOT}")
DST_ROOT.mkdir(parents=True, exist_ok=True)

# Copy images directory tree
file_count = 0
for src_dir, _, files in os.walk(SRC_VARIANT_DIR):
    rel_dir = Path(src_dir).relative_to(SRC_VARIANT_DIR)
    dst_dir = DST_ROOT / rel_dir
    dst_dir.mkdir(parents=True, exist_ok=True)
    for fname in files:
        src_path = Path(src_dir) / fname
        dst_path = dst_dir / fname
        if not dst_path.exists():
            shutil.copy2(src_path, dst_path)
            file_count += 1
print(f"   Copied {file_count} image files")

# Copy manifest/split CSVs that live one level above the variant folder
csv_names = [
    f"manifest_{YOLO_VARIANT}.csv",
    f"train_{YOLO_VARIANT}.csv",
    f"val_{YOLO_VARIANT}.csv",
    f"test_{YOLO_VARIANT}.csv",
    f"detection_metadata_{YOLO_VARIANT}.csv",  # for auditing
]
for fname in csv_names:
    src_csv = SRC_ROOT / fname
    if not src_csv.exists():
        print(f"   ‚ö†Ô∏è Skipping {fname} (not found)")
        continue
    dst_csv = DST_ROOT.parent / fname
    shutil.copy2(src_csv, dst_csv)
    print(f"   ‚úÖ Copied {fname}")

# Point env vars for this runtime to the local YOLO copy
os.environ["YOLO_ROOT_LOCAL"] = str(DST_ROOT.parent)
print(f"\n‚úÖ Copy complete! Set USE_YOLO=True in training cell.")
print(f"   YOLO_ROOT_LOCAL = {os.environ['YOLO_ROOT_LOCAL']}")


## üîç YOLO Crop Quality Audit

Quick stats and visual inspection of YOLO-World crops. Uses `detection_metadata_{variant}.csv`.


In [None]:
# üîç YOLO Crop Quality Audit
import pandas as pd
import numpy as np
from pathlib import Path
from ddriver import config

VARIANT = "face_hands"  # must match the variant you extracted

# Auto-detect: prefer local copy if available
yolo_root_local = Path(os.environ.get("YOLO_ROOT_LOCAL", ""))
if yolo_root_local.exists():
    yolo_root = yolo_root_local
else:
    yolo_root = Path(OUT_ROOT) / "yolo"

metadata_csv = yolo_root / f"detection_metadata_{VARIANT}.csv"
if not metadata_csv.exists():
    raise FileNotFoundError(f"Detection metadata not found: {metadata_csv}\nRun YOLO extraction first.")

print(f"üìÅ Loading metadata from: {metadata_csv}")
df = pd.read_csv(metadata_csv)

# Summary stats
n_total = len(df)
n_fallback = df["fallback_to_full"].sum()
n_face = (df["face_count"] > 0).sum()
n_hands = (df["hand_count"] > 0).sum()
n_face_and_hands = ((df["face_count"] > 0) & (df["hand_count"] > 0)).sum()
avg_face_conf = df.loc[df["face_confidence"] > 0, "face_confidence"].mean()
avg_hand_conf = df.loc[df["hand_confidence"] > 0, "hand_confidence"].mean()

print("=" * 60)
print("üìä YOLO DETECTION SUMMARY")
print("=" * 60)
print(f"Total images: {n_total}")
print(f"\nüéØ Detection rates:")
print(f"   Face detected: {n_face} ({100*n_face/n_total:.1f}%)")
print(f"   Hands detected: {n_hands} ({100*n_hands/n_total:.1f}%)")
print(f"   Both face+hands: {n_face_and_hands} ({100*n_face_and_hands/n_total:.1f}%)")
print(f"\nüìà Confidence scores:")
print(f"   Avg face confidence: {avg_face_conf:.3f}" if not np.isnan(avg_face_conf) else "   Avg face confidence: N/A")
print(f"   Avg hand confidence: {avg_hand_conf:.3f}" if not np.isnan(avg_hand_conf) else "   Avg hand confidence: N/A")
print(f"\n‚ö†Ô∏è  Fallback to full frame: {n_fallback} ({100*n_fallback/n_total:.1f}%)")

# Fallback reason breakdown
fallback_df = df[df["fallback_to_full"]]
if len(fallback_df) > 0:
    print(f"   Fallback reasons:")
    for reason, count in fallback_df["fallback_reason"].value_counts().items():
        print(f"      - {reason}: {count} ({100*count/n_total:.1f}%)")

# ROI stats (for non-fallback images)
non_fallback = df[~df["fallback_to_full"]]
if len(non_fallback) > 0:
    print(f"\nüìê ROI statistics (non-fallback only, n={len(non_fallback)}):")
    print(f"   Raw detection area: mean={non_fallback['raw_detection_area_frac'].mean():.3f}, "
          f"std={non_fallback['raw_detection_area_frac'].std():.3f}")
    print(f"   Final ROI area: mean={non_fallback['roi_area_frac'].mean():.3f}, "
          f"std={non_fallback['roi_area_frac'].std():.3f}")
    print(f"   Aspect ratio: mean={non_fallback['roi_aspect'].mean():.3f}, "
          f"min={non_fallback['roi_aspect'].min():.3f}, max={non_fallback['roi_aspect'].max():.3f}")


In [None]:
# üìã YOLO Breakdown by Camera and Class
print("\nüìã BREAKDOWN BY CAMERA")
print("-" * 80)
camera_stats = df.groupby("camera").agg({
    "fallback_to_full": ["sum", "mean"],
    "roi_area_frac": "mean",
    "face_count": lambda x: (x > 0).mean(),
    "hand_count": lambda x: (x > 0).mean(),
}).round(3)
camera_stats.columns = ["fallback_count", "fallback_pct", "mean_roi_area", "face_rate", "hand_rate"]
camera_stats["fallback_pct"] = (camera_stats["fallback_pct"] * 100).round(1)
camera_stats["face_rate"] = (camera_stats["face_rate"] * 100).round(1)
camera_stats["hand_rate"] = (camera_stats["hand_rate"] * 100).round(1)
print(camera_stats.to_string())

print("\nüìã BREAKDOWN BY CLASS")
print("-" * 80)
class_stats = df.groupby("class_id").agg({
    "fallback_to_full": ["sum", "mean"],
    "roi_area_frac": "mean",
    "face_count": lambda x: (x > 0).mean(),
    "hand_count": lambda x: (x > 0).mean(),
}).round(3)
class_stats.columns = ["fallback_count", "fallback_pct", "mean_roi_area", "face_rate", "hand_rate"]
class_stats["fallback_pct"] = (class_stats["fallback_pct"] * 100).round(1)
class_stats["face_rate"] = (class_stats["face_rate"] * 100).round(1)
class_stats["hand_rate"] = (class_stats["hand_rate"] * 100).round(1)
print(class_stats.to_string())


In [None]:
# üñºÔ∏è Visual Sample Grid - YOLO Crops
import matplotlib.pyplot as plt
import cv2

def show_sample_grid(df_subset, title, crop_root, n_samples=12, n_cols=4):
    """Display a grid of sample crops from a DataFrame subset."""
    samples = df_subset.sample(n=min(n_samples, len(df_subset)), random_state=42)
    n_rows = (len(samples) + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
    axes = axes.flatten() if n_rows > 1 or n_cols > 1 else [axes]
    
    for ax, (_, row) in zip(axes, samples.iterrows()):
        crop_path = crop_root / row["cropped_path"]
        if crop_path.exists():
            img = cv2.imread(str(crop_path))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            ax.imshow(img)
            label = f"c{int(row['class_id'])} | area={row['roi_area_frac']:.2f}"
            if row["fallback_to_full"]:
                label += " [FALLBACK]"
            ax.set_title(label, fontsize=9)
        else:
            ax.text(0.5, 0.5, "Not found", ha="center", va="center")
        ax.axis("off")
    
    # Hide empty subplots
    for ax in axes[len(samples):]:
        ax.axis("off")
    
    plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    plt.show()

# Show samples: fallbacks, smallest ROIs, by class
crop_root = yolo_root  # Use the root determined in the audit cell

# Fallback samples
fallback_samples = df[df["fallback_to_full"]]
if len(fallback_samples) > 0:
    show_sample_grid(fallback_samples, "‚ö†Ô∏è Fallback to Full Frame Samples", crop_root)

# Smallest ROI samples (non-fallback)
smallest = non_fallback.nsmallest(12, "roi_area_frac")
if len(smallest) > 0:
    show_sample_grid(smallest, "üî¨ Smallest ROI Crops (non-fallback)", crop_root)

# Per-class samples (one random per class)
print("\nüè∑Ô∏è Sample crop per class:")
for class_id in sorted(df["class_id"].unique()):
    class_df = df[df["class_id"] == class_id]
    if len(class_df) > 0:
        show_sample_grid(class_df, f"Class {int(class_id)} Samples", crop_root, n_samples=8, n_cols=4)


## üîÄ Hybrid ROI Extraction (RetinaFace + MediaPipe Hands)

**Best accuracy option!** Uses specialized models:
- **RetinaFace**: State-of-the-art face detection (handles occlusion, angles)
- **MediaPipe Hands**: Google's dedicated hand model (much better than Holistic)

This typically gives the lowest fallback rate and best hand detection.


In [None]:
# üîÄ Generate Hybrid ROI crops (InsightFace + MediaPipe Hands)
# Best accuracy option - uses specialized models for face and hands.
# InsightFace uses ONNX (no TensorFlow dependency!)
!pip -q install insightface onnxruntime mediapipe

import subprocess
from pathlib import Path

# ===== CONFIGURATION =====
VARIANT = "face_hands"  # choose: face | hands | face_hands
HYBRID_OUTPUT_ROOT = Path(OUT_ROOT) / "hybrid"
manifest_csv = Path(OUT_ROOT) / "manifests" / "manifest.csv"
splits_root = Path(OUT_ROOT) / "splits"

# ===== AUTO-DETECT LOCAL vs DRIVE IMAGES =====
# If you ran the "copy + compress dataset" cells (46/47), images will be in /content/data
# We auto-detect and use local images for faster I/O, falling back to Drive if not available
LOCAL_DATASET_ROOT = Path("/content/data/auc.distracted.driver.dataset_v2")
DRIVE_DATASET_ROOT = Path(DATASET_ROOT)

if LOCAL_DATASET_ROOT.exists() and any(LOCAL_DATASET_ROOT.iterdir()):
    EFFECTIVE_DATASET_ROOT = LOCAL_DATASET_ROOT
    print(f"üöÄ FAST MODE: Using local images from {LOCAL_DATASET_ROOT}")
else:
    EFFECTIVE_DATASET_ROOT = DRIVE_DATASET_ROOT
    print(f"üìÅ Using images from Drive: {DRIVE_DATASET_ROOT}")
    print("   üí° Tip: Run cells 46/47 first to copy images to /content for faster extraction!")

# ===== TEST MODE OPTIONS (toggle these!) =====
# Option 1: Use train_small.csv for quick testing (~200 images)
TEST_MODE = True  # Set False for full extraction
SAMPLE_CSV = Path(OUT_ROOT) / "splits" / "train_small.csv"  # Small balanced subset

# Option 2: Limit to first N images (even faster for debugging)
LIMIT = None  # Set to e.g. 50 for super quick test, None for no limit

# Build command
sample_flag = f"--sample-csv {SAMPLE_CSV}" if TEST_MODE and SAMPLE_CSV.exists() else ""
limit_flag = f"--limit {LIMIT}" if LIMIT else ""

extract_cmd = f"""
cd {PROJECT_ROOT}
python -m src.ddriver.data.hybrid_extract \
  --manifest {manifest_csv} \
  --splits-root {splits_root} \
  --dataset-root {EFFECTIVE_DATASET_ROOT} \
  --output-root {HYBRID_OUTPUT_ROOT} \
  --variant {VARIANT} \
  --min-detection-area-frac 0.03 \
  --min-area-frac 0.08 \
  --min-aspect 0.20 \
  --pad-frac 0.20 \
  {sample_flag} \
  {limit_flag} \
  --overwrite
"""

if TEST_MODE:
    print("‚ö° TEST MODE: Using small sample for quick testing")
    print(f"   Sample CSV: {SAMPLE_CSV}")
    if LIMIT:
        print(f"   Limit: {LIMIT} images")
else:
    print("ü™µ FULL MODE: Processing all images")

print(f"\nRunning Hybrid extraction (InsightFace + MediaPipe Hands) for variant: {VARIANT}")
print(extract_cmd)
proc = subprocess.Popen(
    extract_cmd,
    shell=True,
    text=True,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
)
if proc.stdout is None:
    raise RuntimeError("Extraction process has no stdout pipe.")
for line in proc.stdout:
    print(line, end="")
proc.wait()
if proc.returncode != 0:
    raise RuntimeError("Hybrid extraction failed. Check logs above.")


## üöö Copy Hybrid crops to /content (optional, faster I/O)

Use this if training/audit from Drive is slow. Copies crops to local SSD.


In [None]:
# üöö Copy Hybrid crops to /content for faster I/O
import os, shutil
from pathlib import Path

# Set your variant to the one you extracted already
HYBRID_VARIANT = "face_hands"  # face | hands | face_hands
SRC_ROOT = Path(OUT_ROOT) / "hybrid"
SRC_VARIANT_DIR = SRC_ROOT / HYBRID_VARIANT
DST_ROOT = Path("/content/data/hybrid") / HYBRID_VARIANT

if not SRC_VARIANT_DIR.exists():
    raise FileNotFoundError(f"Source Hybrid folder not found: {SRC_VARIANT_DIR}\nRun the Hybrid extraction cell first.")

print(f"Copying Hybrid crops from {SRC_VARIANT_DIR} -> {DST_ROOT}")
DST_ROOT.mkdir(parents=True, exist_ok=True)

# Copy images directory tree
file_count = 0
for src_dir, _, files in os.walk(SRC_VARIANT_DIR):
    rel_dir = Path(src_dir).relative_to(SRC_VARIANT_DIR)
    dst_dir = DST_ROOT / rel_dir
    dst_dir.mkdir(parents=True, exist_ok=True)
    for fname in files:
        src_path = Path(src_dir) / fname
        dst_path = dst_dir / fname
        if not dst_path.exists():
            shutil.copy2(src_path, dst_path)
            file_count += 1
print(f"   Copied {file_count} image files")

# Copy manifest/split CSVs that live one level above the variant folder
csv_names = [
    f"manifest_{HYBRID_VARIANT}.csv",
    f"train_{HYBRID_VARIANT}.csv",
    f"val_{HYBRID_VARIANT}.csv",
    f"test_{HYBRID_VARIANT}.csv",
    f"detection_metadata_{HYBRID_VARIANT}.csv",  # for auditing
]
for fname in csv_names:
    src_csv = SRC_ROOT / fname
    if not src_csv.exists():
        print(f"   ‚ö†Ô∏è Skipping {fname} (not found)")
        continue
    dst_csv = DST_ROOT.parent / fname
    shutil.copy2(src_csv, dst_csv)
    print(f"   ‚úÖ Copied {fname}")

# Point env vars for this runtime to the local Hybrid copy
os.environ["HYBRID_ROOT_LOCAL"] = str(DST_ROOT.parent)
print(f"\n‚úÖ Copy complete! Set USE_HYBRID=True in training cell.")
print(f"   HYBRID_ROOT_LOCAL = {os.environ['HYBRID_ROOT_LOCAL']}")


## üîç Hybrid Crop Quality Audit

Quick stats and visual inspection of Hybrid crops (InsightFace + MediaPipe Hands).


In [None]:
# üîç Hybrid Crop Quality Audit
import pandas as pd
import numpy as np
from pathlib import Path

VARIANT = "face_hands"  # must match the variant you extracted

# Auto-detect: prefer local copy if available
hybrid_root_local = Path(os.environ.get("HYBRID_ROOT_LOCAL", ""))
if hybrid_root_local.exists():
    hybrid_root = hybrid_root_local
else:
    hybrid_root = Path(OUT_ROOT) / "hybrid"

metadata_csv = hybrid_root / f"detection_metadata_{VARIANT}.csv"
if not metadata_csv.exists():
    raise FileNotFoundError(f"Detection metadata not found: {metadata_csv}\nRun Hybrid extraction first.")

print(f"üìÅ Loading metadata from: {metadata_csv}")
df = pd.read_csv(metadata_csv)

# Summary stats
n_total = len(df)
n_fallback = df["fallback_to_full"].sum()
n_face = (df["face_count"] > 0).sum()
n_left_hand = df["left_hand_detected"].sum()
n_right_hand = df["right_hand_detected"].sum()
n_both_hands = ((df["left_hand_detected"]) & (df["right_hand_detected"])).sum()
n_any_hands = ((df["left_hand_detected"]) | (df["right_hand_detected"])).sum()
n_face_and_hands = ((df["face_count"] > 0) & (n_any_hands > 0)).sum()
avg_face_conf = df.loc[df["face_confidence"] > 0, "face_confidence"].mean()
avg_left_conf = df.loc[df["left_hand_confidence"] > 0, "left_hand_confidence"].mean()
avg_right_conf = df.loc[df["right_hand_confidence"] > 0, "right_hand_confidence"].mean()

print("=" * 60)
print("üìä HYBRID DETECTION SUMMARY (RetinaFace + MediaPipe Hands)")
print("=" * 60)
print(f"Total images: {n_total}")
print(f"\nüéØ Detection rates:")
print(f"   Face detected (RetinaFace): {n_face} ({100*n_face/n_total:.1f}%)")
print(f"   Left hand (MediaPipe): {n_left_hand} ({100*n_left_hand/n_total:.1f}%)")
print(f"   Right hand (MediaPipe): {n_right_hand} ({100*n_right_hand/n_total:.1f}%)")
print(f"   Both hands: {n_both_hands} ({100*n_both_hands/n_total:.1f}%)")
print(f"   Any hand: {n_any_hands} ({100*n_any_hands/n_total:.1f}%)")
print(f"\nüìà Confidence scores:")
print(f"   Avg face confidence: {avg_face_conf:.3f}" if not np.isnan(avg_face_conf) else "   Avg face confidence: N/A")
print(f"   Avg left hand confidence: {avg_left_conf:.3f}" if not np.isnan(avg_left_conf) else "   Avg left hand confidence: N/A")
print(f"   Avg right hand confidence: {avg_right_conf:.3f}" if not np.isnan(avg_right_conf) else "   Avg right hand confidence: N/A")
print(f"\n‚ö†Ô∏è  Fallback to full frame: {n_fallback} ({100*n_fallback/n_total:.1f}%)")

# Fallback reason breakdown
fallback_df = df[df["fallback_to_full"]]
if len(fallback_df) > 0:
    print(f"   Fallback reasons:")
    for reason, count in fallback_df["fallback_reason"].value_counts().items():
        print(f"      - {reason}: {count} ({100*count/n_total:.1f}%)")

# ROI stats (for non-fallback images)
non_fallback = df[~df["fallback_to_full"]]
if len(non_fallback) > 0:
    print(f"\nüìê ROI statistics (non-fallback only, n={len(non_fallback)}):")
    print(f"   Raw detection area: mean={non_fallback['raw_detection_area_frac'].mean():.3f}, "
          f"std={non_fallback['raw_detection_area_frac'].std():.3f}")
    print(f"   Final ROI area: mean={non_fallback['roi_area_frac'].mean():.3f}, "
          f"std={non_fallback['roi_area_frac'].std():.3f}")


In [None]:
# üìã Hybrid Breakdown by Camera and Class
print("\nüìã BREAKDOWN BY CAMERA")
print("-" * 80)
camera_stats = df.groupby("camera").agg({
    "fallback_to_full": ["sum", "mean"],
    "roi_area_frac": "mean",
    "face_count": lambda x: (x > 0).mean(),
    "left_hand_detected": "mean",
    "right_hand_detected": "mean",
}).round(3)
camera_stats.columns = ["fallback_count", "fallback_pct", "mean_roi_area", "face_rate", "left_hand_rate", "right_hand_rate"]
camera_stats["fallback_pct"] = (camera_stats["fallback_pct"] * 100).round(1)
camera_stats["face_rate"] = (camera_stats["face_rate"] * 100).round(1)
camera_stats["left_hand_rate"] = (camera_stats["left_hand_rate"] * 100).round(1)
camera_stats["right_hand_rate"] = (camera_stats["right_hand_rate"] * 100).round(1)
print(camera_stats.to_string())

print("\nüìã BREAKDOWN BY CLASS")
print("-" * 80)
class_stats = df.groupby("class_id").agg({
    "fallback_to_full": ["sum", "mean"],
    "roi_area_frac": "mean",
    "face_count": lambda x: (x > 0).mean(),
    "left_hand_detected": "mean",
    "right_hand_detected": "mean",
}).round(3)
class_stats.columns = ["fallback_count", "fallback_pct", "mean_roi_area", "face_rate", "left_hand_rate", "right_hand_rate"]
class_stats["fallback_pct"] = (class_stats["fallback_pct"] * 100).round(1)
class_stats["face_rate"] = (class_stats["face_rate"] * 100).round(1)
class_stats["left_hand_rate"] = (class_stats["left_hand_rate"] * 100).round(1)
class_stats["right_hand_rate"] = (class_stats["right_hand_rate"] * 100).round(1)
print(class_stats.to_string())


## üöÇ 11.1 Train a model (adjust these knobs)

- Choose a `RUN_TAG` so logs/checkpoints go into `TFM/checkpoints/runs/<tag>/...`
- Set epochs/batch size to something small for a dry run (1 epoch, 16 batch)
- This command uses the CLI helper (`python -m src.ddriver.cli.train ...`)
- Logs + checkpoints are saved automatically to Google Drive


In [None]:
import os
import subprocess, textwrap, json, time, threading
from pathlib import Path

# ConvNeXt-Tiny baseline run (change RUN_TAG for each experiment)
RUN_TAG = "effb0_noLabelSmoothingCORRECTED"   # change me for each experiment
MODEL_NAME = "efficientnet_b0"                  # must be registered above (timm)

# Training hyperparameters (EfficientNet-B0)
EPOCHS = 15
BATCH_SIZE = 32
NUM_WORKERS = 2
IMAGE_SIZE = 224
LR = 3e-4                        # per provided hyperparams
LR_DROP_EPOCH = None             # no LR drop
LR_DROP_FACTOR = 0.1
LABEL_SMOOTHING = 0.0
USE_TINY_SPLIT = False

# ROI crop pipeline selection (pick ONE, set others to False)
USE_MEDIAPIPE = False             # set True to use MediaPipe ROI crops
USE_YOLO = False                  # set True to use YOLO-World ROI crops
USE_HYBRID = True                 # set True to use Hybrid (RetinaFace + MediaPipe Hands) crops
ROI_VARIANT = "face_hands"        # face | hands | face_hands

# Validate only one pipeline is selected
active_pipelines = sum([USE_MEDIAPIPE, USE_YOLO, USE_HYBRID])
if active_pipelines > 1:
    raise ValueError("Pick ONE pipeline: set only one of USE_MEDIAPIPE, USE_YOLO, USE_HYBRID to True.")

if USE_HYBRID:
    hybrid_root = Path(os.environ.get("HYBRID_ROOT_LOCAL", Path(OUT_ROOT) / "hybrid"))
    manifest_csv = hybrid_root / f"manifest_{ROI_VARIANT}.csv"
    train_split = f"train_{ROI_VARIANT}.csv" if not USE_TINY_SPLIT else f"train_small_{ROI_VARIANT}.csv"
    val_csv = hybrid_root / f"val_{ROI_VARIANT}.csv"
    test_csv = hybrid_root / f"test_{ROI_VARIANT}.csv"
    train_csv = hybrid_root / train_split
    print(f"üîÄ Using Hybrid (RetinaFace + MediaPipe Hands) ROI variant: {ROI_VARIANT}")
    print(f"   hybrid_root = {hybrid_root}")
elif USE_YOLO:
    yolo_root = Path(os.environ.get("YOLO_ROOT_LOCAL", Path(OUT_ROOT) / "yolo"))
    manifest_csv = yolo_root / f"manifest_{ROI_VARIANT}.csv"
    train_split = f"train_{ROI_VARIANT}.csv" if not USE_TINY_SPLIT else f"train_small_{ROI_VARIANT}.csv"
    val_csv = yolo_root / f"val_{ROI_VARIANT}.csv"
    test_csv = yolo_root / f"test_{ROI_VARIANT}.csv"
    train_csv = yolo_root / train_split
    print(f"üéØ Using YOLO-World ROI variant: {ROI_VARIANT}")
    print(f"   yolo_root = {yolo_root}")
elif USE_MEDIAPIPE:
    mp_root = Path(os.environ.get("MEDIAPIPE_ROOT_LOCAL", Path(OUT_ROOT) / "mediapipe"))
    manifest_csv = mp_root / f"manifest_{ROI_VARIANT}.csv"
    train_split = f"train_{ROI_VARIANT}.csv" if not USE_TINY_SPLIT else f"train_small_{ROI_VARIANT}.csv"
    val_csv = mp_root / f"val_{ROI_VARIANT}.csv"
    test_csv = mp_root / f"test_{ROI_VARIANT}.csv"
    train_csv = mp_root / train_split
    print(f"üß≠ Using MediaPipe ROI variant: {ROI_VARIANT}")
    print(f"   mp_root = {mp_root}")
else:
    manifest_csv = Path(OUT_ROOT) / "manifests" / "manifest.csv"
    train_split = "train_small.csv" if USE_TINY_SPLIT else "train.csv"
    train_csv = Path(OUT_ROOT) / "splits" / train_split
    val_csv = Path(OUT_ROOT) / "splits" / "val.csv"
    test_csv = Path(OUT_ROOT) / "splits" / "test.csv"
    print("üì∑ Using full-frame images (no ROI cropping)")

if USE_TINY_SPLIT:
    print("‚ö° Using train_small.csv (20 imgs/class) for a quick smoke test.")
else:
    print("ü™µ Using full train.csv for a proper run.")

train_cmd = textwrap.dedent(f"""
cd {PROJECT_ROOT}
python -m src.ddriver.cli.train \
    --model-name {MODEL_NAME} \
    --epochs {EPOCHS} \
    --batch-size {BATCH_SIZE} \
    --num-workers {NUM_WORKERS} \
    --image-size {IMAGE_SIZE} \
    --lr {LR} \
    --weight-decay 0.0 \
    --optimizer adam \
    --label-smoothing {LABEL_SMOOTHING} \
    --out-tag {RUN_TAG} \
    --manifest-csv {manifest_csv} \
    --train-csv {train_csv} \
    --val-csv {val_csv} \
    --test-csv {test_csv}
""")

print("Running training command and streaming logs:\n", train_cmd)

proc = subprocess.Popen(
    train_cmd,
    shell=True,
    text=True,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
)

# Background GPU monitor (prints every 5 seconds)
def _gpu_monitor():
    while proc.poll() is None:
        try:
            stats = subprocess.check_output(
                "nvidia-smi --query-gpu=utilization.gpu,memory.used,memory.total "
                "--format=csv,nounits,noheader",
                shell=True,
            ).decode("utf-8").strip()
            print(f"[GPU] util%, mem_used, mem_total :: {stats}")
        except Exception as exc:
            print("[GPU] Could not query nvidia-smi:", exc)
        time.sleep(5)

monitor_thread = threading.Thread(target=_gpu_monitor, daemon=True)
monitor_thread.start()

# Stream CLI stdout live
if proc.stdout is None:
    raise RuntimeError("Training process has no stdout pipe.")

for line in proc.stdout:
    print(line, end="")

proc.wait()
monitor_thread.join(timeout=0)

print("\n‚úÖ Training run complete!\n")

# --- Display every epoch's metrics so the notebook shows the learning curve ---
run_base = Path(CKPT_ROOT) / "runs" / RUN_TAG
all_runs = sorted(run_base.glob("*/"))
if not all_runs:
    raise FileNotFoundError(f"No run folders found under {run_base}")
latest_run = all_runs[-1]

history_path = latest_run / "history.json"
if not history_path.exists():
    raise FileNotFoundError(f"history.json not found in {latest_run}")

history = json.loads(history_path.read_text()).get("history", [])
print(f"üìä Epoch metrics for run: {latest_run.name}")
for record in history:
    train_metrics = record.get("train", {})
    val_metrics = record.get("val", {}) or {}
    train_loss = train_metrics.get("loss")
    train_acc = train_metrics.get("accuracy")
    val_loss = val_metrics.get("loss")
    val_acc = val_metrics.get("accuracy")
    val_str = (
        f"val_loss={val_loss:.4f} acc={val_acc:.4f}"
        if val_loss is not None and val_acc is not None
        else "val_loss=‚Äî val_acc=‚Äî"
    )
    print(
        f"  Epoch {record['epoch']:>2}: "
        f"train_loss={train_loss:.4f} acc={train_acc:.4f}  "
        f"{val_str}"
    )
print("")


## üìù 11.1a Log training summary to Google Sheet
Run this right after the training cell finishes. It looks up the newest run under `CKPT_ROOT/runs/<RUN_TAG>`, grabs the best/final train + val accuracies, and logs the model/hyperparams so you can compare experiments before doing predictions or metrics.


## üöö Copy MediaPipe crops to /content (optional, faster I/O)
Use this if training from Drive is slow. It copies the generated MediaPipe crops/CSVs into `/content/data/mediapipe/<variant>` and updates paths for the current runtime. Does not affect the original full-image copy cell.


In [None]:
import os, shutil
from pathlib import Path

# Set your variant to the one you extracted already
MEDIAPIPE_VARIANT = "face_hands"  # face | hands | face_hands
SRC_ROOT = Path(OUT_ROOT) / "mediapipe"
SRC_VARIANT_DIR = SRC_ROOT / MEDIAPIPE_VARIANT
DST_ROOT = Path("/content/data/mediapipe") / MEDIAPIPE_VARIANT

if not SRC_VARIANT_DIR.exists():
    raise FileNotFoundError(f"Source mediapipe folder not found: {SRC_VARIANT_DIR}\nRun the extraction cell first.")

print(f"Copying MediaPipe crops from {SRC_VARIANT_DIR} -> {DST_ROOT}")
DST_ROOT.mkdir(parents=True, exist_ok=True)

# Copy images directory tree
for src_dir, _, files in os.walk(SRC_VARIANT_DIR):
    rel_dir = Path(src_dir).relative_to(SRC_VARIANT_DIR)
    dst_dir = DST_ROOT / rel_dir
    dst_dir.mkdir(parents=True, exist_ok=True)
    for fname in files:
        src_path = Path(src_dir) / fname
        dst_path = dst_dir / fname
        if not dst_path.exists():
            shutil.copy2(src_path, dst_path)

# Copy manifest/split CSVs that live one level above the variant folder
csv_names = [
    f"manifest_{MEDIAPIPE_VARIANT}.csv",
    f"train_{MEDIAPIPE_VARIANT}.csv",
    f"val_{MEDIAPIPE_VARIANT}.csv",
    f"test_{MEDIAPIPE_VARIANT}.csv",
]
for fname in csv_names:
    src_csv = SRC_ROOT / fname
    if not src_csv.exists():
        raise FileNotFoundError(src_csv)
    dst_csv = DST_ROOT.parent / fname
    shutil.copy2(src_csv, dst_csv)
    print(f"Copied {src_csv} -> {dst_csv}")

# Point env vars for this runtime to the local mediapipe copy
os.environ["MEDIAPIPE_ROOT_LOCAL"] = str(DST_ROOT.parent)
os.environ["MEDIAPIPE_VARIANT"] = MEDIAPIPE_VARIANT
print("\n‚úÖ Copy complete. Set USE_MEDIAPIPE=True and point mp_root to MEDIAPIPE_ROOT_LOCAL in the training cell:")
print("  mp_root = Path(os.environ['MEDIAPIPE_ROOT_LOCAL'])")
print("  manifest_csv = mp_root / f'manifest_{MEDIAPIPE_VARIANT}.csv'")
print("  train_csv    = mp_root / f'train_{MEDIAPIPE_VARIANT}.csv'")
print("  val_csv      = mp_root / f'val_{MEDIAPIPE_VARIANT}.csv'")
print("  test_csv     = mp_root / f'test_{MEDIAPIPE_VARIANT}.csv'")



In [None]:
# üìù Training summary ‚Üí Google Sheet
!pip -q install gspread

import json
from pathlib import Path

import gspread
from google.colab import auth
import google.auth

auth.authenticate_user()
creds, _ = google.auth.default()
gc = gspread.authorize(creds)

TRAIN_SHEET_NAME = "TFM Train Logs"   # create this sheet/tab ahead of time
TRAIN_WORKSHEET = "Sheet1"

run_base = Path(CKPT_ROOT) / "runs" / RUN_TAG
all_runs = sorted(run_base.glob("*/"))
if not all_runs:
    raise FileNotFoundError(f"No run folders found under {run_base}")
latest_run = all_runs[-1]
print(f"Logging training summary for run folder: {latest_run}")

history_path = latest_run / "history.json"
if not history_path.exists():
    raise FileNotFoundError(f"history.json not found under {latest_run}")

history_records = json.loads(history_path.read_text()).get("history", [])
if not history_records:
    raise ValueError(f"history.json under {latest_run} has no records.")

params_path = latest_run / "params.json"
params = json.loads(params_path.read_text()) if params_path.exists() else {}

model_name = params.get("model_name", MODEL_NAME)
epochs_cfg = params.get("epochs", EPOCHS)
batch_cfg = params.get("batch_size", BATCH_SIZE)
lr_cfg = params.get("lr", LR)
lr_drop_epoch_cfg = params.get("lr_drop_epoch", LR_DROP_EPOCH)
lr_drop_factor_cfg = params.get("lr_drop_factor", LR_DROP_FACTOR)
image_size_cfg = params.get("image_size", IMAGE_SIZE)
num_workers_cfg = params.get("num_workers", NUM_WORKERS)
use_tiny_cfg = params.get("use_tiny_split", USE_TINY_SPLIT)


def _best_metric(records, split: str) -> tuple[dict, float | None]:
    best_epoch = None
    best_acc = None
    for rec in records:
        split_metrics = rec.get(split) or {}
        acc = split_metrics.get("accuracy")
        if acc is None:
            continue
        if best_acc is None or acc > best_acc:
            best_acc = acc
            best_epoch = rec.get("epoch")
    final_metrics = records[-1].get(split) or {}
    final_acc = final_metrics.get("accuracy")
    return {"epoch": best_epoch, "accuracy": best_acc}, final_acc


best_train, final_train = _best_metric(history_records, "train")
best_val, final_val = _best_metric(history_records, "val")

row = [
    RUN_TAG,
    latest_run.name,
    model_name,
    epochs_cfg,
    batch_cfg,
    lr_cfg,
    lr_drop_epoch_cfg,
    lr_drop_factor_cfg,
    image_size_cfg,
    num_workers_cfg,
    use_tiny_cfg,
    best_train["epoch"] if best_train["epoch"] is not None else "",
    round(best_train["accuracy"], 4) if best_train["accuracy"] is not None else "",
    best_val["epoch"] if best_val["epoch"] is not None else "",
    round(best_val["accuracy"], 4) if best_val["accuracy"] is not None else "",
    round(final_train, 4) if final_train is not None else "",
    round(final_val, 4) if final_val is not None else "",
]

ws = gc.open(TRAIN_SHEET_NAME).worksheet(TRAIN_WORKSHEET)
ws.append_row(row, value_input_option="USER_ENTERED")
print(f"Appended training summary for {latest_run.name} ‚úÖ")



In [None]:
# üîÑ Optional: copy + compress dataset subset ‚Üí fast local SSD (/content/data)
# Re-encodes JPEGs once (quality 80, short side 320px) before landing in /content/data.

import importlib
import os
from pathlib import Path

from ddriver.data.fastcopy import CompressionSpec, copy_splits_with_compression

SRC_ROOT = Path(DRIVE_DATA_ROOT) / "auc.distracted.driver.dataset_v2"
DST_ROOT = Path(FAST_DATA) / "auc.distracted.driver.dataset_v2"

split_csvs = {
    "train": Path(OUT_ROOT) / "splits" / "train.csv",
    "val": Path(OUT_ROOT) / "splits" / "val.csv",
    "train_small": Path(OUT_ROOT) / "splits" / "train_small.csv",
}

compression_spec = CompressionSpec(
    target_short_side=320,  # still >= image_size + resize margin for training
    jpeg_quality=80,        # ImageNet-level compression, visually lossless
)

summary = copy_splits_with_compression(
    split_csvs=split_csvs,
    src_root=SRC_ROOT,
    dst_root=DST_ROOT,
    compression=compression_spec,
    skip_existing=True,
)

print(
    f"\nüìâ FAST_DATA copy stats: processed {summary['processed']} of {summary['total']} files "
    f"(skipped {summary['skipped']} already present)."
)
print(f"Compressed dataset root: {summary['dst_root']}")

DATASET_ROOT = FAST_DATA
os.environ["DATASET_ROOT"] = str(DATASET_ROOT)
try:
    from ddriver import config as _ddriver_config
    importlib.reload(_ddriver_config)
    print("\n‚ö° Copy complete. DATASET_ROOT now points to the local FAST_DATA copy for this runtime:")
    print("   ddriver.config.DATASET_ROOT =", _ddriver_config.DATASET_ROOT)
except Exception as exc:
    print("\n‚ö° Copy complete and DATASET_ROOT env updated, but could not reload ddriver.config:", exc)
print("   (Re-run env summary if you want to rewrite .env, but training now uses /content/data.)")


In [None]:
# üîÑ Optional: copy + compress TEST split ‚Üí /content/data (same settings as train/val)

import importlib
from pathlib import Path

from ddriver.data.fastcopy import CompressionSpec, copy_splits_with_compression

SRC_ROOT = Path(DRIVE_DATA_ROOT) / "auc.distracted.driver.dataset_v2"
DST_ROOT = Path(FAST_DATA) / "auc.distracted.driver.dataset_v2"

split_csvs = {
    "test": Path(OUT_ROOT) / "splits" / "test.csv",
}

compression_spec = CompressionSpec(
    target_short_side=320,
    jpeg_quality=80,
)

summary = copy_splits_with_compression(
    split_csvs=split_csvs,
    src_root=SRC_ROOT,
    dst_root=DST_ROOT,
    compression=compression_spec,
    skip_existing=True,
)

print(
    f"\nüìâ FAST_DATA test copy stats: processed {summary['processed']} of {summary['total']} "
    f"(skipped {summary['skipped']} already present)."
)
print(f"Compressed dataset root: {summary['dst_root']}")

# DATASET_ROOT is already pointing at FAST_DATA from the earlier cell, but reload config just in case
try:
    from ddriver import config as _ddriver_config
    importlib.reload(_ddriver_config)
    print("\n‚ö° Test copy complete. ddriver.config now sees:")
    print("   ddriver.config.DATASET_ROOT =", _ddriver_config.DATASET_ROOT)
except Exception as exc:
    print("\n‚ö° Test copy complete; config reload optional:", exc)

## üì¶ 11.2 Pick the latest checkpoint file

This cell looks inside `CKPT_ROOT/runs/<RUN_TAG>/` and grabs the newest `epoch_*.pt`. Use this path in the prediction step.


In [None]:
from pathlib import Path

RUN_TAG = "convnext_tiny_full_v1"  # pick the tag you want to inspect
#RUN_TAG = globals().get("RUN_TAG", "convnext_tiny_full_v1")  # reuse your latest training tag by default

run_base = Path(CKPT_ROOT) / "runs" / RUN_TAG
runs = sorted(run_base.glob("*/"))
if not runs:
    raise FileNotFoundError(f"No run folders found under {run_base}")

# ---- choose which run folder to use ----
RUN_IDX = -1          # -1 = newest, 0 = oldest, or any index from the printout below
print("Available runs:")
for idx, run_dir in enumerate(runs):
    print(f"  [{idx}] {run_dir.name}")
target_run = runs[RUN_IDX]
print(f"\nSelected run: {target_run}\n")

# ---- choose which checkpoint (epoch) inside that run ----
checkpoint_patterns = ["epoch_*.pt", "best.pt", "last.pt"]
checkpoints = []
for pattern in checkpoint_patterns:
       matches = sorted(target_run.glob(pattern))
       if matches:
           checkpoints.extend(matches)

if not checkpoints:
       raise FileNotFoundError(f"No checkpoints found under {target_run}")

CHECKPOINT_NAME = "best.pt"  # or "last.pt", or None to take the last match
if CHECKPOINT_NAME:
       chosen_ckpt = target_run / CHECKPOINT_NAME
       if not chosen_ckpt.exists():
           raise FileNotFoundError(chosen_ckpt)
else:
       chosen_ckpt = checkpoints[-1]

LATEST_CKPT = chosen_ckpt
print("Using checkpoint:", LATEST_CKPT)

## üîÆ 11.3 Generate predictions CSV

- Uses the checkpoint above
- Choose which split to predict on (`val` or `test`)
- Saves CSV under `OUT_ROOT/preds/<split>/<out_tag>.csv`


In [None]:
import os
PRED_SPLIT = "test"           # or "val"
PRED_TAG = f"{RUN_TAG}_{PRED_SPLIT}"

if "USE_HYBRID" in globals() and USE_HYBRID:
    hybrid_root = Path(os.environ.get("HYBRID_ROOT_LOCAL", Path(OUT_ROOT) / "hybrid"))
    manifest_pred = hybrid_root / f"manifest_{ROI_VARIANT}.csv"
    split_pred = hybrid_root / f"{PRED_SPLIT}_{ROI_VARIANT}.csv"
    split_flag = "--val-csv" if PRED_SPLIT == "val" else "--test-csv"
    manifest_flag = f"--manifest-csv {manifest_pred}"
    split_flag_str = f"{split_flag} {split_pred}"
elif "USE_YOLO" in globals() and USE_YOLO:
    yolo_root = Path(os.environ.get("YOLO_ROOT_LOCAL", Path(OUT_ROOT) / "yolo"))
    manifest_pred = yolo_root / f"manifest_{ROI_VARIANT}.csv"
    split_pred = yolo_root / f"{PRED_SPLIT}_{ROI_VARIANT}.csv"
    split_flag = "--val-csv" if PRED_SPLIT == "val" else "--test-csv"
    manifest_flag = f"--manifest-csv {manifest_pred}"
    split_flag_str = f"{split_flag} {split_pred}"
elif "USE_MEDIAPIPE" in globals() and USE_MEDIAPIPE:
    mp_root = Path(os.environ.get("MEDIAPIPE_ROOT_LOCAL", Path(OUT_ROOT) / "mediapipe"))
    manifest_pred = mp_root / f"manifest_{ROI_VARIANT}.csv"
    split_pred = mp_root / f"{PRED_SPLIT}_{ROI_VARIANT}.csv"
    split_flag = "--val-csv" if PRED_SPLIT == "val" else "--test-csv"
    manifest_flag = f"--manifest-csv {manifest_pred}"
    split_flag_str = f"{split_flag} {split_pred}"
else:
    manifest_pred = Path(OUT_ROOT) / "manifests" / "manifest.csv"
    split_pred = Path(OUT_ROOT) / "splits" / f"{PRED_SPLIT}.csv"
    manifest_flag = f"--manifest-csv {manifest_pred}"
    split_flag_str = f"--{PRED_SPLIT}-csv {split_pred}"

predict_cmd = textwrap.dedent(f"""
cd {PROJECT_ROOT}
python -m src.ddriver.cli.predict \
    --model-name {MODEL_NAME} \
    --checkpoint {LATEST_CKPT} \
    --split {PRED_SPLIT} \
    --batch-size {BATCH_SIZE} \
    --num-workers {NUM_WORKERS} \
    --image-size {IMAGE_SIZE} \
    --out-tag {PRED_TAG} \
    {manifest_flag} \
    {split_flag_str}
""")

print("Running prediction command:\n", predict_cmd)
result = subprocess.run(predict_cmd, shell=True, text=True)
if result.returncode != 0:
    raise RuntimeError("Prediction command failed. See logs above.")
print("\n‚úÖ Predictions completed! Check OUT_ROOT/preds/")


## üìä 11.4 Evaluate metrics

- Uses `src/ddriver/metrics.py`
- Reads the manifest + split CSV + predictions CSV
- Saves results under `OUT_ROOT/metrics/<tag>/<timestamp>/`
- Shows accuracy + macro F1 + per-driver/camera (optional)


In [None]:
import os
from pathlib import Path

if "USE_HYBRID" in globals() and USE_HYBRID:
    hybrid_root = Path(os.environ.get("HYBRID_ROOT_LOCAL", Path(OUT_ROOT) / "hybrid"))
    manifest_path = hybrid_root / f"manifest_{ROI_VARIANT}.csv"
    split_csv_path = hybrid_root / f"{PRED_SPLIT}_{ROI_VARIANT}.csv"
    preds_csv_path = Path(OUT_ROOT) / "preds" / PRED_SPLIT / f"{PRED_TAG}.csv"
elif "USE_YOLO" in globals() and USE_YOLO:
    yolo_root = Path(os.environ.get("YOLO_ROOT_LOCAL", Path(OUT_ROOT) / "yolo"))
    manifest_path = yolo_root / f"manifest_{ROI_VARIANT}.csv"
    split_csv_path = yolo_root / f"{PRED_SPLIT}_{ROI_VARIANT}.csv"
    preds_csv_path = Path(OUT_ROOT) / "preds" / PRED_SPLIT / f"{PRED_TAG}.csv"
elif "USE_MEDIAPIPE" in globals() and USE_MEDIAPIPE:
    mp_root = Path(os.environ.get("MEDIAPIPE_ROOT_LOCAL", Path(OUT_ROOT) / "mediapipe"))
    manifest_path = mp_root / f"manifest_{ROI_VARIANT}.csv"
    split_csv_path = mp_root / f"{PRED_SPLIT}_{ROI_VARIANT}.csv"
    preds_csv_path = Path(OUT_ROOT) / "preds" / PRED_SPLIT / f"{PRED_TAG}.csv"
else:
    manifest_path = Path(OUT_ROOT) / "manifests" / "manifest.csv"
    split_csv_path = Path(OUT_ROOT) / "splits" / f"{PRED_SPLIT}.csv"
    preds_csv_path = Path(OUT_ROOT) / "preds" / PRED_SPLIT / f"{PRED_TAG}.csv"
METRICS_TAG = PRED_TAG

metrics_cmd = textwrap.dedent(f"""
cd {PROJECT_ROOT}
python -m src.ddriver.eval.metrics \
    --manifest {manifest_path} \
    --split-csv {split_csv_path} \
    --predictions {preds_csv_path} \
    --out-tag {METRICS_TAG} \
    --per-driver \
    --per-camera
""")

print("Running metrics command:\n", metrics_cmd)
result = subprocess.run(metrics_cmd, shell=True, text=True)
if result.returncode != 0:
    raise RuntimeError("Metrics command failed. See logs above.")
print("\n‚úÖ Metrics saved under OUT_ROOT/metrics/")


### ‚úÖ You're all set!

**What just happened:**
1. ‚úÖ Mounted Google Drive
2. ‚úÖ Cloned/updated your repo
3. ‚úÖ Installed the package
4. ‚úÖ Set up paths (works on Colab and Mac!)
5. ‚úÖ Generated manifest.csv and train/val/test split CSVs
6. ‚úÖ Tested that dataset.py can load images
7. ‚úÖ Tested that datamod.py can create data loaders
8. ‚úÖ (Optional) Registered a model + ran training ‚Üí prediction ‚Üí metrics pipeline

**Your CSVs are saved in Google Drive:**
- `OUT_ROOT/manifests/manifest.csv` - Big list of all images
- `OUT_ROOT/splits/train.csv` - Training images
- `OUT_ROOT/splits/val.csv` - Validation images (with driver IDs!)
- `OUT_ROOT/splits/test.csv` - Test images

**Next steps:**
- Adjust the training/prediction cells (epochs, batch size, tags) to run bigger experiments
- All paths use `ddriver.config` so it works on Colab and Mac
- Re-run **Clone/Update** cell after pushing new commits
- Optional: copy some data into `/content/data` to use `FAST_DATA` for speed


In [None]:
# ---- Colab cell: append metrics + params to Google Sheet ----
!pip -q install gspread

import json
from pathlib import Path

import gspread
from google.colab import auth
import google.auth

auth.authenticate_user()
creds, _ = google.auth.default()
gc = gspread.authorize(creds)

EVAL_SHEET_NAME = "TFM Eval Logs"   # create this sheet/tab ahead of time
EVAL_WORKSHEET = "Sheet1"

METRICS_TAG = (
    globals().get("METRICS_TAG")
    or globals().get("PRED_TAG")
    or "convnext_tiny_full_v1_val"
)  # match the --out-tag you used
metrics_root = Path(OUT_ROOT) / "metrics" / METRICS_TAG
runs = sorted(metrics_root.glob("*/"))
if not runs:
    raise FileNotFoundError(f"No metrics runs found under {metrics_root}")
latest_metrics = runs[-1]
print("Logging metrics folder:", latest_metrics)

def _read_json(path: Path, *, required: bool = True) -> dict:
    if not path.exists():
        if required:
            raise FileNotFoundError(f"Expected file missing: {path}")
        return {}
    return json.loads(path.read_text())

metrics = _read_json(latest_metrics / "metrics.json")
inputs = _read_json(latest_metrics / "inputs.json", required=False)
params = _read_json(latest_metrics / "params.json", required=False)

overall = metrics.get("overall", {})
macro = overall.get("macro_avg", {})

row = [
    str(latest_metrics),
    inputs.get("predictions", ""),
    inputs.get("split_source", ""),
    metrics.get("num_examples", ""),
    round(overall.get("accuracy", 0.0), 4),
    round(macro.get("f1", 0.0), 4),
    json.dumps(params, sort_keys=True)[:500],
]

ws = gc.open(EVAL_SHEET_NAME).worksheet(EVAL_WORKSHEET)
ws.append_row(row, value_input_option="USER_ENTERED")
print(f"Appended metrics run {latest_metrics.name} to {EVAL_SHEET_NAME}/{EVAL_WORKSHEET} ‚úÖ")

### üìä 11.4a Visualize Confusion Matrix

Quick peek at where the model confuses classes using the most recent metrics run.


In [None]:
import json
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

METRICS_TAG = (
    globals().get("METRICS_TAG")
    or globals().get("PRED_TAG")
    or "convnext_tiny_full_v1_val"
)  # change if you used a different --out-tag
metrics_root = Path(OUT_ROOT) / "metrics" / METRICS_TAG
runs = sorted(metrics_root.glob("*/"))
if not runs:
    raise FileNotFoundError(f"No metrics runs found under {metrics_root}")
latest_metrics = runs[-1]
print("Reading confusion matrix from:", latest_metrics)

metrics = json.loads((latest_metrics / "metrics.json").read_text())
cm_info = metrics.get("confusion_matrix")
if not cm_info:
    raise ValueError("confusion_matrix missing from metrics.json")

labels = cm_info["rows_cols_labels"]
cm_df = pd.DataFrame(cm_info["matrix"], index=labels, columns=labels)

counts_path = latest_metrics / "confusion_matrix_counts.png"
plt.figure(figsize=(8, 6))
sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
plt.title(f"Confusion matrix ‚Äì {METRICS_TAG}")
plt.ylabel("True class")
plt.xlabel("Predicted class")
plt.tight_layout()
plt.savefig(counts_path)
plt.show()
print("Saved counts heatmap to", counts_path)

cm_norm = cm_df.div(cm_df.sum(axis=1).replace(0, 1), axis=0)
norm_path = latest_metrics / "confusion_matrix_normalized.png"
plt.figure(figsize=(8, 6))
sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues")
plt.title(f"Normalized confusion matrix ‚Äì {METRICS_TAG}")
plt.ylabel("True class")
plt.xlabel("Predicted class")
plt.tight_layout()
plt.savefig(norm_path)
plt.show()
print("Saved normalized heatmap to", norm_path)


## üî• 11.5 Grad-CAM Visualizations

**Grad-CAM (Gradient-weighted Class Activation Mapping)** shows which regions of the image the model focuses on when making predictions.

**Use cases:**
- **Full image models:** Verify the model looks at face/hands, not background
- **Hybrid crop models:** See which specific features (hand position, facial expression) matter most
- **Thesis comparison:** Visual evidence of WHY ROI cropping helps

This cell generates:
1. Grad-CAM heatmaps for sample images
2. Comparison of correct vs misclassified predictions
3. Per-class attention patterns
4. Saved visualizations for your thesis


In [None]:
# üî• Grad-CAM Visualization
!pip -q install grad-cam

import json
import random
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision import transforms

# ===== CONFIGURATION =====
# Use the same model/checkpoint from your training run
GRADCAM_TAG = globals().get("TRAIN_TAG") or globals().get("PRED_TAG") or "effb0_hybrid_face_hands"
MODEL_NAME = globals().get("MODEL_NAME") or "efficientnet_b0"  # efficientnet_b0, resnet18, convnext_tiny
SPLIT_TO_ANALYZE = "val"  # val or test
N_SAMPLES_PER_CLASS = 3  # How many samples per class to visualize
N_MISCLASSIFIED = 6  # How many misclassified examples to show
IMAGE_SIZE = 224

# ===== FIND CHECKPOINT =====
run_base = Path(CKPT_ROOT) / "runs" / GRADCAM_TAG
all_runs = sorted(run_base.glob("*/"))
if not all_runs:
    raise FileNotFoundError(f"No run folders found under {run_base}")
latest_run = all_runs[-1]
ckpt_path = latest_run / "best.pt"
if not ckpt_path.exists():
    ckpt_path = latest_run / "last.pt"
print(f"üìÅ Using checkpoint: {ckpt_path}")

# ===== LOAD MODEL =====
from ddriver.models.registry import get_model

# Load checkpoint to get num_classes
ckpt = torch.load(ckpt_path, map_location="cpu")
num_classes = ckpt["model_state_dict"]["classifier.weight"].shape[0] if "classifier.weight" in ckpt["model_state_dict"] else 10

model = get_model(MODEL_NAME, num_classes=num_classes, pretrained=False)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"‚úÖ Loaded model: {MODEL_NAME} with {num_classes} classes")

# ===== DETERMINE TARGET LAYER FOR GRAD-CAM =====
# Different architectures have different final conv layers
if "efficientnet" in MODEL_NAME:
    target_layers = [model.backbone.features[-1]]
elif "resnet" in MODEL_NAME:
    target_layers = [model.backbone.layer4[-1]]
elif "convnext" in MODEL_NAME:
    target_layers = [model.backbone.features[-1]]
else:
    # Fallback: try to find last conv layer
    target_layers = [list(model.backbone.children())[-2]]
print(f"üéØ Target layer for Grad-CAM: {target_layers[0].__class__.__name__}")

# ===== LOAD PREDICTIONS AND DATA =====
# Find the predictions CSV and manifest
preds_csv = Path(OUT_ROOT) / "preds" / SPLIT_TO_ANALYZE / f"{GRADCAM_TAG}.csv"
if not preds_csv.exists():
    raise FileNotFoundError(f"Predictions not found: {preds_csv}\nRun the prediction cell first!")

preds_df = pd.read_csv(preds_csv)
print(f"üìä Loaded {len(preds_df)} predictions from {preds_csv}")

# Determine data root (hybrid, yolo, mediapipe, or full images)
if "USE_HYBRID" in globals() and USE_HYBRID:
    data_root = Path(os.environ.get("HYBRID_ROOT_LOCAL", Path(OUT_ROOT) / "hybrid"))
elif "USE_YOLO" in globals() and USE_YOLO:
    data_root = Path(os.environ.get("YOLO_ROOT_LOCAL", Path(OUT_ROOT) / "yolo"))
elif "USE_MEDIAPIPE" in globals() and USE_MEDIAPIPE:
    data_root = Path(OUT_ROOT) / "mediapipe"
else:
    data_root = Path(DATASET_ROOT)
print(f"üìÅ Data root: {data_root}")

# ===== CLASS NAMES =====
CLASS_NAMES = {
    0: "Safe driving",
    1: "Texting (right)",
    2: "Talking (right)",
    3: "Texting (left)",
    4: "Talking (left)",
    5: "Operating radio",
    6: "Drinking",
    7: "Reaching behind",
    8: "Hair/makeup",
    9: "Talking to passenger",
}

# ===== TRANSFORMS =====
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def load_image(path):
    """Load image and return both tensor and RGB numpy array."""
    img = Image.open(path).convert("RGB")
    img_resized = img.resize((IMAGE_SIZE, IMAGE_SIZE))
    img_np = np.array(img_resized) / 255.0  # Normalized 0-1 for overlay
    img_tensor = transform(img).unsqueeze(0)
    return img_tensor, img_np

# ===== GENERATE GRAD-CAM =====
cam = GradCAM(model=model, target_layers=target_layers)

def generate_gradcam(img_path, target_class=None):
    """Generate Grad-CAM visualization for an image."""
    img_tensor, img_np = load_image(img_path)
    img_tensor = img_tensor.to(device)
    
    # Generate CAM (None = use predicted class)
    grayscale_cam = cam(input_tensor=img_tensor, targets=None)
    grayscale_cam = grayscale_cam[0, :]
    
    # Overlay on image
    visualization = show_cam_on_image(img_np.astype(np.float32), grayscale_cam, use_rgb=True)
    return visualization, grayscale_cam

# ===== 1. SAMPLE CORRECT PREDICTIONS PER CLASS =====
print("\nüé® Generating Grad-CAM for correct predictions per class...")

correct_df = preds_df[preds_df["label"] == preds_df["pred"]]
fig, axes = plt.subplots(num_classes, N_SAMPLES_PER_CLASS * 2, figsize=(N_SAMPLES_PER_CLASS * 6, num_classes * 3))

for class_id in range(num_classes):
    class_samples = correct_df[correct_df["label"] == class_id]
    samples = class_samples.sample(min(N_SAMPLES_PER_CLASS, len(class_samples)), random_state=42)
    
    for i, (_, row) in enumerate(samples.iterrows()):
        if i >= N_SAMPLES_PER_CLASS:
            break
        
        img_path = data_root / row["path"] if not Path(row["path"]).is_absolute() else Path(row["path"])
        if not img_path.exists():
            img_path = Path(DATASET_ROOT) / row["path"]
        
        if img_path.exists():
            # Original image
            img = Image.open(img_path).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
            axes[class_id, i * 2].imshow(img)
            axes[class_id, i * 2].set_title(f"{CLASS_NAMES.get(class_id, f'Class {class_id}')}", fontsize=8)
            axes[class_id, i * 2].axis("off")
            
            # Grad-CAM
            viz, _ = generate_gradcam(img_path)
            axes[class_id, i * 2 + 1].imshow(viz)
            axes[class_id, i * 2 + 1].set_title(f"Grad-CAM (conf: {row['confidence']:.2f})", fontsize=8)
            axes[class_id, i * 2 + 1].axis("off")
        else:
            axes[class_id, i * 2].text(0.5, 0.5, "Not found", ha="center", va="center")
            axes[class_id, i * 2].axis("off")
            axes[class_id, i * 2 + 1].axis("off")

plt.suptitle(f"Grad-CAM: Correct Predictions per Class ({MODEL_NAME})", fontsize=14, y=1.02)
plt.tight_layout()

# Save
gradcam_dir = Path(OUT_ROOT) / "gradcam" / GRADCAM_TAG
gradcam_dir.mkdir(parents=True, exist_ok=True)
plt.savefig(gradcam_dir / "correct_per_class.png", dpi=150, bbox_inches="tight")
plt.show()
print(f"üíæ Saved to {gradcam_dir / 'correct_per_class.png'}")

# ===== 2. MISCLASSIFIED EXAMPLES =====
print("\nüî¥ Generating Grad-CAM for misclassified examples...")

misclassified_df = preds_df[preds_df["label"] != preds_df["pred"]]
if len(misclassified_df) > 0:
    samples = misclassified_df.sample(min(N_MISCLASSIFIED, len(misclassified_df)), random_state=42)
    
    fig, axes = plt.subplots(len(samples), 3, figsize=(12, len(samples) * 3))
    if len(samples) == 1:
        axes = axes.reshape(1, -1)
    
    for i, (_, row) in enumerate(samples.iterrows()):
        img_path = data_root / row["path"] if not Path(row["path"]).is_absolute() else Path(row["path"])
        if not img_path.exists():
            img_path = Path(DATASET_ROOT) / row["path"]
        
        if img_path.exists():
            # Original
            img = Image.open(img_path).convert("RGB").resize((IMAGE_SIZE, IMAGE_SIZE))
            axes[i, 0].imshow(img)
            axes[i, 0].set_title(f"True: {CLASS_NAMES.get(int(row['label']), row['label'])}", fontsize=9)
            axes[i, 0].axis("off")
            
            # Grad-CAM
            viz, heatmap = generate_gradcam(img_path)
            axes[i, 1].imshow(viz)
            axes[i, 1].set_title(f"Pred: {CLASS_NAMES.get(int(row['pred']), row['pred'])} ({row['confidence']:.2f})", fontsize=9)
            axes[i, 1].axis("off")
            
            # Heatmap only
            axes[i, 2].imshow(heatmap, cmap="jet")
            axes[i, 2].set_title("Attention heatmap", fontsize=9)
            axes[i, 2].axis("off")
    
    plt.suptitle(f"Grad-CAM: Misclassified Examples ({MODEL_NAME})", fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig(gradcam_dir / "misclassified.png", dpi=150, bbox_inches="tight")
    plt.show()
    print(f"üíæ Saved to {gradcam_dir / 'misclassified.png'}")
else:
    print("‚úÖ No misclassified examples found!")

# ===== 3. SUMMARY STATISTICS =====
print("\nüìä Grad-CAM Analysis Summary:")
print(f"   Model: {MODEL_NAME}")
print(f"   Checkpoint: {ckpt_path.name}")
print(f"   Split analyzed: {SPLIT_TO_ANALYZE}")
print(f"   Total predictions: {len(preds_df)}")
print(f"   Correct: {len(correct_df)} ({100*len(correct_df)/len(preds_df):.1f}%)")
print(f"   Misclassified: {len(misclassified_df)} ({100*len(misclassified_df)/len(preds_df):.1f}%)")
print(f"\nüìÅ Visualizations saved to: {gradcam_dir}")


## üî¨ 11.5a Grad-CAM Comparison: Full Image vs Hybrid Crops (Optional)

If you've trained both a **full-image model** and a **hybrid crop model**, run this cell to generate side-by-side comparisons showing how cropping changes the model's attention.

**Thesis insight:** This demonstrates WHY ROI cropping helps ‚Äî the full-image model may attend to irrelevant regions while the crop model focuses on meaningful features.


In [None]:
# üî¨ Grad-CAM Comparison: Full Image vs Hybrid Crops
# Requires: trained models for BOTH full images AND hybrid crops

import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision import transforms

# ===== CONFIGURATION =====
# Tags for your trained models (must have run predictions for both)
FULL_IMAGE_TAG = "effb0_full_images"      # Your full-image model tag
HYBRID_TAG = "effb0_hybrid_face_hands"    # Your hybrid crop model tag
MODEL_NAME = "efficientnet_b0"            # Must be same architecture for both
SPLIT = "val"
N_COMPARISON_SAMPLES = 6
IMAGE_SIZE = 224

# ===== HELPER FUNCTIONS =====
def load_model_and_cam(tag, model_name):
    """Load a model and create GradCAM for it."""
    from ddriver.models.registry import get_model
    
    run_base = Path(CKPT_ROOT) / "runs" / tag
    all_runs = sorted(run_base.glob("*/"))
    if not all_runs:
        return None, None, None
    latest_run = all_runs[-1]
    ckpt_path = latest_run / "best.pt"
    if not ckpt_path.exists():
        ckpt_path = latest_run / "last.pt"
    
    ckpt = torch.load(ckpt_path, map_location="cpu")
    num_classes = ckpt["model_state_dict"].get("classifier.weight", torch.zeros(10, 1)).shape[0]
    
    model = get_model(model_name, num_classes=num_classes, pretrained=False)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()
    
    if "efficientnet" in model_name:
        target_layers = [model.backbone.features[-1]]
    elif "resnet" in model_name:
        target_layers = [model.backbone.layer4[-1]]
    else:
        target_layers = [model.backbone.features[-1]]
    
    cam = GradCAM(model=model, target_layers=target_layers)
    return model, cam, ckpt_path

# Load both models
print("Loading models...")
full_model, full_cam, full_ckpt = load_model_and_cam(FULL_IMAGE_TAG, MODEL_NAME)
hybrid_model, hybrid_cam, hybrid_ckpt = load_model_and_cam(HYBRID_TAG, MODEL_NAME)

if full_model is None:
    print(f"‚ö†Ô∏è Full image model not found: {FULL_IMAGE_TAG}")
    print("   Train a full-image model first, or update FULL_IMAGE_TAG")
if hybrid_model is None:
    print(f"‚ö†Ô∏è Hybrid model not found: {HYBRID_TAG}")
    print("   Train a hybrid model first, or update HYBRID_TAG")

if full_model is None or hybrid_model is None:
    raise RuntimeError("Both models required for comparison. Check tags above.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
full_model = full_model.to(device)
hybrid_model = hybrid_model.to(device)
print(f"‚úÖ Loaded both models")

# Transform
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load predictions to find common samples
full_preds = pd.read_csv(Path(OUT_ROOT) / "preds" / SPLIT / f"{FULL_IMAGE_TAG}.csv")
hybrid_preds = pd.read_csv(Path(OUT_ROOT) / "preds" / SPLIT / f"{HYBRID_TAG}.csv")

# Full images path and hybrid crops path
full_data_root = Path(DATASET_ROOT)
hybrid_data_root = Path(os.environ.get("HYBRID_ROOT_LOCAL", Path(OUT_ROOT) / "hybrid"))

# Sample images
samples = full_preds.sample(min(N_COMPARISON_SAMPLES, len(full_preds)), random_state=42)

# Generate comparison
fig, axes = plt.subplots(len(samples), 4, figsize=(16, len(samples) * 3.5))

CLASS_NAMES = {
    0: "Safe", 1: "Text-R", 2: "Talk-R", 3: "Text-L", 4: "Talk-L",
    5: "Radio", 6: "Drink", 7: "Reach", 8: "Hair", 9: "Passenger"
}

for i, (_, row) in enumerate(samples.iterrows()):
    # Find corresponding hybrid path
    full_path = full_data_root / row["path"]
    
    # Match in hybrid predictions by finding same original image
    # The hybrid manifest has "original_path" column
    hybrid_manifest = pd.read_csv(hybrid_data_root / f"manifest_face_hands.csv")
    match = hybrid_manifest[hybrid_manifest["original_path"].str.contains(Path(row["path"]).name, na=False)]
    
    if len(match) == 0:
        continue
    hybrid_path = hybrid_data_root / match.iloc[0]["path"]
    
    # Load and process full image
    if full_path.exists():
        full_img = Image.open(full_path).convert("RGB")
        full_img_resized = full_img.resize((IMAGE_SIZE, IMAGE_SIZE))
        full_np = np.array(full_img_resized) / 255.0
        full_tensor = transform(full_img).unsqueeze(0).to(device)
        
        full_cam_result = full_cam(input_tensor=full_tensor, targets=None)[0]
        full_viz = show_cam_on_image(full_np.astype(np.float32), full_cam_result, use_rgb=True)
        
        axes[i, 0].imshow(full_img_resized)
        axes[i, 0].set_title(f"Full Image\nTrue: {CLASS_NAMES.get(int(row['label']), row['label'])}", fontsize=9)
        axes[i, 0].axis("off")
        
        axes[i, 1].imshow(full_viz)
        axes[i, 1].set_title(f"Full Grad-CAM\nPred: {CLASS_NAMES.get(int(row['pred']), row['pred'])}", fontsize=9)
        axes[i, 1].axis("off")
    
    # Load and process hybrid crop
    if hybrid_path.exists():
        hybrid_img = Image.open(hybrid_path).convert("RGB")
        hybrid_img_resized = hybrid_img.resize((IMAGE_SIZE, IMAGE_SIZE))
        hybrid_np = np.array(hybrid_img_resized) / 255.0
        hybrid_tensor = transform(hybrid_img).unsqueeze(0).to(device)
        
        hybrid_cam_result = hybrid_cam(input_tensor=hybrid_tensor, targets=None)[0]
        hybrid_viz = show_cam_on_image(hybrid_np.astype(np.float32), hybrid_cam_result, use_rgb=True)
        
        axes[i, 2].imshow(hybrid_img_resized)
        axes[i, 2].set_title(f"Hybrid Crop", fontsize=9)
        axes[i, 2].axis("off")
        
        axes[i, 3].imshow(hybrid_viz)
        hybrid_row = hybrid_preds[hybrid_preds["path"].str.contains(Path(hybrid_path).name, na=False)]
        pred_label = hybrid_row.iloc[0]["pred"] if len(hybrid_row) > 0 else "?"
        axes[i, 3].set_title(f"Hybrid Grad-CAM\nPred: {CLASS_NAMES.get(int(pred_label), pred_label)}", fontsize=9)
        axes[i, 3].axis("off")

plt.suptitle(f"Grad-CAM Comparison: Full Image vs Hybrid Crop ({MODEL_NAME})", fontsize=14, y=1.02)
plt.tight_layout()

# Save
comparison_path = Path(OUT_ROOT) / "gradcam" / "comparison_full_vs_hybrid.png"
comparison_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(comparison_path, dpi=150, bbox_inches="tight")
plt.show()
print(f"\nüíæ Saved comparison to: {comparison_path}")
print("\nüìù Thesis talking point:")
print("   'The full-image model attends to [background/seat/etc] while the")
print("    hybrid-crop model focuses on [hand position/facial features/etc],")
print("    demonstrating the value of ROI-based preprocessing.'")
