In [1]:
#!/usr/bin/env python3
"""
wavenet_overfit_holdout.py

Removes k-fold cross validation.
Uses:
  (A) time-series-safe HOLDOUT split (contiguous blocks of window-files)
  (B) explicit overfitting test on a tiny fixed dataset (no resampling)

Datasets expected:
  data_windows_fast/12m , /6m , /3m
Window files: .parquet or .csv(.gz)
Must contain 'utc' (optional) and VALUE_COL (default 'max').

Outputs:
  trained_wavenet_holdout_tests/<months>m/
    - overfit_test_summary.json
    - holdout_summary.json
    - checkpoints_overfit/{best.pt,last.pt}
    - checkpoints_holdout/{best.pt,last.pt}
    - plots/*.png
"""

from __future__ import annotations

import time
import json
import logging
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt

from wavenet_anna import (
    EEGWaveNetCategorical,
    RandomWaveNetSegments,
    dilations_1s_context,
    receptive_field,
    mu_law_decode_np,
    mu_law_encode_np,
)

# ----------------------------
# Global settings
# ----------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WINDOW_ROOT = Path("data_windows_fast")
WINDOW_MONTHS_LIST = (12, 6, 3)
VALUE_COL = "max"
RUNS_DIR = Path("trained_wavenet_holdout_tests")
RECURSIVE = True

# ----------------------------
# Config
# ----------------------------
@dataclass
class Cfg:
    n_bins: int = 256
    kernel_size: int = 2
    n_filters: int = 16

    lr: float = 1e-3
    lr_decay_gamma: float = 0.99
    batch_size: int = 32
    epochs: int = 200
    train_samples_per_epoch: int = 6000
    val_samples_fixed: int = 1500
    early_stop_patience: int = 10
    grad_clip: float = 1.0

    # mu-law range (set automatically from train unless overridden)
    amp_min: float = -3.0
    amp_max: float = 3.0

    receptive_field: int = 1024
    seq_len: int = 0


# ----------------------------
# Logging
# ----------------------------
def setup_logger(out_dir: Path, name: str) -> logging.Logger:
    out_dir.mkdir(parents=True, exist_ok=True)
    lg = logging.getLogger(name)
    lg.setLevel(logging.INFO)
    lg.handlers.clear()
    fmt = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s")

    ch = logging.StreamHandler()
    ch.setFormatter(fmt)
    lg.addHandler(ch)

    fh = logging.FileHandler(out_dir / "run.log")
    fh.setFormatter(fmt)
    lg.addHandler(fh)
    return lg


# ----------------------------
# Window loading helpers
# ----------------------------
def iter_window_files(ds_dir: Path) -> List[Path]:
    pats = ["**/*.parquet", "**/*.csv.gz", "**/*.csv"] if RECURSIVE else ["*.parquet", "*.csv.gz", "*.csv"]
    files: List[Path] = []
    for pat in pats:
        files.extend(ds_dir.glob(pat))
    return sorted(set(files))

def read_window_file(path: Path) -> pd.DataFrame:
    if path.suffix.lower() == ".parquet":
        return pd.read_parquet(path)
    if path.name.lower().endswith(".csv.gz"):
        return pd.read_csv(path, compression="gzip")
    if path.suffix.lower() == ".csv":
        return pd.read_csv(path)
    raise ValueError(f"Unsupported window file: {path}")

def load_series_from_window_file(path: Path, value_col: str = "max") -> Optional[np.ndarray]:
    df = read_window_file(path)
    if "utc" in df.columns:
        df["utc"] = pd.to_datetime(df["utc"], errors="coerce")
        df = df.dropna(subset=["utc"]).sort_values("utc")
    if value_col not in df.columns:
        return None
    y = pd.to_numeric(df[value_col], errors="coerce").dropna().to_numpy(dtype=np.float64)
    if y.size < 10:
        return None
    return y

