In [7]:

from core.training import train_model
from util.datasets import RandomWaveNetSegments
from util.wavenet import dilations_1s_context, receptive_field
from util.metrics import metrics_1d
from util.quantization import mu_law_decode_np
from core.wavenet import WaveNetCategorical

import time
import json
import logging
from pathlib import Path
from dataclasses import dataclass, asdict
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import yaml

config = yaml.safe_load(open("config.yaml", "r"))

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WINDOW_ROOT = Path.cwd() / config["training"]["training_windows_dataset"]
WINDOW_MONTHS_LIST = (3, )
VALUE_COL = "max"

RUNS_DIR = Path("trained_wavenet_kfold_tests")
RECURSIVE = True

# K-fold settings
K_FOLDS = 3                     # change to 3 if runs take too long
TRAIN_FRAC_WITHIN_FILE = None
MIN_LEN_PADDING = 0             # extra length beyond seq_len + 10 if you want stricter filtering

# Training settings per fold
@dataclass
class Cfg:
    n_bins: int = 256
    kernel_size: int = 2
    n_filters: int = 16

    lr: float = 1e-3
    lr_decay_gamma: float = 0.99
    batch_size: int = 32
    epochs: int = 200
    train_samples_per_epoch: int = 6000
    val_samples_fixed: int = 1500
    early_stop_patience: int = 10
    grad_clip: float = 1.0

    amp_min: float = -3.0
    amp_max: float =  3.0

    receptive_field: int = 1024
    seq_len: int = 0


def setup_logger(out_dir: Path, name: str) -> logging.Logger:
    out_dir.mkdir(parents=True, exist_ok=True)
    lg = logging.getLogger(name)
    lg.setLevel(logging.INFO)
    lg.handlers.clear()
    fmt = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s")
    ch = logging.StreamHandler()
    ch.setFormatter(fmt)
    lg.addHandler(ch)
    fh = logging.FileHandler(out_dir / "run.log")
    fh.setFormatter(fmt)
    lg.addHandler(fh)
    return lg


# ----------------------------
# Window loading helpers
# ----------------------------
def iter_window_files(ds_dir: Path) -> list[Path]:
    pats = ["**/*.parquet", "**/*.csv.gz", "**/*.csv"] if RECURSIVE else ["*.parquet", "*.csv.gz", "*.csv"]
    files = []
    for pat in pats:
        files.extend(ds_dir.glob(pat))
    files = sorted(set(files))

    # ✅ keep only LOAD files by filename (adjust keywords if your naming differs)
    def is_load_file(p: Path) -> bool:
        name = p.name.lower()
        return ("load" in name) and ("pv" not in name)

    files = [p for p in files if is_load_file(p)]
    return files


def read_window_file(path: Path) -> pd.DataFrame:
    if path.suffix.lower() == ".parquet":
        return pd.read_parquet(path)
    if path.name.lower().endswith(".csv.gz"):
        return pd.read_csv(path, compression="gzip")
    if path.suffix.lower() == ".csv":
        return pd.read_csv(path)
    raise ValueError(f"Unsupported window file: {path}")


def load_series_from_window_file(
    path: Path,
    value_col: str = "max",
    require_metric: str | None = "load_power",  # ✅ filter PV via metric column if present
) -> np.ndarray | None:
    df = read_window_file(path)

    # ✅ if there is a metric column, only keep load_power rows
    if require_metric is not None and "metric" in df.columns:
        df = df[df["metric"].astype(str).str.lower() == require_metric.lower()]
        if df.empty:
            return None

    if "utc" in df.columns:
        df["utc"] = pd.to_datetime(df["utc"], errors="coerce")
        df = df.dropna(subset=["utc"]).sort_values("utc")

    if value_col not in df.columns:
        return None

    y = pd.to_numeric(df[value_col], errors="coerce").dropna().to_numpy(dtype=np.float64)
    if y.size < 10:
        return None

    # ✅ vertical normalization ONLY (no scaling)
    y = y - float(np.mean(y))
    return y


