In [None]:
# %% Imports (unchanged)
import os, re, pickle, glob, warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore", category=FutureWarning)

# %% CONFIG (unchanged)
MODEL_DIR = "/Users/authorname/Desktop/Projects/ml4h_project/LGBM(Genes+Demo)"
SEEDS     = [1,2,3,4,5]
GENO_PARQUET      = "/Users/authorname/Downloads/all_chr_merged.parquet"
GWAS_CATALOG_TSV  = "/Users/authorname/Downloads/MONDO_0004975_associations_export.tsv"
GTEX_BRAIN_DIR = "/Users/authorname/Downloads/GTEx_v10_SuSiE_eQTL"
GTEX_GLOB      = os.path.join(GTEX_BRAIN_DIR, "Brain_*eQTLs.SuSiE_summary.parquet")
OUT_DIR = os.path.join(MODEL_DIR, "eqtl_overlap_results")
os.makedirs(OUT_DIR, exist_ok=True)
NON_GENE_FEATURES_EXACT = {"age_death","educ","msex"}
NON_GENE_PREFIXES = ("apoe_genotype_",)
PIP_MIN = None

def safe_cls(c: str) -> str:
    return c.replace("+","plus").replace(" ","_").replace("/","-")
def unsafify_cls(s: str) -> str:
    return s.replace("plus","+").replace("_"," ").replace("-","/")
def get_feature_names(automl):
    model = automl.model
    est = getattr(model, "estimator", model)
    if hasattr(est, "feature_name_") and est.feature_name_ is not None:
        return list(est.feature_name_)
    if hasattr(automl, "feature_names_in_") and automl.feature_names_in_ is not None:
        return list(automl.feature_names_in_)
    n_feats = getattr(est, "n_features_", None) or getattr(est, "n_features_in_", None)
    if n_feats is None:
        raise RuntimeError("Could not determine feature names.")
    return [f"feat_{i}" for i in range(n_feats)]
def get_importances(automl):
    model = automl.model
    est = getattr(model, "estimator", model)
    try:
        booster = est.booster_
        imp_gain = booster.feature_importance(importance_type="gain")
        if np.sum(imp_gain) > 0:
            return imp_gain.astype(float)
    except Exception:
        pass
    if hasattr(est, "feature_importances_"):
        return est.feature_importances_.astype(float)
    raise RuntimeError("No importances on fitted estimator.")

# %% Discover classes (unchanged)
patt = re.compile(r"seed(\d+)_(.+)_automl\.pkl$")
classes = sorted({unsafify_cls(m.group(2)) for fn in os.listdir(MODEL_DIR) if (m:=patt.match(fn))})
print(f"Discovered classes: {classes}")

# %% STEP A — Rebuild SNP→gene map + add COORDS  --------------------------------
print("\n[STEP A] Loading genotype + GWAS catalog, rebuilding SNP→gene map with QC ...")
df = pd.read_parquet(GENO_PARQUET)
gwas_df = pd.read_csv(GWAS_CATALOG_TSV, sep="\t")

# Clean GWAS fields (unchanged)
gwas_df['riskAllele_cleaned'] = gwas_df['riskAllele'].str.extract(r'(rs\d+)', expand=False)
gwas_df['first_gene'] = gwas_df['mappedGenes'].str.split(',').str[0].str.strip()

# >>> UPDATED/NEW <<<  pick a locations column and normalize to 'chr{chrom}_{pos}'
loc_col = next((c for c in ['locations','location','chromosomeLocation','CHR_POS'] if c in gwas_df.columns), None)
def norm_loc_to_chrpos(val):
    if pd.isna(val): return None
    s = str(val).split(';')[0].strip()
    m = re.match(r'^(?:chr)?(\w+):(\d+)', s)  # matches '1:1049997' or 'chr1:1049997' or '1:1049997-...'
    if m:
        return f"chr{m.group(1)}_{m.group(2)}"
    return None
gwas_df['chr_pos'] = gwas_df[loc_col].apply(norm_loc_to_chrpos) if loc_col else None

# rsID -> gene (as before)
rs_to_gene = (gwas_df.dropna(subset=['riskAllele_cleaned','first_gene'])
                     .set_index('riskAllele_cleaned')['first_gene'].to_dict())

# >>> NEW <<< rsID -> chr_pos (to recover coords for rs-only columns)
rs_to_chrpos = (gwas_df.dropna(subset=['riskAllele_cleaned','chr_pos'])
                        .drop_duplicates(subset=['riskAllele_cleaned','chr_pos'])
                        .set_index('riskAllele_cleaned')['chr_pos'].to_dict())

# Identify SNP columns and keep those we can map via rsID
df.columns = [c.strip() for c in df.columns]
snp_cols_raw = [c for c in df.columns if c.startswith('rs') or c.startswith('chr')]

def base_rsid(col):  # 'rs123_A' -> 'rs123'
    return col.split('_')[0]

