In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
from scipy.stats import norm, spearmanr
from statsmodels.stats.multitest import multipletests

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 14
plt.rcParams['axes.facecolor'] = 'white'

In [None]:
ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
adata = ad.read_h5ad(ad_path)
adata

In [None]:
# Functions to debug: 
# ---------- Step 3: Stouffer meta-analysis + I² ----------
def stouffer_meta(pvals, signs, weights):
    pvals = np.clip(np.asarray(pvals), 1e-300, 1.0)
    z = norm.isf(pvals / 2.0)
    z_signed = np.sign(signs) * z
    Z = np.sum(weights * z_signed) / np.sqrt(np.sum(weights**2))
    p = 2 * norm.sf(abs(Z))
    return Z, p

def i_squared(effect_sizes, weights):
    k = len(effect_sizes)
    if k <= 1:
        return np.nan
    mean_eff = np.average(effect_sizes, weights=weights)
    Q = np.sum(weights * (effect_sizes - mean_eff) ** 2)
    df = k - 1
    return max(0, (Q - df) / Q) * 100 if Q > 0 else 0

def meta_per_gene(df, cell_type):
    out = []
    for g, gdf in df.groupby("gene"):
        pvals, signs = gdf["p_sparkx"].values, np.sign(gdf["rho_axis"].values)
        weights = np.sqrt(gdf["n_cells"].values)
        Z, p_meta = stouffer_meta(pvals, signs, weights)
        effs = signs * norm.isf(pvals / 2.0)
        I2 = i_squared(effs, weights)
        out.append({
            "cell_type": cell_type, "gene": g, "meta_Z": Z, "meta_p": p_meta,
            "fdr": None, "I2": I2, "direction": "up" if Z > 0 else "down"
        })
    res = pd.DataFrame(out)
    res["fdr"] = multipletests(res["meta_p"], method="fdr_bh")[1]
    return res

In [None]:
# -------------------------------------------------
# 5. Visualization utilities
# -------------------------------------------------
def plot_volcano(df, out_dir="plots", fdr_thresh=0.05):
    """Volcano plot: meta_Z vs -log10(FDR) per cell type"""
    if out_dir is not None: 
        os.makedirs(out_dir, exist_ok=True)
    for ct, gdf in df.groupby("cell_type"):
        gdf["log10FDR"] = -np.log10(gdf["fdr"] + 1e-10)
        plt.figure(figsize=(6,5))
        sns.scatterplot(
            data=gdf, x="meta_Z", y="log10FDR",
            hue=gdf["fdr"] < fdr_thresh,
            palette={True:"crimson", False:"grey"}, alpha=0.6, s=20
        )
        plt.axhline(-np.log10(fdr_thresh), ls="--", color="black", lw=1)
        plt.title(f"SPARK-X Meta Volcano: {ct}")
        plt.xlabel("Stouffer meta Z-score")
        plt.ylabel("-log10(FDR)")
        plt.legend(title=f"FDR<{fdr_thresh}")
        plt.tight_layout()
        if out_dir is not None: 
            plt.savefig(os.path.join(out_dir, f"volcano_{ct}.png"), dpi=200)
            plt.close()
        else: 
            plt.show()
            plt.close()

def plot_I2_distribution(df, out_dir="plots"):
    """Histogram of heterogeneity (I²) per cell type"""
    if out_dir is not None: 
        os.makedirs(out_dir, exist_ok=True)
    plt.figure(figsize=(7,5))
    sns.histplot(data=df, x="I2", hue="cell_type", bins=40, element="step", fill=False)
    plt.xlabel("I² (%)")
    plt.ylabel("Gene count")
    plt.title("Replicate Heterogeneity Across Cell Types")
    plt.tight_layout()
    if out_dir: 
        plt.savefig(os.path.join(out_dir, "I2_distribution.png"), dpi=200)
        plt.close()
    else: 
        plt.show()
        plt.close()

def plot_significant_counts(df, fdr_thresh=0.05, i2_thresh=0.1, out_dir="plots"):
    """Barplot of significant gene counts per cell type"""
    if out_dir is not None: 
        os.makedirs(out_dir, exist_ok=True)
    sig_counts = df.groupby("cell_type")\
                   .apply(lambda x: ((x["fdr"] < fdr_thresh) & (x['I2'] < i2_thresh)).sum())\
                   .reset_index(name="n_sig")
    plt.figure(figsize=(6,4))
    sns.barplot(data=sig_counts, x="cell_type", y="n_sig", color="steelblue")
    plt.ylabel(f"# Significant Genes (FDR<{fdr_thresh})")
    plt.xlabel("Cell Type")
    plt.title("Significant Spatially Variable Genes per Cell Type")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    if out_dir: 
        plt.savefig(os.path.join(out_dir, "sig_gene_counts.png"), dpi=200)
        plt.close()
    else: 
        plt.show()
        plt.close()

