In [1]:
# =========================
# Enrichment pipeline (ORA)
# =========================
from __future__ import annotations
import re
from pathlib import Path
from textwrap import wrap
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple

import numpy as np
import pandas as pd
from scipy.stats import hypergeom
from statsmodels.stats.multitest import multipletests
import matplotlib.pyplot as plt


# -------- Configuration --------
CFG = {
    # Single source of truth for labels
    "TP_LABEL": {"B0": "Baseline", "D1": "Day 1", "D5": "Day 5", "F3": "Month 3", "F6": "Month 6"},
    "TP_ORDER": ["B0", "D1", "D5", "F3", "F6"],

    # Column config
    "feature_col": "feature_id",   # <-- change to your metabolite/feature ID column
    "id_col": "kegg_maps",         # <-- pathway membership column with multiple IDs per row
    "id_delims": r"[,\s;]+",       # split on comma/space/semicolon
    "sig_alpha": 0.05,             # default significance threshold

    # Output
    "out_dir": "./results/enrichment",  # base output dir (will create per-TP subdirs)

    # Plot
    "plot_topn": 20,
    "plot_alpha_line_is_fdr": True,  # vline at -log10(FDR alpha) if True else P-value alpha
}

# If you have a dict mapping {pathway_id: pathway_name}, set it here, else leave as {}.
PATHWAY_NAME_MAP: Dict[str, str] = {}


# -------- File discovery & timepoint mapping --------
TP_REGEX = re.compile(r"\b(B0|D1|D5|F3|F6)\b", flags=re.IGNORECASE)

def parse_tp_from_name(path: Path) -> Optional[str]:
    m = TP_REGEX.search(path.name)
    return m.group(1).upper() if m else None

def build_tp_map(files: Sequence[str]) -> Dict[str, Path]:
    by_tp: Dict[str, Path] = {}
    for f in files:
        p = Path(f)
        tp = parse_tp_from_name(p)
        if not tp:
            continue
        if tp in by_tp:
            raise ValueError(f"Duplicate file for {tp}:\n  {by_tp[tp]}\n  {p}")
        by_tp[tp] = p
    missing = [t for t in CFG["TP_LABEL"] if t not in by_tp]
    if missing:
        raise ValueError(f"Missing required timepoints in files: {missing}")
    return by_tp

def load_tp_df(tp: str, tp_map: Mapping[str, Path]) -> Tuple[pd.DataFrame, str, Path]:
    tp = tp.upper()
    if tp not in CFG["TP_LABEL"]:
        raise ValueError(f"Unsupported tp={tp}. Allowed: {list(CFG['TP_LABEL'])}")
    src = tp_map[tp]
    df = pd.read_csv(src)
    # Optional sanity check if file contains a timepoint column
    if "timepoint" in df.columns:
        tp_in_file = str(df["timepoint"].iloc[0]).upper()
        if tp_in_file != tp:
            raise AssertionError(f"File {src} says timepoint={tp_in_file}, expected {tp}")
    day = CFG["TP_LABEL"][tp]
    return df, day, src


# -------- Helpers for pathway membership & significance --------
def split_ids(cell: object, pattern: str = CFG["id_delims"]) -> List[str]:
    if pd.isna(cell):
        return []
    s = str(cell).strip()
    if not s:
        return []
    return [tok for tok in re.split(pattern, s) if tok]

def extract_feature_to_sets(df: pd.DataFrame,
                            feature_col: str,
                            id_col: str) -> Dict[str, Set[str]]:
    if feature_col not in df.columns:
        raise KeyError(f"feature_col '{feature_col}' not in df")
    if id_col not in df.columns:
        raise KeyError(f"id_col '{id_col}' not in df")
    mapping: Dict[str, Set[str]] = {}
    for feat, memberships in zip(df[feature_col].astype(str), df[id_col].map(split_ids)):
        if not memberships:
            continue
        mapping.setdefault(feat, set()).update(memberships)
    return mapping

def infer_significant_mask(df: pd.DataFrame,
                           alpha: float = CFG["sig_alpha"],
                           sig_flag_col: Optional[str] = None,
                           fdr_cols: Sequence[str] = ("FDR","qval","adj_pval","padj","BH_FDR","q","q_value","adj.P.Val"),
                           p_cols: Sequence[str] = ("pval","p_value","P-value","P Value","p","P","P.value")) -> pd.Series:
    n = len(df)
    if sig_flag_col and sig_flag_col in df.columns:
        mask = df[sig_flag_col].astype(bool)
        return mask.fillna(False)

    for c in fdr_cols:
        if c in df.columns:
            return (pd.to_numeric(df[c], errors="coerce") <= alpha).fillna(False)

    for c in p_cols:
        if c in df.columns:
            return (pd.to_numeric(df[c], errors="coerce") <= alpha).fillna(False)

    raise KeyError("Could not infer significance. Provide either a boolean flag column, "
                   "or an FDR/adjusted p-value column, or a raw p-value column.")


