1. Builds the expanded outcome family = {target SNOMED} ∪ descendants ∪ any Maps-to sources to any member of that set.

2. Excludes candidate predictors that are in that family, map to that family, or have descendants intersecting that family.

3. Runs the sibling test: if a predictor maps to the same standard as the target → drop; else compute Jaccard(k,Y) and PPV(Y|k) from your Y–k row and drop if they exceed thresholds (defaults: Jaccard ≥ 0.88, PPV ≥ 0.90; tweakable)

4. Adds an optional name-based safety net with strict disease-name patterns.

In [None]:
import sqlite3
import math
from typing import Dict, List, Union, Any, Optional, Set, Tuple
import pandas as pd
import numpy as np
import os, time, re
from collections import defaultdict, deque

# ===== CONFIG =====
MAGI_DB_PATH = os.getenv("MAGI_DB_PATH", "/projects/klybarge/pcori_ad/magi/magi_db/magi.db")  # read-only; override via env
OUT_DIR = "./BreastCancer"
os.makedirs(OUT_DIR, exist_ok=True)
EDGE_TABLE = "magi_counts_top500"   
uri = f"file:{MAGI_DB_PATH}?mode=ro"

TOP_K = 500  # 500 for others

OMOP_DB_PATH           = "./omop_athena.db"

TARGETS = [
    "dx_SNOMED_254837009",
]

# Updated 20251101 MAGI AUC=0.7172
# T_{k,Y}  (target=Y, concept=k).
# λ_{k,j}  (target=j, concept=k).
# ===== SNOMED parents+descendants support =====
IS_A = "116680003"
CHAR_TYPES = {
    "inferred": "900000000000011006",  # classification hierarchy
    "stated":   "900000000000010007",
}

SNOMED_REL_FULL_US = "/projects/klybarge/pcori_ad/magi/Test/Test/RareCancer/sct2_Relationship_Full_US1000124_20250901.txt"

def build_is_a_snapshot(full_rel_path: str, characteristic: str = "inferred") -> pd.DataFrame:
    """From a Full RF2 file, return the *current* active IS-A rows."""
    use_cols = [
        "id","effectiveTime","active","moduleId",
        "sourceId","destinationId","relationshipGroup",
        "typeId","characteristicTypeId","modifierId"
    ]
    df = pd.read_csv(full_rel_path, sep="\t", dtype=str, usecols=use_cols)
    df = df[(df["typeId"] == IS_A) & (df["characteristicTypeId"] == CHAR_TYPES[characteristic])]
    df["effectiveTime_num"] = df["effectiveTime"].astype(int)
    idx = df.groupby("id")["effectiveTime_num"].idxmax()
    snap = df.loc[idx]
    snap = snap[snap["active"] == "1"][["sourceId","destinationId"]].reset_index(drop=True)
    return snap

# Build parent→children and child→parent maps once (lazy init so the script still runs if file missing)
__SNAP_REL__ = None
__P2C__ = None
__C2P__ = None

def _ensure_graph():
    global __SNAP_REL__, __P2C__, __C2P__
    if __SNAP_REL__ is None:
        __SNAP_REL__ = build_is_a_snapshot(SNOMED_REL_FULL_US, characteristic="inferred")
    if __P2C__ is None or __C2P__ is None:
        __P2C__ = defaultdict(set)
        __C2P__ = defaultdict(set)
        # RF2: destinationId = parent, sourceId = child
        for parent, child in zip(__SNAP_REL__["destinationId"], __SNAP_REL__["sourceId"]):
            __P2C__[parent].add(child)
            __C2P__[child].add(parent)

def find_descendants_sct(concept_id: str) -> set:
    """All is-a descendants (children, grandchildren, ...) for a SNOMED conceptId."""
    _ensure_graph()
    out = set()
    q = deque([concept_id])
    while q:
        cur = q.popleft()
        for kid in __P2C__.get(cur, ()):
            if kid not in out:
                out.add(kid)
                q.append(kid)
    return out

def find_ancestors_sct(concept_id: str) -> set:
    """All is-a ancestors (parents, grandparents, ...) for a SNOMED conceptId."""
    _ensure_graph()
    out = set()
    q = deque([concept_id])
    while q:
        cur = q.popleft()
        for mom in __C2P__.get(cur, ()):
            if mom not in out:
                out.add(mom)
                q.append(mom)
    return out

def extract_snomed_id(code: str) -> Optional[str]:
    """
    If code looks like 'dx_SNOMED_<digits>' return '<digits>', else None.
    """
    m = re.fullmatch(r"dx_SNOMED_(\d+)", str(code))
    return m.group(1) if m else None

def snomed_aliases_for_outcome(outcome_code: str, *, include_parents: bool = True) -> Tuple[Optional[Set[str]], Dict[str, str]]:
    """
    Given an outcome like 'dx_SNOMED_254645002', return:
      - aliases_codes: {'dx_SNOMED_<id>', ...} including the root itself (or None if not SNOMED)
      - name_map: mapping every alias_code -> outcome_code (canonical)
    If include_parents=True, add all ancestors to the alias family as well.
    """
    root_id = extract_snomed_id(outcome_code)
    if root_id is None:
        return None, {}

    # descendants
    desc = find_descendants_sct(root_id)
    all_ids = {root_id} | set(desc)

    # parents/ancestors (optional)
    if include_parents:
        ancs = find_ancestors_sct(root_id)
        all_ids |= set(ancs)

    aliases_codes: Set[str] = {f"dx_SNOMED_{sid}" for sid in all_ids}
    # map every alias (except the canonical root code string itself) back to outcome_code
    name_map: Dict[str, str] = {alias: outcome_code for alias in aliases_codes if alias != outcome_code}
    return aliases_codes, name_map


