In [None]:
import sqlite3
import math
from typing import Dict, List, Union, Any, Optional, Set, Tuple
import pandas as pd
import numpy as np
import os, time, re
from collections import defaultdict, deque

# ===== CONFIG =====
MAGI_DB_PATH = os.getenv("MAGI_DB_PATH", "/projects/klybarge/pcori_ad/magi/magi_db/magi.db")  # read-only; override via env
OUT_DIR = "./Mesothelioma"
os.makedirs(OUT_DIR, exist_ok=True)
EDGE_TABLE = "magi_counts_top500"   
uri = f"file:{MAGI_DB_PATH}?mode=ro"

TOP_K = 500  # 500 for others

TARGETS = [
    "dx_SNOMED_254645002",
]

# Updated MAGI AUC=0.7172
# T_{k,Y}  (target=Y, concept=k).
# λ_{k,j}  (target=j, concept=k).
# ===== SNOMED parents+descendants support =====
IS_A = "116680003"
CHAR_TYPES = {
    "inferred": "900000000000011006",  # classification hierarchy
    "stated":   "900000000000010007",
}

SNOMED_REL_FULL_US = "/projects/klybarge/pcori_ad/magi/Test/Test/Mesothelioma/sct2_Relationship_Full_US1000124_20250901.txt"

def build_is_a_snapshot(full_rel_path: str, characteristic: str = "inferred") -> pd.DataFrame:
    """From a Full RF2 file, return the *current* active IS-A rows."""
    use_cols = [
        "id","effectiveTime","active","moduleId",
        "sourceId","destinationId","relationshipGroup",
        "typeId","characteristicTypeId","modifierId"
    ]
    df = pd.read_csv(full_rel_path, sep="\t", dtype=str, usecols=use_cols)
    df = df[(df["typeId"] == IS_A) & (df["characteristicTypeId"] == CHAR_TYPES[characteristic])]
    df["effectiveTime_num"] = df["effectiveTime"].astype(int)
    idx = df.groupby("id")["effectiveTime_num"].idxmax()
    snap = df.loc[idx]
    snap = snap[snap["active"] == "1"][["sourceId","destinationId"]].reset_index(drop=True)
    return snap

# Build parent→children and child→parent maps once (lazy init so the script still runs if file missing)
__SNAP_REL__ = None
__P2C__ = None
__C2P__ = None

def _ensure_graph():
    global __SNAP_REL__, __P2C__, __C2P__
    if __SNAP_REL__ is None:
        __SNAP_REL__ = build_is_a_snapshot(SNOMED_REL_FULL_US, characteristic="inferred")
    if __P2C__ is None or __C2P__ is None:
        __P2C__ = defaultdict(set)
        __C2P__ = defaultdict(set)
        # RF2: destinationId = parent, sourceId = child
        for parent, child in zip(__SNAP_REL__["destinationId"], __SNAP_REL__["sourceId"]):
            __P2C__[parent].add(child)
            __C2P__[child].add(parent)

def find_descendants_sct(concept_id: str) -> set:
    """All is-a descendants (children, grandchildren, ...) for a SNOMED conceptId."""
    _ensure_graph()
    out = set()
    q = deque([concept_id])
    while q:
        cur = q.popleft()
        for kid in __P2C__.get(cur, ()):
            if kid not in out:
                out.add(kid)
                q.append(kid)
    return out

def find_ancestors_sct(concept_id: str) -> set:
    """All is-a ancestors (parents, grandparents, ...) for a SNOMED conceptId."""
    _ensure_graph()
    out = set()
    q = deque([concept_id])
    while q:
        cur = q.popleft()
        for mom in __C2P__.get(cur, ()):
            if mom not in out:
                out.add(mom)
                q.append(mom)
    return out

def extract_snomed_id(code: str) -> Optional[str]:
    """
    If code looks like 'dx_SNOMED_<digits>' return '<digits>', else None.
    """
    m = re.fullmatch(r"dx_SNOMED_(\d+)", str(code))
    return m.group(1) if m else None

