In [None]:
import os
import numpy as np
import pandas as pd
import anndata as ad
import subprocess
import matplotlib.pyplot as plt
import tempfile
from pathlib import Path
from scipy.stats import norm
import pyarrow.feather as feather
from concurrent.futures import ProcessPoolExecutor, as_completed

# Functions

## HETEROGENEITY UTILITIES

In [None]:


def stouffer_meta(pvals, weights=None):
    pvals = np.asarray(pvals)
    valid = np.isfinite(pvals) & (pvals > 0) & (pvals < 1)
    if valid.sum() == 0:
        return np.nan

    z = norm.isf(pvals[valid])

    if weights is None:
        w = np.ones(valid.sum())
    else:
        w = np.asarray(weights)[valid]

    z_meta = np.sum(w * z) / np.sqrt(np.sum(w**2))
    return norm.sf(z_meta)


def i_squared(z_vals):
    z_vals = np.asarray(z_vals)
    valid = np.isfinite(z_vals)
    z = z_vals[valid]
    k = len(z)
    if k < 2:
        return np.nan

    Q = np.sum((z - z.mean())**2)
    return max(0, (Q - (k - 1)) / Q)


def tau_squared_dl(z_vals, weights=None):
    z_vals = np.asarray(z_vals)
    valid = np.isfinite(z_vals)
    z = z_vals[valid]
    k = len(z)

    if k < 2:
        return np.nan

    if weights is None:
        w = np.ones(k)
    else:
        w = np.asarray(weights)[valid]

    z_bar = np.sum(w * z) / np.sum(w)
    Q = np.sum(w * (z - z_bar)**2)

    denom = np.sum(w) - (np.sum(w**2) / np.sum(w))
    tau2 = max(0, (Q - (k - 1)) / denom)
    return tau2


def prediction_interval(z_vals, weights=None):
    z_vals = np.asarray(z_vals)
    valid = np.isfinite(z_vals)
    z = z_vals[valid]
    k = len(z)

    if k < 2:
        return np.nan, np.nan, np.nan

    if weights is None:
        w = np.ones(k)
    else:
        w = np.asarray(weights)[valid]

    z_meta = np.sum(w * z) / np.sqrt(np.sum(w**2))
    tau2 = tau_squared_dl(z, weights=w)
    se_pred = np.sqrt(tau2 + 1)

    PI_low = z_meta - 1.96 * se_pred
    PI_high = z_meta + 1.96 * se_pred

    return z_meta, PI_low, PI_high

In [None]:
def to_dense_df(X, var_names, obs_names):
    """Return a dense pandas DataFrame cells×genes from AnnData.X."""
    if hasattr(X, "toarray"):
        arr = X.toarray()
    elif hasattr(X, "A"):
        arr = X.A
    else:
        arr = np.asarray(X)
    return pd.DataFrame(arr, index=obs_names, columns=var_names)

## R CALL WRAPPER USING TEMP FILES


In [None]:
def run_tradeseq_single(expr_df, axis_vec, tradeseq_script, idx=None, outdir=None, nknots=8, r_num_cores=4):
    """
    Run tradeSeq on a single replicate using temporary Arrow files.
    expr_df: cells × genes
    axis_vec: Series (n_cells)
    """
    with tempfile.TemporaryDirectory() as tf:
        tf = Path(tf)
        expr_path = tf / "expr.feather"
        axis_path = tf / "axis.feather"
        out_name = f"results_{idx}.csv" if idx is not None else "results.csv"
        out_csv = tf / out_name if outdir is None else Path(outdir) / out_name

        # tradeSeq expects genes × cells
        expr_gc = expr_df.T.reset_index().rename(columns={"index": "gene"})
        axis_df = pd.DataFrame({"cell": axis_vec.index, "axis": axis_vec.values})

        feather.write_feather(expr_gc, expr_path)
        feather.write_feather(axis_df, axis_path)

        cmd = [
            "pixi", "run", "-e", "dist", 
            "Rscript", tradeseq_script,
            "--expr", str(expr_path),
            "--axis", str(axis_path),
            "--out", str(out_csv),
            "--nknots", str(nknots),
            "--num_cores", str(r_num_cores)
        ]

        try:
            ret = subprocess.run(cmd, check=True, capture_output=True)
        except subprocess.CalledProcessError as e:
            print(f"[ERROR] tradeSeq failed: {e}")
            return None

        if not out_csv.exists():
            print("[WARN] tradeSeq produced no output for this replicate.")
            return None

        return pd.read_csv(out_csv)