In [None]:
### Read Data: 
# path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/sparkx_dsid/sparkx_per_replicate_results.csv"
# path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/sparkx_donor/sparkx_per_replicate_results.csv"
path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/sparkx_brs/sparkx_per_replicate_results.csv"
sparkx_df = pd.read_csv(path)
sparkx_df.head()

In [None]:
df_meta = []
for _ct in sparkx_df["cell_type"].unique():
    df_meta.append(meta_per_gene(sparkx_df[sparkx_df['cell_type'] == _ct], _ct))
df_meta = pd.concat(df_meta)
df_meta.head()

In [None]:
plot_volcano(df_meta[df_meta['cell_type'] == "Astrocyte"], out_dir=None)

In [None]:
sns.histplot(data=df_meta[df_meta['cell_type'] == "Astrocyte"], x="I2", hue="cell_type", bins=40, element="step", fill=False)

In [None]:
with plt.rc_context({"font.fontsize": 8}):
    plot_significant_counts(df_meta, out_dir=None, i2_thresh=90)

In [None]:
sparkx_df.replicate.nunique()

In [None]:
for _ct in sparkx_df['cell_type'].unique():
    print(_ct, sparkx_df[sparkx_df['cell_type'] == _ct].shape[0])

In [None]:
# for _ct in sparkx_df['cell_type'].unique():
#     print(_ct)
#     display(sparkx_df[sparkx_df['cell_type'] == _ct].groupby("gene")['p_sparkx'].describe().sort_values('50%', ascending=True).head(10))

In [None]:
def expr_heatmap(
    adata,
    df_genes,
    n_genes=40,
    gene_names=10,
    heatmap_order_col='rho_axis',
    title="",
    out_path=None,
    ylabel="Genes",
    xlabel="Cells",
    rasterized=True,
    show=True, 
    save=False,
    image_path=None,
):
    """Plot expression heatmap for given genes and cell type"""
    import PyComplexHeatmap as pch
    from matplotlib.colors import TwoSlopeNorm

    df_col = adata.obs[['Group', 'MS_NORM', 'MS_compartment']].copy()
    if df_col.shape[0] > 50000:
        df_col = df_col.sample(50000, random_state=42)
    df_col = df_col.sort_values('MS_NORM')

    df_row = df_genes.iloc[:n_genes].copy()
    toplot = df_row.index[:gene_names]
    df_row['annot'] = [c if c in toplot else np.nan for c in df_row.index]
    df_row = df_row.sort_values(heatmap_order_col, ascending=True)

    df_expr = adata.X.toarray()
    df_expr = pd.DataFrame(df_expr, index=adata.obs_names, columns=adata.var_names).T
    df_expr = df_expr.loc[df_row.index, df_col.index]
    df_expr_norm = df_expr.subtract(df_expr.min(axis=1), axis=0).div(df_expr.max(axis=1) - df_expr.min(axis=1), axis=0)

    ms_score_norm = TwoSlopeNorm(vmin=-1, vcenter=0, vmax=1)
    col_ha = pch.HeatmapAnnotation(
        label=pch.anno_label(
            df_col['MS_compartment'], merge=True, rotation=0, extend=True,
            colors={"Matrix": "blue", "Striosome": "red"}, 
        ),
        Compartment=pch.anno_simple(df_col['MS_compartment'], colors={"Matrix": "blue", "Striosome": "red"}),
        MatStr_Score=pch.anno_simple(df_col['MS_NORM'], cmap="coolwarm_r", norm=ms_score_norm),
        verbose=1, axis=1, plot_legend=False, legend_gap=5, hgap=2,
    )

    left_ha = pch.HeatmapAnnotation(
        label=pch.anno_label(
            df_row['annot'], merge=True, rotation=0, extend=False,
            colors="black", relpos=(1, 0.5), 
        ),
        # Genes=pch.anno_simple(df_row[0]),
        verbose=1, axis=0
    )

    plt.figure(figsize=(8,6))
    cm = pch.ClusterMapPlotter(
        data=df_expr_norm,
        top_annotation=col_ha,
        left_annotation=left_ha,
        row_cluster=False,
        col_cluster=False,
        row_dendrogram=False,
        label="Expression",
        cmap='plasma',
        rasterized=rasterized,
        ylabel=ylabel,
        xlabel=xlabel,
        vmax=0.5,
    )

    plt.suptitle(title)
    if save and image_path is not None: 
        # plt.savefig(image_path + f"/expr_heatmap_{title.replace(' ', '_')}.png", dpi=300, bbox_inches="tight")
        # plt.savefig(image_path + f"/expr_heatmap_{title.replace(' ', '_')}.pdf", dpi=300, bbox_inches="tight")
        plt.savefig(image_path + f"/expr_heatmap_{title.replace(' ', '_')}.svg", dpi=300, bbox_inches="tight")
    if show: 
        plt.show()
    plt.close()