def snomed_aliases_for_outcome(outcome_code: str, *, include_parents: bool = True) -> Tuple[Optional[Set[str]], Dict[str, str]]:
    """
    Given an outcome like 'dx_SNOMED_254645002', return:
      - aliases_codes: {'dx_SNOMED_<id>', ...} including the root itself (or None if not SNOMED)
      - name_map: mapping every alias_code -> outcome_code (canonical)
    If include_parents=True, add all ancestors to the alias family as well.
    """
    root_id = extract_snomed_id(outcome_code)
    if root_id is None:
        return None, {}

    # descendants
    desc = find_descendants_sct(root_id)
    all_ids = {root_id} | set(desc)

    # parents/ancestors (optional)
    if include_parents:
        ancs = find_ancestors_sct(root_id)
        all_ids |= set(ancs)

    aliases_codes: Set[str] = {f"dx_SNOMED_{sid}" for sid in all_ids}
    # map every alias (except the canonical root code string itself) back to outcome_code
    name_map: Dict[str, str] = {alias: outcome_code for alias in aliases_codes if alias != outcome_code}
    return aliases_codes, name_map


def analyze_causal_sequence_py(
    df_in: Union[str, pd.DataFrame],
    *,
    name_map: Dict[str, str] = None,     # raw_code -> friendly label (applied to BOTH columns)
    events: List[str] = None,            # event names to KEEP (AFTER recoding). If None: auto-detect
    force_outcome: str = None,           # if provided and present, force this to be the FINAL node (Y)
    lambda_min_count: int = 15           # L-threshold for λ: if n_code < L ⇒ λ_{k,j}=0
) -> Dict[str, Any]:
    """
    MAGI (Python): Reference routine with explicit comments.

    ─────────────────────────────────────────────────────────────────────────────
    COLUMN CONVENTION PER ROW:
      Left  column: target_concept_code  (call this X for this row)
      Right column: concept_code         (call this k for this row)

      n_code_target        ≡  k ∧ X
      n_code_no_target     ≡  k ∧ ¬X
      n_target_no_code     ≡  ¬k ∧ X
      n_no_target          ≡  total(¬X)
      n_target             ≡  total(X)
      n_code               ≡  total(k)
      n_code_before_target ≡  count(k before X)
      n_target_before_code ≡  count(X before k)

    ORIENTATION (locked to your spec):
      • Total effect T_{kY}:   read row (target = Y, code = k).
      • Lambda     λ_{k,j}:    read row (target = j, code = k), and compute
                               λ_{k,j} = n_code_target / n_code  with L-threshold on n_code.

    TEMPORAL SCORE for each node Zi:
      Score(Z_i) = Σ_{j≠i} [ C(Z_i≺Z_j) - C(Z_j≺Z_i) + C(Z_i∧¬Z_j) - C(Z_j∧¬Z_i) ]
      Read from row (target=Z_i, code=Z_j):
        - n_code_before_target      → C(Z_j≺Z_i)
        - n_target_before_code      → C(Z_i≺Z_j)
        - n_code_no_target          → C(Z_j∧¬Z_i)
        - n_target_no_code          → C(Z_i∧¬Z_j)

    T_{kY} from row (Y, k):
      a = n_code_target        (k ∧ Y)
      b = n_code_no_target     (k ∧ ¬Y)
      c = n_target_no_code     (¬k ∧ Y)
      d = n_no_target - b      (¬k ∧ ¬Y)   ← computed on the fly (no extra column needed)
      With sample-size–anchored odds:
         odds_k1 = a/b with guards; odds_k0 = c/d with guards; T = odds_k1 / odds_k0

    DIRECT EFFECTS via backward recursion:
      D_{k,Y} = ( T_{k,Y} - Σ_j λ_{k,j} D_{j,Y} ) / ( 1 - Σ_j λ_{k,j} ),
      where j are downstream nodes between k and Y in the temporal order.

    LOGISTIC LINK:
      logit P(Y=1 | Z) = β0 + Σ_k β_k Z_k  with β_k = log D_{k,Y};
      invalid/nonpositive D map to β_k=0.

    RETURNS a dict with:
      sorted_scores, temporal_order, order_used,
      T_val, D_val, lambda_l, coef_df, beta_0, beta, logit_predictors, predict_proba
    """

    # ── 0) Ingest & validate ───────────────────────────────────────────────────
    df = pd.read_csv(df_in) if isinstance(df_in, str) else df_in.copy()

    need_cols = [
        "target_concept_code", "concept_code",
        "n_code_target", "n_code_no_target",
        "n_target", "n_no_target",
        "n_target_no_code",
        "n_code",                                 # for λ denominator
        "n_code_before_target", "n_target_before_code",
    ]
    missing = [c for c in need_cols if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns: {', '.join(missing)}")

    # Optional recoding to friendly labels (applied to BOTH endpoints)
    if name_map:
        df["target_concept_code"] = df["target_concept_code"].replace(name_map)

    # Limit to selected events (union if auto)
    if events is None:
        ev_t = df["target_concept_code"].astype(str).unique().tolist()
        ev_c = df["concept_code"].astype(str).unique().tolist()
        events = sorted(set(ev_t) | set(ev_c))
    else:
        # Normalize types to avoid silent mismatches
        events = [str(e) for e in events]
        df["target_concept_code"] = df["target_concept_code"].astype(str)
        df["concept_code"] = df["concept_code"].astype(str)
    if len(events) < 2:
        raise ValueError("Need at least two events.")

    df = df[df["target_concept_code"].isin(events) & df["concept_code"].isin(events)].copy()

    # Coerce numerics; NA→0 to allow safe sums/max
    num_cols = [
        "n_code_target", "n_code_no_target",
        "n_target", "n_no_target", "n_target_no_code",
        "n_code", "n_code_before_target", "n_target_before_code",
    ]
    for c in num_cols:
        df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0.0)

    # ── 1) Temporal score (read rows: target=Z_i, code=Z_j) ────────────────────
    scores: Dict[str, float] = {}
    for zi in events:
        s = 0.0
        for zj in (x for x in events if x != zi):
            # Row oriented as (Zi, Zj)
            pair = df[(df["target_concept_code"] == zi) & (df["concept_code"] == zj)]
            if pair.empty:
                continue
            c_i_before_j = float(pair["n_target_before_code"].sum(skipna=True))   # Zi before Zj
            c_j_before_i = float(pair["n_code_before_target"].sum(skipna=True))   # Zj before Zi
            c_i_and_not_j = float(pair["n_target_no_code"].sum(skipna=True))      # Zi ∧ ¬Zj
            c_j_and_not_i = float(pair["n_code_no_target"].sum(skipna=True))      # Zj ∧ ¬Zi
            s += (c_i_before_j - c_j_before_i + c_i_and_not_j - c_j_and_not_i)
        scores[zi] = s

    sorted_scores = pd.Series(scores).sort_values(ascending=False)

    # Choose outcome Y: either forced or top-scoring node
    if force_outcome and (force_outcome in sorted_scores.index):
        outcome = force_outcome
        temporal_order = [ev for ev in sorted_scores.index if ev != outcome] + [outcome]
    else:
        outcome = sorted_scores.index[0]
        temporal_order = [ev for ev in sorted_scores.index if ev != outcome] + [outcome]

    events_order = temporal_order              # earliest … → Y
    nodes = events_order[:-1]            # everything before Y

    # ── 2) T and λ (row orientations locked) ───────────────────────────────────
    T_val = pd.Series(0.0, index=nodes, dtype=float)
    D_val = pd.Series(np.nan, index=nodes, dtype=float)
    lambda_l: Dict[str, pd.Series] = {}

    for k in nodes:
        # ---- T_{kY} from the single row (target=Y, code=k) ----
        row_Yk = df[(df["target_concept_code"] == outcome) & (df["concept_code"] == k)]

        # 2×2 cells (a,b,c,d) at (Y,k):
        a = float(row_Yk["n_code_target"].sum(skipna=True))        # k ∧ Y
        b = float(row_Yk["n_code_no_target"].sum(skipna=True))     # k ∧ ¬Y
        c = float(row_Yk["n_target_no_code"].sum(skipna=True))     # ¬k ∧ Y
        # d is not in data; compute from the same row: d = total(¬Y) − (k ∧ ¬Y)
        n_noY = float(row_Yk["n_no_target"].max(skipna=True)) if not row_Yk.empty else 0.0
        d = max(n_noY - b, 0.0)                                    # ¬k ∧ ¬Y

        # Stratum sizes
        N1 = a + b                   # k = 1
        N0 = c + d                   # k = 0

        # Sample-size–anchored odds for k=1 and k=0
        if N1 == 0:
            odds_k1 = 1.0
        else:
            if a == 0:
                odds_k1 = 1.0 / (N1 + 1.0)
            elif b == 0:
                odds_k1 = (N1 + 1.0)
            else:
                odds_k1 = a / b

        if N0 == 0:
            odds_k0 = 1.0
        else:
            if c == 0:
                odds_k0 = 1.0 / (N0 + 1.0)
            elif d == 0:
                odds_k0 = (N0 + 1.0)
            else:
                odds_k0 = c / d

        T_val.loc[k] = float(odds_k1 / odds_k0) if odds_k0 > 0 else (N1 + 1.0)

        # ---- λ_{k,j} from rows (target=j, code=k): λ = n_code_target / n_code ----
        pos_k = events_order.index(k)
        js = events_order[pos_k + 1 : -1] if pos_k < len(events_order) - 1 else []

        lam_pairs = []
        for j in js:
            row_jk = df[(df["target_concept_code"] == j) & (df["concept_code"] == k)]
            if row_jk.empty:
                lam_pairs.append((j, 0.0))
                continue

            num = float(pd.to_numeric(row_jk["n_code_target"], errors="coerce").fillna(0.0).sum())
            den = float(pd.to_numeric(row_jk["n_code"],        errors="coerce").fillna(0.0).sum())

            # L-threshold on the conditioning size n_code (count of k)
            if (den <= 0) or (den < lambda_min_count):
                lam_pairs.append((j, 0.0))
                continue

            lam = num / den
            lam = 0.0 if not np.isfinite(lam) else float(min(max(lam, 0.0), 1.0))
            lam_pairs.append((j, lam))

        lambda_l[k] = pd.Series({j: v for j, v in lam_pairs}, dtype=float)

    # ── 3) Backward recursion for direct effects D ─────────────────────────────
    # Last node (just before Y): no downstream → D = T
    if len(nodes) >= 1:
        last_anc = nodes[-1]
        D_val.loc[last_anc] = T_val.loc[last_anc]

    # Walk backward for the rest
    if len(nodes) > 1:
        for k in list(reversed(nodes[:-1])):
            lam_vec = lambda_l.get(k, pd.Series(dtype=float))
            downstream = list(lam_vec.index)  # j nodes after k (already resolved)
            lam_vals = lam_vec.reindex(downstream).fillna(0.0).to_numpy()
            D_down  = pd.to_numeric(D_val.reindex(downstream), errors="coerce").fillna(0.0).to_numpy()

            num = T_val.loc[k] - float(np.nansum(lam_vals * D_down))
            den = 1.0 - float(np.nansum(lam_vals))

            if (not np.isfinite(den)) or den == 0.0:
                D_val.loc[k] = T_val.loc[k]            # neutralize if pathological
            else:
                tmp = num / den
                D_val.loc[k] = tmp if np.isfinite(tmp) else T_val.loc[k]

    # ── 4) Logistic link (β) and predict_proba ─────────────────────────────────
    # Intercept β0 from marginal prevalence of Y (rows with target == Y)
    resp_rows = df[df["target_concept_code"] == outcome]
    n_t = float(resp_rows["n_target"].max(skipna=True)) if not resp_rows.empty else np.nan
    n_n = float(resp_rows["n_no_target"].max(skipna=True)) if not resp_rows.empty else np.nan
    denom = n_t + n_n
    p_y = 0.5 if (not np.isfinite(denom) or denom <= 0) else (n_t / denom)
    p_y = min(max(p_y, 1e-12), 1 - 1e-12)
    beta_0 = float(np.log(p_y / (1 - p_y)))

    # β_k = log D_{k,Y}; invalid/nonpositive → 0
    D_clean = pd.to_numeric(D_val, errors="coerce").astype(float)
    beta_vals = np.log(D_clean.where(D_clean > 0.0)).replace([np.inf, -np.inf], np.nan).fillna(0.0)

    coef_df = pd.DataFrame({
        "predictor": list(beta_vals.index) + ["(intercept)"],
        "beta":      list(beta_vals.values) + [beta_0],
    })

    # Vectorized predict_proba
    predictors = list(beta_vals.index)
    beta_vec = beta_vals.values

    def predict_proba(Z: Union[Dict[str, Any], pd.Series, np.ndarray, List[float], pd.DataFrame]):
        """
        Compute P(Y=1|Z) using: logit P = β0 + Σ_k β_k Z_k.
        Z can be:
          - dict/Series mapping predictor name -> 0/1
          - 1D/2D numpy/list with columns ordered as `predictors`
          - DataFrame containing any/all of `predictors` (others ignored)
        """
        def sigmoid(x):
            x = np.clip(x, -700, 700)  # numerical stability for large |η|
            return 1.0 / (1.0 + np.exp(-x))

        if isinstance(Z, pd.DataFrame):
            M = Z.reindex(columns=predictors, fill_value=0.0).astype(float).to_numpy()
            return sigmoid(beta_0 + M @ beta_vec)

        if isinstance(Z, (dict, pd.Series)):
            v = np.array([float(Z.get(p, 0.0)) for p in predictors], dtype=float)
            return float(sigmoid(beta_0 + float(v @ beta_vec)))

        arr = np.asarray(Z, dtype=float)
        if arr.ndim == 1:
            if arr.size != len(predictors):
                raise ValueError(f"Expected {len(predictors)} features in order: {predictors}")
            return float(sigmoid(beta_0 + float(arr @ beta_vec)))
        if arr.ndim == 2:
            if arr.shape[1] != len(predictors):
                raise ValueError(f"Expected shape (*,{len(predictors)}), got {arr.shape}")
            return sigmoid(beta_0 + arr @ beta_vec)

        raise ValueError("Unsupported input for predict_proba")

    # ── 5) Package results ─────────────────────────────────────────────────────
    return {
        "sorted_scores": sorted_scores,
        "temporal_order": temporal_order,
        "order_used": events_order,
        "T_val": T_val,
        "D_val": D_val,
        "lambda_l": lambda_l,
        "coef_df": coef_df,
        "beta_0": beta_0,
        "beta": pd.Series(beta_vec, index=predictors, dtype=float),
        "logit_predictors": predictors,
        "predict_proba": predict_proba,
    }