# -------- Build global universe across ALL timepoints --------
def build_global_universe(tp_map: Mapping[str, Path],
                          feature_col: str,
                          id_col: str) -> Tuple[Set[str], Dict[str, Set[str]]]:
    """
    Returns:
      U: set of all features observed with >=1 pathway membership across all timepoints
      feat2sets_all: merged mapping feature -> set of pathway_ids (union across TPs)
    """
    feat2sets_all: Dict[str, Set[str]] = {}
    for tp in CFG["TP_ORDER"]:
        df, _, _ = load_tp_df(tp, tp_map)
        m = extract_feature_to_sets(df, feature_col, id_col)
        for feat, sets_ in m.items():
            feat2sets_all.setdefault(feat, set()).update(sets_)
    U = {f for f, s in feat2sets_all.items() if len(s) > 0}
    if not U:
        raise ValueError("Empty universe: no features with pathway membership across files.")
    return U, feat2sets_all


# -------- ORA core --------
def compute_ora_for_tp(df: pd.DataFrame,
                       tp: str,
                       universe_feats: Set[str],
                       feat2sets_all: Mapping[str, Set[str]],
                       feature_col: str,
                       alpha: float = CFG["sig_alpha"],
                       sig_flag_col: Optional[str] = None) -> pd.DataFrame:
    """
    Hypergeometric over-representation analysis using a fixed global universe.
    """
    # Significant features at this TP
    sig_mask = infer_significant_mask(df, alpha=alpha, sig_flag_col=sig_flag_col)
    sig_feats_raw = set(df.loc[sig_mask, feature_col].astype(str))
    # Restrict to features that are in the universe & have at least one pathway in feat2sets_all
    sig_feats = {f for f in sig_feats_raw if f in universe_feats and f in feat2sets_all}

    # Build pathway -> members (from global mapping)
    pathway_to_members: Dict[str, Set[str]] = {}
    for feat, sets_ in feat2sets_all.items():
        if feat not in universe_feats:
            continue
        for pid in sets_:
            pathway_to_members.setdefault(pid, set()).add(feat)

    M = len(universe_feats)     # population size
    N = len(sig_feats)          # number of draws (significant)
    if N == 0:
        return pd.DataFrame(columns=[
            "tp","pathway_id","pathway_name","Overlap","k","n","N","M",
            "Odds Ratio","P-value","FDR","neglog10p"
        ])

    rows = []
    for pid, members in pathway_to_members.items():
        n = len(members)             # pathway size in universe
        if n == 0:
            continue
        k = len(sig_feats & members) # overlap
        # hypergeom survival function: P[X >= k]
        pval = hypergeom.sf(k - 1, M, n, N)
        # Odds ratio with Haldane-Anscombe correction to avoid zeros
        a = k + 0.5
        b = (N - k) + 0.5
        c = (n - k) + 0.5
        d = (M - N - (n - k)) + 0.5
        or_est = (a * d) / (b * c)
        rows.append({
            "tp": tp,
            "pathway_id": pid,
            "pathway_name": PATHWAY_NAME_MAP.get(pid, pid),
            "Overlap": f"{k}/{n}",
            "k": k, "n": n, "N": N, "M": M,
            "Odds Ratio": or_est,
            "P-value": pval,
        })

    res = pd.DataFrame(rows)
    if not res.empty:
        # Multiple testing correction across all pathways for this TP
        res["FDR"] = multipletests(res["P-value"].values, method="fdr_bh")[1]
        res["neglog10p"] = -np.log10(np.clip(res["P-value"].values, 1e-300, 1.0))
        res = res.sort_values(["FDR","P-value","Odds Ratio","pathway_id"]).reset_index(drop=True)
    return res


