In [None]:
import os

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
from scipy.stats import spearmanr
import rpy2.robjects as ro
from rpy2.robjects import pandas2ri, numpy2ri
from rpy2.robjects.packages import importr
from statsmodels.stats.multitest import multipletests
from scipy.stats import norm

# pandas2ri.activate()
spark_pkg = importr("SPARK")

In [None]:
# Print versions of important packages
print(f"Python: {os.sys.version}")
print(f"NumPy: {np.__version__}")
print(f"Pandas: {pd.__version__}")
print(f"Anndata: {ad.__version__}")
print(f"Scanpy: {sc.__version__}")

In [None]:
ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
adata = ad.read_h5ad(ad_path)
adata

In [None]:
sub = adata[
    (adata.obs['brain_region'] == 'NAC') & 
    (adata.obs['donor'] == 'UCI5224') & 
    (adata.obs['replicate'] == 'salk') & 
    (~adata.obs['MSN_Groups'].isna()) & 
    (~adata.obs['MS_NORM'].isna())
].copy()
sub

In [None]:
sub.obs['Group'].value_counts()

In [None]:
np.random.rand(sub.n_obs)

In [None]:
coords_df = pd.DataFrame({'x': sub.obs['MS_NORM'], 'y': 0}, index=sub.obs_names)
expr_df = pd.DataFrame(sub.layers['counts'].toarray(), columns=sub.var_names, index=sub.obs_names)
coords_df.head()
# expr_df.head()

In [None]:
with ro.conversion.localconverter(ro.default_converter + pandas2ri.converter + numpy2ri.converter):
    # Convert the pandas DataFrame to an R data.frame within this context
    r_expr = pandas2ri.py2rpy(expr_df.T)
    r_coords = pandas2ri.py2rpy(coords_df)
    # res = spark_pkg.sparkx(r_expr, r_coords, option="mixture")
    

In [None]:
%%R -i r_expr -i r_coords
expr_mat <- as.matrix(r_expr)
expr_var <- apply(expr_mat, 1, var)
expr_mat <- expr_mat[expr_var > 0, , drop = FALSE]
expr_mat <- expr_mat[rowSums(expr_mat > 0) >= 5, , drop = FALSE]
if (sd(r_coords[,1]) == 0) r_coords[,1] <- r_coords[,1] + rnorm(nrow(r_coords), 0, 1e-6)
if (sd(r_coords[,2]) == 0) r_coords[,2] <- r_coords[,2] + rnorm(nrow(r_coords), 0, 1e-6)

In [None]:
%%R 
# class(r_expr)
# class(as.matrix(r_expr))
# r_expr_mat <- r_expr_mat[rowSums(r_expr_mat > 0) >= 5, ]

res = sparkx(expr_mat, r_coords, option="mixture")

# r_expr_mat

In [None]:
%%R -o df
# class(res$res_mtest)
df <- as.data.frame(res$res_mtest)
df$gene <- rownames(df)
df <- df[, c("gene", "combinedPval", "adjustedPval")]
colnames(df) <- c("gene", "combined_pvalue", "adjusted_pvalue")
class(df)

In [None]:
%load_ext rpy2.ipython

In [None]:
%%R -i r_expr -i r_coords -i counts
library(SPARK)
spark = CreateSPARKObject(counts=r_expr, location=r_coords, min_total_counts=20)
spark.

In [None]:
r_expr

In [None]:
res = spark_pkg.sparkx(r_expr, r_coords, option="mixture")

In [None]:
res

In [None]:
r_coords.rx(True, ro.StrVector(['x', 'y']))

In [None]:
# ---------- Step 1: SPARK-X per replicate ----------
def run_sparkx(expr_df, coords_df, min_cells=50):
    if expr_df.shape[0] < min_cells:
        return pd.DataFrame()
    r_expr = pandas2ri.py2rpy(expr_df)
    r_coords = pandas2ri.py2rpy(coords_df)
    res = spark_pkg.sparkx(r_expr, r_coords, option="mixture")
    df = pandas2ri.rpy2py(res.rx2("res_mtest"))
    df = df.rename(columns={"adjusted_pvalue": "p_sparkx"})
    df["gene"] = df.index
    return df[["gene", "p_sparkx"]]

In [None]:
# ---------- Step 1: SPARK-X per replicate ----------
def run_sparkx(expr_df, coords_df, min_cells=50):
    if expr_df.shape[0] < min_cells:
        return pd.DataFrame()
    r_expr = pandas2ri.py2rpy(expr_df)
    r_coords = pandas2ri.py2rpy(coords_df)
    res = spark_pkg.sparkx(r_expr, r_coords, option="mixture")
    df = pandas2ri.rpy2py(res.rx2("res_mtest"))
    df = df.rename(columns={"adjusted_pvalue": "p_sparkx"})
    df["gene"] = df.index
    return df[["gene", "p_sparkx"]]

# ---------- Step 2: run per replicate & cell type ----------
def run_per_replicate(adata, axis_key="axis", celltype_key="cell_type", replicate_key="replicate"):
    results = []
    for ct in adata.obs[celltype_key].unique():
        for rep in adata.obs[replicate_key].unique():
            idx = (adata.obs[celltype_key] == ct) & (adata.obs[replicate_key] == rep)
            if idx.sum() < 50:
                continue
            sub = adata[idx]
            expr = pd.DataFrame(sub.X.A if hasattr(sub.X, "A") else sub.X,
                                columns=sub.var_names, index=sub.obs_names)
            coords = pd.DataFrame({"x": sub.obs[axis_key], "y": np.zeros(len(sub))}, index=sub.obs_names)

            print(f"Running SPARK-X for {ct}, replicate {rep} ({expr.shape[0]} cells)")
            df = run_sparkx(expr, coords)
            df["cell_type"] = ct
            df["replicate"] = rep
            df["n_cells"] = expr.shape[0]

            # Spearman correlation with aligned axis
            rho = []
            for g in df["gene"]:
                rho_g, _ = spearmanr(expr[g], sub.obs[axis_key])
                rho.append(rho_g)
            df["rho_axis"] = rho
            results.append(df)
    return pd.concat(results, ignore_index=True)