def _ensure_derived_cols(df: pd.DataFrame) -> pd.DataFrame:
    """
    Only coerce numerics and compute total_effect if absent, using provided cells and
    computing d = n_no_target - n_code_no_target on the fly (since d isn't in data).
    """
    required = [
        "n_code_target", "n_code_no_target",
        "n_target_no_code", "n_target", "n_no_target"
    ]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns for TE: {', '.join(missing)}")

    out = df.copy()
    for c in required:
        out[c] = pd.to_numeric(out[c], errors="coerce").fillna(0.0)

    if "total_effect" not in out.columns:
        a = out["n_code_target"].astype(float)
        b = out["n_code_no_target"].astype(float)
        c = out["n_target_no_code"].astype(float)
        d = (out["n_no_target"] - out["n_code_no_target"]).astype(float)
        d = np.maximum(d, 0.0)

        N1 = a + b
        N0 = c + d

        with np.errstate(divide="ignore", invalid="ignore"):
            odds_k1 = np.where(
                N1 == 0, 1.0,
                np.where((a > 0) & (b > 0), a / b,
                         np.where(b == 0, N1 + 1.0, 1.0 / (N1 + 1.0)))
            )
            odds_k0 = np.where(
                N0 == 0, 1.0,
                np.where((c > 0) & (d > 0), c / d,
                         np.where(d == 0, N0 + 1.0, 1.0 / (N0 + 1.0)))
            )
            te = odds_k1 / odds_k0

        out["total_effect"] = np.where(np.isfinite(te), te, 1.0).astype(float)

    return out


