In [None]:
import os, sqlite3, math, sys, time, re
from typing import Dict, Any, List, Optional, Iterable, Tuple, Set, Union
import pandas as pd
import numpy as np
from collections import defaultdict, deque

# ===== CONFIG =====
MAGI_DB_PATH = os.getenv("MAGI_DB_PATH", "/projects/klybarge/pcori_ad/magi/magi_db/magiv2.db")
OUT_DIR = "./BreastCancer"
os.makedirs(OUT_DIR, exist_ok=True)
EDGE_TABLE = "magi_counts_published"   
uri = f"file:{MAGI_DB_PATH}?mode=ro"
TOP_K = 500  
   
TARGETS = [
    "dx_SNOMED_254837009",
]

def connect_in_read_only_mode(db_path):
   """
   Connects to a SQLite database in read-only mode using a URI.
   This is the safest way for multiple users to read from the same DB.
   """
   
   if not os.path.exists(db_path):
       print(f"❌ ERROR: Database file not found at: {db_path}", file=sys.stderr)
       return None

   try:
       # 1. Create a URI for the database file
       #    'file:' prefix is necessary
       #    '?mode=ro' tells SQLite to open in "read-only" mode
       db_uri = f"file:{db_path}?mode=ro"
        
       print(f"Connecting in read-only mode to: {db_uri}")
       
       # 2. Connect using the uri=True flag
       conn = sqlite3.connect(db_uri, uri=True)
       print("✅ Connection successful.")
       return conn
        
   except Exception as e:
       print(f"❌ An unexpected error occurred: {e}", file=sys.stderr)
       return None

def load_code_maps(conn):
    """
    Load mapping between concept_code (string, e.g. 'dx_SNOMED_254645002')
    and concept_code_int (internal INTEGER used in magi_counts_published).
    """
    df_map = pd.read_sql_query(
        "SELECT concept_code_int, concept_code FROM concept_names",
        conn
    )
    code2int = dict(zip(df_map["concept_code"], df_map["concept_code_int"]))
    int2code = dict(zip(df_map["concept_code_int"], df_map["concept_code"]))
    return code2int, int2code

# ===== SNOMED parents+descendants support =====
from collections import defaultdict, deque

IS_A = "116680003"
CHAR_TYPES = {
    "inferred": "900000000000011006",
    "stated":   "900000000000010007",
}

SNOMED_REL_FULL_US = "/projects/klybarge/pcori_ad/magi/Test/Test/RareCancer/sct2_Relationship_Full_US1000124_20250901.txt"
__SNAP_REL__ = None
__P2C__, __C2P__ = None, None

def build_is_a_snapshot(full_rel_path: str, characteristic: str = "inferred") -> pd.DataFrame:
    """From a Full RF2 file, return the *current* active IS-A rows."""
    use_cols = [
        "id","effectiveTime","active","moduleId",
        "sourceId","destinationId","relationshipGroup",
        "typeId","characteristicTypeId","modifierId"
    ]
    df = pd.read_csv(full_rel_path, sep="\t", dtype=str, usecols=use_cols)
    df = df[(df["typeId"] == IS_A) & (df["characteristicTypeId"] == CHAR_TYPES[characteristic])]
    df["effectiveTime_num"] = df["effectiveTime"].astype(int)
    idx = df.groupby("id")["effectiveTime_num"].idxmax()
    snap = df.loc[idx]
    snap = snap[snap["active"] == "1"][["sourceId","destinationId"]].reset_index(drop=True)
    return snap

# Build both directions once
def _ensure_graph():
    global __SNAP_REL__, __P2C__, __C2P__
    if __SNAP_REL__ is None:
        __SNAP_REL__ = build_is_a_snapshot(SNOMED_REL_FULL_US, characteristic="inferred")
    if __P2C__ is None or __C2P__ is None:
        __P2C__, __C2P__ = defaultdict(set), defaultdict(set)
        for child, parent in zip(__SNAP_REL__["sourceId"], __SNAP_REL__["destinationId"]):
            __P2C__[parent].add(child)
            __C2P__[child].add(parent)

