In [None]:
from __future__ import annotations

import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Tuple

import numpy as np
import pandas as pd

In [None]:
LTR_HDF_PATH = Path("features/ltr_df.h5")
LTR_HDF_KEY  = "ltr_df"
ltr_df = pd.read_hdf(LTR_HDF_PATH, key=LTR_HDF_KEY)

In [None]:
BACKTEST_H5 = "D:/oxford/Dissertation/cs_strats/results/norm_by_lag/h5/bocd_clean/backtest_df_mlogloss_mom_bocd_exp_clean_mono_con_norm_by_lag_std.h5"
BACKTEST_KEY = 'backtest_df_mlogloss_mom_bocd_exp_clean_mono_con_norm_by_lag_std'

BACKTEST_H5 = "YOUR_BACKTEST_DF_PATH"
BACKTEST_KEY = 'YOUR_BACKTEST_KEY'

backtest_df = pd.read_hdf(BACKTEST_H5, BACKTEST_KEY)

In [None]:
TARGET_COL = "fwd_ret_div_sigma_lag_63"


@dataclass
class InteractionSpec:
    """
    Slice definition: a momentum feature and a second (state) feature.

    Momentum is split at `ret_neg_threshold` into NEG/POS. The other feature
    is binned into LOW/HIGH using the `low_q` and `high_q` quantiles; bin
    names come from `low_label` and `high_label`.
    """
    name: str
    momentum_feature: str
    other_feature: str
    low_q: float = 0.20
    high_q: float = 0.80
    low_label: str = "MATURE"
    high_label: str = "YOUNG"
    ret_neg_threshold: float = 0.0


# Helpers
def _slugify(text: str) -> str:
    text = text.strip().lower()
    text = re.sub(r"[^a-z0-9]+", "-", text)
    text = re.sub(r"-+", "-", text).strip("-")
    return text or "interaction"


def _share_long(x: pd.Series) -> float:
    return float((x > 0.0).mean()) if len(x) else np.nan


def _mean_and_se(x: pd.Series) -> Tuple[float, float]:
    mu = float(x.mean())
    se = float(x.std(ddof=1) / np.sqrt(len(x))) if len(x) > 1 else np.nan
    return mu, se


def _add_signed_percentile_by_date(
    df: pd.DataFrame,
    value_col: str,
    date_col: str = "date",
    new_col: str = "target_pct_signed",
) -> pd.DataFrame:
    """
    Per-date signed percentile in (-1, 1): 2 * (rank/(n+1)) - 1.
    Robust to ties and comparable across dates.
    """
    pct = df.groupby(date_col, observed=True)[value_col].rank(method="average", pct=True)
    n = df.groupby(date_col, observed=True)[value_col].transform("count")
    factor = n / (n + 1.0)
    df[new_col] = 2.0 * pct * factor - 1.0
    return df


# Core functions
def build_oos_master(
    ltr_df: pd.DataFrame,
    backtest_df: pd.DataFrame,
    features: Iterable[str],
) -> pd.DataFrame:
    """
    Build a minimal OOS frame (test rows only) with required columns.
    Ensures `TARGET_COL` is present and adds a per-date signed percentile of it.
    """
    cols_bt = ["date", "permno", "p_bottom_test", "p_middle_test", "p_top_test", "net_score_test"]
    if TARGET_COL in backtest_df.columns:
        cols_bt = cols_bt + [TARGET_COL]

    for c in cols_bt:
        if c not in backtest_df.columns:
            raise KeyError(f"Required column missing in backtest_df: {c}")

    keep_ltr = ["date", "permno"] + sorted(set(features))
    if TARGET_COL not in backtest_df.columns:
        keep_ltr = keep_ltr + [TARGET_COL]

    for c in keep_ltr:
        if c not in ltr_df.columns:
            raise KeyError(f"Required column missing in ltr_df: {c}")

    merged = (
        ltr_df[keep_ltr]
        .merge(backtest_df[cols_bt], on=["date", "permno"], how="inner", validate="one_to_one")
        .copy()
    )
    merged["date"] = pd.to_datetime(merged["date"])
    if TARGET_COL not in merged.columns:
        raise KeyError(f"Could not find {TARGET_COL} after merge.")

    oos = merged.loc[merged["net_score_test"].notna()].copy()
    oos = _add_signed_percentile_by_date(oos, TARGET_COL, date_col="date", new_col="target_pct_signed")
    return oos