# TOPK_HI/LO
def _select_top_by_te_unique_k(k_to_T: pd.DataFrame, top_k: int = TOP_K,
                               hi: float = 1.5) -> pd.DataFrame:
    """
    Balanced selection without fallback:
      - Take up to top_k//2 strongest risk (TE >= hi)
      - Take up to top_k - top_k//2 strongest protective (TE <= 1/hi)
      - If either side is short, do NOT fill from elsewhere; total may be < top_k
      - Return one row per k (**concept_code**), ranked by extremeness
    """
    df = k_to_T.copy()

    # Ensure total_effect numeric
    df["total_effect"] = pd.to_numeric(df["total_effect"], errors="coerce")
    df = df.dropna(subset=["total_effect"]).copy()
    if df.empty:
        print("[SELECT] no rows with total_effect; returning empty selection.")
        return df

    # Extremeness symmetrical around 1
    df["effect_strength"] = np.where(df["total_effect"] >= 1.0,
                                     df["total_effect"],
                                     1.0 / df["total_effect"])

    # One row per k: **k is concept_code in (target=Y, code=k)**
    best_per_k = (df.sort_values("effect_strength", ascending=False)
                    .drop_duplicates(subset=["concept_code"], keep="first"))

    HI = hi
    LO = 1.0 / hi

    risk_pool = best_per_k[best_per_k["total_effect"] >= HI].copy()
    prot_pool = best_per_k[best_per_k["total_effect"] <= LO].copy()

    want_risk = top_k // 2
    want_prot = top_k - want_risk

    sel_risk = risk_pool.nlargest(min(len(risk_pool), want_risk), "effect_strength")
    sel_prot = prot_pool.nlargest(min(len(prot_pool), want_prot), "effect_strength")

    selected = pd.concat([sel_risk, sel_prot], ignore_index=True)

    print(f"[SELECT] total unique k={best_per_k['concept_code'].nunique():,}  "
          f"risk={len(risk_pool):,}  prot={len(prot_pool):,}  "
          f"selected(total)={selected['concept_code'].nunique()}  "
          f"with: risk={len(sel_risk)}, prot={len(sel_prot)}")

    return selected.reset_index(drop=True)


