# Stage 1 Ferning Classification – Full Pipeline Notebook

**Run cells top to bottom.** This notebook is fully self-contained:

| Cell | What it does |
|------|--------------|
| 1 | Imports & GPU/CPU setup |
| 2 | Path resolution (auto-detects your dataset folder) |
| 3 | Package check (same as old `setup_verify.py`) |
| 4 | **Generate** `master_patch_index.csv` from your `.npy` files |
| 5 | **Generate** `fold_splits.csv` (stratified, leakage-free) |
| 6 | Data verification (same as old `verify_data.py`) |
| 7 | Data loading & preprocessing utilities |
| 8 | Generator + model definition |
| 9 | Evaluation helper |
| 10 | Training loop (cross-validation) |
| 11 | Results summary |

---
## Cell 1 — Imports & GPU/CPU Setup

In [None]:
import os
import re
import sys
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import EfficientNetB3
from sklearn.metrics import confusion_matrix, roc_auc_score
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import StratifiedKFold

# Fix Windows console encoding
if sys.platform == 'win32':
    try:
        sys.stdout.reconfigure(encoding='utf-8')
    except Exception:
        pass

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

print(f"Python version     : {sys.version.split()[0]}")
print(f"TensorFlow version : {tf.__version__}")

physical_gpus = tf.config.list_physical_devices("GPU")
if physical_gpus:
    for gpu in physical_gpus:
        try:
            tf.config.experimental.set_memory_growth(gpu, True)
        except Exception as e:
            print(f"  Could not set memory growth: {e}")
    print(f"Device             : GPU {[g.name for g in physical_gpus]}")
else:
    print("Device             : CPU (no GPU detected - training will be slower)")

---
## Cell 2 - Configuration & Path Resolution

The notebook searches for your `dataset/` folder automatically in this order:
```
1. <notebook_dir>/../local/data/     <- README default
2. <notebook_dir>/local/data/
3. <notebook_dir>/data/
4. DATA_ROOT_OVERRIDE                <- set this manually if needed
```
Your dataset folder should look like:
```
data/
  dataset/
    PF/   <- Partial Ferning .npy files
    CF/   <- Complete Ferning .npy files
    NF/   <- No Ferning .npy files
  master_patch_index.csv   <- auto-generated by Cell 4
  fold_splits.csv          <- auto-generated by Cell 5
```

In [None]:
# ---- Optional manual override ------------------------------------------------
# Set this to a string if auto-detection fails, e.g.:
#   DATA_ROOT_OVERRIDE = r"C:/Users/YourName/my_project/local/data"
DATA_ROOT_OVERRIDE = None
# ------------------------------------------------------------------------------

# Training hyperparameters
EPOCHS               = 10
BATCH_SIZE           = 32
LEARNING_RATE        = 1e-4
INPUT_SHAPE          = (64, 64, 3)
NUM_CLASSES          = 2
N_FOLDS              = 5
CLASS_NAMES          = ["No Ferning", "Ferning"]
FERNING_CLASSES      = ["PF", "CF"]   # classes that map to label=1
KNOWN_CLASSES        = ["PF", "CF", "NF"]
POSITIVE_CLASS_INDEX = 1
THRESHOLD            = 0.5
RANDOM_SEED          = 42

# ---- Locate notebook directory -----------------------------------------------
try:
    NOTEBOOK_DIR = Path(__file__).resolve().parent
except NameError:
    NOTEBOOK_DIR = Path.cwd()

# ---- Search for data root ----------------------------------------------------
if DATA_ROOT_OVERRIDE:
    _candidates = [Path(DATA_ROOT_OVERRIDE)]
else:
    _candidates = [
        NOTEBOOK_DIR.parent / "local" / "data",
        NOTEBOOK_DIR / "local" / "data",
        NOTEBOOK_DIR / "data",
    ]