## Update 10 23 25
def analyze_causal_sequence_py(
    df_in: Union[str, pd.DataFrame],
    *,
    name_map: Dict[str, str] = None,     # raw_code -> friendly label (applied to BOTH columns)
    events: List[str] = None,            # event names to KEEP (AFTER recoding). If None: auto-detect
    force_outcome: str = None,           # if provided and present, force this to be the FINAL node (Y)
    lambda_min_count: int = 15           # L-threshold for λ: if n_code < L ⇒ λ_{k,j}=0
) -> Dict[str, Any]:
    """
    MAGI (Python): Reference routine with explicit comments.

    ─────────────────────────────────────────────────────────────────────────────
    COLUMN CONVENTION PER ROW:
      Left  column: target_concept_code  (call this X for this row)
      Right column: concept_code         (call this k for this row)

      n_code_target        ≡  k ∧ X
      n_code_no_target     ≡  k ∧ ¬X
      n_target_no_code     ≡  ¬k ∧ X
      n_no_target          ≡  total(¬X)
      n_target             ≡  total(X)
      n_code               ≡  total(k)
      n_code_before_target ≡  count(k before X)
      n_target_before_code ≡  count(X before k)

    ORIENTATION (locked to your spec):
      • Total effect T_{kY}:   read row (target = Y, code = k).
      • Lambda     λ_{k,j}:    read row (target = j, code = k), and compute
                               λ_{k,j} = n_code_target / n_code  with L-threshold on n_code.

    TEMPORAL SCORE for each node Zi:
      Score(Z_i) = Σ_{j≠i} [ C(Z_i≺Z_j) - C(Z_j≺Z_i) + C(Z_i∧¬Z_j) - C(Z_j∧¬Z_i) ]
      Read from row (target=Z_i, code=Z_j):
        - n_code_before_target      → C(Z_j≺Z_i)
        - n_target_before_code      → C(Z_i≺Z_j)
        - n_code_no_target          → C(Z_j∧¬Z_i)
        - n_target_no_code          → C(Z_i∧¬Z_j)

    T_{kY} from row (Y, k):
      a = n_code_target        (k ∧ Y)
      b = n_code_no_target     (k ∧ ¬Y)
      c = n_target_no_code     (¬k ∧ Y)
      d = n_no_target - b      (¬k ∧ ¬Y)   ← computed on the fly (no extra column needed)
      With sample-size–anchored odds:
         odds_k1 = a/b with guards; odds_k0 = c/d with guards; T = odds_k1 / odds_k0

    DIRECT EFFECTS via backward recursion:
      D_{k,Y} = ( T_{k,Y} - Σ_j λ_{k,j} D_{j,Y} ) / ( 1 - Σ_j λ_{k,j} ),
      where j are downstream nodes between k and Y in the temporal order.

    LOGISTIC LINK:
      logit P(Y=1 | Z) = β0 + Σ_k β_k Z_k  with β_k = log D_{k,Y};
      invalid/nonpositive D map to β_k=0.

    RETURNS a dict with:
      sorted_scores, temporal_order, order_used,
      T_val, D_val, lambda_l, coef_df, beta_0, beta, logit_predictors, predict_proba
    """

    # ── 0) Ingest & validate ───────────────────────────────────────────────────
    df = pd.read_csv(df_in) if isinstance(df_in, str) else df_in.copy()

    need_cols = [
        "target_concept_code", "concept_code",
        "n_code_target", "n_code_no_target",
        "n_target", "n_no_target",
        "n_target_no_code",
        "n_code",                                 # for λ denominator
        "n_code_before_target", "n_target_before_code",
    ]
    missing = [c for c in need_cols if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns: {', '.join(missing)}")

    # Optional recoding to friendly labels (applied to BOTH endpoints)
    if name_map:
        df["target_concept_code"] = df["target_concept_code"].replace(name_map)

    # Limit to selected events (union if auto)
    if events is None:
        ev_t = df["target_concept_code"].astype(str).unique().tolist()
        ev_c = df["concept_code"].astype(str).unique().tolist()
        events = sorted(set(ev_t) | set(ev_c))
    else:
        # Normalize types to avoid silent mismatches
        events = [str(e) for e in events]
        df["target_concept_code"] = df["target_concept_code"].astype(str)
        df["concept_code"] = df["concept_code"].astype(str)
    if len(events) < 2:
        raise ValueError("Need at least two events.")

    df = df[df["target_concept_code"].isin(events) & df["concept_code"].isin(events)].copy()

    # Coerce numerics; NA→0 to allow safe sums/max
    num_cols = [
        "n_code_target", "n_code_no_target",
        "n_target", "n_no_target", "n_target_no_code",
        "n_code", "n_code_before_target", "n_target_before_code",
    ]
    for c in num_cols:
        df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0.0)

    # ── 1) Temporal score (read rows: target=Z_i, code=Z_j) ────────────────────
    scores: Dict[str, float] = {}
    for zi in events:
        s = 0.0
        for zj in (x for x in events if x != zi):
            # Row oriented as (Zi, Zj)
            pair = df[(df["target_concept_code"] == zi) & (df["concept_code"] == zj)]
            if pair.empty:
                continue
            c_i_before_j = float(pair["n_target_before_code"].sum(skipna=True))   # Zi before Zj
            c_j_before_i = float(pair["n_code_before_target"].sum(skipna=True))   # Zj before Zi
            c_i_and_not_j = float(pair["n_target_no_code"].sum(skipna=True))      # Zi ∧ ¬Zj
            c_j_and_not_i = float(pair["n_code_no_target"].sum(skipna=True))      # Zj ∧ ¬Zi
            s += (c_i_before_j - c_j_before_i + c_i_and_not_j - c_j_and_not_i)
        scores[zi] = s

    sorted_scores = pd.Series(scores).sort_values(ascending=False)

    # Choose outcome Y: either forced or top-scoring node
    if force_outcome and (force_outcome in sorted_scores.index):
        outcome = force_outcome
        temporal_order = [ev for ev in sorted_scores.index if ev != outcome] + [outcome]
    else:
        outcome = sorted_scores.index[0]
        temporal_order = [ev for ev in sorted_scores.index if ev != outcome] + [outcome]

    events_order = temporal_order              # earliest … → Y
    nodes = events_order[:-1]            # everything before Y

    # ── 2) T and λ (row orientations locked) ───────────────────────────────────
    T_val = pd.Series(0.0, index=nodes, dtype=float)
    D_val = pd.Series(np.nan, index=nodes, dtype=float)
    lambda_l: Dict[str, pd.Series] = {}

    for k in nodes:
        # ---- T_{kY} from the single row (target=Y, code=k) ----
        row_Yk = df[(df["target_concept_code"] == outcome) & (df["concept_code"] == k)]

        # 2×2 cells (a,b,c,d) at (Y,k):
        a = float(row_Yk["n_code_target"].sum(skipna=True))        # k ∧ Y
        b = float(row_Yk["n_code_no_target"].sum(skipna=True))     # k ∧ ¬Y
        c = float(row_Yk["n_target_no_code"].sum(skipna=True))     # ¬k ∧ Y
        # d is not in data; compute from the same row: d = total(¬Y) − (k ∧ ¬Y)
        n_noY = float(row_Yk["n_no_target"].max(skipna=True)) if not row_Yk.empty else 0.0
        d = max(n_noY - b, 0.0)                                    # ¬k ∧ ¬Y

        # Stratum sizes
        N1 = a + b                   # k = 1
        N0 = c + d                   # k = 0

        # Sample-size–anchored odds for k=1 and k=0
        if N1 == 0:
            odds_k1 = 1.0
        else:
            if a == 0:
                odds_k1 = 1.0 / (N1 + 1.0)
            elif b == 0:
                odds_k1 = (N1 + 1.0)
            else:
                odds_k1 = a / b

        if N0 == 0:
            odds_k0 = 1.0
        else:
            if c == 0:
                odds_k0 = 1.0 / (N0 + 1.0)
            elif d == 0:
                odds_k0 = (N0 + 1.0)
            else:
                odds_k0 = c / d

        T_val.loc[k] = float(odds_k1 / odds_k0) if odds_k0 > 0 else (N1 + 1.0)

        # ---- λ_{k,j} from rows (target=j, code=k): λ = n_code_target / n_code ----
        pos_k = events_order.index(k)
        js = events_order[pos_k + 1 : -1] if pos_k < len(events_order) - 1 else []

        lam_pairs = []
        for j in js:
            row_jk = df[(df["target_concept_code"] == j) & (df["concept_code"] == k)]
            if row_jk.empty:
                lam_pairs.append((j, 0.0))
                continue

            num = float(pd.to_numeric(row_jk["n_code_target"], errors="coerce").fillna(0.0).sum())
            den = float(pd.to_numeric(row_jk["n_code"],        errors="coerce").fillna(0.0).sum())

            # L-threshold on the conditioning size n_code (count of k)
            if (den <= 0) or (den < lambda_min_count):
                lam_pairs.append((j, 0.0))
                continue

            lam = num / den
            lam = 0.0 if not np.isfinite(lam) else float(min(max(lam, 0.0), 1.0))
            lam_pairs.append((j, lam))

        lambda_l[k] = pd.Series({j: v for j, v in lam_pairs}, dtype=float)

    # ── 3) Backward recursion for direct effects D ─────────────────────────────
    # Last node (just before Y): no downstream → D = T
    if len(nodes) >= 1:
        last_anc = nodes[-1]
        D_val.loc[last_anc] = T_val.loc[last_anc]

    # Walk backward for the rest
    if len(nodes) > 1:
        for k in list(reversed(nodes[:-1])):
            lam_vec = lambda_l.get(k, pd.Series(dtype=float))
            downstream = list(lam_vec.index)  # j nodes after k (already resolved)
            lam_vals = lam_vec.reindex(downstream).fillna(0.0).to_numpy()
            D_down  = pd.to_numeric(D_val.reindex(downstream), errors="coerce").fillna(0.0).to_numpy()

            num = T_val.loc[k] - float(np.nansum(lam_vals * D_down))
            den = 1.0 - float(np.nansum(lam_vals))

            if (not np.isfinite(den)) or den == 0.0:
                D_val.loc[k] = T_val.loc[k]            # neutralize if pathological
            else:
                tmp = num / den
                D_val.loc[k] = tmp if np.isfinite(tmp) else T_val.loc[k]

    # ── 4) Logistic link (β) and predict_proba ─────────────────────────────────
    # Intercept β0 from marginal prevalence of Y (rows with target == Y)
    resp_rows = df[df["target_concept_code"] == outcome]
    n_t = float(resp_rows["n_target"].max(skipna=True)) if not resp_rows.empty else np.nan
    n_n = float(resp_rows["n_no_target"].max(skipna=True)) if not resp_rows.empty else np.nan
    denom = n_t + n_n
    p_y = 0.5 if (not np.isfinite(denom) or denom <= 0) else (n_t / denom)
    p_y = min(max(p_y, 1e-12), 1 - 1e-12)
    beta_0 = float(np.log(p_y / (1 - p_y)))

    # β_k = log D_{k,Y}; invalid/nonpositive → 0
    D_clean = pd.to_numeric(D_val, errors="coerce").astype(float)
    beta_vals = np.log(D_clean.where(D_clean > 0.0)).replace([np.inf, -np.inf], np.nan).fillna(0.0)

    coef_df = pd.DataFrame({
        "predictor": list(beta_vals.index) + ["(intercept)"],
        "beta":      list(beta_vals.values) + [beta_0],
    })

    # Vectorized predict_proba
    predictors = list(beta_vals.index)
    beta_vec = beta_vals.values

    def predict_proba(Z: Union[Dict[str, Any], pd.Series, np.ndarray, List[float], pd.DataFrame]):
        """
        Compute P(Y=1|Z) using: logit P = β0 + Σ_k β_k Z_k.
        Z can be:
          - dict/Series mapping predictor name -> 0/1
          - 1D/2D numpy/list with columns ordered as `predictors`
          - DataFrame containing any/all of `predictors` (others ignored)
        """
        def sigmoid(x):
            x = np.clip(x, -700, 700)  # numerical stability for large |η|
            return 1.0 / (1.0 + np.exp(-x))

        if isinstance(Z, pd.DataFrame):
            M = Z.reindex(columns=predictors, fill_value=0.0).astype(float).to_numpy()
            return sigmoid(beta_0 + M @ beta_vec)

        if isinstance(Z, (dict, pd.Series)):
            v = np.array([float(Z.get(p, 0.0)) for p in predictors], dtype=float)
            return float(sigmoid(beta_0 + float(v @ beta_vec)))

        arr = np.asarray(Z, dtype=float)
        if arr.ndim == 1:
            if arr.size != len(predictors):
                raise ValueError(f"Expected {len(predictors)} features in order: {predictors}")
            return float(sigmoid(beta_0 + float(arr @ beta_vec)))
        if arr.ndim == 2:
            if arr.shape[1] != len(predictors):
                raise ValueError(f"Expected shape (*,{len(predictors)}), got {arr.shape}")
            return sigmoid(beta_0 + arr @ beta_vec)

        raise ValueError("Unsupported input for predict_proba")

    # ── 5) Package results ─────────────────────────────────────────────────────
    return {
        "sorted_scores": sorted_scores,
        "temporal_order": temporal_order,
        "order_used": events_order,
        "T_val": T_val,
        "D_val": D_val,
        "lambda_l": lambda_l,
        "coef_df": coef_df,
        "beta_0": beta_0,
        "beta": pd.Series(beta_vec, index=predictors, dtype=float),
        "logit_predictors": predictors,
        "predict_proba": predict_proba,
    }

