The point of this notebook is to plot all the heatmaps + spatial plots of the SVGs associated with the Matrix Striosome Compartment switching. The figures from this notebook will go into spatial supp. figure 5 as well as the spatial main figure (Fig. 6). 

Author: Amit Klein 
Email: a3klein@ucsd.edu

In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
from scipy.stats import norm, spearmanr
from statsmodels.stats.multitest import multipletests
import geopandas as gpd
from spida.utilities._ad_utils import normalize_adata

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.colors as mcolors
import seaborn as sns

from spida.pl._utils import add_color_scheme
from spida.pl import plot_categorical, plot_continuous, categorical_scatter, continuous_scatter
from spida.utilities.sd_utils import _get_obs_or_gene

In [None]:
# Print versions of important packages
print(f"Python: {os.sys.version}")
print(f"NumPy: {np.__version__}")
print(f"Pandas: {pd.__version__}")
print(f"Anndata: {ad.__version__}")
print(f"Scanpy: {sc.__version__}")

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 8
plt.rcParams['axes.facecolor'] = 'white'
    
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['savefig.transparent'] = True
plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['savefig.pad_inches'] = 0.01

RASTERIZED = False

In [None]:
ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
adata = ad.read_h5ad(ad_path)
adata

In [None]:
spatial_mch = ad.read_h5ad("/home/x-aklein2/projects/aklein/BICAN/BG/data/methylation_2/BG_mCH_Imp_SubR.h5ad")
for k, v in adata.uns.items(): 
    if "colors" in k or "palette" in k: 
        spatial_mch.uns[k] = v
spatial_mch.obsm['spatial'] = spatial_mch.obs[['CENTER_X', 'CENTER_Y']].values
spatial_mch

In [None]:
image_path = Path("/home/x-aklein2/projects/aklein/BICAN/BG/images/figures/ms-svg3")
image_path.mkdir(exist_ok=True, parents=True)

## Helper Functions

In [None]:
# Downsample Function
def _downsample_reference(
    ref_adata : ad.AnnData,
    cluster_col : str,
    max_cluster_size: int = 3000,
    min_cluster_size: int = 0,
):
    """
    Remove clusters from the reference that have less than min_cluster_size cells.
    Downsample larger clusters that have more than max_cluster_size cells.
    """
    from spida.utilities._ad_utils import _downsample_ref_clusters, _remove_small_clusters
    if min_cluster_size > 0: 
        ref_adata = _remove_small_clusters(ref_adata, cluster_col, min_cells=min_cluster_size)
    if max_cluster_size > 0:
        ref_adata = _downsample_ref_clusters(ref_adata, cluster_col, max_cells=max_cluster_size)
    return ref_adata

In [None]:
def expr_heatmap(
    adata,
    df_genes,
    n_genes=40,
    gene_names=10,
    heatmap_order_col='rho_axis',
    groupby_replicate=True,
    replicate_col="brain_region",
    replicate_palette=None,
    replicate_order=None,
    title="",
    filename = None,
    out_path=None,
    ylabel="Genes",
    xlabel="Cells",
    cmap="plasma",
    rasterized=True,
    show=True, 
    save=False,
    image_path=None,
    min_max_quantiles=(0.02, 0.98),
    vmin=0, vmax=1,
    color_splits=None,
):
    """Plot expression heatmap for given genes and cell type"""
    import PyComplexHeatmap as pch
    from matplotlib.colors import TwoSlopeNorm

    if filename is None: 
        filename = f"expr_heatmap_{title.replace(' ', '_')}"

    if replicate_palette is None:
        replicate_palette = adata.uns.get(f"{replicate_col}_palette", None)

    df_col = adata.obs[['Group', 'MS_NORM', 'MS_compartment']].copy()
    if df_col.shape[0] > 50000:
        df_col = df_col.sample(50000, random_state=42)
    if groupby_replicate:
        df_col[replicate_col] = adata.obs[replicate_col]
        df_col = df_col.sort_values([replicate_col, 'MS_NORM'])
    else:
        df_col = df_col.sort_values('MS_NORM')

    df_row = df_genes.iloc[:n_genes].copy()
    df_row.drop_duplicates(subset='gene', keep='first', inplace=True)
    if "gene" in df_row.columns:
        df_row = df_row.set_index('gene')
    toplot = df_row.index[:gene_names]
    df_row['annot'] = [c if c in toplot else np.nan for c in df_row.index]
    df_row = df_row.sort_values(heatmap_order_col, ascending=True)
    # if "gene" in df_row.columns:
    #     df_row = df_row.set_index('gene')

    if hasattr(adata.X, "toarray"):
        df_expr = adata.X.toarray()
    else: 
        df_expr = adata.X
    df_expr = pd.DataFrame(df_expr, index=adata.obs_names, columns=adata.var_names).T
    df_expr = df_expr.loc[df_row.index, df_col.index]
    df_expr_norm = (
        df_expr
        .subtract(df_expr.quantile(min_max_quantiles[0], axis=1), axis=0)
        .div(df_expr.quantile(min_max_quantiles[1], axis=1) - df_expr.quantile(min_max_quantiles[0], axis=1), axis=0)
    )

    ms_score_norm = TwoSlopeNorm(vmin=-1, vcenter=0, vmax=1)
    col_ha = pch.HeatmapAnnotation(
        # label=pch.anno_label(
        #     df_col['MS_compartment'], merge=True, rotation=0, extend=True,
        #     colors={"Matrix": "blue", "Striosome": "red"}, 
        # ),
        Brain_Region=pch.anno_simple(df_col[replicate_col], colors=replicate_palette, add_text=True, legend=False, height=3, text_kws={"fontsize":8, "fontweight":"bold"}),
        Compartment=pch.anno_simple(df_col['MS_compartment'], colors={"Matrix": "blue", "Striosome": "red"}, height=3),
        MatStr_Score=pch.anno_simple(df_col['MS_NORM'], cmap="coolwarm_r", norm=ms_score_norm, height=3),
        verbose=1, axis=1, plot_legend=True, legend_gap=5,
    )

    left_ha = pch.HeatmapAnnotation(
        label=pch.anno_label(
            df_row['annot'], merge=True, rotation=0, extend=True,
            colors="black", relpos=(1, 0.5), 
            arrowprops = dict(visible=True,)
        ),
        # Genes=pch.anno_simple(df_row[0]),
        verbose=1, axis=0
    )

    plt.figure(figsize=(8,4))
    cm = pch.ClusterMapPlotter(
        data=df_expr_norm,
        top_annotation=col_ha,
        left_annotation=left_ha,
        row_cluster=False,
        col_cluster=False,
        col_split=df_col[replicate_col] if groupby_replicate else None,
        row_dendrogram=False,
        label="Expression",
        cmap=cmap,
        rasterized=rasterized,
        ylabel=ylabel,
        xlabel=title,
        xlabel_side="top",
        vmin=vmin, vmax=vmax,
        col_split_order=replicate_order
    )

    # plot custom spines # For coloring the region in between the col splits as something other than white! 
    if color_splits is not None: 
        for i in range(cm.heatmap_axes.shape[0]-1):
            for j in range(cm.heatmap_axes.shape[1]):
                # if i != j:
                #     continue
                ax = cm.heatmap_axes[i][j]
                for side in ["right"]:
                    ax.spines[side].set_visible(True)
                    ax.spines[side].set_color(color_splits)
                    ax.spines[side].set_linewidth(2)

    # plt.suptitle(title)
    if save and image_path is not None:
        print("saving")
        plt.savefig(image_path / f"{filename}.png", dpi=300, bbox_inches="tight")
        plt.savefig(image_path / f"{filename}.pdf", dpi=300, bbox_inches="tight")
        # plt.savefig(image_path / f"{filename}.svg", dpi=300, bbox_inches="tight")
        # pass
    if show: 
        plt.show()
    plt.close()