def _find_data_root(candidates):
    """Return the first candidate that contains a dataset/ subfolder."""
    for root in candidates:
        if (root / "dataset").exists():
            return root
    return candidates[0]  # fallback: use first candidate, cells will create folders

DATA_ROOT         = _find_data_root(_candidates)
DATASET_DIR       = DATA_ROOT / "dataset"
MASTER_INDEX_PATH = DATA_ROOT / "master_patch_index.csv"
FOLD_SPLITS_PATH  = DATA_ROOT / "fold_splits.csv"
OUTPUT_DIR        = NOTEBOOK_DIR / "outputs"

DATA_ROOT.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f"Notebook dir      : {NOTEBOOK_DIR}")
print(f"Data root         : {DATA_ROOT}")
print(f"Dataset folder    : {DATASET_DIR}  {'[EXISTS]' if DATASET_DIR.exists() else '[NOT FOUND - place .npy files here]'}")
print(f"Master index      : {'[EXISTS]' if MASTER_INDEX_PATH.exists() else '[will be generated]'}")
print(f"Fold splits       : {'[EXISTS]' if FOLD_SPLITS_PATH.exists() else '[will be generated]'}")
print(f"Output dir        : {OUTPUT_DIR}")

---
## Cell 3 - Package Check
Same dependency check as the old `setup_verify.py`.

In [None]:
print("=" * 60)
print("PACKAGE CHECK")
print("=" * 60)

required_packages = {
    'numpy':      'numpy',
    'pandas':     'pandas',
    'tensorflow': 'tensorflow',
    'sklearn':    'scikit-learn',
    'matplotlib': 'matplotlib',
    'cv2':        'opencv-python',
}

missing_packages = []
for module_name, package_name in required_packages.items():
    try:
        __import__(module_name)
        print(f"  [OK] {package_name}")
    except ImportError:
        print(f"  [X]  {package_name}  <- MISSING")
        missing_packages.append(package_name)

if missing_packages:
    print(f"\n  Install missing packages with:")
    print(f"    pip install {' '.join(missing_packages)}")
    raise ImportError(f"Missing packages: {missing_packages}")
else:
    print("\n  [OK] All packages installed - ready to continue.")

---
## Cell 4 - Generate `master_patch_index.csv`

Scans `dataset/PF/`, `dataset/CF/`, `dataset/NF/` and builds the index CSV.

**Skips generation if the file already exists.** Set `FORCE_REGENERATE = True` to rebuild it.

In [None]:
FORCE_REGENERATE = False   # set True to rebuild even if CSV already exists

def generate_master_index(dataset_dir: Path, output_path: Path) -> pd.DataFrame:
    """
    Walk dataset_dir/PF, dataset_dir/CF, dataset_dir/NF and build
    master_patch_index.csv with columns:

        sample_id    - derived from filename (everything before the last '_')
                       e.g. 'sample001_patch03.npy' -> sample_id = 'sample001'
                       If there is no '_', the full stem is the sample_id.
        patch_id     - full filename stem
        class        - PF / CF / NF
        label_stage1 - 1 for PF/CF (Ferning), 0 for NF (No Ferning)
        path         - absolute path to the .npy file
    """
    if not dataset_dir.exists():
        raise FileNotFoundError(
            f"Dataset folder not found: {dataset_dir}\n"
            f"Please create it and place your .npy files in subfolders PF/, CF/, NF/."
        )

    records = []
    for class_name in KNOWN_CLASSES:
        class_dir  = dataset_dir / class_name
        if not class_dir.exists():
            print(f"  [!] Class folder not found, skipping: {class_dir}")
            continue

        npy_files = sorted(class_dir.glob("*.npy"))
        if not npy_files:
            print(f"  [!] No .npy files found in {class_dir}")
            continue

        print(f"  {class_name}: {len(npy_files):,} patches found")

        for fp in npy_files:
            stem      = fp.stem
            parts     = stem.rsplit("_", 1)
            sample_id = parts[0] if len(parts) > 1 else stem

            records.append({
                "sample_id"   : sample_id,
                "patch_id"    : stem,
                "class"       : class_name,
                "label_stage1": 1 if class_name in FERNING_CLASSES else 0,
                "path"        : str(fp.resolve()),
            })

    if not records:
        raise RuntimeError(
            "No .npy files found in any class folder. "
            "Check that your dataset is placed correctly."
        )

    df = pd.DataFrame(records)
    df.to_csv(output_path, index=False)
    print(f"\n  Saved -> {output_path}")
    return df


