In [None]:
# typing: static type hints for clarity & IDE/tooling support
from typing import Union, Dict, List, Any
# pandas: tabular data I/O (CSV) and DataFrame transforms/joins/group-bys
import pandas as pd
# numpy: fast vectorized numerics (arrays, nansum, clipping) for counts & logits
import numpy as np
# math: scalar math utilities; used for log in safe_log (per-element, non-vector)
import math
# build file paths, make folders, read env vars, list files, etc.
import os
# open a connection to a DB file, run SQL queries, fetch results
import sqlite3


In [None]:
############### MAGI FUNCTION ###############
"""
MAGI: Dependent Bayes with Temporal Ordering — Reference Implementation
-------------------------------------------------------------------------------
This file adds **step-by-step comments** to the function `analyze_causal_sequence_py`. The comments mirror the proposed
algorithm sections:

1) Determination of Temporal Order
2) Estimation of Dependent Bayes (T-values, λ-links, and D-values)
3) Logistic link to produce P(Y=1 | Z)

INPUT TABLE EXPECTATIONS (long format, one row per directed pair):
- target_concept_code : str (# left node in the edge; in many places this
denotes the earlier event)
- concept_code : str (# right node in the edge; the later event)
- n_code_target : float/int (# co-occurrence count for concept & target
in the *target=Y* stratum; used as 'a')
- n_code_no_target : float/int (# co-occurrence count for concept & target
in the *target=¬Y* stratum; used as 'b')
- n_target : float/int (# total Y count for this target concept, per edge row; we take max within a block)
- n_no_target : float/int (# total ¬Y count; max within a block)
- no_code_no_target : float/int (# optional, computed if missing as n_no_target - n_code_no_target; clipped ≥0)
- n_target_before_code: float/int (# count where target occurred before code)
- n_code_before_target: float/int (# count where code occurred before target)

NOTES ON TEMPORAL COUNTS:
For a row with target_concept_code = A and concept_code = B, the columns
`n_code_before_target` and `n_target_before_code` are interpreted as:
- n_code_before_target: # of persons where **B happened before A**
- n_target_before_code: # of persons where **A happened before B**
We aggregate these across j≠i to compute *temporal scores* for each event.

PIECEWISE, SAMPLE-SIZE–ANCHORED ADJUSTMENTS:
- For odds terms that would be 0 or ∞ due to zero cells, we replace the
offending odds with 1/(N+1) or (N+1)/1, where N is the size of the
appropriate stratum, so all ratios remain finite and interpretable.

RETURN VALUE:
A dict with temporal ordering, T-values, λ-vectors, D-values, coefficients
for a logistic link, a `predict_proba` callable, and trace tables.

"""