def choose_umbrella_ancestor(root_id: str) -> str:
    """
    Given a SNOMED conceptId (e.g., '254645002'), choose an 'umbrella' ancestor:

      - If there are no direct parents: return root_id.
      - If there is 1 direct parent: return that parent.
      - If there are multiple direct parents:
            * For each direct parent, BFS upward and record distance
              (0 = the parent itself).
            * Take the intersection of ancestor sets = concepts above ALL parents.
            * Among those common ancestors, choose the *lowest* one(s):
                  - minimize max(distance_from_parent)
                  - break ties by sum of distances
                  - break remaining ties by conceptId for determinism.

    This prefers *specific* common ancestors close to the parents (true LCA),
    instead of very generic concepts that sit high in the hierarchy.
    """
    _ensure_graph()

    direct_parents = set(__C2P__.get(root_id, ()))
    if not direct_parents:
        # No parents at all → just use the original concept
        return root_id

    if len(direct_parents) == 1:
        # Single parent → no conflict, just use it
        return next(iter(direct_parents))

    # ---- BFS upward from each parent, tracking distance ----
    dist_maps = []  # one dict per parent: {conceptId: distance_from_this_parent}
    ancestor_sets = []

    for p in direct_parents:
        dist = {p: 0}
        q = deque([p])
        while q:
            cur = q.popleft()
            for par in __C2P__.get(cur, ()):
                if par not in dist:
                    dist[par] = dist[cur] + 1
                    q.append(par)
        dist_maps.append(dist)
        ancestor_sets.append(set(dist.keys()))

    # Common ancestors across all parent branches
    common_anc = set.intersection(*ancestor_sets) if ancestor_sets else set()

    if not common_anc:
        # In the unlikely event there is no common ancestor, fall back
        # to a deterministic direct parent.
        return sorted(direct_parents)[0]

    # ---- Choose the *lowest* common ancestor ----
    # Score each candidate by:
    #   1) max distance from any parent  (smaller = lower/closer)
    #   2) sum of distances              (tie-breaker)
    #   3) conceptId (string)            (final deterministic tie-breaker)
    best_cid = None
    best_score = None

    for cid in common_anc:
        dists = []
        valid = True
        for dm in dist_maps:
            if cid not in dm:
                valid = False
                break
            dists.append(dm[cid])
        if not valid:
            continue

        score = (max(dists), sum(dists), cid)
        if best_score is None or score < best_score:
            best_score = score
            best_cid = cid

    # Safety fallback: if something weird happened, fall back to a direct parent
    if best_cid is None:
        return sorted(direct_parents)[0]

    return best_cid

def find_descendants_sct(concept_id: str) -> set:
    """All is-a descendants for a conceptId."""
    _ensure_graph()
    out, q = set(), deque([concept_id])
    while q:
        cur = q.popleft()
        for kid in __P2C__.get(cur, ()):
            if kid not in out:
                out.add(kid); q.append(kid)
    return out

def find_parents_sct(concept_id: str) -> set:
    _ensure_graph()
    return set(__C2P__.get(concept_id, ()))

# ===== MAP PREDEFINED TARGET → PARENT TARGET + DESCENDANT BLOCKLIST =====
def extract_snomed_id(dx_code: str) -> Optional[str]:
    """
    Given something like 'dx_SNOMED_254645002', return '254645002'.
    Return None if this is not a SNOMED dx code.
    """
    prefix = "dx_SNOMED_"
    if dx_code.startswith(prefix):
        return dx_code[len(prefix):]
    return None

def _make_dx_snomed(concept_id: str) -> str:
    """Turn '254645002' back into 'dx_SNOMED_254645002'."""
    return f"dx_SNOMED_{concept_id}"

def parent_target_and_blocklist_for_T(T: str) -> Tuple[str, Set[str]]:
    """
    For a predefined target T:

      1) If T is a SNOMED dx:
           - choose an 'umbrella' ancestor using choose_umbrella_ancestor:
               • if no parents → use T itself
               • if one parent → use that parent
               • if multiple parents → pick a common ancestor (grandparent or higher)
           - build a predictor blocklist consisting of:
               • T itself
               • all *siblings* of T (other children of the same parent(s))
               • all *descendants of T*
           (Parents / umbrella ancestors are NOT blocked and can still be predictors.)
      2) If T is not SNOMED, just use T as outcome and an empty blocklist.

    Returns (effective_outcome_code, predictor_blocklist_set_of_codes).
    """
    snomed_id = extract_snomed_id(T)
    # Non-SNOMED targets: no change
    if snomed_id is None:
        return T, set()

    # --- 1. choose umbrella ancestor (or T itself if no parents) ---
    umbrella_id = choose_umbrella_ancestor(snomed_id)
    parent_outcome_code = _make_dx_snomed(umbrella_id)

    # Ensure graph is built so we can see parents/children
    _ensure_graph()

    # --- 2. direct parents of T ---
    direct_parents = set(__C2P__.get(snomed_id, ()))

    # --- 3. siblings of T: all children of each parent, minus T itself ---
    sibling_ids: Set[str] = set()
    for p in direct_parents:
        sibling_ids |= __P2C__.get(p, set())
    sibling_ids.discard(snomed_id)

    # --- 4. descendants of T ---
    t_desc_ids = find_descendants_sct(snomed_id)

    # --- 5. blocklist = T + siblings(T) + descendants(T) ---
    block_ids = {snomed_id} | sibling_ids | t_desc_ids
    blocklist: Set[str] = {_make_dx_snomed(cid) for cid in block_ids}

    # Also ensure the exact T code is present (it already is if dx_SNOMED_, but harmless to add)
    blocklist.add(T)

    return parent_outcome_code, blocklist