# -------- Plotting --------
def plot_top(df: pd.DataFrame, title: str, out_png: Path, topn: int = CFG["plot_topn"],
             alpha: float = CFG["sig_alpha"], use_fdr: bool = CFG["plot_alpha_line_is_fdr"]) -> None:
    if df.empty:
        print(f"No enriched pathways to plot for: {title}")
        return
    df = df.sort_values(["P-value","pathway_name"]).head(topn).copy()
    df["label"] = df["pathway_name"].map(lambda s: "\n".join(wrap(str(s), width=35)))

    plt.figure(figsize=(9, max(4, 0.45*len(df))))
    y = np.arange(len(df))[::-1]
    plt.barh(y, df["neglog10p"].values)
    plt.yticks(y, df["label"].values)
    plt.xlabel(r"$-\log_{10}(\mathrm{P\text{-}value})$")
    plt.title(title)

    # Add text labels (OR and Overlap)
    for i, (xv, orr, ov) in enumerate(zip(df["neglog10p"], df["Odds Ratio"], df["Overlap"])):
        plt.text(xv + 0.05, y[i], f"OR={orr:.2f} | {ov}", va="center", fontsize=9)

    # Vertical line for significance threshold
    thresh = -np.log10(alpha)
    if use_fdr:
        # If using FDR as the decision criterion, convert: we don't know per-row FDR<alpha → bar heights are -log10(P)
        # The most interpretable is still a P-value line unless you also plot -log10(FDR).
        # If you wish to draw a line at FDR alpha, consider plotting -log10(FDR) instead.
        pass
    plt.axvline(thresh, linestyle="--", linewidth=1.2)

    plt.tight_layout()
    out_png.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_png, dpi=200, bbox_inches="tight")
    plt.close()


# -------- Orchestration --------
def run_for_timepoint(tp: str,
                      files: Sequence[str],
                      out_base: str = CFG["out_dir"],
                      feature_col: str = CFG["feature_col"],
                      id_col: str = CFG["id_col"],
                      alpha: float = CFG["sig_alpha"],
                      sig_flag_col: Optional[str] = None,
                      make_plot: bool = True) -> pd.DataFrame:
    """
    Deterministic load + global-universe ORA for a single timepoint.
    Writes CSV/Parquet/PNG under out_base/{tp}/
    """
    tp_map = build_tp_map(files)
    # Build global universe once using all files
    U, feat2sets_all = build_global_universe(tp_map, feature_col, id_col)
    df_tp, day_label, src = load_tp_df(tp, tp_map)

    # ORA
    res = compute_ora_for_tp(df_tp, tp, U, feat2sets_all, feature_col, alpha=alpha, sig_flag_col=sig_flag_col)

    # Write outputs
    out_dir = Path(out_base) / tp
    out_dir.mkdir(parents=True, exist_ok=True)
    out_csv = out_dir / f"enrichment_{tp}.csv"
    out_parq = out_dir / f"enrichment_{tp}.parquet"
    res.to_csv(out_csv, index=False)
    try:
        res.to_parquet(out_parq, index=False)
    except Exception:
        pass  # parquet optional

    # Plot
    if make_plot and not res.empty:
        out_png = out_dir / f"enrichment_{tp}.png"
        plot_top(res, f"{tp} — {day_label}", out_png, alpha=alpha)

    print(f"[{tp}] Loaded from: {src}")
    print(f"[{tp}] Universe size M={len(U)}; N_sig={res['N'].iloc[0] if not res.empty else 0}")
    print(f"[{tp}] Wrote: {out_csv}")
    return res

def run_all_timepoints(files: Sequence[str],
                       out_base: str = CFG["out_dir"],
                       feature_col: str = CFG["feature_col"],
                       id_col: str = CFG["id_col"],
                       alpha: float = CFG["sig_alpha"],
                       sig_flag_col: Optional[str] = None,
                       make_plots: bool = True) -> pd.DataFrame:
    """
    Iterate in canonical order with a single, shared universe across all TPs.
    Returns concatenated results with a 'tp' column.
    """
    tp_map = build_tp_map(files)
    U, feat2sets_all = build_global_universe(tp_map, feature_col, id_col)
    all_res: List[pd.DataFrame] = []

    for tp in CFG["TP_ORDER"]:
        df_tp, day_label, src = load_tp_df(tp, tp_map)
        res = compute_ora_for_tp(df_tp, tp, U, feat2sets_all, feature_col, alpha=alpha, sig_flag_col=sig_flag_col)
        out_dir = Path(out_base) / tp
        out_dir.mkdir(parents=True, exist_ok=True)
        res.to_csv(out_dir / f"enrichment_{tp}.csv", index=False)
        try:
            res.to_parquet(out_dir / f"enrichment_{tp}.parquet", index=False)
        except Exception:
            pass
        if make_plots and not res.empty:
            plot_top(res, f"{tp} — {CFG['TP_LABEL'][tp]}", out_dir / f"enrichment_{tp}.png", alpha=alpha)
        print(f"[{tp}] from {src}: M={len(U)}, N_sig={res['N'].iloc[0] if not res.empty else 0}, results={len(res)}")
        all_res.append(res)

    out_all = Path(out_base) / "combined_enrichment.csv"
    pd.concat(all_res, ignore_index=True).to_csv(out_all, index=False)
    print(f"Wrote combined: {out_all}")
    return pd.concat(all_res, ignore_index=True)