# ---------- Step 3: Stouffer meta-analysis + I² ----------
def stouffer_meta(pvals, signs, weights):
    pvals = np.clip(np.asarray(pvals), 1e-300, 1.0)
    z = norm.isf(pvals / 2.0)
    z_signed = np.sign(signs) * z
    Z = np.sum(weights * z_signed) / np.sqrt(np.sum(weights**2))
    p = 2 * norm.sf(abs(Z))
    return Z, p

def i_squared(effect_sizes, weights):
    k = len(effect_sizes)
    if k <= 1:
        return np.nan
    mean_eff = np.average(effect_sizes, weights=weights)
    Q = np.sum(weights * (effect_sizes - mean_eff) ** 2)
    df = k - 1
    return max(0, (Q - df) / Q) * 100 if Q > 0 else 0

def meta_per_gene(df, cell_type):
    out = []
    for g, gdf in df.groupby("gene"):
        pvals, signs = gdf["p_sparkx"].values, np.sign(gdf["rho_axis"].values)
        weights = np.sqrt(gdf["n_cells"].values)
        Z, p_meta = stouffer_meta(pvals, signs, weights)
        effs = signs * norm.isf(pvals / 2.0)
        I2 = i_squared(effs, weights)
        out.append({
            "cell_type": cell_type, "gene": g, "meta_Z": Z, "meta_p": p_meta,
            "fdr": None, "I2": I2, "direction": "up" if Z > 0 else "down"
        })
    res = pd.DataFrame(out)
    res["fdr"] = multipletests(res["meta_p"], method="fdr_bh")[1]
    return res

# # ---------- Step 4: driver ----------
# if __name__ == "__main__":
#     adata = sc.read_h5ad("your_spatial_data.h5ad")
#     all_reps = run_per_replicate(adata)
#     all_reps.to_csv("sparkx_per_replicate_results.csv", index=False)
#     meta = []
#     for ct, gdf_ct in all_reps.groupby("cell_type"):
#         meta.append(meta_per_gene(gdf_ct, ct))
#     meta = pd.concat(meta)
#     meta.to_csv("sparkx_meta_results.csv", index=False)
#     print("Pipeline finished: results saved.")


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import os

# -------------------------------------------------
# 5. Visualization utilities
# -------------------------------------------------
def plot_volcano(df, out_dir="plots", fdr_thresh=0.05):
    """Volcano plot: meta_Z vs -log10(FDR) per cell type"""
    os.makedirs(out_dir, exist_ok=True)
    for ct, gdf in df.groupby("cell_type"):
        gdf["log10FDR"] = -np.log10(gdf["fdr"] + 1e-10)
        plt.figure(figsize=(6,5))
        sns.scatterplot(
            data=gdf, x="meta_Z", y="log10FDR",
            hue=gdf["fdr"] < fdr_thresh,
            palette={True:"crimson", False:"grey"}, alpha=0.6, s=20
        )
        plt.axhline(-np.log10(fdr_thresh), ls="--", color="black", lw=1)
        plt.title(f"SPARK-X Meta Volcano: {ct}")
        plt.xlabel("Stouffer meta Z-score")
        plt.ylabel("-log10(FDR)")
        plt.legend(title=f"FDR<{fdr_thresh}")
        plt.tight_layout()
        plt.savefig(os.path.join(out_dir, f"volcano_{ct}.png"), dpi=200)
        plt.close()

def plot_I2_distribution(df, out_dir="plots"):
    """Histogram of heterogeneity (I²) per cell type"""
    os.makedirs(out_dir, exist_ok=True)
    plt.figure(figsize=(7,5))
    sns.histplot(data=df, x="I2", hue="cell_type", bins=40, element="step", fill=False)
    plt.xlabel("I² (%)")
    plt.ylabel("Gene count")
    plt.title("Replicate Heterogeneity Across Cell Types")
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "I2_distribution.png"), dpi=200)
    plt.close()

def plot_significant_counts(df, fdr_thresh=0.05, out_dir="plots"):
    """Barplot of significant gene counts per cell type"""
    os.makedirs(out_dir, exist_ok=True)
    sig_counts = df.groupby("cell_type")\
                   .apply(lambda x: (x["fdr"] < fdr_thresh).sum())\
                   .reset_index(name="n_sig")
    plt.figure(figsize=(6,4))
    sns.barplot(data=sig_counts, x="cell_type", y="n_sig", color="steelblue")
    plt.ylabel(f"# Significant Genes (FDR<{fdr_thresh})")
    plt.xlabel("Cell Type")
    plt.title("Significant Spatially Variable Genes per Cell Type")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "sig_gene_counts.png"), dpi=200)
    plt.close()

# -------------------------------------------------
# 6. Run visualizations
# -------------------------------------------------
if __name__ == "__main__":
    meta = pd.read_csv("sparkx_meta_results.csv")
    plot_volcano(meta)
    plot_I2_distribution(meta)
    plot_significant_counts(meta)
    print("Summary plots saved in ./plots/")


In [None]:
# meta_per_gene(df, "astrocyte")