def run_single_interaction(
    oos_master: pd.DataFrame,
    spec: InteractionSpec,
) -> Dict[str, pd.DataFrame]:
    """
    Run one interaction:
      • bin by momentum (NEG/POS) and other feature (LOW/HIGH),
      • compute date-balanced cell stats,
      • compute paired diffs (HIGH − LOW) within NEG and POS.

    Returns: thresholds, cell_summary, paired_diff_NEG, paired_diff_POS.
    """
    needed = [
        spec.momentum_feature,
        spec.other_feature,
        "net_score_test",
        "p_top_test",
        "p_bottom_test",
        TARGET_COL,
        "target_pct_signed",
    ]
    for c in needed:
        if c not in oos_master.columns:
            raise KeyError(f"Column '{c}' not found in oos_master.")

    oos = oos_master.dropna(subset=[spec.momentum_feature, spec.other_feature]).copy()

    # Quantile cuts for the other feature
    low_cut = float(oos[spec.other_feature].quantile(spec.low_q))
    high_cut = float(oos[spec.other_feature].quantile(spec.high_q))

    # Define bins
    oos["momentum_bin"] = np.where(oos[spec.momentum_feature] < spec.ret_neg_threshold, "NEG", "POS")
    oos["other_bin"] = np.select(
        [oos[spec.other_feature] <= low_cut, oos[spec.other_feature] >= high_cut],
        [spec.low_label, spec.high_label],
        default="MID",
    )
    oos = oos.loc[oos["other_bin"].isin([spec.low_label, spec.high_label])].copy()

    # Per-date cell metrics (date-balanced)
    per_date = (
        oos.groupby(["date", "momentum_bin", "other_bin"], observed=True)
        .agg(
            n=("net_score_test", "size"),
            mean_s=("net_score_test", "mean"),
            share_long=("net_score_test", _share_long),
            mean_p_top=("p_top_test", "mean"),
            mean_p_bottom=("p_bottom_test", "mean"),
            mean_y=(TARGET_COL, "mean"),
            mean_target_pct=("target_pct_signed", "mean"),
        )
        .reset_index()
    )

    # Collapse across dates (each date gets equal weight)
    summaries: List[Dict[str, float]] = []
    for (mom, oth), g in per_date.groupby(["momentum_bin", "other_bin"], observed=True):
        mu_s, se_s = _mean_and_se(g["mean_s"])
        mu_share, se_share = _mean_and_se(g["share_long"])
        mu_ptop, se_ptop = _mean_and_se(g["mean_p_top"])
        mu_pbot, se_pbot = _mean_and_se(g["mean_p_bottom"])
        mu_y, se_y = _mean_and_se(g["mean_y"])
        mu_pct, se_pct = _mean_and_se(g["mean_target_pct"])
        summaries.append(
            {
                "interaction": spec.name,
                "momentum_bin": mom,
                "other_bin": oth,
                "n_dates": int(g["date"].nunique()),
                "n_obs_total": int(g["n"].sum()),
                "mean_s": mu_s,
                "se_mean_s": se_s,
                "share_long": mu_share,
                "se_share_long": se_share,
                "mean_p_top": mu_ptop,
                "se_mean_p_top": se_ptop,
                "mean_p_bottom": mu_pbot,
                "se_mean_p_bottom": se_pbot,
                "mean_y": mu_y,
                "se_mean_y": se_y,
                "mean_target_pct": mu_pct,
                "se_mean_target_pct": se_pct,
            }
        )
    cell_summary = pd.DataFrame(summaries).sort_values(["momentum_bin", "other_bin"])

    # Paired HIGH − LOW diffs within each momentum side
    def _paired_diff(df_in: pd.DataFrame, momentum_label: str) -> pd.DataFrame:
        g = df_in.loc[df_in["momentum_bin"] == momentum_label].copy()

        # score
        wide_s = g.pivot(index="date", columns="other_bin", values="mean_s")
        wide_s = wide_s.loc[wide_s[[spec.low_label, spec.high_label]].notna().all(axis=1)]
        diff_s = wide_s[spec.high_label] - wide_s[spec.low_label]

        # share_long
        wide_share = g.pivot(index="date", columns="other_bin", values="share_long")
        wide_share = wide_share.loc[wide_share[[spec.low_label, spec.high_label]].notna().all(axis=1)]
        diff_share = wide_share[spec.high_label] - wide_share[spec.low_label]

        # realized target (raw)
        wide_y = g.pivot(index="date", columns="other_bin", values="mean_y")
        wide_y = wide_y.loc[wide_y[[spec.low_label, spec.high_label]].notna().all(axis=1)]
        diff_y = wide_y[spec.high_label] - wide_y[spec.low_label]

        # realized target (signed percentile)
        wide_pct = g.pivot(index="date", columns="other_bin", values="mean_target_pct")
        wide_pct = wide_pct.loc[wide_pct[[spec.low_label, spec.high_label]].notna().all(axis=1)]
        diff_pct = wide_pct[spec.high_label] - wide_pct[spec.low_label]

        n_dates = int(len(diff_s))
        mu_diff_s, se_diff_s = _mean_and_se(diff_s) if n_dates else (np.nan, np.nan)
        mu_diff_share, se_diff_share = _mean_and_se(diff_share) if n_dates else (np.nan, np.nan)
        mu_diff_y, se_diff_y = _mean_and_se(diff_y) if n_dates else (np.nan, np.nan)
        mu_diff_pct, se_diff_pct = _mean_and_se(diff_pct) if n_dates else (np.nan, np.nan)

        return pd.DataFrame(
            {
                "interaction": [spec.name],
                "momentum_bin": [momentum_label],
                "n_dates_paired": [n_dates],
                f"mean_diff_s_{spec.high_label}_minus_{spec.low_label}": [mu_diff_s],
                "se_diff_s": [se_diff_s],
                f"mean_diff_share_long_{spec.high_label}_minus_{spec.low_label}": [mu_diff_share],
                "se_diff_share_long": [se_diff_share],
                f"mean_diff_y_{spec.high_label}_minus_{spec.low_label}": [mu_diff_y],
                "se_diff_y": [se_diff_y],
                f"mean_diff_target_pct_{spec.high_label}_minus_{spec.low_label}": [mu_diff_pct],
                "se_diff_target_pct": [se_diff_pct],
            }
        )

    paired_diff_neg = _paired_diff(per_date, "NEG")
    paired_diff_pos = _paired_diff(per_date, "POS")

    thresholds = pd.DataFrame(
        [
            {
                "interaction": spec.name,
                "momentum_feature": spec.momentum_feature,
                "other_feature": spec.other_feature,
                "ret_neg_threshold": spec.ret_neg_threshold,
                "low_q": spec.low_q,
                "high_q": spec.high_q,
                "low_cut_value": low_cut,
                "high_cut_value": high_cut,
                "low_label": spec.low_label,
                "high_label": spec.high_label,
                "target_col": TARGET_COL,
            }
        ]
    )

    return {
        "thresholds": thresholds,
        "cell_summary": cell_summary,
        "paired_diff_NEG": paired_diff_neg,
        "paired_diff_POS": paired_diff_pos,
    }


