# ðŸš‚ 02 â€” Training

**Purpose:** Train models on full images or hybrid ROI crops.

**Sections:**
1. Inline Setup
2. Copy Data to /content (full images OR hybrid crops)
3. Register Models
4. Training Configuration & Execution
5. Learning Curves Visualization

**Prerequisites:** Manifests and splits exist on Drive (from 01_data_preparation.ipynb or prior run)


## ðŸ”§ Section 1: Inline Setup


In [None]:
# --- INLINE SETUP ---
import os, subprocess, sys

# Config
REPO_URL       = "https://github.com/ClaudiaCPach/CNNs-distracted-driving"
REPO_DIRNAME   = "CNNs-distracted-driving"
BRANCH         = "main"
PROJECT_ROOT   = f"/content/{REPO_DIRNAME}"
DRIVE_PATH     = "/content/drive/MyDrive/TFM"
DRIVE_DATA_ROOT = f"{DRIVE_PATH}/data"
FAST_DATA      = "/content/data"
DATASET_ROOT   = DRIVE_DATA_ROOT
OUT_ROOT       = f"{DRIVE_PATH}/outputs"
CKPT_ROOT      = f"{DRIVE_PATH}/checkpoints"

# Mount Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

# Clone/update repo
def sh(cmd):
    print(f"$ {cmd}")
    rc = subprocess.call(cmd, shell=True, executable="/bin/bash")
    if rc != 0:
        raise RuntimeError(f"Command failed: {cmd}")

if os.path.isdir(PROJECT_ROOT):
    sh(f"cd {PROJECT_ROOT} && git pull --rebase origin {BRANCH}")
else:
    sh(f"git clone --branch {BRANCH} {REPO_URL} {PROJECT_ROOT}")

# Install
sh(f"pip install -q -e {PROJECT_ROOT}")
!pip -q install timm

# Set env vars
os.environ["DRIVE_PATH"] = DRIVE_PATH
os.environ["DATASET_ROOT"] = DATASET_ROOT
os.environ["OUT_ROOT"] = OUT_ROOT
os.environ["CKPT_ROOT"] = CKPT_ROOT
os.environ["FAST_DATA"] = FAST_DATA

sys.path.insert(0, PROJECT_ROOT)
sys.path.insert(0, os.path.join(PROJECT_ROOT, "src"))

# GPU check
!nvidia-smi || echo "No GPU"
print("âœ… Inline setup complete")


## âš¡ Section 2: Copy Data to /content

Choose ONE of the options below based on what you're training on.


In [None]:
# âš¡ OPTION A: Copy HYBRID CROPS to /content (for face/face_hands training)
import os, shutil
from pathlib import Path
import importlib

HYBRID_VARIANT = "face"  # face | face_hands

LOCAL_ROOT = Path("/content/data/hybrid")
DRIVE_ROOT = Path(OUT_ROOT) / "hybrid"

LOCAL_VARIANT_DIR = LOCAL_ROOT / HYBRID_VARIANT
DRIVE_VARIANT_DIR = DRIVE_ROOT / HYBRID_VARIANT

def count_jpgs(p: Path) -> int:
    return sum(1 for _ in p.rglob("*.jpg")) if p.exists() else 0

local_count = count_jpgs(LOCAL_VARIANT_DIR)
drive_count = count_jpgs(DRIVE_VARIANT_DIR)

print(f"ðŸ”Ž Local: {local_count} jpgs | Drive: {drive_count} jpgs")

if local_count > 0:
    print(f"âœ… Hybrid crops already in /content. Skipping copy.")
elif drive_count == 0:
    raise FileNotFoundError(f"No crops on Drive at {DRIVE_VARIANT_DIR}")