def files_to_series_list(files: list[Path], min_len: int, value_col: str) -> list[np.ndarray]:
    out = []
    for p in files:
        y = load_series_from_window_file(p, value_col=value_col, require_metric="load_power")
        if y is None or len(y) < min_len:
            continue
        out.append(y.astype(np.float64))
    return out


# ----------------------------
# Evaluation helpers
# ----------------------------
@torch.no_grad()
def eval_teacher_forced_metrics(model: WaveNetCategorical, loader: DataLoader, cfg: Cfg, device: torch.device,max_batches: int = 200) -> dict:
    model.eval()
    rf = int(cfg.receptive_field)

    centers = np.linspace(cfg.amp_min, cfg.amp_max, cfg.n_bins, dtype=np.float64)
    centers_t = torch.tensor(centers, device=device, dtype=torch.float32)

    y_true_all, y_pred_all = [], []

    for bi, (xb, yb) in enumerate(loader):
        if bi >= max_batches:
            break
        xb = xb.to(device)
        yb = yb.to(device)

        logits = model(xb)           # [B,K,T]
        logits_v = logits[:, :, rf:] # [B,K,T-rf]
        y_v = yb[:, rf:]             # [B,T-rf]

        probs = torch.softmax(logits_v, dim=1)
        y_pred = (probs * centers_t.view(1, -1, 1)).sum(dim=1)  # [B,T-rf]

        y_true = mu_law_decode_np(
            y_v.detach().cpu().numpy(),
            mu=cfg.n_bins - 1,
            amp_max=cfg.amp_max,
        ).reshape(-1)

        y_pred = y_pred.detach().cpu().numpy().reshape(-1)
        y_true_all.append(y_true)
        y_pred_all.append(y_pred)

    if not y_true_all:
        return {"MSE": float("nan"), "MAE": float("nan"), "Corr": float("nan"), "N": 0.0}

    return metrics_1d(np.concatenate(y_true_all), np.concatenate(y_pred_all))


@torch.no_grad()
def eval_teacher_forced_ce(model: WaveNetCategorical, loader: DataLoader, cfg: Cfg, device: torch.device,
                           max_batches: int = 200) -> float:
    """Compute teacher-forced cross-entropy (same objective as training) on a loader."""
    import torch.nn.functional as F
    model.eval()
    rf = int(cfg.receptive_field)
    losses = []
    for bi, (xb, yb) in enumerate(loader):
        if bi >= max_batches:
            break
        xb = xb.to(device)
        yb = yb.to(device)
        logits = model(xb)                # [B,K,T]
        logits_v = logits[:, :, rf:]
        y_v = yb[:, rf:]
        loss = F.cross_entropy(
            logits_v.permute(0, 2, 1).reshape(-1, cfg.n_bins),
            y_v.reshape(-1)
        )
        losses.append(float(loss.item()))
    return float(np.mean(losses)) if losses else float("nan")


def load_best_last_ckpts(ckpt_dir: Path):
    best_path = ckpt_dir / "best.pt"
    last_path = ckpt_dir / "last.pt"
    best = torch.load(best_path, map_location="cpu", weights_only=False) if best_path.exists() else None
    last = torch.load(last_path, map_location="cpu", weights_only=False) if last_path.exists() else None
    return best, last