def _fetch_k_to_T(conn, outcome_code: str) -> pd.DataFrame:
    q = f"""
      SELECT m.*,
             tcn.concept_code AS target_concept_code,
             ccn.concept_code AS concept_code
      FROM {EDGE_TABLE} m
      JOIN concept_names tcn ON m.target_concept_code_int = tcn.concept_code_int
      JOIN concept_names ccn ON m.concept_code_int        = ccn.concept_code_int
      WHERE tcn.concept_code = ?
    """
    return pd.read_sql_query(q, conn, params=[outcome_code])

def _fetch_k_to_T_in(conn, target_codes: List[str]) -> pd.DataFrame:
    """
    Fetch rows whose LEFT (target) is in target_codes.
    This is used to include alias outcomes (descendants of T) on the LEFT.
    """
    tmp_name = "tmp_targets_magi"
    with conn:
        conn.execute(f"DROP TABLE IF EXISTS {tmp_name}")
        conn.execute(f"CREATE TEMP TABLE {tmp_name}(concept_code TEXT)")
        conn.executemany(f"INSERT INTO {tmp_name}(concept_code) VALUES (?)", [(c,) for c in target_codes])

    q = f"""
      SELECT m.*,
             tcn.concept_code AS target_concept_code,   -- LEFT  (Y or Y-alias)
             ccn.concept_code AS concept_code           -- RIGHT (k)
      FROM {EDGE_TABLE} m
      JOIN concept_names tcn ON m.target_concept_code_int = tcn.concept_code_int
      JOIN concept_names ccn ON m.concept_code_int        = ccn.concept_code_int
      JOIN {tmp_name} tmp     ON tcn.concept_code         = tmp.concept_code   -- <<< LEFT filter
    """
    return pd.read_sql_query(q, conn)

