In [1]:
import pandas as pd

In [2]:
df = pd.read_parquet("pre-elec-pivot-all.parquet")
df.shape  # 60*60*6.5*20

(468000, 475)

In [3]:
"""
Information Leadership (single-lag) pipeline using log_mid_* (no delta_log_mid in input).

What this script does:
1) Time handling (UTC), RTH filtering, ToD bin labels (open60/mid/close60)
2) Returns construction: delta_log_mid_* from log_mid_* with safe per-day diffs
3) VECM on LOG LEVELS (log_mid_*): rolling CS / IS / ILS + leader per window
4) Lag calibration: choose a single lag per feature prefix OR one global lag
5) Build windowed supervised dataset (features averaged over VECM windows)
6) Train/eval:
     a) General (all RTH)
     b) ToD models (open60, close60)
7) Uniform plots + feature importance tables
8) Save artifacts: lag_map + models (joblib)

Required columns (wide format):
- 'bucket' OR 'bucket_ts' (time)
- 'log_mid_{TICKER}'
- Features by prefix, e.g.:
  'iso_flow_intensity_{TICKER}', 'total_flow_{TICKER}', 'total_flow_non_iso_{TICKER}',
  'num_trades_{TICKER}', 'quote_updates_{TICKER}', 'avg_rsprd_{TICKER}', 'pct_trades_iso_{TICKER}'
"""

import warnings

warnings.filterwarnings("ignore")

from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Literal

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.metrics import (
    roc_auc_score,
    accuracy_score,
    classification_report,
    confusion_matrix,
    r2_score,
    mean_absolute_error,
    mean_squared_error,
)
import joblib

from statsmodels.tsa.vector_ar.vecm import VECM


# ------------------
# CONFIG / CONSTANTS
# ------------------


@dataclass
class Config:
    # Universe
    etf: str
    tickers: List[str]  # desired tickers (we'll intersect with what's in df)
    # Time columns
    time_col: str = "bucket_ts"  # fallback to "bucket" if not present
    # Trading hours in UTC (RTH): 13:30 - 20:00
    rth_start_h: int = 13
    rth_start_m: int = 30
    rth_end_h: int = 20
    rth_end_m: int = 0
    # VECM (rolling)
    vecm_window: int = 1000
    vecm_step: int = 200
    k_ar_diff: int = 1
    coint_rank: int = 1
    # Lag calibration
    lag_max: int = 5
    # policy: "per_prefix" = one lag per feature prefix; "global" = one lag for all features
    lag_policy: Literal["per_prefix", "global"] = "per_prefix"
    lag_calibration_frac: float = 0.5  # first X% of each day used to calibrate lags
    # Feature prefixes to include (only those found in df are used)
    feature_prefixes: Tuple[str, ...] = (
        "iso_flow_intensity",
        "total_flow",
        "total_flow_non_iso",
        "num_trades",
        "quote_updates",
        "avg_rsprd",
        "pct_trades_iso",
    )
    # Target kind for modeling
    # "binary"  -> ILI_{ETF} (1 if ETF is leader for the window)
    # "regression" -> ILS_{ETF} (continuous)
    task_type: Literal["binary", "regression"] = "binary"
    # Model selector
    # For binary: "logit"|"rf" ; For regression: "linreg"|"rf"
    model_type: Literal["logit", "linreg", "rf"] = "logit"
    # Output paths
    artifacts_dir: str = "il_artifacts"  # where to save learned lags & models


# -----------------------------
# UTILS: TIME, ALIGNMENT, BINS
# -----------------------------


def ensure_datetime(df: pd.DataFrame, cfg: Config) -> pd.DataFrame:
    """Ensure we have a UTC datetime column named 'ts' from cfg.time_col or 'bucket'."""
    if cfg.time_col in df.columns:
        ts = pd.to_datetime(df[cfg.time_col], utc=True, errors="coerce")
    elif "bucket" in df.columns:
        ts = pd.to_datetime(df["bucket"], utc=True, errors="coerce")
    else:
        raise ValueError("No time column found. Provide 'bucket_ts' or 'bucket'.")

    out = df.copy()
    out["ts"] = ts
    out = out.loc[out["ts"].notna()].copy()
    out.sort_values("ts", inplace=True)
    out.reset_index(drop=True, inplace=True)
    return out


def date_utc(ts: pd.Series) -> pd.Series:
    if ts.dt.tz is None:
        ts = ts.dt.tz_localize("UTC")
    return ts.dt.tz_convert("UTC").dt.date


def in_rth(ts: pd.Series, cfg: Config) -> pd.Series:
    """Boolean mask for rows inside RTH window in UTC."""
    hhmm = ts.dt.hour * 60 + ts.dt.minute
    start = cfg.rth_start_h * 60 + cfg.rth_start_m
    end = cfg.rth_end_h * 60 + cfg.rth_end_m
    return (hhmm >= start) & (hhmm < end)


def tod_bucket(ts: pd.Series, cfg: Config) -> pd.Series:
    """Label rows into 'open60', 'close60', 'mid' (inside RTH), else 'off'."""
    minutes = (ts.dt.hour * 60 + ts.dt.minute) - (
        cfg.rth_start_h * 60 + cfg.rth_start_m
    )
    total = (cfg.rth_end_h * 60 + cfg.rth_end_m) - (
        cfg.rth_start_h * 60 + cfg.rth_start_m
    )  # usually 390
    labels = pd.Series("off", index=ts.index, dtype="object")
    mask_rth = in_rth(ts, cfg)
    m = minutes[mask_rth]
    lab = pd.Series(index=m.index, dtype="object")
    lab[(m >= 0) & (m < 60)] = "open60"
    lab[(m >= 60) & (m < total - 60)] = "mid"
    lab[(m >= total - 60) & (m < total)] = "close60"
    labels.loc[mask_rth] = lab
    return labels