def run_all_interactions(
    ltr_df: pd.DataFrame,
    backtest_df: pd.DataFrame,
    interactions: List[InteractionSpec],
    results_root: str = "slice_results/mom_bocd",
    save: bool = True,
) -> Dict[str, Dict[str, pd.DataFrame]]:
    """
    Run all specs and optionally write CSVs (fixed output folder, no timestamp).
    Returns a nested dict: {interaction_name: {tables...}}.
    """
    # Build once with the union of required features
    feat_union = set()
    for sp in interactions:
        feat_union.add(sp.momentum_feature)
        feat_union.add(sp.other_feature)
    oos_master = build_oos_master(ltr_df, backtest_df, features=feat_union)

    if save:
        os.makedirs(results_root, exist_ok=True)

    all_results: Dict[str, Dict[str, pd.DataFrame]] = {}
    for sp in interactions:
        res = run_single_interaction(oos_master, sp)
        all_results[sp.name] = res

        if save:
            subdir = os.path.join(results_root, _slugify(sp.name))
            os.makedirs(subdir, exist_ok=True)
            res["thresholds"].to_csv(os.path.join(subdir, "thresholds.csv"), index=False)
            res["cell_summary"].to_csv(os.path.join(subdir, "cell_summary.csv"), index=False)
            res["paired_diff_NEG"].to_csv(os.path.join(subdir, "paired_diff_NEG.csv"), index=False)
            res["paired_diff_POS"].to_csv(os.path.join(subdir, "paired_diff_POS.csv"), index=False)

    # Combined rollups
    if save and all_results:
        thresholds_all = pd.concat([v["thresholds"] for v in all_results.values()], ignore_index=True)
        cell_all = pd.concat([v["cell_summary"] for v in all_results.values()], ignore_index=True)
        diffs_neg_all = pd.concat([v["paired_diff_NEG"] for v in all_results.values()], ignore_index=True)
        diffs_pos_all = pd.concat([v["paired_diff_POS"] for v in all_results.values()], ignore_index=True)
        thresholds_all.to_csv(os.path.join(results_root, "ALL_thresholds.csv"), index=False)
        cell_all.to_csv(os.path.join(results_root, "ALL_cell_summary.csv"), index=False)
        diffs_neg_all.to_csv(os.path.join(results_root, "ALL_paired_diff_NEG.csv"), index=False)
        diffs_pos_all.to_csv(os.path.join(results_root, "ALL_paired_diff_POS.csv"), index=False)

    return all_results