def _ensure_derived_cols(df: pd.DataFrame) -> pd.DataFrame:
    """
    Only coerce numerics and compute total_effect if absent, using provided cells and
    computing d = n_no_target - n_code_no_target on the fly (since d isn't in data).
    """
    required = [
        "n_code_target", "n_code_no_target",
        "n_target_no_code", "n_target", "n_no_target"
    ]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns for TE: {', '.join(missing)}")

    out = df.copy()
    for c in required:
        out[c] = pd.to_numeric(out[c], errors="coerce").fillna(0.0)

    if "total_effect" not in out.columns:
        a = out["n_code_target"].astype(float)
        b = out["n_code_no_target"].astype(float)
        c = out["n_target_no_code"].astype(float)
        d = (out["n_no_target"] - out["n_code_no_target"]).astype(float)
        d = np.maximum(d, 0.0)

        N1 = a + b
        N0 = c + d

        with np.errstate(divide="ignore", invalid="ignore"):
            odds_k1 = np.where(
                N1 == 0, 1.0,
                np.where((a > 0) & (b > 0), a / b,
                         np.where(b == 0, N1 + 1.0, 1.0 / (N1 + 1.0)))
            )
            odds_k0 = np.where(
                N0 == 0, 1.0,
                np.where((c > 0) & (d > 0), c / d,
                         np.where(d == 0, N0 + 1.0, 1.0 / (N0 + 1.0)))
            )
            te = odds_k1 / odds_k0

        out["total_effect"] = np.where(np.isfinite(te), te, 1.0).astype(float)

    return out


# TOPK_HI/LO
def _select_top_by_te_unique_k(k_to_T: pd.DataFrame, top_k: int = TOP_K,
                               hi: float = 1.5) -> pd.DataFrame:
    """
    Balanced selection without fallback:
      - Take up to top_k//2 strongest risk (TE >= hi)
      - Take up to top_k - top_k//2 strongest protective (TE <= 1/hi)
      - If either side is short, do NOT fill from elsewhere; total may be < top_k
      - Return one row per k (**concept_code**), ranked by extremeness
    """
    df = k_to_T.copy()

    # Ensure total_effect numeric
    df["total_effect"] = pd.to_numeric(df["total_effect"], errors="coerce")
    df = df.dropna(subset=["total_effect"]).copy()
    if df.empty:
        print("[SELECT] no rows with total_effect; returning empty selection.")
        return df

    # Extremeness symmetrical around 1
    df["effect_strength"] = np.where(df["total_effect"] >= 1.0,
                                     df["total_effect"],
                                     1.0 / df["total_effect"])

    # One row per k: **k is concept_code in (target=Y, code=k)**
    best_per_k = (df.sort_values("effect_strength", ascending=False)
                    .drop_duplicates(subset=["concept_code"], keep="first"))

    HI = hi
    LO = 1.0 / hi

    risk_pool = best_per_k[best_per_k["total_effect"] >= HI].copy()
    prot_pool = best_per_k[best_per_k["total_effect"] <= LO].copy()

    want_risk = top_k // 2
    want_prot = top_k - want_risk

    sel_risk = risk_pool.nlargest(min(len(risk_pool), want_risk), "effect_strength")
    sel_prot = prot_pool.nlargest(min(len(prot_pool), want_prot), "effect_strength")

    selected = pd.concat([sel_risk, sel_prot], ignore_index=True)

    print(f"[SELECT] total unique k={best_per_k['concept_code'].nunique():,}  "
          f"risk={len(risk_pool):,}  prot={len(prot_pool):,}  "
          f"selected(total)={selected['concept_code'].nunique()}  "
          f"with: risk={len(sel_risk)}, prot={len(sel_prot)}")

    return selected.reset_index(drop=True)