# -----------------------------
# RETURNS FROM LOG LEVELS
# -----------------------------


def add_delta_returns_from_log_mid(
    df: pd.DataFrame, tickers: List[str]
) -> pd.DataFrame:
    """
    Create delta_log_mid_{t} = log_mid_{t} - log_mid_{t}.shift(1)
    SAFELY per day (no cross-day diffs).
    """
    out = df.copy()
    d = date_utc(out["ts"])
    for t in tickers:
        colL = f"log_mid_{t}"
        if colL not in out.columns:
            continue
        # raw diff
        delta = out[colL].diff()
        # invalidate diffs at day boundaries
        same_day = d == d.shift(1)
        delta[~same_day] = np.nan
        out[f"delta_log_mid_{t}"] = delta
    return out


# -----------------------------
# VECM → CS/IS/ILS/ILI per window
# -----------------------------


def _alpha_perp_from_alpha(alpha: np.ndarray) -> np.ndarray:
    U, S, Vt = np.linalg.svd(alpha, full_matrices=True)
    vec = U[:, -1]
    sgn = np.sign(vec[0]) if vec[0] != 0 else 1.0
    vec = sgn * vec
    denom = np.sum(np.abs(vec))
    return vec / denom if denom > 1e-12 else np.ones(len(vec)) / len(vec)


def _safe_is_from_alpha_perp(alpha_perp: np.ndarray, Omega: np.ndarray) -> np.ndarray:
    psi = alpha_perp.reshape(-1, 1)
    denom = float(psi.T @ Omega @ psi)
    if not np.isfinite(denom) or abs(denom) < 1e-18:
        denom = float(np.trace(Omega))
        denom = denom if denom > 0 else 1.0
    num = np.array(
        [(psi[i, 0] ** 2) * Omega[i, i] for i in range(len(alpha_perp))], dtype=float
    )
    num = np.clip(num, 0.0, None)
    if num.sum() <= 0:
        return np.ones_like(num) / len(num)
    IS = num / denom
    IS = IS / IS.sum()
    return IS


def _ils_from_cs_is(CS: np.ndarray, IS: np.ndarray) -> np.ndarray:
    beta = IS / (CS + 1e-12)
    w = beta**2
    return w / w.sum() if w.sum() > 0 else np.ones_like(w) / len(w)


def build_levels_matrix(df: pd.DataFrame, tickers: List[str]) -> pd.DataFrame:
    """
    Collect log levels for tickers: columns -> [tickers], rows -> aligned timestamps.
    """
    cols = [f"log_mid_{t}" for t in tickers if f"log_mid_{t}" in df.columns]
    P = df[cols].copy()
    P.columns = [c.replace("log_mid_", "") for c in cols]
    return P


