In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, average_precision_score, log_loss
import torch
import torch.nn as nn
from tqdm import tqdm
import random
import copy
import inspect
import pandas as pd

import data_interface
import mnar_blackout_lds

random.seed(42)
np.random.seed(42)

In [4]:
# Load the data
x_t, m_t, meta = data_interface.load_panel()

In [5]:
evaluation_windows = data_interface.get_eval_windows(
    data_dir="data",
    manifest_name="evaluation_windows_mnar_weighted.parquet",
)

In [6]:
from collections import defaultdict

def stratified_month_sampling(data, n_per_month, ts_key="blackout_start"):
    buckets = defaultdict(list)

    for item in data:
        ts = item[ts_key]
        month_key = (ts.year, ts.month)
        buckets[month_key].append(item)

    result = []
    for month_key, items in buckets.items():
        if len(items) < n_per_month:
            picks = random.choices(items, k=n_per_month)
        else:
            picks = random.sample(items, n_per_month)
        result.extend(picks)

    return result


In [7]:
# ------------------------------------------------------------------
# 1) Group evaluation windows by window_id and test_type/horizon
#    (handles cases where there are multiple rows per window_id).
# ------------------------------------------------------------------
impute_by_id = {}
forecast_1_by_id = {}
forecast_3_by_id = {}
forecast_6_by_id = {}

for w in evaluation_windows:
    wid = w["window_id"]
    if w["test_type"] == "impute":
        # If duplicates exist, last one wins – that's fine for eval.
        impute_by_id[wid] = w
    elif w["test_type"] == "forecast":
        h = int(w["horizon_steps"])
        if h == 1:
            forecast_1_by_id[wid] = w
        elif h == 3:
            forecast_3_by_id[wid] = w
        elif h == 6:
            forecast_6_by_id[wid] = w


In [8]:
# ------------------------------------------------------------------
# 2) Keep only window_ids that have impute + 1-step + 3-step + 6-step
# ------------------------------------------------------------------
common_ids = (
    set(impute_by_id.keys())
    & set(forecast_1_by_id.keys())
    & set(forecast_3_by_id.keys())
    & set(forecast_6_by_id.keys())
)

# Pool of impute windows that have all matching forecast horizons
impute_windows_pool = [impute_by_id[wid] for wid in common_ids]

In [9]:
# ------------------------------------------------------------------
# 3) Sample *validation* impute windows (stratified by month)
# ------------------------------------------------------------------
impute_evaluation_windows_val = stratified_month_sampling(
    impute_windows_pool,
    n_per_month=25,
    ts_key="blackout_start",
)

# Preserve order: we now build *aligned* forecast lists
val_ids_ordered = [w["window_id"] for w in impute_evaluation_windows_val]

forecast_1_evaluation_windows_val = [
    forecast_1_by_id[wid] for wid in val_ids_ordered
]
forecast_3_evaluation_windows_val = [
    forecast_3_by_id[wid] for wid in val_ids_ordered
]
forecast_6_evaluation_windows_val = [
    forecast_6_by_id[wid] for wid in val_ids_ordered
]

In [10]:
assert len(forecast_1_evaluation_windows_val) == len(impute_evaluation_windows_val)
assert len(forecast_3_evaluation_windows_val) == len(impute_evaluation_windows_val)
assert len(forecast_6_evaluation_windows_val) == len(impute_evaluation_windows_val)

# ------------------------------------------------------------------
# 4) Combined list used only for masking (deduped later)
# ------------------------------------------------------------------
evaluation_windows_val = (
    forecast_1_evaluation_windows_val
    + forecast_3_evaluation_windows_val
    + forecast_6_evaluation_windows_val
    + impute_evaluation_windows_val
)

In [11]:
# ============================================================
# 0) Time features (used by diagnosis + seasonal baselines)
# ============================================================
def build_time_features(timestamps: np.ndarray) -> np.ndarray:
    """
    timestamps: np.ndarray of pandas.Timestamp, shape (T,)
    Returns: X_time, shape (T, 6)
      [sin_hour, cos_hour, sin_dow, cos_dow, is_weekend, is_rush]
    """
    ts = pd.to_datetime(timestamps)
    hour = ts.hour.to_numpy()
    dow = ts.dayofweek.to_numpy()  # Mon=0

    hour_rad = 2.0 * np.pi * (hour / 24.0)
    dow_rad = 2.0 * np.pi * (dow / 7.0)

    sin_hour = np.sin(hour_rad)
    cos_hour = np.cos(hour_rad)
    sin_dow = np.sin(dow_rad)
    cos_dow = np.cos(dow_rad)

    is_weekend = ((dow >= 5).astype(float))
    # Simple rush-hour proxy (tune if needed): 7–10 and 16–19
    is_rush = (((hour >= 7) & (hour <= 10)) | ((hour >= 16) & (hour <= 19))).astype(float)

    return np.stack([sin_hour, cos_hour, sin_dow, cos_dow, is_weekend, is_rush], axis=1)

X_time = data_interface.build_time_features(meta["timestamps"])

In [12]:
def locf_impute_baseline(x_t, start_idx, end_idx, detector_idx):
    """
    Naive 'last observation carried forward' baseline for imputation.
    Uses the last finite value before blackout; falls back to
    detector-wise historical mean if needed.
    """
    last_idx = start_idx - 1
    while last_idx >= 0 and not np.isfinite(x_t[last_idx, detector_idx]):
        last_idx -= 1

    if last_idx < 0:
        det_vals = x_t[:, detector_idx]
        last_val = float(np.nanmean(det_vals))
    else:
        last_val = float(x_t[last_idx, detector_idx])

    length = end_idx - start_idx + 1
    return np.full(length, last_val, dtype=float)


def locf_forecast_baseline(x_t, end_idx, detector_idx):
    """
    Naive baseline for forecasting: hold the last available observation
    at the end of the blackout.
    """
    last_idx = end_idx
    while last_idx >= 0 and not np.isfinite(x_t[last_idx, detector_idx]):
        last_idx -= 1

    if last_idx < 0:
        det_vals = x_t[:, detector_idx]
        last_val = float(np.nanmean(det_vals))
    else:
        last_val = float(x_t[last_idx, detector_idx])

    return last_val


def evaluate_impute_forecast_model(
    model,
    mu_smooth,
    Sigma_smooth,
    mu_filt,
    Sigma_filt,
    x_t,
    meta,
    label="model",
):
    """
    Re-usable evaluation for any LDS-like model (MAR or MNAR):
    - blackout imputation RMSE/MAE (length-weighted)
    - 1 / 3 / 6-step forecast RMSE/MAE
    """
    # ---------------- Imputation ----------------
    impute_mae_list = []
    impute_mse_list = []

    for window in impute_evaluation_windows_val:
        if window["test_type"] != "impute":
            continue

        start_idx = np.where(meta["timestamps"] == window["blackout_start"])[0][0]
        end_idx = np.where(meta["timestamps"] == window["blackout_end"])[0][0]
        detector_idx = np.where(meta["detectors"] == window["detector_id"])[0][0]

        eval_x_t = x_t[start_idx : end_idx + 1].copy()
        eval_mu_smooth = mu_smooth[start_idx : end_idx + 1]
        eval_Sigma_smooth = Sigma_smooth[start_idx : end_idx + 1]

        reconstruct_x_t, _ = model.reconstruct_from_smoother(
            eval_mu_smooth, eval_Sigma_smooth
        )

        y_true = eval_x_t[:, detector_idx]
        y_pred = reconstruct_x_t[:, detector_idx]

        mask = np.isfinite(y_true) & np.isfinite(y_pred)
        if not mask.any():
            continue

        mae = sklearn.metrics.mean_absolute_error(y_pred[mask], y_true[mask])
        mse = sklearn.metrics.mean_squared_error(y_pred[mask], y_true[mask])

        impute_mae_list.append([mae, window["len_steps"]])
        impute_mse_list.append([mse, window["len_steps"]])

    final_mae = np.average(
        [item[0] for item in impute_mae_list],
        weights=[item[1] for item in impute_mae_list],
    )
    final_mse = np.average(
        [item[0] for item in impute_mse_list],
        weights=[item[1] for item in impute_mse_list],
    )
    final_rmse = np.sqrt(final_mse)

    print(f"\n[{label}] Imputation performance:")
    print("  MAE :", final_mae)
    print("  MSE :", final_mse)
    print("  RMSE:", final_rmse)

    # ---------------- Forecasting ----------------
    y_actual_1_step, y_forecast_1_step = [], []
    y_actual_3_step, y_forecast_3_step = [], []
    y_actual_6_step, y_forecast_6_step = [], []

    forecast_evaluation_windows_val = (
        forecast_1_evaluation_windows_val
        + forecast_3_evaluation_windows_val
        + forecast_6_evaluation_windows_val
    )

    for window in forecast_evaluation_windows_val:
        if window["test_type"] != "forecast":
            continue

        start_idx = np.where(meta["timestamps"] == window["blackout_start"])[0][0]
        end_idx = np.where(meta["timestamps"] == window["blackout_end"])[0][0]
        detector_idx = np.where(meta["detectors"] == window["detector_id"])[0][0]
        horizon = int(window["horizon_steps"])

        # Skip windows too close to the end of the series
        if end_idx + horizon >= x_t.shape[0]:
            continue

        eval_x_t = x_t[end_idx + 1 : end_idx + 1 + horizon].copy()

        forecast_x_t, _ = model.k_step_forecast(
            mu_filt, Sigma_filt, end_idx, k=horizon
        )

        y_true = eval_x_t[horizon - 1, detector_idx]
        y_pred = forecast_x_t[detector_idx]

        if not (np.isfinite(y_true) and np.isfinite(y_pred)):
            continue

        if horizon == 1:
            y_forecast_1_step.append(y_pred)
            y_actual_1_step.append(y_true)
        elif horizon == 3:
            y_forecast_3_step.append(y_pred)
            y_actual_3_step.append(y_true)
        elif horizon == 6:
            y_forecast_6_step.append(y_pred)
            y_actual_6_step.append(y_true)

    mae_1_step = sklearn.metrics.mean_absolute_error(
        y_forecast_1_step, y_actual_1_step
    )
    mse_1_step = sklearn.metrics.mean_squared_error(
        y_forecast_1_step, y_actual_1_step
    )
    rmse_1_step = np.sqrt(mse_1_step)

    mae_3_step = sklearn.metrics.mean_absolute_error(
        y_forecast_3_step, y_actual_3_step
    )
    mse_3_step = sklearn.metrics.mean_squared_error(
        y_forecast_3_step, y_actual_3_step
    )
    rmse_3_step = np.sqrt(mse_3_step)

    mae_6_step = sklearn.metrics.mean_absolute_error(
        y_forecast_6_step, y_actual_6_step
    )
    mse_6_step = sklearn.metrics.mean_squared_error(
        y_forecast_6_step, y_actual_6_step
    )
    rmse_6_step = np.sqrt(mse_6_step)

    print(f"\n[{label}] Forecasting performance:")
    print("-----------------------------------")
    print("1-step MAE :", mae_1_step)
    print("1-step MSE :", mse_1_step)
    print("1-step RMSE:", rmse_1_step)

    print("\n-----------------------------------")
    print("3-step MAE :", mae_3_step)
    print("3-step MSE :", mse_3_step)
    print("3-step RMSE:", rmse_3_step)

    print("\n-----------------------------------")
    print("6-step MAE :", mae_6_step)
    print("6-step MSE :", mse_6_step)
    print("6-step RMSE:", rmse_6_step)

    return {
        "impute_mae": final_mae,
        "impute_mse": final_mse,
        "impute_rmse": final_rmse,
        "forecast_mae_1": mae_1_step,
        "forecast_mse_1": mse_1_step,
        "forecast_rmse_1": rmse_1_step,
        "forecast_mae_3": mae_3_step,
        "forecast_mse_3": mse_3_step,
        "forecast_rmse_3": rmse_3_step,
        "forecast_mae_6": mae_6_step,
        "forecast_mse_6": mse_6_step,
        "forecast_rmse_6": rmse_6_step,
    }

