In [None]:
import sqlite3
import math
from typing import Dict, List, Union, Any, Optional, Set, Tuple
import pandas as pd
import numpy as np
import os, time, re
from collections import defaultdict

# ===== CONFIG =====
MAGI_DB_PATH = os.getenv("MAGI_DB_PATH", "/projects/klybarge/pcori_ad/magi/magi_db/magi.db")  # read-only; override via env
OUT_DIR = "./MAGI_LASSO"
os.makedirs(OUT_DIR, exist_ok=True)
EDGE_TABLE = "magi_counts_top500"   
uri = f"file:{MAGI_DB_PATH}?mode=ro"

TOP_K = 500  # 500 for others

TARGETS = [
    "aa_meas_amitriptyline_rem",
    "aa_meas_fluoxetine_rem", 
    "aa_meas_citalopram_rem",
    "aa_meas_venlafaxine_rem",
    "aa_meas_mirtazapine_rem",
    "aa_meas_sertraline_rem",
    "aa_meas_bupropion_rem",
    "aa_meas_trazodone_rem",
    "aa_meas_duloxetine_rem",
    "aa_meas_escitalopram_rem",
    "aa_meas_paroxetine_rem",
    "aa_meas_nortriptyline_rem",
    "aa_meas_other_rem",
    "aa_meas_doxepin_rem",
    "aa_meas_desvenlafaxine_rem",
]

