In [None]:
import os
import sys
import json 
import math
import shutil
import tempfile
import subprocess
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad
from scipy.stats import spearmanr, norm
from statsmodels.stats.multitest import multipletests

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# parameters
AD_PATH = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
R_HOME = "/anvil/projects/x-mcb130189/aklein/SPIDA/.pixi/envs/dist/lib/R"
R_SCRIPT = "/home/x-aklein2/projects/aklein/BICAN/spida_dev/helper_scripts/sparkX.R"
OUTDIR = "/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/sparkx"
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/CPS/sparkx"
REP_KEY = "dataset_id"
CELLTYPE_KEY = "Subclass"
AXIS = "MS_NORM"

min_cells = 50
num_cores = 4

COORDS_MODE = "axis_only"
XY_KEYS = ("CENTER_X", "CENTER_Y")

In [None]:
OUTDIR = Path(OUTDIR)
image_path = Path(image_path)
OUTDIR.mkdir(parents=True, exist_ok=True)
image_path.mkdir(parents=True, exist_ok=True)
os.environ["R_HOME"] = R_HOME
# R_BIN = subprocess.check_output(["which", "Rscript"], text=True).strip()

In [None]:
def to_dense_df(X, var_names, obs_names):
    """Return a dense pandas DataFrame cells×genes from AnnData.X."""
    if hasattr(X, "toarray"):
        arr = X.toarray()
    elif hasattr(X, "A"):
        arr = X.A
    else:
        arr = np.asarray(X)
    return pd.DataFrame(arr, index=obs_names, columns=var_names)

def write_rds_from_python(obj, out_path):
    """
    Save a pandas DataFrame or numpy array to an .Rds file using rpy2.
    Requires rpy2 and R in the environment.
    """
    import rpy2.robjects as ro
    from rpy2.robjects import pandas2ri
    r = ro.r

    if isinstance(obj, pd.DataFrame):
        with ro.conversion.localconverter(ro.default_converter + pandas2ri.converter):
            r_df = pandas2ri.py2rpy(obj)
            r.assign(".__TMP_OBJ__", r_df)
    else:
            # assume numpy array
            r.assign(".__TMP_OBJ__", obj)

    r(f"saveRDS(.__TMP_OBJ__, file='{str(out_path)}')")

def write_r_input(expr_df, coords_df, base_path):
    expr_path = base_path / "expr_df.csv"
    coords_path = base_path / "coords_df.csv"
    expr_df.T.to_csv(expr_path)
    coords_df.to_csv(coords_path)
    return expr_path, coords_path


def write_arrow_input(expr_df, coords_df, base_path: Path):
    """
    Write SPARK-X inputs using Apache Arrow (Feather) format.
    Avoids rpy2 entirely, with zero-copy I/O and fast reads in R.
    
    Parameters
    ----------
    expr_df : pd.DataFrame
        Cells × genes dataframe.
    coords_df : pd.DataFrame
        Cells × 2 dataframe (x/y or axis coords).
    base_path : Path
        Temporary directory for this replicate.
    
    Returns
    -------
    tuple(Path, Path)
        Paths to expression feather and coords feather files.
    """
    import pyarrow as pa
    import pyarrow.feather as feather
    
    expr_path = base_path / "expr_df.feather"
    coords_path = base_path / "coords_df.feather"

    # Transpose to genes × cells for SPARK-X
    expr_genes_cells = expr_df.T.reset_index().rename(columns={"index": "gene"})
    coords_reset = coords_df.reset_index().rename(columns={"index": "cell"})

    # Write using pyarrow Feather format (fast, compressed)
    feather.write_feather(expr_genes_cells, expr_path)
    feather.write_feather(coords_reset, coords_path)
    return expr_path, coords_path

In [None]:
def run_sparkx_once(expr_rds, coords_rds, out_csv):
    cmd = [
        "pixi", "run", "-e", "dist", 
        "Rscript", R_SCRIPT,
        "--expr", str(expr_rds),
        "--coords", str(coords_rds),
        "--out", str(out_csv),
        "--min_cells", str(min_cells),
        "--num_cores", str(num_cores),
    ]
    subprocess.run(cmd, check=True)

def stouffer_meta(pvals, signs, weights):
    pvals = np.clip(np.asarray(pvals, float), 1e-300, 1.0)
    z = norm.isf(pvals / 2.0)        # two-sided to one-sided tail
    z_signed = np.sign(signs) * z
    Z = np.sum(weights * z_signed) / np.sqrt(np.sum(weights ** 2))
    p = 2 * norm.sf(abs(Z))
    return Z, p

def i_squared(effects, weights):
    k = len(effects)
    if k <= 1:
        return np.nan
    mean_eff = np.average(effects, weights=weights)
    Q = np.sum(weights * (effects - mean_eff) ** 2)
    df = k - 1
    return max(0.0, (Q - df) / Q) * 100 if Q > 0 else 0.0