def analyze_causal_sequence_py(
    data: Union[str, pd.DataFrame],
    name_map: Dict[str, str],
    events: List[str],
    force_outcome=None,
) -> Dict[str, Any]:
    """Compute temporal order, dependent-Bayes direct effects (D), and
    a logistic-link probability for outcome Y from pairwise counts.

    Parameters
    ----------
    data : str | DataFrame
        CSV path or in-memory DataFrame with the columns described above.
    name_map : Dict[str, str]
        Optional mapping raw code -> friendly label. If provided, both
        `target_concept_code` and `concept_code` are replaced.
    events : List[str]
        List of event names/codes to restrict the analysis to. If `None`,
        events are auto-detected from the table and intersected.
    force_outcome : str | None
        If provided and found among events, this event is forced to be the
        **final** node (i.e., the outcome) in the temporal order.

    Returns
    -------
    Dict[str, Any]
        - sorted_scores : pd.Series of temporal scores (desc)
        - temporal_order: list of events (outcome at the end)
        - order_used    : same as temporal_order
        - T_val         : pd.Series of total effects T_{k,Y}
        - D_val         : pd.Series of direct effects D_{k,Y}
        - coef_df       : pd.DataFrame of coefficients (β_k and intercept)
        - lambda_l      : dict[str -> pd.Series] of λ_{k,j} vectors
        - trace_df      : pd.DataFrame detailing the backward recursion steps
        - invalid_predictors: list of predictors whose log(D) was invalid
        - beta_0, beta, logit_predictors, predict_proba: logistic elements
    """

    # ---------------------------------------------------------------------
    # 0) INGEST & BASIC VALIDATION
    # ---------------------------------------------------------------------
    if isinstance(data, str):
        # Read from CSV path
        df = pd.read_csv(data)
    else:
        # Work on a copy to avoid mutating caller's object
        df = data.copy()

    # Ensure required identifier columns are present
    for col in ["target_concept_code", "concept_code"]:
        if col not in df.columns:
            raise ValueError(f"Missing required column: {col}")

    # Optional recoding to human-friendly labels
    if name_map:
        df["target_concept_code"] = df["target_concept_code"].replace(name_map)
        df["concept_code"] = df["concept_code"].replace(name_map)

    # Numeric columns the algorithm expects
    need = [
        "n_code_target", "n_code_no_target", "n_target", "n_no_target",
        "n_target_before_code", "n_code_before_target",
    ]
    missing = [c for c in need if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns: {', '.join(missing)}")

    # If precomputed total_effect exists (for λ or sanity checks), coerce to numeric
    has_total = "total_effect" in df.columns
    if has_total:
        df["total_effect"] = pd.to_numeric(df["total_effect"], errors="coerce")

    # ---------------------------------------------------------------------
    # 1) EVENT SET (optional auto-detect) & TYPE COERCION
    # ---------------------------------------------------------------------
    if events is None:
        # Auto-detect: intersect events appearing on both sides of edges
        ev_targets = df["target_concept_code"].dropna().astype(str).unique().tolist()
        ev_code = df["concept_code"].dropna().astype(str).unique().tolist()
        events = sorted(set(ev_targets).intersection(ev_code))
        if len(events) == 0:
            # Fall back to union if intersection is empty
            events = sorted(set(ev_targets) | set(ev_code))
    if len(events) < 2:
        raise ValueError("Need at least two events after auto-detection.")

    # Keep only rows whose endpoints are both in the chosen event set
    df = df[df["target_concept_code"].isin(events) & df["concept_code"].isin(events)].copy()

    # Coerce numeric columns robustly (invalid -> NaN); subsequent ops handle NaNs
    for c in need:
        df[c] = pd.to_numeric(df[c], errors="coerce")

    # If `no_code_no_target` missing, derive as (n_no_target - n_code_no_target) ≥ 0
    if "no_code_no_target" not in df.columns:
        df["no_code_no_target"] = (df["n_no_target"] - df["n_code_no_target"]).clip(lower=0)
    else:
        df["no_code_no_target"] = pd.to_numeric(df["no_code_no_target"], errors="coerce").clip(lower=0)

    # Helper: total count for an event (max n_target where that event is target)
    # This mirrors your original choice; change to .sum() if warranted.
    def C_of(ev: str) -> float:
        sub = df[df["target_concept_code"] == ev]
        if sub.empty:
            return float("nan")
        C = pd.to_numeric(sub["n_target"], errors="coerce").max()
        return float(C) if pd.notna(C) and np.isfinite(C) else float("nan")

    # Helper: numerically safe log (log(1)=0 for invalid/≤0)
    def safe_log(x: float) -> float:
        try:
            xv = float(x)
        except (TypeError, ValueError):
            return 0.0
        if not np.isfinite(xv) or xv <= 0.0:
            return 0.0
        return math.log(xv)

    # ---------------------------------------------------------------------
    # 2) TEMPORAL ORDER — pairwise-before counts → per-node score
    # ---------------------------------------------------------------------
    # Score(Z_i) = Σ_{j≠i} [ C(Z_i≪Z_j) - C(Z_j≪Z_i) + C(Z_i ∩ ¬Z_j) - C(Z_j ∩ ¬Z_i) ]
    # Here we implement the *before/after* portion using the provided columns.
    # If only presence/absence terms are available, you may approximate using
    # the last two terms.
    scores: Dict[str, float] = {}
    for zk in events:
        s = 0.0
        for zj in [x for x in events if x != zk]:
            # For the pair (zj, zk), we interpret:
            #   n_code_before_target   — # where zk (code) before zj (target)
            #   n_target_before_code   — # where zj (target) before zk (code)
            rowr = df[(df["target_concept_code"] == zj) & (df["concept_code"] == zk)]
            if not rowr.empty:
                s += float(rowr["n_code_before_target"].sum(skipna=True) -
                           rowr["n_target_before_code"].sum(skipna=True))
        scores[zk] = s

    sorted_scores = pd.Series(scores).sort_values(ascending=False)

    # Outcome selection:
    #  - If `force_outcome` is provided and present, put it at the end.
    #  - Else, default to the top-scoring node as outcome.
    if force_outcome and (force_outcome in sorted_scores.index):
        outcome_event = force_outcome
        temporal_order = [ev for ev in sorted_scores.index if ev != outcome_event] + [outcome_event]
    else:
        outcome_event = sorted_scores.index[0]
        temporal_order = [ev for ev in sorted_scores.index if ev != outcome_event] + [outcome_event]

    # Propagation order is the temporal order; last is outcome Y
    events_order = temporal_order
    outcome = events_order[-1]
    antecedents = events_order[:-1]

    # ---------------------------------------------------------------------
    # 3) T-VALUES (TOTAL EFFECTS) & λ-LINKS BETWEEN ANTECEDENTS
    # ---------------------------------------------------------------------
    # For each antecedent k, we compute T_{k,Y} as an odds ratio between
    # strata Z_k=1 and Z_k=0 with sample-size–anchored fixes for zero cells.
    # For λ_{k,j} (dependence of j given k), we first use precomputed
    # `total_effect` if present; otherwise, compute from counts with similar
    # corrections and/or normalized co-occurrence.

    T_val = pd.Series(0.0, index=antecedents, dtype=float)  # T_{k,Y}
    D_val = pd.Series(np.nan, index=antecedents, dtype=float)  # D_{k,Y}
    lambda_l: Dict[str, pd.Series] = {}  # per-k vector of λ_{k,j} for j after k and before Y

    for k in antecedents:
        # ---- Contingency for (k, outcome) ----
        row_ko = df[(df["target_concept_code"] == k) & (df["concept_code"] == outcome)]

        # a: Y∩k    b: ¬Y∩k    c: Y∩¬k    d: ¬Y∩¬k
        a = float(row_ko["n_code_target"].sum(skipna=True))            # co-occurrence in Y
        b = float(row_ko["n_code_no_target"].sum(skipna=True))         # co-occurrence in ¬Y

        # If n_target_no_code absent, approximate c by (max n_target - a)
        if "n_target_no_code" in row_ko.columns:
            c = float(row_ko["n_target_no_code"].sum(skipna=True))
        else:
            c = float(pd.to_numeric(row_ko["n_target"], errors="coerce").max() - a)

        # If no_code_no_target absent, approximate d by (max n_no_target - b)
        if "no_code_no_target" in row_ko.columns:
            d = float(row_ko["no_code_no_target"].sum(skipna=True))
        else:
            d = float(pd.to_numeric(row_ko["n_no_target"], errors="coerce").max() - b)

        N1, N0 = a + b, c + d  # stratum sizes for Z_k=1 and Z_k=0

        # ---- Sample-size–anchored odds for Z_k=1 ----
        if a == 0:
            odds_pos_adj = 1.0 / (N1 + 1.0)  # 0 → tiny positive odds
        elif b == 0:
            odds_pos_adj = (N1 + 1.0) / 1.0  # ∞ → capped by (N1+1)
        else:
            odds_pos_adj = a / b

        # ---- Sample-size–anchored odds for Z_k=0 ----
        if c == 0:
            odds_neg_adj = 1.0 / (N0 + 1.0)
        elif d == 0:
            odds_neg_adj = (N0 + 1.0) / 1.0
        else:
            odds_neg_adj = c / d

        # Total effect T_{k,Y}
        T_val.loc[k] = float(odds_pos_adj / odds_neg_adj)

        # ---- λ_{k,j}: dependence of j on k for nodes j after k (and before Y) ----
        pos_k = events_order.index(k)
        js = events_order[pos_k + 1 : -1] if pos_k < len(events_order) - 1 else []

        lam_pairs = []
        for j in js:
            row_kj = df[(df["target_concept_code"] == k) & (df["concept_code"] == j)]
            if row_kj.empty:
                lam_pairs.append((j, 0.0))  # no evidence of dependence
                continue

            # Prefer precomputed total_effect if present for this edge
            #te = float(pd.to_numeric(row_kj["total_effect"], errors="coerce").max()) if has_total else float("nan")
            #if np.isfinite(te):
             #   lam_pairs.append((j, te))
              #  continue

            # Otherwise approximate λ with piecewise, size-anchored logic
            # C11 = C(j∩k); Cj_not_k = C(j∩¬k); Ck = C(k)
            C11 = float(row_kj["n_code_target"].sum(skipna=True))  # re-using the same column name for co-occurrence
            if "n_code_no_target" in row_kj.columns:
                Cj_not_k = float(row_kj["n_code_no_target"].sum(skipna=True))
            else:
                Cj = C_of(j)
                Cj_not_k = 0.0 if (not np.isfinite(Cj)) else max(Cj - C11, 0.0)
            Ck = C_of(k)

            if Cj_not_k == 0:
                L = 1.0 + C11          # always-with-k → boost
            elif C11 == 0:
                L = 1.0 / (1.0 + Cj_not_k)  # never-with-k → downweight
            elif np.isfinite(Ck) and Ck > 0:
                L = C11 / Ck           # normalized co-occurrence
            else:
                L = 0.0

            lam_pairs.append((j, float(L)))

        lambda_l[k] = pd.Series({j: v for j, v in lam_pairs}, dtype=float)

    # ---------------------------------------------------------------------
    # 4) BACKWARD RECURSION — resolve D_{k,Y} from T and λ
    # ---------------------------------------------------------------------
    # D_{k,Y} = ( T_{k,Y} - Σ_i λ_{k,k+i} * D_{k+i,Y} ) / ( 1 - Σ_i λ_{k,k+i} )
    # Start at the last antecedent (just before Y): D := T, since there are no
    # downstream nodes to adjust for.

    trace_rows = []  # for human-auditable tracing of the recursion

    last_anc = antecedents[-1] if antecedents else None
    if last_anc is not None:
        D_val.loc[last_anc] = T_val.loc[last_anc]
        trace_rows.append({
            "stage": "Last 2 Nodes",
            "nodes": f"{last_anc} - {outcome}",
            "k": last_anc,
            "T_kY": T_val.loc[last_anc],
            "lambda_terms": None,
            "sum_lambda": 0.0,
            "D_kY": D_val.loc[last_anc],
            "log_D": safe_log(D_val.loc[last_anc]),
        })

    if len(antecedents) > 1:
        # Walk backwards through the remaining antecedents
        for k in list(reversed(antecedents))[1:]:
            lam_vec = lambda_l.get(k, pd.Series(dtype=float))
            # When computing the adjustment, use **only** D-values for nodes that
            # are after k and already resolved (present in lam_vec index).
            code = list(lam_vec.index)
            num = T_val.loc[k] - float(np.nansum(lam_vec.reindex(code).values * D_val.reindex(code).values))
            den = 1.0 - float(np.nansum(lam_vec.values))  # may approach 0 if λ's are large

            # If den is pathological (≤0 or NaN), fall back to T (neutralization).
            D_val.loc[k] = (num / den) if np.isfinite(num / den) else T_val.loc[k]

            span = len(events_order) - events_order.index(k) + 1
            lam_str = ", ".join(
                f"λ_{events_order.index(k)+1}{events_order.index(c)+1}={lam_vec[c]:.6f}"
                for c in code
            ) if len(lam_vec) else None

            trace_rows.append({
                "stage": f"Last {span} Nodes",
                "nodes": " - ".join([k] + events_order[events_order.index(k)+1:]),
                "k": k,
                "T_kY": T_val.loc[k],
                "lambda_terms": lam_str,
                "sum_lambda": float(np.nansum(lam_vec.values)),
                "D_kY": D_val.loc[k],
                "log_D": safe_log(D_val.loc[k]),
            })

    trace_df = pd.DataFrame(trace_rows)

    # ---------------------------------------------------------------------
    # 5) COEFFICIENTS — map D's onto a logistic link
    # ---------------------------------------------------------------------
    # We model:  logit P(Y=1 | Z) = β0 + Σ_k β_k * Z_k,  with β_k = log D_{k,Y}
    # Intercept β0 is set by the marginal prevalence of Y (from `n_target` &
    # `n_no_target`) for the outcome rows.

    resp_rows = df[df["target_concept_code"] == outcome]
    if resp_rows.empty:
        raise ValueError(f"No rows for outcome '{outcome}'.")

    n_t = resp_rows["n_target"].dropna().iloc[0] if resp_rows["n_target"].dropna().size else np.nan
    n_n = resp_rows["n_no_target"].dropna().iloc[0] if resp_rows["n_no_target"].dropna().size else np.nan
    denom = n_t + n_n
    p_y = 0.5 if (not np.isfinite(denom) or denom <= 0) else (n_t / denom)
    beta_0 = float(np.log(p_y / (1 - p_y)))  # stable enough for p∈(0,1)

    # β_k = log(D_{k,Y}); protect against non-positive D by mapping to 0
    D_clean = pd.to_numeric(D_val, errors="coerce").astype(float)
    D_pos = D_clean.where(D_clean > 0)

    with np.errstate(divide="ignore", invalid="ignore"):
        beta_vals = np.log(D_pos.to_numpy())
    beta_k_raw = pd.Series(beta_vals, index=D_val.index)
    invalid_predictors = list(beta_k_raw[~np.isfinite(beta_k_raw)].index)

    beta_k = beta_k_raw.copy()
    beta_k[~np.isfinite(beta_k)] = 0.0  # neutralize invalid predictors

    coef_df = pd.DataFrame({
        "predictor": list(beta_k.index) + ["(intercept)"],
        "beta": list(beta_k.astype(float).values) + [beta_0],
    })

    # ---------------------------------------------------------------------
    # 6) PREDICT_PROBA — vectorized logistic link
    # ---------------------------------------------------------------------
    predictors = list(beta_k.index)
    beta_vec = beta_k.astype(float).values

    def predict_proba(z: Union[Dict[str, Any], pd.Series, np.ndarray, List[float], pd.DataFrame]) -> Union[float, np.ndarray, pd.Series]:
        """Compute P(Y=1 | Z) using the logistic link.

        Accepts:
          - dict/Series mapping predictor name -> 0/1
          - 1D/2D numpy/list with columns ordered as `predictors`
          - DataFrame with columns containing any/all of `predictors` (others ignored)

        Returns:
          - float for 1D inputs; np.ndarray or pd.Series for vectorized inputs
        """
        if isinstance(z, pd.DataFrame):
            Z = z.reindex(columns=predictors, fill_value=0).astype(float).to_numpy()
            eta = beta_0 + Z @ beta_vec
            # Stable sigmoid via clipping; avoids overflow for extreme η
            return 1.0 / (1.0 + np.exp(-np.clip(eta, -700, 700)))

        if isinstance(z, (dict, pd.Series)):
            v = np.array([float(z.get(p, 0.0)) for p in predictors], dtype=float)
            eta = beta_0 + float(v @ beta_vec)
            return float(1.0 / (1.0 + np.exp(-np.clip(eta, -700, 700))))

        arr = np.asarray(z, dtype=float)
        if arr.ndim == 1:
            if arr.size != len(predictors):
                raise ValueError(f"Expected {len(predictors)} features in order: {predictors}")
            eta = beta_0 + float(arr @ beta_vec)
            return float(1.0 / (1.0 + np.exp(-np.clip(eta, -700, 700))))
        else:
            if arr.shape[1] != len(predictors):
                raise ValueError(f"Expected shape (*, {len(predictors)}), got {arr.shape}")
            eta = beta_0 + arr @ beta_vec
            return 1.0 / (1.0 + np.exp(-np.clip(eta, -700, 700)))

    # ---------------------------------------------------------------------
    # 7) PACKAGE RESULTS
    # ---------------------------------------------------------------------
    return {
        "sorted_scores": sorted_scores,
        "temporal_order": temporal_order,
        "order_used": events_order,
        "T_val": T_val,
        "D_val": D_val,
        "coef_df": coef_df,
        "lambda_l": lambda_l,
        "trace_df": trace_df,
        "invalid_predictors": invalid_predictors,
        # Logistic link outputs:
        "beta_0": beta_0,
        "beta": pd.Series(beta_vec, index=predictors, dtype=float),
        "logit_predictors": predictors,
        "predict_proba": predict_proba,
    }

# ========= helpers =========
def _ensure_derived_cols(df: pd.DataFrame) -> pd.DataFrame:
    """Ensure derived columns exist and are numeric; compute `total_effect` if missing.

    Purpose
    -------
    This normalizes the input long-format edge table so downstream MAGI steps
    can rely on consistent columns and numeric types.

    What it does
    ------------
    1) Coerces relevant count columns to numeric (invalids -> NaN).
    2) Derives `no_code_no_target` if missing, as max(n_no_target - n_code_no_target, 0).
    3) Computes `total_effect` (TE) per row **if not already present**:
       TE ≈ odds ratio proxy for (target_concept_code -> concept_code) based on
       available counts. This is used later as a λ candidate when building
       dependence links between antecedents.

    Parameters
    ----------
    df : pd.DataFrame
        Long-format table with at least:
        ['target_concept_code','concept_code','n_code_target','n_code_no_target',
         'n_target','n_no_target','n_target_before_code','n_code_before_target'].

    Returns
    -------
    pd.DataFrame
        A copy of `df` with coerced numeric columns, ensured/derived counts,
        and a `total_effect` column present.
    """
    df = df.copy()

    # -- 1) Coerce to numeric (robust to garbage / missing strings)
    base_cols = [
        "n_code_target", "n_code_no_target", "n_target", "n_no_target",
        "n_target_before_code", "n_code_before_target",
        "no_code_no_target", "n_target_no_code", "total_effect",
    ]
    for c in base_cols:
        if c not in df.columns:
            df[c] = 0
        df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0.0)

    # -- 2) Derive no_code_no_target if absent (the "d" cell in a 2x2)
    if "no_code_no_target" not in df.columns:
        # clip at 0 to avoid negative residuals due to data noise
        df["no_code_no_target"] = (df["n_no_target"] - df["n_code_no_target"]).clip(lower=0.0)

    # -- 3) Total_effect (OR with sample-size–anchored smoothing)
    if "total_effect" not in df.columns:
        a  = df["n_code_target"].astype(float)                      # k=1, T=1
        b  = df["n_code_no_target"].astype(float)                   # k=1, T=0
        c_ = (df["n_target"].astype(float) - a).clip(lower=0.0)     # k=0, T=1
        d  = (df["n_no_target"].astype(float) - b).clip(lower=0.0)  # k=0, T=0
        N1, N0 = (a+b), (c_+d)

        with np.errstate(divide="ignore", invalid="ignore"):
            odds_pos = np.where((a>0)&(b>0), a/b, np.where(b==0, N1+1.0, 1.0/(N1+1.0)))
            odds_neg = np.where((c_>0)&(d>0), c_/d, np.where(d==0, N0+1.0, 1.0/(N0+1.0)))
            df["total_effect"] = odds_pos / odds_neg

    return df

