In [None]:
######### All of Us Case-Control Organization ################ 
# ---------- UTILS ----------
def banner(txt):
    bar = "=" * max(12, len(txt) + 4)
    print(f"\n{bar}\n{txt}\n{bar}")

def subhead(txt):
    print(f"\n--- {txt} ---")

def safe_name(s: str) -> str:
    return "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in s)

def qtiles(x):
    x = np.asarray(x)
    return np.quantile(x, [0, 0.01, 0.25, 0.5, 0.75, 0.99, 1.0])

def preview_active_codes(X_csr, feature_codes, row_indices, k=8):
    """Print up to k active codes for a few rows."""
    for i in row_indices:
        start, end = X_csr.indptr[i], X_csr.indptr[i+1]
        cols = X_csr.indices[start:end]
        codes_list = [feature_codes[j] for j in cols[:k]]
        print(f"   row {i}: n_active={len(cols)}  sample_active={codes_list}")

def load_magi_betas(coef_csv):
    df = pd.read_csv(coef_csv)
    def pick(df, opts):
        for c in opts:
            if c in df.columns:
                return c
        raise KeyError(f"Missing any of {opts} in {coef_csv}. Found: {list(df.columns)}")
    code_col = pick(df, ["concept_code","standard_concept_code","predictor","feature","term","name"])
    beta_col = pick(df, ["coef","coefficient","beta","estimate","b","value"])
    df[code_col] = df[code_col].astype(str).str.strip()
    is_int = df[code_col].str.lower().isin(["(intercept)","intercept","const","(const)","bias"])
    intercept = float(df.loc[is_int, beta_col].iloc[0]) if is_int.any() else 0.0
    coef_map  = dict(zip(df.loc[~is_int, code_col], df.loc[~is_int, beta_col].astype(float)))
    return intercept, coef_map

def sample_all_pos_kx_neg(y, k=4, seed=42):
    rng = np.random.default_rng(seed)
    pos = np.where(y == 1)[0]
    neg = np.where(y == 0)[0]
    if pos.size == 0: raise ValueError("No positives for this target.")
    want = min(k * pos.size, neg.size)
    sel_neg = rng.choice(neg, size=want, replace=False)
    sel = np.concatenate([pos, sel_neg]); rng.shuffle(sel)
    return sel

def score_from_betas(X_sub, feature_codes, betas_map, intercept):
    feat = np.array(feature_codes, dtype=str)
    mask = np.isin(feat, list(betas_map.keys()))
    idx  = np.where(mask)[0]
    if idx.size == 0: raise ValueError("No overlap between features and MAGI coefficients.")
    betas = np.array([betas_map[c] for c in feat[idx]], dtype=float)
    lp = intercept + X_sub[:, idx].dot(betas)      # (n,)
    p  = expit(np.asarray(lp).ravel())
    return np.asarray(lp).ravel(), p, idx, betas

def plot_roc(y_true, p_hat, title, out_png, out_svg):
    fpr, tpr, _ = roc_curve(y_true, p_hat)
    auc = roc_auc_score(y_true, p_hat)
    plt.figure()
    plt.plot(fpr, tpr, label=f"AUC={auc:.3f}")
    plt.plot([0,1], [0,1], linestyle="--", linewidth=1)
    plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
    plt.title(title); plt.legend(loc="lower right"); plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.savefig(out_svg, bbox_inches="tight")
    plt.close()
    return auc

# ---------- LOAD DESIGN ONCE ----------
banner("LOAD DESIGN")
X_full  = load_npz(f"{BASE}/Lasso_X.npz").tocsr().astype(np.float32)
persons = pd.read_csv(f"{BASE}/person_index.csv")["person_id"].astype(str).to_numpy()
codes   = pd.read_csv(f"{BASE}/code_index.csv")["concept_code"].astype(str).to_numpy()
print(f"[INFO] Matrix: persons={X_full.shape[0]:,}  codes={X_full.shape[1]:,}")
if len(persons) != X_full.shape[0] or len(codes) != X_full.shape[1]:
    raise ValueError("[ERROR] person/code indices do not match matrix shape.")