In [None]:
# WUBIN THIS IS WHERE THE FUNCTION THAT DOES THE ACTUAL PLOTTING LIVES. 
def joint_expr_heatmap(
    adata_rna,
    adata_mch,
    df_genes,
    n_genes=40,
    gene_names=10,
    manual_added_genes= None, 
    heatmap_order_col='rho_axis',
    groupby_replicate=True,
    replicate_col="brain_region",
    replicate_palette=None,
    replicate_order=None,
    title="",
    filename=None, 
    out_path=None,
    ylabel="Genes",
    xlabel="Cells",
    rna_cmap="Reds",
    mch_cmap="parula",
    rasterized=True,
    show=True, 
    save=False,
    image_path=None,
    meth_min_max_quantiles=(0.02, 0.98),
    rna_min_max_quantiles=(0.02, 0.98),
    vmin=0, vmax=1,
    color_splits=None,
    figsize=(8,4),
    title_fontsize=14,
    def_fontsize=12, 
):
    """Plot expression heatmap for given genes and cell type"""
    import PyComplexHeatmap as pch
    from matplotlib.colors import TwoSlopeNorm

    if filename is None: 
        filename = f"joint_expr_heatmap_{title.replace(' ', '_')}"
    if replicate_palette is None:
        replicate_palette = adata_rna.uns.get(f"{replicate_col}_palette", None)

    df_col = adata_rna.obs[['Group', 'MS_NORM', 'MS_compartment']].copy()
    if df_col.shape[0] > 50000:
        df_col = df_col.sample(50000, random_state=42)
    if groupby_replicate:
        df_col[replicate_col] = adata_rna.obs[replicate_col]
        df_col = df_col.sort_values([replicate_col, 'MS_NORM'])
    else:
        df_col = df_col.sort_values('MS_NORM')

    df_row = df_genes.iloc[:n_genes].copy()
    if manual_added_genes is not None:
        df_manual = df_genes[ df_genes['gene'].isin(manual_added_genes) ].copy()
        df_row = pd.concat([df_row, df_manual], axis=0)
    df_row.drop_duplicates(subset='gene', keep='first', inplace=True)
    if "gene" in df_row.columns:
        df_row = df_row.set_index('gene')
    toplot = df_row.index[:gene_names]
    if manual_added_genes is not None:
        toplot = list(toplot) + manual_added_genes
    df_row['annot'] = [c if c in toplot else np.nan for c in df_row.index]
    df_row = df_row.sort_values(heatmap_order_col, ascending=True)
    # if "gene" in df_row.columns:
    #     df_row = df_row.set_index('gene')

    if hasattr(adata_rna.X, "toarray"):
        df_expr_rna = adata_rna.X.toarray()
    else: 
        df_expr_rna = adata_rna.X
    df_expr_rna = pd.DataFrame(df_expr_rna, index=adata_rna.obs_names, columns=adata_rna.var_names).T
    df_expr_rna = df_expr_rna.loc[df_row.index, df_col.index]
    # df_expr_norm_rna = (
    #     df_expr_rna
    #     .subtract(df_expr_rna.quantile(min_max_quantiles[0], axis=1), axis=0)
    #     .div(df_expr_rna.quantile(min_max_quantiles[1], axis=1) - df_expr_rna.quantile(min_max_quantiles[0], axis=1), axis=0)
    # )
    repl = []
    for cc, df in df_col.groupby(replicate_col, observed=True): 
        df_samp = df_expr_rna[ df.index ].copy()
        df_norm = (
            df_samp
            .subtract(df_samp.quantile(rna_min_max_quantiles[0], axis=1), axis=0)
            .div(df_samp.quantile(rna_min_max_quantiles[1], axis=1) - df_samp.quantile(rna_min_max_quantiles[0], axis=1), axis=0)
        )
        repl.append(df_norm)
    df_expr_norm_rna =  pd.concat(repl, axis=1)

    if hasattr(adata_mch.X, "toarray"):
        df_expr_mch = adata_mch.X.toarray()
    else: 
        df_expr_mch = adata_mch.X
    df_expr_mch = pd.DataFrame(df_expr_mch, index=adata_mch.obs_names, columns=adata_mch.var_names).T
    df_expr_mch = df_expr_mch.loc[df_row.index, df_col.index]
    # df_expr_norm_mch = (
    #     df_expr_mch
    #     .subtract(df_expr_mch.quantile(min_max_quantiles[0], axis=1), axis=0)
    #     .div(df_expr_mch.quantile(min_max_quantiles[1], axis=1) - df_expr_mch.quantile(min_max_quantiles[0], axis=1), axis=0)
    # )
    repl = []
    for cc, df in df_col.groupby(replicate_col, observed=True): 
        df_samp = df_expr_mch[ df.index ].copy()
        df_norm = (
            df_samp
            .subtract(df_samp.quantile(meth_min_max_quantiles[0], axis=1), axis=0)
            .div(df_samp.quantile(meth_min_max_quantiles[1], axis=1) - df_samp.quantile(meth_min_max_quantiles[0], axis=1), axis=0)
        )
        repl.append(df_norm)
    df_expr_norm_mch =  pd.concat(repl, axis=1)

    ms_score_norm = TwoSlopeNorm(vmin=-1, vcenter=0, vmax=1)
    col_ha = pch.HeatmapAnnotation(
        Brain_Region=pch.anno_simple(df_col[replicate_col], colors=replicate_palette, add_text=True, legend=False, height=3, text_kws={"fontsize":8, "fontweight":"bold"}, label="Brain Region"),
        Compartment=pch.anno_simple(df_col['MS_compartment'], colors={"Matrix": "blue", "Striosome": "red"}, height=3, label="Compartment"),
        MatStr_Score=pch.anno_simple(df_col['MS_NORM'], cmap="coolwarm_r", norm=ms_score_norm, height=3, label="MatStr Score"),
        verbose=1, axis=1, plot_legend=True, legend_gap=5, label_kws={'visible':False}
    )

    left_ha = pch.HeatmapAnnotation(
        label=pch.anno_label(
            df_row['annot'], merge=True, rotation=0, extend=True,
            colors="black", relpos=(1, 0.7), height=3, fontsize=def_fontsize,
            arrowprops = dict(visible=True, linewidth=1, connectionstyle='arc3'),
        ),
        verbose=1, axis=0, orientation='left'
    )

    cm1 = pch.ClusterMapPlotter(
        data=df_expr_norm_rna,
        top_annotation=col_ha,
        left_annotation=left_ha,
        row_cluster=False,
        col_cluster=False,
        col_split=df_col[replicate_col] if groupby_replicate else None,
        col_split_order=replicate_order,
        row_dendrogram=False,
        label="Expression",
        cmap=rna_cmap,
        rasterized=rasterized,
        ylabel=ylabel,
        xlabel="RNA",
        xlabel_side="top",
        plot=False,
        vmin=vmin, vmax=vmax,
    )

    col_ha = pch.HeatmapAnnotation(
        Brain_Region=pch.anno_simple(df_col[replicate_col], colors=replicate_palette, add_text=True, legend=False, height=3, text_kws={"fontsize":8, "fontweight":"bold"}, label="Brain Region"),
        Compartment=pch.anno_simple(df_col['MS_compartment'], colors={"Matrix": "blue", "Striosome": "red"}, height=3, label="Compartment"),
        MatStr_Score=pch.anno_simple(df_col['MS_NORM'], cmap="coolwarm_r", norm=ms_score_norm, height=3, label="MatStr Score"),
        verbose=1, axis=1, plot_legend=True, legend_gap=5,
    )

    cm2 = pch.ClusterMapPlotter(
        data=df_expr_norm_mch,
        top_annotation=col_ha,
        # left_annotation=left_ha,
        row_cluster=False,
        col_cluster=False,
        col_split=df_col[replicate_col] if groupby_replicate else None,
        col_split_order=replicate_order,
        row_dendrogram=False,
        col_dendrogram=False,
        label="Score",
        cmap=mch_cmap,
        rasterized=rasterized,
        ylabel=ylabel,
        xlabel="mCH",
        xlabel_side="top",
        plot=False,
        vmin=vmin, vmax=vmax,
    )


    plt.figure(figsize=figsize)
    ax, legend_axes = pch.composite(cmlist=[cm1, cm2], main=0, legend_hpad=3, col_gap=0.1)
    for _leg in legend_axes: 
        _leg.set_rasterized(False)
    ax.set_title(title, fontsize=title_fontsize, rasterized=False)

    # plot custom spines # For coloring the region in between the col splits as something other than white! 
    if color_splits is not None: 
        for cm in [cm1, cm2]:
            for i in range(cm.heatmap_axes.shape[0]):
                for j in range(cm.heatmap_axes.shape[1]-1):
                    # if i != j:
                    #     continue
                    ax = cm.heatmap_axes[i][j]
                    for side in ["right"]:
                        ax.spines[side].set_visible(True)
                        ax.spines[side].set_color(color_splits)
                        ax.spines[side].set_linewidth(2)

    if save and image_path is not None: 
        print("saving")
        plt.savefig(image_path / f"{filename}.png", dpi=300, bbox_inches="tight")
        plt.savefig(image_path / f"{filename}.pdf", dpi=300, bbox_inches="tight", transparent=False)
        # plt.savefig(image_path / f"{filename}.svg", dpi=300, bbox_inches="tight")
        # pass
    if show: 
        plt.show()
    plt.close()

# SparkX Summary Plots

In [None]:
level = "Subclass"
_let = level.lower()[0]
res_path = f"/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/sparkx_br{_let}/sparkx_per_replicate_results.csv"
meta_path = f"/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/sparkx_br{_let}/sparkx_meta_results.csv"

df_meta = pd.read_csv(meta_path)
sparkx_df = pd.read_csv(res_path)

common_genes = adata.var_names.intersection(spatial_mch.var_names)
df_meta = df_meta[df_meta['gene'].isin(common_genes)].copy()
sparkx_df = sparkx_df[sparkx_df['gene'].isin(common_genes)].copy()

In [None]:
# Do not want results from CAT
sparkx_df = sparkx_df[sparkx_df['replicate'] != "CAT"].copy()
sparkx_df['direction'] = (sparkx_df['rho_axis'] < 0).map({True: 'Striosome', False: 'Matrix'})

In [None]:
#sig_counts from meta analysis
sig_counts = df_meta.loc[df_meta['fdr'] < 1e-2].groupby("cell_type")['direction'].value_counts().reset_index().rename(columns={'count' : 'n_sig'})
sig_counts['direction'] = sig_counts['direction'].map({'down': 'Striosome', 'up':'Matrix'})

In [None]:
# --- Compute replicate-level counts ---
fdr_col = "p_sparkx"
fdr_thresh = 1e-2

# sig_counts = (
#     sparkx_df.groupby(["cell_type", "replicate", "direction"], observed=True)
#     .apply(lambda x: ((x[fdr_col] < fdr_thresh)).sum())
#     .reset_index(name="n_sig")
# )
# sig_counts = sig_counts.reset_index(drop=True)
# sig_counts = sig_counts[sig_counts['cell_type'] != "unknown"]

# cell_types = sig_counts["cell_type"].unique()
# directions = sig_counts["direction"].unique()

# Updated mid-soft palette
palette = {
    "Matrix":    (0.30, 0.50, 0.90),
    "Striosome": (0.90, 0.45, 0.45),
}


# --- Set up figure ---
fig, ax = plt.subplots(figsize=(8, 3))

x = np.arange(len(cell_types))
n_dir = len(directions)
width = 0.8 / n_dir

jitter_scale = 0.3   # jitter occupies 30% of the bar width
rng = np.random.default_rng(42)

# --- Plot bars, error bars, and jittered points ---
for i, direction in enumerate(directions):

    # x-pos of this direction’s bars
    x_pos = x - 0.4 + i * width + width / 2

    means = []
    errs = []
    points_per_ct = []   # ensures exact alignment with x_pos

    for ct in cell_types:
        vals = sig_counts.loc[
            (sig_counts["cell_type"] == ct) &
            (sig_counts["direction"] == direction),
            "n_sig"
        ].values
        
        means.append(vals.mean())
        errs.append(vals.std(ddof=1))
        points_per_ct.append(vals)  # aligned list → FIX

    # --- Bars ---
    ax.bar(
        x_pos, means, width=width,
        color=palette.get(direction, "gray"),
        edgecolor="black", linewidth=0.5,
        label=direction,
        rasterized=RASTERIZED, zorder=2
    )

    # --- Error bars ---
    ax.errorbar(
        x_pos, means, yerr=errs,
        fmt="none", ecolor="black",
        elinewidth=1, capsize=3,
        zorder=3
    )

    # --- Jittered points ---
    for xi, vals in zip(x_pos, points_per_ct):
        if len(vals) == 0:
            continue
        jitter = rng.uniform(
            -width * jitter_scale,
            width * jitter_scale,
            size=len(vals)
        )
        ax.scatter(
            xi + jitter, vals,
            color="black", s=2,
            zorder=4,
            rasterized=RASTERIZED
        )

# --- Axes styling ---
ax.set_xticks(x)
ax.set_xticklabels(cell_types, rotation=30, ha="right", fontsize=8, rasterized=RASTERIZED)

ax.set_ylabel(f"# Significant Genes (FDR<{fdr_thresh})", rasterized=RASTERIZED)
# ax.set_xlabel("Cell Type", rasterized=RASTERIZED)
ax.set_title("Significant Compartment Enriched Genes per Cell Type", rasterized=RASTERIZED, fontsize=14)

# --- Color x-labels conditionally ---
color_ct = [
    ct for ct in cell_types
    if sig_counts[sig_counts["cell_type"] == ct]["n_sig"].sum() > 25
]

for ticklabel in ax.get_xticklabels():
    ticklabel.set_color("black" if ticklabel.get_text() in color_ct else "lightgray")

ax.legend(frameon=False)

fig.tight_layout()
# fig.savefig(image_path / f"num_sig_genes_per_celltype_full_{level}.png",dpi=300, bbox_inches="tight")
# fig.savefig(image_path / f"num_sig_genes_per_celltype_full_{level}.pdf",dpi=300, bbox_inches="tight")
plt.show()
plt.close()


In [None]:
# fdr_col = "p_sparkx"
# fdr_thresh = 1e-2
# sig_counts = (
#     sparkx_df.groupby(["cell_type", "replicate", "direction"], observed=True)
#     .apply(lambda x: ((x[fdr_col] < fdr_thresh)).sum())
#     .reset_index(name="n_sig")
# )
# sig_counts = sig_counts.reset_index()

# fig, ax = plt.subplots(figsize=(8,3))
# sns.barplot(
#     data=sig_counts, x="cell_type", y="n_sig", hue="direction",
#     palette = {"Matrix": "blue", "Striosome": "red"},
#     errorbar=None, err_kws={"color": "black", "linewidth": 1}, capsize=0.2,
#     ax=ax, rasterized=RASTERIZED
# )
# ax.set_ylabel(f"# Significant Genes (FDR<{fdr_thresh})", rasterized=RASTERIZED)
# ax.set_xlabel("Cell Type", rasterized=RASTERIZED)
# ax.set_title("Significant Compartment Enriched Genes per Cell Type", rasterized=RASTERIZED)
# ax.set_xticks(ax.get_xticks())
# ax.set_xticklabels(ax.get_xticklabels(), rotation=30, ha="right", rasterized=RASTERIZED)


# color_ct = []
# for _ct in sig_counts['cell_type'].unique():
#     if sig_counts[(sig_counts['cell_type'] == _ct)]['n_sig'].sum() > 5: 
#         color_ct.append(_ct)

# for i, ticklabel in enumerate(ax.get_xticklabels()): 
#     ticklabel.set_color('black' if ticklabel.get_text() in color_ct else 'lightgray')