print("=" * 60)
print("GENERATING master_patch_index.csv")
print("=" * 60)

if MASTER_INDEX_PATH.exists() and not FORCE_REGENERATE:
    print(f"  [SKIP] Already exists - loading from disk.")
    print(f"         Set FORCE_REGENERATE = True to rebuild.")
    master_index = pd.read_csv(MASTER_INDEX_PATH)
else:
    master_index = generate_master_index(DATASET_DIR, MASTER_INDEX_PATH)

print(f"\n  Total patches  : {len(master_index):,}")
print(f"  Unique samples : {master_index['sample_id'].nunique():,}")
print("\n  Class breakdown:")
for cls, count in master_index['class'].value_counts().items():
    label = 1 if cls in FERNING_CLASSES else 0
    print(f"    {cls} (label={label}): {count:,} patches")
print(f"\n  Stage-1 binary:")
print(f"    Ferning  (1): {(master_index['label_stage1']==1).sum():,}")
print(f"    No Fern  (0): {(master_index['label_stage1']==0).sum():,}")
print("\n  Preview:")
display(master_index.head(5))

---
## Cell 5 - Generate `fold_splits.csv`

Creates a **stratified** 5-fold split at the **sample level** — all patches from one
sample always land in the same split, which prevents data leakage between folds.

**Skips generation if the file already exists.** Set `FORCE_REGENERATE_FOLDS = True` to rebuild.

In [None]:
FORCE_REGENERATE_FOLDS = False   # set True to rebuild even if CSV already exists

def generate_fold_splits(
    master_index: pd.DataFrame,
    output_path: Path,
    n_folds: int = 5,
    seed: int = 42,
) -> pd.DataFrame:
    """
    Build fold_splits.csv with columns: sample_id, fold, split

    Strategy:
      - Collapse to one row per sample_id
      - Each sample gets its majority class label (for stratification)
      - StratifiedKFold ensures each fold has a balanced class ratio
      - All patches of a sample inherit that sample's fold assignment
    """
    sample_df = (
        master_index
        .groupby("sample_id")["label_stage1"]
        .agg(lambda x: int(x.mode()[0]))
        .reset_index()
        .rename(columns={"label_stage1": "majority_label"})
    )

    sample_ids = sample_df["sample_id"].values
    labels     = sample_df["majority_label"].values

    skf     = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
    records = []

    for fold_num, (train_idx, val_idx) in enumerate(skf.split(sample_ids, labels), start=1):
        for idx in train_idx:
            records.append({"sample_id": sample_ids[idx], "fold": fold_num, "split": "train"})
        for idx in val_idx:
            records.append({"sample_id": sample_ids[idx], "fold": fold_num, "split": "val"})

    fold_splits_df = pd.DataFrame(records)
    fold_splits_df.to_csv(output_path, index=False)
    print(f"  Saved -> {output_path}")
    return fold_splits_df


print("=" * 60)
print("GENERATING fold_splits.csv")
print("=" * 60)

if FOLD_SPLITS_PATH.exists() and not FORCE_REGENERATE_FOLDS:
    print(f"  [SKIP] Already exists - loading from disk.")
    print(f"         Set FORCE_REGENERATE_FOLDS = True to rebuild.")
    fold_splits = pd.read_csv(FOLD_SPLITS_PATH)