## MAIN PIPELINE

In [None]:
# def run_tradeseq_pipeline(
#     adata_path,
#     tradeseq_script,
#     celltype_field="cell_type",
#     replicate_field="slide_id",
#     axis_field="axis",
#     layer=None, 
#     min_cells=20,
#     nknots=6,
#     num_cores=4
# ):
#     adata = ad.read_h5ad(adata_path)
#     if axis_field not in adata.obs:
#         raise ValueError(f"axis vector missing in adata.obs['{axis_field}']")

#     results = {}   # {celltype: {replicate: df}}

#     for ct in np.unique(adata.obs[celltype_field]):
#         print(f"\n=== Cell type: {ct} ===")
#         ad_ct = adata[adata.obs[celltype_field] == ct]
#         reps = np.unique(ad_ct.obs[replicate_field])
#         results[ct] = {}

#         for rep in reps:
#             print(f"  - replicate {rep}")
#             ad_rep = ad_ct[(ad_ct.obs[replicate_field] == rep) & (~ad_ct.obs[axis_field].isna())]

#             if ad_rep.n_obs < min_cells:
#                 print(f"    [SKIP] only {ad_rep.n_obs} cells (< {min_cells})")
#                 continue

#             # convert sparse → dense when necessary
#             X = ad_rep.X if layer is None else ad_rep.layers[layer]
#             expr_df = to_dense_df(ad_rep.X, ad_rep.var_names, ad_rep.obs_names)

#             axis_vec = ad_rep.obs[axis_field]
#             # axis_vec = (axis_vec - axis_vec.min()) / (axis_vec.max() - axis_vec.min())

#             res = run_tradeseq_single(expr_df, axis_vec, tradeseq_script,
#                                       nknots=nknots, num_cores=num_cores)
#             results[ct][rep] = res

#     return results

In [None]:
def _worker_tradeseq_job(job):
    """
    job: dict with keys:
      'celltype', 'rep', 'expr_df', 'axis_vec',
      'tradeseq_script', 'nknots', 'r_num_cores'
    """
    ct = job["celltype"]
    rep = job["rep"]
    expr_df = job["expr_df"]
    axis_vec = job["axis_vec"]
    tradeseq_script = job["tradeseq_script"]
    outdir = job.get("outdir", None)
    nknots = job["nknots"]
    r_num_cores = job["r_num_cores"]

    print(f"[worker] Running tradeSeq for {ct} / {rep} "
          f"({expr_df.shape[0]} cells × {expr_df.shape[1]} genes)")

    res = run_tradeseq_single(expr_df, axis_vec, tradeseq_script, idx=f"{ct}_{rep}",
                              outdir=outdir, nknots=nknots, r_num_cores=r_num_cores)
    return ct, rep, res