def rolling_vecm_panel(
    df: pd.DataFrame, cfg: Config, group_tickers: List[str]
) -> pd.DataFrame:
    """
    Compute CS/IS/ILS and leader per rolling window for selected tickers using LOG LEVELS.
    """
    P_full = build_levels_matrix(df, group_tickers)
    out_rows = []
    for start in range(0, len(P_full) - cfg.vecm_window + 1, cfg.vecm_step):
        end = start + cfg.vecm_window
        P = P_full.iloc[start:end]
        # Need variance and no NaNs
        if (P.std() < 1e-10).any() or P.isna().any().any():
            continue
        try:
            vecm = VECM(P, k_ar_diff=cfg.k_ar_diff, coint_rank=cfg.coint_rank)
            res = vecm.fit()
            alpha = res.alpha
            Omega = res.sigma_u
            alpha_perp = _alpha_perp_from_alpha(alpha)
            CS = np.abs(alpha_perp)
            CS = CS / CS.sum()
            IS = _safe_is_from_alpha_perp(alpha_perp, Omega)
            ILS = _ils_from_cs_is(CS, IS)
            leader = P.columns[int(np.argmax(ILS))]

            row = {"start_idx": start, "end_idx": end, "leader": leader}
            for i, t in enumerate(P.columns):
                row[f"CS_{t}"] = CS[i]
                row[f"IS_{t}"] = IS[i]
                row[f"ILS_{t}"] = ILS[i]
            # attach timing (for later joining/labels)
            row["ts_start"] = df.iloc[start]["ts"]
            row["ts_end"] = df.iloc[end - 1]["ts"]
            row["ts_mid"] = df.iloc[(start + end) // 2]["ts"]
            out_rows.append(row)
        except Exception:
            continue
    return pd.DataFrame(out_rows)


# --------------------------
# LAG CALIBRATION & APPLYING
# --------------------------


def _lagged_corr(x: pd.Series, y: pd.Series, lag: int) -> float:
    """Correlation with positive lag meaning x leads y by lag steps."""
    if lag > 0:
        return x.shift(lag).corr(y)
    elif lag < 0:
        return x.corr(y.shift(-lag))
    else:
        return x.corr(y)


def _valid_same_day_pairs(ts: pd.Series, lag: int) -> pd.Series:
    """Mask rows where shifting by lag stays on the same UTC date (prevents cross-day leakage)."""
    if lag == 0:
        return pd.Series(True, index=ts.index)
    shifted = ts.shift(lag)
    return date_utc(ts) == date_utc(shifted)


def calibrate_lags(
    df: pd.DataFrame,
    cfg: Config,
    tickers: List[str],
    policy: Literal["per_prefix", "global"] = "per_prefix",
    max_lag: int = 5,
    calibration_frac: float = 0.5,
) -> Dict[str, int]:
    """
    Learn canonical lags with two options:
    - per_prefix: ONE lag per feature prefix (e.g., all iso_flow_intensity_* share same lag)
      chosen to maximize avg |corr(feature_{t} shifted by L, delta_log_mid_{t})| across all tickers+days
    - global: ONE lag for all prefixes at once
    Correlations computed on the *first calibration_frac of each day* and only on same-day-aligned pairs.
    """
    df = df.copy()
    df["date"] = date_utc(df["ts"])

    # figure out which prefixes actually exist
    prefixes = [
        p
        for p in cfg.feature_prefixes
        if any(c.startswith(p + "_") for c in df.columns)
    ]
    # returns series present
    returns_cols = {
        t: f"delta_log_mid_{t}" for t in tickers if f"delta_log_mid_{t}" in df.columns
    }

    if policy == "per_prefix":
        lag_map: Dict[str, int] = (
            {}
        )  # map per COLUMN at the end; we first pick per prefix
        prefix_best: Dict[str, int] = {}
        for pref in prefixes:
            # collect all feature columns for this prefix across tickers
            cols = [
                f"{pref}_{t}"
                for t in tickers
                if f"{pref}_{t}" in df.columns and t in returns_cols
            ]
            if not cols:
                continue
            # evaluate candidate lags
            score_by_L = {}
            for L in range(0, max_lag + 1):
                c_list, w_list = [], []
                for d, g in df.groupby("date"):
                    n = len(g)
                    if n < 5:
                        continue
                    split = int(max(1, np.floor(calibration_frac * n)))
                    gcal = g.iloc[:split]
                    mask = _valid_same_day_pairs(gcal["ts"], L)
                    if mask.sum() < 20:
                        continue
                    # average over all tickers' columns for this prefix
                    score_cols = []
                    for col in cols:
                        t = col.rsplit("_", 1)[-1]
                        ycol = returns_cols[t]
                        c = _lagged_corr(gcal.loc[mask, col], gcal.loc[mask, ycol], L)
                        if np.isfinite(c):
                            score_cols.append(abs(c))
                    if score_cols:
                        c_list.append(np.mean(score_cols))
                        w_list.append(mask.sum())
                score_by_L[L] = (
                    np.average(c_list, weights=w_list) if c_list else -np.inf
                )
            bestL = (
                int(max(score_by_L, key=lambda k: score_by_L[k]))
                if len(score_by_L)
                else 0
            )
            prefix_best[pref] = bestL
        # assign chosen lag to each COLUMN of the prefix
        for pref, L in prefix_best.items():
            for t in tickers:
                col = f"{pref}_{t}"
                if col in df.columns:
                    lag_map[col] = L
        return lag_map

    elif policy == "global":
        # All prefixes, pooled
        agg_scores = {L: [] for L in range(0, max_lag + 1)}
        agg_weights = {L: [] for L in range(0, max_lag + 1)}
        for pref in prefixes:
            for t in tickers:
                col = f"{pref}_{t}"
                if col not in df.columns or t not in returns_cols:
                    continue
                ycol = returns_cols[t]
                for L in range(0, max_lag + 1):
                    for d, g in df.groupby("date"):
                        n = len(g)
                        if n < 5:
                            continue
                        split = int(max(1, np.floor(calibration_frac * n)))
                        gcal = g.iloc[:split]
                        mask = _valid_same_day_pairs(gcal["ts"], L)
                        if mask.sum() < 20:
                            continue
                        c = _lagged_corr(gcal.loc[mask, col], gcal.loc[mask, ycol], L)
                        if np.isfinite(c):
                            agg_scores[L].append(abs(c))
                            agg_weights[L].append(mask.sum())
        lag_score = {
            L: (
                np.average(agg_scores[L], weights=agg_weights[L])
                if agg_scores[L]
                else -np.inf
            )
            for L in agg_scores.keys()
        }
        best_global = (
            int(max(lag_score, key=lambda k: lag_score[k])) if len(lag_score) else 0
        )
        # assign same lag to all feature columns we will use
        lag_map = {}
        for pref in prefixes:
            for t in tickers:
                col = f"{pref}_{t}"
                if col in df.columns:
                    lag_map[col] = best_global
        return lag_map

    else:
        raise ValueError("Unknown lag policy: use 'per_prefix' or 'global'.")


def apply_lags(df: pd.DataFrame, lag_map: Dict[str, int]) -> pd.DataFrame:
    """Shift features by learned lags; drop rows where any shifted feature crosses a day boundary."""
    out = df.copy()
    out["date"] = date_utc(out["ts"])
    safe_mask = pd.Series(True, index=out.index)
    for col, L in lag_map.items():
        if col not in out.columns or L == 0:
            continue
        same_day = _valid_same_day_pairs(out["ts"], L)
        safe_mask &= same_day
        out[col] = out[col].shift(L)
    out = out.loc[safe_mask].copy()
    out.drop(columns=["date"], inplace=True)
    out.reset_index(drop=True, inplace=True)
    return out


# -------------------------
# BUILD WINDOWED DATAFRAME
# -------------------------


def window_average_features(
    df: pd.DataFrame, cfg: Config, tickers: List[str], start_idx: int, end_idx: int
) -> Dict[str, float]:
    """Average per-ticker features over [start_idx:end_idx)."""
    sl = df.iloc[start_idx:end_idx]
    row = {}
    for t in tickers:
        for pref in cfg.feature_prefixes:
            col = f"{pref}_{t}"
            if col in sl.columns:
                row[col] = float(sl[col].mean())
    return row


def build_panel_dataset(
    df: pd.DataFrame, panel: pd.DataFrame, cfg: Config, tickers: List[str]
) -> pd.DataFrame:
    """
    Build a *predictive* dataset:
      - X_k = averages of features over window k (panel row k: [start_k:end_k))
      - y_k = ILS_{ETF} (or ILI) from *next* window k+1
    We drop the last window because it has no future label.
    """
    rows = []
    pnl = panel.reset_index(drop=True).copy()

    for k in range(len(pnl) - 1):
        r_feat = pnl.iloc[k]  # window k for features
        r_tgt = pnl.iloc[k + 1]  # window k+1 for target

        st, en = int(r_feat["start_idx"]), int(r_feat["end_idx"])
        feats = window_average_features(df, cfg, tickers, st, en)

        if cfg.task_type == "binary":
            target = 1 if r_tgt["leader"] == cfg.etf else 0
        else:
            target = float(r_tgt.get(f"ILS_{cfg.etf}", np.nan))

        ts_mid = r_tgt["ts_mid"]  # time stamp of the *target* window
        rows.append(
            {"ts_mid": ts_mid, "tod_bin": None, "y": target, **feats}  # fill after
        )

    Xy = pd.DataFrame(rows)

    # Time-of-day bins are computed on the *target* window timestamp
    Xy["tod_bin"] = tod_bucket(pd.to_datetime(Xy["ts_mid"], utc=True), cfg)
    Xy = Xy.loc[Xy["tod_bin"].isin(["open60", "mid", "close60"])].reset_index(drop=True)

    # OPTIONAL: remove ETF's *own* contemporaneous features to be stricter
    # for c in list(Xy.columns):
    #     if c.endswith(f"_{cfg.etf}") and c not in ("y","ts_mid","tod_bin"):
    #         Xy.drop(columns=c, inplace=True)

    return Xy


# -----------------
# MODEL FACTORY/EVAL
# -----------------


def make_model(cfg: Config):
    if cfg.task_type == "binary":
        if cfg.model_type == "logit":
            return Pipeline(
                [
                    ("scaler", StandardScaler(with_mean=False)),
                    ("clf", LogisticRegression(max_iter=2000)),
                ]
            )
        elif cfg.model_type == "rf":
            return RandomForestClassifier(n_estimators=300, random_state=42, n_jobs=-1)
        else:
            raise ValueError("For binary task, choose 'logit' or 'rf'")
    else:  # regression
        if cfg.model_type == "linreg":
            return Pipeline(
                [
                    ("scaler", StandardScaler(with_mean=False)),
                    ("reg", LinearRegression()),
                ]
            )
        elif cfg.model_type == "rf":
            return RandomForestRegressor(n_estimators=300, random_state=42, n_jobs=-1)
        else:
            raise ValueError("For regression task, choose 'linreg' or 'rf'")


def evaluate_model_binary(y_true, y_pred_proba, y_pred):
    auc = roc_auc_score(y_true, y_pred_proba) if len(np.unique(y_true)) == 2 else np.nan
    acc = accuracy_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred)
    rep = classification_report(y_true, y_pred, digits=3)
    return {"auc": auc, "accuracy": acc, "confusion_matrix": cm, "report": rep}