def files_to_series_list(files: List[Path], min_len: int, value_col: str) -> List[np.ndarray]:
    out: List[np.ndarray] = []
    for p in files:
        y = load_series_from_window_file(p, value_col=value_col)
        if y is None or len(y) < min_len:
            continue
        out.append(y.astype(np.float64))
    return out


# ----------------------------
# Time-series-safe holdout split (by contiguous file blocks)
# ----------------------------
def holdout_split_files(files_sorted: List[Path], train_frac: float, val_frac: float) -> Tuple[List[Path], List[Path], List[Path]]:
    n = len(files_sorted)
    n_train = int(round(n * train_frac))
    n_val = int(round(n * val_frac))
    n_train = max(1, min(n_train, n - 2))
    n_val = max(1, min(n_val, n - n_train - 1))
    train_files = files_sorted[:n_train]
    val_files = files_sorted[n_train:n_train + n_val]
    test_files = files_sorted[n_train + n_val:]
    return train_files, val_files, test_files


# ----------------------------
# Fixed dataset for overfitting test (NO resampling)
# ----------------------------
class FixedWaveNetSegments(Dataset):
    def __init__(
        self,
        epochs_1d: List[np.ndarray],
        seq_len: int,
        amp_min: float,
        amp_max: float,
        n_bins: int,
        fixed_pairs: List[Tuple[int, int]],
    ):
        self.epochs = epochs_1d
        self.seq_len = int(seq_len)
        self.amp_min = float(amp_min)
        self.amp_max = float(amp_max)
        self.n_bins = int(n_bins)
        self.pairs = list(fixed_pairs)

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx: int):
        e, s = self.pairs[idx]
        seg = self.epochs[e][s:s + self.seq_len].astype(np.float64)
        y = mu_law_encode_np(seg, mu=self.n_bins - 1, amp_max=self.amp_max)
        x_in = np.empty_like(y)
        x_in[0] = 0
        x_in[1:] = y[:-1]
        return torch.from_numpy(x_in.astype(np.int64)), torch.from_numpy(y.astype(np.int64))


def sample_fixed_pairs(
    epochs_1d: List[np.ndarray],
    seq_len: int,
    n_samples: int,
    amp_min: float,
    amp_max: float,
    n_bins: int,
    seed: int = 0,
) -> List[Tuple[int, int]]:
    tmp = RandomWaveNetSegments(
        epochs_1d=epochs_1d,
        seq_len=seq_len,
        n_samples=n_samples,
        amp_min=amp_min,
        amp_max=amp_max,
        n_bins=n_bins,
        rng=np.random.default_rng(seed),
    )
    return list(tmp.pairs)