# plt.tight_layout()
# plt.savefig(image_path / f"num_sig_genes_per_celltype_full_{level}.png", dpi=300, bbox_inches="tight")
# plt.savefig(image_path / f"num_sig_genes_per_celltype_full_{level}.pdf", dpi=300, bbox_inches="tight")
# plt.show()
# plt.close()

In [None]:
fdr_col = "p_sparkx"
fdr_thresh = 1e-5
sig_counts = (
    sparkx_df.groupby(["cell_type", "replicate"], observed=True)
    .apply(lambda x: ((x[fdr_col] < fdr_thresh)).sum())
    .reset_index(name="n_sig")
    .groupby(["cell_type"], observed=True)
    .agg({'n_sig': ['mean', 'std']})
)
sig_counts.columns = sig_counts.columns.droplevel(0)
sig_counts = (
    sig_counts.reset_index()
    .query("mean > 1")
    .sort_values('mean', ascending=True)
)
sig_counts['color'] = sig_counts['cell_type'].map(adata.uns[f"{level}_palette"])
sig_counts = sig_counts[sig_counts['cell_type'] != "unknown"]


fig, ax = plt.subplots(figsize=(4,3))
ax.barh(
    data=sig_counts, y="cell_type", width="mean", color=sig_counts['color'],
    rasterized=RASTERIZED
)
ax.set_ylabel("Cell Type", rasterized=RASTERIZED)
ax.set_xlabel(f"# Significant Genes (FDR<{fdr_thresh})", rasterized=RASTERIZED)
ax.set_title("Significant Compartment Enriched Genes per Cell Type", rasterized=RASTERIZED, x=0.2)

group_to_nt = adata.obs[[level, "Neighborhood"]].drop_duplicates().set_index(level).to_dict()['Neighborhood']
tick_to_nn_pal = {tick: adata.uns['m3c_neighborhood_palette'].get(nt, '#808080') for tick, nt in group_to_nt.items()}
for i, ticklabel in enumerate(ax.get_yticklabels()): 
    ticklabel.set_color(tick_to_nn_pal.get(ticklabel.get_text(), '#808080'))

plt.tight_layout()
# plt.savefig(image_path / f"num_sig_genes_per_celltype_{level}.png", dpi=300, bbox_inches="tight")
# plt.savefig(image_path / f"num_sig_genes_per_celltype_{level}.pdf", dpi=300, bbox_inches="tight", pad_inches=0.5)
plt.show()
plt.close()

## Trial: Overlap with Wubin DMGs

In [None]:
from matplotlib_venn import venn2

In [None]:
dmg_res = pd.read_csv("/anvil/projects/x-mcb130189/Wubin/BG/DMG/merged_dmg.one_vs_rest.tsv", sep="\t")
msn_dmg = dmg_res.query("Group=='Group' & Parent=='MSN' & mc_type=='CH'")
# display(dmg_res.head())
del dmg_res
msn_dmg.head()

In [None]:
spatial_genes = adata.var_names
common_genes = spatial_genes.intersection(msn_dmg['Gene'])
msn_dmg = msn_dmg[ msn_dmg['Gene'].isin(common_genes) ].copy()

In [None]:
matrix_dmgs = msn_dmg.loc[(msn_dmg['groupA'] == "STRd D2 Matrix MSN") | (msn_dmg['groupA'] == "STRd D1 Matrix MSN") ]
striosome_dmgs = msn_dmg.loc[(msn_dmg['groupA'] == "STRd D1 Striosome MSN") | (msn_dmg['groupA'] == "STRd D2 Striosome MSN") ]

In [None]:
_ct = "STR D2 MSN"
fdr_thresh = 1e-5
df_sub = (
    sparkx_df[sparkx_df['cell_type'] == _ct]
    .query(f"{fdr_col} < {fdr_thresh} & {fdr_col} >= 0")
    .groupby(["gene", "replicate"])[[fdr_col, 'rho_axis']]
    .median()
    .sort_values(fdr_col, ascending=True)
    .reset_index()
    .drop_duplicates(subset='gene', keep='first')
)
df_sub.shape

In [None]:
df_sub_str = df_sub.loc[df_sub['rho_axis'] < 0]
df_sub_mat = df_sub.loc[df_sub['rho_axis'] > 0]

In [None]:
naming_map = {"STR D1 MSN" : ["STRd D1 Matrix MSN", "STRd D1 Striosome MSN"], 
              "STR D2 MSN" : ["STRd D2 Matrix MSN", "STRd D2 Striosome MSN"] }

In [None]:
dmg_fdr_thresh = 1e-5
sub_matrix_dmgs = matrix_dmgs[(matrix_dmgs["groupA"].isin(naming_map[_ct])) & (matrix_dmgs['fdr_bh'] < dmg_fdr_thresh)].copy()
sub_matrix_dmgs.drop_duplicates(subset='Gene', keep='first', inplace=True)
sub_matrix_dmgs.shape

mat_ab = df_sub_mat['gene'].isin(sub_matrix_dmgs['Gene']).sum()
mat_a = df_sub_mat.shape[0] - mat_ab
mat_b = sub_matrix_dmgs.shape[0] - mat_ab

fig, ax = plt.subplots(figsize=(4,4))
venn2(
    subsets= (mat_a, mat_b, mat_ab),
    set_labels = ('Matrix Enriched Genes', 'Matrix DMGs'),
    set_colors = (adata.uns[f"{level}_palette"]['STR D1 MSN'], 'blue'),
    ax=ax)
ax.set_title("Matrix Enriched Genes vs Matrix DMGs\n(STR D1 MSNs)", rasterized=True)
plt.show()

In [None]:
dmg_fdr_thresh = 1e-5
sub_str_dmgs = striosome_dmgs[(striosome_dmgs["groupA"].isin(naming_map[_ct])) & (striosome_dmgs['fdr_bh'] < dmg_fdr_thresh)].copy()
sub_str_dmgs.drop_duplicates(subset='Gene', keep='first', inplace=True)
sub_str_dmgs.shape

str_ab = df_sub_str['gene'].isin(sub_str_dmgs['Gene']).sum()
str_a = df_sub_str.shape[0] - str_ab
str_b = sub_str_dmgs.shape[0] - str_ab

fig, ax = plt.subplots(figsize=(4,4))
venn2(
    subsets= (str_a, str_b, str_ab),
    set_labels = ('Striosome Enriched Genes', 'Striosome DMGs'),
    set_colors = (adata.uns[f"MSN_Groups_palette"][naming_map[_ct][0]], 'blue'),
    ax=ax)
ax.set_title("Striosome Enriched Genes vs Striosome DMGs\n(STR D1 MSNs)", rasterized=True)
plt.show()

# SPARKX Heatmaps
Start from here to just plot the heatmaps after running the initial cells and the helper functions

In [None]:
level = "Subclass"
_let = level.lower()[0]
res_path = f"/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/sparkx_br{_let}/sparkx_per_replicate_results.csv"
meta_path = f"/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/sparkx_br{_let}/sparkx_meta_results.csv"

df_meta = pd.read_csv(meta_path)
sparkx_df = pd.read_csv(res_path)

common_genes = adata.var_names.intersection(spatial_mch.var_names)
df_meta = df_meta[df_meta['gene'].isin(common_genes)].copy()
sparkx_df = sparkx_df[sparkx_df['gene'].isin(common_genes)].copy()

# Do not want results from CAT
sparkx_df = sparkx_df[sparkx_df['replicate'] != "CAT"].copy()
sparkx_df['direction'] = (sparkx_df['rho_axis'] > 0).map({True: 'Striosome', False: 'Matrix'})

# --- Compute replicate-level counts ---
fdr_col = "p_sparkx"
fdr_thresh = 1e-5

group_to_nt = adata.obs[[level, "Neighborhood"]].drop_duplicates().set_index(level).to_dict()['Neighborhood']

In [None]:
transfer_cols = ['brain_region_corr', 'MS_NORM', 'MS_SCORE', 'MS_compartment']
for col in transfer_cols:
    spatial_mch.obs[col] = adata.obs.loc[spatial_mch.obs.index, col].copy()

joint_cells = adata.obs_names.intersection(spatial_mch.obs_names)
adata_joint_rna = adata[joint_cells].copy()
adata_joint_mch = spatial_mch[joint_cells].copy()

repr_order = ["CaH", "CaB", "Pu", "NAC"]

In [None]:
# %%capture
# # TODO: Add Manual Labeling of Genes
# topn = 10
# for _ct in sig_counts['cell_type'].unique():
#     df_sub = (
#         sparkx_df[sparkx_df['cell_type'] == _ct]
#         .query(f"{fdr_col} < {fdr_thresh} & {fdr_col} >= 0")
#         .groupby(["gene", "replicate"])[[fdr_col, 'rho_axis']]
#         .median()
#         .sort_values(fdr_col, ascending=True)
#         .reset_index()
#         # .drop_duplicates(subset='gene', keep='first')
#     )
#     gene_vc = df_sub['gene'].value_counts()
#     ns = len(df_sub['gene'].unique())
#     if ns < 50: 
#         gvc_thr = 1
#     elif ns < 100: 
#         gvc_thr = 2
#     else: 
#         gvc_thr = 3
#     keep_genes = gene_vc[gene_vc >= gvc_thr].index
#     df_sub = df_sub[df_sub['gene'].isin(keep_genes)].copy()
#     df_sub = df_sub.drop_duplicates(subset='gene', keep='first')

#     nt = group_to_nt.get(_ct)
#     ns = len(df_sub['gene'].unique())
#     print(_ct, ns, nt)
#     _ct_filename = _ct.replace("/", "_").replace(" ", "_")
#     if ns < 5: 
#         continue

#     if nt in ['Nonneuron', "unknown"]: 
#         # Plot only RNA
#         # pass
#         adata_ct = adata[(adata.obs[level] == _ct) & (~adata.obs['MS_NORM'].isna())].copy()
#         adata_ct = adata_ct[adata_ct.obs['brain_region_corr'].isin(repr_order)].copy()
#         adata_ct = _downsample_reference(adata_ct, cluster_col="MS_compartment", max_cluster_size=6000, min_cluster_size=0)
#         adata_ct.obs['brain_region_corr'] = pd.Categorical(adata_ct.obs['brain_region_corr'], categories=repr_order, ordered=True)
        
#         rep_palette = {k : v for k, v in adata_ct.uns.get("brain_region_corr_palette", {}).items() if k in adata_ct.obs['brain_region_corr'].unique()}
#         expr_heatmap(adata_ct, df_sub, n_genes=ns, gene_names=10, ylabel=None, color_splits="white",
#                     rasterized=True, show=False, save=True, image_path=image_path, title=f"Top SPARK-X genes in {_ct}",
#                     replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
#                     cmap='agSunset', vmin=0, vmax=1, min_max_quantiles=(0.25, 0.99), 
#                     filename=f"sparkx_heatmap_{_ct_filename}",
#         )
#         # break
#     else: 
#         # Plot joint RNA + mCH
#         adata_rna = adata_joint_rna[(adata_joint_rna.obs[level] == _ct) & (~adata_joint_rna.obs['MS_NORM'].isna())].copy()

#         adata_rna = adata_rna[adata_rna.obs['brain_region_corr'].isin(repr_order)].copy()
#         adata_rna = _downsample_reference(adata_rna, cluster_col="MS_compartment", max_cluster_size=4000, min_cluster_size=0)
#         adata_rna.obs['brain_region_corr'] = pd.Categorical(adata_rna.obs['brain_region_corr'], categories=repr_order, ordered=True)

#         adata_mch = adata_joint_mch[(adata_joint_mch.obs[level] == _ct) & (~adata_joint_mch.obs['MS_NORM'].isna())].copy()
#         adata_mch = adata_mch[adata_mch.obs['brain_region_corr'].isin(repr_order)].copy()
#         adata_mch.obs['brain_region_corr'] = pd.Categorical(adata_mch.obs['brain_region_corr'], categories=repr_order, ordered=True)
        