# ----------------------------
# Time-series-safe k-fold splitter (blocked folds)
# ----------------------------
def blocked_kfold(files_sorted: list[Path], k: int) -> list[tuple[list[Path], list[Path]]]:
    """
    Split ordered files into k contiguous blocks (validation blocks).
    Train = all other blocks.
    """
    n = len(files_sorted)
    if n < k:
        k = n
    fold_sizes = [(n // k) + (1 if i < (n % k) else 0) for i in range(k)]
    splits = []
    start = 0
    for fs in fold_sizes:
        end = start + fs
        val_files = files_sorted[start:end]
        train_files = files_sorted[:start] + files_sorted[end:]
        splits.append((train_files, val_files))
        start = end
    return splits


# ----------------------------
# One fold run
# ----------------------------
def run_fold(months: int, fold_idx: int, train_files: list[Path], val_files: list[Path]) -> dict:
    run_dir = RUNS_DIR / f"{months}m" / f"fold_{fold_idx}"
    ckpt_dir = run_dir / "checkpoints"
    plots_dir = run_dir / "plots"
    plots_dir.mkdir(parents=True, exist_ok=True)

    logger = setup_logger(run_dir, name=f"wavenet_{months}m_fold{fold_idx}")
    cfg = Cfg()

    dils = dilations_1s_context()
    cfg.receptive_field = int(receptive_field(cfg.kernel_size, dils))
    cfg.seq_len = cfg.receptive_field + 256
    min_len = cfg.seq_len + 10 + int(MIN_LEN_PADDING)

    logger.info(f"months={months} fold={fold_idx} device={DEVICE}")
    logger.info(f"seq_len={cfg.seq_len} RF={cfg.receptive_field}")
    logger.info(f"train_files={len(train_files)} val_files={len(val_files)}")

    tr_list = files_to_series_list(train_files, min_len=min_len, value_col=VALUE_COL)
    va_list = files_to_series_list(val_files,   min_len=min_len, value_col=VALUE_COL)

    if not tr_list or not va_list:
        raise RuntimeError(f"Fold {fold_idx} has insufficient usable series.")

    train_ds = RandomWaveNetSegments(
        epochs_1d=tr_list,
        seq_len=cfg.seq_len,
        n_samples=cfg.train_samples_per_epoch,
        amp_min=cfg.amp_min, amp_max=cfg.amp_max, n_bins=cfg.n_bins,
        rng=np.random.default_rng(1),
    )

    # fixed validation pairs for stability
    val_ds_tmp = RandomWaveNetSegments(
        epochs_1d=va_list,
        seq_len=cfg.seq_len,
        n_samples=cfg.val_samples_fixed,
        amp_min=cfg.amp_min, amp_max=cfg.amp_max, n_bins=cfg.n_bins,
        rng=np.random.default_rng(2),
    )
    fixed_pairs = list(val_ds_tmp.pairs)
    val_ds = RandomWaveNetSegments(
        epochs_1d=va_list,
        seq_len=cfg.seq_len,
        n_samples=len(fixed_pairs),
        amp_min=cfg.amp_min, amp_max=cfg.amp_max, n_bins=cfg.n_bins,
        rng=np.random.default_rng(2),
        fixed_pairs=fixed_pairs,
    )

    val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=0, drop_last=False)

    # Optional: create a small train loader for CE comparison (overfitting gap)
    train_ds_eval_tmp = RandomWaveNetSegments(
        epochs_1d=tr_list,
        seq_len=cfg.seq_len,
        n_samples=min(2000, cfg.train_samples_per_epoch),
        amp_min=cfg.amp_min, amp_max=cfg.amp_max, n_bins=cfg.n_bins,
        rng=np.random.default_rng(3),
    )
    train_loader_eval = DataLoader(train_ds_eval_tmp, batch_size=cfg.batch_size, shuffle=False, num_workers=0, drop_last=False)

    model = WaveNetCategorical(n_bins=cfg.n_bins, n_filters=cfg.n_filters, kernel_size=cfg.kernel_size)

    t0 = time.perf_counter()
    model = train_model(
        model=model,
        train_ds=train_ds,
        val_loader=val_loader,
        rf=cfg.receptive_field,
        cfg=cfg,
        device=DEVICE,
        logger=logger,
        wb_run=None,
        ckpt_dir=ckpt_dir,
    )
    train_time = time.perf_counter() - t0

    # Best/last checkpoint info
    best, last = load_best_last_ckpts(ckpt_dir)
    best_epoch = int(best.get("best_epoch", best.get("epoch", -1))) if best else -1
    best_val_ce = float(best.get("best_val", np.nan)) if best else float("nan")
    last_epoch = int(last.get("epoch", -1)) if last else -1

    # Evaluate metrics on BEST model (train_model reloads best.pt before returning)
    val_m = eval_teacher_forced_metrics(model, val_loader, cfg, DEVICE, max_batches=250)
    val_ce_tf = eval_teacher_forced_ce(model, val_loader, cfg, DEVICE, max_batches=250)
    tr_ce_tf  = eval_teacher_forced_ce(model, train_loader_eval, cfg, DEVICE, max_batches=250)

    # Overfitting signals
    ce_gap = tr_ce_tf - val_ce_tf                # negative is good; positive suggests overfitting
    trained_past_best = (last_epoch > best_epoch) if (best_epoch >= 0 and last_epoch >= 0) else False

    # Quick plots: CE gap bar + a tiny summary plot
    plt.figure(figsize=(6, 3))
    plt.bar(["train_CE", "val_CE"], [tr_ce_tf, val_ce_tf])
    plt.title(f"{months}m fold {fold_idx} CE (teacher-forced)")
    plt.tight_layout()
    plt.savefig(plots_dir / "ce_train_vs_val.png", dpi=150)
    plt.close()

    # Save fold summary
    summary = {
        "months": months,
        "fold": fold_idx,
        "train_files": len(train_files),
        "val_files": len(val_files),
        "train_time_s": round(train_time, 2),
        "best_epoch": best_epoch,
        "last_epoch": last_epoch,
        "best_val_ce": best_val_ce,
        "train_ce_teacher_forced": tr_ce_tf,
        "val_ce_teacher_forced": val_ce_tf,
        "ce_gap_train_minus_val": ce_gap,
        "trained_past_best": trained_past_best,
        "val_metrics_1d": val_m,
    }
    (run_dir / "fold_summary.json").write_text(json.dumps(summary, indent=2))
    logger.info(f"FOLD SUMMARY: {summary}")

    return summary


# ----------------------------
# Run k-fold for one dataset
# ----------------------------
def run_kfold_for_months(months: int) -> dict:
    ds_dir = WINDOW_ROOT / f"{months}m"
    all_files = iter_window_files(ds_dir)

    # Sort deterministically (time-safe-ish): by filename then path
    # (your window naming included dates; if so, this is good. Otherwise it’s still consistent.)
    all_files = sorted(all_files, key=lambda p: str(p))

    if len(all_files) < 2:
        raise RuntimeError(f"Not enough window files in {ds_dir}")

    splits = blocked_kfold(all_files, K_FOLDS)

    fold_summaries = []
    for fi, (train_files, val_files) in enumerate(splits, start=1):
        fold_summaries.append(run_fold(months, fi, train_files, val_files))

    # Aggregate results
    maes = [fs["val_metrics_1d"]["MAE"] for fs in fold_summaries]
    mses = [fs["val_metrics_1d"]["MSE"] for fs in fold_summaries]
    cors = [fs["val_metrics_1d"]["Corr"] for fs in fold_summaries]
    gaps = [fs["ce_gap_train_minus_val"] for fs in fold_summaries]

    agg = {
        "months": months,
        "k_folds": len(fold_summaries),
        "MAE_mean": float(np.nanmean(maes)),
        "MAE_std": float(np.nanstd(maes)),
        "MSE_mean": float(np.nanmean(mses)),
        "MSE_std": float(np.nanstd(mses)),
        "Corr_mean": float(np.nanmean(cors)),
        "Corr_std": float(np.nanstd(cors)),
        "CE_gap_mean": float(np.nanmean(gaps)),
        "CE_gap_std": float(np.nanstd(gaps)),
        "overfit_risk_flag": bool(np.nanmean(gaps) > 0.0),  # avg train_ce > val_ce is a red flag
    }

    # Save + plot distribution of MAE across folds
    out_dir = RUNS_DIR / f"{months}m"
    (out_dir / "kfold_aggregate.json").write_text(json.dumps(agg, indent=2))

    plt.figure(figsize=(7, 4))
    plt.plot(range(1, len(maes) + 1), maes, marker="o")
    plt.xlabel("fold")
    plt.ylabel("VAL MAE (metrics_1d)")
    plt.title(f"{months}m dataset: MAE across folds")
    plt.tight_layout()
    plt.savefig(out_dir / "kfold_mae_across_folds.png", dpi=150)
    plt.close()

    return agg


def main():
    RUNS_DIR.mkdir(parents=True, exist_ok=True)

    aggs = []
    for months in WINDOW_MONTHS_LIST:
        aggs.append(run_kfold_for_months(months))

    # Single-line final report
    # Example: 12m(MAE=0.123±0.01,gap=0.02±0.01,overfit=Y) | 6m(...) | 3m(...)
    parts = []
    for a in aggs:
        parts.append(
            f"{a['months']}m("
            f"MAE={a['MAE_mean']:.4f}±{a['MAE_std']:.4f},"
            f"Corr={a['Corr_mean']:.3f}±{a['Corr_std']:.3f},"
            f"gap={a['CE_gap_mean']:.4f}±{a['CE_gap_std']:.4f},"
            f"overfit={'Y' if a['overfit_risk_flag'] else 'N'})"
        )
    print(" | ".join(parts))

    # Comparison plot: MAE_mean vs months
    plt.figure(figsize=(7, 4))
    xs = [a["months"] for a in aggs]
    ys = [a["MAE_mean"] for a in aggs]
    plt.plot(xs, ys, marker="o")
    plt.xlabel("dataset (months)")
    plt.ylabel("K-fold mean VAL MAE")
    plt.title("K-fold validation MAE vs dataset window length")
    plt.tight_layout()
    plt.savefig(RUNS_DIR / "compare_kfold_mae.png", dpi=150)
    plt.close()

    print(f"[DONE] Outputs saved under: {RUNS_DIR.resolve()}")


if __name__ == "__main__":
        main()

[2026-02-22 01:07:51,777] INFO - months=3 fold=1 device=cpu
[2026-02-22 01:07:51,778] INFO - seq_len=1256 RF=1000
[2026-02-22 01:07:51,778] INFO - train_files=628 val_files=315
[2026-02-22 01:10:48,743] INFO - epoch 001/200 | train_ce=4.846106e+00 | val_ce=4.334502e+00 | lr=1.000e-03 | dt=30.85s
[2026-02-22 01:11:19,961] INFO - epoch 002/200 | train_ce=3.872791e+00 | val_ce=3.543097e+00 | lr=9.900e-04 | dt=31.17s
[2026-02-22 01:11:49,376] INFO - epoch 003/200 | train_ce=3.158140e+00 | val_ce=3.113452e+00 | lr=9.801e-04 | dt=29.39s
[2026-02-22 01:12:26,147] INFO - epoch 004/200 | train_ce=2.881385e+00 | val_ce=2.991494e+00 | lr=9.703e-04 | dt=36.73s
[2026-02-22 01:13:31,074] INFO - epoch 005/200 | train_ce=2.786013e+00 | val_ce=2.911303e+00 | lr=9.606e-04 | dt=64.90s
[2026-02-22 01:14:44,000] INFO - epoch 006/200 | train_ce=2.683351e+00 | val_ce=2.872121e+00 | lr=9.510e-04 | dt=72.88s
[2026-02-22 01:15:46,598] INFO - epoch 007/200 | train_ce=2.688212e+00 | val_ce=2.845222e+00 | lr=9.415

KeyboardInterrupt: 