In [None]:
def run_tradeseq_pipeline_mp(
    adata_path,
    tradeseq_script,
    celltype_field="cell_type",
    replicate_field="slide_id",
    axis_field="axis",
    outdir=None,
    layer=None, 
    min_cells=20,
    nknots=8,
    r_num_cores=4,
    n_workers=4,
):
    """
    Parallel tradeSeq runner.
    - adata_path: .h5ad
    - tradeseq_script: path to run_tradeseq_auto.R
    - nknots: spline knots for tradeSeq
    - r_num_cores: cores per R call
    - n_workers: Python processes (jobs in parallel)
    """
    adata = ad.read_h5ad(adata_path)
    if axis_field not in adata.obs:
        raise ValueError(f"axis vector missing in adata.obs['{axis_field}']")

    jobs = []
    celltypes = np.unique(adata.obs[celltype_field])

    for ct in celltypes:
        ad_ct = adata[adata.obs[celltype_field] == ct]
        reps = np.unique(ad_ct.obs[replicate_field])

        for rep in reps:
            ad_rep = ad_ct[(ad_ct.obs[replicate_field] == rep) & (~ad_ct.obs[axis_field].isna())]
            if ad_rep.n_obs < min_cells:
                continue

            # convert sparse → dense when necessary
            X = ad_rep.X if layer is None else ad_rep.layers[layer]
            expr_df = to_dense_df(X, ad_rep.var_names, ad_rep.obs_names)

            axis_vec = ad_rep.obs[axis_field]
            # axis_vec = (axis_vec - axis_vec.min()) / (axis_vec.max() - axis_vec.min())
            
            jobs.append({
                "celltype": ct,
                "rep": rep,
                "expr_df": expr_df,
                "axis_vec": axis_vec,
                "tradeseq_script": tradeseq_script,
                "outdir": outdir,
                "nknots": nknots,
                "r_num_cores": r_num_cores,
            })

    # Run jobs in parallel
    results = {ct: {} for ct in celltypes}
    
    with ProcessPoolExecutor(max_workers=n_workers) as ex:
        future_to_job = {ex.submit(_worker_tradeseq_job, job): job for job in jobs}

        for fut in as_completed(future_to_job):
            job = future_to_job[fut]
            ct = job["celltype"]
            rep = job["rep"]
            try:
                ct_out, rep_out, res_df = fut.result()
                results[ct_out][rep_out] = res_df
            except Exception as e:
                print(f"[ERROR] Job {ct}/{rep} failed in worker: {e}")
                results[ct][rep] = None

    return results

## META-ANALYSIS

In [None]:
def meta_combine_results(results_dict):
    meta_out = {}

    for ct, rep_dict in results_dict.items():
        print(f"\n=== Meta-analysis: {ct} ===")
        dfs = [df for df in rep_dict.values() if df is not None]
        if len(dfs) == 0:
            continue

        all_genes = set().union(*[set(df["gene"]) for df in dfs])
        rows = []

        for gene in all_genes:
            assoc_pvals = []
            end_pvals = []

            for rep, df in rep_dict.items():
                if df is None:
                    continue
                row = df[df["gene"] == gene]
                if len(row) == 1:
                    assoc_pvals.append(row["association_pvalue"].values[0])
                    end_pvals.append(row["end_test_pvalue"].values[0])

            if len(assoc_pvals) == 0:
                continue

            assoc_z = norm.isf(assoc_pvals)
            end_z = norm.isf(end_pvals)

            # Meta p-values
            assoc_meta_p = stouffer_meta(assoc_pvals)
            end_meta_p = stouffer_meta(end_pvals)

            # Relative heterogeneity
            assoc_I2 = i_squared(assoc_z)
            end_I2 = i_squared(end_z)

            # Absolute heterogeneity
            assoc_tau2 = tau_squared_dl(assoc_z)
            end_tau2 = tau_squared_dl(end_z)

            assoc_tau = np.sqrt(assoc_tau2) if np.isfinite(assoc_tau2) else np.nan
            end_tau = np.sqrt(end_tau2) if np.isfinite(end_tau2) else np.nan

            # Prediction intervals
            assoc_zmeta, assoc_PI_low, assoc_PI_high = prediction_interval(assoc_z)
            end_zmeta, end_PI_low, end_PI_high = prediction_interval(end_z)

            rows.append([
                gene,
                assoc_meta_p, assoc_I2, assoc_tau2, assoc_tau,
                assoc_zmeta, assoc_PI_low, assoc_PI_high,
                end_meta_p, end_I2, end_tau2, end_tau,
                end_zmeta, end_PI_low, end_PI_high
            ])

        df_meta = pd.DataFrame(rows, columns=[
            "gene",
            "assoc_meta_p", "assoc_I2", "assoc_tau2", "assoc_tau",
            "assoc_zmeta", "assoc_PI_low", "assoc_PI_high",
            "end_meta_p", "end_I2", "end_tau2", "end_tau",
            "end_zmeta", "end_PI_low", "end_PI_high"
        ])

        # FDR
        df_meta["assoc_fdr"] = df_meta["assoc_meta_p"].rank() / len(df_meta)
        df_meta["end_fdr"] = df_meta["end_meta_p"].rank() / len(df_meta)

        meta_out[ct] = df_meta

    return meta_out