def _fetch_subgraph_by_targets(conn, events_list):
    """
    Induced subgraph where LEFT is in events_list.
    """
    ph = ",".join(["?"] * len(events_list))
    q = f"""
      SELECT m.*,
             tcn.concept_code AS target_concept_code,   -- LEFT
             ccn.concept_code AS concept_code           -- RIGHT
      FROM {EDGE_TABLE} m
      JOIN concept_names tcn ON m.target_concept_code_int = tcn.concept_code_int
      JOIN concept_names ccn ON m.concept_code_int        = ccn.concept_code_int
      WHERE tcn.concept_code IN ({ph})
    """
    return pd.read_sql_query(q, conn, params=list(events_list))

def expand_targets_with_descendants(targets: List[str]) -> Dict[str, Set[str]]:
    out: Dict[str, Set[str]] = {}
    for T in targets:
        aliases, _name_map = snomed_aliases_for_outcome(T, include_parents=True)
        out[T] = aliases if aliases is not None else {T}
    return out

def print_full_targets(targets: List[str], *, save_csv: bool = True, out_dir: str = "./"):
    """
    Pretty-print the expanded target definitions and (optionally) save a CSV with all rows:
      root_target, alias_code, alias_conceptId, is_root
    """
    expanded = expand_targets_with_descendants(targets)

    # pretty print to console
    for T, alias_set in expanded.items():
        n = len(alias_set)
        sample = ", ".join(sorted(list(alias_set))[:10])
        print("=" * 100)
        print(f"[TARGET] {T}")
        print(f"  aliases (parents + descendants + root): {n}")
        print(f"  sample: {sample}{' ...' if n > 10 else ''}")

    # optional CSV dump
    if save_csv:
        rows = []
        for T, alias_set in expanded.items():
            root_id = extract_snomed_id(T)
            for alias in sorted(alias_set):
                alias_id = extract_snomed_id(alias)
                rows.append({
                    "root_target": T,
                    "alias_code": alias,
                    "alias_conceptId": alias_id if alias_id else "",
                    "is_root": (alias == T)
                })
        df = pd.DataFrame(rows, columns=["root_target","alias_code","alias_conceptId","is_root"])
        ts = time.strftime("%Y%m%d-%H%M%S")
        os.makedirs(out_dir, exist_ok=True)
        path = os.path.join(out_dir, f"targets_expanded_{ts}.csv")
        df.to_csv(path, index=False)
        print(f"\n[SAVED] Full target expansions → {path}")

    return expanded