else:
    fold_splits = generate_fold_splits(
        master_index, FOLD_SPLITS_PATH, n_folds=N_FOLDS, seed=RANDOM_SEED
    )

print(f"\n  Fold summary (sample level):")
for fold_num in range(1, N_FOLDS + 1):
    fold_data   = fold_splits[fold_splits["fold"] == fold_num]
    train_ids   = set(fold_data[fold_data["split"] == "train"]["sample_id"])
    val_ids     = set(fold_data[fold_data["split"] == "val"]["sample_id"])
    train_patch = len(master_index[master_index["sample_id"].isin(train_ids)])
    val_patch   = len(master_index[master_index["sample_id"].isin(val_ids)])
    print(f"  Fold {fold_num}: {len(train_ids):3d} train samples ({train_patch:,} patches) | "
          f"{len(val_ids):3d} val samples ({val_patch:,} patches)")

print("\n  Preview:")
display(fold_splits.head(10))

---
## Cell 6 - Data Verification
Runs all checks from the old `verify_data.py`. Stops early if anything is wrong.

In [None]:
print("=" * 60)
print("DATA VERIFICATION")
print("=" * 60)

_checks_passed = True

# ---- 1. Required columns -----------------------------------------------------
print("\n[1/4] Required Columns")
REQUIRED_COLS = ["sample_id", "class", "path", "label_stage1"]
missing_cols  = [c for c in REQUIRED_COLS if c not in master_index.columns]
if missing_cols:
    print(f"  [X] Missing columns: {missing_cols}")
    _checks_passed = False
else:
    print("  [OK] All required columns present")

# ---- 2. File existence -------------------------------------------------------
print("\n[2/4] .npy File Existence")
all_paths  = master_index["path"].tolist()
n_existing = sum(1 for p in all_paths if os.path.exists(p))
n_missing  = len(all_paths) - n_existing

print(f"  Total   : {len(all_paths):,}")
print(f"  [OK] Found  : {n_existing:,}")

if n_missing > 0:
    print(f"  [!] Missing : {n_missing:,}")
    for p in [p for p in all_paths if not os.path.exists(p)][:3]:
        print(f"       {p}")
    if n_existing == 0:
        print("  [X] Zero files accessible - check your dataset folder.")
        _checks_passed = False
    else:
        print(f"  [!] Trimming master_index to {n_existing:,} accessible rows.")
        master_index = master_index[master_index["path"].apply(os.path.exists)].copy()
else:
    print(f"  [OK] All files found")

# ---- 3. .npy format spot-check -----------------------------------------------
print("\n[3/4] .npy Format Spot-Check")
existing_paths = [p for p in master_index["path"].tolist() if os.path.exists(p)]
if existing_paths:
    _s = np.load(existing_paths[0])
    print(f"  Shape      : {_s.shape}")
    print(f"  dtype      : {_s.dtype}")
    print(f"  Value range: [{_s.min():.3f}, {_s.max():.3f}]")
    print(f"  {'[OK] Shape compatible' if _s.shape in [(64,64,3),(64,64)] else '[!] Unexpected shape - preprocess_npy will attempt to handle it'}")
    print(f"  {'[OK] Normalised to [0,1]' if _s.max() <= 1.0 else '[!] Values in [0,255] - will normalise during loading'}")
else:
    print("  [X] No files to spot-check")
    _checks_passed = False

# ---- 4. Fold leakage check ---------------------------------------------------
print("\n[4/4] Fold Splits & Leakage Check")
_leakage_found = False
for fold_num in range(1, N_FOLDS + 1):
    fold_data = fold_splits[fold_splits["fold"] == fold_num]
    train_ids = set(fold_data[fold_data["split"] == "train"]["sample_id"])
    val_ids   = set(fold_data[fold_data["split"] == "val"]["sample_id"])
    overlap   = train_ids & val_ids
    train_pat = len(master_index[master_index["sample_id"].isin(train_ids)])
    val_pat   = len(master_index[master_index["sample_id"].isin(val_ids)])
    status    = "[OK]" if not overlap else "[X] LEAKAGE DETECTED"
    print(f"  Fold {fold_num}: {len(train_ids):3d} train samples ({train_pat:,} patches) | "
          f"{len(val_ids):3d} val samples ({val_pat:,} patches)  {status}")
    if overlap:
        print(f"    Overlapping IDs: {list(overlap)[:5]}")
        _leakage_found = True