def snomed_aliases_for_outcome(outcome_code: str, include_parents: bool = True) -> Tuple[Optional[Set[str]], Dict[str, str]]:
    """
    Given an outcome like 'dx_SNOMED_254645002', return:
      aliases_codes: {'dx_SNOMED_<id>', ...} including root and (optionally) parents
      name_map:      every alias_code -> outcome_code (canonical), excluding the root itself
    """
    root_id = extract_snomed_id(outcome_code)
    if root_id is None:
        return None, {}

    desc = find_descendants_sct(root_id)
    fam_ids = {root_id} | set(desc)
    if include_parents:
        fam_ids |= find_parents_sct(root_id)

    aliases_codes: Set[str] = {f"dx_SNOMED_{sid}" for sid in fam_ids}
    name_map: Dict[str, str] = {alias: outcome_code for alias in aliases_codes if alias != outcome_code}
    return aliases_codes, name_map


# ======================================================================
# MAGI core: analyze_causal_sequence_py (INT-BASED)
# ======================================================================

# RIGHT = k (predictor) and LEFT = Y/j
def analyze_causal_sequence_py(
    df_in: Union[str, pd.DataFrame],
    *,
    name_map: Dict[str, str] = None,     # kept for compatibility but IGNORED in _int version
    events: List[int] = None,            # event IDs to KEEP; if None: auto-detect from *_int cols
    force_outcome: int = None,           # if provided and present, force this to be the FINAL node (Y)
    lambda_min_count: int = 15           # L-threshold for λ: if n_code < L ⇒ λ_{k,j}=0
) -> Dict[str, Any]:
    """
    MAGI (Python, INT-BASED) – uses total_effects from DB for T, no 2×2 fallback.
    All computations are done on target_concept_code_int / concept_code_int.

    Rules:
      • T_{kY}:
            If `total_effects` column exists:
                T_{kY} = mean(total_effects) from row(s) with (target=Y, code=k).
                If those rows are missing or total_effects is NaN/≤0/±inf → T_{kY} = 1.
            If `total_effects` column does NOT exist at all → T_{kY} = 1 for all k.
      • Temporal score:
            For each Zi:
                score(Zi) = Σ_{Zj≠Zi} [ n_target_before_code(Zi,Zj) - n_code_before_target(Zi,Zj) ]
            This is computed from the same counts as your original code,
            just via a MultiIndex instead of repeated scans.
      • λ_{k,j}:
            λ_{k,j} = n_code_target(j,k) / n_code(j,k),
            read from rows with (target=j, code=k), with L-threshold on n_code.
    """
    # ── 0) Ingest & basic checks ───────────────────────────────────────────────
    df = pd.read_csv(df_in) if isinstance(df_in, str) else df_in.copy()

    need_cols = [
        "target_concept_code_int", "concept_code_int",
        "n_code_target", "n_code_no_target",
        "n_target", "n_no_target",
        "n_target_no_code",
        "n_code",
        "n_code_before_target", "n_target_before_code",
    ]
    missing = [c for c in need_cols if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required count columns for MAGI: {', '.join(missing)}")

    has_total_effects = "total_effects" in df.columns

    # Ensure *_int are numeric / nullable ints
    df["target_concept_code_int"] = pd.to_numeric(df["target_concept_code_int"], errors="coerce").astype("Int64")
    df["concept_code_int"]        = pd.to_numeric(df["concept_code_int"],        errors="coerce").astype("Int64")

    # name_map intentionally ignored in _int version

    # Limit to selected events
    if events is None:
        arr_t = df["target_concept_code_int"].dropna().unique()
        arr_c = df["concept_code_int"].dropna().unique()
        # IntegerArray -> list of Python ints
        ev_t = [int(x) for x in arr_t]
        ev_c = [int(x) for x in arr_c]
        events = sorted(set(ev_t) | set(ev_c))
    else:
        events = [int(e) for e in events]

    if len(events) < 2:
        raise ValueError("Need at least two events.")

    df = df[
        df["target_concept_code_int"].isin(events)
        & df["concept_code_int"].isin(events)
    ].copy()

    # Coerce numerics
    num_cols = [
        "n_code_target", "n_code_no_target",
        "n_target", "n_no_target", "n_target_no_code",
        "n_code",
        "n_code_before_target", "n_target_before_code",
    ]
    if has_total_effects:
        num_cols.append("total_effects")

    for c in num_cols:
        df[c] = pd.to_numeric(df[c], errors="coerce")

    # ── 1) Build an indexed edge table ────────────────────────────────────────
    edge = (
        df.groupby(["target_concept_code_int", "concept_code_int"], as_index=True)[
            ["n_target_before_code", "n_code_before_target",
             "n_code_target", "n_code"]
        ].sum()
    )

    edge_targets = edge.index.get_level_values(0)

    # ── 2) Temporal scores ────────────────────────────────────────────────────
    scores: Dict[int, float] = {}
    for zi in events:
        if zi not in edge_targets:
            scores[zi] = 0.0
            continue

        try:
            sub = edge.xs(zi, level="target_concept_code_int")  # index = concept_code_int
        except KeyError:
            scores[zi] = 0.0
            continue

        s = float(
            (sub["n_target_before_code"].fillna(0.0) -
             sub["n_code_before_target"].fillna(0.0)).sum()
        )
        scores[zi] = s

    sorted_scores = pd.Series(scores).sort_values(ascending=False)

    # Choose outcome Y
    if (force_outcome is not None) and (force_outcome in sorted_scores.index):
        outcome = int(force_outcome)
        temporal_order = [ev for ev in sorted_scores.index if ev != outcome] + [outcome]
    else:
        outcome = int(sorted_scores.index[0])
        temporal_order = [ev for ev in sorted_scores.index if ev != outcome] + [outcome]

    events_order = temporal_order
    nodes = events_order[:-1]

    pos_by_event = {ev: i for i, ev in enumerate(events_order)}

    # Pre-filter rows where LEFT == outcome for T_{kY}
    if has_total_effects:
        dfY = df[df["target_concept_code_int"] == outcome].copy()
        dfY["total_effects"] = pd.to_numeric(dfY["total_effects"], errors="coerce")
    else:
        dfY = None

    # ── 3) T and λ ─────────────────────────────────────────────────────────────
    T_val = pd.Series(0.0, index=nodes, dtype=float)
    D_val = pd.Series(np.nan, index=nodes, dtype=float)
    lambda_l: Dict[int, pd.Series] = {}

    for k in nodes:
        # T_{kY}
        if has_total_effects:
            row_Yk = dfY[dfY["concept_code_int"] == k]
            if row_Yk.empty:
                T_val.loc[k] = 1.0
            else:
                T_col = pd.to_numeric(row_Yk["total_effects"], errors="coerce")
                T_col = T_col.replace([np.inf, -np.inf], np.nan)
                T_clean = T_col.dropna()
                if T_clean.empty:
                    T_val.loc[k] = 1.0
                else:
                    T_db = float(T_clean.mean())
                    if (not np.isfinite(T_db)) or (T_db <= 0):
                        T_db = 1.0
                    T_val.loc[k] = T_db
        else:
            T_val.loc[k] = 1.0

        # λ_{k,j}
        pos_k = pos_by_event[k]
        js = events_order[pos_k + 1 : -1] if pos_k < len(events_order) - 1 else []

        lam_pairs = {}
        for j in js:
            key = (j, k)
            if key not in edge.index:
                lam_pairs[j] = 0.0
                continue

            row_jk = edge.loc[key]
            num = float(row_jk["n_code_target"])
            den = float(row_jk["n_code"])

            if (den <= 0) or (den < lambda_min_count):
                lam_pairs[j] = 0.0
                continue

            lam = num / den
            if not np.isfinite(lam):
                lam = 0.0
            lam_pairs[j] = float(min(max(lam, 0.0), 1.0))

        lambda_l[k] = pd.Series(lam_pairs, dtype=float)

    # ── 4) Backward recursion for D ────────────────────────────────────────────
    if len(nodes) >= 1:
        last_anc = nodes[-1]
        D_val.loc[last_anc] = T_val.loc[last_anc]

    if len(nodes) > 1:
        for k in list(reversed(nodes[:-1])):
            lam_vec = lambda_l.get(k, pd.Series(dtype=float))
            downstream = list(lam_vec.index)
            lam_vals = lam_vec.reindex(downstream).fillna(0.0).to_numpy()
            D_down  = pd.to_numeric(D_val.reindex(downstream),
                                    errors="coerce").fillna(0.0).to_numpy()

            num = T_val.loc[k] - float(np.nansum(lam_vals * D_down))
            den = 1.0 - float(np.nansum(lam_vals))

            if (not np.isfinite(den)) or den == 0.0:
                D_val.loc[k] = T_val.loc[k]
            else:
                tmp = num / den
                D_val.loc[k] = tmp if np.isfinite(tmp) else T_val.loc[k]

    # ── 5) Logistic link (β) and predict_proba ─────────────────────────────────
    resp_rows = df[df["target_concept_code_int"] == outcome]
    n_t = float(pd.to_numeric(resp_rows["n_target"],      errors="coerce").max()) if not resp_rows.empty else np.nan
    n_n = float(pd.to_numeric(resp_rows["n_no_target"],   errors="coerce").max()) if not resp_rows.empty else np.nan
    denom = n_t + n_n
    p_y = 0.5 if (not np.isfinite(denom) or denom <= 0) else (n_t / denom)
    p_y = min(max(p_y, 1e-12), 1 - 1e-12)
    beta_0 = float(np.log(p_y / (1 - p_y)))

    D_clean = pd.to_numeric(D_val, errors="coerce").astype(float)
    beta_vals = np.log(D_clean.where(D_clean > 0.0)) \
                    .replace([np.inf, -np.inf], np.nan).fillna(0.0)

    coef_df = pd.DataFrame({
        "predictor": list(beta_vals.index) + ["(intercept)"],
        "beta":      list(beta_vals.values) + [beta_0],
    })

    predictors = list(beta_vals.index)
    beta_vec = beta_vals.values

    def predict_proba(Z):
        """
        Compute P(Y=1|Z) using: logit P = β0 + Σ_k β_k Z_k.
        Here Z keys should be concept_code_int IDs.
        """
        def sigmoid(x):
            x = np.clip(x, -700, 700)
            return 1.0 / (1.0 + np.exp(-x))

        if isinstance(Z, pd.DataFrame):
            M = Z.reindex(columns=predictors, fill_value=0.0) \
                 .astype(float).to_numpy()
            return sigmoid(beta_0 + M @ beta_vec)

        if isinstance(Z, (dict, pd.Series)):
            v = np.array([float(Z.get(p, 0.0)) for p in predictors], dtype=float)
            return float(sigmoid(beta_0 + float(v @ beta_vec)))

        arr = np.asarray(Z, dtype=float)
        if arr.ndim == 1:
            if arr.size != len(predictors):
                raise ValueError(f"Expected {len(predictors)} features in order: {predictors}")
            return float(sigmoid(beta_0 + float(arr @ beta_vec)))
        if arr.ndim == 2:
            if arr.shape[1] != len(predictors):
                raise ValueError(f"Expected shape (*,{len(predictors)}), got {arr.shape}")
            return sigmoid(beta_0 + arr @ beta_vec)

        raise ValueError("Unsupported input for predict_proba")

    return {
        "sorted_scores": sorted_scores,
        "temporal_order": events_order,
        "order_used": events_order,
        "T_val": T_val,
        "D_val": D_val,
        "lambda_l": lambda_l,
        "coef_df": coef_df,
        "beta_0": beta_0,
        "beta": pd.Series(beta_vec, index=predictors, dtype=float),
        "logit_predictors": predictors,
        "predict_proba": predict_proba,
    }
    
def _ensure_derived_cols(df: pd.DataFrame) -> pd.DataFrame:
    """
    For DB-derived data (magi_counts_published):
      - n_no_code_no_target is present and valid
      - total_effects is already computed in the DB

    This helper just coerces the relevant columns to numeric types.
    No recomputation or fallback is done here.
    """
    out = df.copy()

    num_cols = [
        "n_code_target",
        "n_code_no_target",
        "n_target_no_code",
        "n_target",
        "n_no_target",
        "n_no_code_no_target",
        "n_code",
        "n_code_before_target",
        "n_target_before_code",
        "total_effects",
    ]

    out[num_cols] = out[num_cols].apply(pd.to_numeric, errors="coerce")

    # Ensure *_int are numeric
    out["target_concept_code_int"] = pd.to_numeric(out["target_concept_code_int"], errors="coerce").astype("Int64")
    out["concept_code_int"]        = pd.to_numeric(out["concept_code_int"],        errors="coerce").astype("Int64")

    return out

# TOPK_HI/LO
def _select_top_by_te_unique_k(k_to_T: pd.DataFrame, top_k: int = None) -> pd.DataFrame:
    """
    Select top predictors for a single outcome T from its Y→k subgraph `k_to_T`,
    using DB-provided total_effects_rank (dense rank on total_effects_norm
    partitioned by target_concept_code_int).
    """
    if top_k is None:
        top_k = TOP_K

    df = k_to_T.copy()

    if "total_effects_rank" not in df.columns:
        raise ValueError("Expected 'total_effects_rank' in k_to_T for rank-based selection.")

    df["total_effects_rank"] = pd.to_numeric(df["total_effects_rank"], errors="coerce")
    df = df.dropna(subset=["total_effects_rank"]).copy()
    if df.empty:
        print("[SELECT] no rows with valid total_effects_rank; returning empty selection.")
        return df

    # One row per k_int (RIGHT side)
    best_per_k = (
        df.sort_values("total_effects_rank", ascending=True)
          .drop_duplicates(subset=["concept_code_int"], keep="first")
    )

    selected = best_per_k.nsmallest(top_k, "total_effects_rank").copy()

    print(
        f"[SELECT] (rank-based) total unique k={best_per_k['concept_code_int'].nunique():,}  "
        f"selected(total)={selected['concept_code_int'].nunique()}  top_k={top_k}"
    )

    return selected.reset_index(drop=True)

def _fetch_k_to_T(conn, outcome_int: int) -> pd.DataFrame:
    """
    Fetch all rows whose LEFT (target) == outcome_int (concept_code_int).
    These rows provide the single (Y, k) lines needed to compute T_{kY}.
    """
    q = f"""
      SELECT *
      FROM {EDGE_TABLE}
      WHERE target_concept_code_int = ?
    """
    return pd.read_sql_query(q, conn, params=[int(outcome_int)])

def _fetch_k_to_T_in(conn, target_ints: List[int]) -> pd.DataFrame:
    """
    Fetch rows whose LEFT (target_concept_code_int) is in target_ints.
    Used to include alias outcomes (descendants/parents of T) on the LEFT.
    """
    if not target_ints:
        return pd.DataFrame()

    ph = ",".join(["?"] * len(target_ints))
    q = f"""
      SELECT *
      FROM {EDGE_TABLE}
      WHERE target_concept_code_int IN ({ph})
    """
    return pd.read_sql_query(q, conn, params=[int(x) for x in target_ints])

def _fetch_subgraph_by_targets(conn, events_list_int: List[int]) -> pd.DataFrame:
    """
    Fetch the induced subgraph for the event set on the LEFT side (INT codes).

    This captures all rows (X, k) where X ∈ events_set (as *_int),
    regardless of what k is. Later we filter k to selected_k.
    """
    if not events_list_int:
        return pd.DataFrame()

    ph = ",".join(["?"] * len(events_list_int))
    q = f"""
      SELECT *
      FROM {EDGE_TABLE}
      WHERE target_concept_code_int IN ({ph})
    """
    return pd.read_sql_query(q, conn, params=[int(e) for e in events_list_int])

def expand_targets_with_descendants(targets: List[str]) -> Dict[str, Set[str]]:
    out: Dict[str, Set[str]] = {}
    for T in targets:
        aliases, _name_map = snomed_aliases_for_outcome(T)
        out[T] = aliases if aliases is not None else {T}
    return out

def print_full_targets(targets: List[str], *, save_csv: bool = True, out_dir: str = "./"):
    """
    Pretty-print the expanded target definitions and (optionally) save a CSV with all rows:
      root_target, alias_code, alias_conceptId, is_root
    """
    expanded = expand_targets_with_descendants(targets)

    # pretty print to console
    for T, alias_set in expanded.items():
        # count and show a short sample
        n = len(alias_set)
        sample = ", ".join(sorted(list(alias_set))[:10])
        print("=" * 100)
        print(f"[TARGET] {T}")
        print(f"  aliases (including descendants): {n}")
        print(f"  sample: {sample}{' ...' if n > 10 else ''}")

    # optional CSV dump
    if save_csv:
        rows = []
        for T, alias_set in expanded.items():
            root_id = extract_snomed_id(T)
            for alias in sorted(alias_set):
                alias_id = extract_snomed_id(alias)
                rows.append({
                    "root_target": T,
                    "alias_code": alias,
                    "alias_conceptId": alias_id if alias_id else "",
                    "is_root": (alias == T)
                })
        df = pd.DataFrame(rows, columns=["root_target","alias_code","alias_conceptId","is_root"])
        ts = time.strftime("%Y%m%d-%H%M%S")
        os.makedirs(out_dir, exist_ok=True)
        path = os.path.join(out_dir, f"targets_expanded_{ts}.csv")
        df.to_csv(path, index=False)
        print(f"\n[SAVED] Full target expansions → {path}")

    return expanded

# ========= main loop (remove dup) =========
if __name__ == '__main__':
    uri = f"file:{MAGI_DB_PATH}?mode=ro"
    if not os.path.exists(MAGI_DB_PATH):
        raise FileNotFoundError(
            f"MAGI_DB_PATH not found: {MAGI_DB_PATH}. "
            f"Set MAGI_DB_PATH to a valid SQLite DB or run with CSV input."
        )

    # ===== BUILD PARENT/DESCENDANT CONFIG FOR ALL TARGETS (STRING LEVEL) =====
    PARENT_CONFIG = {}
    for T in TARGETS:  # T is 'aa_meas_citalopram_rem', etc.
        parent_outcome, desc_blocklist = parent_target_and_blocklist_for_T(T)
        PARENT_CONFIG[T] = {
            "parent_outcome": parent_outcome,   # string dx_SNOMED_...
            "desc_blocklist": desc_blocklist,   # set of string dx_SNOMED_...
        }

    print("\n[SNOMED parent promotion]")
    for T in TARGETS:
        cfg = PARENT_CONFIG[T]
        if cfg["parent_outcome"] != T:
            print(
                f"  {T} → parent outcome {cfg['parent_outcome']}, "
                f"desc_blocklist_size={len(cfg['desc_blocklist'])}"
            )
        else:
            print(
                f"  {T} (non-SNOMED or no parents) kept as outcome; "
                f"desc_blocklist_size={len(cfg['desc_blocklist'])}"
            )

    with sqlite3.connect(uri, uri=True) as conn:
        # load mapping tables once
        code2int, int2code = load_code_maps(conn)

        # expanded targets: alias families (string level)
        expanded_targets = print_full_targets(TARGETS, save_csv=True, out_dir=OUT_DIR)

        for T in TARGETS:
            print("\n" + "=" * 100)
            print(f"[RUN] Target = {T}")

            cfg = PARENT_CONFIG.get(T, {"parent_outcome": T, "desc_blocklist": set()})
            outcome_T_str = cfg["parent_outcome"]          # string
            desc_blocklist_str = cfg["desc_blocklist"]     # set[str]

            # --- map outcome and blocklist to INTs ---
            if outcome_T_str not in code2int:
                print(f"[WARN] Outcome {outcome_T_str} not found in concept_names; skipping.")
                continue
            outcome_T_int = int(code2int[outcome_T_str])

            desc_blocklist_int = {
                int(code2int[c]) for c in desc_blocklist_str if c in code2int
            }

            print(
                f"[TARGET CONFIG] outcome_str = {outcome_T_str}, outcome_int = {outcome_T_int}, "
                f"desc_blocklist_size_str = {len(desc_blocklist_str)}, "
                f"desc_blocklist_size_int = {len(desc_blocklist_int)}"
            )

            # 0) SNOMED alias family (string level)
            aliases_codes_str = expanded_targets.get(T, set())
            if aliases_codes_str:
                alias_ints = {int(code2int[a]) for a in aliases_codes_str if a in code2int}
            else:
                alias_ints = set()

            # 1) k→T rows: allow target_concept_code_int ∈ alias_ints (if SNOMED), else outcome_T_int only
            if alias_ints:
                k_to_T = _fetch_k_to_T_in(conn, sorted(alias_ints))
            else:
                k_to_T = _fetch_k_to_T(conn, outcome_T_int)

            if k_to_T.empty:
                print(f"[WARN] No k→T rows for {T}; skipping.")
                continue

            # derive / coerce
            k_to_T = _ensure_derived_cols(k_to_T)

            # 3) select predictors: top-K by TE / rank (one row per k_int)
            sel_rows = _select_top_by_te_unique_k(k_to_T, top_k=TOP_K)
            if sel_rows.empty:
                print(f"[WARN] No predictors selected for {T}; skipping.")
                continue

            selected_k = set(int(x) for x in sel_rows["concept_code_int"].dropna().unique())

            # remove descendants of T (INT blocklist)
            if desc_blocklist_int:
                before_sel = len(selected_k)
                selected_k -= desc_blocklist_int
                removed_sel = before_sel - len(selected_k)
                print(
                    f"[FILTER] Removed {removed_sel} predictors in descendant "
                    f"blocklist of {T}"
                )

            print(
                f"[SELECT] unique k available={k_to_T['concept_code_int'].nunique():,}  "
                f"selected={len(selected_k):,}"
            )
            if not selected_k:
                print(
                    f"[WARN] All candidate predictors for {T} "
                    f"are in descendant blocklist; skipping."
                )
                continue

            # 4) build subgraph among {outcome_T_int} ∪ selected_k
            events_set_int = selected_k | {outcome_T_int}

            df_trim = _fetch_subgraph_by_targets(conn, sorted(events_set_int))
            df_trim = df_trim[
                df_trim["target_concept_code_int"].isin(events_set_int)
                & df_trim["concept_code_int"].isin(selected_k)
            ].copy()

            before = len(df_trim)
            df_trim = df_trim.drop_duplicates()
            after = len(df_trim)
            if after < before:
                print(
                    f"[DEDUP] subgraph exact-row duplicates removed: {before - after}  "
                    f"(kept {after})"
                )

            df_trim = _ensure_derived_cols(df_trim)

            k_to_T_count = int((df_trim["target_concept_code_int"] == outcome_T_int).sum())
            T_to_j_count = int((df_trim["concept_code_int"] == outcome_T_int).sum())
            print(
                f"[TRIM] rows={len(df_trim):,}  events={len(events_set_int)}  "
                f"k→T rows={k_to_T_count}  outcome_T→j rows={T_to_j_count}"
            )

            sub_csv = os.path.join(
                OUT_DIR,
                f"magi_subgraph_{T}_parent_{outcome_T_str}_int.csv"
            )
            df_trim.to_csv(sub_csv, index=False)
            print(f"[SAVED] Subgraph → {sub_csv}")

            # 6) run MAGI, forcing the INT outcome
            try:
                res = analyze_causal_sequence_py(
                    df_trim,
                    name_map=None,             # ignored in _int version
                    events=None,               # let it auto-detect from *_int
                    force_outcome=outcome_T_int,
                    lambda_min_count=15,
                )
            except Exception as e:
                print("[ERROR] MAGI failed:", e)
                continue

            # 7) save NON-ZERO coefficients with predictor mapped back to concept_code
            coef_df = res["coef_df"].copy()

            coef_col = "coef" if "coef" in coef_df.columns else (
                "beta" if "beta" in coef_df.columns else None
            )
            if coef_col is None:
                raise KeyError(
                    "Coefficient column not found in coef_df "
                    "(expected 'coef' or 'beta')."
                )

            eps = 1e-12
            coef_df[coef_col] = coef_df[coef_col].astype(float)
            mask_nz = coef_df[coef_col].abs() > eps
            coef_nz = coef_df.loc[mask_nz].copy()

            pred_col = "predictor"
            if pred_col not in coef_nz.columns:
                raise KeyError("Expected 'predictor' column in coef_df.")

            # split intercept vs real predictors
            is_intercept = coef_nz[pred_col].astype(str) == "(intercept)"
            coef_intercept = coef_nz[is_intercept].copy()
            coef_preds = coef_nz[~is_intercept].copy()

            if not coef_preds.empty:
                coef_preds["predictor_int"] = coef_preds[pred_col].astype(int)
                coef_preds["predictor"] = coef_preds["predictor_int"].map(int2code)
                # fallback if some ints not in map
                missing = coef_preds["predictor"].isna()
                if missing.any():
                    coef_preds.loc[missing, "predictor"] = (
                        coef_preds.loc[missing, "predictor_int"].astype(str)
                    )
            else:
                coef_preds["predictor_int"] = []
                coef_preds["predictor"] = []

            if not coef_intercept.empty:
                coef_intercept["predictor_int"] = np.nan
                coef_intercept["predictor"] = "(intercept)"

            coef_out = pd.concat([coef_intercept, coef_preds], ignore_index=True)

            coef_dir = os.path.join(OUT_DIR, "Coef_Yk")
            os.makedirs(coef_dir, exist_ok=True)

            coef_csv = os.path.join(
                coef_dir,
                f"magi_coef_{T}_parent_{outcome_T_str}_int_nonzero.csv"
            )
            coef_out.to_csv(coef_csv, index=False)

            print(
                f"[SAVED] Coefficients (non-zero only, descendant predictors removed={len(desc_blocklist_int)}) "
                f"→ {coef_csv}  | kept={len(coef_out):,} of {len(coef_df):,}  "
                f"| nodes={len(res.get('order_used', [])) - 1}"
            )