# ----------------------------
# Training loop with history (so you can see overfit curves)
# ----------------------------
def train_with_history(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    rf: int,
    cfg: Cfg,
    device: torch.device,
    logger: logging.Logger,
    ckpt_dir: Path,
) -> dict:
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    best_path = ckpt_dir / "best.pt"
    last_path = ckpt_dir / "last.pt"

    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    sched = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=float(cfg.lr_decay_gamma))

    best_val = float("inf")
    best_epoch = 0
    bad = 0

    hist = {"epoch": [], "train_ce": [], "val_ce": [], "lr": []}

    for ep in range(1, cfg.epochs + 1):
        t0 = time.perf_counter()
        model.train()
        tr_losses = []

        for xb, yb in train_loader:
            xb = xb.to(device)
            yb = yb.to(device)
            opt.zero_grad(set_to_none=True)

            logits = model(xb)
            logits_v = logits[:, :, rf:]
            y_v = yb[:, rf:]

            loss = F.cross_entropy(logits_v.permute(0, 2, 1).reshape(-1, cfg.n_bins), y_v.reshape(-1))
            loss.backward()
            if cfg.grad_clip and cfg.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            opt.step()
            tr_losses.append(float(loss.item()))

        model.eval()
        va_losses = []
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device)
                yb = yb.to(device)
                logits = model(xb)
                logits_v = logits[:, :, rf:]
                y_v = yb[:, rf:]
                loss = F.cross_entropy(logits_v.permute(0, 2, 1).reshape(-1, cfg.n_bins), y_v.reshape(-1))
                va_losses.append(float(loss.item()))

        tr = float(np.mean(tr_losses)) if tr_losses else float("inf")
        va = float(np.mean(va_losses)) if va_losses else float("inf")
        lr_now = float(opt.param_groups[0]["lr"])
        dt = time.perf_counter() - t0

        hist["epoch"].append(ep)
        hist["train_ce"].append(tr)
        hist["val_ce"].append(va)
        hist["lr"].append(lr_now)

        logger.info(f"epoch {ep:03d}/{cfg.epochs} | train_ce={tr:.6e} | val_ce={va:.6e} | lr={lr_now:.3e} | dt={dt:.2f}s")

        improved = (va + 1e-12) < best_val
        if improved:
            best_val = va
            best_epoch = ep
            bad = 0
            torch.save(
                {"epoch": ep, "best_val": best_val, "best_epoch": best_epoch, "model": model.state_dict(),
                 "opt": opt.state_dict(), "sched": sched.state_dict(), "bad": bad},
                best_path,
            )
        else:
            bad += 1

        sched.step()

        torch.save(
            {"epoch": ep, "best_val": best_val, "best_epoch": best_epoch, "model": model.state_dict(),
             "opt": opt.state_dict(), "sched": sched.state_dict(), "bad": bad},
            last_path,
        )

        if bad >= int(cfg.early_stop_patience):
            logger.info(f"early stop @ epoch={ep} (patience={cfg.early_stop_patience}) best_epoch={best_epoch} best_val={best_val:.6e}")
            break

    # reload best
    if best_path.exists():
        best = torch.load(best_path, map_location="cpu", weights_only=False)
        model.load_state_dict(best["model"])

    return {"history": hist, "best_val_ce": float(best_val), "best_epoch": int(best_epoch)}


# ----------------------------
# Evaluation helpers (teacher-forced CE + expected-value metrics)
# ----------------------------
@torch.no_grad()
def eval_teacher_forced_ce(model: EEGWaveNetCategorical, loader: DataLoader, cfg: Cfg, device: torch.device, max_batches: int = 300) -> float:
    model.eval()
    rf = int(cfg.receptive_field)
    losses = []
    for bi, (xb, yb) in enumerate(loader):
        if bi >= max_batches:
            break
        xb = xb.to(device)
        yb = yb.to(device)
        logits = model(xb)
        logits_v = logits[:, :, rf:]
        y_v = yb[:, rf:]
        loss = F.cross_entropy(logits_v.permute(0, 2, 1).reshape(-1, cfg.n_bins), y_v.reshape(-1))
        losses.append(float(loss.item()))
    return float(np.mean(losses)) if losses else float("nan")