# rsID-based map to genes (unchanged)
all_mapped_snps = {col: rs_to_gene.get(base_rsid(col))
                   for col in snp_cols_raw if rs_to_gene.get(base_rsid(col))}
print(f"Initially mapped {len(all_mapped_snps)} SNP columns to genes via GWAS catalog.")

# QC (unchanged)
def calc_maf(col):
    x = df[col]
    return (2*x.eq(2).sum() + x.eq(1).sum()) / (2*x.notna().sum())
snps_to_map = list(all_mapped_snps.keys())
maf = pd.Series({col: calc_maf(col) for col in snps_to_map})
miss = df[snps_to_map].isna().mean()
snps_remove = set(maf[maf < 0.01].index.tolist()) | set(miss[miss > 0.05].index.tolist())
all_mapped_snps_qc = {snp: gene for snp, gene in all_mapped_snps.items() if snp not in snps_remove}
print(f"Retained {len(all_mapped_snps_qc)} SNPs after QC.")

# Build gene -> list of SNP column names (unchanged)
from collections import defaultdict
gene_to_snpcols = defaultdict(list)
for snp_col, gene in all_mapped_snps_qc.items():
    gene_to_snpcols[gene].append(snp_col)

# >>> NEW <<<  derive a normalized chr_pos for each SNP column
chrpos_regex = re.compile(r'chr(\w+)[_:](\d+)')  # matches 'chr1:285155' or 'chr1_285155'
def col_to_chrpos(col):
    # 1) try to parse coordinate from the column name directly
    m = chrpos_regex.search(col)
    if m:
        return f"chr{m.group(1)}_{m.group(2)}"
    # 2) if it's an rsID column, try the GWAS rsID→chr_pos map
    rs = base_rsid(col) if col.startswith('rs') else None
    if rs and rs in rs_to_chrpos:
        return rs_to_chrpos[rs]
    return None

# gene -> set of chr_pos used by your model for that gene
gene_to_chrpos_used = {g: {cp for cp in (col_to_chrpos(c) for c in cols) if cp is not None}
                       for g, cols in gene_to_snpcols.items()}

# %% STEP B — Load GTEx and add chr_pos from variant_id  ------------------------
print("\n[STEP B] Loading GTEx v10 SuSiE brain cis-eQTL parquet files ...")
gtex_files = sorted(glob.glob(GTEX_GLOB))
if not gtex_files:
    raise RuntimeError("No GTEx brain parquet files found. Check GTEX_GLOB path/pattern.")

eqtl_rows = []
for path in gtex_files:
    tissue = os.path.basename(path).split(".v10.eQTLs.SuSiE_summary.parquet")[0]
    tdf = pd.read_parquet(path)
    if "gene_name" not in tdf.columns or "variant_id" not in tdf.columns:
        continue
    # >>> UPDATED/NEW <<< make 'chr_pos' = first two tokens of variant_id: 'chrX_pos'
    # variant_id looks like 'chr1_285155_A_C_b38' (indels possible in ref/alt, this still works)
    parts = tdf["variant_id"].str.split("_", n=4, expand=True)
    tdf["chr_pos"] = parts[0] + "_" + parts[1]  # 'chr1' + '_' + '285155'
    # optional PIP filter
    if PIP_MIN is not None and "pip" in tdf.columns:
        tdf = tdf.loc[tdf["pip"] >= PIP_MIN]
    tdf = tdf[["gene_name","chr_pos"] + (["pip"] if "pip" in tdf.columns else [])].copy()
    tdf["tissue"] = tissue
    eqtl_rows.append(tdf)

if not eqtl_rows:
    raise RuntimeError("No usable GTEx rows parsed.")
gtex_eqtl = pd.concat(eqtl_rows, ignore_index=True)

# Lookups: gene -> set(chr_pos)
gene_to_gtex_chrpos = (
    gtex_eqtl.groupby("gene_name")["chr_pos"].apply(set).to_dict()
)

# Details per (gene, chr_pos): tissues and max PIP (if present)
if "pip" in gtex_eqtl.columns:
    eqtl_detail = (gtex_eqtl.groupby(["gene_name","chr_pos"])
                           .agg(tissues=("tissue", lambda x: sorted(set(x))),
                                max_pip=("pip","max")).reset_index())
else:
    eqtl_detail = (gtex_eqtl.groupby(["gene_name","chr_pos"])
                           .agg(tissues=("tissue", lambda x: sorted(set(x)))).reset_index())
    eqtl_detail["max_pip"] = np.nan

print(f"GTEx eQTL rows loaded: {len(gtex_eqtl):,}")

# %% STEP C — Predictive genes per class (unchanged)
print("\n[STEP C] Identifying predictive genes per class (non-zero importance in ≥2 seeds) ...")

def is_non_gene(feature_name: str) -> bool:
    if feature_name in NON_GENE_FEATURES_EXACT:
        return True
    return any(feature_name.startswith(p) for p in NON_GENE_PREFIXES)