# LEFT = k (predictor) and RIGHT = Y/j (outcome/downstream)
def analyze_causal_sequence_py(
    df_in: Union[str, pd.DataFrame],
    *,
    name_map: Dict[str, str] = None,     # raw_code -> friendly label (applied to BOTH columns)
    events: List[str] = None,            # event names to KEEP (AFTER recoding). If None: auto-detect
    force_outcome: str = None,           # if provided and present, force this to be the FINAL node (Y)
    lambda_min_count: int = 15           # L-threshold for λ: if n_code < L ⇒ λ_{k,j}=0
) -> Dict[str, Any]:
    """
    MAGI (Python): Reference routine with explicit comments.

    ─────────────────────────────────────────────────────────────────────────────
    COLUMN CONVENTION PER ROW:
    Left  column: target_concept_code  (this is k for the row)
    Right column: concept_code         (this is Y or j for the row)

    n_code_target        ≡  (RIGHT ∧ LEFT) = (Y/j ∧ k)
    n_code_no_target     ≡  (RIGHT ∧ ¬LEFT) = (Y/j ∧ ¬k)
    n_target_no_code     ≡  (LEFT ∧ ¬RIGHT) = (k ∧ ¬Y/j)
    n_no_target          ≡  total(¬LEFT) = total(¬k)
    n_target             ≡  total(LEFT)  = total(k)
    n_code_before_target ≡  count(RIGHT before LEFT) = count(Y/j before k)
    n_target_before_code ≡  count(LEFT before RIGHT) = count(k before Y/j)

    ORIENTATION (LEFT = k, RIGHT = Y/j):
    • Total effect T_{kY}:   read row (target = k, code = Y).
    • Lambda     λ_{k,j}:    read row (target = k, code = j),
                            λ_{k,j} = (j ∧ k) / total(k) with L-threshold on total(k).

    TEMPORAL SCORE for each node v:
    Score(v) = Σ_{k≠v} [ C(k≺v) - C(v≺k) + C(k∧¬v) - C(v∧¬k) ]
    Read rows (target = k, code = v).

        - n_code_before_target      → C(Z_j≺Z_i)
        - n_target_before_code      → C(Z_i≺Z_j)
        - n_code_no_target          → C(Z_j∧¬Z_i)
        - n_target_no_code          → C(Z_i∧¬Z_j)

  T_{kY} from row (k, Y):
      a = n_code_target        (k ∧ Y)
      b = n_code_no_target     (k ∧ ¬Y)
      c = n_target_no_code     (¬k ∧ Y)
      d = n_no_target - b      (¬k ∧ ¬Y)   ← computed on the fly (no extra column needed)
      With sample-size–anchored odds:
         odds_k1 = a/b with guards; odds_k0 = c/d with guards; T = odds_k1 / odds_k0

    DIRECT EFFECTS via backward recursion:
      D_{k,Y} = ( T_{k,Y} - Σ_j λ_{k,j} D_{j,Y} ) / ( 1 - Σ_j λ_{k,j} ),
      where j are downstream nodes between k and Y in the temporal order.

    LOGISTIC LINK:
      logit P(Y=1 | Z) = β0 + Σ_k β_k Z_k  with β_k = log D_{k,Y};
      invalid/nonpositive D map to β_k=0.

    RETURNS a dict with:
      sorted_scores, temporal_order, order_used,
      T_val, D_val, lambda_l, coef_df, beta_0, beta, logit_predictors, predict_proba
    """
    # ── 0) Ingest & validate ───────────────────────────────────────────────────
    df = pd.read_csv(df_in) if isinstance(df_in, str) else df_in.copy()

    need_cols = [
        "target_concept_code", "concept_code",
        "n_code_target", "n_code_no_target",
        "n_target", "n_no_target",
        "n_target_no_code",
        "n_code",
        "n_code_before_target", "n_target_before_code",
    ]
    missing = [c for c in need_cols if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns: {', '.join(missing)}")

    # Always treat endpoints as strings to avoid silent filter drop
    df["target_concept_code"] = df["target_concept_code"].astype(str)
    df["concept_code"]        = df["concept_code"].astype(str)

    # Optional recoding
    if name_map:
        df["target_concept_code"] = df["target_concept_code"].replace(name_map)
        df["concept_code"]        = df["concept_code"].replace(name_map)

    # Limit to selected events (union if auto)
    if events is None:
        ev_t = df["target_concept_code"].unique().tolist()
        ev_c = df["concept_code"].unique().tolist()
        events = sorted(set(ev_t) | set(ev_c))
    else:
        events = [str(e) for e in events]

    if len(events) < 2:
        raise ValueError("Need at least two events.")

    df = df[df["target_concept_code"].isin(events) & df["concept_code"].isin(events)].copy()

    # Coerce numerics
    num_cols = [
        "n_code_target", "n_code_no_target",
        "n_target", "n_no_target", "n_target_no_code",
        "n_code", "n_code_before_target", "n_target_before_code",
    ]
    for c in num_cols:
        df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0.0)

    # ── 1) Temporal score (read rows: target=k, code=v) ───────────────────────────
    scores: Dict[str, float] = {}
    for v in events:
        s = 0.0
        for k in (x for x in events if x != v):
            # Row oriented as (k, v)
            pair = df[(df["target_concept_code"] == k) & (df["concept_code"] == v)]
            if pair.empty:
                continue
            # In this row:
            #   n_target_before_code -> count(k before v)
            #   n_code_before_target -> count(v before k)
            #   n_target_no_code     -> count(k ∧ ¬v)
            #   n_code_no_target     -> count(v ∧ ¬k)
            k_before_v   = float(pair["n_target_before_code"].sum(skipna=True))
            v_before_k   = float(pair["n_code_before_target"].sum(skipna=True))
            k_and_not_v  = float(pair["n_target_no_code"].sum(skipna=True))
            v_and_not_k  = float(pair["n_code_no_target"].sum(skipna=True))
            s += (k_before_v - v_before_k + k_and_not_v - v_and_not_k)
        scores[v] = s

    sorted_scores = pd.Series(scores).sort_values(ascending=False)

    # Choose outcome Y: either forced or top-scoring node
    if force_outcome and (force_outcome in sorted_scores.index):
        outcome = force_outcome
        temporal_order = [ev for ev in sorted_scores.index if ev != outcome] + [outcome]
    else:
        outcome = sorted_scores.index[0]
        temporal_order = [ev for ev in sorted_scores.index if ev != outcome] + [outcome]

    events_order = temporal_order              # earliest … → Y
    nodes = events_order[:-1]            # everything before Y

    # ── 2) T and λ (row orientations locked) ───────────────────────────────────
    T_val = pd.Series(0.0, index=nodes, dtype=float)
    D_val = pd.Series(np.nan, index=nodes, dtype=float)
    lambda_l: Dict[str, pd.Series] = {}

    for k in nodes:
        # ---- T_{kY} from the single row (target=k, code=Y) ----
        row_kY = df[(df["target_concept_code"] == k) & (df["concept_code"] == outcome)]

        # 2×2 cells (a,b,c,d) at (Y,k):
        a = float(row_kY["n_code_target"].sum(skipna=True))        # Y ∧ k
        b = float(row_kY["n_code_no_target"].sum(skipna=True))     # Y ∧ ¬k
        c = float(row_kY["n_target_no_code"].sum(skipna=True))     # k ∧ ¬Y
        n_noY = float(row_kY["n_no_target"].max(skipna=True)) if not row_kY.empty else 0.0
        d = max(n_noY - b, 0.0)                                    # ¬k ∧ ¬Y

        # Stratum sizes (corrected)
        N1 = a + c   # k = 1
        N0 = b + d   # k = 0

        # Sample-size–anchored odds (guarded) -- corrected pairings
        if N1 == 0:
            odds_k1 = 1.0
        else:
            if a == 0:
                odds_k1 = 1.0 / (N1 + 1.0)
            elif c == 0:
                odds_k1 = (N1 + 1.0)
            else:
                odds_k1 = a / c

        if N0 == 0:
            odds_k0 = 1.0
        else:
            if b == 0:
                odds_k0 = 1.0 / (N0 + 1.0)
            elif d == 0:
                odds_k0 = (N0 + 1.0)
            else:
                odds_k0 = b / d

        T_val.loc[k] = float(odds_k1 / odds_k0) if odds_k0 > 0 else (N1 + 1.0)

        # ---- λ_{k,j} from rows (target=j, code=k): λ = n_code_target / n_code ----
        pos_k = events_order.index(k)
        js = events_order[pos_k + 1 : -1] if pos_k < len(events_order) - 1 else []

        lam_pairs = []
        for j in js:
            # ---- λ_{k,j} from rows (target=k, code=j): λ = (j ∧ k) / total(k) ----
            row_kj = df[(df["target_concept_code"] == k) & (df["concept_code"] == j)]
            if row_kj.empty:
                lam_pairs.append((j, 0.0))
                continue

            num = float(pd.to_numeric(row_kj["n_code_target"], errors="coerce").fillna(0.0).sum())  # j ∧ k
            # With LEFT = k, total(k) is stored in n_target on this row:
            den = float(pd.to_numeric(row_kj["n_target"], errors="coerce").fillna(0.0).max())

            # L-threshold on total(k)
            if (den <= 0) or (den < lambda_min_count):
                lam_pairs.append((j, 0.0))
                continue

            lam = num / den
            lam = 0.0 if not np.isfinite(lam) else float(min(max(lam, 0.0), 1.0))
            lam_pairs.append((j, lam))

        lambda_l[k] = pd.Series({j: v for j, v in lam_pairs}, dtype=float)

    # ── 3) Backward recursion for direct effects D ─────────────────────────────
    # Last antecedent (just before Y): no downstream → D = T
    if len(nodes) >= 1:
        last_anc = nodes[-1]
        D_val.loc[last_anc] = T_val.loc[last_anc]

    # Walk backward for the rest
    if len(nodes) > 1:
        for k in list(reversed(nodes[:-1])):
            lam_vec = lambda_l.get(k, pd.Series(dtype=float))
            downstream = list(lam_vec.index)  # j nodes after k (already resolved)
            lam_vals = lam_vec.reindex(downstream).fillna(0.0).to_numpy()
            D_down  = pd.to_numeric(D_val.reindex(downstream), errors="coerce").fillna(0.0).to_numpy()

            num = T_val.loc[k] - float(np.nansum(lam_vals * D_down))
            den = 1.0 - float(np.nansum(lam_vals))

            if (not np.isfinite(den)) or den == 0.0:
                D_val.loc[k] = T_val.loc[k]            # neutralize if pathological
            else:
                tmp = num / den
                D_val.loc[k] = tmp if np.isfinite(tmp) else T_val.loc[k]

    # ── 4) Logistic link (β) and predict_proba ─────────────────────────────────
    # Intercept β0 from marginal prevalence of Y (rows with target == Y)
    resp_rows = df[df["target_concept_code"] == outcome]
    n_t = float(resp_rows["n_target"].max()) if not resp_rows.empty else np.nan
    n_n = float(resp_rows["n_no_target"].max()) if not resp_rows.empty else np.nan
    denom = n_t + n_n
    p_y = 0.5 if (not np.isfinite(denom) or denom <= 0) else (n_t / denom)
    p_y = min(max(p_y, 1e-12), 1 - 1e-12)
    beta_0 = float(np.log(p_y / (1 - p_y)))

    # β_k = log D_{k,Y}; invalid/nonpositive → 0
    D_clean = pd.to_numeric(D_val, errors="coerce").astype(float)
    beta_vals = np.log(D_clean.where(D_clean > 0.0)).replace([np.inf, -np.inf], np.nan).fillna(0.0)

    coef_df = pd.DataFrame({
        "predictor": list(beta_vals.index) + ["(intercept)"],
        "beta":      list(beta_vals.values) + [beta_0],
    })

    # Vectorized predict_proba
    predictors = list(beta_vals.index)
    beta_vec = beta_vals.values

    def predict_proba(Z: Union[Dict[str, Any], pd.Series, np.ndarray, List[float], pd.DataFrame]):
        """
        Compute P(Y=1|Z) using: logit P = β0 + Σ_k β_k Z_k.
        Z can be:
          - dict/Series mapping predictor name -> 0/1
          - 1D/2D numpy/list with columns ordered as `predictors`
          - DataFrame containing any/all of `predictors` (others ignored)
        """
        def sigmoid(x):
            x = np.clip(x, -700, 700)  # numerical stability for large |η|
            return 1.0 / (1.0 + np.exp(-x))

        if isinstance(Z, pd.DataFrame):
            M = Z.reindex(columns=predictors, fill_value=0.0).astype(float).to_numpy()
            return sigmoid(beta_0 + M @ beta_vec)

        if isinstance(Z, (dict, pd.Series)):
            v = np.array([float(Z.get(p, 0.0)) for p in predictors], dtype=float)
            return float(sigmoid(beta_0 + float(v @ beta_vec)))

        arr = np.asarray(Z, dtype=float)
        if arr.ndim == 1:
            if arr.size != len(predictors):
                raise ValueError(f"Expected {len(predictors)} features in order: {predictors}")
            return float(sigmoid(beta_0 + float(arr @ beta_vec)))
        if arr.ndim == 2:
            if arr.shape[1] != len(predictors):
                raise ValueError(f"Expected shape (*,{len(predictors)}), got {arr.shape}")
            return sigmoid(beta_0 + arr @ beta_vec)

        raise ValueError("Unsupported input for predict_proba")

    # ── 5) Package results ─────────────────────────────────────────────────────
    return {
        "sorted_scores": sorted_scores,
        "temporal_order": temporal_order,
        "order_used": events_order,
        "T_val": T_val,
        "D_val": D_val,
        "lambda_l": lambda_l,
        "coef_df": coef_df,
        "beta_0": beta_0,
        "beta": pd.Series(beta_vec, index=predictors, dtype=float),
        "logit_predictors": predictors,
        "predict_proba": predict_proba,
    }