# TOPK_HI/LO
def _select_by_risk_prot(k_to_T: pd.DataFrame, top_k: int = 500,
                                   hi: float = 1.5) -> pd.DataFrame:
    """
    Balanced selection without fallback:
      - Take up to top_k//2 strongest risk (TE >= hi)
      - Take up to top_k - top_k//2 strongest protective (TE <= 1/hi)
      - If either side is short, do NOT fill from elsewhere; total may be < top_k
      - Return one row per k (target_concept_code), ranked by extremeness
    """
    df = k_to_T.copy()

    # Ensure total_effect numeric
    df["total_effect"] = pd.to_numeric(df["total_effect"], errors="coerce")
    df = df.dropna(subset=["total_effect"]).copy()
    if df.empty:
        return df

    # Extremeness symmetrical around 1
    df["effect_strength"] = np.where(df["total_effect"] >= 1.0,
                                     df["total_effect"],
                                     1.0 / df["total_effect"])

    # One row per k (keep most extreme)
    best_per_k = (df.sort_values("effect_strength", ascending=False)
                    .drop_duplicates(subset=["target_concept_code"], keep="first"))

    HI = hi
    LO = 1.0 / hi

    risk_pool = best_per_k[best_per_k["total_effect"] >= HI].copy()
    prot_pool = best_per_k[best_per_k["total_effect"] <= LO].copy()

    # Target split
    want_risk = top_k // 2
    want_prot = top_k - want_risk

    # Take strongest from each side, capped by availability
    take_risk = min(len(risk_pool), want_risk)
    take_prot = min(len(prot_pool), want_prot)

    sel_risk = risk_pool.nlargest(take_risk, "effect_strength")
    sel_prot = prot_pool.nlargest(take_prot, "effect_strength")

    selected = pd.concat([sel_risk, sel_prot], ignore_index=True)

    return selected.reset_index(drop=True)