# ---------- RUN PER TARGET ----------
summary = []
for tcode in TARGETS:
    pretty = TARGET_NAME.get(tcode, tcode)
    banner(f"TARGET {tcode} — {pretty}")

    # SECTION A: labels
    subhead("A) Label vector from full design")
    idx_y = np.where(codes == tcode)[0]
    if idx_y.size == 0:
        print(f"[SKIP] Target not found in code_index.csv → {tcode}")
        continue
    y_full = X_full[:, idx_y[0]].toarray().ravel().astype(np.int8)
    print(f"[INFO] y_full: n={y_full.size:,}  pos={int(y_full.sum()):,}  "
          f"prev={y_full.mean():.4f}")

    # SECTION B: predictors (keep all except DV)
    subhead("B) Predictor matrix (keep all columns except DV)")
    mask_pred = (codes != tcode)
    X = X_full[:, mask_pred]
    feature_codes = codes[mask_pred]
    print(f"[INFO] Predictors: persons={X.shape[0]:,}  features={X.shape[1]:,}")
    print(f"[CHECK] DV in features? {tcode in feature_codes} (should be False)")

    # SECTION C: sampling (all pos + 4x neg)
    subhead("C) Sampling (keep ALL positives + 4× negatives)")
    sel = sample_all_pos_kx_neg(y_full, k=NEG_MULT, seed=RNG_SEED)
    X_sub       = X[sel, :]
    y_sub       = y_full[sel].astype(np.int8)
    persons_sub = persons[sel]
    n_rows      = X_sub.shape[0]
    n_pos_sub   = int(y_sub.sum())
    n_neg_sub   = n_rows - n_pos_sub
    print(f"[INFO] subset: n={n_rows:,}  pos={n_pos_sub:,}  neg={n_neg_sub:,}  "
          f"ratio≈{(n_neg_sub/max(n_pos_sub,1)):.2f}:1  PR-baseline={y_sub.mean():.4f}")
    # a tiny peek at first 3 rows' active codes
    try:
        preview_active_codes(X_sub, feature_codes, row_indices=range(min(3, n_rows)), k=8)
    except Exception as e:
        print(f"[WARN] preview_active_codes failed: {e}")

    # SECTION D: coefficients
    subhead("D) Load MAGI coefficients")
    coef_csv = COEF_PATTERN.format(target=tcode)
    if not os.path.exists(coef_csv):
        print(f"[SKIP] Coef file missing: {coef_csv}")
        continue
    intercept, coef_map = load_magi_betas(coef_csv)
    print(f"[INFO] Coefs: intercept={intercept:.6f}  n_features={len(coef_map):,}")
    # show a few coef samples
    for k,(cc,bb) in enumerate(list(coef_map.items())[:5]):
        print(f"   beta[{cc}] = {bb:.6f}")
    if "(intercept)" not in open(coef_csv, 'r', encoding="utf-8", errors="ignore").read():
        print("[NOTE] No explicit '(intercept)' row in CSV; using 0.0 if not found.")

    # SECTION E: alignment & scoring
    subhead("E) Align & score")
    lp, p_hat, idx_overlap, betas_vec = None, None, None, None
    try:
        lp, p_hat, idx_cols, betas_vec = score_from_betas(X_sub, feature_codes, coef_map, intercept)
    except Exception as e:
        print(f"[SKIP] Scoring failed (no overlap or other issue): {e}")
        continue

    n_overlap = idx_cols.size
    print(f"[INFO] overlap with predictors = {n_overlap:,} columns")
    print(f"[INFO] first 5 aligned columns: {[feature_codes[i] for i in idx_cols[:5]]}")
    print(f"[INFO] first 5 aligned betas:   {[float(b) for b in betas_vec[:5]]}")

    # SECTION F: metrics & distributions
    subhead("F) Metrics & probability distribution")
    auc    = roc_auc_score(y_sub, p_hat)
    pr_auc = average_precision_score(y_sub, p_hat)
    q = qtiles(p_hat)
    print(f"[RESULT] AUC={auc:.4f}  |  PR-AUC={pr_auc:.4f}  (baseline={y_sub.mean():.4f})")
    print(f"[DIST] prob quantiles: min={q[0]:.4g}, p1={q[1]:.4g}, p25={q[2]:.4g}, "
          f"median={q[3]:.4g}, p75={q[4]:.4g}, p99={q[5]:.4g}, max={q[6]:.4g}")
    print(f"[COUNT] prob>=0.999: {(p_hat>=0.999).sum()}  |  prob<=0.001: {(p_hat<=0.001).sum()}")

    # SECTION G: save predictions
    subhead("G) Save per-person predictions")
    safe = safe_name(pretty)
    pred_csv = os.path.join(CSV_DIR, f"pred_{safe}.csv")
    pd.DataFrame({
        "person_id": persons_sub,
        "y_true": y_sub.astype(int),
        "score_logit": lp,
        "prob": p_hat
    }).to_csv(pred_csv, index=False)
    print(f"[SAVE] predictions → {pred_csv}")
    print(pd.read_csv(pred_csv).head(10))

    # SECTION H: plots
    subhead("H) ROC plots (PNG/SVG)")
    png_path = os.path.join(PNG_DIR, f"ROC_{safe}.png")
    svg_path = os.path.join(PNG_DIR, f"ROC_{safe}.svg")
    _auc = plot_roc(y_sub, p_hat, pretty, png_path, svg_path)
    print(f"[SAVE] ROC → {png_path}")
    print(f"[SAVE] ROC → {svg_path}")

    # accumulate summary
    summary.append({
        "target_code": tcode,
        "target_name": pretty,
        "n_cases": n_rows,
        "n_pos": n_pos_sub,
        "n_neg": n_neg_sub,
        "feature_overlap": n_overlap,
        "AUC": auc,
        "PR_AUC": pr_auc,
        "PR_baseline": y_sub.mean(),
        "coef_csv": coef_csv,
        "pred_csv": pred_csv,
        "roc_png": png_path
    })