def _ensure_derived_cols(df: pd.DataFrame) -> pd.DataFrame:
    """
    Only coerce numerics and compute total_effect if absent, using provided cells and
    computing d = n_no_target - n_code_no_target on the fly (since d isn't in data).
    """
    required = [
        "n_code_target", "n_code_no_target",
        "n_target_no_code", "n_target", "n_no_target"
    ]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns for TE: {', '.join(missing)}")

    out = df.copy()
    for c in required:
        out[c] = pd.to_numeric(out[c], errors="coerce").fillna(0.0)

    if "total_effect" not in out.columns:
        a = out["n_code_target"].astype(float)                  # Y ∧ k
        b = out["n_code_no_target"].astype(float)               # Y ∧ ¬k
        c = out["n_target_no_code"].astype(float)               # k ∧ ¬Y
        d = (out["n_no_target"] - out["n_code_no_target"]).astype(float)  # ¬k ∧ ¬Y
        d = np.maximum(d, 0.0)

        N1 = a + c  # k=1
        N0 = b + d  # k=0

        with np.errstate(divide="ignore", invalid="ignore"):
            odds_k1 = np.where(
                N1 == 0, 1.0,
                np.where((a > 0) & (c > 0), a / c,
                        np.where(c == 0, N1 + 1.0, 1.0 / (N1 + 1.0)))
            )
            odds_k0 = np.where(
                N0 == 0, 1.0,
                np.where((b > 0) & (d > 0), b / d,
                        np.where(d == 0, N0 + 1.0, 1.0 / (N0 + 1.0)))
            )
            te = odds_k1 / odds_k0

        out["total_effect"] = np.where(np.isfinite(te), te, 1.0).astype(float)

    return out