In [13]:
# ============================================================
# 1) Stronger baselines
#    - Linear interpolation inside blackout (uses pre + post points)
#    - Optional spline interpolation (falls back to linear)
#    - Seasonal naive for forecasting (same time yesterday / last week)
# ============================================================
def _find_last_finite(x: np.ndarray, idx: int, d: int) -> tuple[int, float] | tuple[None, None]:
    j = idx
    while j >= 0 and not np.isfinite(x[j, d]):
        j -= 1
    if j < 0:
        return None, None
    return j, float(x[j, d])


def _find_next_finite(x: np.ndarray, idx: int, d: int) -> tuple[int, float] | tuple[None, None]:
    j = idx
    T = x.shape[0]
    while j < T and not np.isfinite(x[j, d]):
        j += 1
    if j >= T:
        return None, None
    return j, float(x[j, d])


def linear_interp_impute_baseline(x_t_masked, start_idx, end_idx, detector_idx):
    """
    Impute a blackout [start_idx, end_idx] by linear interpolation between:
      left  = last finite before start
      right = first finite after end
    Falls back to LOCF if one side missing.
    """
    left_j, left_v = _find_last_finite(x_t_masked, start_idx - 1, detector_idx)
    right_j, right_v = _find_next_finite(x_t_masked, end_idx + 1, detector_idx)

    L = end_idx - start_idx + 1
    if left_v is None and right_v is None:
        # ultimate fallback: detector mean
        det_vals = x_t_masked[:, detector_idx]
        fill = float(np.nanmean(det_vals))
        return np.full(L, fill, dtype=float)
    if left_v is None:
        return np.full(L, right_v, dtype=float)
    if right_v is None:
        return np.full(L, left_v, dtype=float)

    # interpolate over actual index distance so long gaps are handled correctly
    xs = np.arange(start_idx, end_idx + 1)
    denom = max((right_j - left_j), 1)
    alpha = (xs - left_j) / denom
    return (1 - alpha) * left_v + alpha * right_v


def spline_impute_baseline(x_t_masked, start_idx, end_idx, detector_idx, order=3):
    """
    Optional spline interpolation via pandas (requires scipy).
    Falls back to linear interpolation if spline unavailable.
    """
    try:
        s = pd.Series(x_t_masked[:, detector_idx]).astype(float)
        # only interpolate the blackout segment; uses surrounding points
        s2 = s.copy()
        s2.iloc[start_idx:end_idx+1] = np.nan
        s2 = s2.interpolate(method="spline", order=order, limit_direction="both")
        return s2.iloc[start_idx:end_idx+1].to_numpy(dtype=float)
    except Exception:
        return linear_interp_impute_baseline(x_t_masked, start_idx, end_idx, detector_idx)


def seasonal_naive_forecast_baseline(x_t_masked, target_idx, detector_idx, offsets=(288, 2016)):
    """
    Forecast x[target_idx, d] using historical seasonal offsets:
      - 288 steps = 1 day back (5-min grid)
      - 2016 steps = 1 week back
    Falls back to LOCF at end of blackout if needed.
    """
    for off in offsets:
        j = target_idx - off
        if j >= 0 and np.isfinite(x_t_masked[j, detector_idx]):
            return float(x_t_masked[j, detector_idx])
    # fallback: last observed before target
    j, v = _find_last_finite(x_t_masked, target_idx - 1, detector_idx)
    if v is None:
        det_vals = x_t_masked[:, detector_idx]
        return float(np.nanmean(det_vals))
    return v