def _fetch_k_to_T(conn, outcome_code: str) -> pd.DataFrame:
    q = f"""
      SELECT m.*,
             tcn.concept_code AS target_concept_code,
             ccn.concept_code AS concept_code
      FROM {EDGE_TABLE} m
      JOIN concept_names tcn ON m.target_concept_code_int = tcn.concept_code_int
      JOIN concept_names ccn ON m.concept_code_int        = ccn.concept_code_int
      WHERE tcn.concept_code = ?
    """
    return pd.read_sql_query(q, conn, params=[outcome_code])

def _fetch_k_to_T_in(conn, target_codes: List[str]) -> pd.DataFrame:
    """
    Fetch rows whose LEFT (target) is in target_codes.
    This is used to include alias outcomes (descendants of T) on the LEFT.
    """
    tmp_name = "tmp_targets_magi"
    with conn:
        conn.execute(f"DROP TABLE IF EXISTS {tmp_name}")
        conn.execute(f"CREATE TEMP TABLE {tmp_name}(concept_code TEXT)")
        conn.executemany(f"INSERT INTO {tmp_name}(concept_code) VALUES (?)", [(c,) for c in target_codes])

    q = f"""
      SELECT m.*,
             tcn.concept_code AS target_concept_code,   -- LEFT  (Y or Y-alias)
             ccn.concept_code AS concept_code           -- RIGHT (k)
      FROM {EDGE_TABLE} m
      JOIN concept_names tcn ON m.target_concept_code_int = tcn.concept_code_int
      JOIN concept_names ccn ON m.concept_code_int        = ccn.concept_code_int
      JOIN {tmp_name} tmp     ON tcn.concept_code         = tmp.concept_code   -- <<< LEFT filter
    """
    return pd.read_sql_query(q, conn)

def _fetch_subgraph_by_targets(conn, events_list):
    """
    Induced subgraph where LEFT is in events_list.
    """
    ph = ",".join(["?"] * len(events_list))
    q = f"""
      SELECT m.*,
             tcn.concept_code AS target_concept_code,   -- LEFT
             ccn.concept_code AS concept_code           -- RIGHT
      FROM {EDGE_TABLE} m
      JOIN concept_names tcn ON m.target_concept_code_int = tcn.concept_code_int
      JOIN concept_names ccn ON m.concept_code_int        = ccn.concept_code_int
      WHERE tcn.concept_code IN ({ph})
    """
    return pd.read_sql_query(q, conn, params=list(events_list))

def expand_targets_with_descendants(targets: List[str]) -> Dict[str, Set[str]]:
    out: Dict[str, Set[str]] = {}
    for T in targets:
        aliases, _name_map = snomed_aliases_for_outcome(T, include_parents=True)
        out[T] = aliases if aliases is not None else {T}
    return out

def print_full_targets(targets: List[str], *, save_csv: bool = True, out_dir: str = "./"):
    """
    Pretty-print the expanded target definitions and (optionally) save a CSV with all rows:
      root_target, alias_code, alias_conceptId, is_root
    """
    expanded = expand_targets_with_descendants(targets)

    # pretty print to console
    for T, alias_set in expanded.items():
        n = len(alias_set)
        sample = ", ".join(sorted(list(alias_set))[:10])
        print("=" * 100)
        print(f"[TARGET] {T}")
        print(f"  aliases (parents + descendants + root): {n}")
        print(f"  sample: {sample}{' ...' if n > 10 else ''}")

    # optional CSV dump
    if save_csv:
        rows = []
        for T, alias_set in expanded.items():
            root_id = extract_snomed_id(T)
            for alias in sorted(alias_set):
                alias_id = extract_snomed_id(alias)
                rows.append({
                    "root_target": T,
                    "alias_code": alias,
                    "alias_conceptId": alias_id if alias_id else "",
                    "is_root": (alias == T)
                })
        df = pd.DataFrame(rows, columns=["root_target","alias_code","alias_conceptId","is_root"])
        ts = time.strftime("%Y%m%d-%H%M%S")
        os.makedirs(out_dir, exist_ok=True)
        path = os.path.join(out_dir, f"targets_expanded_{ts}.csv")
        df.to_csv(path, index=False)
        print(f"\n[SAVED] Full target expansions → {path}")

    return expanded

OMOP_DB_PATH = OMOP_DB_PATH if 'OMOP_DB_PATH' in globals() else './omop_athena.db'

def _ensure_omop_attached(conn, omop_db_path: str):
    # attach omop if not already
    schemas = conn.execute("PRAGMA database_list").fetchall()
    if not any(row[1] == 'omop' for row in schemas):
        conn.execute(f"ATTACH DATABASE '{omop_db_path}' AS omop")

def _load_temp_codes(conn, table_name: str, codes: set):
    conn.execute(f"DROP TABLE IF EXISTS {table_name}")
    conn.execute(f"CREATE TEMP TABLE {table_name}(concept_code TEXT PRIMARY KEY)")
    if codes:
        conn.executemany(
            f"INSERT OR IGNORE INTO {table_name}(concept_code) VALUES (?)",
            [(str(c),) for c in codes]
        )
    conn.execute(f"CREATE INDEX IF NOT EXISTS ix_{table_name}_code ON {table_name}(concept_code)")

def _fetch_subgraph_by_targets_fast(conn, events_set):
    _load_temp_codes(conn, "tmp_events", set(map(str, events_set)))
    q = f"""
      SELECT m.*,
             tcn.concept_code AS target_concept_code,
             ccn.concept_code AS concept_code
      FROM {EDGE_TABLE} m
      JOIN concept_names tcn ON m.target_concept_code_int = tcn.concept_code_int
      JOIN concept_names ccn ON m.concept_code_int        = ccn.concept_code_int
      JOIN tmp_events te1    ON te1.concept_code          = tcn.concept_code
      JOIN tmp_events te2    ON te2.concept_code          = ccn.concept_code
    """
    return pd.read_sql_query(q, conn)


# ====== OMOP "Maps to" helpers (defensive; return empty sets if tables missing) ======
def _get_domain_vocab(conn, magi_code: str):
    row = pd.read_sql_query("""
        SELECT domain_id, vocabulary_id
        FROM omop.concept_names
        WHERE concept_code = ?
        LIMIT 1
    """, conn, params=[magi_code])
    if row.empty:
        return None, None
    return str(row["domain_id"].iloc[0]), str(row["vocabulary_id"].iloc[0])
    