# TOPK_HI/LO
def _select_top_by_te_unique_k(k_to_T: pd.DataFrame, top_k: int = None, hi: float = 1.5) -> pd.DataFrame:
    if top_k is None:
        top_k = TOP_K
    """
    Balanced selection without fallback:
      - Take up to top_k//2 strongest risk (TE >= hi)
      - Take up to top_k - top_k//2 strongest protective (TE <= 1/hi)
      - If either side is short, do NOT fill from elsewhere; total may be < top_k
      - Return one row per k (**concept_code**), ranked by extremeness
    """
    df = k_to_T.copy()

    # Ensure total_effect numeric
    df["total_effect"] = pd.to_numeric(df["total_effect"], errors="coerce")
    df = df.dropna(subset=["total_effect"]).copy()
    if df.empty:
        print("[SELECT] no rows with total_effect; returning empty selection.")
        return df

    # Extremeness symmetrical around 1
    df["effect_strength"] = np.where(df["total_effect"] >= 1.0,
                                     df["total_effect"],
                                     1.0 / df["total_effect"])

    # One row per k: k is LEFT = target_concept_code in (target=k, code=Y)
    best_per_k = (df.sort_values("effect_strength", ascending=False)
                    .drop_duplicates(subset=["target_concept_code"], keep="first"))

    HI = hi
    LO = 1.0 / hi

    risk_pool = best_per_k[best_per_k["total_effect"] >= HI].copy()
    prot_pool = best_per_k[best_per_k["total_effect"] <= LO].copy()

    want_risk = top_k // 2
    want_prot = top_k - want_risk

    sel_risk = risk_pool.nlargest(min(len(risk_pool), want_risk), "effect_strength")
    sel_prot = prot_pool.nlargest(min(len(prot_pool), want_prot), "effect_strength")

    selected = pd.concat([sel_risk, sel_prot], ignore_index=True)

    print(f"[SELECT] total unique k={best_per_k['target_concept_code'].nunique():,}  "
        f"risk={len(risk_pool):,}  prot={len(prot_pool):,}  "
        f"selected(total)={selected['target_concept_code'].nunique()}  "
        f"with: risk={len(sel_risk)}, prot={len(sel_prot)}")

    return selected.reset_index(drop=True)