def _fetch_k_to_T(conn, target_code: str) -> pd.DataFrame:
    """Fetch only k→T rows (concept_code == target)."""
    q = """
      SELECT m.*,
             tcn.concept_code AS target_concept_code,   -- k (predictor)
             ccn.concept_code AS concept_code           -- T (outcome)
      FROM magi_counts_top500 m
      JOIN concept_names tcn ON m.target_concept_code_int = tcn.concept_code_int
      JOIN concept_names ccn ON m.concept_code_int        = ccn.concept_code_int
      WHERE ccn.concept_code = ?
    """
    return pd.read_sql_query(q, conn, params=[target_code])

def _fetch_subgraph_by_targets(conn, events_list):
    """Fetch edges with target_concept_code IN events_set (single IN to avoid 999 param issues)."""
    ph = ",".join(["?"] * len(events_list))
    q = f"""
      SELECT m.*,
             tcn.concept_code AS target_concept_code,
             ccn.concept_code AS concept_code
      FROM magi_counts_top500 m
      JOIN concept_names tcn ON m.target_concept_code_int = tcn.concept_code_int
      JOIN concept_names ccn ON m.concept_code_int        = ccn.concept_code_int
      WHERE tcn.concept_code IN ({ph})
    """
    return pd.read_sql_query(q, conn, params=list(events_list))