@torch.no_grad()
def eval_teacher_forced_metrics(model: EEGWaveNetCategorical, loader: DataLoader, cfg: Cfg, device: torch.device, max_batches: int = 300) -> dict:
    """
    Computes MSE/MAE/Corr on expected amplitude under the predicted distribution (teacher forced).
    Uses TRUE mu-law decoded bin values (not linear centers).
    """
    model.eval()
    rf = int(cfg.receptive_field)

    bin_vals = mu_law_decode_np(np.arange(cfg.n_bins), mu=cfg.n_bins - 1, amp_max=cfg.amp_max).astype(np.float64)
    bin_vals_t = torch.tensor(bin_vals, device=device, dtype=torch.float32)  # (K,)

    y_true_all, y_pred_all = [], []

    for bi, (xb, yb) in enumerate(loader):
        if bi >= max_batches:
            break
        xb = xb.to(device)
        yb = yb.to(device)

        logits = model(xb)              # [B,K,T]
        logits_v = logits[:, :, rf:]    # [B,K,T-rf]
        y_v = yb[:, rf:]                # [B,T-rf]

        probs = torch.softmax(logits_v, dim=1)                     # [B,K,T-rf]
        y_pred = (probs * bin_vals_t.view(1, -1, 1)).sum(dim=1)    # [B,T-rf]

        y_true = mu_law_decode_np(
            y_v.detach().cpu().numpy(),
            mu=cfg.n_bins - 1,
            amp_max=cfg.amp_max,
        ).reshape(-1)

        y_pred = y_pred.detach().cpu().numpy().reshape(-1)

        y_true_all.append(y_true)
        y_pred_all.append(y_pred)

    if not y_true_all:
        return {"MSE": float("nan"), "MAE": float("nan"), "Corr": float("nan"), "N": 0.0}

    yt = np.concatenate(y_true_all)
    yp = np.concatenate(y_pred_all)
    mse = float(np.mean((yt - yp) ** 2))
    mae = float(np.mean(np.abs(yt - yp)))
    corr = float(np.corrcoef(yt, yp)[0, 1]) if (yt.size > 1 and np.std(yt) > 0 and np.std(yp) > 0) else float("nan")
    return {"MSE": mse, "MAE": mae, "Corr": corr, "N": float(yt.size)}


# ----------------------------
# Amp range helper (important)
# ----------------------------
def auto_amp_from_train(tr_list: List[np.ndarray], pct: float = 99.9, sample_per_series: int = 20000) -> float:
    xs = []
    for ep in tr_list:
        if len(ep) == 0:
            continue
        take = min(len(ep), sample_per_series)
        xs.append(np.abs(ep[:take]))
    xcat = np.concatenate(xs) if xs else np.array([1.0], dtype=np.float64)
    return float(np.percentile(xcat, pct))