# Default interaction set
INTERACTIONS = [
    InteractionSpec(
        name="slow_mom_vtskip_252_42_z  x  Pr_le_63 (mature=low)",
        momentum_feature="(252, 42)_day_vol_time_scaled_ret_z",
        other_feature="Pr_le_63",
        low_q=0.20,
        high_q=0.80,
        low_label="MATURE",
        high_label="YOUNG",
    ),
    InteractionSpec(
        name="slow_mom_vtskip_252_42_z  x  Pr_le_252_z (mature=low)",
        momentum_feature="(252, 42)_day_vol_time_scaled_ret_z",
        other_feature="Pr_le_252_z",
        low_q=0.20,
        high_q=0.80,
        low_label="MATURE",
        high_label="YOUNG",
    ),
    InteractionSpec(
        name="slow_mom_vtskip_252_42_z  x  E_rt_z (mature=high)",
        momentum_feature="(252, 42)_day_vol_time_scaled_ret_z",
        other_feature="E_rt_z",
        low_q=0.20,
        high_q=0.80,
        low_label="YOUNG",   # here HIGH means mature
        high_label="MATURE",
    ),
    InteractionSpec(
        name="slow_mom_vtskip_252_42_z  x  Var_rt_z (uncertainty)",
        momentum_feature="(252, 42)_day_vol_time_scaled_ret_z",
        other_feature="Var_rt_z",
        low_q=0.20,
        high_q=0.80,
        low_label="LOW_VAR",
        high_label="HIGH_VAR",
    ),
    InteractionSpec(
        name="macd_32_96_lag_3m_z  x  Pr_le_63 (mature=low)",
        momentum_feature="32_96_lag_3m_z",
        other_feature="Pr_le_63",
        low_q=0.20,
        high_q=0.80,
        low_label="MATURE",
        high_label="YOUNG",
    ),
    InteractionSpec(
        name="slow_mom_vtskip_252_42_z  x  Pr_le_126_z (mature=low)",
        momentum_feature="(252, 42)_day_vol_time_scaled_ret_z",
        other_feature="Pr_le_126_z",
        low_q=0.20,
        high_q=0.80,
        low_label="MATURE",
        high_label="YOUNG",
    ),
]


results = run_all_interactions(
    ltr_df,
    backtest_df,
    INTERACTIONS,
    results_root="slice_results/mom_bocd",
    save=True,
)
print(list(results.keys()))
results["slow_mom_vtskip_252_42_z  x  Pr_le_126_z (mature=low)"]["cell_summary"].head()


['slow_mom_vtskip_252_42_z  x  Pr_le_63 (mature=low)', 'slow_mom_vtskip_252_42_z  x  Pr_le_252_z (mature=low)', 'slow_mom_vtskip_252_42_z  x  E_rt_z (mature=high)', 'slow_mom_vtskip_252_42_z  x  Var_rt_z (uncertainty)', 'macd_32_96_lag_3m_z  x  Pr_le_63 (mature=low)', 'slow_mom_vtskip_252_42_z  x  Pr_le_126_z (mature=low)']


Unnamed: 0,interaction,momentum_bin,other_bin,n_dates,n_obs_total,mean_s,se_mean_s,share_long,se_share_long,mean_p_top,se_mean_p_top,mean_p_bottom,se_mean_p_bottom,mean_y,se_mean_y,mean_target_pct,se_mean_target_pct
0,slow_mom_vtskip_252_42_z x Pr_le_126_z (matu...,NEG,MATURE,5034,508849,-0.004312,0.000308,0.476951,0.001788,0.336309,0.000222,0.340621,0.000134,0.483564,0.043692,0.002168,0.001745
1,slow_mom_vtskip_252_42_z x Pr_le_126_z (matu...,NEG,YOUNG,4705,586828,-0.028341,0.000212,0.280858,0.001438,0.295146,0.000161,0.323487,0.000201,0.37833,0.033909,-0.020215,0.001584
2,slow_mom_vtskip_252_42_z x Pr_le_126_z (matu...,POS,MATURE,5034,493857,-0.002396,0.000323,0.509582,0.001866,0.34003,0.000148,0.342426,0.000231,0.559283,0.042764,0.012811,0.001798
3,slow_mom_vtskip_252_42_z x Pr_le_126_z (matu...,POS,YOUNG,4705,415878,0.005144,0.000222,0.487301,0.001824,0.324665,0.00018,0.319522,0.000184,0.493437,0.035699,0.001543,0.001756