# ========= main loop =========
uri = f"file:{MAGI_DB_PATH}?mode=ro"
with sqlite3.connect(uri, uri=True) as conn:
    for T in TARGETS:

        # 1) k→T rows (concept_code == T)
        k_to_T = _fetch_k_to_T(conn, T)

        # --- drop exact duplicate rows (all columns identical) ---
        before = len(k_to_T)
        k_to_T = k_to_T.drop_duplicates()
        after = len(k_to_T)
        if after < before:
            print(f"[DEDUP] k_to_T exact-row duplicates removed: {before - after}  (kept {after})")

        if k_to_T.empty:
            continue

        # 2) derive totals/effects on the k→T list
        k_to_T = _ensure_derived_cols(k_to_T)

        # 3) select predictors
        sel_rows = _select_by_risk_prot(k_to_T, top_k=500)
        if sel_rows.empty:
            continue

        # IMPORTANT: predictors are the LEFT side (target_concept_code)
        selected_k = set(sel_rows["target_concept_code"].astype(str))

        # 4) build subgraph among {T} ∪ selected_k
        events_set = selected_k | {T}
        df_trim = _fetch_subgraph_by_targets(conn, sorted(events_set))
        df_trim = df_trim[
            df_trim["target_concept_code"].isin(events_set) &
            df_trim["concept_code"].isin(events_set)
        ].copy()

        # --- drop exact duplicate rows (all columns identical) ---
        before = len(df_trim)
        df_trim = df_trim.drop_duplicates()
        after = len(df_trim)
        if after < before:
            print(f"[DEDUP] subgraph exact-row duplicates removed: {before - after}  (kept {after})")

        # Recompute derived cols (incl. total_effect) on the subgraph
        df_trim = _ensure_derived_cols(df_trim)

        #   k→T rows: concept_code == T
        #   T→j rows: target_concept_code == T

        # 5) save subgraph (for audit)
        sub_csv = os.path.join(OUT_DIR, f"magi_subgraph_{T}.csv")
        df_trim.to_csv(sub_csv, index=False)

        # 6) run MAGI
        try:
            res = analyze_causal_sequence_py(df_trim, events=None, name_map=None, force_outcome=T)
        except TypeError:
            res = analyze_causal_sequence_py(df_trim)
            outcome_used = res.get("order_used", [None])[-1]
            if outcome_used != T:
                print(f"[NOTE] outcome auto-inferred as {outcome_used}, not {T}")

        # 7) save NON-ZERO coefficients only
        outcome_used = res.get("order_used", [T])[-1]
        coef_df = res["coef_df"].copy()

        # robust: accept 'coef' or 'beta' column name
        coef_col = "coef" if "coef" in coef_df.columns else ("beta" if "beta" in coef_df.columns else None)
        if coef_col is None:
            raise KeyError("Coefficient column not found in coef_df (expected 'coef' or 'beta').")

        # filter to non-zero (tolerance to avoid tiny numerical noise)
        eps = 1e-12
        mask_nz = coef_df[coef_col].astype(float).abs() > eps
        coef_nz = coef_df.loc[mask_nz].copy()

        coef_csv = os.path.join(OUT_DIR, f"magi_coef_{outcome_used}_nonzero.csv")
        coef_nz.to_csv(coef_csv, index=False)