if _leakage_found:
    print("  [X] Leakage detected - set FORCE_REGENERATE_FOLDS=True in Cell 5 and re-run.")
    _checks_passed = False

# ---- Sample visualisation ----------------------------------------------------
print("\n[Visual] Sample Patches per Class")
try:
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    for ax, cls in zip(axes, ["PF", "CF", "NF"]):
        cls_paths = master_index[master_index["class"] == cls]["path"].tolist()
        hit = next((p for p in cls_paths if os.path.exists(p)), None)
        if hit:
            img = np.load(hit)
            if img.ndim == 2:
                img = np.stack([img] * 3, axis=-1)
            if img.max() > 1.0:
                img = img / 255.0
            ax.imshow(np.clip(img, 0, 1))
            ax.set_title(cls, fontsize=13, fontweight='bold')
        else:
            ax.set_title(f"{cls} (no file)")
        ax.axis("off")
    plt.suptitle("One sample patch per class", y=1.02)
    plt.tight_layout()
    vis_path = OUTPUT_DIR / "verification_samples.png"
    plt.savefig(vis_path, dpi=150, bbox_inches="tight")
    plt.show()
    print(f"  Saved to {vis_path}")
except Exception as e:
    print(f"  [!] Could not render visualisation: {e}")

# ---- Summary -----------------------------------------------------------------
print("\n" + "=" * 60)
if _checks_passed:
    print("[OK] ALL CHECKS PASSED - ready to train!")
else:
    print("[X]  SOME CHECKS FAILED - fix the issues above before continuing.")
    raise RuntimeError("Data verification failed. See output above.")
print("=" * 60)

---
## Cell 7 - Data Loading & Preprocessing Utilities

In [None]:
def load_fold_data(fold_num: int):
    """
    Return (train_df, val_df) for the given fold.
    Splits at the sample level so all patches from one sample
    always stay in the same split - no leakage possible.
    """
    fold_data = fold_splits[fold_splits["fold"] == fold_num]
    train_ids = fold_data[fold_data["split"] == "train"]["sample_id"].tolist()
    val_ids   = fold_data[fold_data["split"] == "val"]["sample_id"].tolist()

    train_df  = master_index[master_index["sample_id"].isin(train_ids)].copy()
    val_df    = master_index[master_index["sample_id"].isin(val_ids)].copy()

    print(f"  Fold {fold_num}: {len(train_ids)} train samples ({len(train_df):,} patches) | "
          f"{len(val_ids)} val samples ({len(val_df):,} patches)")
    return train_df, val_df


def preprocess_npy(npy_path) -> np.ndarray:
    """Load a .npy patch and return a normalised (64, 64, 3) float32 array."""
    img = np.load(npy_path)

    # Already ImageNet-normalised (has negative values)
    if img.min() < 0:
        img = img.astype(np.float32)
        if img.ndim == 2:
            img = np.stack([img] * 3, axis=-1)
        elif img.shape[-1] == 1:
            img = np.repeat(img, 3, axis=-1)
        return img

    # Ensure 3-channel
    if img.ndim == 2:
        img = np.stack([img] * 3, axis=-1)
    elif img.shape[-1] == 1:
        img = np.repeat(img, 3, axis=-1)

    # Resize if needed
    if img.shape[:2] != (64, 64):
        img = cv2.resize(img.astype(np.float32), (64, 64))

    # Normalise to [0, 1]
    if img.max() > 1.0:
        img = img / 255.0

    img = img.astype(np.float32)

    # ImageNet mean/std normalisation
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
    std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)
    return (img - mean) / std