def _fetch_k_to_T(conn, outcome_code: str) -> pd.DataFrame:
    """
    Fetch all rows whose RIGHT (concept) == outcome Y.
    These rows provide the (k, Y) lines needed for T_{kY}.
    """
    q = """
      SELECT m.*,
             tcn.concept_code AS target_concept_code,   -- LEFT  (k)
             ccn.concept_code AS concept_code           -- RIGHT (Y)
      FROM magi_counts_top500 m
      JOIN concept_names tcn ON m.target_concept_code_int = tcn.concept_code_int
      JOIN concept_names ccn ON m.concept_code_int        = ccn.concept_code_int
      WHERE ccn.concept_code = ?                         -- <<< RIGHT == Y
    """
    return pd.read_sql_query(q, conn, params=[outcome_code])

def _fetch_subgraph_by_targets(conn, events_list):
    """
    Fetch the induced subgraph for the event set on the LEFT side (k’s).
    """
    ph = ",".join(["?"] * len(events_list))
    q = f"""
      SELECT m.*,
             tcn.concept_code AS target_concept_code,   -- LEFT (k)
             ccn.concept_code AS concept_code           -- RIGHT (Y/j)
      FROM magi_counts_top500 m
      JOIN concept_names tcn ON m.target_concept_code_int = tcn.concept_code_int
      JOIN concept_names ccn ON m.concept_code_int        = ccn.concept_code_int
      WHERE tcn.concept_code IN ({ph})
    """
    return pd.read_sql_query(q, conn, params=list(events_list))