In [None]:
for i, _ct in enumerate(sparkx_df['cell_type'].unique()):
    print(_ct)
    df_sub = sparkx_df[sparkx_df['cell_type'] == _ct].groupby("gene")[['p_sparkx', 'rho_axis']].median().sort_values('p_sparkx', ascending=True)
    adata_ct = adata[(adata.obs['Subclass'] == _ct) & (~adata.obs['MS_NORM'].isna())].copy()
    expr_heatmap(adata_ct, df_sub, n_genes=40, gene_names=10, title=f"Top SPARK-X genes in {_ct}", rasterized=True, show=True, save=False)
    # if i == 5:
    #     break

In [None]:
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/poster"

In [None]:
toplot = ["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN", "CN LAMP5-CXCL14 GABA", "CN ST18 GABA", "CN VIP GABA"]
for _group in toplot:
    print(_group)
    adata_ct = adata[(adata.obs['Subclass'] == _group) & (~adata.obs['MS_NORM'].isna())].copy()
    df_sub = sparkx_df[sparkx_df['cell_type'] == _group].groupby("gene")[['p_sparkx', 'rho_axis']].median().sort_values('p_sparkx', ascending=True)
    expr_heatmap(adata_ct, df_sub, n_genes=20, gene_names=20, title=f"Top SPARK-X genes in {_group}", ylabel=None, rasterized=True, show=False, save=True, image_path=image_path)

In [None]:
# ### Investigating Further: 
group = "CN LAMP5-CXCL14 GABA"
adata_ct = adata[(adata.obs['Subclass'] == group) & (~adata.obs['MS_NORM'].isna())].copy()
df_sub = sparkx_df[sparkx_df['cell_type'] == group].groupby("gene")[['p_sparkx', 'rho_axis']].median().sort_values('p_sparkx', ascending=True)
expr_heatmap(adata_ct, df_sub, n_genes=20, gene_names=20, title=f"Top SPARK-X genes in {group}", ylabel=None, rasterized=True, show=True, save=False)

In [None]:
sparkx_df[sparkx_df['cell_type'] == group].groupby("gene")[['p_sparkx']].describe().loc['HDAC9']

In [None]:
### Investigating Further: 

group = "CN ST18 GABA"
adata_ct = adata[(adata.obs['Subclass'] == group) & (~adata.obs['MS_NORM'].isna())].copy()
df_sub = sparkx_df[sparkx_df['cell_type'] == group].groupby("gene")[['p_sparkx', 'rho_axis']].median().sort_values('p_sparkx', ascending=True)
expr_heatmap(adata_ct, df_sub, n_genes=20, gene_names=20, title=f"Top SPARK-X genes in {group}", ylabel=None)

In [None]:
sparkx_df[sparkx_df['cell_type'] == group].groupby("gene")[['p_sparkx']].describe().loc['RASGRF2']

In [None]:
### Investigating Further: 

group = "CN VIP GABA"
adata_ct = adata[(adata.obs['Subclass'] == group) & (~adata.obs['MS_NORM'].isna())].copy()
df_sub = sparkx_df[sparkx_df['cell_type'] == group].groupby("gene")[['p_sparkx', 'rho_axis']].median().sort_values('p_sparkx', ascending=True)
expr_heatmap(adata_ct, df_sub, n_genes=20, gene_names=20, title=f"Top SPARK-X genes in {group}", ylabel=None)

In [None]:
sparkx_df[sparkx_df['cell_type'] == group].groupby("gene")[['p_sparkx']].describe().loc['KIAA1211']