In [None]:
def meta_and_qc(per_rep_csv, fdr_thresh=0.05):
    df = pd.read_csv(per_rep_csv)
    if df.empty:
        print("No replicate results found.")
        return pd.DataFrame()

    out = []
    # group either by cell_type (if present) or do single group
    group_cols = ["cell_type"] if "cell_type" in df.columns else []
    group_cols_gene = group_cols + ["gene"]

    for keys, g in df.groupby(group_cols_gene):
        if isinstance(keys, tuple):
            *ct_part, gene = keys
            cell_type = ct_part[0] if ct_part else None
        else:
            gene = keys
            cell_type = None

        pvals = g["p_sparkx"].values
        n = g["n_cells"].values
        weights = np.sqrt(n)

        # sign from Spearman rho w.r.t aligned axis
        signs = np.sign(g["rho_axis"].values)
        Z, p_meta = stouffer_meta(pvals, signs, weights)
        effs = signs * norm.isf(np.clip(pvals, 1e-300, 1.0) / 2.0)
        I2 = i_squared(effs, weights)

        out.append({
            "cell_type": cell_type if cell_type is not None else "ALL",
            "gene": gene,
            "meta_Z": Z,
            "meta_p": p_meta,
            "I2": I2,
            "direction": "up" if Z > 0 else "down",
        })

    meta = pd.DataFrame(out)
    # FDR per cell type
    meta["fdr"] = meta.groupby("cell_type")["meta_p"].transform(
        lambda p: multipletests(p, method="fdr_bh")[1]
    )
    meta = meta.sort_values(["cell_type", "fdr", "meta_p", "meta_Z"])
    meta.to_csv(OUTDIR / "sparkx_meta_results.csv", index=False)
    print(f"[meta] Saved: {OUTDIR/'sparkx_meta_results.csv'}")

    # --- Plots ---
    # Volcano per cell type
    for ct, g in meta.groupby("cell_type"):
        g = g.copy()
        g["log10FDR"] = -np.log10(g["fdr"].clip(lower=1e-300))
        plt.figure(figsize=(6,5))
        sns.scatterplot(data=g, x="meta_Z", y="log10FDR",
                        hue=g["fdr"] < fdr_thresh,
                        palette={True: "crimson", False: "grey"},
                        s=18, alpha=0.7, legend=False)
        plt.axhline(-np.log10(fdr_thresh), ls="--", lw=1, c="black")
        plt.title(f"SPARK-X Stouffer Volcano — {ct}")
        plt.xlabel("Meta Z (signed)")
        plt.ylabel("-log10(FDR)")
        plt.tight_layout()
        plt.savefig(image_path / f"volcano_{ct}.png", dpi=200)
        plt.close()

    # I2 distribution
    plt.figure(figsize=(7,5))
    sns.histplot(data=meta, x="I2", hue="cell_type", bins=40,
                 element="step", fill=False)
    plt.xlabel("I² (%)")
    plt.ylabel("Gene count")
    plt.title("Replicate heterogeneity (I²)")
    plt.tight_layout()
    plt.savefig(image_path / "I2_distribution.png", dpi=200)
    plt.close()

    # Significant counts
    sig = (meta["fdr"] < fdr_thresh)
    sig_counts = meta.loc[sig].groupby("cell_type").size().reset_index(name="n_sig")
    plt.figure(figsize=(6,4))
    sns.barplot(data=sig_counts, x="cell_type", y="n_sig", color="steelblue")
    plt.ylabel(f"# significant genes (FDR<{fdr_thresh})")
    plt.xlabel("Cell type")
    plt.xticks(rotation=45, ha="right")
    plt.title("Significant SVGs per cell type")
    plt.tight_layout()
    plt.savefig(image_path / "sig_gene_counts.png", dpi=200)
    plt.close()

    print(f"[plots] Saved to {image_path}")
    return meta

In [None]:
per_rep_csv = OUTDIR / "sparkx_per_replicate_results.csv"
rows = []

adata = ad.read_h5ad(AD_PATH)

if REP_KEY not in adata.obs:
    raise KeyError(f"obs missing '{REP_KEY}'")

if CELLTYPE_KEY and CELLTYPE_KEY not in adata.obs:
    raise KeyError(f"obs missing '{CELLTYPE_KEY}'")

if AXIS not in adata.obs and COORDS_MODE == "axis_only":
    raise KeyError(f"obs missing '{AXIS}'")

if COORDS_MODE == "xy":
    xk, yk = XY_KEYS
    if xk not in adata.obs or yk not in adata.obs:
        raise KeyError(f"obs missing XY keys {XY_KEYS}")

In [None]:
# group splits
reps = adata.obs[REP_KEY].unique().tolist()
cts = adata.obs[CELLTYPE_KEY].unique().tolist() if CELLTYPE_KEY else [None]