def _get_maps_to_sources(conn, standard_codes: set) -> set:
    """
    Given standard MAGI codes (e.g., dx_SNOMED_*), return all MAGI codes that map TO them.
    Uses omop.concept_relationship with relationship_id in ('MAPS TO','MAPS TO VALUE').
    Implemented via TEMP table join to avoid SQLite param limits.
    """
    _ensure_omop_attached(conn, OMOP_DB_PATH)  # ensure 'omop' is attached and concept_names view exists
    tmp = "tmp_std_codes"
    _load_temp_codes(conn, tmp, standard_codes)

    q = f"""
      WITH std AS (
        SELECT DISTINCT c.concept_id
        FROM omop.concept_names cn
        JOIN omop.concept c ON c.concept_id = cn.concept_id
        JOIN {tmp} t        ON t.concept_code = cn.concept_code
      )
      SELECT DISTINCT src_cn.concept_code AS concept_code
      FROM std
      JOIN omop.concept_relationship cr ON cr.concept_id_2 = std.concept_id
      JOIN omop.concept src            ON src.concept_id   = cr.concept_id_1
      JOIN omop.concept_names src_cn   ON src_cn.concept_id= src.concept_id
      WHERE cr.relationship_id IN ('Maps to','Maps to value','MAPS TO','MAPS TO VALUE')
        AND IFNULL(cr.invalid_reason,'') = ''
    """
    rows = pd.read_sql_query(q, conn)
    return set(rows["concept_code"].astype(str)) if not rows.empty else set()

def _get_maps_to_targets(conn, src_codes: set, restrict_domain: str = None, restrict_vocab: str = None) -> set:
    """
    Given MAGI codes (sources), return the MAGI-style standard codes they map TO.
    Optional filters: restrict_domain/restrict_vocab on the STANDARD side.
    Uses TEMP table join to avoid SQLite param limits.
    """
    _ensure_omop_attached(conn, OMOP_DB_PATH)
    tmp = "tmp_src_codes"
    _load_temp_codes(conn, tmp, src_codes)

    filt = []
    if restrict_domain:
        filt.append("UPPER(std.domain_id)=UPPER(?)")
    if restrict_vocab:
        filt.append("UPPER(std.vocabulary_id)=UPPER(?)")
    where_extra = (" AND " + " AND ".join(filt)) if filt else ""

    params = []
    if restrict_domain: params.append(restrict_domain)
    if restrict_vocab:  params.append(restrict_vocab)

    q = f"""
      WITH src AS (
        SELECT DISTINCT c.concept_id
        FROM omop.concept_names cn
        JOIN omop.concept c ON c.concept_id = cn.concept_id
        JOIN {tmp} t        ON t.concept_code = cn.concept_code
      )
      SELECT DISTINCT std_cn.concept_code AS concept_code
      FROM src
      JOIN omop.concept_relationship cr ON cr.concept_id_1 = src.concept_id
      JOIN omop.concept std            ON std.concept_id   = cr.concept_id_2
      JOIN omop.concept_names std_cn   ON std_cn.concept_id= std.concept_id
      WHERE cr.relationship_id IN ('Maps to','Maps to value','MAPS TO','MAPS TO VALUE')
        AND IFNULL(cr.invalid_reason,'') = ''
        {where_extra}
    """
    rows = pd.read_sql_query(q, conn, params=params)
    return set(rows["concept_code"].astype(str)) if not rows.empty else set()

def _name_like_disease(conn, codes: Set[str], like_patterns: List[str]) -> Set[str]:
    """
    Drop guard: return codes whose concept_name matches any SQL LIKE pattern (case-insensitive).
    Patterns should be raw SQL LIKE strings e.g., '%malignant neoplasm%breast%'.
    """
    if not codes or not like_patterns:
        return set()
    try:
        ph_codes = ",".join(["?"] * len(codes))
        ph_like  = " OR ".join([f"LOWER(concept_name) LIKE LOWER(?)" for _ in like_patterns])
        q = f"""
          SELECT DISTINCT concept_code
          FROM concept_names
          WHERE concept_code IN ({ph_codes})
            AND ({ph_like})
        """
        params = list(codes) + like_patterns
        rows = pd.read_sql_query(q, conn, params=params)
        return set(rows["concept_code"].astype(str)) if not rows.empty else set()
    except Exception:
        return set()

def _to_magi_sct(concept_id: str) -> str:
    return f"dx_SNOMED_{concept_id}"

def _from_magi_sct(magi_code: str) -> Optional[str]:
    return extract_snomed_id(magi_code)

def _descendants_of_codes(codes: Set[str]) -> Set[str]:
    """
    RF2-backed: for any MAGI codes, return all SNOMED descendants (as MAGI codes).
    Non-SNOMED inputs are ignored.
    """
    if not codes:
        return set()
    _ensure_graph()
    out: Set[str] = set()
    for code in codes:
        sid = _from_magi_sct(code)
        if not sid:
            continue
        for d in find_descendants_sct(sid):
            out.add(_to_magi_sct(d))
    return out

def _descendants_intersect_family(k: str, family: Set[str]) -> bool:
    """
    True if predictor k (SNOMED) has any descendant that lands in `family`.
    """
    sid = _from_magi_sct(k)
    if not sid:
        return False
    kids = find_descendants_sct(sid)
    if not kids:
        return False
    kids_magi = {_to_magi_sct(x) for x in kids}
    return bool(kids_magi & family)

def _drop_k_with_descendants_in_family(conn, k_codes: set, family_codes: set) -> set:
    # temp tables
    _load_temp_codes(conn, "tmp_k", k_codes)
    _load_temp_codes(conn, "tmp_family", family_codes)

    q = """
      WITH k_ids AS (
        SELECT cn.concept_id, cn.concept_code
        FROM tmp_k t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
        WHERE cn.vocabulary_id='SNOMED' AND cn.domain_id='Condition'
      ),
      fam_ids AS (
        SELECT cn.concept_id
        FROM tmp_family t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
        WHERE cn.vocabulary_id='SNOMED' AND cn.domain_id='Condition'
      )
      SELECT DISTINCT k.concept_code
      FROM k_ids k
      JOIN omop.concept_ancestor ca ON ca.ancestor_concept_id = k.concept_id
      JOIN fam_ids f               ON f.concept_id          = ca.descendant_concept_id
    """
    rows = pd.read_sql_query(q, conn)
    return set(rows["concept_code"].astype(str)) if not rows.empty else set()