# ========= main loop =========
if __name__ == '__main__':
    # Runtime safety checks
    uri = f"file:{MAGI_DB_PATH}?mode=ro"
    if not os.path.exists(MAGI_DB_PATH):
        raise FileNotFoundError(
            f"MAGI_DB_PATH not found: {MAGI_DB_PATH}. "
            f"Set MAGI_DB_PATH to a valid SQLite DB or run with CSV input."
        )

    with sqlite3.connect(uri, uri=True) as conn:
        for T in TARGETS:
            print("\n" + "=" * 100)
            print(f"[RUN] Target = {T}")

            # 1) k→Y rows to compute T_{kY}
            k_to_T = _fetch_k_to_T(conn, T)

            # DEDUP
            before = len(k_to_T)
            k_to_T = k_to_T.drop_duplicates()
            after = len(k_to_T)
            if after < before:
                print(f"[DEDUP] k_to_T exact-row duplicates removed: {before - after}  (kept {after})")

            if k_to_T.empty:
                print(f"[WARN] No predictor→target rows for {T}; skipping.")
                continue

            # 2) derive totals/effects on the Y→k list
            k_to_T = _ensure_derived_cols(k_to_T)

            # 3) pick predictors by TE (RIGHT side = concept_code)
            sel_rows = _select_top_by_te_unique_k(k_to_T, top_k=TOP_K)
            if sel_rows.empty:
                print(f"[WARN] No predictors selected for {T}; skipping.")
                continue

            # Save selected factors for audit
            sub_csv = os.path.join(OUT_DIR, f"risk_prot_{T}.csv")
            os.makedirs(OUT_DIR, exist_ok=True)
            sel_rows.to_csv(sub_csv, index=False)
            print(f"[SAVED] Factors → {sub_csv}")

            # selected_k := chosen predictors; never include T itself
            selected_k = set(sel_rows["target_concept_code"].astype(str))
            if T in selected_k:
                selected_k.remove(T)

            print(f"[SELECT] unique k available={k_to_T['target_concept_code'].nunique():,}  "
                f"selected={len(selected_k):,}")

            if not selected_k:
                print(f"[WARN] No predictors after removing target {T}; skipping.")
                continue

            # 4) subgraph for λ: LEFT in {T} ∪ selected_k; RIGHT in selected_k only
            events_set    = selected_k | {T}
            right_allowed = events_set

            df_trim = _fetch_subgraph_by_targets(conn, sorted(events_set))  # LEFT ∈ selected_k
            df_trim = df_trim[df_trim["concept_code"].isin(right_allowed)].copy()

            # DEDUP
            before = len(df_trim)
            df_trim = df_trim.drop_duplicates()
            after = len(df_trim)
            if after < before:
                print(f"[DEDUP] subgraph exact-row duplicates removed: {before - after}  (kept {after})")

            # Recompute derived cols (incl. total_effect) on the subgraph
            df_trim = _ensure_derived_cols(df_trim)

            # Sanity: both Y→k and j→k present
            y_on_right = int((df_trim["concept_code"] == T).sum())
            j_on_right = int(((df_trim["concept_code"] != T) &
                            (df_trim["concept_code"].isin(selected_k))).sum())
            print(f"[TRIM] rows={len(df_trim):,} events={len(events_set)}  k→Y={y_on_right}  k→j={j_on_right}")

            # 5) save subgraph (audit)
            sub_csv = os.path.join(OUT_DIR, f"magi_subgraph_{T}.csv")
            df_trim.to_csv(sub_csv, index=False)
            print(f"[SAVED] Subgraph → {sub_csv}")

            # 6) run MAGI with outcome forced to T; no alias/name_map in this version
            try:
                res = analyze_causal_sequence_py(
                    df_trim, events=None, name_map=None, force_outcome=T
                )
            except Exception as e:
                print("[ERROR] MAGI failed:", e)
                continue

            # 7) save NON-ZERO coefficients, whitelisting to selected_k (+ intercept)
            outcome_used = res.get("order_used", [T])[-1]
            coef_df = res["coef_df"].copy()

            # robust: accept 'coef' or 'beta' column name
            coef_col = "coef" if "coef" in coef_df.columns else ("beta" if "beta" in coef_df.columns else None)
            if coef_col is None:
                raise KeyError("Coefficient column not found in coef_df (expected 'coef' or 'beta').")

            eps = 1e-12
            mask_nz = coef_df[coef_col].astype(float).abs() > eps
            coef_nz = coef_df.loc[mask_nz].copy()

            # Keep only intercept or predictors in selected_k
            is_intercept = coef_nz["predictor"].eq("(intercept)")
            in_whitelist = coef_nz["predictor"].isin(selected_k)
            keep_mask = is_intercept | in_whitelist
            dropped = int((~keep_mask).sum())
            if dropped:
                print(f"[FILTER] Dropping {dropped} predictor(s) not in whitelist.")

            coef_nz = coef_nz.loc[keep_mask]

            coef_csv = os.path.join(OUT_DIR, f"magi_coef_{outcome_used}_nonzero.csv")
            coef_nz.to_csv(coef_csv, index=False)
            print(f"[SAVED] Coefficients (non-zero only) → {coef_csv}  "
                  f"| kept={len(coef_nz):,} of {len(coef_df):,}  "
                  f"| nodes={len(res.get('order_used', [])) - 1}")