## Plotting Functions

In [None]:
def plot_mean_vs_meta(df_meta, rep_dict, ct, outdir):
    outdir = Path(outdir); outdir.mkdir(exist_ok=True)

    rep_dfs = {rep: df.set_index("gene")
               for rep, df in rep_dict.items() if df is not None}

    rep_mat = pd.concat([df["association_pvalue"] for df in rep_dfs.values()], axis=1)
    rep_mat.columns = list(rep_dfs.keys())
    rep_mean = rep_mat.mean(axis=1)

    merged = df_meta.set_index("gene").join(rep_mean.rename("rep_mean_p")).reset_index()
    X = -np.log10(merged["rep_mean_p"]+1e-12)
    Y = -np.log10(merged["assoc_meta_p"]+1e-12)
    
    fig, ax = plt.subplots(figsize=(6,6))
    ax.scatter(X, Y, s=10, alpha=0.5)
    ax.plot([0,int(max(X))+1],[0,int(max(X))+1],"--",color="red")
    ax.set_xlabel("-log10(mean replicate p)")
    ax.set_ylabel("-log10(meta p)")
    ax.set_title(f"{ct}: mean vs meta")
    fig.tight_layout()
    fig.savefig(outdir / f"{ct}_mean_vs_meta.png", dpi=200)
    plt.close(fig)


def plot_rep_heatmap(df_meta, rep_dict, ct, outdir, top_n=50):
    outdir = Path(outdir); outdir.mkdir(exist_ok=True)

    top_genes = df_meta.sort_values("assoc_meta_p").head(top_n)["gene"]
    rep_dfs = {rep: df.set_index("gene")
               for rep, df in rep_dict.items() if df is not None}
    rep_mat = pd.concat([df.loc[top_genes]["association_pvalue"]
                         for df in rep_dfs.values()], axis=1)
    rep_mat.columns = list(rep_dfs.keys())

    data = -np.log10(rep_mat.values + 1e-12)

    fig, ax = plt.subplots(figsize=(10,8))
    im = ax.imshow(data, aspect='auto', cmap='viridis')
    ax.set_yticks(range(len(top_genes)))
    ax.set_yticklabels(top_genes)
    ax.set_xticks(range(len(rep_dfs)))
    ax.set_xticklabels(rep_dfs.keys(), rotation=90)
    fig.colorbar(im, ax=ax, label="-log10(p)")
    ax.set_title(f"{ct}: replicate heatmap")
    fig.tight_layout()
    fig.savefig(outdir / f"{ct}_heatmap.png", dpi=200)
    plt.close(fig)


def plot_meta_vs_I2(df_meta, ct, outdir):
    outdir = Path(outdir); outdir.mkdir(exist_ok=True)

    fig, ax = plt.subplots(figsize=(7,6))
    ax.scatter(df_meta["assoc_I2"], 
               -np.log10(df_meta["assoc_meta_p"]+1e-12),
               s=10, alpha=0.5)
    ax.set_xlabel("I²")
    ax.set_ylabel("-log10(meta p)")
    ax.set_title(f"{ct}: meta vs I²")
    fig.tight_layout()
    fig.savefig(outdir / f"{ct}_meta_vs_I2.png", dpi=200)
    plt.close(fig)