def evaluate_model_regression(y_true, y_pred):
    r2 = r2_score(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    rmse = mean_squared_error(y_true, y_pred, squared=False)
    return {"r2": r2, "mae": mae, "rmse": rmse}


def train_one_dataset(Xy: pd.DataFrame, cfg: Config, label: str):
    """
    Train a single model on provided Xy (no CV here to keep flow simple).
    Returns fitted model, per-row predictions df, and metrics.
    """
    os.makedirs(cfg.artifacts_dir, exist_ok=True)

    meta_cols = ["ts_mid", "tod_bin"]
    y = Xy["y"].astype(float if cfg.task_type == "regression" else int)
    X = Xy.drop(
        columns=["y"] + [c for c in meta_cols if c in Xy.columns], errors="ignore"
    ).fillna(0.0)

    model = make_model(cfg)
    model.fit(X, y)

    if cfg.task_type == "binary":
        if hasattr(model, "predict_proba"):
            y_proba = model.predict_proba(X)[:, 1]
        elif hasattr(model, "decision_function"):
            z = model.decision_function(X)
            y_proba = 1 / (1 + np.exp(-z))
        else:
            y_proba = model.predict(X)
        y_hat = (y_proba >= 0.5).astype(int)
        metrics = evaluate_model_binary(y.values, y_proba, y_hat)
        preds = pd.DataFrame(
            {
                "ts_mid": Xy["ts_mid"],
                "tod_bin": Xy.get("tod_bin"),
                "y": y,
                "y_proba": y_proba,
                "y_hat": y_hat,
            }
        )
    else:
        y_hat = model.predict(X)
        metrics = evaluate_model_regression(y.values, y_hat)
        preds = pd.DataFrame(
            {
                "ts_mid": Xy["ts_mid"],
                "tod_bin": Xy.get("tod_bin"),
                "y": y,
                "y_hat": y_hat,
            }
        )

    joblib.dump(model, f"{cfg.artifacts_dir}/model_{label}.joblib")
    return model, preds, metrics


def uniform_plots(
    preds_general: Optional[pd.DataFrame],
    metrics_general: Optional[Dict],
    preds_tod_open: Optional[pd.DataFrame],
    metrics_tod_open: Optional[Dict],
    preds_tod_close: Optional[pd.DataFrame],
    metrics_tod_close: Optional[Dict],
    cfg: Config,
):
    """Uniform visualization for both families (general/open60/close60)."""
    if cfg.task_type == "binary":
        plt.figure(figsize=(12, 4))
        labels = ["general", "open60", "close60"]
        aucs = [
            metrics_general.get("auc", np.nan) if metrics_general else np.nan,
            metrics_tod_open.get("auc", np.nan) if metrics_tod_open else np.nan,
            metrics_tod_close.get("auc", np.nan) if metrics_tod_close else np.nan,
        ]
        accs = [
            metrics_general.get("accuracy", np.nan) if metrics_general else np.nan,
            metrics_tod_open.get("accuracy", np.nan) if metrics_tod_open else np.nan,
            metrics_tod_close.get("accuracy", np.nan) if metrics_tod_close else np.nan,
        ]
        x = np.arange(len(labels))
        plt.subplot(1, 2, 1)
        plt.bar(x - 0.2, aucs, width=0.4, label="AUC")
        plt.bar(x + 0.2, accs, width=0.4, label="Accuracy")
        plt.xticks(x, labels)
        plt.title("Binary metrics")
        plt.legend()

        plt.subplot(1, 2, 2)
        for name, p in [
            ("general", preds_general),
            ("open60", preds_tod_open),
            ("close60", preds_tod_close),
        ]:
            if p is None or "y_proba" not in p.columns:
                continue
            dfb = p.copy()
            dfb["bin"] = pd.qcut(dfb["y_proba"], q=10, duplicates="drop")
            cal = dfb.groupby("bin").agg(
                y_mean=("y", "mean"), proba_mean=("y_proba", "mean")
            )
            plt.plot(
                cal["proba_mean"].values, cal["y_mean"].values, marker="o", label=name
            )
        plt.plot([0, 1], [0, 1], "--")
        plt.xlabel("Predicted prob (bin mean)")
        plt.ylabel("Empirical freq")
        plt.title("Calibration (deciles)")
        plt.legend()
        plt.tight_layout()
        plt.show()

        # Text reports in console
        print("\n=== Binary: Reports ===")
        for name, m in [
            ("general", metrics_general),
            ("open60", metrics_tod_open),
            ("close60", metrics_tod_close),
        ]:
            if m and "report" in m:
                print(f"\n[{name}]")
                print(m["report"])

    else:
        # Regression: scatter plots
        for name, p in [
            ("general", preds_general),
            ("open60", preds_tod_open),
            ("close60", preds_tod_close),
        ]:
            if p is None:
                continue
            plt.figure(figsize=(5, 5))
            plt.scatter(p["y"], p["y_hat"], alpha=0.5)
            mn = min(p["y"].min(), p["y_hat"].min())
            mx = max(p["y"].max(), p["y_hat"].max())
            plt.plot([mn, mx], [mn, mx], "--")
            plt.xlabel("True")
            plt.ylabel("Pred")
            plt.title(f"Regression fit: {name}")
            plt.tight_layout()
            plt.show()

        print("\n=== Regression: Metrics ===")
        for name, m in [
            ("general", metrics_general),
            ("open60", metrics_tod_open),
            ("close60", metrics_tod_close),
        ]:
            if m:
                print(
                    f"\n[{name}] R2={m['r2']:.3f}, MAE={m['mae']:.4f}, RMSE={m['rmse']:.4f}"
                )


def feature_importance_table(model, X_cols: List[str], top_n: int = 20) -> pd.DataFrame:
    """Return a tidy table of feature importances / coefficients where available."""
    if hasattr(model, "named_steps"):
        final_est = list(model.named_steps.values())[-1]
    else:
        final_est = model

    if hasattr(final_est, "feature_importances_"):
        vals = final_est.feature_importances_
        tbl = pd.DataFrame({"feature": X_cols, "importance": vals})
        return tbl.sort_values("importance", ascending=False).head(top_n)
    elif hasattr(final_est, "coef_"):
        coefs = getattr(final_est, "coef_")
        if coefs.ndim == 1:
            vals = coefs
        else:
            vals = np.linalg.norm(coefs, axis=0)  # multiclass / multi-output
        tbl = pd.DataFrame({"feature": X_cols, "coef_abs": np.abs(vals), "coef": vals})
        return tbl.sort_values("coef_abs", ascending=False).head(top_n)
    else:
        return pd.DataFrame()


# ----------------------
# TRAINING ENTRY POINT
# ----------------------


def run_pipeline(df_wide: pd.DataFrame, cfg: Config):
    """
    End-to-end:
      1) Ensure datetime, filter to RTH
      2) Build group = (ETF + tickers) that actually exist in df (via log_mid_*)
      3) Add returns delta_log_mid_* from log_mid_*
      4) Rolling VECM on log levels -> panel with CS/IS/ILS + leader
      5) Calibrate lags (per_prefix/global) vs returns
      6) Apply lags (same-day safety)
      7) Build supervised dataset (avg features per VECM window)
      8) Train general + ToD models; plot & print feature importances
      9) Save artifacts (lag map + models)
    """
    os.makedirs(cfg.artifacts_dir, exist_ok=True)

    # 1) Time handling
    df = ensure_datetime(df_wide, cfg)
    mask_rth = in_rth(df["ts"], cfg)
    df = df.loc[mask_rth].reset_index(drop=True)

    df["date"] = date_utc(df["ts"])
    log_mid_cols = [c for c in df.columns if c.startswith("log_mid_")]

    # forward-fill within each day (prevents cross-day leakage)
    if log_mid_cols:
        df[log_mid_cols] = df.groupby("date", group_keys=False)[log_mid_cols].apply(
            lambda g: g.ffill()
        )

    # drop rows that still have NaNs in any of the group log_mid_* columns
    present = [
        c.replace("log_mid_", "") for c in df.columns if c.startswith("log_mid_")
    ]
    group = [cfg.etf] + [t for t in cfg.tickers if t != cfg.etf]
    group = [t for t in group if t in present]
    need_cols = [f"log_mid_{t}" for t in group]
    df = df.dropna(subset=need_cols, how="any").reset_index(drop=True)

    # optional: remove the helper column to keep df tidy
    df = df.drop(columns=["date"])

    # 2) Build group from tickers that truly exist (via log_mid_*)
    present = [
        c.replace("log_mid_", "") for c in df.columns if c.startswith("log_mid_")
    ]
    group = [cfg.etf] + [t for t in cfg.tickers if t != cfg.etf]
    group = [t for t in group if t in present]

    # if ETF missing, try to fall back to a common ETF that exists
    if cfg.etf not in group:
        fallback = next((t for t in ["SPY", "QQQ", "IVV", "VOO"] if t in present), None)
        if fallback is None:
            raise ValueError(
                "ETF not found in df (log_mid_*). Provide an ETF present in your data."
            )
        print(f"[INFO] ETF '{cfg.etf}' not found. Using fallback ETF '{fallback}'.")
        cfg = Config(**{**cfg.__dict__, "etf": fallback})
        if fallback not in group:
            group = [fallback] + [t for t in group if t != cfg.etf]
            group = [t for t in group if t in present]

    if len(group) < 2:
        raise ValueError("Need at least two instruments with log_mid_* columns.")

    # 3) Returns (for lag calibration ONLY)
    df = add_delta_returns_from_log_mid(df, group)

    # 4) VECM panel on levels
    n_obs = len(df)
    if cfg.vecm_window >= n_obs:
        # keep window reasonable: ~1/3 of available samples, but at least 300
        cfg.vecm_window = max(300, n_obs // 3)
    if cfg.vecm_step >= cfg.vecm_window:
        # take about 20% of the window as step, but at least 50
        cfg.vecm_step = max(50, cfg.vecm_window // 5)

    panel = rolling_vecm_panel(df, cfg, group)
    if panel.empty:
        raise ValueError(
            "VECM panel is empty. Consider smaller vecm_window, larger step, or ensure no NaNs in log_mid_*."
        )

    # 5) Lag calibration (vs returns)
    lag_map = calibrate_lags(
        df=df,
        cfg=cfg,
        tickers=group,
        policy=cfg.lag_policy,
        max_lag=cfg.lag_max,
        calibration_frac=cfg.lag_calibration_frac,
    )
    joblib.dump(lag_map, f"{cfg.artifacts_dir}/lag_map_{cfg.lag_policy}.joblib")

    # 6) Apply lags
    df_lagged = apply_lags(df, lag_map)

    # 7) Build supervised dataset from VECM panel
    Xy = build_panel_dataset(df_lagged, panel, cfg, group)
    if Xy.empty:
        raise ValueError("Supervised dataset is empty after building from panel.")

    # subsets
    Xy_general = Xy.copy()
    Xy_open = Xy.loc[Xy["tod_bin"] == "open60"].reset_index(drop=True)
    Xy_close = Xy.loc[Xy["tod_bin"] == "close60"].reset_index(drop=True)

    # 8) Train
    model_general, preds_general, metrics_general = train_one_dataset(
        Xy_general, cfg, label=f"general_{cfg.model_type}_{cfg.task_type}"
    )

    preds_open = metrics_open = model_open = None
    if len(Xy_open) > 20:
        model_open, preds_open, metrics_open = train_one_dataset(
            Xy_open, cfg, label=f"open60_{cfg.model_type}_{cfg.task_type}"
        )
    else:
        print("Not enough open60 rows; skipping open model.")

    preds_close = metrics_close = model_close = None
    if len(Xy_close) > 20:
        model_close, preds_close, metrics_close = train_one_dataset(
            Xy_close, cfg, label=f"close60_{cfg.model_type}_{cfg.task_type}"
        )
    else:
        print("Not enough close60 rows; skipping close model.")

    # 9) Uniform viz + feature importances
    uniform_plots(
        preds_general,
        metrics_general,
        preds_open,
        metrics_open,
        preds_close,
        metrics_close,
        cfg,
    )

    meta_cols = ["ts_mid", "tod_bin", "y"]
    X_cols = [c for c in Xy_general.columns if c not in meta_cols]
    print("\nTop features (general):")
    print(feature_importance_table(model_general, X_cols, top_n=20))
    if model_open is not None:
        print("\nTop features (open60):")
        print(feature_importance_table(model_open, X_cols, top_n=20))
    if model_close is not None:
        print("\nTop features (close60):")
        print(feature_importance_table(model_close, X_cols, top_n=20))

    return {
        "lag_map": lag_map,
        "panel": panel,
        "Xy": Xy,
        "models": {
            "general": model_general,
            "open60": model_open,
            "close60": model_close,
        },
        "preds": {
            "general": preds_general,
            "open60": preds_open,
            "close60": preds_close,
        },
        "metrics": {
            "general": metrics_general,
            "open60": metrics_open,
            "close60": metrics_close,
        },
    }

In [4]:
universe = [
    "AAPL",
    "AMD",
    "AMZN",
    "AVGO",
    "BRK.B",
    "CEG",
    "DIA",
    "DJT",
    "EEM",
    "FXI",
    "GOOG",
    "GOOGL",
    "HYG",
    "IVV",
    "IWM",
    "JPM",
    "KRE",
    "LLY",
    "LQD",
    "META",
    "MSFT",
    "MSTR",
    "NVDA",
    "PLTR",
    "QQQ",
    "SHW",
    "SMCI",
    "SMH",
    "SOXL",
    "SPY",
    "SQQQ",
    "TQQQ",
    "TSLA",
    "VOO",
    "VUG",
    "XLE",
    "XLF",
    "XLU",
    "XOM",
]

# universe = ["AAPL","AMD","AMZN","AVGO","BRK.B","CEG","DIA","DJT","EEM","FXI","GOOG","GOOGL",
#             "HYG","IVV","IWM","JPM","KRE","LLY","LQD","META","MSFT","MSTR","NVDA","PLTR",
#             "QQQ","SHW","SMCI","SMH","SOXL","SPY","SQQQ","TLT","TQQQ","TSLA","VOO","VUG",
#             "XLE","XLF","XLU","XOM"]

cfg = Config(
    etf="TLT",
    tickers=universe,
    lag_policy="global",  # or "global"
    task_type="regression",  # "binary" (ILI ETF leads) OR "regression" (ILS ETF)
    model_type="linreg",  # "logit"|"linreg"|"rf"
)
artifacts = run_pipeline(df, cfg)

: 

In [None]:
from statsmodels.tsa.stattools import grangercausalitytests


def lead_lag_tests(
    df: pd.DataFrame,
    etf: str,
    tickers: List[str],
    max_lag: int = 10,
    tod: Optional[str] = None,  # "open60"|"mid"|"close60"|None
) -> pd.DataFrame:
    """
    Tests whether ETF returns lead/lag each stock's returns using:
      - max cross-correlation lag in [-max_lag, +max_lag]
      - one-sided Granger tests (ETF -> stock and stock -> ETF)
    Returns a tidy table of results. 'tod' can filter to a specific time-of-day bin.
    """
    # ensure returns exist
    have = [t for t in tickers if f"delta_log_mid_{t}" in df.columns]
    if f"delta_log_mid_{etf}" not in df.columns:
        raise ValueError(f"Missing delta_log_mid_{etf} for lead/lag tests.")
    # ToD filter
    tmp = df.copy()
    tmp["tod_bin"] = tod_bucket(
        pd.to_datetime(tmp["ts"], utc=True), Config(etf=etf, tickers=tickers)
    )
    if tod is not None:
        tmp = tmp.loc[tmp["tod_bin"] == tod].copy()

    # Same-day safety: we will align by shifting and drop cross-day pairs
    def safe_shift(s: pd.Series, L: int) -> pd.Series:
        # Positive L: shift forward (lead)
        if L == 0:
            return s
        out = s.shift(L)
        same = date_utc(
            s.index.to_series().map(lambda i: tmp.loc[i, "ts"])
        ) == date_utc(
            out.index.to_series().map(lambda i: tmp.loc[i, "ts"])
            if out.notna().all()
            else tmp["ts"].shift(L)
        )
        out.loc[~same] = np.nan
        return out

    results = []
    r_etf = tmp[f"delta_log_mid_{etf}"].astype(float)

    for t in have:
        r_stk = tmp[f"delta_log_mid_{t}"].astype(float)

        # Cross-correlation peak lag
        best_lag, best_corr = 0, 0.0
        for L in range(-max_lag, max_lag + 1):
            if L >= 0:
                c = r_stk.corr(r_etf.shift(L))
            else:
                c = r_stk.shift(-L).corr(r_etf)
            if np.isfinite(c) and abs(c) > abs(best_corr):
                best_corr, best_lag = c, L

        # Granger tests (small sample safe-guard: require enough rows)
        g_etf_to_stk_p = np.nan
        g_stk_to_etf_p = np.nan
        try:
            # dropna common
            D = pd.DataFrame({"stk": r_stk, "etf": r_etf}).dropna()
            if len(D) > (max_lag + 20):
                res1 = grangercausalitytests(
                    D[["stk", "etf"]], maxlag=min(max_lag, 5), verbose=False
                )
                g_etf_to_stk_p = min(d[0]["ssr_chi2test"][1] for k, d in res1.items())
                res2 = grangercausalitytests(
                    D[["etf", "stk"]], maxlag=min(max_lag, 5), verbose=False
                )
                g_stk_to_etf_p = min(d[0]["ssr_chi2test"][1] for k, d in res2.items())
        except Exception:
            pass

        # Interpret lag sign: best_lag > 0 means ETF return at t+L aligns with stock at t → stock *leads* ETF
        # best_lag < 0 means ETF leads stocks.
        interpretation = (
            "stock_leads_etf"
            if best_lag > 0
            else ("etf_leads_stock" if best_lag < 0 else "synchronous")
        )

        results.append(
            {
                "ticker": t,
                "xcorr_best_lag": int(best_lag),
                "xcorr_at_best": float(best_corr),
                "xcorr_interpretation": interpretation,
                "granger_p_etf_to_stock": g_etf_to_stk_p,
                "granger_p_stock_to_etf": g_stk_to_etf_p,
                "tod_bin": tod or "all",
            }
        )

    return (
        pd.DataFrame(results)
        .sort_values(["tod_bin", "xcorr_best_lag", "ticker"])
        .reset_index(drop=True)
    )


def add_returns_from_log_mid(df: pd.DataFrame, tickers: List[str]) -> pd.DataFrame:
    """
    Creates delta_log_mid_{T} by differencing log_mid_{T} *within each UTC date* to avoid
    cross-day leakage. Fills the first diff of the day with 0.0.
    """
    out = df.copy()
    out["date"] = date_utc(out["ts"])
    for t in tickers:
        col = f"log_mid_{t}"
        if col in out.columns:
            out[f"delta_log_mid_{t}"] = (
                out.groupby("date", group_keys=False)[col].diff().fillna(0.0)
            )
    return out.drop(columns=["date"])

In [None]:
# Build a minimal cfg for time handling
cfg_tmp = Config(etf="TLT", tickers=universe)

# 1) Ensure we have a UTC timestamp column 'ts'
df_ll = ensure_datetime(df, cfg_tmp)

# 2) Restrict to RTH (so ToD bins make sense)
df_ll = df_ll.loc[in_rth(df_ll["ts"], cfg_tmp)].reset_index(drop=True)

# 3) Create within-day returns for all tickers we have levels for
present_levels = [
    c.replace("log_mid_", "") for c in df_ll.columns if c.startswith("log_mid_")
]
df_ll = add_returns_from_log_mid(df_ll, present_levels)

# 4) Now run lead/lag
tbl_all = lead_lag_tests(df_ll, etf="TLT", tickers=universe, max_lag=10, tod=None)
tbl_open = lead_lag_tests(df_ll, etf="TLT", tickers=universe, max_lag=10, tod="open60")
tbl_close = lead_lag_tests(
    df_ll, etf="TLT", tickers=universe, max_lag=10, tod="close60"
)

In [47]:
tbl_all

Unnamed: 0,ticker,xcorr_best_lag,xcorr_at_best,xcorr_interpretation,granger_p_etf_to_stock,granger_p_stock_to_etf,tod_bin
0,JPM,-1,-0.006763,etf_leads_stock,2.115862e-05,2.941082e-06,all
1,XLE,-1,-0.017436,etf_leads_stock,0.02790372,1.722661e-44,all
2,XOM,-1,-0.01807,etf_leads_stock,0.008118418,2.357622e-44,all
3,AAPL,0,0.006596,synchronous,4.650405e-05,0.000538845,all
4,AMD,0,0.012721,synchronous,2.039351e-08,0.002703066,all
5,AMZN,0,0.010678,synchronous,1.733746e-11,0.1072253,all
6,BRK.B,0,-0.006843,synchronous,3.357168e-07,0.009279677,all
7,CEG,0,0.007103,synchronous,1.488443e-05,0.6027312,all
8,DIA,0,0.041887,synchronous,1.381951e-24,0.01410629,all
9,DJT,0,0.0,synchronous,,,all


In [48]:
tbl_open

Unnamed: 0,ticker,xcorr_best_lag,xcorr_at_best,xcorr_interpretation,granger_p_etf_to_stock,granger_p_stock_to_etf,tod_bin
0,LLY,-9,0.00691,etf_leads_stock,0.3841725,0.5367321,open60
1,KRE,-8,-0.014833,etf_leads_stock,0.001520882,0.0004606361,open60
2,GOOG,-6,0.012346,etf_leads_stock,0.02279695,0.1415995,open60
3,GOOGL,-6,0.012142,etf_leads_stock,0.02520223,0.1751493,open60
4,AAPL,-2,0.008454,etf_leads_stock,0.02290939,0.01338115,open60
5,JPM,-1,-0.011075,etf_leads_stock,0.05275763,0.003604224,open60
6,META,-1,0.009368,etf_leads_stock,0.1309723,0.01491098,open60
7,MSFT,-1,0.009721,etf_leads_stock,0.03388889,0.007917679,open60
8,XLE,-1,-0.020648,etf_leads_stock,0.5090386,3.998137e-10,open60
9,XOM,-1,-0.021703,etf_leads_stock,0.07752908,3.460679e-09,open60


In [49]:
tbl_close

Unnamed: 0,ticker,xcorr_best_lag,xcorr_at_best,xcorr_interpretation,granger_p_etf_to_stock,granger_p_stock_to_etf,tod_bin
0,MSFT,-10,0.011341,etf_leads_stock,0.01545801,0.01609708,close60
1,JPM,-9,-0.013559,etf_leads_stock,0.006440816,0.003851883,close60
2,MSTR,-9,-0.0059,etf_leads_stock,0.1176376,0.5654146,close60
3,TSLA,-9,-0.008816,etf_leads_stock,0.2737776,0.1395291,close60
4,SOXL,-8,0.010144,etf_leads_stock,0.1358024,0.2664878,close60
5,EEM,-7,-0.013312,etf_leads_stock,0.005593888,0.02830344,close60
6,XLE,-6,0.009808,etf_leads_stock,0.4080116,0.3881841,close60
7,AAPL,-5,0.016562,etf_leads_stock,6.736399e-08,8.590721e-06,close60
8,CEG,-5,0.013039,etf_leads_stock,0.03392587,0.00116448,close60
9,AMD,-4,-0.008365,etf_leads_stock,0.293752,0.1791292,close60