def drop_related_predictors(conn, T: str, k_to_T: pd.DataFrame,
                            alias_family_rf2: set = None) -> pd.DataFrame:
    """
    Removes from k_to_T any predictor k that:
      - is in the expanded outcome family (T ∪ parents ∪ descendants ∪ maps-to sources ∪ their descendants),
      - maps to any concept in that family (same-standard peers),
      - has SNOMED descendants intersecting the family.
    """
    _ensure_omop_attached(conn, "./omop_athena.db")

    # seed family: T plus RF2 family if you already computed it
    family_seed = set(alias_family_rf2 or set()) | {T}
    _load_temp_codes(conn, "tmp_family_seed", family_seed)

    # all sources that map to the family (step 1: “any concept that Maps to the target (and their descendants)”)
    q_src = """
      WITH fam AS (
        SELECT cn.concept_id
        FROM tmp_family_seed t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
      )
      SELECT DISTINCT src_cn.concept_code AS concept_code
      FROM fam
      JOIN omop.concept_relationship cr ON cr.concept_id_2 = fam.concept_id
      JOIN omop.concept src            ON src.concept_id   = cr.concept_id_1
      JOIN omop.concept_names src_cn   ON src_cn.concept_id= src.concept_id
      WHERE cr.relationship_id IN ('Maps to','Maps to value','MAPS TO','MAPS TO VALUE')
        AND IFNULL(cr.invalid_reason,'') = ''
    """
    maps_to_sources = set(pd.read_sql_query(q_src, conn)["concept_code"])

    # descendants of those sources (SNOMED/Condition only)
    _load_temp_codes(conn, "tmp_maps_src", maps_to_sources)
    q_src_desc = """
      WITH src_ids AS (
        SELECT cn.concept_id
        FROM tmp_maps_src t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
        WHERE cn.vocabulary_id='SNOMED' AND cn.domain_id='Condition'
      )
      SELECT DISTINCT cn.concept_code
      FROM src_ids s
      JOIN omop.concept_ancestor ca ON ca.ancestor_concept_id = s.concept_id
      JOIN omop.concept_names cn    ON cn.concept_id          = ca.descendant_concept_id
    """
    maps_to_desc = set(pd.read_sql_query(q_src_desc, conn)["concept_code"])

    expanded_family = family_seed | maps_to_sources | maps_to_desc
    _load_temp_codes(conn, "tmp_family_all", expanded_family)

    # PREP: predictors set (RIGHT==k in your k→T)
    k_codes = set(k_to_T["concept_code"].astype(str))
    _load_temp_codes(conn, "tmp_k", k_codes)

    # (2a) drop k ∈ expanded family
    drop_in_family = set(pd.read_sql_query("""
        SELECT t.concept_code FROM tmp_k t
        INNER JOIN tmp_family_all f ON f.concept_code = t.concept_code
    """, conn)["concept_code"])

    # (2b) drop k that maps-to any concept in expanded family
    drop_maps_to_family = set(pd.read_sql_query("""
      WITH fam AS (
        SELECT cn.concept_id
        FROM tmp_family_all t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
      ),
      k_ids AS (
        SELECT cn.concept_id AS k_id, cn.concept_code AS k_code
        FROM tmp_k t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
      )
      SELECT DISTINCT k.k_code
      FROM k_ids k
      JOIN omop.concept_relationship cr ON cr.concept_id_1 = k.k_id
      WHERE cr.relationship_id IN ('Maps to','Maps to value','MAPS TO','MAPS TO VALUE')
        AND IFNULL(cr.invalid_reason,'') = ''
        AND EXISTS (SELECT 1 FROM fam WHERE fam.concept_id = cr.concept_id_2)
    """, conn)["k_code"])

    # (2c) drop k having SNOMED descendants intersecting expanded family
    drop_desc_intersect = set(pd.read_sql_query("""
      WITH k_ids AS (
        SELECT cn.concept_id, cn.concept_code
        FROM tmp_k t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
        WHERE cn.vocabulary_id='SNOMED' AND cn.domain_id='Condition'
      ),
      fam_ids AS (
        SELECT cn.concept_id
        FROM tmp_family_all t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
        WHERE cn.vocabulary_id='SNOMED' AND cn.domain_id='Condition'
      )
      SELECT DISTINCT k.concept_code
      FROM k_ids k
      JOIN omop.concept_ancestor ca ON ca.ancestor_concept_id = k.concept_id
      JOIN fam_ids f               ON f.concept_id          = ca.descendant_concept_id
    """, conn)["concept_code"])

    # (3) sibling same-standard → exclude (map k to same standard in same domain/vocab as T)
    td, tv = _get_domain_vocab(conn, T)  # your helper; returns ('Condition','SNOMED') for SNOMED dx
    drop_same_std = set(pd.read_sql_query("""
      WITH fam AS (
        SELECT cn.concept_id
        FROM tmp_family_all t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
      ),
      k_ids AS (
        SELECT cn.concept_id AS k_id, cn.concept_code AS k_code
        FROM tmp_k t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
      )
      SELECT DISTINCT k.k_code
      FROM k_ids k
      JOIN omop.concept_relationship cr ON cr.concept_id_1 = k.k_id
      JOIN omop.concept std            ON std.concept_id   = cr.concept_id_2
      WHERE cr.relationship_id IN ('Maps to','Maps to value','MAPS TO','MAPS TO VALUE')
        AND IFNULL(cr.invalid_reason,'') = ''
        AND UPPER(std.domain_id)    = UPPER(?)
        AND UPPER(std.vocabulary_id)= UPPER(?)
        AND EXISTS (SELECT 1 FROM fam WHERE fam.concept_id = cr.concept_id_2)
    """, conn, params=[td, tv])["k_code"])

    drops = drop_in_family | drop_maps_to_family | drop_desc_intersect | drop_same_std
    keep  = k_codes - drops
    return k_to_T[k_to_T["concept_code"].isin(keep)].copy()