#         rep_palette = {k : adata_rna.uns["brain_region_corr_palette"].get(k, None) for k in repr_order}
#         joint_expr_heatmap(adata_rna, adata_mch, df_sub, n_genes=ns, gene_names=topn, ylabel=None, color_splits="white", 
#                     rasterized=True, show=False, save=True, image_path=image_path, title=f"Top SPARK-X genes in {_ct}",
#                     replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
#                     rna_cmap='agSunset', vmin=0, vmax=1, meth_min_max_quantiles=(0.1, 0.9), rna_min_max_quantiles=(0.25, 0.99),
#                     figsize=(12,6), filename=f"sparkx_heatmap_{_ct_filename}",
#         )

In [None]:
# %%capture
manual_genes = ['BACH2', 'KIRREL3', 'GRM1', 'PDYN']
topn = 5
# max_genes = 40
_ct = "STR D1 MSN"
df_sub = (
    sparkx_df[sparkx_df['cell_type'] == _ct]
    .query(f"{fdr_col} < {fdr_thresh} & {fdr_col} >= 0 | gene in {manual_genes}")
    .groupby(["gene", "replicate"])[[fdr_col, 'rho_axis']]
    .median()
    .sort_values(fdr_col, ascending=True)
    .reset_index()
    # .drop_duplicates(subset='gene', keep='first')
)
gene_vc = df_sub['gene'].value_counts()
ns = len(df_sub['gene'].unique())
if ns < 50: 
    gvc_thr = 1
elif ns < 100: 
    gvc_thr = 2
else: 
    gvc_thr = 3
keep_genes = gene_vc[gene_vc >= gvc_thr].index
df_sub = df_sub[df_sub['gene'].isin(keep_genes)].copy()
df_sub = df_sub.drop_duplicates(subset='gene', keep='first')
# df_sub = df_sub.head(max_genes)

nt = group_to_nt.get(_ct)
ns = len(df_sub['gene'].unique())
print(_ct, ns, nt)
_ct_filename = _ct.replace("/", "_").replace(" ", "_")
# if ns < 5: 
#     continue

if nt in ['Nonneuron', "unknown"]: 
    # Plot only RNA
    # pass
    adata_ct = adata[(adata.obs[level] == _ct) & (~adata.obs['MS_NORM'].isna())].copy()
    adata_ct = adata_ct[adata_ct.obs['brain_region_corr'].isin(repr_order)].copy()
    adata_ct = _downsample_reference(adata_ct, cluster_col="MS_compartment", max_cluster_size=4000, min_cluster_size=0)
    adata_ct.obs['brain_region_corr'] = pd.Categorical(adata_ct.obs['brain_region_corr'], categories=repr_order, ordered=True)
    
    rep_palette = {k : v for k, v in adata_ct.uns.get("brain_region_corr_palette", {}).items() if k in adata_ct.obs['brain_region_corr'].unique()}
    expr_heatmap(adata_ct, df_sub, n_genes=ns, gene_names=10, ylabel=None, color_splits="white",
                rasterized=True, show=False, save=True, image_path=image_path, title=f"Top SPARK-X genes in {_ct}",
                replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
                cmap='agSunset', vmin=0, vmax=1, min_max_quantiles=(0.25, 0.99), 
                filename=f"sparkx_heatmap_{_ct_filename}",
    )
    # break
else: 
    # Plot joint RNA + mCH
    adata_rna = adata_joint_rna[(adata_joint_rna.obs[level] == _ct) & (~adata_joint_rna.obs['MS_NORM'].isna())].copy()

    adata_rna = adata_rna[adata_rna.obs['brain_region_corr'].isin(repr_order)].copy()
    adata_rna = _downsample_reference(adata_rna, cluster_col="MS_compartment", max_cluster_size=8000, min_cluster_size=0)
    adata_rna.obs['brain_region_corr'] = pd.Categorical(adata_rna.obs['brain_region_corr'], categories=repr_order, ordered=True)

    adata_mch = adata_joint_mch[(adata_joint_mch.obs[level] == _ct) & (~adata_joint_mch.obs['MS_NORM'].isna())].copy()
    adata_mch = adata_mch[adata_mch.obs['brain_region_corr'].isin(repr_order)].copy()
    adata_mch.obs['brain_region_corr'] = pd.Categorical(adata_mch.obs['brain_region_corr'], categories=repr_order, ordered=True)
    
    rep_palette = {k : adata_rna.uns["brain_region_corr_palette"].get(k, None) for k in repr_order}
    joint_expr_heatmap(adata_rna, adata_mch, df_sub, n_genes=ns, gene_names=topn, manual_added_genes = manual_genes, 
                ylabel=None, color_splits="white", 
                rasterized=True, show=False, save=True, image_path=image_path, title=f"Top SPARK-X genes in {_ct}",
                replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
                rna_cmap='agSunset', vmin=0, vmax=1, meth_min_max_quantiles=(0.1, 0.9), rna_min_max_quantiles=(0.25, 0.99),
                figsize=(12,6), filename=f"sparkx_heatmap_{_ct_filename}_2",
    )

In [None]:
manual_genes = ['RASGRF2', 'GLP1R', 'GALNT17', 'ALK']
topn = 5
_ct = "CN ST18 GABA"
df_sub = (
    sparkx_df[sparkx_df['cell_type'] == _ct]
    .query(f"{fdr_col} < {fdr_thresh} & {fdr_col} >= 0 | gene in {manual_genes}")
    .groupby(["gene", "replicate"])[[fdr_col, 'rho_axis']]
    .median()
    .sort_values(fdr_col, ascending=True)
    .reset_index()
    # .drop_duplicates(subset='gene', keep='first')
)
gene_vc = df_sub['gene'].value_counts()
ns = len(df_sub['gene'].unique())
if ns < 50: 
    gvc_thr = 1
elif ns < 100: 
    gvc_thr = 2
else: 
    gvc_thr = 3
keep_genes = gene_vc[gene_vc >= gvc_thr].index
df_sub = df_sub[df_sub['gene'].isin(keep_genes)].copy()
df_sub = df_sub.drop_duplicates(subset='gene', keep='first')

In [None]:
#sig_counts from meta analysis
sig_counts = df_meta.loc[df_meta['fdr'] < 0.05].groupby("cell_type")['direction'].value_counts().reset_index().rename(columns={'count' : 'n_sig'})
sig_counts['direction'] = sig_counts['direction'].map({'down': 'Striosome', 'up':'Matrix'})

In [None]:
_ct = "STR D2 MSN"
fdr_thresh = 1e-5
df_sub2 = (
    df_meta.loc[df_meta['cell_type'] == _ct]
    .query(f"fdr < {fdr_thresh}")
)
_ct_filename = _ct.replace("/", "_").replace(" ", "_")

In [None]:
adata_ct = adata[(adata.obs[level] == _ct) & (~adata.obs['MS_NORM'].isna())].copy()
adata_ct = adata_ct[adata_ct.obs['brain_region_corr'].isin(repr_order)].copy()
adata_ct = _downsample_reference(adata_ct, cluster_col="MS_compartment", max_cluster_size=4000, min_cluster_size=0)
adata_ct.obs['brain_region_corr'] = pd.Categorical(adata_ct.obs['brain_region_corr'], categories=repr_order, ordered=True)

rep_palette = {k : v for k, v in adata_ct.uns.get("brain_region_corr_palette", {}).items() if k in adata_ct.obs['brain_region_corr'].unique()}
expr_heatmap(adata_ct, df_sub2, n_genes=ns, gene_names=10, ylabel=None, color_splits="white", heatmap_order_col="meta_Z",
            rasterized=True, show=True, save=False, image_path=image_path, title=f"Top SPARK-X genes in {_ct}",
            replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
            cmap='agSunset', vmin=0, vmax=1, min_max_quantiles=(0.25, 0.99), 
            filename=f"sparkx_heatmap_{_ct_filename}",
)

In [None]:
adata_rna = adata_joint_rna[(adata_joint_rna.obs[level] == _ct) & (~adata_joint_rna.obs['MS_NORM'].isna())].copy()

adata_rna = adata_rna[adata_rna.obs['brain_region_corr'].isin(repr_order)].copy()
adata_rna = _downsample_reference(adata_rna, cluster_col="MS_compartment", max_cluster_size=2000, min_cluster_size=0)
adata_rna.obs['brain_region_corr'] = pd.Categorical(adata_rna.obs['brain_region_corr'], categories=repr_order, ordered=True)

adata_mch = adata_joint_mch[(adata_joint_mch.obs[level] == _ct) & (~adata_joint_mch.obs['MS_NORM'].isna())].copy()
adata_mch = adata_mch[adata_mch.obs['brain_region_corr'].isin(repr_order)].copy()
adata_mch.obs['brain_region_corr'] = pd.Categorical(adata_mch.obs['brain_region_corr'], categories=repr_order, ordered=True)

rep_palette = {k : adata_rna.uns["brain_region_corr_palette"].get(k, None) for k in repr_order}
joint_expr_heatmap(adata_rna, adata_mch, df_sub2, n_genes=ns, gene_names=topn, manual_added_genes = manual_genes, 
            ylabel=None, color_splits="white", heatmap_order_col="meta_Z",
            rasterized=True, show=True, save=False, image_path=image_path, title=f"Top SPARK-X genes in {_ct}",
            replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
            rna_cmap='agSunset', vmin=0, vmax=1, meth_min_max_quantiles=(0.1, 0.9), rna_min_max_quantiles=(0.25, 0.99),
            figsize=(12,6), filename=f"sparkx_heatmap_{_ct_filename}_2",
)

In [None]:
# %%capture
# TODO: Add Manual Labeling of Genes
manual_genes = ['RASGRF2', 'GLP1R', 'GALNT17', 'ALK']
topn = 5
_ct = "CN ST18 GABA"
df_sub = (
    sparkx_df[sparkx_df['cell_type'] == _ct]
    .query(f"{fdr_col} < {fdr_thresh} & {fdr_col} >= 0 | gene in {manual_genes}")
    .groupby(["gene", "replicate"])[[fdr_col, 'rho_axis']]
    .median()
    .sort_values(fdr_col, ascending=True)
    .reset_index()
    # .drop_duplicates(subset='gene', keep='first')
)
gene_vc = df_sub['gene'].value_counts()
ns = len(df_sub['gene'].unique())
if ns < 50: 
    gvc_thr = 1
elif ns < 100: 
    gvc_thr = 2
else: 
    gvc_thr = 3
keep_genes = gene_vc[gene_vc >= gvc_thr].index
df_sub = df_sub[df_sub['gene'].isin(keep_genes)].copy()
df_sub = df_sub.drop_duplicates(subset='gene', keep='first')

nt = group_to_nt.get(_ct)
ns = len(df_sub['gene'].unique())
print(_ct, ns, nt)
_ct_filename = _ct.replace("/", "_").replace(" ", "_")
# if ns < 5: 
#     continue

if nt in ['Nonneuron', "unknown"]: 
    # Plot only RNA
    # pass
    adata_ct = adata[(adata.obs[level] == _ct) & (~adata.obs['MS_NORM'].isna())].copy()
    adata_ct = adata_ct[adata_ct.obs['brain_region_corr'].isin(repr_order)].copy()
    adata_ct = _downsample_reference(adata_ct, cluster_col="MS_compartment", max_cluster_size=4000, min_cluster_size=0)
    adata_ct.obs['brain_region_corr'] = pd.Categorical(adata_ct.obs['brain_region_corr'], categories=repr_order, ordered=True)
    
    rep_palette = {k : v for k, v in adata_ct.uns.get("brain_region_corr_palette", {}).items() if k in adata_ct.obs['brain_region_corr'].unique()}
    expr_heatmap(adata_ct, df_sub, n_genes=ns, gene_names=10, ylabel=None, color_splits="white",
                rasterized=True, show=False, save=True, image_path=image_path, title=f"Top SPARK-X genes in {_ct}",
                replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
                cmap='agSunset', vmin=0, vmax=1, min_max_quantiles=(0.25, 0.99), 
                filename=f"sparkx_heatmap_{_ct_filename}",
    )
    # break