print("[OK] Data utilities defined")

---
## Cell 8 - Generator & Model

In [None]:
class NumpyDataGenerator(keras.utils.Sequence):
    """On-the-fly .npy data generator."""

    def __init__(self, dataframe: pd.DataFrame, batch_size: int = 32,
                 shuffle: bool = True):
        self.df         = dataframe.reset_index(drop=True)
        self.batch_size = batch_size
        self.shuffle    = shuffle
        self.n          = len(self.df)
        self.on_epoch_end()

    def __len__(self) -> int:
        return int(np.ceil(self.n / self.batch_size))

    def __getitem__(self, index: int):
        start = index * self.batch_size
        end   = min(start + self.batch_size, self.n)
        rows  = self.df.iloc[self.indices[start:end]]
        X     = np.array([preprocess_npy(p) for p in rows["path"]])
        y     = keras.utils.to_categorical(rows["label_stage1"].values, num_classes=NUM_CLASSES)
        return X, y

    def on_epoch_end(self):
        self.indices = np.arange(self.n)
        if self.shuffle:
            np.random.shuffle(self.indices)

    def reset(self):
        self.on_epoch_end()


def build_model(input_shape=INPUT_SHAPE, num_classes=NUM_CLASSES) -> keras.Model:
    """EfficientNetB3 backbone with a classification head."""
    try:
        base = EfficientNetB3(include_top=False, weights="imagenet",
                              input_shape=input_shape, pooling="avg")
    except Exception as e:
        print(f"  Could not load ImageNet weights, using random init: {e}")
        base = EfficientNetB3(include_top=False, weights=None,
                              input_shape=input_shape, pooling="avg")

    base.trainable = False  # frozen backbone - transfer learning

    inputs  = layers.Input(shape=input_shape)
    x       = base(inputs, training=False)
    x       = layers.Dropout(0.3)(x)
    x       = layers.Dense(128, activation="relu")(x)
    x       = layers.Dropout(0.2)(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    model = keras.Model(inputs, outputs, name="EfficientNetB3_stage1")
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss="categorical_crossentropy",
        metrics=["accuracy"],
    )
    return model


_test = build_model()
_test.summary()
del _test

---
## Cell 9 - Evaluation Helper

In [None]:
def evaluate_predictions(y_true, y_pred_proba, threshold=THRESHOLD, class_names=CLASS_NAMES):
    y_true       = np.asarray(y_true)
    y_pred_proba = np.asarray(y_pred_proba)

    if y_pred_proba.ndim > 1:
        y_pred_pos = y_pred_proba[:, POSITIVE_CLASS_INDEX]
        y_pred_cls = np.argmax(y_pred_proba, axis=1)
    else:
        y_pred_pos = y_pred_proba
        y_pred_cls = (y_pred_pos >= threshold).astype(int)

    cm = confusion_matrix(y_true, y_pred_cls)
    if cm.shape != (2, 2):
        raise ValueError(f"Expected (2,2) confusion matrix, got {cm.shape}")

    tn, fp, fn, tp = cm.ravel()
    sensitivity  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    specificity  = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    precision    = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    accuracy     = (tp + tn) / (tp + tn + fp + fn)
    balanced_acc = 0.5 * (sensitivity + specificity)
    f1           = (2 * precision * sensitivity / (precision + sensitivity)
                    if (precision + sensitivity) > 0 else 0.0)
    try:
        auc = roc_auc_score(y_true, y_pred_pos)
    except Exception:
        auc = np.nan

    print("\n" + "=" * 60)
    print("EVALUATION REPORT")
    print("=" * 60)
    print(f"Classes           : {class_names}")
    print(f"\nConfusion matrix:\n{cm}")
    print(f"\nSensitivity (TPR) : {sensitivity:.4f}")
    print(f"Specificity (TNR) : {specificity:.4f}")
    print(f"Balanced accuracy : {balanced_acc:.4f}")
    print(f"Accuracy          : {accuracy:.4f}")
    print(f"Precision (PPV)   : {precision:.4f}")
    print(f"F1-score          : {f1:.4f}")
    print(f"AUC-ROC           : {auc:.4f}" if not np.isnan(auc) else "AUC-ROC           : N/A")

    return dict(sensitivity=sensitivity, specificity=specificity,
                balanced_accuracy=balanced_acc, accuracy=accuracy,
                precision=precision, f1_score=f1, auc=auc,
                tp=int(tp), tn=int(tn), fp=int(fp), fn=int(fn))