# --------------------------------------------------------------------
# ORIENTATION REMINDER (per your spec)
#   LEFT  column  = target_concept_code
#   RIGHT column  = concept_code
#
#   Total effect  T_{k,Y}  ⇒ read row where (target = Y,   code = k)
#   Lambda        λ_{k,j}  ⇒ read row where (target = j,   code = k)
#
#   ⇒ Predictors k are ALWAYS on the RIGHT side (concept_code).
#   
#   Step-1 to seed with {T} ∪ descendants ∪ parents and then add Maps-to sources of any member of that set (but not descendants-of-sources).
# --------------------------------------------------------------------
def _apply_decision_tree_filters(
    conn,
    T: str,
    k_to_T: pd.DataFrame,
    *,
    alias_family: Set[str],              # <- pass the {T}+descendants+parents set from main loop
    jaccard_drop: float = 0.88,
    ppv_drop: float = 0.90,
    disease_like_patterns: Optional[List[str]] = None,
    verbose: bool = True,
    max_examples: int = 6,
) -> pd.DataFrame:
    """
    Decision tree Steps 1–4 (parents INCLUDED in Step-1 seed):

      1) Expanded family = {T} ∪ descendants ∪ parents (RF2 alias_family you pass here)
         ∪ MAPS-TO sources of any member of that set.
         (No 'descendants of sources' expansion.)

      2) Drop k if:
           (a) k ∈ expanded family,
           (b) k MAPS-TO any member of expanded family,
           (c) k has DESCENDANTS (not parents) intersecting expanded family.

      3) Sibling test:
           • if k MAPS-TO the SAME standard as T (same domain+vocab) → drop
           • else compute Jaccard(k,Y) and PPV(Y|k) from the Y→k row;
             drop if Jaccard ≥ jaccard_drop or PPV ≥ ppv_drop.

      4) Optional name-based filter (patterns).
    """
    def _log(tag, items: Set[str]):
        if not verbose: return
        print(f"[FILTER][{tag}] drop={len(items):,}" +
              ("" if not items else f"  e.g., {', '.join(sorted(list(items))[:max_examples])}"))

    _ensure_omop_attached(conn, OMOP_DB_PATH)

    # Universe of candidate predictors (RIGHT == k in Y→k rows)
    k_codes_all: Set[str] = set(k_to_T["concept_code"].astype(str))
    if verbose:
        print(f"[FILTER][init] candidates={len(k_codes_all):,}")

    # ── Step 1: Expanded family (parents+descendants IN seed; no descendants-of-sources) ──
    # alias_family is already {T} ∪ descendants ∪ parents from the main loop
    outcome_family_std: Set[str] = set(alias_family) | {T}
    _load_temp_codes(conn, "tmp_family_seed", outcome_family_std)

    # Add Maps-to sources to any member of the seed
    maps_to_sources: Set[str] = _get_maps_to_sources(conn, outcome_family_std)

    # Expanded family: seed ∪ maps-to sources  (NO descendants-of-sources)
    expanded_family: Set[str] = outcome_family_std | maps_to_sources
    _load_temp_codes(conn, "tmp_family_all", expanded_family)

    if verbose:
        print(f"[FILTER][step1] seed(T+desc+parents)={len(outcome_family_std):,}  "
              f"maps_to_sources={len(maps_to_sources):,}  "
              f"expanded={len(expanded_family):,}")

    # ── Step 2: Structural drops ───────────────────────────────────────────────
    _load_temp_codes(conn, "tmp_k", k_codes_all)

    # (2a) k ∈ expanded family
    drop_in_family = set(pd.read_sql_query("""
        SELECT t.concept_code
        FROM tmp_k t
        INNER JOIN tmp_family_all f ON f.concept_code = t.concept_code
    """, conn)["concept_code"])

    # (2b) k MAPS-TO any member of expanded family
    drop_maps_to_family = set(pd.read_sql_query("""
      WITH fam AS (
        SELECT cn.concept_id
        FROM tmp_family_all t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
      ),
      k_ids AS (
        SELECT cn.concept_id AS k_id, cn.concept_code AS k_code
        FROM tmp_k t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
      )
      SELECT DISTINCT k.k_code
      FROM k_ids k
      JOIN omop.concept_relationship cr ON cr.concept_id_1 = k.k_id
      WHERE cr.relationship_id IN ('Maps to','Maps to value','MAPS TO','MAPS TO VALUE')
        AND IFNULL(cr.invalid_reason,'') = ''
        AND EXISTS (SELECT 1 FROM fam WHERE fam.concept_id = cr.concept_id_2)
    """, conn)["k_code"])

    # (2c) k has DESCENDANTS (NOT parents) intersecting expanded family
    drop_desc_intersect = set(pd.read_sql_query("""
      WITH k_ids AS (
        SELECT cn.concept_id, cn.concept_code
        FROM tmp_k t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
        WHERE cn.vocabulary_id='SNOMED' AND cn.domain_id='Condition'
      ),
      fam_ids AS (
        SELECT cn.concept_id
        FROM tmp_family_all t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
        WHERE cn.vocabulary_id='SNOMED' AND cn.domain_id='Condition'
      )
      SELECT DISTINCT k.concept_code
      FROM k_ids k
      JOIN omop.concept_ancestor ca ON ca.ancestor_concept_id = k.concept_id
      JOIN fam_ids f               ON f.concept_id          = ca.descendant_concept_id
    """, conn)["concept_code"])

    struct_drops = drop_in_family | drop_maps_to_family | drop_desc_intersect
    _log("2a_in_family", drop_in_family)
    _log("2b_maps_to_family", drop_maps_to_family)
    _log("2c_desc_intersect_family", drop_desc_intersect)

    keep_after_struct = k_codes_all - struct_drops
    if verbose:
        print(f"[FILTER][after step2] keep={len(keep_after_struct):,}")
    if not keep_after_struct:
        return k_to_T.loc[[]].copy()

    # ── Step 3: Sibling test ───────────────────────────────────────────────────
    # (i) Same standard as T (same domain+vocab) → drop
    t_domain, t_vocab = _get_domain_vocab(conn, T)
    _load_temp_codes(conn, "tmp_keep_struct", keep_after_struct)
    drop_same_std = set(pd.read_sql_query("""
      WITH fam AS (
        SELECT cn.concept_id
        FROM tmp_family_all t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
      ),
      k_ids AS (
        SELECT cn.concept_id AS k_id, cn.concept_code AS k_code
        FROM tmp_keep_struct t
        JOIN omop.concept_names cn ON cn.concept_code = t.concept_code
      )
      SELECT DISTINCT k.k_code
      FROM k_ids k
      JOIN omop.concept_relationship cr ON cr.concept_id_1 = k.k_id
      JOIN omop.concept std            ON std.concept_id   = cr.concept_id_2
      WHERE cr.relationship_id IN ('Maps to','Maps to value','MAPS TO','MAPS TO VALUE')
        AND IFNULL(cr.invalid_reason,'') = ''
        AND UPPER(std.domain_id)     = UPPER(?)
        AND UPPER(std.vocabulary_id) = UPPER(?)
        AND EXISTS (SELECT 1 FROM fam WHERE fam.concept_id = cr.concept_id_2)
    """, conn, params=[t_domain, t_vocab])["k_code"])

    _log("3a_same_standard_as_target", drop_same_std)

    keep_for_overlap = keep_after_struct - drop_same_std
    if verbose:
        print(f"[FILTER][step3 prep] eligible for overlap test={len(keep_for_overlap):,}")
    if not keep_for_overlap:
        return k_to_T.loc[[]].copy()

    # (ii) Near-duplicate by Y→k overlap metrics (Jaccard & PPV)
    k_rows = k_to_T[k_to_T["concept_code"].isin(keep_for_overlap)].copy()
    k_rows["a"] = pd.to_numeric(k_rows["n_code_target"], errors="coerce").fillna(0.0)
    k_rows["b"] = pd.to_numeric(k_rows["n_code_no_target"], errors="coerce").fillna(0.0)
    k_rows["c"] = pd.to_numeric(k_rows["n_target_no_code"], errors="coerce").fillna(0.0)

    agg = (k_rows.groupby("concept_code", as_index=False)[["a","b","c"]].sum())
    denom_jacc = (agg["a"] + agg["b"] + agg["c"]).replace(0.0, np.nan)
    denom_ppv  = (agg["a"] + agg["b"]).replace(0.0, np.nan)
    agg["jaccard"] = agg["a"] / denom_jacc
    agg["ppv"]     = agg["a"] / denom_ppv

    near_dupe = set(agg.loc[(agg["jaccard"] >= jaccard_drop) | (agg["ppv"] >= ppv_drop), "concept_code"])
    _log("3b_near_duplicate_overlap", near_dupe)

    keep_overlap = keep_for_overlap - near_dupe
    if verbose:
        print(f"[FILTER][after step3] keep={len(keep_overlap):,}")

    # ── Step 4: Optional name filter ───────────────────────────────────────────
    if disease_like_patterns:
        name_hits = _name_like_disease(conn, keep_overlap, disease_like_patterns)
        _log("4_name_based", name_hits)
        keep_final = keep_overlap - name_hits
    else:
        keep_final = keep_overlap
        if verbose:
            print("[FILTER][step4] skipped (no patterns)")

    if verbose:
        dropped_total = len(k_codes_all - keep_final)
        print(f"[FILTER][summary] start={len(k_codes_all):,}  keep={len(keep_final):,}  drop_total={dropped_total:,}")

    return k_to_T[k_to_T["concept_code"].isin(keep_final)].copy()