else: 
    # Plot joint RNA + mCH
    adata_rna = adata_joint_rna[(adata_joint_rna.obs[level] == _ct) & (~adata_joint_rna.obs['MS_NORM'].isna())].copy()

    adata_rna = adata_rna[adata_rna.obs['brain_region_corr'].isin(repr_order)].copy()
    adata_rna = _downsample_reference(adata_rna, cluster_col="MS_compartment", max_cluster_size=2000, min_cluster_size=0)
    adata_rna.obs['brain_region_corr'] = pd.Categorical(adata_rna.obs['brain_region_corr'], categories=repr_order, ordered=True)

    adata_mch = adata_joint_mch[(adata_joint_mch.obs[level] == _ct) & (~adata_joint_mch.obs['MS_NORM'].isna())].copy()
    adata_mch = adata_mch[adata_mch.obs['brain_region_corr'].isin(repr_order)].copy()
    adata_mch.obs['brain_region_corr'] = pd.Categorical(adata_mch.obs['brain_region_corr'], categories=repr_order, ordered=True)
    
    rep_palette = {k : adata_rna.uns["brain_region_corr_palette"].get(k, None) for k in repr_order}
    joint_expr_heatmap(adata_rna, adata_mch, df_sub, n_genes=ns, gene_names=topn, manual_added_genes = manual_genes, 
                ylabel=None, color_splits="white", 
                rasterized=True, show=False, save=True, image_path=image_path, title=f"Top SPARK-X genes in {_ct}",
                replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
                rna_cmap='agSunset', vmin=0, vmax=1, meth_min_max_quantiles=(0.1, 0.9), rna_min_max_quantiles=(0.25, 0.99),
                figsize=(12,6), filename=f"sparkx_heatmap_{_ct_filename}_2",
    )

In [None]:
def joint_expr_heatmap(
    adata_rna,
    adata_mch,
    df_genes,
    n_genes=40,
    gene_names=10,
    manual_added_genes= None, 
    heatmap_order_col='rho_axis',
    groupby_replicate=True,
    replicate_col="brain_region",
    replicate_palette=None,
    replicate_order=None,
    title="",
    filename=None, 
    out_path=None,
    ylabel="Genes",
    xlabel="Cells",
    rna_cmap="Reds",
    mch_cmap="parula",
    rasterized=True,
    show=True, 
    save=False,
    image_path=None,
    meth_min_max_quantiles=(0.02, 0.98),
    rna_min_max_quantiles=(0.02, 0.98),
    vmin=0, vmax=1,
    color_splits=None,
    figsize=(8,4),
    title_fontsize=14,
    def_fontsize=12, 
):
    """Plot expression heatmap for given genes and cell type"""
    import PyComplexHeatmap as pch
    from matplotlib.colors import TwoSlopeNorm

    if filename is None: 
        filename = f"joint_expr_heatmap_{title.replace(' ', '_')}"
    if replicate_palette is None:
        replicate_palette = adata_rna.uns.get(f"{replicate_col}_palette", None)

    df_col = adata_rna.obs[['Group', 'MS_NORM', 'MS_compartment']].copy()
    if df_col.shape[0] > 50000:
        df_col = df_col.sample(50000, random_state=42)
    if groupby_replicate:
        df_col[replicate_col] = adata_rna.obs[replicate_col]
        df_col = df_col.sort_values([replicate_col, 'MS_NORM'])
    else:
        df_col = df_col.sort_values('MS_NORM')

    df_row = df_genes.iloc[:n_genes].copy()
    if manual_added_genes is not None:
        df_manual = df_genes[ df_genes['gene'].isin(manual_added_genes) ].copy()
        df_row = pd.concat([df_row, df_manual], axis=0)
    df_row.drop_duplicates(subset='gene', keep='first', inplace=True)
    if "gene" in df_row.columns:
        df_row = df_row.set_index('gene')
    toplot = df_row.index[:gene_names]
    if manual_added_genes is not None:
        toplot = list(toplot) + manual_added_genes
    df_row['annot'] = [c if c in toplot else np.nan for c in df_row.index]
    df_row = df_row.sort_values(heatmap_order_col, ascending=True)
    # if "gene" in df_row.columns:
    #     df_row = df_row.set_index('gene')

    if hasattr(adata_rna.X, "toarray"):
        df_expr_rna = adata_rna.X.toarray()
    else: 
        df_expr_rna = adata_rna.X
    df_expr_rna = pd.DataFrame(df_expr_rna, index=adata_rna.obs_names, columns=adata_rna.var_names).T
    df_expr_rna = df_expr_rna.loc[df_row.index, df_col.index]
    # df_expr_norm_rna = (
    #     df_expr_rna
    #     .subtract(df_expr_rna.quantile(min_max_quantiles[0], axis=1), axis=0)
    #     .div(df_expr_rna.quantile(min_max_quantiles[1], axis=1) - df_expr_rna.quantile(min_max_quantiles[0], axis=1), axis=0)
    # )
    repl = []
    for cc, df in df_col.groupby(replicate_col, observed=True): 
        df_samp = df_expr_rna[ df.index ].copy()
        df_norm = (
            df_samp
            .subtract(df_samp.quantile(rna_min_max_quantiles[0], axis=1), axis=0)
            .div(df_samp.quantile(rna_min_max_quantiles[1], axis=1) - df_samp.quantile(rna_min_max_quantiles[0], axis=1), axis=0)
        )
        repl.append(df_norm)
    df_expr_norm_rna =  pd.concat(repl, axis=1)

    if hasattr(adata_mch.X, "toarray"):
        df_expr_mch = adata_mch.X.toarray()
    else: 
        df_expr_mch = adata_mch.X
    df_expr_mch = pd.DataFrame(df_expr_mch, index=adata_mch.obs_names, columns=adata_mch.var_names).T
    df_expr_mch = df_expr_mch.loc[df_row.index, df_col.index]
    # df_expr_norm_mch = (
    #     df_expr_mch
    #     .subtract(df_expr_mch.quantile(min_max_quantiles[0], axis=1), axis=0)
    #     .div(df_expr_mch.quantile(min_max_quantiles[1], axis=1) - df_expr_mch.quantile(min_max_quantiles[0], axis=1), axis=0)
    # )
    repl = []
    for cc, df in df_col.groupby(replicate_col, observed=True): 
        df_samp = df_expr_mch[ df.index ].copy()
        df_norm = (
            df_samp
            .subtract(df_samp.quantile(meth_min_max_quantiles[0], axis=1), axis=0)
            .div(df_samp.quantile(meth_min_max_quantiles[1], axis=1) - df_samp.quantile(meth_min_max_quantiles[0], axis=1), axis=0)
        )
        repl.append(df_norm)
    df_expr_norm_mch =  pd.concat(repl, axis=1)

    ms_score_norm = TwoSlopeNorm(vmin=-1, vcenter=0, vmax=1)
    col_ha = pch.HeatmapAnnotation(
        Brain_Region=pch.anno_simple(df_col[replicate_col], colors=replicate_palette, add_text=True, legend=False, height=3, text_kws={"fontsize":8, "fontweight":"bold"}, label="Brain Region"),
        Compartment=pch.anno_simple(df_col['MS_compartment'], colors={"Matrix": "blue", "Striosome": "red"}, height=3, label="Compartment"),
        MatStr_Score=pch.anno_simple(df_col['MS_NORM'], cmap="coolwarm_r", norm=ms_score_norm, height=3, label="MatStr Score"),
        verbose=1, axis=1, plot_legend=True, legend_gap=5, label_kws={'visible':False}
    )

    left_ha = pch.HeatmapAnnotation(
        label=pch.anno_label(
            df_row['annot'], merge=True, rotation=0, extend=True,
            colors="black", relpos=(1, 0.7), height=3, fontsize=def_fontsize,
            arrowprops = dict(visible=True, linewidth=0.5)
        ),
        verbose=1, axis=0
    )

    cm1 = pch.ClusterMapPlotter(
        data=df_expr_norm_rna,
        top_annotation=col_ha,
        left_annotation=left_ha,
        row_cluster=False,
        col_cluster=False,
        col_split=df_col[replicate_col] if groupby_replicate else None,
        col_split_order=replicate_order,
        row_dendrogram=False,
        label="Expression",
        cmap=rna_cmap,
        rasterized=rasterized,
        ylabel=ylabel,
        xlabel="RNA",
        xlabel_side="top",
        plot=False,
        vmin=vmin, vmax=vmax,
    )

    col_ha = pch.HeatmapAnnotation(
        Brain_Region=pch.anno_simple(df_col[replicate_col], colors=replicate_palette, add_text=True, legend=False, height=3, text_kws={"fontsize":8, "fontweight":"bold"}, label="Brain Region"),
        Compartment=pch.anno_simple(df_col['MS_compartment'], colors={"Matrix": "blue", "Striosome": "red"}, height=3, label="Compartment"),
        MatStr_Score=pch.anno_simple(df_col['MS_NORM'], cmap="coolwarm_r", norm=ms_score_norm, height=3, label="MatStr Score"),
        verbose=1, axis=1, plot_legend=True, legend_gap=5,
    )

    cm2 = pch.ClusterMapPlotter(
        data=df_expr_norm_mch,
        top_annotation=col_ha,
        # left_annotation=left_ha,
        row_cluster=False,
        col_cluster=False,
        col_split=df_col[replicate_col] if groupby_replicate else None,
        col_split_order=replicate_order,
        row_dendrogram=False,
        col_dendrogram=False,
        label="Score",
        cmap=mch_cmap,
        rasterized=rasterized,
        ylabel=ylabel,
        xlabel="mCH",
        xlabel_side="top",
        plot=False,
        vmin=vmin, vmax=vmax,
    )


    plt.figure(figsize=figsize)
    ax, legend_axes = pch.composite(cmlist=[cm1, cm2], main=0, legend_hpad=3, col_gap=0.1)
    for _leg in legend_axes: 
        _leg.set_rasterized(False)
    ax.set_title(title, fontsize=title_fontsize, rasterized=False)

    # plot custom spines # For coloring the region in between the col splits as something other than white! 
    if color_splits is not None: 
        for cm in [cm1, cm2]:
            for i in range(cm.heatmap_axes.shape[0]):
                for j in range(cm.heatmap_axes.shape[1]-1):
                    # if i != j:
                    #     continue
                    ax = cm.heatmap_axes[i][j]
                    for side in ["right"]:
                        ax.spines[side].set_visible(True)
                        ax.spines[side].set_color(color_splits)
                        ax.spines[side].set_linewidth(2)

    if save and image_path is not None: 
        print("saving")
        plt.savefig(image_path / f"{filename}.png", dpi=300, bbox_inches="tight")
        plt.savefig(image_path / f"{filename}.pdf", dpi=300, bbox_inches="tight", transparent=False)
        # plt.savefig(image_path / f"{filename}.svg", dpi=300, bbox_inches="tight")
        # pass
    if show: 
        plt.show()
    plt.close()

# Spatial Plots

## Helper Functions