In [None]:
for rep in reps:
    for ct in cts:
        idx = (adata.obs[REP_KEY] == rep)
        label = f"{REP_KEY}={rep}"
        if CELLTYPE_KEY and ct is not None:
            idx = idx & (adata.obs[CELLTYPE_KEY] == ct)
            label += f", {CELLTYPE_KEY}={ct}"
        if COORDS_MODE == "axis_only": 
            idx = idx & (~adata.obs[AXIS].isna())

        n_cells = int(idx.sum())
        if n_cells < min_cells:
            print(f"[skip] {label} — {n_cells} cells (<{min_cells})")
            continue

        sub = adata[idx].copy()
        print(f"[run] {label}: {sub.n_obs} cells, {sub.n_vars} genes")

        # Build expression: genes×cells for SPARK-X
        expr_cells_genes = to_dense_df(sub.X, sub.var_names, sub.obs_names)  # cells×genes
        expr_genes_cells = expr_cells_genes.T                                # genes×cells

        # Build coords
        if COORDS_MODE == "axis_only":
            coords_df = pd.DataFrame({
                "x": sub.obs[AXIS].values,
                "y": np.zeros(sub.n_obs, dtype=float)
            }, index=sub.obs_names)
        else:
            coords_df = pd.DataFrame({
                "x": sub.obs[XY_KEYS[0]].values,
                "y": sub.obs[XY_KEYS[1]].values
            }, index=sub.obs_names)

        # Temporary files for one run
        with tempfile.TemporaryDirectory() as td:
            td = Path(td)
            out_csv    = td / "sparkx_results.csv"

            # Should I write an handler for this choice (or is it always smarter to just use pyarrow)? 
            
            # write RDS
            # expr_rds   = td / "expr_df.Rds"
            # coords_rds = td / "coords_df.Rds"
            # write_rds_from_python(expr_genes_cells, expr_rds)
            # write_rds_from_python(coords_df, coords_rds)
            
            # Write to CSV and to read in CSV's in R
            # expr_path, coords_path = write_r_input(expr_cells_genes, coords_df, td)

            # Write to pyarrow feather files in R
            expr_path, coords_path = write_arrow_input(expr_cells_genes, coords_df, td)

            print("Wrote CSVs files.")

            # call R
            try:
                run_sparkx_once(expr_path, coords_path, out_csv)
            except subprocess.CalledProcessError as e:
                print(f"[err] SPARK-X failed for {label}: {e}")
                continue

            # read results
            if not out_csv.exists():
                print(f"[warn] no CSV produced for {label}")
                continue
            res = pd.read_csv(out_csv)

        # If empty, continue
        if res.empty:
            print(f"[warn] empty results for {label}")
            continue

        # Spearman rho (direction) for each gene vs aligned axis
        axis_vals = sub.obs[AXIS].values if COORDS_MODE == "axis_only" else sub.obs[XY_KEYS[0]].values
        # We compute rho using cells×genes table
        rho_map = {}
        for g in res["gene"]:
            if g not in expr_cells_genes.columns:
                # should not happen, but be safe (filtering may drop genes)
                continue
            rho, _ = spearmanr(expr_cells_genes[g].values, axis_vals)
            rho_map[g] = rho

        res["rho_axis"] = res["gene"].map(rho_map)
        res["replicate"] = rep
        res["n_cells"] = n_cells
        if CELLTYPE_KEY and ct is not None:
            res["cell_type"] = ct

        rows.append(res)

if not rows:
    print("done No runs produced results.")

all_rep = pd.concat(rows, ignore_index=True)

# Some SPARK versions use 'adjusted_pvalue' vs 'adjusted_pvalue' casing; keep p_sparkx consistent
if "adjusted_pvalue" in all_rep.columns:
    all_rep.rename(columns={"adjusted_pvalue": "p_sparkx"}, inplace=True)
if "adjusted_pValue" in all_rep.columns:
    all_rep.rename(columns={"adjusted_pValue": "p_sparkx"}, inplace=True)
if "adjusted_pval" in all_rep.columns:
    all_rep.rename(columns={"adjusted_pval": "p_sparkx"}, inplace=True)
if "adjusted_pvalue" not in all_rep.columns and "p_sparkx" not in all_rep.columns:
    # fall back to 'combined_pvalue' if adjusted missing
    if "combined_pvalue" in all_rep.columns:
        print("[warn] adjusted_pvalue missing; using combined_pvalue for meta (less conservative).")
        all_rep.rename(columns={"combined_pvalue": "p_sparkx"}, inplace=True)

all_rep.to_csv(per_rep_csv, index=False)
print(f"[io] saved per-replicate: {per_rep_csv}")

# meta + plots
meta_and_qc(per_rep_csv)