def plot_meta_vs_tau(df_meta, ct, outdir):
    outdir = Path(outdir); outdir.mkdir(exist_ok=True)

    fig, ax = plt.subplots(figsize=(7,6))
    ax.scatter(df_meta["assoc_tau"],
               -np.log10(df_meta["assoc_meta_p"]+1e-12),
               s=10, alpha=0.5)
    ax.set_xlabel("tau (absolute heterogeneity)")
    ax.set_ylabel("-log10(meta p)")
    ax.set_title(f"{ct}: meta vs tau")
    fig.tight_layout()
    fig.savefig(outdir / f"{ct}_meta_vs_tau.png", dpi=200)
    plt.close(fig)


def plot_PIwidth_vs_meta(df_meta, ct, outdir):
    outdir = Path(outdir); outdir.mkdir(exist_ok=True)
    pi_width = df_meta["assoc_PI_high"] - df_meta["assoc_PI_low"]

    fig, ax = plt.subplots(figsize=(7,6))
    ax.scatter(pi_width,
               -np.log10(df_meta["assoc_meta_p"]+1e-12),
               s=10, alpha=0.5)
    ax.set_xlabel("PI width")
    ax.set_ylabel("-log10(meta p)")
    ax.set_title(f"{ct}: PI width vs meta")
    fig.tight_layout()
    fig.savefig(outdir / f"{ct}_PIwidth_vs_meta.png", dpi=200)
    plt.close(fig)


def generate_all_plots(df_meta, results_dict, ct, outdir="plots"):
    plot_mean_vs_meta(df_meta, results_dict, ct, outdir)
    plot_rep_heatmap(df_meta, results_dict, ct, outdir)
    plot_meta_vs_I2(df_meta, ct, outdir)
    plot_meta_vs_tau(df_meta, ct, outdir)
    plot_PIwidth_vs_meta(df_meta, ct, outdir)

# Run

In [None]:
# parameters
AD_PATH = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
R_SCRIPT = "/home/x-aklein2/projects/aklein/BICAN/BG/spatial_analysis/scripts/tradeseq.R"
R_HOME = "/anvil/projects/x-mcb130189/aklein/SPIDA/.pixi/envs/dist/lib/R"
OUTDIR = "/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/tradeseq_dsid"
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/CPS/tradeseq_dsid"
REP_KEY = "dataset_id"
CELLTYPE_KEY = "Subclass"
AXIS = "MS_NORM"
layer="counts"
min_cells = 50

N_WORKERS = 2        # Python processes (jobs in parallel)
R_NUM_CORES = 6      # cores per R tradeSeq call
NKNOTS = 8

In [None]:
OUTDIR = Path(OUTDIR)
OUTDIR.mkdir(exist_ok=True, parents=True)
image_path = Path(image_path)
image_path.mkdir(exist_ok=True, parents=True)
os.environ["R_HOME"] = R_HOME

In [None]:
results = run_tradeseq_pipeline_mp(
    adata_path=AD_PATH,
    tradeseq_script=R_SCRIPT,
    celltype_field=CELLTYPE_KEY,
    replicate_field=REP_KEY,
    axis_field=AXIS,
    outdir=OUTDIR,
    layer=layer,
    nknots=NKNOTS,
    r_num_cores=R_NUM_CORES,
    n_workers=N_WORKERS,
    min_cells=min_cells
)

In [None]:
meta = meta_combine_results(results)

for ct, df_meta in meta.items():
    generate_all_plots(df_meta, results[ct], ct, image_path)
    df_meta.to_csv(OUTDIR / f"tradeseq_meta_{ct}.csv", index=False)