In [None]:
# TODO: @aklein Move to spida.pl
def plot_categorical(
    adata : ad.AnnData | str, 
    ax=None,
    coord_base='tsne',
    cluster_col='MajorType',
    palette_path=None,
    highlight_group : str | list | None = None,
    color_unknown : str = "#D0D0D0",
    is_background : bool = False,
    axis_format='tiny',
    alpha=0.7,
    coding=True,
    id_marker=True,
    output=None,
    show=True,
    figsize=(4, 3.5),
    sheet_name=None,
    ncol=None,
    fontsize=5,
    legend_fontsize=5,
    legend_kws=None,
    legend_title_fontsize=5,
    marker_fontsize=4,
    marker_pad=0.1,
    linewidth=0.5,
    ret=False,
    text_anno:bool = False,
    text_kws=None,
    **kwargs
):
    """
    Plot cluster.

    Parameters
    ----------
    adata_path :
    ax :
    coord_base :
    cluster_col :
    palette_path :
    highlight_group : str | list
        A group (or list of groups) in the annotation to highlight by setting all other groups to background color (light gray).
    color_unknown : str
        Color to use for "unknown" or NaN categories. Default is light gray (#D0D0D0).
    coding :
    output :
    show :
    figsize :
    sheet_name :
    ncol :
    fontsize :
    legend_fontsize : int
        legend fontsize, default 5
    legend_kws: dict
        kwargs passed to ax.legend
    legend_title_fontsize: int
        legend title fontsize, default 5
    marker_fontsize: int
        Marker fontsize, default 3
        if id_marker is True, and coding is True. legend marker will be a circle (or rectangle) with code
    linewidth : float
        Line width of the legend marker (circle or rectangle), default 0.5
    text_anon: bool
        Whether to add text annotation for each cluster, default False.
	ret : bool 
		Whether to return the plot and colors, default False.
    kwargs : dict
        set text_anno=None to plot clustering without text annotations,
        coding=True to plot clustering without code annotations,
        set show_legend=False to remove the legend

    Returns
    -------

    """
    if sheet_name is None: # for getting color scheme from excel file
        sheet_name=cluster_col
    if isinstance(adata,str): # getting adata
        adata=ad.read_h5ad(adata,backed='r')
    if not isinstance(adata.obs[cluster_col].dtype, pd.CategoricalDtype): # make sure cluster_col is categorical 
        adata.obs[cluster_col] = adata.obs[cluster_col].astype('category')
    # get palette
    if palette_path is not None:
        if isinstance(palette_path,str):
            colors=pd.read_excel(os.path.expanduser(palette_path),sheet_name=sheet_name,index_col=0).Hex.to_dict()
            keys=list(colors.keys())
            existed_vals=adata.obs[cluster_col].unique().tolist()
            for k in existed_vals:
                if k not in keys:
                    colors[k]='gray'
            for k in keys:
                if k not in existed_vals:
                    del colors[k]
        else:
            colors=palette_path
        adata.uns[cluster_col + '_colors'] = [colors.get(k, 'grey') for k in adata.obs[cluster_col].cat.categories.tolist()]
    else:
        if f'{cluster_col}_colors' not in adata.uns:
            colors = add_color_scheme(adata, cluster_col, palette_key=f"{cluster_col}_colors")
        else:
            colors={cluster:color for cluster,color in zip(adata.obs[cluster_col].cat.categories.tolist(),adata.uns[f'{cluster_col}_colors'])}
    data = adata[adata.obs[cluster_col].notna()]
    if highlight_group is not None:
        if isinstance(highlight_group, str):
            highlight_group = [highlight_group]
        all_groups = adata.obs[cluster_col].cat.categories.tolist()
        for group in all_groups:
            if group not in highlight_group:
                colors[group] = color_unknown
                data.obs.loc[data.obs[cluster_col] == group, cluster_col] = np.nan
        data = data[data.obs[cluster_col].notna(), :]
        data.obs[cluster_col] = data.obs[cluster_col].cat.remove_unused_categories()
    colors = {k: v for k, v in colors.items() if k in data.obs[cluster_col].cat.categories.tolist()}
    if is_background:
        bg_colors = {k: color_unknown for k in colors.keys()}
        bg_colors['unknown'] = color_unknown
        colors = bg_colors        

    hue=cluster_col
    text_anno = cluster_col if text_anno else None
    text_kws = {} if text_kws is None else text_kws
    text_kws.setdefault("fontsize", fontsize)
    kwargs.setdefault("hue",hue)
    kwargs.setdefault("text_anno", text_anno)
    kwargs.setdefault("text_kws", text_kws)
    kwargs.setdefault("luminance", 0.65)
    kwargs.setdefault("dodge_text", False)
    kwargs.setdefault("axis_format", axis_format)
    kwargs.setdefault("show_legend", True)
    kwargs.setdefault("marker_fontsize", marker_fontsize)
    kwargs.setdefault("marker_pad", marker_pad)
    kwargs.setdefault("linewidth", linewidth)
    kwargs.setdefault("alpha", alpha)
    kwargs["coding"]=coding
    kwargs["id_marker"]=id_marker
    legend_kws={} if legend_kws is None else legend_kws
    default_lgd_kws=dict(
        fontsize=legend_fontsize,
        title=cluster_col,title_fontsize=legend_title_fontsize)
    if ncol is not None:
        default_lgd_kws['ncol']=ncol
    for k in default_lgd_kws:
        legend_kws.setdefault(k, default_lgd_kws[k])
    kwargs.setdefault("dodge_kws", {
            "arrowprops": {
                "arrowstyle": "->",
                "fc": 'grey',
                "ec": "none",
                "connectionstyle": "angle,angleA=-90,angleB=180,rad=5",
            },
            'autoalign': 'xy'})
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize, dpi=300)
    p = categorical_scatter(
        data=data, #adata[adata.obs[cluster_col].notna(),],
        ax=ax,
        coord_base=coord_base,
        palette=colors,
        legend_kws=legend_kws,
        **kwargs)

    if output is not None:
        plt.savefig(os.path.expanduser(output),bbox_inches='tight',dpi=300)
    if show:
        plt.show()
    
    if ret:
        return p, colors
    else: 
        return None, None

In [None]:
## Set up scale bar function

def make_scale_bar(ax, x_coords, y_coords, microns_per_pixel=0.106, x_pct=100, y_pct=100):
    """
    Adds a scale bar to a matplotlib plot.

    Parameters:
    - ax: matplotlib Axes object
    - x_coords: list or array of x coordinates
    - y_coords: list or array of y coordinates
    - microns_per_pixel: conversion factor from pixels to microns
    - x_pct: percentile of x-coordinates to position the scale bar horizontally
    - y_pct: percentile of y-coordinates to position the scale bar vertically
    
    Example usage:
    - fig, ax = plt.subplots()
    - ax.scatter(x, y)
    - make_scale_bar(ax, x, y)

    """
    # Calculate x-axis range
    x_range = [min(x_coords), max(x_coords)]
    x_length = x_range[1] - x_range[0]
    x_length_um = x_length * microns_per_pixel

    # Target scale length ~1/6 of the x-axis
    target = x_length_um / 6

    # Compute order of magnitude
    order = 10 ** np.floor(np.log10(target))
    mantissa = target / order

    # Round mantissa to nearest 1, 2, or 5
    if mantissa < 1.5:
        nice_mantissa = 1
    elif mantissa < 3.5:
        nice_mantissa = 2
    elif mantissa < 7.5:
        nice_mantissa = 5
    else:
        nice_mantissa = 10

    # Final scale length in pixels
    scale_length_um = nice_mantissa * order
    scale_length_px = scale_length_um / microns_per_pixel

    # Format label
    if scale_length_um >= 1000:
        scale_label = f"{scale_length_um / 1000:.1f} mm"
    else:
        scale_label = f"{scale_length_um:.0f} µm"

    # Set coordinates for the scale bar
    x_start = np.percentile(x_coords, x_pct) - scale_length_px * 1.1
    x_end = np.percentile(x_coords, x_pct) - scale_length_px * 0.1
    y_pos = np.percentile(y_coords, y_pct) - scale_length_px * 0.1
    
    # Set up background for scale bar
    scale_bg = mpatches.Rectangle(
        (x_start - scale_length_px * 0.05, y_pos - scale_length_px * 0.05),
        width=(x_end - x_start) + scale_length_px * 0.1,
        height=scale_length_px * 0.4,
        color='white', alpha=0.8,
        zorder = 10
    )
    
    # Add background
    ax.add_patch(scale_bg)

    # Draw scale bar
    ax.plot([x_start, x_end], [y_pos + scale_length_px * 0.3, y_pos + scale_length_px * 0.3], color='black', linewidth=2, zorder=11)
    ax.text((x_start + x_end) / 2, y_pos + scale_length_px * 0.3, scale_label,
            color='black', ha='center', va='bottom', zorder=12)

In [None]:
def to_dense_df(X, var_names, obs_names):
    """Return a dense pandas DataFrame cells×genes from AnnData.X."""
    if hasattr(X, "toarray"):
        arr = X.toarray()
    elif hasattr(X, "A"):
        arr = X.A
    else:
        arr = np.asarray(X)
    return pd.DataFrame(arr, index=obs_names, columns=var_names)

In [None]:
def plot_gene_groups(
    adata: ad.AnnData,
    gdf_geoms: gpd.GeoDataFrame,
    gene: str,
    groups: list[str],
    regions: list[str],
    group_level: str = "Group",
    labs: list[str] = None,
    donors: list[str] = None,
    layers = None,
    layer_norm = False,
    color_wm = False,
    hue_norm = 0.8,
    min_norm = 0, 
    max_norm = 90,
    image_path: str = None,
    image_name: str = None,
    rasterized: bool = False,
    save_fig: bool = False,
    show: bool = True, 
    shade_regions = False,
    shade_alpha=0.05,
    title_fontsize=16,
    title = None,
    show_scalebar = True,
): 
    if gene not in adata.var_names:
        print(f"Gene {gene} not found in adata.var_names")
        return
    adata = adata[adata.obs['brain_region'].isin(regions)].copy()
    if layers is not None: 
        adata.X = adata.layers[layers].copy()
    if layer_norm:
        normalize_adata(adata, log1p=True)
    
    if labs is None: 
        labs = adata.obs['replicate'].unique().tolist()
    if donors is None:
        donors = adata.obs['donor'].unique().tolist()

    nrows = len(labs)
    ncols = len(regions)
    nplots = len(donors)
    
    for i, donor in enumerate(donors):
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3), dpi=300)
        # adata_don = adata[adata.obs['donor'] == donor].copy()
        # hue_min = adata_don[:, gene].X.min()
        # hue_max = adata_don[:, gene].X.max() * 0.9
        for l, lab in enumerate(labs):
            adata_don = adata[(adata.obs['donor'] == donor) & (adata.obs['replicate'] == lab)].copy()
            X = to_dense_df(adata_don.X, adata_don.var_names, adata_don.obs_names)
            X = X[gene].values
            hue_min = np.percentile(X, min_norm)
            hue_max = np.percentile(X, max_norm)
            # hue_min = adata_don[:, gene].X.min()
            # hue_max = adata_don[:, gene].X.max() * hue_norm
            for r, region in enumerate(regions):
                ax = axes[l, r] if nrows > 1 and ncols > 1 else axes # [max(l, r)]
                # plot background cells
                adata_sub = adata[(adata.obs['donor'] == donor) & (adata.obs['brain_region'] == region) & (adata.obs['replicate'] == lab)].copy()
                categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax, rasterized=rasterized, axis_format=None)
                
                # Plot Data
                adata_plot = adata_sub[adata_sub.obs[group_level].isin(groups)].copy()
                if adata_plot.shape[0] != 0:
                    plot_continuous(adata_plot, coord_base="spatial", color_by=gene, cmap="YlOrRd", ax=ax, show=False, hue_norm=(hue_min, hue_max), rasterized=rasterized, axis_format=None) 
                ax.set_title(f"{region}", rasterized=False)

                # plot regions (MS vs. STR)
                sub_geoms = gdf_geoms.loc[
                    (gdf_geoms['brain_region'] == region) & 
                    (gdf_geoms['donor'] == donor) & 
                    (gdf_geoms['lab'] == lab)
                    # (gdf_geoms['type'].isin(["Striosome", "Matrix"]))
                ].copy()
                sub_geoms.plot(ax=ax, edgecolor='black', facecolor='none', linewidth=1, rasterized=rasterized)
                if color_wm: 
                    print("coloring_wm")
                    ss = sub_geoms[sub_geoms['type'] == "White_Matter"]
                    print(ss.shape)
                    ss.plot(ax=ax, color=color_wm, edgecolor="none", rasterized=rasterized).axis("off");
                if shade_regions: 
                    sub_geoms.plot(ax=ax, color=sub_geoms['type_color'], edgecolor='none', legend=True, alpha=shade_alpha, rasterized=rasterized).axis("off");
                # sub_geoms.plot(ax=ax, color="none", edgecolor='black', legend=True).axis("off");
        
        # TODO: CHANGE THE WAY TITLE PLOTS
        if nrows == 1 and ncols == 1: 
            print("A")
            if title is None: 
                ax.set_title(f"{region} - {donor} - {gene}", y=0.96, fontsize=title_fontsize, rasterized=False)
            else: 
                ax.set_title(title, y=0.96, fontsize=title_fontsize, rasterized=False)
            plt.suptitle("")
        else:
            print("B")
            plt.suptitle(f"Gene: {gene} Donor: {donor}, Expression", y=1.02, fontsize=title_fontsize, rasterized=False)
        if show_scalebar: 
            make_scale_bar(ax, adata_sub.obs['CENTER_X'], adata_sub.obs['CENTER_Y'], microns_per_pixel=1, x_pct=0, y_pct=0)
        if save_fig and image_path is not None:
            if image_name is not None: 
                image_save_path_png = f"{image_path}/{image_name}.png"
                image_save_path_pdf = f"{image_path}/{image_name}.pdf"
                # image_save_path_svg = f"{image_path}/{image_name}.svg"
            else: 
                image_save_path_png = f"{image_path}/gene_{gene}_donor_{donor}_spatial.png"
                image_save_path_pdf = f"{image_path}/gene_{gene}_donor_{donor}_spatial.pdf"
                # image_save_path_svg = f"{image_path}/gene_{gene}_donor_{donor}_spatial.svg"
            plt.savefig(image_save_path_png, bbox_inches='tight', dpi=300)
            plt.savefig(image_save_path_pdf, bbox_inches='tight', dpi=300)
            # plt.savefig(image_save_path_svg, bbox_inches='tight', dpi=300)
        if show: 
            plt.show()
        plt.close()