# ========= main loop (remove dup) =========
if __name__ == '__main__':
    # ---- Runtime safety checks ----
    uri = f"file:{MAGI_DB_PATH}?mode=ro"
    if not os.path.exists(MAGI_DB_PATH):
        raise FileNotFoundError(f"MAGI_DB_PATH not found: {MAGI_DB_PATH}. Set MAGI_DB_PATH to a valid SQLite DB or run with CSV input.")
    with sqlite3.connect(uri, uri=True) as conn:
        # ---- Save expanded targets list ----
        expanded_targets = print_full_targets(TARGETS, save_csv=True, out_dir=OUT_DIR)
    for T in TARGETS:
        print("\n" + "=" * 100)
        print(f"[RUN] Target = {T}")

        # --- Identify alias family (parents + descendants + root) for filtering only
        #     We will NOT use these as targets anywhere; we just detect & later drop if they appear as predictors.
        aliases_codes, _name_map_alias = snomed_aliases_for_outcome(T, include_parents=True)
        alias_family = set(aliases_codes) if aliases_codes else set()
        if alias_family:
            print(f"[ALIASES] Identified {len(alias_family):,} alias codes for {T} (parents+descendants+root).")
        else:
            print("[ALIASES] Non-SNOMED or no aliases detected; proceeding without alias filtering.")

        # 1) k→T rows: **ONLY the root target T** on the LEFT (no parents/descendants as target)
        k_to_T = _fetch_k_to_T(conn, T)

        # --- Optional: audit if any alias codes show up as RIGHT-side predictors (concept_code)
        if aliases_codes and not k_to_T.empty:
            present_mask = k_to_T["concept_code"].isin(alias_family)
            k_to_T_alias = k_to_T.loc[present_mask].copy()
            if not k_to_T_alias.empty:
                alias_break = (
                    k_to_T_alias
                    .groupby("concept_code", as_index=False)
                    .agg(
                        rows=("concept_code", "size"),
                        n_code_target=("n_code_target", "sum"),
                        n_code_no_target=("n_code_no_target", "sum"),
                        n_target=("n_target", "sum"),
                        n_no_target=("n_no_target", "sum"),
                    )
                    .sort_values("rows", ascending=False)
                )
                alias_csv = os.path.join(OUT_DIR, f"alias_predictors_present_{T}.csv")
                alias_break.to_csv(alias_csv, index=False)
                print(f"[AUDIT] Alias-family predictors present among k→T rows → {alias_csv}")
            else:
                print("[AUDIT] No alias-family predictors present among k→T rows.")

        # --- Drop exact duplicate rows (all columns identical)
        before = len(k_to_T)
        k_to_T = k_to_T.drop_duplicates()
        after = len(k_to_T)
        if after < before:
            print(f"[DEDUP] k→T exact-row duplicates removed: {before - after}  (kept {after})")

        if k_to_T.empty:
            print(f"[WARN] No k→T rows for {T}; skipping.")
            continue

        # 2) derive totals/effects on the k→T list
        k_to_T = _ensure_derived_cols(k_to_T)

        # 3) select predictors: top-K by total_effect (one row per k).  (No alias-target usage.)
        sel_rows = _select_top_by_te_unique_k(k_to_T, top_k=TOP_K)
        if sel_rows.empty:
            print(f"[WARN] No predictors selected for {T}; skipping.")
            continue

        selected_k = set(sel_rows["concept_code"].astype(str))  # predictors are on RIGHT
        print(f"[SELECT] unique k available={k_to_T['concept_code'].nunique():,}  "
            f"selected={len(selected_k):,}")

        # 4) build subgraph among {T} ∪ selected_k
        #    LEFT is always in events_set (targets); RIGHT limited to events_set as well.
        #    (**Do not** expand RIGHT with alias family; we analyze only the chosen predictors and T.)
        events_set = selected_k | {T}
        df_trim = _fetch_subgraph_by_targets(conn, sorted(events_set))
        df_trim = df_trim[
            df_trim["target_concept_code"].isin(events_set) &
            df_trim["concept_code"].isin(events_set)
        ].copy()

        # --- Drop exact duplicate rows
        before = len(df_trim)
        df_trim = df_trim.drop_duplicates()
        after = len(df_trim)
        if after < before:
            print(f"[DEDUP] subgraph exact-row duplicates removed: {before - after}  (kept {after})")

        # Recompute derived cols (incl. total_effect) on the subgraph
        df_trim = _ensure_derived_cols(df_trim)

        # Reporting
        k_to_T_count = (df_trim["concept_code"] == T).sum()
        T_to_j_count = (df_trim["target_concept_code"] == T).sum()
        print(f"[TRIM] rows={len(df_trim):,}  events={len(events_set)}  "
            f"k→T rows={k_to_T_count}  T→j rows={T_to_j_count}")

        # 5) save subgraph (for audit)
        sub_csv = os.path.join(OUT_DIR, f"magi_subgraph_{T}.csv")
        df_trim.to_csv(sub_csv, index=False)
        print(f"[SAVED] Subgraph → {sub_csv}")

        # 6) run MAGI — analyze with **only** the canonical root outcome T
        try:
            # Do NOT pass name_map; do NOT canonicalize aliases; no alias-target expansion.
            res = analyze_causal_sequence_py(df_trim, events=None, name_map=None, force_outcome=T)
        except Exception as e:
            print("[ERROR] MAGI failed:", e)
            continue

        # 7) save NON-ZERO coefficients, then REMOVE any parent/descendant predictors
        outcome_used = res.get("order_used", [T])[-1]
        coef_df = res["coef_df"].copy()

        # robust: accept 'coef' or 'beta' column name
        coef_col = "coef" if "coef" in coef_df.columns else ("beta" if "beta" in coef_df.columns else None)
        if coef_col is None:
            raise KeyError("Coefficient column not found in coef_df (expected 'coef' or 'beta').")

        # filter to non-zero (tolerance to avoid tiny numerical noise)
        eps = 1e-12
        mask_nz = coef_df[coef_col].astype(float).abs() > eps
        coef_nz = coef_df.loc[mask_nz].copy()

        # Identify predictor column name (commonly 'predictor')
        pred_col = "predictor" if "predictor" in coef_nz.columns else None
        if pred_col is None:
            # Fallback: try to infer the predictor column as the first non-coef column not equal to '(intercept)'
            maybe_pred_cols = [c for c in coef_nz.columns if c not in {coef_col}]
            if "(intercept)" in coef_nz.get(maybe_pred_cols[0], "").values if maybe_pred_cols else False:
                # nothing sensible to infer; keep as-is
                pred_col = None
            else:
                pred_col = maybe_pred_cols[0] if maybe_pred_cols else None

        # Drop alias-family predictors (parents/descendants of T), always keep intercept row
        dropped_alias = 0
        if pred_col and alias_family:
            is_intercept = coef_nz[pred_col].astype(str) == "(intercept)"
            is_alias_pred = coef_nz[pred_col].astype(str).isin(alias_family)
            dropped_alias = int(is_alias_pred.sum())
            coef_nz = coef_nz[is_intercept | (~is_alias_pred)].copy()

        coef_csv = os.path.join(OUT_DIR, f"magi_coef_{outcome_used}_nonzero.csv")
        coef_nz.to_csv(coef_csv, index=False)

        print(f"[SAVED] Coefficients (non-zero only, alias-family predictors removed={dropped_alias}) → {coef_csv}  "
            f"| kept={len(coef_nz):,} of {len(coef_df):,}  "
            f"| nodes={len(res.get('order_used', [])) - 1}  "
            f"| used_total_effect={res.get('used_total_effect', True)}")