def build_hour_of_week_climatology(
    x_t_masked: np.ndarray,
    m_t_masked: np.ndarray,
    timestamps: np.ndarray,
    step_minutes: int = 5,
):
    """
    Builds time-of-week climatology: mean per (slot_of_week, detector).

    slot_of_week = dow * steps_per_day + step_in_day
    where step_in_day = (hour*60 + minute) // step_minutes

    Returns:
      how_mean: (S, D) array of means, S = 7 * steps_per_day
      slot_of_week: (T,) int array mapping each t -> slot index
      global_mean: (D,) detector-wise global mean fallback
    """
    ts = pd.to_datetime(timestamps)
    dow = ts.dayofweek.to_numpy()  # Mon=0
    minutes = (ts.hour.to_numpy() * 60 + ts.minute.to_numpy())

    steps_per_day = int((24 * 60) // step_minutes)
    step_in_day = (minutes // step_minutes).astype(int)
    # safety: if timestamps aren't aligned to step_minutes, clamp
    step_in_day = np.clip(step_in_day, 0, steps_per_day - 1)

    slot_of_week = (dow * steps_per_day + step_in_day).astype(int)
    S = 7 * steps_per_day
    T, D = x_t_masked.shape

    # observed mask: 1 where observed & finite
    obs = (m_t_masked == 0) & np.isfinite(x_t_masked)

    # detector-wise global mean fallback (computed only from observed)
    global_mean = np.nan_to_num(
        (np.nansum(np.where(obs, x_t_masked, 0.0), axis=0) / (obs.sum(axis=0) + 1e-6)),
        nan=0.0
    ).astype(float)

    sums = np.zeros((S, D), dtype=np.float64)
    cnts = np.zeros((S, D), dtype=np.float64)

    # single pass over time; vectorized over detectors
    for t in range(T):
        s = slot_of_week[t]
        m = obs[t]  # (D,)
        if not m.any():
            continue
        sums[s, m] += x_t_masked[t, m]
        cnts[s, m] += 1.0

    how_mean = sums / np.maximum(cnts, 1.0)
    how_mean[cnts < 1.0] = np.nan  # mark unseen slots as nan
    return how_mean.astype(float), slot_of_week, global_mean


def make_hour_of_week_forecast_fn(how_mean, slot_of_week, global_mean):
    """
    Returns a forecast_fn compatible with evaluate_impute_forecast_baseline:
      forecast_fn(x_t_masked, target_idx, detector_idx) -> float
    """
    def _fn(x_t_masked, target_idx, detector_idx):
        s = int(slot_of_week[target_idx])
        mu = how_mean[s, detector_idx]
        if np.isfinite(mu):
            return float(mu)
        # fallback: detector global mean, then LOCF
        gm = global_mean[detector_idx]
        if np.isfinite(gm):
            return float(gm)
        j, v = _find_last_finite(x_t_masked, target_idx - 1, detector_idx)
        if v is None:
            return 0.0
        return float(v)
    return _fn


def evaluate_impute_forecast_baseline(
    x_t_true,
    x_t_masked,
    meta,
    impute_fn,
    forecast_fn,
    label="baseline",
):
    """
    Generic evaluator for baselines.
      impute_fn(x_t_masked, start_idx, end_idx, detector_idx) -> (L,)
      forecast_fn(x_t_masked, target_idx, detector_idx) -> float
    """
    # ---------- Imputation ----------
    impute_mae_list = []
    impute_mse_list = []

    for window in impute_evaluation_windows_val:
        start_idx = np.where(meta["timestamps"] == window["blackout_start"])[0][0]
        end_idx   = np.where(meta["timestamps"] == window["blackout_end"])[0][0]
        d         = np.where(meta["detectors"] == window["detector_id"])[0][0]

        y_true = x_t_true[start_idx:end_idx+1, d].copy()
        y_pred = impute_fn(x_t_masked, start_idx, end_idx, d)

        mask = np.isfinite(y_true) & np.isfinite(y_pred)
        if not mask.any():
            continue

        mae = sklearn.metrics.mean_absolute_error(y_pred[mask], y_true[mask])
        mse = sklearn.metrics.mean_squared_error(y_pred[mask], y_true[mask])
        impute_mae_list.append([mae, window["len_steps"]])
        impute_mse_list.append([mse, window["len_steps"]])

    final_mae = np.average([a for a,_ in impute_mae_list], weights=[w for _,w in impute_mae_list])
    final_mse = np.average([a for a,_ in impute_mse_list], weights=[w for _,w in impute_mse_list])
    final_rmse = float(np.sqrt(final_mse))

    print(f"\n[{label}] Imputation performance:")
    print("  MAE :", final_mae)
    print("  MSE :", final_mse)
    print("  RMSE:", final_rmse)

    # ---------- Forecast ----------
    y_true_1, y_pred_1 = [], []
    y_true_3, y_pred_3 = [], []
    y_true_6, y_pred_6 = [], []

    forecast_windows = (
        forecast_1_evaluation_windows_val
        + forecast_3_evaluation_windows_val
        + forecast_6_evaluation_windows_val
    )

    for window in forecast_windows:
        end_idx = np.where(meta["timestamps"] == window["blackout_end"])[0][0]
        d       = np.where(meta["detectors"] == window["detector_id"])[0][0]
        h       = int(window["horizon_steps"])
        target_idx = end_idx + h
        if target_idx >= x_t_true.shape[0]:
            continue

        yt = x_t_true[target_idx, d]
        yp = forecast_fn(x_t_masked, target_idx, d)
        if not (np.isfinite(yt) and np.isfinite(yp)):
            continue

        if h == 1:
            y_true_1.append(yt); y_pred_1.append(yp)
        elif h == 3:
            y_true_3.append(yt); y_pred_3.append(yp)
        elif h == 6:
            y_true_6.append(yt); y_pred_6.append(yp)

    def _pack(y_pred, y_true, h):
        mae = sklearn.metrics.mean_absolute_error(y_pred, y_true)
        mse = sklearn.metrics.mean_squared_error(y_pred, y_true)
        rmse = float(np.sqrt(mse))
        print(f"{h}-step MAE :", mae)
        print(f"{h}-step MSE :", mse)
        print(f"{h}-step RMSE:", rmse)
        return mae, mse, rmse

    print(f"\n[{label}] Forecasting performance:")
    print("-----------------------------------")
    mae1,mse1,rmse1 = _pack(y_pred_1, y_true_1, 1)
    print("\n-----------------------------------")
    mae3,mse3,rmse3 = _pack(y_pred_3, y_true_3, 3)
    print("\n-----------------------------------")
    mae6,mse6,rmse6 = _pack(y_pred_6, y_true_6, 6)

    return {
        "impute_mae": final_mae,
        "impute_mse": final_mse,
        "impute_rmse": final_rmse,
        "forecast_mae_1": mae1, "forecast_mse_1": mse1, "forecast_rmse_1": rmse1,
        "forecast_mae_3": mae3, "forecast_mse_3": mse3, "forecast_rmse_3": rmse3,
        "forecast_mae_6": mae6, "forecast_mse_6": mse6, "forecast_rmse_6": rmse6,
    }

In [14]:
def evaluate_locf_baseline(
    x_t_true,
    x_t_masked,
    meta,
    label="LOCF baseline",
):
    """
    Baseline evaluation using LOCF for both imputation and forecasting.
    x_t_true   : full panel (no artificial masking), used ONLY for y_true
    x_t_masked : panel with blackout windows masked (same as training),
                 used for baseline predictions so it can't peek inside.
    """
    # ---------- Imputation ----------
    impute_mae_list = []
    impute_mse_list = []

    for window in impute_evaluation_windows_val:
        if window["test_type"] != "impute":
            continue

        start_idx = np.where(meta["timestamps"] == window["blackout_start"])[0][0]
        end_idx   = np.where(meta["timestamps"] == window["blackout_end"])[0][0]
        detector_idx = np.where(meta["detectors"] == window["detector_id"])[0][0]

        # Truth from full data
        y_true = x_t_true[start_idx : end_idx + 1, detector_idx].copy()

        # LOCF baseline only sees masked training panel
        y_pred = locf_impute_baseline(
            x_t_masked, start_idx, end_idx, detector_idx
        )

        mask = np.isfinite(y_true) & np.isfinite(y_pred)
        if not mask.any():
            continue

        mae = sklearn.metrics.mean_absolute_error(y_pred[mask], y_true[mask])
        mse = sklearn.metrics.mean_squared_error(y_pred[mask], y_true[mask])

        impute_mae_list.append([mae, window["len_steps"]])
        impute_mse_list.append([mse, window["len_steps"]])

    final_mae = np.average(
        [item[0] for item in impute_mae_list],
        weights=[item[1] for item in impute_mae_list],
    )
    final_mse = np.average(
        [item[0] for item in impute_mse_list],
        weights=[item[1] for item in impute_mse_list],
    )
    final_rmse = np.sqrt(final_mse)

    print(f"\n[{label}] Imputation performance:")
    print("  MAE :", final_mae)
    print("  MSE :", final_mse)
    print("  RMSE:", final_rmse)

    # ---------- Forecast ----------
    y_actual_1_step, y_forecast_1_step = [], []
    y_actual_3_step, y_forecast_3_step = [], []
    y_actual_6_step, y_forecast_6_step = [], []

    forecast_evaluation_windows_val = (
        forecast_1_evaluation_windows_val
        + forecast_3_evaluation_windows_val
        + forecast_6_evaluation_windows_val
    )

    for window in forecast_evaluation_windows_val:
        if window["test_type"] != "forecast":
            continue

        start_idx = np.where(meta["timestamps"] == window["blackout_start"])[0][0]
        end_idx   = np.where(meta["timestamps"] == window["blackout_end"])[0][0]
        detector_idx = np.where(meta["detectors"] == window["detector_id"])[0][0]
        horizon = int(window["horizon_steps"])

        if end_idx + horizon >= x_t_true.shape[0]:
            continue

        # Truth from full data
        y_true = x_t_true[end_idx + horizon, detector_idx]

        # Baseline sees ONLY masked panel (so it uses last observed pre-blackout)
        y_pred = locf_forecast_baseline(
            x_t_masked, end_idx, detector_idx
        )

        if not (np.isfinite(y_true) and np.isfinite(y_pred)):
            continue

        if horizon == 1:
            y_forecast_1_step.append(y_pred)
            y_actual_1_step.append(y_true)
        elif horizon == 3:
            y_forecast_3_step.append(y_pred)
            y_actual_3_step.append(y_true)
        elif horizon == 6:
            y_forecast_6_step.append(y_pred)
            y_actual_6_step.append(y_true)

    mae_1_step = sklearn.metrics.mean_absolute_error(
        y_forecast_1_step, y_actual_1_step
    )
    mse_1_step = sklearn.metrics.mean_squared_error(
        y_forecast_1_step, y_actual_1_step
    )
    rmse_1_step = np.sqrt(mse_1_step)

    mae_3_step = sklearn.metrics.mean_absolute_error(
        y_forecast_3_step, y_actual_3_step
    )
    mse_3_step = sklearn.metrics.mean_squared_error(
        y_forecast_3_step, y_actual_3_step
    )
    rmse_3_step = np.sqrt(mse_3_step)

    mae_6_step = sklearn.metrics.mean_absolute_error(
        y_forecast_6_step, y_actual_6_step
    )
    mse_6_step = sklearn.metrics.mean_squared_error(
        y_forecast_6_step, y_actual_6_step
    )
    rmse_6_step = np.sqrt(mse_6_step)

    print(f"\n[{label}] Forecasting performance:")
    print("-----------------------------------")
    print("1-step MAE :", mae_1_step)
    print("1-step MSE :", mse_1_step)
    print("1-step RMSE:", rmse_1_step)

    print("\n-----------------------------------")
    print("3-step MAE :", mae_3_step)
    print("3-step MSE :", mse_3_step)
    print("3-step RMSE:", rmse_3_step)

    print("\n-----------------------------------")
    print("6-step MAE :", mae_6_step)
    print("6-step MSE :", mse_6_step)
    print("6-step RMSE:", rmse_6_step)

    return {
        "impute_mae": final_mae,
        "impute_mse": final_mse,
        "impute_rmse": final_rmse,
        "forecast_mae_1": mae_1_step,
        "forecast_mse_1": mse_1_step,
        "forecast_rmse_1": rmse_1_step,
        "forecast_mae_3": mae_3_step,
        "forecast_mse_3": mse_3_step,
        "forecast_rmse_3": rmse_3_step,
        "forecast_mae_6": mae_6_step,
        "forecast_mse_6": mse_6_step,
        "forecast_rmse_6": rmse_6_step,
    }

In [15]:
def mask_evaluation_windows(x_t, m_t, evaluation_windows_val, meta):
    x_t_masked = x_t.copy()
    m_t_masked = m_t.copy()

    # Deduplicate (detector, start, end) so we don't re-mask the same block many times
    unique_blocks = set()
    for window in evaluation_windows_val:
        start_ts = window["blackout_start"]
        end_ts = window["blackout_end"]
        det_id = window["detector_id"]
        unique_blocks.add((start_ts, end_ts, det_id))

    for (start_ts, end_ts, det_id) in unique_blocks:
        start_idx = np.where(meta["timestamps"] == start_ts)[0][0]
        end_idx = np.where(meta["timestamps"] == end_ts)[0][0]
        detector_idx = np.where(meta["detectors"] == det_id)[0][0]

        x_t_masked[start_idx:end_idx+1, detector_idx] = np.nan
        m_t_masked[start_idx:end_idx+1, detector_idx] = 1

    return x_t_masked, m_t_masked

In [16]:
# Prepare training data by masking evaluation windows
x_t_train, m_t_train = mask_evaluation_windows(x_t, m_t, evaluation_windows_val, meta)
latent_dim = 20
D = x_t_train.shape[1]

In [17]:
from data_interface import build_time_features

X_time = build_time_features(meta["timestamps"])

# ---------------- MAR model ----------------
mar_params = mnar_blackout_lds.MNARParams.init_random(K=latent_dim, D=D, seed=42)
model_mar = mnar_blackout_lds.MNARBlackoutLDS(mar_params)
em_train_history_mar = model_mar.em_train(
    x_t_train,
    m_t_train,
    X_time=X_time,
    num_iters=12,
    update_phi=False,
    phi_steps=0,
    phi_lr=0.0,
    verbose=True,
    convergence_tol=1e-3,
)

# ---------------- MNAR model ----------------
mnar_params = copy.deepcopy(model_mar.params)

# Initialize b from empirical missing rate
eps = 1e-6
p_miss_d = np.clip(m_t_train.mean(axis=0), eps, 1 - eps)
mnar_params.b = np.log(p_miss_d / (1 - p_miss_d))

model_mnar = mnar_blackout_lds.MNARBlackoutLDS(mnar_params)
em_train_history_mnar = model_mnar.em_train(
    x_t_train,
    m_t_train,
    X_time=X_time,
    num_iters=12,
    update_phi=True,
    phi_steps=4,
    phi_lr=5e-4,
    verbose=True,
    convergence_tol=1e-3,
)



=== EM iteration 1/12 ===
  A norm: 4.344
  Q trace: 8.194
  mean diag(R): 37.230

=== EM iteration 2/12 ===
  A norm: 4.225
  Q trace: 24.719
  mean diag(R): 22.689
  max relative param change: 2.570e-01

=== EM iteration 3/12 ===
  A norm: 4.222
  Q trace: 35.362
  mean diag(R): 20.612
  max relative param change: 1.206e-01

=== EM iteration 4/12 ===
  A norm: 4.215
  Q trace: 39.576
  mean diag(R): 20.497
  max relative param change: 8.958e-02

=== EM iteration 5/12 ===
  A norm: 4.210
  Q trace: 41.870
  mean diag(R): 21.159
  max relative param change: 8.940e-02

=== EM iteration 6/12 ===
  A norm: 4.192
  Q trace: 46.512
  mean diag(R): 22.482
  max relative param change: 1.043e-01

=== EM iteration 7/12 ===
  A norm: 4.142
  Q trace: 50.419
  mean diag(R): 23.147
  max relative param change: 9.172e-02

=== EM iteration 8/12 ===
  A norm: 4.094
  Q trace: 53.918
  mean diag(R): 23.430
  max relative param change: 8.581e-02

=== EM iteration 9/12 ===
  A norm: 4.062
  Q trace: 57

### Reconstruction and Prediction

In [18]:
ekf_mar = model_mar.ekf_forward(x_t_train, m_t_train, X_time=X_time, use_missingness_obs=False)
smoother_mar = model_mar.rts_smoother(ekf_mar)

mu_filt_mar = ekf_mar["mu_filt"]
Sigma_filt_mar = ekf_mar["Sigma_filt"]
mu_smooth_mar = smoother_mar["mu_smooth"]
Sigma_smooth_mar = smoother_mar["Sigma_smooth"]

metrics_mar = evaluate_impute_forecast_model(
    model=model_mar,
    mu_smooth=mu_smooth_mar,
    Sigma_smooth=Sigma_smooth_mar,
    mu_filt=mu_filt_mar,
    Sigma_filt=Sigma_filt_mar,
    x_t=x_t,
    meta=meta,
    label="MAR LDS",
)


[MAR LDS] Imputation performance:
  MAE : 5.421049097566691
  MSE : 70.32399358378896
  RMSE: 8.385940232543335

[MAR LDS] Forecasting performance:
-----------------------------------
1-step MAE : 4.748208136663754
1-step MSE : 53.04238739123073
1-step RMSE: 7.283020485432588

-----------------------------------
3-step MAE : 5.104959797781098
3-step MSE : 58.34928297969126
3-step RMSE: 7.638670236349469

-----------------------------------
6-step MAE : 5.7794533988064165
6-step MSE : 71.47613528938345
6-step RMSE: 8.454355994952156


In [19]:
# ============================================================
# 2) Missingness diagnosis (Tests 1–3)
# ============================================================
def build_blackout_onset_dataset(windows, x_t_true, m_t_true, meta, X_time, past_steps=12):
    """
    Build (X, y) where each sample is a blackout ONSET edge:
      time t0 = start_idx-1 (last observed before blackout)
      label y = m[t0+1, d] (should be 1 for true onsets)
    Features (observed-only):
      - last observed speed at t0
      - rolling variance over past_steps (default 1 hour = 12*5min)
      - time features at t0
    Returns:
      X_obs: (N, 2+F_time)
      y:     (N,)
      t0_idx: (N,)
      d_idx:  (N,)
    """
    X_list, y_list, t0_list, d_list = [], [], [], []
    for w in windows:
        start_idx = np.where(meta["timestamps"] == w["blackout_start"])[0][0]
        d = np.where(meta["detectors"] == w["detector_id"])[0][0]
        t0 = start_idx - 1
        if t0 <= past_steps or t0 + 1 >= x_t_true.shape[0]:
            continue
        last_speed = x_t_true[t0, d]
        if not np.isfinite(last_speed):
            continue
        hist = x_t_true[t0-past_steps:t0, d]
        hist = hist[np.isfinite(hist)]
        if hist.size < max(3, past_steps//3):
            continue
        roll_var = float(np.var(hist))
        feats = np.concatenate([[last_speed, roll_var], X_time[t0]], axis=0)
        y = float(m_t_true[t0+1, d])  # whether next step is missing
        X_list.append(feats); y_list.append(y); t0_list.append(t0); d_list.append(d)
    return np.asarray(X_list, float), np.asarray(y_list, float), np.asarray(t0_list), np.asarray(d_list)


def build_matched_control_dataset(N, x_t_true, m_t_true, X_time, t0_idx, d_idx, rng=42):
    """
    Controls: sample (t,d) pairs with the SAME hour/weekend/rush distribution
    as onsets by sampling t from the same t0_idx pool, but with random detectors,
    and forcing label to be actual next-step missingness.
    """
    rs = np.random.default_rng(rng)
    Xc, yc = [], []
    T, D = x_t_true.shape
    for i in range(N):
        t0 = int(t0_idx[rs.integers(0, len(t0_idx))])
        d = int(rs.integers(0, D))
        if t0 <= 12 or t0 + 1 >= T:
            continue
        last_speed = x_t_true[t0, d]
        if not np.isfinite(last_speed):
            continue
        hist = x_t_true[t0-12:t0, d]
        hist = hist[np.isfinite(hist)]
        if hist.size < 4:
            continue
        roll_var = float(np.var(hist))
        feats = np.concatenate([[last_speed, roll_var], X_time[t0]], axis=0)
        y = float(m_t_true[t0+1, d])
        Xc.append(feats); yc.append(y)
    Xc = np.asarray(Xc, float)
    yc = np.asarray(yc, float)
    return Xc, yc


def auc_logreg(X, y, label="clf"):
    clf = LogisticRegression(max_iter=2000, class_weight="balanced")
    clf.fit(X, y)
    p = clf.predict_proba(X)[:, 1]
    auc = roc_auc_score(y, p)
    print(f"[{label}] AUC = {auc:.4f}  (N={len(y)}, pos={y.mean():.3f})")
    return auc


# --- Test 1: observed-only proxies near blackout edges ---
X_on, y_on, t0_on, d_on = build_blackout_onset_dataset(
    impute_evaluation_windows_val, x_t, m_t, meta, X_time, past_steps=12
)
X_ctrl, y_ctrl = build_matched_control_dataset(len(y_on), x_t, m_t, X_time, t0_on, d_on, rng=42)

X1 = np.vstack([X_on, X_ctrl])
y1 = np.concatenate([np.ones(len(y_on)), np.zeros(len(y_ctrl))])  # onset vs control
auc_test1 = auc_logreg(X1, y1, label="Test1 (observed-only): onset vs control")


# --- Test 2: latent improves missingness prediction ---
# Build two classifiers predicting next-step missingness:
#   (a) time + last_speed
#   (b) time + last_speed + smoothed latent state
def build_nextstep_missingness_dataset_balanced(
    x_t_true, m_t_true, X_time, mu_smooth,
    sample_stride=12, max_pos=20_000, neg_per_pos=3, seed=0
):
    rs = np.random.default_rng(seed)
    T, D = x_t_true.shape
    pos, neg = [], []

    for t in range(0, T - 2, sample_stride):
        ds = rs.integers(0, D, size=min(64, D))
        for d in ds:
            if not np.isfinite(x_t_true[t, d]):
                continue
            y = int(m_t_true[t + 1, d])  # next-step missing

            feat_obs = np.concatenate([[x_t_true[t, d]], X_time[t]])
            feat_lat = np.concatenate([[x_t_true[t, d]], X_time[t], mu_smooth[t]])

            if y == 1 and len(pos) < max_pos:
                pos.append((t, feat_obs, feat_lat, 1))
            elif y == 0 and (len(neg) < neg_per_pos * max(1, len(pos))):
                neg.append((t, feat_obs, feat_lat, 0))

    data = pos + neg
    # IMPORTANT: keep time ordering for a time-aware split
    data.sort(key=lambda z: z[0])  # sort by t

    Xobs = np.asarray([a for _, a, _, _ in data], float)
    Xlat = np.asarray([b for _, _, b, _ in data], float)
    y    = np.asarray([c for *_, c in data], int)
    return Xobs, Xlat, y


# ----------------------------
# Balanced Test2 (ROC-AUC + PR-AUC + LogLoss)
# ----------------------------
X2_obs, X2_lat, y2 = build_nextstep_missingness_dataset_balanced(
    x_t_true=x_t,
    m_t_true=m_t,
    X_time=X_time,
    mu_smooth=mu_smooth_mar,
    sample_stride=12,
    max_pos=20_000,
    neg_per_pos=3,
    seed=0,
)

print(f"[Test2 Balanced] N={len(y2)}  pos_rate={y2.mean():.3f}")

# Time-aware split: first 70% train, last 30% test (because we kept time order)
n = len(y2)
split = int(0.7 * n)

Xobs_tr, Xobs_te = X2_obs[:split], X2_obs[split:]
Xlat_tr, Xlat_te = X2_lat[:split], X2_lat[split:]
y_tr, y_te       = y2[:split], y2[split:]

clf_obs = LogisticRegression(max_iter=200, n_jobs=-1).fit(Xobs_tr, y_tr)
clf_lat = LogisticRegression(max_iter=200, n_jobs=-1).fit(Xlat_tr, y_tr)

for name, clf, Xte in [("OBS", clf_obs, Xobs_te), ("LAT", clf_lat, Xlat_te)]:
    p = clf.predict_proba(Xte)[:, 1]
    # sklearn versions removed `eps=` in log_loss; clip manually instead
    p = np.clip(p, 1e-7, 1 - 1e-7)
    print(f"[{name}] ROC-AUC:", roc_auc_score(y_te, p))
    print(f"[{name}] PR-AUC :", average_precision_score(y_te, p))
    print(f"[{name}] LogLoss:", log_loss(y_te, p))


# --- Test 3: event-level clustering / structure ---
try:
    det_events = data_interface.load_detector_blackouts("data", as_dataframe=True)
    net_events = data_interface.load_network_blackouts("data", as_dataframe=True)

    print("\n[Test3] Detector blackout durations (steps) summary:")
    print(det_events["len_steps"].describe())
    print("\n[Test3] Network blackout durations (steps) summary:")
    print(net_events["len_steps"].describe())

    # Simple clustering proxy: inter-event time (network-level)
    net_starts = pd.to_datetime(net_events["start"]).sort_values()
    deltas_min = net_starts.diff().dropna().dt.total_seconds() / 60.0
    print("\n[Test3] Network inter-event time (minutes) summary:")
    print(deltas_min.describe())
except Exception as e:
    print(f"[Test3] Skipped (missing blackout event parquet?): {e}")

[Test1 (observed-only): onset vs control] AUC = 0.7374  (N=578, pos=0.516)
[Test2 Balanced] N=60  pos_rate=0.250
[OBS] ROC-AUC: 0.48214285714285715
[OBS] PR-AUC : 0.29583333333333334
[OBS] LogLoss: 0.8846304717482462
[LAT] ROC-AUC: 0.5357142857142857
[LAT] PR-AUC : 0.3138888888888889
[LAT] LogLoss: 1.5093931476681492

[Test3] Detector blackout durations (steps) summary:
count      942.000000
mean       849.284501
std       7138.302338
min         12.000000
25%         18.000000
50%         37.000000
75%         84.000000
max      88536.000000
Name: len_steps, dtype: float64

[Test3] Network blackout durations (steps) summary:
count     25.00000
mean      66.92000
std       92.28214
min        2.00000
25%       19.00000
50%       36.00000
75%       84.00000
max      427.00000
Name: len_steps, dtype: float64

[Test3] Network inter-event time (minutes) summary:
count        24.000000
mean      18596.041667
std       28327.880873
min          20.000000
25%         748.750000
50%        583

In [20]:
# ---------------- Reconstruction & prediction: MNAR ----------------
ekf_mnar = model_mnar.ekf_forward(
    x_t_train,
    m_t_train,
    X_time=X_time,
    use_missingness_obs=True,     
)
smoother_mnar = model_mnar.rts_smoother(ekf_mnar)

mu_filt_mnar = ekf_mnar["mu_filt"]
Sigma_filt_mnar = ekf_mnar["Sigma_filt"]
mu_smooth_mnar = smoother_mnar["mu_smooth"]
Sigma_smooth_mnar = smoother_mnar["Sigma_smooth"]

metrics_mnar = evaluate_impute_forecast_model(
    model=model_mnar,
    mu_smooth=mu_smooth_mnar,
    Sigma_smooth=Sigma_smooth_mnar,
    mu_filt=mu_filt_mnar,
    Sigma_filt=Sigma_filt_mnar,
    x_t=x_t,
    meta=meta,
    label="MNAR LDS (report: const missingness var)",
)


[MNAR LDS (report: const missingness var)] Imputation performance:
  MAE : 5.330255647881937
  MSE : 66.29155389355113
  RMSE: 8.141962533293256

[MNAR LDS (report: const missingness var)] Forecasting performance:
-----------------------------------
1-step MAE : 4.697742157492266
1-step MSE : 50.374242781494495
1-step RMSE: 7.097481439320183

-----------------------------------
3-step MAE : 5.0728911274653505
3-step MSE : 56.15808247323508
3-step RMSE: 7.493869659477344

-----------------------------------
6-step MAE : 5.699766120342447
6-step MSE : 68.5219621844729
6-step RMSE: 8.27779935637926


In [22]:
# ============================================================
# 3) Inference ablation: missingness variance
#    (i) moment-matched (default) vs (ii) constant variance
# ============================================================
ekf_mnar_const = model_mnar.ekf_forward(
    x_t_train, m_t_train,X_time=X_time,
    missingness_var_mode="constant",
)
smoother_mnar_const = model_mnar.rts_smoother(ekf_mnar_const)

metrics_mnar_const = evaluate_impute_forecast_model(
    model=model_mnar,
    mu_smooth=smoother_mnar_const["mu_smooth"],
    Sigma_smooth=smoother_mnar_const["Sigma_smooth"],
    mu_filt=ekf_mnar_const["mu_filt"],
    Sigma_filt=ekf_mnar_const["Sigma_filt"],
    x_t=x_t,
    meta=meta,
    label="MNAR LDS (const missingness var)",
)


[MNAR LDS (const missingness var)] Imputation performance:
  MAE : 5.39648843115324
  MSE : 67.61631888903317
  RMSE: 8.22291425767247

[MNAR LDS (const missingness var)] Forecasting performance:
-----------------------------------
1-step MAE : 4.757923626273731
1-step MSE : 50.96568696378213
1-step RMSE: 7.139025631259642

-----------------------------------
3-step MAE : 5.115397432570886
3-step MSE : 56.20095663580049
3-step RMSE: 7.496729729408717

-----------------------------------
6-step MAE : 5.7311422446607345
6-step MSE : 68.25567849470387
6-step RMSE: 8.261699491914715


In [23]:
# ---------------- Baseline: LOCF ----------------
baseline_locf_metrics = evaluate_locf_baseline(
    x_t_true=x_t,
    x_t_masked=x_t_train,
    meta=meta,
    label="LOCF baseline",
)

print("\nDone: LOCF vs MAR vs MNAR evaluated on the same blackout windows.")


[LOCF baseline] Imputation performance:
  MAE : 7.880999151920512
  MSE : 181.6578376980285
  RMSE: 13.478050218708509

[LOCF baseline] Forecasting performance:
-----------------------------------
1-step MAE : 11.52644662019662
1-step MSE : 318.6954144807004
1-step RMSE: 17.852042305593507

-----------------------------------
3-step MAE : 12.502798017520238
3-step MSE : 362.3945872475539
3-step RMSE: 19.036664288880914

-----------------------------------
6-step MAE : 14.010510346065898
6-step MSE : 407.7179884622047
6-step RMSE: 20.192027844231117

Done: LOCF vs MAR vs MNAR evaluated on the same blackout windows.


In [24]:
# ============================================================
# 4) Bootstrap uncertainty on RMSE deltas (paper-ready)
#    - window-resampling with replacement
#    - length-weighted RMSE for imputation windows
#    - plain RMSE for 1/3/6-step forecast windows
# ============================================================
def collect_impute_window_mse(model, mu_smooth, Sigma_smooth, x_t_true, meta):
    """
    Returns lists aligned to impute_evaluation_windows_val:
      mse_list: per-window mean squared error inside blackout (float)
      w_list:   per-window weight = len_steps (int)
    """
    mse_list, w_list = [], []
    for w in impute_evaluation_windows_val:
        start_idx = np.where(meta["timestamps"] == w["blackout_start"])[0][0]
        end_idx   = np.where(meta["timestamps"] == w["blackout_end"])[0][0]
        d         = np.where(meta["detectors"] == w["detector_id"])[0][0]

        # reconstruct only within blackout slice
        eval_mu = mu_smooth[start_idx:end_idx+1]
        eval_S  = Sigma_smooth[start_idx:end_idx+1]
        recon, _ = model.reconstruct_from_smoother(eval_mu, eval_S)

        y_true = x_t_true[start_idx:end_idx+1, d]
        y_pred = recon[:, d]
        mask = np.isfinite(y_true) & np.isfinite(y_pred)
        if not mask.any():
            continue
        mse = float(np.mean((y_true[mask] - y_pred[mask])**2))
        mse_list.append(mse)
        w_list.append(int(w["len_steps"]))
    return np.asarray(mse_list, float), np.asarray(w_list, int)


def collect_forecast_sqerr(model, mu_filt, Sigma_filt, x_t_true, meta, horizon: int):
    """
    Returns per-window squared error list for given horizon.
    Forecast is produced from end_idx using model.k_step_forecast (no peeking).
    """
    sqerrs = []
    windows = {1: forecast_1_evaluation_windows_val,
               3: forecast_3_evaluation_windows_val,
               6: forecast_6_evaluation_windows_val}[horizon]

    for w in windows:
        end_idx = np.where(meta["timestamps"] == w["blackout_end"])[0][0]
        d       = np.where(meta["detectors"] == w["detector_id"])[0][0]
        target_idx = end_idx + horizon
        if target_idx >= x_t_true.shape[0]:
            continue

        y_true = float(x_t_true[target_idx, d])
        if not np.isfinite(y_true):
            continue

        pred_x, _ = model.k_step_forecast(mu_filt, Sigma_filt, end_idx, k=horizon)
        y_pred = float(pred_x[d])
        if not np.isfinite(y_pred):
            continue

        sqerrs.append((y_true - y_pred)**2)

    return np.asarray(sqerrs, float)


def collect_forecast_sqerr_baseline(x_t_true, x_t_masked, meta, forecast_fn, horizon: int):
    """
    Baseline version: forecast_fn(x_t_masked, target_idx, d) -> float
    """
    sqerrs = []
    windows = {1: forecast_1_evaluation_windows_val,
               3: forecast_3_evaluation_windows_val,
               6: forecast_6_evaluation_windows_val}[horizon]
    for w in windows:
        end_idx = np.where(meta["timestamps"] == w["blackout_end"])[0][0]
        d       = np.where(meta["detectors"] == w["detector_id"])[0][0]
        target_idx = end_idx + horizon
        if target_idx >= x_t_true.shape[0]:
            continue
        y_true = float(x_t_true[target_idx, d])
        y_pred = float(forecast_fn(x_t_masked, target_idx, d))
        if not (np.isfinite(y_true) and np.isfinite(y_pred)):
            continue
        sqerrs.append((y_true - y_pred)**2)
    return np.asarray(sqerrs, float)


def _rmse_from_sqerrs(sqerrs: np.ndarray) -> float:
    return float(np.sqrt(np.mean(sqerrs))) if sqerrs.size else np.nan


def _weighted_rmse_from_window_mse(mse_list: np.ndarray, w_list: np.ndarray) -> float:
    if mse_list.size == 0:
        return np.nan
    return float(np.sqrt(np.average(mse_list, weights=w_list)))

In [25]:
def bootstrap_rmse_delta(
    a_values: np.ndarray,
    b_values: np.ndarray,
    n_boot: int = 500,
    seed: int = 0,
    mode: str = "sqerr",         # "sqerr" or "win_mse"
    a_weights: np.ndarray | None = None,
    b_weights: np.ndarray | None = None,
):
    """
    Bootstraps delta = RMSE(a) - RMSE(b). Lower is better => negative delta is improvement.
    mode:
      - "sqerr": a_values/b_values are per-window squared errors, RMSE = sqrt(mean)
      - "win_mse": a_values/b_values are per-window MSE, RMSE = sqrt(weighted avg MSE)
    Returns: dict with point_est, CI, boot_samples
    """
    rs = np.random.default_rng(seed)

    # point estimates
    if mode == "sqerr":
        rmse_a = _rmse_from_sqerrs(a_values)
        rmse_b = _rmse_from_sqerrs(b_values)
    else:
        rmse_a = _weighted_rmse_from_window_mse(a_values, a_weights)
        rmse_b = _weighted_rmse_from_window_mse(b_values, b_weights)

    point = rmse_a - rmse_b

    # bootstrap
    boots = []
    n_a = len(a_values)
    n_b = len(b_values)
    for _ in range(n_boot):
        ia = rs.integers(0, n_a, size=n_a)
        ib = rs.integers(0, n_b, size=n_b)

        if mode == "sqerr":
            ra = _rmse_from_sqerrs(a_values[ia])
            rb = _rmse_from_sqerrs(b_values[ib])
        else:
            ra = _weighted_rmse_from_window_mse(a_values[ia], a_weights[ia])
            rb = _weighted_rmse_from_window_mse(b_values[ib], b_weights[ib])
        boots.append(ra - rb)

    boots = np.asarray(boots, float)
    lo, hi = np.quantile(boots, [0.025, 0.975])
    return {
        "point_delta": float(point),
        "ci95": (float(lo), float(hi)),
        "boot": boots,
        "rmse_a": float(rmse_a),
        "rmse_b": float(rmse_b),
    }


# ---------- Build error arrays ----------
# Imputation (window MSE + len_steps weights)
mar_imp_mse, mar_imp_w = collect_impute_window_mse(model_mar, mu_smooth_mar, Sigma_smooth_mar, x_t, meta)
mnar_imp_mse, mnar_imp_w = collect_impute_window_mse(model_mnar, mu_smooth_mnar, Sigma_smooth_mnar, x_t, meta)

# LOCF imputation: compute per-window MSE (still cheap)
def collect_impute_window_mse_baseline(x_t_true, x_t_masked, meta, impute_fn):
    mse_list, w_list = [], []
    for w in impute_evaluation_windows_val:
        start_idx = np.where(meta["timestamps"] == w["blackout_start"])[0][0]
        end_idx   = np.where(meta["timestamps"] == w["blackout_end"])[0][0]
        d         = np.where(meta["detectors"] == w["detector_id"])[0][0]
        y_true = x_t_true[start_idx:end_idx+1, d]
        y_pred = impute_fn(x_t_masked, start_idx, end_idx, d)
        mask = np.isfinite(y_true) & np.isfinite(y_pred)
        if not mask.any():
            continue
        mse = float(np.mean((y_true[mask] - y_pred[mask])**2))
        mse_list.append(mse)
        w_list.append(int(w["len_steps"]))
    return np.asarray(mse_list, float), np.asarray(w_list, int)

locf_imp_mse, locf_imp_w = collect_impute_window_mse_baseline(x_t, x_t_train, meta, locf_impute_baseline)

# Forecast sqerr arrays per horizon
mar_sq1 = collect_forecast_sqerr(model_mar, mu_filt_mar, Sigma_filt_mar, x_t, meta, horizon=1)
mar_sq3 = collect_forecast_sqerr(model_mar, mu_filt_mar, Sigma_filt_mar, x_t, meta, horizon=3)
mar_sq6 = collect_forecast_sqerr(model_mar, mu_filt_mar, Sigma_filt_mar, x_t, meta, horizon=6)

mnar_sq1 = collect_forecast_sqerr(model_mnar, mu_filt_mnar, Sigma_filt_mnar, x_t, meta, horizon=1)
mnar_sq3 = collect_forecast_sqerr(model_mnar, mu_filt_mnar, Sigma_filt_mnar, x_t, meta, horizon=3)
mnar_sq6 = collect_forecast_sqerr(model_mnar, mu_filt_mnar, Sigma_filt_mnar, x_t, meta, horizon=6)

locf_sq1 = collect_forecast_sqerr_baseline(x_t, x_t_train, meta, lambda xt, ti, d: locf_forecast_baseline(xt, ti-1, d), horizon=1)
locf_sq3 = collect_forecast_sqerr_baseline(x_t, x_t_train, meta, lambda xt, ti, d: locf_forecast_baseline(xt, ti-3, d), horizon=3)
locf_sq6 = collect_forecast_sqerr_baseline(x_t, x_t_train, meta, lambda xt, ti, d: locf_forecast_baseline(xt, ti-6, d), horizon=6)

season_sq1 = collect_forecast_sqerr_baseline(x_t, x_t_train, meta, seasonal_naive_forecast_baseline, horizon=1)
season_sq3 = collect_forecast_sqerr_baseline(x_t, x_t_train, meta, seasonal_naive_forecast_baseline, horizon=3)
season_sq6 = collect_forecast_sqerr_baseline(x_t, x_t_train, meta, seasonal_naive_forecast_baseline, horizon=6)

# ---------- Hour-of-Week baseline sqerr arrays ----------
how_mean, slot_of_week, global_mean = build_hour_of_week_climatology(
    x_t_masked=x_t_train,
    m_t_masked=m_t_train,
    timestamps=meta["timestamps"],
    step_minutes=5,
)
hour_of_week_forecast_fn = make_hour_of_week_forecast_fn(how_mean, slot_of_week, global_mean)

how_sq1 = collect_forecast_sqerr_baseline(x_t, x_t_train, meta, hour_of_week_forecast_fn, horizon=1)
how_sq3 = collect_forecast_sqerr_baseline(x_t, x_t_train, meta, hour_of_week_forecast_fn, horizon=3)
how_sq6 = collect_forecast_sqerr_baseline(x_t, x_t_train, meta, hour_of_week_forecast_fn, horizon=6)

# ---------- Bootstrap deltas ----------
print("\n=== Bootstrap RMSE deltas (delta = A - B; negative means A better) ===")

# MAR vs LOCF
res_imp_mar_locf = bootstrap_rmse_delta(mar_imp_mse, locf_imp_mse, n_boot=500, seed=0, mode="win_mse", a_weights=mar_imp_w, b_weights=locf_imp_w)
print(f"Impute RMSE: MAR - LOCF = {res_imp_mar_locf['point_delta']:+.4f}  CI95={res_imp_mar_locf['ci95']}")

res_f1_mar_locf = bootstrap_rmse_delta(mar_sq1, locf_sq1, n_boot=500, seed=1, mode="sqerr")
res_f3_mar_locf = bootstrap_rmse_delta(mar_sq3, locf_sq3, n_boot=500, seed=2, mode="sqerr")
res_f6_mar_locf = bootstrap_rmse_delta(mar_sq6, locf_sq6, n_boot=500, seed=3, mode="sqerr")
print(f"Fcast1 RMSE: MAR - LOCF = {res_f1_mar_locf['point_delta']:+.4f}  CI95={res_f1_mar_locf['ci95']}")
print(f"Fcast3 RMSE: MAR - LOCF = {res_f3_mar_locf['point_delta']:+.4f}  CI95={res_f3_mar_locf['ci95']}")
print(f"Fcast6 RMSE: MAR - LOCF = {res_f6_mar_locf['point_delta']:+.4f}  CI95={res_f6_mar_locf['ci95']}")

# MNAR vs MAR
res_imp_mnar_mar = bootstrap_rmse_delta(mnar_imp_mse, mar_imp_mse, n_boot=500, seed=10, mode="win_mse", a_weights=mnar_imp_w, b_weights=mar_imp_w)
print(f"\nImpute RMSE: MNAR - MAR = {res_imp_mnar_mar['point_delta']:+.4f}  CI95={res_imp_mnar_mar['ci95']}")

res_f1_mnar_mar = bootstrap_rmse_delta(mnar_sq1, mar_sq1, n_boot=500, seed=11, mode="sqerr")
res_f3_mnar_mar = bootstrap_rmse_delta(mnar_sq3, mar_sq3, n_boot=500, seed=12, mode="sqerr")
res_f6_mnar_mar = bootstrap_rmse_delta(mnar_sq6, mar_sq6, n_boot=500, seed=13, mode="sqerr")
print(f"Fcast1 RMSE: MNAR - MAR = {res_f1_mnar_mar['point_delta']:+.4f}  CI95={res_f1_mnar_mar['ci95']}")
print(f"Fcast3 RMSE: MNAR - MAR = {res_f3_mnar_mar['point_delta']:+.4f}  CI95={res_f3_mnar_mar['ci95']}")
print(f"Fcast6 RMSE: MNAR - MAR = {res_f6_mnar_mar['point_delta']:+.4f}  CI95={res_f6_mnar_mar['ci95']}")

# --- Baseline sanity: SeasonalNaive should be compared to HourOfWeekMean ---
res_f1_season_how = bootstrap_rmse_delta(season_sq1, how_sq1, n_boot=500, seed=21, mode="sqerr")
res_f3_season_how = bootstrap_rmse_delta(season_sq3, how_sq3, n_boot=500, seed=22, mode="sqerr")
res_f6_season_how = bootstrap_rmse_delta(season_sq6, how_sq6, n_boot=500, seed=23, mode="sqerr")
print(f"\nFcast1 RMSE: SeasonalNaive - HourOfWeek = {res_f1_season_how['point_delta']:+.4f}  CI95={res_f1_season_how['ci95']}")
print(f"Fcast3 RMSE: SeasonalNaive - HourOfWeek = {res_f3_season_how['point_delta']:+.4f}  CI95={res_f3_season_how['ci95']}")
print(f"Fcast6 RMSE: SeasonalNaive - HourOfWeek = {res_f6_season_how['point_delta']:+.4f}  CI95={res_f6_season_how['ci95']}")

# --- Models vs HourOfWeekMean ---
res_f1_mar_how = bootstrap_rmse_delta(mar_sq1, how_sq1, n_boot=500, seed=31, mode="sqerr")
res_f3_mar_how = bootstrap_rmse_delta(mar_sq3, how_sq3, n_boot=500, seed=32, mode="sqerr")
res_f6_mar_how = bootstrap_rmse_delta(mar_sq6, how_sq6, n_boot=500, seed=33, mode="sqerr")
print(f"\nFcast1 RMSE: MAR - HourOfWeek = {res_f1_mar_how['point_delta']:+.4f}  CI95={res_f1_mar_how['ci95']}")
print(f"Fcast3 RMSE: MAR - HourOfWeek = {res_f3_mar_how['point_delta']:+.4f}  CI95={res_f3_mar_how['ci95']}")
print(f"Fcast6 RMSE: MAR - HourOfWeek = {res_f6_mar_how['point_delta']:+.4f}  CI95={res_f6_mar_how['ci95']}")

res_f1_mnar_how = bootstrap_rmse_delta(mnar_sq1, how_sq1, n_boot=500, seed=41, mode="sqerr")
res_f3_mnar_how = bootstrap_rmse_delta(mnar_sq3, how_sq3, n_boot=500, seed=42, mode="sqerr")
res_f6_mnar_how = bootstrap_rmse_delta(mnar_sq6, how_sq6, n_boot=500, seed=43, mode="sqerr")
print(f"\nFcast1 RMSE: MNAR - HourOfWeek = {res_f1_mnar_how['point_delta']:+.4f}  CI95={res_f1_mnar_how['ci95']}")
print(f"Fcast3 RMSE: MNAR - HourOfWeek = {res_f3_mnar_how['point_delta']:+.4f}  CI95={res_f3_mnar_how['ci95']}")
print(f"Fcast6 RMSE: MNAR - HourOfWeek = {res_f6_mnar_how['point_delta']:+.4f}  CI95={res_f6_mnar_how['ci95']}")



=== Bootstrap RMSE deltas (delta = A - B; negative means A better) ===
Impute RMSE: MAR - LOCF = -5.0921  CI95=(-7.0998724022880175, -3.0292486885804264)
Fcast1 RMSE: MAR - LOCF = -10.5690  CI95=(-12.646440837166292, -8.106431925662411)
Fcast3 RMSE: MAR - LOCF = -11.3980  CI95=(-13.451453222550814, -9.145692835320263)
Fcast6 RMSE: MAR - LOCF = -11.7377  CI95=(-13.701436252005843, -9.553280102131584)

Impute RMSE: MNAR - MAR = -0.2440  CI95=(-1.5589055096183935, 1.2333407544169206)
Fcast1 RMSE: MNAR - MAR = -0.1855  CI95=(-1.5828058910198368, 1.2839101823746895)
Fcast3 RMSE: MNAR - MAR = -0.1448  CI95=(-1.4951432853203375, 1.0912836638542414)
Fcast6 RMSE: MNAR - MAR = -0.1766  CI95=(-1.5759430465968478, 1.1617406523019946)

Fcast1 RMSE: SeasonalNaive - HourOfWeek = +3.9765  CI95=(1.6405572258517576, 6.156798138636804)
Fcast3 RMSE: SeasonalNaive - HourOfWeek = +2.5999  CI95=(0.346698429252156, 4.981561996180359)
Fcast6 RMSE: SeasonalNaive - HourOfWeek = +3.0434  CI95=(0.7548304148588332

In [26]:
# ---------------- Hour-of-Week baseline ----------------
# Build detector-specific mean by (day_of_week, hour) using TRAIN/OBSERVED values only.
# Forecast fn signature must match: fn(x_t_masked, t_idx, detector_idx) -> float
def _build_hour_of_week_stats(x_obs: np.ndarray, meta: dict):
    # timestamps expected to be numpy datetime64 or pandas timestamps
    ts = meta["timestamps"]
    # Convert to pandas to reliably get dayofweek/hour (works for datetime64 too)
    ts_pd = pd.to_datetime(ts)
    how = (ts_pd.dayofweek.to_numpy() * 24 + ts_pd.hour.to_numpy()).astype(int)  # 0..167

    T, D = x_obs.shape
    means = np.full((D, 168), np.nan, float)
    global_means = np.nanmean(x_obs, axis=0)  # per-detector fallback

    for d in range(D):
        y = x_obs[:, d]
        for k in range(168):
            m = (how == k) & np.isfinite(y)
            if np.any(m):
                means[d, k] = float(np.mean(y[m]))
    return means, global_means, how

how_means, how_global_means, how_index = _build_hour_of_week_stats(x_t_train, meta)

def hour_of_week_forecast_fn(x_t_masked: np.ndarray, t_idx: int, d: int) -> float:
    k = int(how_index[t_idx])  # 0..167
    v = how_means[d, k]
    if np.isfinite(v):
        return float(v)
    # Fallback if that (d, hour-of-week) bin has no training samples
    vg = how_global_means[d]
    return float(vg) if np.isfinite(vg) else 0.0

how_sq1 = collect_forecast_sqerr_baseline(x_t, x_t_train, meta, hour_of_week_forecast_fn, horizon=1)
how_sq3 = collect_forecast_sqerr_baseline(x_t, x_t_train, meta, hour_of_week_forecast_fn, horizon=3)
how_sq6 = collect_forecast_sqerr_baseline(x_t, x_t_train, meta, hour_of_week_forecast_fn, horizon=6)

print("\n[HourOfWeekMean baseline] Forecast RMSEs:")
print("  1-step:", _rmse_from_sqerrs(how_sq1))
print("  3-step:", _rmse_from_sqerrs(how_sq3))
print("  6-step:", _rmse_from_sqerrs(how_sq6))

# ---------- Bootstrap deltas ----------
print("\n=== Bootstrap RMSE deltas (delta = A - B; negative means A better) ===")


[HourOfWeekMean baseline] Forecast RMSEs:
  1-step: 10.137130026808572
  3-step: 9.579017766777703
  6-step: 9.333722857457461

=== Bootstrap RMSE deltas (delta = A - B; negative means A better) ===


In [None]:
# ---------------- Baselines: Linear interp + Seasonal naive ----------------
metrics_lin_season = evaluate_impute_forecast_baseline(
    x_t_true=x_t,
    x_t_masked=x_t_train,
    meta=meta,
    impute_fn=linear_interp_impute_baseline,
    forecast_fn=seasonal_naive_forecast_baseline,
    label="LinearInterp (impute) + SeasonalNaive (forecast)",
)


how_mean, slot_of_week, global_mean = build_hour_of_week_climatology(
    x_t_masked=x_t_train,
    m_t_masked=m_t_train,
    timestamps=meta["timestamps"],
    step_minutes=5,
)
hour_of_week_forecast_fn = make_hour_of_week_forecast_fn(how_mean, slot_of_week, global_mean)


metrics_locf_how = evaluate_impute_forecast_baseline(
    x_t_true=x_t,
    x_t_masked=x_t_train,
    meta=meta,
    impute_fn=locf_impute_baseline,              # keep imputation baseline simple
    forecast_fn=hour_of_week_forecast_fn,        # NEW baseline
    label="LOCF (impute) + HourOfWeekMean (forecast)",
)

metrics_locf_season = evaluate_impute_forecast_baseline(
    x_t_true=x_t,
    x_t_masked=x_t_train,
    meta=meta,
    impute_fn=locf_impute_baseline,
    forecast_fn=seasonal_naive_forecast_baseline,
    label="LOCF (impute) + SeasonalNaive (forecast)",
)

metrics_spline_season = evaluate_impute_forecast_baseline(
    x_t_true=x_t,
    x_t_masked=x_t_train,
    meta=meta,
    impute_fn=spline_impute_baseline,
    forecast_fn=seasonal_naive_forecast_baseline,
    label="Spline (impute, fallback->linear) + SeasonalNaive (forecast)",
)

def _delta(a, b, key):
    return a[key] - b[key]

print("\n=== Summary deltas (lower is better) ===")
print("Dynamics win  (LOCF -> MAR)  impute_RMSE:", _delta(metrics_mar, baseline_locf_metrics, "impute_rmse"))
print("MNAR refine   (MAR  -> MNAR) impute_RMSE:", _delta(metrics_mnar, metrics_mar, "impute_rmse"))
print("MNAR refine   (MAR  -> MNAR) fcast_RMSE1:", _delta(metrics_mnar, metrics_mar, "forecast_rmse_1"))
print("MNAR refine   (MAR  -> MNAR) fcast_RMSE3:", _delta(metrics_mnar, metrics_mar, "forecast_rmse_3"))
print("MNAR refine   (MAR  -> MNAR) fcast_RMSE6:", _delta(metrics_mnar, metrics_mar, "forecast_rmse_6"))


[LinearInterp (impute) + SeasonalNaive (forecast)] Imputation performance:
  MAE : 4.683692189112716
  MSE : 55.559623086886994
  RMSE: 7.4538327783018445

[LinearInterp (impute) + SeasonalNaive (forecast)] Forecasting performance:
-----------------------------------
1-step MAE : 8.296791097902208
1-step MSE : 190.43454001592664
1-step RMSE: 13.799802173072143

-----------------------------------
3-step MAE : 7.088293768710437
3-step MSE : 141.3627038084115
3-step RMSE: 11.88960486342635

-----------------------------------
6-step MAE : 7.184098242292688
6-step MSE : 142.20147969735766
6-step RMSE: 11.924826191494686

[LOCF (impute) + HourOfWeekMean (forecast)] Imputation performance:
  MAE : 7.880999151920512
  MSE : 181.6578376980285
  RMSE: 13.478050218708509

[LOCF (impute) + HourOfWeekMean (forecast)] Forecasting performance:
-----------------------------------
1-step MAE : 6.386332833952402
1-step MSE : 96.49647383428665
1-step RMSE: 9.823261873445432

--------------------------

In [None]:
# ============================================================
# 5) BRITS / GRU-D hook
# ============================================================
try:
    # ------------------------------------------------------------
    # GRU-D style imputer (torch-only baseline)
    # - Causal imputation (uses past only)
    # - Forecasting via free-run after blackout end (no peeking)
    # ------------------------------------------------------------
    class GRUDImputer(nn.Module):
        def __init__(self, D: int, hidden: int = 128):
            super().__init__()
            self.D = D
            self.hidden = hidden
            # Per-feature decay -> D outputs
            self.decay = nn.Linear(D, D)
            # Input uses [x_tilde, obs_mask]
            self.inp = nn.Linear(2 * D, hidden)
            self.cell = nn.GRUCell(hidden, hidden)
            self.out = nn.Linear(hidden, D)

        def forward(
            self,
            x_filled: torch.Tensor,   # (B,T,D) NaNs already replaced
            obs_mask: torch.Tensor,   # (B,T,D) 1 if observed else 0
            delta: torch.Tensor,      # (B,T,D) time since last obs (in steps)
            x_mean: torch.Tensor,     # (D,)
            h0: torch.Tensor | None = None,  # (B,H)
            last_x0: torch.Tensor | None = None,  # (B,D) initial last observed per feature
        ):
            B, T, D = x_filled.shape
            device = x_filled.device

            if h0 is None:
                h = torch.zeros(B, self.hidden, device=device)
            else:
                h = h0

            # last observed value per feature
            if last_x0 is None:
                last_x = x_mean[None, :].repeat(B, 1)  # (B,D)
            else:
                last_x = last_x0

            preds = []
            h_seq = []

            for t in range(T):
                x_t = x_filled[:, t, :]         # (B,D)
                m_t = obs_mask[:, t, :]         # (B,D)
                d_t = delta[:, t, :]            # (B,D)

                # gamma = exp(-relu(W d + b))  in (0,1]
                gamma = torch.exp(-torch.relu(self.decay(d_t)))

                # GRU-D input imputation
                x_hat = gamma * last_x + (1.0 - gamma) * x_mean[None, :]
                x_tilde = m_t * x_t + (1.0 - m_t) * x_hat

                # update last observed
                last_x = m_t * x_t + (1.0 - m_t) * last_x

                u = torch.tanh(self.inp(torch.cat([x_tilde, m_t], dim=-1)))  # (B,H)
                h = self.cell(u, h)                                          # (B,H)
                y = self.out(h)                                              # (B,D)

                preds.append(y)
                h_seq.append(h)

            preds = torch.stack(preds, dim=1)   # (B,T,D)
            h_seq = torch.stack(h_seq, dim=1)   # (B,T,H)
            return preds, h_seq


    def _compute_delta(obs_mask_np: np.ndarray) -> np.ndarray:
        """
        obs_mask_np: (T,D) with 1 if observed else 0
        returns delta in steps since last observation, (T,D)
        """
        T, D = obs_mask_np.shape
        delta = np.zeros((T, D), dtype=np.float32)
        last = np.zeros(D, dtype=np.float32)
        for t in range(T):
            if t == 0:
                delta[t] = 0.0
            else:
                last = last + 1.0
                # reset where observed
                last = last * (1.0 - obs_mask_np[t].astype(np.float32))
                delta[t] = last
        return delta


    # -----------------------------
    # Prepare tensors from your panel
    # -----------------------------
    # Expected existing vars in your notebook:
    #   x_t_train, m_t_train, x_t, m_t, meta
    # where m_t is 1=missing, 0=observed
    x_train = x_t_train.copy()
    obs_mask = ((m_t_train == 0) & np.isfinite(x_train)).astype(np.float32)  # (T,D) 1=observed
    x_filled = np.nan_to_num(x_train, nan=0.0).astype(np.float32)            # (T,D)
    delta = _compute_delta(obs_mask)                                         # (T,D)
    x_mean = ( (x_filled * obs_mask).sum(axis=0) / (obs_mask.sum(axis=0) + 1e-6) ).astype(np.float32)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    D = x_filled.shape[1]
    model_grud = GRUDImputer(D=D, hidden=128).to(device)
    opt = torch.optim.Adam(model_grud.parameters(), lr=1e-3)

    x_mean_t = torch.tensor(x_mean, device=device)

    # -----------------------------
    # Mini-batch training via random subsequences
    # -----------------------------
    T_total = x_filled.shape[0]
    seq_len = 288           # 1 day on 5-min grid
    batch_size = 16
    steps = 300             # keep modest; bump if you want

    rs = np.random.default_rng(0)

    def make_batch():
        starts = rs.integers(0, max(1, T_total - seq_len - 1), size=batch_size)
        xb = np.stack([x_filled[s:s+seq_len] for s in starts], axis=0)     # (B,L,D)
        mb = np.stack([obs_mask[s:s+seq_len] for s in starts], axis=0)     # (B,L,D)
        db = np.stack([delta[s:s+seq_len] for s in starts], axis=0)        # (B,L,D)
        return xb, mb, db

    model_grud.train()
    for step in range(1, steps + 1):
        xb, mb, db = make_batch()
        xb_t = torch.tensor(xb, device=device)
        mb_t = torch.tensor(mb, device=device)
        db_t = torch.tensor(db, device=device)

        pred, _ = model_grud(xb_t, mb_t, db_t, x_mean_t)

        # loss only where observed (mb_t == 1)
        diff2 = (pred - xb_t) ** 2
        loss = (diff2 * mb_t).sum() / (mb_t.sum() + 1e-6)

        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model_grud.parameters(), 1.0)
        opt.step()

        if step % 50 == 0:
            print(f"[GRU-D] step {step:4d}/{steps}  loss={float(loss):.5f}")

    # -----------------------------
    # One full forward pass to cache predictions + hidden states
    # -----------------------------
    model_grud.eval()
    with torch.no_grad():
        x_full = torch.tensor(x_filled[None, :, :], device=device)     # (1,T,D)
        m_full = torch.tensor(obs_mask[None, :, :], device=device)     # (1,T,D)
        d_full = torch.tensor(delta[None, :, :], device=device)        # (1,T,D)
        pred_full, h_full = model_grud(x_full, m_full, d_full, x_mean_t)
        pred_full = pred_full[0].cpu().numpy()                         # (T,D)
        h_full = h_full[0].cpu().numpy()                               # (T,H)
    
    # Cache last observed value per feature at each timestep (for proper free-run)
    last_x_hist = np.zeros_like(x_filled, dtype=np.float32)  # (T,D)
    last = x_mean.astype(np.float32).copy()
    for t in range(x_filled.shape[0]):
        # update last where observed at time t
        obs = obs_mask[t].astype(bool)
        last[obs] = x_filled[t, obs]
        last_x_hist[t] = last

    # -----------------------------
    # Forecast helper: free-run k steps after end_idx
    # -----------------------------
    def grud_forecast_k(end_idx: int, k: int) -> np.ndarray:
        """
        Returns predicted x at time end_idx + k, without using ground truth
        beyond end_idx (free-run with missing masks).
        """
        model_grud.eval()
        with torch.no_grad():
            # start hidden at end_idx
            h0 = torch.tensor(h_full[end_idx][None, :], device=device)  # (1,H)
            last_x0 = torch.tensor(last_x_hist[end_idx][None, :], device=device)  # (1,D)

            # build a tiny rollout of length k, with "all missing" inputs
            # delta should start from per-feature delta at end_idx and then increase
            xb = torch.zeros(1, k, D, device=device)
            mb = torch.zeros(1, k, D, device=device)   # all missing

            # delta should start from per-feature delta at end_idx and then increase
            delta0 = torch.tensor(
                delta[end_idx][None, :],
                device=device,
                dtype=torch.float32
            )  # (1,D)
            steps = torch.arange(
                1, k + 1,
                device=device,
                dtype=torch.float32
            )[:, None]  # (k,1)

            # (1,k,D) — broadcasting does the right thing
            db = (delta0 + steps).unsqueeze(0)

            pred_k, _ = model_grud(xb, mb, db, x_mean_t, h0=h0, last_x0=last_x0)
            return pred_k[0, -1].cpu().numpy()  # (D,)

    # -----------------------------
    # Evaluate on the SAME windows using your existing evaluator
    # -----------------------------
    def grud_impute_fn(x_t_masked_unused, start_idx, end_idx, detector_idx):
        # causal predictions from pred_full on the masked training panel
        return pred_full[start_idx:end_idx+1, detector_idx].astype(float)

    def grud_forecast_fn(x_t_masked_unused, target_idx, detector_idx):
        # target_idx = end_idx + h  => forecast from end_idx, horizon=h
        # Need end_idx; infer h by scanning forecast windows, or pass end_idx directly in a custom eval.
        # Here, we use the fact that evaluator calls forecast_fn(target_idx, d) only.
        # We'll approximate by using k=1 (NOT ideal). Prefer the custom eval below.
        return float(pred_full[target_idx, detector_idx])

    # Better: custom forecast evaluation that knows end_idx + horizon (no leakage)
    def evaluate_grud_forecast_only(label="GRU-D (free-run forecast)"):
        y1t, y1p, y3t, y3p, y6t, y6p = [], [], [], [], [], []
        forecast_windows = (
            forecast_1_evaluation_windows_val
            + forecast_3_evaluation_windows_val
            + forecast_6_evaluation_windows_val
        )
        for w in forecast_windows:
            end_idx = np.where(meta["timestamps"] == w["blackout_end"])[0][0]
            d = np.where(meta["detectors"] == w["detector_id"])[0][0]
            h = int(w["horizon_steps"])
            target_idx = end_idx + h
            if target_idx >= x_t.shape[0]:
                continue
            yt = float(x_t[target_idx, d])
            if not np.isfinite(yt):
                continue
            yp = float(grud_forecast_k(end_idx=end_idx, k=h)[d])
            if not np.isfinite(yp):
                continue
            if h == 1:
                y1t.append(yt); y1p.append(yp)
            elif h == 3:
                y3t.append(yt); y3p.append(yp)
            elif h == 6:
                y6t.append(yt); y6p.append(yp)

        import sklearn.metrics
        def _pack(y_pred, y_true, h):
            mae = sklearn.metrics.mean_absolute_error(y_true, y_pred)
            mse = sklearn.metrics.mean_squared_error(y_true, y_pred)
            rmse = float(np.sqrt(mse))
            print(f"{h}-step MAE : {mae}")
            print(f"{h}-step MSE : {mse}")
            print(f"{h}-step RMSE: {rmse}")

        print(f"\n[{label}] Forecasting performance (free-run):")
        print("-----------------------------------")
        _pack(y1p, y1t, 1)
        print("\n-----------------------------------")
        _pack(y3p, y3t, 3)
        print("\n-----------------------------------")
        _pack(y6p, y6t, 6)


    # Use your baseline evaluator for imputation (forecasting handled by custom function above)
    metrics_grud_impute = evaluate_impute_forecast_baseline(
        x_t_true=x_t,
        x_t_masked=x_t_train,
        meta=meta,
        impute_fn=grud_impute_fn,
        forecast_fn=seasonal_naive_forecast_baseline,  # keep a sane baseline here
        label="GRU-D (impute via torch) + SeasonalNaive (forecast)",
    )
    evaluate_grud_forecast_only(label="GRU-D (free-run forecast)")

    print("\n[GRU-D] Done. You now have a torch baseline you can cite + compare.")
except Exception:
    print("\n[Optional] torch not available; skipping BRITS/GRU-D baseline hook.")