else:
    print(f"ðŸ“¦ Copying {HYBRID_VARIANT} crops from Drive to /content...")
    LOCAL_VARIANT_DIR.mkdir(parents=True, exist_ok=True)
    
    file_count = 0
    for src_dir, _, files in os.walk(DRIVE_VARIANT_DIR):
        rel_dir = Path(src_dir).relative_to(DRIVE_VARIANT_DIR)
        dst_dir = LOCAL_VARIANT_DIR / rel_dir
        dst_dir.mkdir(parents=True, exist_ok=True)
        for fname in files:
            if fname.lower().endswith(".jpg"):
                shutil.copy2(Path(src_dir) / fname, dst_dir / fname)
                file_count += 1
    print(f"   Copied {file_count} images")
    
    # Copy CSVs
    for fname in [f"manifest_{HYBRID_VARIANT}.csv", f"train_{HYBRID_VARIANT}.csv", 
                  f"val_{HYBRID_VARIANT}.csv", f"test_{HYBRID_VARIANT}.csv"]:
        src = DRIVE_ROOT / fname
        if src.exists():
            shutil.copy2(src, LOCAL_ROOT / fname)
            print(f"   Copied {fname}")

# Update env vars
os.environ["HYBRID_ROOT_LOCAL"] = str(LOCAL_ROOT)
os.environ["DATASET_ROOT"] = str(LOCAL_ROOT)
from ddriver import config as _cfg
importlib.reload(_cfg)
print(f"\nâœ… DATASET_ROOT = {os.environ['DATASET_ROOT']}")


In [None]:
# âš¡ OPTION B: Copy FULL IMAGES to /content (for full-frame training)
# Skip this if using hybrid crops above

import importlib
from pathlib import Path
from ddriver.data.fastcopy import CompressionSpec, copy_splits_with_compression

SRC_ROOT = Path(DRIVE_DATA_ROOT) / "auc.distracted.driver.dataset_v2"
DST_ROOT = Path(FAST_DATA) / "auc.distracted.driver.dataset_v2"

split_csvs = {
    "train": Path(OUT_ROOT) / "splits" / "train.csv",
    "val": Path(OUT_ROOT) / "splits" / "val.csv",
    "test": Path(OUT_ROOT) / "splits" / "test.csv",
}

compression_spec = CompressionSpec(target_short_side=320, jpeg_quality=80)

summary = copy_splits_with_compression(
    split_csvs=split_csvs, src_root=SRC_ROOT, dst_root=DST_ROOT,
    compression=compression_spec, skip_existing=True,
)

print(f"ðŸ“‰ Copied {summary['processed']} files (skipped {summary['skipped']})")

os.environ["DATASET_ROOT"] = str(FAST_DATA)
from ddriver import config as _cfg
importlib.reload(_cfg)
print(f"âœ… DATASET_ROOT = {os.environ['DATASET_ROOT']}")


## ðŸ“¦ Section 3: Register Models


In [None]:
# Register models from timm
from ddriver.models import registry

registry.register_timm_backbone("efficientnet_b0")
# registry.register_timm_backbone("convnext_tiny")
# registry.register_timm_backbone("resnet50")

print("Available models:", registry.available_models()[:10])


## ðŸš‚ Section 4: Training Configuration & Execution


In [None]:
# ðŸš‚ TRAINING CONFIGURATION
import os
import subprocess, textwrap, json, time, threading
from pathlib import Path

# ============== EXPERIMENT CONFIG ==============
RUN_TAG = "effb0_face_v1"           # <<<< CHANGE for each experiment
MODEL_NAME = "efficientnet_b0"
SEED = 42                           # <<<< CHANGE for stability runs (e.g., 42, 123, 456)

# Training hyperparameters
EPOCHS = 30
BATCH_SIZE = 32
NUM_WORKERS = 2
IMAGE_SIZE = 224
LR = 3e-4
LABEL_SMOOTHING = 0.05
USE_TINY_SPLIT = False

# Data source: choose ONE
USE_HYBRID = True                   # True = use hybrid crops
ROI_VARIANT = "face"                # face | face_hands (only if USE_HYBRID=True)