# ========= main loop (Y→k for T_{k,Y}; j→k for λ_{k,j}) =========
if __name__ == '__main__':
    # 0) Runtime safety checks
    uri = f"file:{MAGI_DB_PATH}?mode=ro"
    if not os.path.exists(MAGI_DB_PATH):
        raise FileNotFoundError(
            f"MAGI_DB_PATH not found: {MAGI_DB_PATH}. "
            "Set MAGI_DB_PATH to a valid SQLite DB or run with CSV input."
        )

    with sqlite3.connect(uri, uri=True) as conn:
        # Attach OMOP once (needed for maps-to & concept_ancestor SQL)
        _ensure_omop_attached(conn, OMOP_DB_PATH)

        # (Optional) log the expanded RF2 families (parents+descendants+root)
        _ = print_full_targets(TARGETS, save_csv=True, out_dir=OUT_DIR)

        for T in TARGETS:
            print("\n" + "=" * 100)
            print(f"[RUN] Target (Y) = {T}")

            # 1) Pull rows with LEFT=Y and RIGHT=k (this is where we read T_{k,Y})
            #    Orientation: T_{k,Y} uses row (target_concept_code==Y, concept_code==k)  ← Y→k
            k_to_T = _fetch_k_to_T(conn, T)

            # DEDUP safeguard
            before = len(k_to_T)
            k_to_T = k_to_T.drop_duplicates()
            after = len(k_to_T)
            if after < before:
                print(f"[DEDUP] Y→k exact-row duplicates removed: {before - after}  (kept {after})")

            if k_to_T.empty:
                print(f"[WARN] No Y→k rows for {T}; skipping.")
                continue

            # 2) Add derived cells (incl. total_effect based on a,b,c,d with d computed on the fly)
            k_to_T = _ensure_derived_cols(k_to_T)

            # --- RF2 alias family for filtering seed (NO parents + descendants + root)
            aliases_codes, _nm = snomed_aliases_for_outcome(T, include_parents=True)
            alias_family = set(aliases_codes) if aliases_codes else {T}

            # 3) Decision-tree filters (Steps 1–3 (+ optional 4)):
            #    • Build expanded outcome family = {T} ∪ parents ∪ descendants ∪ maps-to sources ∪ descendants(sources)
            #    • Drop k that are in family, map to family, OR whose parents/descendants intersect family
            #    • Sibling test: if k maps to SAME standard (domain+vocab) as T → drop
            #    • Else compute Jaccard(k,Y) & PPV(Y|k) from Y–k row; drop if Jaccard≥0.88 or PPV≥0.90
            k_to_T = _apply_decision_tree_filters(
                conn, T, k_to_T,
                alias_family=alias_family,
                jaccard_drop=0.88,
                ppv_drop=0.90,
                disease_like_patterns=None,  # set patterns list to enable optional step 4
                verbose=True
            )
            if k_to_T.empty:
                print(f"[WARN] All predictors pruned by decision-tree for {T}; skipping.")
                continue

            # 4) Pick predictors by total_effect (RIGHT side = concept_code = k)
            sel_rows = _select_top_by_te_unique_k(k_to_T, top_k=TOP_K)
            if sel_rows.empty:
                print(f"[WARN] No predictors selected for {T} after filtering; skipping.")
                continue

            # Audit selected risk/protective list
            os.makedirs(OUT_DIR, exist_ok=True)
            audit_csv = os.path.join(OUT_DIR, f"risk_prot_{T}.csv")
            sel_rows.to_csv(audit_csv, index=False)
            print(f"[SAVED] Factors (post-filters, pre-MAGI) → {audit_csv}")

            # Whitelist of predictors (do not include T itself)
            selected_k = set(sel_rows["concept_code"].astype(str))
            if T in selected_k:
                selected_k.remove(T)

            print(f"[SELECT] unique k available={k_to_T['concept_code'].nunique():,}  "
                  f"selected={len(selected_k):,}")
            if not selected_k:
                print(f"[WARN] No predictors after removing target {T}; skipping.")
                continue

            # 5) Build subgraph for λ_{k,j}:
            #    • LEFT ∈ {T} ∪ selected_k  (targets are Y and all j)
            #    • RIGHT ∈ selected_k        (codes are predictors k only; exclude RIGHT=T)
            events_set    = selected_k | {T}
            right_allowed = selected_k

            df_trim = _fetch_subgraph_by_targets(conn, sorted(events_set))
            df_trim = df_trim[df_trim["concept_code"].isin(right_allowed)].copy()

            # DEDUP again on trimmed subgraph
            before = len(df_trim)
            df_trim = df_trim.drop_duplicates()
            after = len(df_trim)
            if after < before:
                print(f"[DEDUP] subgraph exact-row duplicates removed: {before - after}  (kept {after})")

            # Ensure derived columns present/consistent
            df_trim = _ensure_derived_cols(df_trim)

            # Sanity counts by orientation
            # • Y→k rows present for TE
            y_to_k = int((df_trim["target_concept_code"] == T).sum())
            # • j→k rows present for λ (targets not T, rights ∈ selected_k)
            j_to_k = int(((df_trim["target_concept_code"] != T) &
                          (df_trim["concept_code"].isin(right_allowed))).sum())
            print(f"[TRIM] rows={len(df_trim):,} events={len(events_set)}  Y→k={y_to_k}  j→k={j_to_k}")

            # Save subgraph (audit)
            sub_csv = os.path.join(OUT_DIR, f"magi_subgraph_{T}.csv")
            df_trim.to_csv(sub_csv, index=False)
            print(f"[SAVED] Subgraph → {sub_csv}")

            # 6) Run MAGI with outcome forced to T
            #    Internals:
            #      • T_{k,Y} read from row (target=Y, code=k)         ← our Y→k rows
            #      • λ_{k,j} read from row (target=j, code=k)         ← our j→k rows
            try:
                res = analyze_causal_sequence_py(
                    df_trim,
                    events=None,     # auto-detect from df_trim
                    name_map=None,   # don't canonicalize aliases here
                    force_outcome=T
                )
            except Exception as e:
                print("[ERROR] MAGI failed:", e)
                continue

            # 7) Save NON-ZERO coefficients, restricted to whitelist (selected_k) + intercept
            outcome_used = res.get("order_used", [T])[-1]
            coef_df = res["coef_df"].copy()

            # accept 'coef' or 'beta'
            coef_col = "coef" if "coef" in coef_df.columns else ("beta" if "beta" in coef_df.columns else None)
            if coef_col is None:
                raise KeyError("Coefficient column not found in coef_df (expected 'coef' or 'beta').")

            eps = 1e-12
            coef_df[coef_col] = pd.to_numeric(coef_df[coef_col], errors="coerce").fillna(0.0)
            coef_nz = coef_df[coef_df[coef_col].abs() > eps].copy()

            # predictor column (usually 'predictor')
            pred_col = "predictor" if "predictor" in coef_nz.columns else None
            if pred_col is None:
                # fallback: first non-coef column
                maybe = [c for c in coef_nz.columns if c != coef_col]
                pred_col = maybe[0] if maybe else None

            # whitelist to selected_k (plus intercept); final guard: drop any alias-family codes if present
            if pred_col:
                is_intercept = coef_nz[pred_col].astype(str) == "(intercept)"
                in_whitelist = coef_nz[pred_col].astype(str).isin(selected_k)
                is_alias     = coef_nz[pred_col].astype(str).isin(alias_family)
                keep_mask = is_intercept | (in_whitelist & ~is_alias)
                dropped = int((~keep_mask).sum())
                if dropped:
                    print(f"[FILTER] Dropping {dropped} non-whitelisted or alias-family predictor(s).")
                coef_nz = coef_nz.loc[keep_mask]

            # Save
            coef_dir = os.path.join(OUT_DIR, "Coef_Yk")
            os.makedirs(coef_dir, exist_ok=True)
            coef_csv = os.path.join(coef_dir, f"magi_coef_{outcome_used}_nonzero.csv")
            coef_nz.to_csv(coef_csv, index=False)
            print(f"[SAVED] Coefficients (non-zero only) → {coef_csv}  "
                  f"| kept={len(coef_nz):,} of {len(coef_df):,}  "
                  f"| nodes={len(res.get('order_used', [])) - 1}")