print("[OK] Evaluation helper defined")

---
## Cell 10 - Training Loop (Cross-Validation)

In [None]:
all_results = []

for fold_num in range(1, N_FOLDS + 1):
    print("\n" + "=" * 60)
    print(f"FOLD {fold_num} / {N_FOLDS}")
    print("=" * 60)

    fold_output_dir = OUTPUT_DIR / f"fold{fold_num}"
    fold_output_dir.mkdir(parents=True, exist_ok=True)

    train_df, val_df = load_fold_data(fold_num)

    # Class weights to handle imbalance
    y_train       = train_df["label_stage1"].values
    classes       = np.unique(y_train)
    cw_arr        = compute_class_weight(class_weight="balanced", classes=classes, y=y_train)
    class_weights = {int(c): float(w) for c, w in zip(classes, cw_arr)}
    print(f"  Class weights: {class_weights}")

    train_gen = NumpyDataGenerator(train_df, batch_size=BATCH_SIZE, shuffle=True)
    val_gen   = NumpyDataGenerator(val_df,   batch_size=BATCH_SIZE, shuffle=False)

    tf.keras.backend.clear_session()
    model = build_model()

    callbacks = [
        keras.callbacks.ModelCheckpoint(
            filepath=str(fold_output_dir / "best_model.h5"),
            monitor="val_accuracy", mode="max",
            save_best_only=True, verbose=1,
        ),
        keras.callbacks.EarlyStopping(
            monitor="val_accuracy", patience=5,
            restore_best_weights=True, verbose=1,
        ),
        keras.callbacks.CSVLogger(str(fold_output_dir / "history.csv")),
    ]

    start_time = datetime.now()
    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=EPOCHS,
        class_weight=class_weights,
        callbacks=callbacks,
        verbose=1,
    )
    elapsed = (datetime.now() - start_time).total_seconds()
    print(f"  Training time: {elapsed:.0f}s")

    val_gen.reset()
    y_pred_proba = model.predict(val_gen, verbose=0)
    y_true       = val_df["label_stage1"].values

    metrics = evaluate_predictions(y_true, y_pred_proba)
    metrics["fold"]          = fold_num
    metrics["train_patches"] = int(len(train_df))
    metrics["val_patches"]   = int(len(val_df))
    metrics["train_time_s"]  = int(elapsed)
    all_results.append(metrics)

    # Save incrementally so a crash doesn't lose earlier folds
    pd.DataFrame(all_results).to_csv(OUTPUT_DIR / "all_results.csv", index=False)
    print(f"  Results saved -> {OUTPUT_DIR / 'all_results.csv'}")

---
## Cell 11 - Results Summary

In [None]:
results_df  = pd.DataFrame(all_results)
metric_cols = ["sensitivity", "specificity", "balanced_accuracy", "accuracy", "f1_score", "auc"]
summary     = results_df[metric_cols].agg(["mean", "std"]).T
summary.columns = ["Mean", "Std"]

print("\n" + "=" * 60)
print("CROSS-VALIDATION SUMMARY")
print("=" * 60)
print(summary.to_string(float_format="{:.4f}".format))
print("\nPer-fold results:")
display(results_df[["fold"] + metric_cols])