In [None]:
def plot_gene_groups_meth(
    adata: ad.AnnData,
    gdf_geoms: gpd.GeoDataFrame,
    gene: str,
    groups: list[str],
    regions: list[str],
    group_level: str = "Group",
    labs: list[str] = None,
    donors: list[str] = None,
    layers = None,
    layer_norm = False,
    color_wm = False,
    max_norm = 90,
    min_norm = 10,
    bg_adata: ad.AnnData = None,
    cmap="YlOrRd",
    save_fig: bool = False,
    image_path: str = None,
    image_name: str = None,
    rasterized: bool = False,
    show: bool = True, 
    shade_regions = False,
    shade_alpha=0.05,
    title_fontsize=16,
    title = None,
    show_scalebar = True,
): 
    
    if gene not in adata.var_names:
        print(f"Gene {gene} not found in adata.var_names")
        return
    adata = adata[adata.obs['brain_region'].isin(regions)].copy()
    if layers is not None: 
        adata.X = adata.layers[layers].copy()
    if layer_norm:
        normalize_adata(adata, log1p=True)
    
    if labs is None: 
        labs = adata.obs['replicate'].unique().tolist()
    if donors is None:
        donors = adata.obs['donor'].unique().tolist()

    nrows = len(labs)
    ncols = len(regions)
    nplots = len(donors)
    
    for i, donor in enumerate(donors):
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3), dpi=300)
        for l, lab in enumerate(labs):
            adata_don = adata[(adata.obs['donor'] == donor) & (adata.obs['replicate'] == lab)].copy()
            hue_min = np.percentile(adata_don[:, gene].X, min_norm)
            hue_max = np.percentile(adata_don[:, gene].X, max_norm)
            for r, region in enumerate(regions):
                ax = axes[l, r] if nrows > 1 and ncols > 1 else axes # [max(l, r)]
                adata_sub = adata[(adata.obs['donor'] == donor) & (adata.obs['brain_region'] == region) & (adata.obs['replicate'] == lab)].copy()
                
                # Plot bg
                if bg_adata is not None: 
                    bg_sub = bg_adata[(bg_adata.obs['donor'] == donor) & (bg_adata.obs['brain_region'] == region) & (bg_adata.obs['replicate'] == lab)].copy()
                    categorical_scatter(bg_sub, coord_base="spatial", color='lightgrey', ax=ax, axis_format=None, rasterized=rasterized)
                else: 
                    categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax, axis_format=None, rasterized=rasterized)

                # Plot Data
                adata_plot = adata_sub[adata_sub.obs[group_level].isin(groups)].copy()
                if adata_plot.shape[0] != 0:
                    plot_continuous(adata_plot, coord_base="spatial", color_by=gene, cmap=cmap, ax=ax, show=False, hue_norm=(hue_min, hue_max), 
                                    rasterized=rasterized, axis_format=None) 
                
                # plot regions (MS vs. STR)
                sub_geoms = gdf_geoms.loc[
                    (gdf_geoms['brain_region'] == region) & 
                    (gdf_geoms['donor'] == donor) & 
                    (gdf_geoms['lab'] == lab)
                    # (gdf_geoms['type'].isin(["Striosome", "Matrix"]))
                ].copy()
                sub_geoms.plot(ax=ax, edgecolor='black', facecolor='none', linewidth=1, rasterized=rasterized)
                if color_wm: 
                    print("coloring_wm")
                    ss = sub_geoms[sub_geoms['type'] == "White_Matter"]
                    print(ss.shape)
                    ss.plot(ax=ax, color=color_wm, edgecolor="none", rasterized=rasterized).axis("off");
                if shade_regions: 
                    sub_geoms.plot(ax=ax, color=sub_geoms['type_color'], edgecolor='none', legend=True, alpha=shade_alpha, rasterized=rasterized).axis("off");
                # sub_geoms.plot(ax=ax, color="none", edgecolor='black', legend=True).axis("off");

        # TODO: CHANGE THE WAY TITLE PLOTS
        if nrows == 1 and ncols == 1: 
            print("A")
            if title is None: 
                ax.set_title(f"{region} - {donor} - {gene}", y=0.96, fontsize=title_fontsize, rasterized=False)
            else: 
                ax.set_title(title, y=0.96, fontsize=title_fontsize, rasterized=False)
            plt.suptitle("")
        else:
            print("B")
            plt.suptitle(f"Gene: {gene} Donor: {donor}, Expression", y=1.02, fontsize=title_fontsize, rasterized=False)
        if show_scalebar: 
            make_scale_bar(ax, adata_sub.obs['CENTER_X'], adata_sub.obs['CENTER_Y'], microns_per_pixel=1, x_pct=0, y_pct=0)
        if save_fig and image_path is not None:
            if image_name is not None: 
                image_save_path_png = f"{image_path}/{image_name}.png"
                image_save_path_pdf = f"{image_path}/{image_name}.pdf"
                # image_save_path_svg = f"{image_path}/{image_name}.svg"
            else: 
                image_save_path_png = f"{image_path}/gene_{gene}_donor_{donor}_spatial.png"
                image_save_path_pdf = f"{image_path}/gene_{gene}_donor_{donor}_spatial.pdf"
                # image_save_path_svg = f"{image_path}/gene_{gene}_donor_{donor}_spatial.svg"
            plt.savefig(image_save_path_png, bbox_inches='tight', dpi=300)
            plt.savefig(image_save_path_pdf, bbox_inches='tight', dpi=300)
            # plt.savefig(image_save_path_svg, bbox_inches='tight', dpi=300)
        if show: 
            plt.show()
        plt.close()

## Plots

In [None]:
### Load Geoms path
geom_store_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/regions/region_geometries_cps.parquet"
gdf_geoms = gpd.read_parquet(geom_store_path)
gdf_geoms.head()

In [None]:
adata.X = adata.layers['counts'].copy()

In [None]:
RASTERIZED = True

### STR D1 MSN 

PDYN + BACH2 good for Striosome

KIRREL3 + GRM1 good for Matrix

There are many more these are just two that I chose

In [None]:
genes = ["PDYN", "BACH2", "KIRREL3", "GRM1"]

In [None]:
subclass_plot = ["STR D1 MSN"]
_donor = "UCI5224"
_lab = "ucsd"
_regions = ["PU"]
_gene = "BACH2"

subc_title = ", ".join(subclass_plot)
reg_title = ", ".join(_regions)

subc_filename = "_".join([s.replace(" ", "_") for s in subclass_plot])
reg_filename = "_".join([r.replace(" ", "_") for r in _regions])

for _gene in genes: 
    plot_gene_groups(
        adata,
        gdf_geoms,
        groups=subclass_plot,
        group_level="Subclass", 
        regions=_regions,
        donors=[_donor],
        labs = [_lab],
        gene=_gene,
        layers='volume_norm',
        layer_norm=True,
        min_norm=0,
        max_norm=99,
        title_fontsize=14,
        shade_regions = False, 
        color_wm = "white",
        image_path=image_path,
        image_name=f'ms_spatial_{subc_filename}_{_gene}_{_donor}_{_lab}_{reg_filename}_rna',
        title = f"{subc_title}\n{reg_title} - {_gene}",
        rasterized=RASTERIZED,
        show=False,
        save_fig=True,
        show_scalebar=True,
    )

In [None]:
subclass_plot = ["STR D1 MSN"]
_donor = "UCI5224"
_lab = "salk"
_regions = ["PU"]
_gene = "BACH2"

subc_title = ", ".join(subclass_plot)
reg_title = ", ".join(_regions)

subc_filename = "_".join([s.replace(" ", "_") for s in subclass_plot])
reg_filename = "_".join([r.replace(" ", "_") for r in _regions])

for _gene in genes: 
    plot_gene_groups_meth(
        spatial_mch,
        gdf_geoms,
        bg_adata = adata,
        groups=subclass_plot,
        group_level="Subclass", 
        regions=_regions,
        donors=[_donor],
        labs = [_lab],
        gene=_gene,
        layers=None,
        layer_norm=False,
        cmap="Blues_r",    
        min_norm=20,
        max_norm=80,
        color_wm = 'white',
        shade_regions = False, 
        shade_alpha=0.05,
        image_path=image_path,
        image_name=f'ms_spatial_{subc_filename}_{_gene}_{_donor}_{_lab}_{reg_filename}_mch',
        title = f"{subc_title}\n{reg_title} - {_gene}",
        title_fontsize=14,
        save_fig=True,
        show=False,
        rasterized=RASTERIZED,
        show_scalebar=True
    )

In [None]:
# subclass_plot = ["STR D1 MSN"]
# _donor = "UCI5224"
# _lab = "ucsd"
# _regions = ["PU"]
# _gene = "KIRREL3"

# subc_title = ", ".join(subclass_plot)
# reg_title = ", ".join(_regions)

# subc_filename = "_".join([s.replace(" ", "_") for s in subclass_plot])
# reg_filename = "_".join([r.replace(" ", "_") for r in _regions])

# plot_gene_groups(
#     adata,
#     gdf_geoms,
#     groups=subclass_plot,
#     group_level="Subclass", 
#     regions=_regions,
#     donors=[_donor],
#     labs = [_lab],
#     gene=_gene,
#     layers='volume_norm',
#     layer_norm=True,
#     hue_norm=0.5,
#     shade_regions = False, 
#     color_wm = "white",
#     image_path=image_path,
#     image_name=f'ms_spatial_{subc_filename}_{_gene}_{_donor}_{_lab}_{reg_filename}_rna',
#     title = f"{subc_title}\n{reg_title} - {_gene}",
#     title_fontsize=18,
#     save_fig=True,
#     show=False,
#     rasterized=True
# )