# Control split selection (for 5-run experimental plan)
# Set to None for natural runs, or "facesubset" / "fhsubset" / "both" for control runs
# NOTE: Control splits only work with full-frame (USE_HYBRID=False)
USE_CONTROL_SPLIT = None            # None | "facesubset" | "fhsubset" | "both"

# Validate settings
if USE_CONTROL_SPLIT and USE_HYBRID:
    raise ValueError("Control splits require full-frame training. Set USE_HYBRID=False to use control splits.")

# ============== BUILD PATHS ==============
if USE_HYBRID:
    hybrid_root = Path(os.environ.get("HYBRID_ROOT_LOCAL", Path(OUT_ROOT) / "hybrid"))
    manifest_csv = hybrid_root / f"manifest_{ROI_VARIANT}.csv"
    train_split = f"train_{ROI_VARIANT}.csv" if not USE_TINY_SPLIT else f"train_small_{ROI_VARIANT}.csv"
    train_csv = hybrid_root / train_split
    val_csv = hybrid_root / f"val_{ROI_VARIANT}.csv"
    test_csv = hybrid_root / f"test_{ROI_VARIANT}.csv"
    print(f"ðŸ”€ Using Hybrid crops: {ROI_VARIANT}")
    print(f"   hybrid_root = {hybrid_root}")
else:
    manifest_csv = Path(OUT_ROOT) / "manifests" / "manifest.csv"
    # Handle control splits for 5-run experimental plan
    if USE_CONTROL_SPLIT:
        control_root = Path(OUT_ROOT) / "splits" / "control"
        train_csv = control_root / f"train_{USE_CONTROL_SPLIT}.csv"
        val_csv = control_root / f"val_{USE_CONTROL_SPLIT}.csv"
        test_csv = control_root / f"test_{USE_CONTROL_SPLIT}.csv"
        print(f"ðŸ“· Using full-frame images with CONTROL SPLIT: {USE_CONTROL_SPLIT}")
        print(f"   (Filtered to {USE_CONTROL_SPLIT} IDs for fair comparison)")
    else:
        train_split = "train_small.csv" if USE_TINY_SPLIT else "train.csv"
        train_csv = Path(OUT_ROOT) / "splits" / train_split
        val_csv = Path(OUT_ROOT) / "splits" / "val.csv"
        test_csv = Path(OUT_ROOT) / "splits" / "test.csv"
        print("ðŸ“· Using full-frame images")

# Update DATASET_ROOT for hybrid
if USE_HYBRID:
    _roi_root = hybrid_root
    if str(_roi_root).startswith("/content/data"):
        import importlib
        os.environ["DATASET_ROOT"] = str(_roi_root)
        from ddriver import config as _cfg
        importlib.reload(_cfg)
        print(f"   âš¡ DATASET_ROOT = {os.environ['DATASET_ROOT']}")

print(f"\nðŸ“‹ Training config:")
print(f"   RUN_TAG: {RUN_TAG}")
print(f"   Model: {MODEL_NAME}")
print(f"   Seed: {SEED}")
print(f"   Epochs: {EPOCHS}, Batch: {BATCH_SIZE}, LR: {LR}")
if USE_CONTROL_SPLIT:
    print(f"   Control Split: {USE_CONTROL_SPLIT}")


In [None]:
# ðŸš‚ RUN TRAINING
train_cmd = textwrap.dedent(f"""
cd {PROJECT_ROOT}
python -m src.ddriver.cli.train \
    --model-name {MODEL_NAME} \
    --epochs {EPOCHS} \
    --batch-size {BATCH_SIZE} \
    --num-workers {NUM_WORKERS} \
    --image-size {IMAGE_SIZE} \
    --lr {LR} \
    --weight-decay .01 \
    --optimizer adamw \
    --label-smoothing {LABEL_SMOOTHING} \
    --seed {SEED} \
    --out-tag {RUN_TAG} \
    --manifest-csv {manifest_csv} \
    --train-csv {train_csv} \
    --val-csv {val_csv} \
    --test-csv {test_csv}
""")