# ----------------------------
# Main per-month runner
# ----------------------------
def run_one_months(
    months: int,
    train_frac: float = 0.70,
    val_frac: float = 0.15,
    auto_amp_pct: float = 99.9,
    overfit_n_files: int = 2,
    overfit_fixed_segments: int = 1500,
    overfit_epochs: int = 250,
    holdout_epochs: int = 200,
):
    run_dir = RUNS_DIR / f"{months}m"
    ckpt_overfit = run_dir / "checkpoints_overfit"
    ckpt_holdout = run_dir / "checkpoints_holdout"
    plots_dir = run_dir / "plots"
    plots_dir.mkdir(parents=True, exist_ok=True)

    logger = setup_logger(run_dir, name=f"wavenet_holdout_{months}m")

    # list & split files (contiguous blocks)
    ds_dir = WINDOW_ROOT / f"{months}m"
    all_files = iter_window_files(ds_dir)
    all_files = sorted(all_files, key=lambda p: str(p))  # deterministic
    if len(all_files) < 5:
        raise RuntimeError(f"Not enough window files in {ds_dir} (found {len(all_files)})")

    train_files, val_files, test_files = holdout_split_files(all_files, train_frac=train_frac, val_frac=val_frac)
    logger.info(f"[FILES] months={months} total={len(all_files)} train={len(train_files)} val={len(val_files)} test={len(test_files)}")

    cfg = Cfg()
    dils = dilations_1s_context()
    cfg.receptive_field = int(receptive_field(cfg.kernel_size, dils))
    cfg.seq_len = cfg.receptive_field + 120
    min_len = cfg.seq_len + 10

    # load series
    tr_list = files_to_series_list(train_files, min_len=min_len, value_col=VALUE_COL)
    va_list = files_to_series_list(val_files,   min_len=min_len, value_col=VALUE_COL)
    te_list = files_to_series_list(test_files,  min_len=min_len, value_col=VALUE_COL)
    if not tr_list or not va_list or not te_list:
        raise RuntimeError(f"[DATA] insufficient usable series after filtering: train={len(tr_list)} val={len(va_list)} test={len(te_list)}")

    # auto amp range from TRAIN
    amp = auto_amp_from_train(tr_list, pct=auto_amp_pct)
    cfg.amp_max = amp
    cfg.amp_min = -amp
    logger.info(f"[AMP] auto pct={auto_amp_pct} -> amp_max={cfg.amp_max:.6e} amp_min={cfg.amp_min:.6e}")

    logger.info(f"[CFG] RF={cfg.receptive_field} seq_len={cfg.seq_len} device={DEVICE}")

    # ----------------------------
    # (B) Overfitting test: tiny fixed dataset
    # ----------------------------
    tiny_files = train_files[: max(1, min(overfit_n_files, len(train_files)))]
    tiny_list = files_to_series_list(tiny_files, min_len=min_len, value_col=VALUE_COL)
    if not tiny_list:
        raise RuntimeError("[OVERFIT] tiny_list empty; increase overfit_n_files or reduce min_len")

    cfg_overfit = Cfg(**asdict(cfg))
    cfg_overfit.epochs = int(overfit_epochs)
    cfg_overfit.early_stop_patience = max(20, cfg_overfit.early_stop_patience)  # allow real overfit
    cfg_overfit.train_samples_per_epoch = int(overfit_fixed_segments)
    cfg_overfit.val_samples_fixed = int(overfit_fixed_segments)

    # fixed pairs -> fixed dataset
    fixed_pairs = sample_fixed_pairs(
        epochs_1d=tiny_list,
        seq_len=cfg_overfit.seq_len,
        n_samples=cfg_overfit.train_samples_per_epoch,
        amp_min=cfg_overfit.amp_min,
        amp_max=cfg_overfit.amp_max,
        n_bins=cfg_overfit.n_bins,
        seed=123,
    )
    train_ds_fix = FixedWaveNetSegments(
        epochs_1d=tiny_list,
        seq_len=cfg_overfit.seq_len,
        amp_min=cfg_overfit.amp_min,
        amp_max=cfg_overfit.amp_max,
        n_bins=cfg_overfit.n_bins,
        fixed_pairs=fixed_pairs,
    )
    # for overfit test, val loader is SAME fixed dataset (measures memorization)
    val_ds_fix = train_ds_fix

    train_loader_fix = DataLoader(train_ds_fix, batch_size=cfg_overfit.batch_size, shuffle=True, num_workers=0, drop_last=True)
    val_loader_fix   = DataLoader(val_ds_fix,   batch_size=cfg_overfit.batch_size, shuffle=False, num_workers=0, drop_last=False)

    model_overfit = EEGWaveNetCategorical(n_bins=cfg_overfit.n_bins, n_filters=cfg_overfit.n_filters, kernel_size=cfg_overfit.kernel_size)
    logger.info(f"[OVERFIT] training on {len(tiny_files)} file(s), fixed_segments={len(train_ds_fix)} epochs={cfg_overfit.epochs}")

    out_overfit = train_with_history(
        model=model_overfit,
        train_loader=train_loader_fix,
        val_loader=val_loader_fix,
        rf=cfg_overfit.receptive_field,
        cfg=cfg_overfit,
        device=DEVICE,
        logger=logger,
        ckpt_dir=ckpt_overfit,
    )

    # compute CE on fixed set (train==val) and also on REAL holdout val
    # create a small REAL holdout val loader for comparison
    val_pairs = sample_fixed_pairs(va_list, cfg.seq_len, n_samples=min(1500, cfg.val_samples_fixed),
                                   amp_min=cfg.amp_min, amp_max=cfg.amp_max, n_bins=cfg.n_bins, seed=456)
    val_ds_real = FixedWaveNetSegments(va_list, cfg.seq_len, cfg.amp_min, cfg.amp_max, cfg.n_bins, val_pairs)
    val_loader_real = DataLoader(val_ds_real, batch_size=cfg.batch_size, shuffle=False, num_workers=0, drop_last=False)

    train_ce_fix = eval_teacher_forced_ce(model_overfit, train_loader_fix, cfg_overfit, DEVICE)
    val_ce_fix   = eval_teacher_forced_ce(model_overfit, val_loader_fix,   cfg_overfit, DEVICE)
    val_ce_real  = eval_teacher_forced_ce(model_overfit, val_loader_real,  cfg, DEVICE)

    overfit_summary = {
        "months": months,
        "tiny_files_used": [str(p) for p in tiny_files],
        "fixed_segments": int(len(train_ds_fix)),
        "best_epoch": int(out_overfit["best_epoch"]),
        "best_val_ce_fixedset": float(out_overfit["best_val_ce"]),
        "train_ce_fixedset": float(train_ce_fix),
        "val_ce_fixedset": float(val_ce_fix),
        "val_ce_real_holdout": float(val_ce_real),
        "gap_fixed_train_minus_real_val": float(train_ce_fix - val_ce_real),
        "history": out_overfit["history"],
    }

    (run_dir / "overfit_test_summary.json").write_text(json.dumps(overfit_summary, indent=2))

    # plot CE curves for overfit
    h = overfit_summary["history"]
    plt.figure(figsize=(7, 4))
    plt.plot(h["epoch"], h["train_ce"], marker="o", label="train_ce (fixed)")
    plt.plot(h["epoch"], h["val_ce"], marker="o", label="val_ce (fixed)")
    plt.xlabel("epoch")
    plt.ylabel("cross-entropy")
    plt.title(f"{months}m OVERFIT TEST (tiny fixed set)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(plots_dir / "overfit_ce_curves.png", dpi=150)
    plt.close()

    logger.info(f"[OVERFIT] train_ce_fixed={train_ce_fix:.4f} val_ce_fixed={val_ce_fix:.4f} val_ce_real={val_ce_real:.4f}")

    # ----------------------------
    # (A) Holdout training (normal training, but single split)
    # ----------------------------
    cfg_holdout = Cfg(**asdict(cfg))
    cfg_holdout.epochs = int(holdout_epochs)

    train_ds = RandomWaveNetSegments(
        epochs_1d=tr_list,
        seq_len=cfg_holdout.seq_len,
        n_samples=cfg_holdout.train_samples_per_epoch,
        amp_min=cfg_holdout.amp_min,
        amp_max=cfg_holdout.amp_max,
        n_bins=cfg_holdout.n_bins,
        rng=np.random.default_rng(1),
    )

    # fixed val for stability
    val_ds_tmp = RandomWaveNetSegments(
        epochs_1d=va_list,
        seq_len=cfg_holdout.seq_len,
        n_samples=cfg_holdout.val_samples_fixed,
        amp_min=cfg_holdout.amp_min,
        amp_max=cfg_holdout.amp_max,
        n_bins=cfg_holdout.n_bins,
        rng=np.random.default_rng(2),
    )
    fixed_val_pairs = list(val_ds_tmp.pairs)
    val_ds = FixedWaveNetSegments(
        epochs_1d=va_list,
        seq_len=cfg_holdout.seq_len,
        amp_min=cfg_holdout.amp_min,
        amp_max=cfg_holdout.amp_max,
        n_bins=cfg_holdout.n_bins,
        fixed_pairs=fixed_val_pairs,
    )

    train_loader = DataLoader(train_ds, batch_size=cfg_holdout.batch_size, shuffle=True, num_workers=0, drop_last=True)
    val_loader   = DataLoader(val_ds,   batch_size=cfg_holdout.batch_size, shuffle=False, num_workers=0, drop_last=False)

    model_holdout = EEGWaveNetCategorical(n_bins=cfg_holdout.n_bins, n_filters=cfg_holdout.n_filters, kernel_size=cfg_holdout.kernel_size)

    logger.info(f"[HOLDOUT] training: epochs={cfg_holdout.epochs} train_samples/epoch={cfg_holdout.train_samples_per_epoch} val_fixed={len(val_ds)}")
    out_holdout = train_with_history(
        model=model_holdout,
        train_loader=train_loader,
        val_loader=val_loader,
        rf=cfg_holdout.receptive_field,
        cfg=cfg_holdout,
        device=DEVICE,
        logger=logger,
        ckpt_dir=ckpt_holdout,
    )

    # evaluate on val/test with teacher-forced metrics
    val_ce = eval_teacher_forced_ce(model_holdout, val_loader, cfg_holdout, DEVICE)
    val_m  = eval_teacher_forced_metrics(model_holdout, val_loader, cfg_holdout, DEVICE)

    # make a fixed test loader
    test_pairs = sample_fixed_pairs(te_list, cfg.seq_len, n_samples=min(1500, cfg.val_samples_fixed),
                                    amp_min=cfg.amp_min, amp_max=cfg.amp_max, n_bins=cfg.n_bins, seed=789)
    test_ds = FixedWaveNetSegments(te_list, cfg.seq_len, cfg.amp_min, cfg.amp_max, cfg.n_bins, test_pairs)
    test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=0, drop_last=False)

    test_ce = eval_teacher_forced_ce(model_holdout, test_loader, cfg_holdout, DEVICE)
    test_m  = eval_teacher_forced_metrics(model_holdout, test_loader, cfg_holdout, DEVICE)

    holdout_summary = {
        "months": months,
        "files_total": len(all_files),
        "files_train": len(train_files),
        "files_val": len(val_files),
        "files_test": len(test_files),
        "amp_min": cfg_holdout.amp_min,
        "amp_max": cfg_holdout.amp_max,
        "best_epoch": int(out_holdout["best_epoch"]),
        "best_val_ce": float(out_holdout["best_val_ce"]),
        "val_ce_teacher_forced": float(val_ce),
        "test_ce_teacher_forced": float(test_ce),
        "val_metrics_expected": val_m,
        "test_metrics_expected": test_m,
        "history": out_holdout["history"],
    }
    (run_dir / "holdout_summary.json").write_text(json.dumps(holdout_summary, indent=2))

    # plot holdout CE curves
    hh = holdout_summary["history"]
    plt.figure(figsize=(7, 4))
    plt.plot(hh["epoch"], hh["train_ce"], marker="o", label="train_ce")
    plt.plot(hh["epoch"], hh["val_ce"], marker="o", label="val_ce")
    plt.xlabel("epoch")
    plt.ylabel("cross-entropy")
    plt.title(f"{months}m HOLDOUT TRAINING (time-safe split)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(plots_dir / "holdout_ce_curves.png", dpi=150)
    plt.close()

    logger.info(f"[HOLDOUT] val_ce={val_ce:.4f} test_ce={test_ce:.4f} | val_MAE={val_m['MAE']:.4f} test_MAE={test_m['MAE']:.4f}")

    return {"overfit": overfit_summary, "holdout": holdout_summary}


def main():
    RUNS_DIR.mkdir(parents=True, exist_ok=True)

    all_reports = []
    for months in WINDOW_MONTHS_LIST:
        rep = run_one_months(
            months=months,
            train_frac=0.70,
            val_frac=0.15,
            auto_amp_pct=99.9,
            overfit_n_files=2,
            overfit_fixed_segments=1500,
            overfit_epochs=250,
            holdout_epochs=200,
        )
        all_reports.append(rep)

    # one-line summary
    parts = []
    for rep in all_reports:
        months = rep["holdout"]["months"]
        val_ce = rep["holdout"]["val_ce_teacher_forced"]
        test_ce = rep["holdout"]["test_ce_teacher_forced"]
        gap = rep["overfit"]["gap_fixed_train_minus_real_val"]
        parts.append(f"{months}m(valCE={val_ce:.3f}, testCE={test_ce:.3f}, overfitGap={gap:.3f})")
    print(" | ".join(parts))
    print(f"[DONE] Outputs saved under: {RUNS_DIR.resolve()}")

if __name__ == "__main__":
    main()

KeyboardInterrupt: 