In [None]:
# subclass_plot = ["STR D1 MSN"]
# _donor = "UCI5224"
# _lab = "salk"
# _regions = ["PU"]
# _gene = "BACH2"

# subc_title = ", ".join(subclass_plot)
# reg_title = ", ".join(_regions)

# subc_filename = "_".join([s.replace(" ", "_") for s in subclass_plot])
# reg_filename = "_".join([r.replace(" ", "_") for r in _regions])

# plot_gene_groups_meth(
#     spatial_mch,
#     gdf_geoms,
#     bg_adata = adata,
#     groups=subclass_plot,
#     group_level="Subclass", 
#     regions=_regions,
#     donors=[_donor],
#     labs = [_lab],
#     gene=_gene,
#     layers=None,
#     layer_norm=False,
#     cmap="Blues_r",    
#     min_norm=20,
#     max_norm=80,
#     color_wm = 'white',
#     shade_regions = False, 
#     shade_alpha=0.05,
#     image_path=image_path,
#     image_name=f'ms_spatial_{subc_filename}_{_gene}_{_donor}_{_lab}_{reg_filename}_mch',
#     title = f"{subc_title}\n{reg_title} - {_gene}",
#     title_fontsize=18,
#     save_fig=True,
#     show=False,
#     rasterized=True
# )

In [None]:
# subclass_plot = ["STR D1 MSN"]
# _donor = "UCI5224"
# _lab = "salk"
# _regions = ["PU"]
# _gene = "KIRREL3"

# subc_title = ", ".join(subclass_plot)
# reg_title = ", ".join(_regions)

# subc_filename = "_".join([s.replace(" ", "_") for s in subclass_plot])
# reg_filename = "_".join([r.replace(" ", "_") for r in _regions])

# plot_gene_groups_meth(
#     spatial_mch,
#     gdf_geoms,
#     bg_adata = adata,
#     groups=subclass_plot,
#     group_level="Subclass", 
#     regions=_regions,
#     donors=[_donor],
#     labs = [_lab],
#     gene=_gene,
#     layers=None,
#     layer_norm=False,
#     cmap="Blues_r",    
#     min_norm=20,
#     max_norm=80,
#     color_wm = 'white',
#     shade_regions = False, 
#     shade_alpha=0.05,
#     image_path=image_path,
#     image_name=f'ms_spatial_{subc_filename}_{_gene}_{_donor}_{_lab}_{reg_filename}_mch',
#     title = f"{subc_title}\n{reg_title} - {_gene}",
#     title_fontsize=18,
#     show=False,
#     save_fig=True,
#     rasterized=True
# )

### CN ST18 GABA + STR FS PTHLH-PVALB GABA

PLXNA4 + ALK + RASGRF2 for Striosome 
- PLXNA4 is a marker for *VIP GABA* interneurons
- ALK is a marker for *STR-BF TAC3-PLPP4-LHX8 GABA* and *CN LAMP5-CXCL14 GABA*
- RASGRF2 is a marker for *STR FS PTHLH-PVALB GABA*

GLP1R + GALNT17 for Matrix
- GLP1R is a marker for *CN ST18 GABA neurons*
- GALNT17 is not a marker!

In [None]:
# subclass_plot = ["CN ST18 GABA"]
# donors = ["UCI5224", "UCI2424", "UCI4723", "UWA7648"]
# labs = ["salk", "ucsd"]
# _regions = ["CAH", "CAB", "PU", "NAC"]
# _gene = "PLXNA4"

# subc_title = ", ".join(subclass_plot)
# reg_title = ", ".join(_regions)

# subc_filename = "_".join([s.replace(" ", "_") for s in subclass_plot])
# reg_filename = "_".join([r.replace(" ", "_") for r in _regions])

# plot_gene_groups(
#     adata,
#     gdf_geoms,
#     groups=subclass_plot,
#     group_level="Subclass", 
#     regions=_regions,
#     donors= donors,
#     labs = labs,
#     gene=_gene,
#     layers='volume_norm',
#     layer_norm=False,
#     min_norm=1,
#     max_norm=99,
#     title_fontsize=18,
#     shade_regions = False, 
#     color_wm = "white",
#     image_path=image_path,
#     image_name=f'ms_spatial_{subc_filename}_{_gene}_{_donor}_{_lab}_{reg_filename}_rna',
#     title = f"{subc_title}\n{reg_title} - {_gene}",
#     rasterized=True,
#     show=True,
#     save_fig=False,
# )

In [None]:
genes = ["PLXNA4", "ALK", "RASGRF2", "GLP1R", "GALNT17"]

In [None]:
subclass_plot = ["CN ST18 GABA"]
_donor = "UWA7648"
_lab = "ucsd"
_regions = ["CAH"]

subc_title = ", ".join(subclass_plot)
reg_title = ", ".join(_regions)

subc_filename = "_".join([s.replace(" ", "_") for s in subclass_plot])
reg_filename = "_".join([r.replace(" ", "_") for r in _regions])

for _gene in genes:

    plot_gene_groups(
        adata,
        gdf_geoms,
        groups=subclass_plot,
        group_level="Subclass", 
        regions=_regions,
        donors=[_donor],
        labs = [_lab],
        gene=_gene,
        layers='volume_norm',
        layer_norm=False,
        # hue_norm=0.5,
        min_norm=0,
        max_norm=99,
        title_fontsize=14,
        shade_regions = False, 
        color_wm = "white",
        image_path=image_path,
        image_name=f'ms_spatial_{subc_filename}_{_gene}_{_donor}_{_lab}_{reg_filename}_rna',
        title = f"{subc_title}\n{reg_title} - {_gene}",
        rasterized=RASTERIZED,
        show=False,
        save_fig=True,
        show_scalebar=True,
    )

In [None]:
subclass_plot = ["CN ST18 GABA"]
_donor = "UWA7648"
_lab = "ucsd"
_regions = ["CAH"]

subc_title = ", ".join(subclass_plot)
reg_title = ", ".join(_regions)

subc_filename = "_".join([s.replace(" ", "_") for s in subclass_plot])
reg_filename = "_".join([r.replace(" ", "_") for r in _regions])

for _gene in genes:
    plot_gene_groups_meth(
        spatial_mch,
        gdf_geoms,
        bg_adata = adata,
        groups=subclass_plot,
        group_level="Subclass", 
        regions=_regions,
        donors=[_donor],
        labs = [_lab],
        gene=_gene,
        layers=None,
        layer_norm=False,
        cmap="Blues_r",    
        min_norm=2,
        max_norm=98,
        color_wm = 'white',
        shade_regions = False, 
        shade_alpha=0.05,
        image_path=image_path,
        image_name=f'ms_spatial_{subc_filename}_{_gene}_{_donor}_{_lab}_{reg_filename}_mch',
        title = f"{subc_title}\n{reg_title} - {_gene}",
        title_fontsize=14,
        show=False,
        save_fig=True,
        rasterized=RASTERIZED,
        show_scalebar=True
    )

In [None]:
subclass_plot = ["STR FS PTHLH-PVALB GABA"]
_donor = "UWA7648"
_lab = "ucsd"
_regions = ["CAH"]
_gene = "EPHA4"

subc_title = ", ".join(subclass_plot)
reg_title = ", ".join(_regions)

subc_filename = "_".join([s.replace(" ", "_") for s in subclass_plot])
reg_filename = "_".join([r.replace(" ", "_") for r in _regions])

for _gene in genes: 
    plot_gene_groups(
        adata,
        gdf_geoms,
        groups=subclass_plot,
        group_level="Group", 
        regions=_regions,
        donors=[_donor],
        labs = [_lab],
        gene=_gene,
        layers='volume_norm',
        layer_norm=False,
        # hue_norm=0.5,
        min_norm=0,
        max_norm=99,
        title_fontsize=12,
        shade_regions = False, 
        color_wm = "white",
        image_path=image_path,
        image_name=f'ms_spatial_{subc_filename}_{_gene}_{_donor}_{_lab}_{reg_filename}_rna',
        title = f"{subc_title}\n{reg_title} - {_gene}",
        rasterized=RASTERIZED,
        show=False,
        save_fig=True,
        show_scalebar=True,
    )

In [None]:
subclass_plot = ["STR FS PTHLH-PVALB GABA"]
_donor = "UWA7648"
_lab = "ucsd"
_regions = ["CAH"]
_gene = "EPHA4"

subc_title = ", ".join(subclass_plot)
reg_title = ", ".join(_regions)

subc_filename = "_".join([s.replace(" ", "_") for s in subclass_plot])
reg_filename = "_".join([r.replace(" ", "_") for r in _regions])

for _gene in genes: 
    plot_gene_groups_meth(
        spatial_mch,
        gdf_geoms,
        bg_adata = adata,
        groups=subclass_plot,
        group_level="Group", 
        regions=_regions,
        donors=[_donor],
        labs = [_lab],
        gene=_gene,
        layers=None,
        layer_norm=False,
        cmap="Blues_r",
        min_norm=1,
        max_norm=90,
        color_wm = 'white',
        shade_regions = False, 
        shade_alpha=0.05,
        image_path=image_path,
        image_name=f'ms_spatial_{subc_filename}_{_gene}_{_donor}_{_lab}_{reg_filename}_mch',
        title = f"{subc_title}\n{reg_title} - {_gene}",
        title_fontsize=12,
        show=False,
        save_fig=True,
        rasterized=RASTERIZED,
        show_scalebar=True
    )

In [None]:
subclass_plot = ["STR FS PTHLH-PVALB GABA"]
_donor = "UWA7648"
_lab = "ucsd"
_regions = ["CAH"]
_gene = "GLP1R"

subc_title = ", ".join(subclass_plot)
reg_title = ", ".join(_regions)

subc_filename = "_".join([s.replace(" ", "_") for s in subclass_plot])
reg_filename = "_".join([r.replace(" ", "_") for r in _regions])

plot_gene_groups(
    adata,
    gdf_geoms,
    groups=subclass_plot,
    group_level="Group", 
    regions=_regions,
    donors=[_donor],
    labs = [_lab],
    gene=_gene,
    layers='volume_norm',
    layer_norm=False,
    # hue_norm=0.5,
    min_norm=1,
    max_norm=99,
    title_fontsize=18,
    shade_regions = False, 
    color_wm = "white",
    image_path=image_path,
    image_name=f'ms_spatial_{subc_filename}_{_gene}_{_donor}_{_lab}_{reg_filename}_rna',
    title = f"{subc_title}\n{reg_title} - {_gene}",
    rasterized=True,
    show=True,
    save_fig=False,
)

In [None]:
subclass_plot = ["STR FS PTHLH-PVALB GABA"]
_donor = "UWA7648"
_lab = "ucsd"
_regions = ["CAH"]
_gene = "GLP1R"

subc_title = ", ".join(subclass_plot)
reg_title = ", ".join(_regions)

subc_filename = "_".join([s.replace(" ", "_") for s in subclass_plot])
reg_filename = "_".join([r.replace(" ", "_") for r in _regions])

plot_gene_groups_meth(
    spatial_mch,
    gdf_geoms,
    bg_adata = adata,
    groups=subclass_plot,
    group_level="Group", 
    regions=_regions,
    donors=[_donor],
    labs = [_lab],
    gene=_gene,
    layers=None,
    layer_norm=False,
    cmap="Blues_r",    
    min_norm=1,
    max_norm=90,
    color_wm = 'white',
    shade_regions = False, 
    shade_alpha=0.05,
    image_path=image_path,
    image_name=f'ms_spatial_{subc_filename}_{_gene}_{_donor}_{_lab}_{reg_filename}_mch',
    title = f"{subc_title}\n{reg_title} - {_gene}",
    title_fontsize=18,
    show=True,
    save_fig=False,
    rasterized=True
)