print("Running training command:\n", train_cmd)

proc = subprocess.Popen(train_cmd, shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

# GPU monitor thread
def _gpu_monitor():
    while proc.poll() is None:
        try:
            stats = subprocess.check_output(
                "nvidia-smi --query-gpu=utilization.gpu,memory.used,memory.total --format=csv,nounits,noheader",
                shell=True,
            ).decode("utf-8").strip()
            print(f"[GPU] {stats}")
        except:
            pass
        time.sleep(5)

monitor = threading.Thread(target=_gpu_monitor, daemon=True)
monitor.start()

for line in proc.stdout:
    print(line, end="")

proc.wait()
print("\nâœ… Training complete!")


## ðŸ“ˆ Section 5: Learning Curves Visualization


In [None]:
# ðŸ“ˆ Display training metrics and learning curves
import json
import matplotlib.pyplot as plt
from pathlib import Path

run_base = Path(CKPT_ROOT) / "runs" / RUN_TAG
all_runs = sorted(run_base.glob("*/"))
if not all_runs:
    raise FileNotFoundError(f"No run folders found under {run_base}")
latest_run = all_runs[-1]

history_path = latest_run / "history.json"
if not history_path.exists():
    raise FileNotFoundError(f"history.json not found in {latest_run}")

history = json.loads(history_path.read_text()).get("history", [])

print(f"ðŸ“Š Epoch metrics for run: {latest_run.name}")
for record in history:
    train_metrics = record.get("train", {})
    val_metrics = record.get("val", {}) or {}
    train_loss = train_metrics.get("loss")
    train_acc = train_metrics.get("accuracy")
    val_loss = val_metrics.get("loss")
    val_acc = val_metrics.get("accuracy")
    val_str = f"val_loss={val_loss:.4f} acc={val_acc:.4f}" if val_loss else "val_loss=â€” val_acc=â€”"
    print(f"  Epoch {record['epoch']:>2}: train_loss={train_loss:.4f} acc={train_acc:.4f}  {val_str}")

# Plot learning curves
epochs = [r["epoch"] for r in history]
train_loss = [r["train"]["loss"] for r in history]
train_acc = [r["train"]["accuracy"] for r in history]
val_loss = [(r.get("val") or {}).get("loss") for r in history]
val_acc = [(r.get("val") or {}).get("accuracy") for r in history]

val_epochs = [e for e, v in zip(epochs, val_loss) if v is not None]
val_loss_f = [v for v in val_loss if v is not None]
val_acc_f = [v for v in val_acc if v is not None]

fig, axes = plt.subplots(1, 2, figsize=(12, 4), dpi=140)

axes[0].plot(epochs, train_acc, label="Training Accuracy")
if val_acc_f:
    axes[0].plot(val_epochs, val_acc_f, label="Validation Accuracy")
axes[0].set_title("Accuracy Curves")
axes[0].set_xlabel("Epochs")
axes[0].set_ylabel("Accuracy")
axes[0].legend()

axes[1].plot(epochs, train_loss, label="Training Loss")
if val_loss_f:
    axes[1].plot(val_epochs, val_loss_f, label="Validation Loss")
axes[1].set_title("Loss Curves")
axes[1].set_xlabel("Epochs")
axes[1].set_ylabel("Loss")
axes[1].legend()

plt.tight_layout()
out_png = latest_run / "learning_curves.png"
fig.savefig(out_png, bbox_inches="tight")
plt.show()
print("âœ… Saved:", out_png)


## âœ… Training Complete!

**Outputs saved to Drive:**
- `CKPT_ROOT/runs/{RUN_TAG}/{timestamp}/best.pt` â€” Best checkpoint
- `CKPT_ROOT/runs/{RUN_TAG}/{timestamp}/history.json` â€” Training history
- `CKPT_ROOT/runs/{RUN_TAG}/{timestamp}/learning_curves.png` â€” Curves figure

**Next steps:**
- Run **03_evaluation.ipynb** to generate predictions and metrics