all_gene_names = set(gene_to_snpcols.keys())
predictive_genes_per_class, counts_per_class = {}, {}

for cls in classes:
    counts = {}
    for seed in SEEDS:
        pkl_path = os.path.join(MODEL_DIR, f"seed{seed}_{safe_cls(cls)}_automl.pkl")
        if not os.path.exists(pkl_path):
            print(f"[skip] Missing {pkl_path}")
            continue
        with open(pkl_path, "rb") as f:
            automl = pickle.load(f)
        feats = get_feature_names(automl)
        imps  = get_importances(automl)
        if len(feats) != len(imps):
            raise ValueError(f"Mismatch in feature name/importance lengths for {cls}, seed {seed}")
        nonzero = [f for f, imp in zip(feats, imps) if imp > 0]
        nonzero_genes = [f for f in nonzero if (f in all_gene_names and not is_non_gene(f))]
        for g in nonzero_genes:
            counts[g] = counts.get(g, 0) + 1
    counts_per_class[cls] = counts
    predictive_genes_per_class[cls] = sorted([g for g, c in counts.items() if c >= 2])
    print(f"[{cls}] predictive genes (≥2 seeds): {len(predictive_genes_per_class[cls])}")

pd.DataFrame.from_dict(counts_per_class, orient="index").T.to_csv(
    os.path.join(OUT_DIR, "predictive_gene_counts_per_class.csv")
)

# %% STEP D — COORDINATE-LEVEL overlap with GTEx (per class & gene) ------------
print("\n[STEP D] Computing coordinate-level overlaps with GTEx brain cis-eQTLs ...")

rows = []
for cls in classes:
    for g in predictive_genes_per_class.get(cls, []):
        used_chrpos = gene_to_chrpos_used.get(g, set())
        gtex_chrpos = gene_to_gtex_chrpos.get(g, set())
        overlap = sorted(used_chrpos & gtex_chrpos)
        # details
        detail_list = []
        if overlap:
            det = eqtl_detail[(eqtl_detail["gene_name"] == g) & (eqtl_detail["chr_pos"].isin(overlap))]
            for _, r in det.iterrows():
                detail_list.append({
                    "chr_pos": r["chr_pos"],
                    "tissues": ";".join(r["tissues"]),
                    "max_pip": r.get("max_pip", np.nan)
                })
        rows.append({
            "class": cls,
            "gene": g,
            "n_chrpos_used": len(used_chrpos),
            "n_overlap_chrpos": len(overlap),
            "overlap_chrpos": ";".join(overlap),
            "overlap_detail": detail_list,
            "any_eqtl_anyBrain": (g in gene_to_gtex_chrpos)
        })

overlap_df = pd.DataFrame(rows)
overlap_path = os.path.join(OUT_DIR, "eqtl_coord_overlap_per_gene_per_class.csv")
if not overlap_df.empty:
    overlap_df["overlap_detail_json"] = overlap_df["overlap_detail"].apply(lambda x: "" if not x else str(x))
    overlap_df.drop(columns=["overlap_detail"], inplace=True)
overlap_df.to_csv(overlap_path, index=False)
print(f"[saved] {overlap_path}")

# Exploded details table
detail_rows = []
for _, r in overlap_df.iterrows():
    if not r["overlap_chrpos"]:
        continue
    cls, gene = r["class"], r["gene"]
    for cp in r["overlap_chrpos"].split(";"):
        det = eqtl_detail[(eqtl_detail["gene_name"] == gene) & (eqtl_detail["chr_pos"] == cp)]
        for _, d in det.iterrows():
            detail_rows.append({
                "class": cls,
                "gene": gene,
                "chr_pos": cp,
                "tissues": ";".join(d["tissues"]),
                "max_pip": d.get("max_pip", np.nan)
            })
detail_df = pd.DataFrame(detail_rows)
if not detail_df.empty:
    detail_path = os.path.join(OUT_DIR, "eqtl_coord_overlap_details.csv")
    detail_df.to_csv(detail_path, index=False)
    print(f"[saved] {detail_path}")

# %% Summary
print("\n=== Summary (coordinate-based) ===")
print(overlap_df.groupby("class")["n_overlap_chrpos"].sum().rename("total_overlapping_chrpos_by_class"))

In [None]:
import pandas as pd

overlap_path = "/Users/authorname/Desktop/Projects/ml4h_project/LGBM(Genes+Demo)/eqtl_overlap_results/eqtl_coord_overlap_per_gene_per_class.csv"
df = pd.read_csv(overlap_path)

# filter to only overlapping genes
hits = df[df["n_overlap_chrpos"] > 0]

print("=== Overlapping genes by class ===")
for cls, sub in hits.groupby("class"):
    genes = sub["gene"].tolist()
    print(f"{cls} ({len(genes)}): {', '.join(genes)}")