In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
from scipy.stats import norm, spearmanr
from statsmodels.stats.multitest import multipletests

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 8
plt.rcParams['axes.facecolor'] = 'white'

## Plotting Functions

In [None]:
# Downsample Function
def _downsample_reference(
    ref_adata : ad.AnnData,
    cluster_col : str,
    max_cluster_size: int = 3000,
    min_cluster_size: int = 0,
):
    """
    Remove clusters from the reference that have less than min_cluster_size cells.
    Downsample larger clusters that have more than max_cluster_size cells.
    """
    from spida.utilities._ad_utils import _downsample_ref_clusters, _remove_small_clusters
    if min_cluster_size > 0: 
        ref_adata = _remove_small_clusters(ref_adata, cluster_col, min_cells=min_cluster_size)
    if max_cluster_size > 0:
        ref_adata = _downsample_ref_clusters(ref_adata, cluster_col, max_cells=max_cluster_size)
    return ref_adata

In [None]:
def expr_heatmap(
    adata,
    df_genes,
    n_genes=40,
    gene_names=10,
    heatmap_order_col='rho_axis',
    groupby_replicate=True,
    replicate_col="brain_region",
    replicate_palette=None,
    replicate_order=None,
    title="",
    out_path=None,
    ylabel="Genes",
    xlabel="Cells",
    cmap="plasma",
    rasterized=True,
    show=True, 
    save=False,
    image_path=None,
    min_max_quantiles=(0.02, 0.98),
    vmin=0, vmax=1,
    color_splits=None,
):
    """Plot expression heatmap for given genes and cell type"""
    import PyComplexHeatmap as pch
    from matplotlib.colors import TwoSlopeNorm

    if replicate_palette is None:
        replicate_palette = adata.uns.get(f"{replicate_col}_palette", None)

    df_col = adata.obs[['Group', 'MS_NORM', 'MS_compartment']].copy()
    if df_col.shape[0] > 50000:
        df_col = df_col.sample(50000, random_state=42)
    if groupby_replicate:
        df_col[replicate_col] = adata.obs[replicate_col]
        df_col = df_col.sort_values([replicate_col, 'MS_NORM'])
    else:
        df_col = df_col.sort_values('MS_NORM')

    df_row = df_genes.iloc[:n_genes].copy()
    df_row.drop_duplicates(subset='gene', keep='first', inplace=True)
    if "gene" in df_row.columns:
        df_row = df_row.set_index('gene')
    toplot = df_row.index[:gene_names]
    df_row['annot'] = [c if c in toplot else np.nan for c in df_row.index]
    df_row = df_row.sort_values(heatmap_order_col, ascending=True)
    # if "gene" in df_row.columns:
    #     df_row = df_row.set_index('gene')

    if hasattr(adata.X, "toarray"):
        df_expr = adata.X.toarray()
    else: 
        df_expr = adata.X
    df_expr = pd.DataFrame(df_expr, index=adata.obs_names, columns=adata.var_names).T
    df_expr = df_expr.loc[df_row.index, df_col.index]
    df_expr_norm = (
        df_expr
        .subtract(df_expr.quantile(min_max_quantiles[0], axis=1), axis=0)
        .div(df_expr.quantile(min_max_quantiles[1], axis=1) - df_expr.quantile(min_max_quantiles[0], axis=1), axis=0)
    )

    ms_score_norm = TwoSlopeNorm(vmin=-1, vcenter=0, vmax=1)
    col_ha = pch.HeatmapAnnotation(
        # label=pch.anno_label(
        #     df_col['MS_compartment'], merge=True, rotation=0, extend=True,
        #     colors={"Matrix": "blue", "Striosome": "red"}, 
        # ),
        Brain_Region=pch.anno_simple(df_col[replicate_col], colors=replicate_palette, add_text=True, legend=False, height=3, text_kws={"fontsize":8, "fontweight":"bold"}),
        Compartment=pch.anno_simple(df_col['MS_compartment'], colors={"Matrix": "blue", "Striosome": "red"}, height=3),
        MatStr_Score=pch.anno_simple(df_col['MS_NORM'], cmap="coolwarm_r", norm=ms_score_norm, height=3),
        verbose=1, axis=1, plot_legend=True, legend_gap=5,
    )

    left_ha = pch.HeatmapAnnotation(
        label=pch.anno_label(
            df_row['annot'], merge=True, rotation=0, extend=True,
            colors="black", relpos=(1, 0.5), 
            arrowprops = dict(visible=True,)
        ),
        # Genes=pch.anno_simple(df_row[0]),
        verbose=1, axis=0
    )

    plt.figure(figsize=(8,4))
    cm = pch.ClusterMapPlotter(
        data=df_expr_norm,
        top_annotation=col_ha,
        left_annotation=left_ha,
        row_cluster=False,
        col_cluster=False,
        col_split=df_col[replicate_col] if groupby_replicate else None,
        row_dendrogram=False,
        label="Expression",
        cmap=cmap,
        rasterized=rasterized,
        ylabel=ylabel,
        xlabel=title,
        xlabel_side="top",
        vmin=vmin, vmax=vmax,
        col_split_order=replicate_order
    )

    # plot custom spines # For coloring the region in between the col splits as something other than white! 
    if color_splits is not None: 
        for i in range(cm.heatmap_axes.shape[0]-1):
            for j in range(cm.heatmap_axes.shape[1]):
                # if i != j:
                #     continue
                ax = cm.heatmap_axes[i][j]
                for side in ["right"]:
                    ax.spines[side].set_visible(True)
                    ax.spines[side].set_color(color_splits)
                    ax.spines[side].set_linewidth(2)

    # plt.suptitle(title)
    if save and image_path is not None: 
        print("saving")
        plt.savefig(image_path + f"/expr_heatmap_{title.replace(' ', '_')}.png", dpi=300, bbox_inches="tight")
        plt.savefig(image_path + f"/expr_heatmap_{title.replace(' ', '_')}.pdf", dpi=300, bbox_inches="tight")
        # plt.savefig(image_path + f"/expr_heatmap_{title.replace(' ', '_')}.svg", dpi=300, bbox_inches="tight")
        # pass
    if show: 
        plt.show()
    plt.close()

In [None]:
def joint_expr_heatmap(
    adata_rna,
    adata_mch,
    df_genes,
    n_genes=40,
    gene_names=10,
    heatmap_order_col='rho_axis',
    groupby_replicate=True,
    replicate_col="brain_region",
    replicate_palette=None,
    replicate_order=None,
    title="",
    out_path=None,
    ylabel="Genes",
    xlabel="Cells",
    rna_cmap="Reds",
    mch_cmap="parula",
    rasterized=True,
    show=True, 
    save=False,
    image_path=None,
    meth_min_max_quantiles=(0.02, 0.98),
    rna_min_max_quantiles=(0.02, 0.98),
    vmin=0, vmax=1,
    color_splits=None,
    figsize=(8,4)
):
    """Plot expression heatmap for given genes and cell type"""
    import PyComplexHeatmap as pch
    from matplotlib.colors import TwoSlopeNorm

    if replicate_palette is None:
        replicate_palette = adata_rna.uns.get(f"{replicate_col}_palette", None)

    df_col = adata_rna.obs[['Group', 'MS_NORM', 'MS_compartment']].copy()
    if df_col.shape[0] > 50000:
        df_col = df_col.sample(50000, random_state=42)
    if groupby_replicate:
        df_col[replicate_col] = adata_rna.obs[replicate_col]
        df_col = df_col.sort_values([replicate_col, 'MS_NORM'])
    else:
        df_col = df_col.sort_values('MS_NORM')

    df_row = df_genes.iloc[:n_genes].copy()
    df_row.drop_duplicates(subset='gene', keep='first', inplace=True)
    if "gene" in df_row.columns:
        df_row = df_row.set_index('gene')
    toplot = df_row.index[:gene_names]
    df_row['annot'] = [c if c in toplot else np.nan for c in df_row.index]
    df_row = df_row.sort_values(heatmap_order_col, ascending=True)
    # if "gene" in df_row.columns:
    #     df_row = df_row.set_index('gene')

    if hasattr(adata_rna.X, "toarray"):
        df_expr_rna = adata_rna.X.toarray()
    else: 
        df_expr_rna = adata_rna.X
    df_expr_rna = pd.DataFrame(df_expr_rna, index=adata_rna.obs_names, columns=adata_rna.var_names).T
    df_expr_rna = df_expr_rna.loc[df_row.index, df_col.index]
    # df_expr_norm_rna = (
    #     df_expr_rna
    #     .subtract(df_expr_rna.quantile(min_max_quantiles[0], axis=1), axis=0)
    #     .div(df_expr_rna.quantile(min_max_quantiles[1], axis=1) - df_expr_rna.quantile(min_max_quantiles[0], axis=1), axis=0)
    # )
    repl = []
    for cc, df in df_col.groupby(replicate_col, observed=True): 
        df_samp = df_expr_rna[ df.index ].copy()
        df_norm = (
            df_samp
            .subtract(df_samp.quantile(rna_min_max_quantiles[0], axis=1), axis=0)
            .div(df_samp.quantile(rna_min_max_quantiles[1], axis=1) - df_samp.quantile(rna_min_max_quantiles[0], axis=1), axis=0)
        )
        repl.append(df_norm)
    df_expr_norm_rna =  pd.concat(repl, axis=1)

    if hasattr(adata_mch.X, "toarray"):
        df_expr_mch = adata_mch.X.toarray()
    else: 
        df_expr_mch = adata_mch.X
    df_expr_mch = pd.DataFrame(df_expr_mch, index=adata_mch.obs_names, columns=adata_mch.var_names).T
    df_expr_mch = df_expr_mch.loc[df_row.index, df_col.index]
    # df_expr_norm_mch = (
    #     df_expr_mch
    #     .subtract(df_expr_mch.quantile(min_max_quantiles[0], axis=1), axis=0)
    #     .div(df_expr_mch.quantile(min_max_quantiles[1], axis=1) - df_expr_mch.quantile(min_max_quantiles[0], axis=1), axis=0)
    # )
    repl = []
    for cc, df in df_col.groupby(replicate_col, observed=True): 
        df_samp = df_expr_mch[ df.index ].copy()
        df_norm = (
            df_samp
            .subtract(df_samp.quantile(meth_min_max_quantiles[0], axis=1), axis=0)
            .div(df_samp.quantile(meth_min_max_quantiles[1], axis=1) - df_samp.quantile(meth_min_max_quantiles[0], axis=1), axis=0)
        )
        repl.append(df_norm)
    df_expr_norm_mch =  pd.concat(repl, axis=1)

    ms_score_norm = TwoSlopeNorm(vmin=-1, vcenter=0, vmax=1)
    col_ha = pch.HeatmapAnnotation(
        Brain_Region=pch.anno_simple(df_col[replicate_col], colors=replicate_palette, add_text=True, legend=False, height=3, text_kws={"fontsize":8, "fontweight":"bold"}),
        Compartment=pch.anno_simple(df_col['MS_compartment'], colors={"Matrix": "blue", "Striosome": "red"}, height=3),
        MatStr_Score=pch.anno_simple(df_col['MS_NORM'], cmap="coolwarm_r", norm=ms_score_norm, height=3),
        verbose=1, axis=1, plot_legend=True, legend_gap=5, label_kws={'visible':False}
    )

    left_ha = pch.HeatmapAnnotation(
        label=pch.anno_label(
            df_row['annot'], merge=True, rotation=0, extend=True,
            colors="black", relpos=(1, 0.7), height=3,
            arrowprops = dict(visible=True,)
        ),
        verbose=1, axis=0
    )

    cm1 = pch.ClusterMapPlotter(
        data=df_expr_norm_rna,
        top_annotation=col_ha,
        left_annotation=left_ha,
        row_cluster=False,
        col_cluster=False,
        col_split=df_col[replicate_col] if groupby_replicate else None,
        col_split_order=replicate_order,
        row_dendrogram=False,
        label="Expression",
        cmap=rna_cmap,
        rasterized=rasterized,
        ylabel=ylabel,
        xlabel="RNA",
        xlabel_side="top",
        plot=False,
        vmin=vmin, vmax=vmax,
    )

    col_ha = pch.HeatmapAnnotation(
        Brain_Region=pch.anno_simple(df_col[replicate_col], colors=replicate_palette, add_text=True, legend=False, height=3, text_kws={"fontsize":8, "fontweight":"bold"}),
        Compartment=pch.anno_simple(df_col['MS_compartment'], colors={"Matrix": "blue", "Striosome": "red"}, height=3),
        MatStr_Score=pch.anno_simple(df_col['MS_NORM'], cmap="coolwarm_r", norm=ms_score_norm, height=3),
        verbose=1, axis=1, plot_legend=True, legend_gap=5,
    )

    cm2 = pch.ClusterMapPlotter(
        data=df_expr_norm_mch,
        top_annotation=col_ha,
        # left_annotation=left_ha,
        row_cluster=False,
        col_cluster=False,
        col_split=df_col[replicate_col] if groupby_replicate else None,
        col_split_order=replicate_order,
        row_dendrogram=False,
        col_dendrogram=False,
        label="Score",
        cmap=mch_cmap,
        rasterized=rasterized,
        ylabel=ylabel,
        xlabel="mCH",
        xlabel_side="top",
        plot=False,
        vmin=vmin, vmax=vmax,
    )


    plt.figure(figsize=figsize)
    ax, legend_axes = pch.composite(cmlist=[cm1, cm2], main=0, legend_hpad=3, col_gap=0.1)
    ax.set_title(title)

    # plot custom spines # For coloring the region in between the col splits as something other than white! 
    if color_splits is not None: 
        for cm in [cm1, cm2]:
            for i in range(cm.heatmap_axes.shape[0]):
                for j in range(cm.heatmap_axes.shape[1]-1):
                    # if i != j:
                    #     continue
                    ax = cm.heatmap_axes[i][j]
                    for side in ["right"]:
                        ax.spines[side].set_visible(True)
                        ax.spines[side].set_color(color_splits)
                        ax.spines[side].set_linewidth(2)

    if save and image_path is not None: 
        print("saving")
        plt.savefig(image_path + f"/expr_heatmap_{title.replace(' ', '_')}.png", dpi=300, bbox_inches="tight")
        plt.savefig(image_path + f"/expr_heatmap_{title.replace(' ', '_')}.pdf", dpi=300, bbox_inches="tight")
        # plt.savefig(image_path + f"/expr_heatmap_{title.replace(' ', '_')}.svg", dpi=300, bbox_inches="tight")
        # pass
    if show: 
        plt.show()
    plt.close()

### Colors

In [None]:
import matplotlib as mpl
from matplotlib.colors import LinearSegmentedColormap
colors = ['#e3e3e3', '#d2c2d7', '#c1a1cb', '#b080bf', '#9f5fb3', '#8e3ea7', '#7e1e9c']
n_bins = 256
gene_cmap = LinearSegmentedColormap.from_list('RNA', colors, N=n_bins)

colors = ['#e3e3e3', '#bdbdd4', '#9797c5', '#7171b7', '#4b4ba8', '#252599', '#00008b']

n_bins = 256
mch_cmap = LinearSegmentedColormap.from_list('mCH', colors, N=n_bins)

spatial_cmaps = {
        'RNA': gene_cmap,
        "mCH" : mch_cmap
}

for name, cmap in spatial_cmaps.items():
    mpl.colormaps.unregister(name)
    mpl.colormaps.register(name=name, cmap=cmap)

In [None]:
def hex_to_rgb(hex_color):
    """Converts a hex color string to an RGB tuple."""
    hex_color = hex_color.lstrip('#')
    return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))

def rgb_to_hex(rgb_color):
    """Converts an RGB tuple to a hex color string."""
    return '#{:02x}{:02x}{:02x}'.format(int(rgb_color[0]), int(rgb_color[1]), int(rgb_color[2]))

def get_intermediate_hex_colors(hex_color1, hex_color2, num_steps):
    """
    Generates a list of intermediate hex colors between two given hex colors.

    Args:
        hex_color1 (str): The starting hex color (e.g., '#FF0000').
        hex_color2 (str): The ending hex color (e.g., '#0000FF').
        num_steps (int): The number of intermediate colors to generate.

    Returns:
        list: A list of hex color strings, including the start and end colors.
    """
    rgb1 = hex_to_rgb(hex_color1)
    rgb2 = hex_to_rgb(hex_color2)

    intermediate_colors = []
    for i in range(num_steps + 2):  # Include start and end colors
        ratio = i / (num_steps + 1)
        r = rgb1[0] + (rgb2[0] - rgb1[0]) * ratio
        g = rgb1[1] + (rgb2[1] - rgb1[1]) * ratio
        b = rgb1[2] + (rgb2[2] - rgb1[2]) * ratio
        intermediate_colors.append(rgb_to_hex((r, g, b)))
    return intermediate_colors

def print_colored_text(text, hex_foreground_color=None, hex_background_color=None):
    """Prints text with specified foreground and/or background colors using hex values."""
    RESET = '\033[0m'
    color_codes = []

    if hex_foreground_color:
        r, g, b = hex_to_rgb(hex_foreground_color)
        color_codes.append(f'\033[38;2;{r};{g};{b}m')  # Set foreground color

    if hex_background_color:
        r, g, b = hex_to_rgb(hex_background_color)
        color_codes.append(f'\033[48;2;{r};{g};{b}m')  # Set background color

    print("".join(color_codes) + text + RESET)

# Example usage:
start_color = '#E3E3E3'
end_color = '#00008b'
steps = 5                # 3 intermediate colors

gradient = get_intermediate_hex_colors(start_color, end_color, steps)
print(gradient)
for color in gradient:
    print_colored_text(f"{color}", hex_foreground_color=color)

## Read

In [None]:
ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
adata = ad.read_h5ad(ad_path)
adata

In [None]:
spatial_mch = ad.read_h5ad("/home/x-aklein2/projects/aklein/BICAN/BG/data/methylation_2/BG_mCH_Imp_SubR.h5ad")
for k, v in adata.uns.items(): 
    if "colors" in k or "palette" in k: 
        spatial_mch.uns[k] = v
spatial_mch.obsm['spatial'] = spatial_mch.obs[['CENTER_X', 'CENTER_Y']].values
spatial_mch

## MorphoGAM

In [None]:
##
DIR = Path("/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/morphogam_dsid/")
# fp = "morphogam_meta_STR D2 MSN.csv"
# df = pd.read_csv(DIR / fp)
(df['meta_p_t'] < 1e-12).sum()

In [None]:
DIR = Path("/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/morphogam_dsid/")
all_meta = "morphogam_meta_all_celltypes.csv"
df_meta = pd.read_csv(DIR / all_meta)

In [None]:
df_tt = df_meta[df_meta['celltype'] == 'STR D2 MSN'].copy()
df_tt.head()

In [None]:
# df_tt.sort_values('fdr_t').head(20)
# # df_tt[df_tt['tau2_t'] < 1.5].sort_values("fdr_t").head(50)

In [None]:
# genes = df_tt[df_tt['tau2_t'] < 2].sort_values("fdr_t").gene
# len(genes)

In [None]:
### Read Data:
path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/morphogam_gr_brs/results_all.csv"
morphogam_df = pd.read_csv(path)
morphogam_df.head()

In [None]:
common_genes = adata.var_names.intersection(spatial_mch.var_names)
morphogam_df = morphogam_df[morphogam_df['gene'].isin(common_genes)].copy()

In [None]:
# morphogam_df['region'] = morphogam_df['replicate'].str.split('_').str[0]
# morphogam_df['donor'] = morphogam_df['replicate'].str.split('_').str[0]
# morphogam_df['lab'] = morphogam_df['replicate'].str.split('_').str[0]

In [None]:
# fdr_col = "end_test_pvalue"
fdr_col = "pv.t"
fdr_thresh = 0.01
sig_counts = morphogam_df.groupby(["celltype", 'replicate'], observed=True)\
                .apply(lambda x: ((x[fdr_col] < fdr_thresh)).sum())\
                .reset_index(name="n_sig")

fig, ax = plt.subplots(figsize=(6,4))
sns.barplot(data=sig_counts, x="celltype", y="n_sig", hue="replicate", palette="muted", ax=ax)
ax.set_ylabel(f"# Significant Genes (FDR<{fdr_thresh})")
ax.set_xlabel("Cell Type")
ax.set_title("Significant Spatially Variable Genes per Cell Type")
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
plt.tight_layout()
plt.show()

In [None]:
transfer_cols = ['brain_region_corr', 'MS_NORM', 'MS_SCORE', 'MS_compartment']
for col in transfer_cols:
    spatial_mch.obs[col] = adata.obs.loc[spatial_mch.obs.index, col].copy()

In [None]:
repr_order = ["CaH", "CaB", "Pu", "NAC"]
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/CPS/ms_morphogam"

In [None]:
morphogam_df['pv'] = morphogam_df['pv.t']

In [None]:
df_sub = (
    morphogam_df[morphogam_df['celltype'] == _ct]
    .query(f"{fdr_col} < 1e-2 & {fdr_col} >= 0")
    # .groupby(["gene", "replicate"])[[fdr_col, 'rho']]
    # .median()
    # .sort_values(fdr_col, ascending=True)
    # .reset_index()
    # .drop_duplicates(subset='gene', keep='first')
)

In [None]:
df_sub

In [None]:
adata[adata.obs['Subclass'] == "CN ST18 GABA"].obs['Group'].value_counts()

In [None]:
morphogam_df['celltype'].unique()

In [None]:
_ct = "STR FS PTHLH-PVALB GABA"
level="Group"
# _ct = "CN ST18 GABA"
# level = "Subclass"
fdr_col = "fdr_t"
print(_ct)
df_sub = (
    morphogam_df[morphogam_df['celltype'] == _ct]
    .query(f"{fdr_col} < 1e-12 & {fdr_col} >= 0")
    .groupby(["gene", "replicate"])[[fdr_col, 'rho']]
    .median()
    .sort_values(fdr_col, ascending=True)
    .reset_index()
    .drop_duplicates(subset='gene', keep='first')
)
ns = len(df_sub['gene'].unique())
print(ns)
adata_ct = adata[(adata.obs[level] == _ct) & (~adata.obs['MS_NORM'].isna())].copy()
adata_ct = adata_ct[adata_ct.obs['brain_region_corr'].isin(repr_order)].copy()
adata_ct = _downsample_reference(adata_ct, cluster_col="MS_compartment", max_cluster_size=2000, min_cluster_size=0)
adata_ct.obs['brain_region_corr'] = pd.Categorical(adata_ct.obs['brain_region_corr'], categories=repr_order, ordered=True)
rep_palette = {k : v for k, v in adata_ct.uns.get("brain_region_corr_palette", {}).items() if k in adata_ct.obs['brain_region_corr'].unique()}
# expr_heatmap(adata_ct, df_sub, n_genes=ns, gene_names=ns, title=f"Top SPARK-X genes in {_ct}",
#              rasterized=True, show=True, save=False, ylabel=None, color_splits="black", heatmap_order_col='rho',
#              replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
#              cmap='Reds', vmin=0, vmax=1, min_max_quantiles=(0.2, 0.98))

genes = list(set(genes).intersection(spatial_mch.var_names))
ns = len(genes)
df_m = pd.DataFrame({'gene': genes, 'rho': df_tt.loc[df_tt['gene'].isin(genes), 'meta_rho']})
expr_heatmap(adata_ct, df_m, n_genes=ns, gene_names=ns, title=f"Top SPARK-X genes in {_ct}",
             rasterized=True, show=True, save=False, ylabel=None, color_splits="black", heatmap_order_col='rho',
             replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
             cmap='Reds', vmin=0, vmax=1, min_max_quantiles=(0.2, 0.98))

In [None]:
_ct = "STR FS PTHLH-PVALB GABA"
level="Group"
# _ct = "CN ST18 GABA"
# level = "Subclass"
fdr_col = "fdr_t"
print(_ct)
df_sub = (
    morphogam_df[morphogam_df['celltype'] == _ct]
    .query(f"{fdr_col} < 1e-12 & {fdr_col} >= 0")
    .groupby(["gene", "replicate"])[[fdr_col, 'rho']]
    .median()
    .sort_values(fdr_col, ascending=True)
    .reset_index()
    .drop_duplicates(subset='gene', keep='first')
)
ns = len(df_sub['gene'].unique())
print(ns)
adata_ct = spatial_mch[(spatial_mch.obs[level] == _ct) & (~spatial_mch.obs['MS_NORM'].isna())].copy()
adata_ct = adata_ct[adata_ct.obs['brain_region_corr'].isin(repr_order)].copy()
adata_ct.obs['brain_region_corr'] = pd.Categorical(adata_ct.obs['brain_region_corr'], categories=repr_order, ordered=True)
adata_ct = _downsample_reference(adata_ct, cluster_col="MS_compartment", max_cluster_size=6000, min_cluster_size=0)
rep_palette = {k : v for k, v in adata_ct.uns.get("brain_region_corr_palette", {}).items() if k in adata_ct.obs['brain_region_corr'].unique()}
expr_heatmap(adata_ct, df_sub, n_genes=ns, gene_names=ns, title=f"Top SPARK-X genes in {_ct}",
             rasterized=True, show=True, save=False, ylabel=None, image_path=image_path, heatmap_order_col='rho',
             replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
             cmap='parula', vmin=0, vmax=1, min_max_quantiles=(0.0, 0.98))

In [None]:
joint_cells = adata.obs_names.intersection(spatial_mch.obs_names)
adata_joint_rna = adata[joint_cells].copy()
adata_joint_mch = spatial_mch[joint_cells].copy()
# adata_joint_rna.X = adata_joint_rna.layers['counts'].copy()

In [None]:
_ct = "STR FS PTHLH-PVALB GABA"
level="Group"
# _ct = "STR D2 MSN"
# level = "Subclass"
fdr_col = "fdr_t"
print(_ct)
df_sub = (
    morphogam_df[morphogam_df['celltype'] == _ct]
    .query(f"{fdr_col} < 1e-3 & {fdr_col} >= 0")
    .groupby(["gene", "replicate"])[[fdr_col, 'rho']]
    .median()
    .sort_values(fdr_col, ascending=True)
    .reset_index()
    .drop_duplicates(subset='gene', keep='first')
)
ns = len(df_sub['gene'].unique())
print(ns)

adata_rna = adata_joint_rna[(adata_joint_rna.obs[level] == _ct) & (~adata_joint_rna.obs['MS_NORM'].isna())].copy()

adata_rna = adata_rna[adata_rna.obs['brain_region_corr'].isin(repr_order)].copy()
adata_rna = _downsample_reference(adata_rna, cluster_col="MS_compartment", max_cluster_size=4000, min_cluster_size=0)
adata_rna.obs['brain_region_corr'] = pd.Categorical(adata_rna.obs['brain_region_corr'], categories=repr_order, ordered=True)

# adata_mch = adata_joint_mch[adata_rna.obs.index].copy()
adata_mch = adata_joint_mch[(adata_joint_mch.obs[level] == _ct) & (~adata_joint_mch.obs['MS_NORM'].isna())].copy()
adata_mch = adata_mch[adata_mch.obs['brain_region_corr'].isin(repr_order)].copy()
adata_mch.obs['brain_region_corr'] = pd.Categorical(adata_mch.obs['brain_region_corr'], categories=repr_order, ordered=True)

# adata_mch = _downsample_reference(adata_mch, cluster_col="MS_compartment", max_cluster_size=6000, min_cluster_size=0)

rep_palette = {k : adata_rna.uns["brain_region_corr_palette"].get(k, None) for k in repr_order}
joint_expr_heatmap(adata_rna, adata_mch, df_sub, n_genes=ns, gene_names=10, title=f"Top SPARK-X genes in {_ct}",
             rasterized=True, show=True, save=False, ylabel=None, color_splits="black", heatmap_order_col='rho',
             replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
             rna_cmap='agSunset', vmin=0, vmax=1, meth_min_max_quantiles=(0.1, 0.9), rna_min_max_quantiles=(0.25, 0.99),
             figsize=(12,6))

# genes = list(set(genes).intersection(spatial_mch.var_names))
# ns = len(genes)
# df_m = pd.DataFrame({'gene': genes, 'rho': df_tt.loc[df_tt['gene'].isin(genes), 'meta_rho']})
# joint_expr_heatmap(adata_rna, adata_mch, df_m, n_genes=ns, gene_names=10, title=f"Top SPARK-X genes in {_ct}",
#              rasterized=True, show=True, save=False, ylabel=None, color_splits="black", heatmap_order_col='rho',
#              replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
#              rna_cmap='agSunset', vmin=0, vmax=1, meth_min_max_quantiles=(0.1, 0.9), rna_min_max_quantiles=(0.25, 0.99),
#              figsize=(12,6))

## SparkX

In [None]:
### Read Data: 
# SPARKX
# path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/sparkx_dsid/sparkx_per_replicate_results.csv"
# path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/sparkx_donor/sparkx_per_replicate_results.csv"
path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/sparkx_brs/sparkx_per_replicate_results.csv"
# path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/sparkx_brg/sparkx_per_replicate_results.csv"
sparkx_df = pd.read_csv(path)
sparkx_df.head()

In [None]:
level = "Subclass"
_let = level.lower()[0]
meta_path = f"/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/sparkx_br{_let}/sparkx_meta_results.csv"
df_meta = pd.read_csv(meta_path)

common_genes = adata.var_names.intersection(spatial_mch.var_names)
df_meta = df_meta[df_meta['gene'].isin(common_genes)].copy()
sparkx_df = sparkx_df[sparkx_df['gene'].isin(common_genes)].copy()

In [None]:
fdr_col = "fdr"
fdr_thresh = 1e-5
sig_counts = df_meta.groupby(["cell_type", "direction"], observed=True)\
                .apply(lambda x: ((x[fdr_col] < fdr_thresh)).sum())\
                .reset_index(name="n_sig")
sig_counts['direction'] = sig_counts['direction'].map({'down': 'Matrix', 'up': 'Striosome'})

fig, ax = plt.subplots(figsize=(8,3))
sns.barplot(data=sig_counts, x="cell_type", y="n_sig", hue="direction", ax=ax, palette = {"Matrix": "blue", "Striosome": "red"}, rasterized=True)
ax.set_ylabel(f"# Significant Genes (FDR<{fdr_thresh})", rasterized=True)
ax.set_xlabel("Cell Type", rasterized=True)
ax.set_title("Significant Compartment Enriched Genes per Cell Type", rasterized=True)
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", rasterized=True)

color_ct = []
for _ct in sig_counts['cell_type'].unique():
    if sig_counts[(sig_counts['cell_type'] == _ct)]['n_sig'].sum() > 5: 
        color_ct.append(_ct)

for i, ticklabel in enumerate(ax.get_xticklabels()): 
    ticklabel.set_color('black' if ticklabel.get_text() in color_ct else 'lightgray')

plt.tight_layout()
plt.show()
plt.close()

In [None]:
sig_counts = df_meta.groupby(["cell_type"], observed=True)\
                .apply(lambda x: ((x[fdr_col] < fdr_thresh)).sum())\
                .reset_index(name="n_sig")\
                .query("n_sig > 1")\
                .sort_values('n_sig', ascending=True)
sig_counts['color'] = sig_counts['cell_type'].map(adata.uns[f"{level}_palette"])

fig, ax = plt.subplots(figsize=(8,3))
ax.barh(data=sig_counts, y="cell_type", width="n_sig", color=sig_counts['color'], rasterized=True)
ax.set_ylabel("Cell Type", rasterized=True)
ax.set_xlabel(f"# Significant Genes (FDR<{fdr_thresh})", rasterized=True)
ax.set_title("Significant Compartment Enriched Genes per Cell Type", rasterized=True)
# ax.set_xticks(ax.get_xticks())
# ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", rasterized=True)

# for i, ticklabel in enumerate(ax.get_yticklabels()): 
#     ticklabel.set_color(adata.uns[f"{level}_palette"].get(ticklabel.get_text(), 'black'))
group_to_nt = adata.obs[[level, "Neighborhood"]].drop_duplicates().set_index(level).to_dict()['Neighborhood']
tick_to_nn_pal = {tick: adata.uns['m3c_neighborhood_palette'].get(nt, '#808080') for tick, nt in group_to_nt.items()}
for i, ticklabel in enumerate(ax.get_yticklabels()): 
    ticklabel.set_color(tick_to_nn_pal.get(ticklabel.get_text(), '#808080'))

plt.tight_layout()
plt.show()
plt.close()

In [None]:
#dsid level
# rep_to_region = group_to_nt = adata.obs[['dataset_id', "brain_region"]].drop_duplicates().set_index('dataset_id').to_dict()['brain_region']
# sparkx_df['brain_region'] = sparkx_df['replicate'].map(rep_to_region)
# sparkx_df = sparkx_df[sparkx_df['brain_region'] != "CAT"].copy()

# region level
sparkx_df = sparkx_df[sparkx_df['replicate'] != "CAT"].copy()

sparkx_df['direction'] = (sparkx_df['rho_axis'] > 0).map({True: 'Striosome', False: 'Matrix'})
sparkx_df.head()

In [None]:
fdr_col = "p_sparkx"
fdr_thresh = 1e-2
sig_counts = (
    sparkx_df.groupby(["cell_type", "replicate", "direction"], observed=True)
    .apply(lambda x: ((x[fdr_col] < fdr_thresh)).sum())
    .reset_index(name="n_sig")
    # .groupby(["cell_type", 'direction'], observed=True)
    # .agg({'n_sig': ['mean', 'std']})
)
# sig_counts.columns = sig_counts.columns.droplevel(0)
sig_counts = sig_counts.reset_index()
# sig_counts = sig_counts.sort_values(('mean'), ascending=False)
# sig_counts = sig_counts[sig_counts['mean'] > 1]
sig_counts.head()

In [None]:
fig, ax = plt.subplots(figsize=(8,3))
sns.barplot(
    data=sig_counts, x="cell_type", y="n_sig", hue="direction",
    palette = {"Matrix": "blue", "Striosome": "red"},
    errorbar=None, err_kws={"color": "black", "linewidth": 1}, capsize=0.2,
    ax=ax, rasterized=True
)
ax.set_ylabel(f"# Significant Genes (FDR<{fdr_thresh})", rasterized=True)
ax.set_xlabel("Cell Type", rasterized=True)
ax.set_title("Significant Compartment Enriched Genes per Cell Type", rasterized=True)
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", rasterized=True)


color_ct = []
for _ct in sig_counts['cell_type'].unique():
    if sig_counts[(sig_counts['cell_type'] == _ct)]['n_sig'].sum() > 5: 
        color_ct.append(_ct)

for i, ticklabel in enumerate(ax.get_xticklabels()): 
    ticklabel.set_color('black' if ticklabel.get_text() in color_ct else 'lightgray')

plt.tight_layout()
plt.show()
plt.close()

In [None]:
level = "Subclass"
fdr_col = "p_sparkx"
fdr_thresh = 1e-5
sig_counts = (
    sparkx_df.groupby(["cell_type", "replicate"], observed=True)
    .apply(lambda x: ((x[fdr_col] < fdr_thresh)).sum())
    .reset_index(name="n_sig")
    .groupby(["cell_type"], observed=True)
    .agg({'n_sig': ['mean', 'std']})
)
sig_counts.columns = sig_counts.columns.droplevel(0)
sig_counts = (
    sig_counts.reset_index()
    .query("mean > 1")
    .sort_values('mean', ascending=True)
)
sig_counts['color'] = sig_counts['cell_type'].map(adata.uns[f"{level}_palette"])


fig, ax = plt.subplots(figsize=(8,3))
ax.barh(
    data=sig_counts, y="cell_type", width="mean", color=sig_counts['color'],
    # xerr=sig_counts['std'], ecolor='black', capsize=4,
    rasterized=True
)
ax.set_ylabel("Cell Type", rasterized=True)
ax.set_xlabel(f"# Significant Genes (FDR<{fdr_thresh})", rasterized=True)
ax.set_title("Significant Compartment Enriched Genes per Cell Type", rasterized=True)
# ax.set_xticks(ax.get_xticks())
# ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", rasterized=True)

# for i, ticklabel in enumerate(ax.get_yticklabels()): 
#     ticklabel.set_color(adata.uns[f"{level}_palette"].get(ticklabel.get_text(), 'black'))
group_to_nt = adata.obs[[level, "Neighborhood"]].drop_duplicates().set_index(level).to_dict()['Neighborhood']
tick_to_nn_pal = {tick: adata.uns['m3c_neighborhood_palette'].get(nt, '#808080') for tick, nt in group_to_nt.items()}
for i, ticklabel in enumerate(ax.get_yticklabels()): 
    ticklabel.set_color(tick_to_nn_pal.get(ticklabel.get_text(), '#808080'))

plt.tight_layout()
plt.show()
plt.close()

In [None]:
transfer_cols = ['brain_region_corr', 'MS_NORM', 'MS_SCORE', 'MS_compartment']
for col in transfer_cols:
    spatial_mch.obs[col] = adata.obs.loc[spatial_mch.obs.index, col].copy()

In [None]:
joint_cells = adata.obs_names.intersection(spatial_mch.obs_names)
adata_joint_rna = adata[joint_cells].copy()
adata_joint_mch = spatial_mch[joint_cells].copy()

In [None]:
repr_order = ["CaH", "CaB", "Pu", "NAC"]
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/CPS/ms_sparkx"

In [None]:
df_sub = (
    sparkx_df[sparkx_df['cell_type'] == _ct]
    .query(f"{fdr_col} < {fdr_thresh} & {fdr_col} >= 0")
    .groupby(["gene", "replicate"])[[fdr_col, 'rho_axis']]
    .median()
    .sort_values(fdr_col, ascending=True)
    .reset_index()
    # .drop_duplicates(subset='gene', keep='first')
)
gene_vc = df_sub['gene'].value_counts()
keep_genes = gene_vc[gene_vc >= 2].index
df_sub = df_sub[df_sub['gene'].isin(keep_genes)].copy()
df_sub = df_sub.drop_duplicates(subset='gene', keep='first')

In [None]:
fdr_thresh = 1e-5
fdr_thresh

In [None]:
topn = 10
for _ct in sig_counts['cell_type'].unique():
    df_sub = (
        sparkx_df[sparkx_df['cell_type'] == _ct]
        .query(f"{fdr_col} < {fdr_thresh} & {fdr_col} >= 0")
        .groupby(["gene", "replicate"])[[fdr_col, 'rho_axis']]
        .median()
        .sort_values(fdr_col, ascending=True)
        .reset_index()
        # .drop_duplicates(subset='gene', keep='first')
    )
    gene_vc = df_sub['gene'].value_counts()
    ns = len(df_sub['gene'].unique())
    if ns < 50: 
        gvc_thr = 1a
    elif ns < 100: 
        gvc_thr = 2
    else: 
        gvc_thr = 3
    keep_genes = gene_vc[gene_vc >= gvc_thr].index
    df_sub = df_sub[df_sub['gene'].isin(keep_genes)].copy()
    df_sub = df_sub.drop_duplicates(subset='gene', keep='first')

    nt = group_to_nt.get(_ct)
    ns = len(df_sub['gene'].unique())
    print(_ct, ns, nt)
    if ns == 0: 
        continue

    if nt in ['Nonneuron', "unknown"]: 
        # Plot only RNA
        # pass
        adata_ct = adata[(adata.obs['Subclass'] == _ct) & (~adata.obs['MS_NORM'].isna())].copy()
        adata_ct = adata_ct[adata_ct.obs['brain_region_corr'].isin(repr_order)].copy()
        adata_ct = _downsample_reference(adata_ct, cluster_col="MS_compartment", max_cluster_size=6000, min_cluster_size=0)
        adata_ct.obs['brain_region_corr'] = pd.Categorical(adata_ct.obs['brain_region_corr'], categories=repr_order, ordered=True)
        rep_palette = {k : v for k, v in adata_ct.uns.get("brain_region_corr_palette", {}).items() if k in adata_ct.obs['brain_region_corr'].unique()}
        expr_heatmap(adata_ct, df_sub, n_genes=ns, gene_names=10, ylabel=None, color_splits="white",
                    rasterized=True, show=False, save=True, image_path=image_path, title=f"Top SPARK-X genes in {_ct}",
                    replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
                    cmap='agSunset', vmin=0, vmax=1, min_max_quantiles=(0.25, 0.99))
        # break
    else: 
        # Plot joint RNA + mCH
        adata_rna = adata_joint_rna[(adata_joint_rna.obs[level] == _ct) & (~adata_joint_rna.obs['MS_NORM'].isna())].copy()

        adata_rna = adata_rna[adata_rna.obs['brain_region_corr'].isin(repr_order)].copy()
        adata_rna = _downsample_reference(adata_rna, cluster_col="MS_compartment", max_cluster_size=4000, min_cluster_size=0)
        adata_rna.obs['brain_region_corr'] = pd.Categorical(adata_rna.obs['brain_region_corr'], categories=repr_order, ordered=True)

        # adata_mch = adata_joint_mch[adata_rna.obs.index].copy()
        adata_mch = adata_joint_mch[(adata_joint_mch.obs[level] == _ct) & (~adata_joint_mch.obs['MS_NORM'].isna())].copy()
        adata_mch = adata_mch[adata_mch.obs['brain_region_corr'].isin(repr_order)].copy()
        adata_mch.obs['brain_region_corr'] = pd.Categorical(adata_mch.obs['brain_region_corr'], categories=repr_order, ordered=True)

        # adata_mch = _downsample_reference(adata_mch, cluster_col="MS_compartment", max_cluster_size=6000, min_cluster_size=0)

        rep_palette = {k : adata_rna.uns["brain_region_corr_palette"].get(k, None) for k in repr_order}
        joint_expr_heatmap(adata_rna, adata_mch, df_sub, n_genes=ns, gene_names=topn, ylabel=None, color_splits="white", 
                    rasterized=True, show=False, save=True, image_path=image_path, title=f"Top SPARK-X genes in {_ct}",
                    replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
                    rna_cmap='agSunset', vmin=0, vmax=1, meth_min_max_quantiles=(0.1, 0.9), rna_min_max_quantiles=(0.25, 0.99),
                    figsize=(12,4))
        # pass
    

In [None]:
# fdr_thresh=1e-2

In [None]:
topn = 10
for _ct in sig_counts['cell_type'].unique():
    df_sub = (
        df_meta[df_meta['cell_type'] == _ct]
        .query(f"{fdr_col} < {fdr_thresh} & {fdr_col} >= 0")
        .sort_values(fdr_col, ascending=True)
    )
    nt = group_to_nt.get(_ct)
    print(_ct, df_sub.shape[0], nt)

    if nt in ['Nonneuron', "unknown"]: 
        # Plot only RNA
        ns = len(df_sub['gene'].unique())
        adata_ct = adata[(adata.obs['Subclass'] == _ct) & (~adata.obs['MS_NORM'].isna())].copy()
        adata_ct = adata_ct[adata_ct.obs['brain_region_corr'].isin(repr_order)].copy()
        adata_ct = _downsample_reference(adata_ct, cluster_col="MS_compartment", max_cluster_size=6000, min_cluster_size=0)
        adata_ct.obs['brain_region_corr'] = pd.Categorical(adata_ct.obs['brain_region_corr'], categories=repr_order, ordered=True)
        rep_palette = {k : v for k, v in adata_ct.uns.get("brain_region_corr_palette", {}).items() if k in adata_ct.obs['brain_region_corr'].unique()}
        expr_heatmap(adata_ct, df_sub, n_genes=ns, gene_names=10, title=f"Top SPARK-X genes in {_ct}",
                    rasterized=True, show=True, save=False, ylabel=None, color_splits="black", heatmap_order_col='direction',
                    replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
                    cmap='agSunset', vmin=0, vmax=1, min_max_quantiles=(0.25, 0.99))
        # break
    else: 
        # Plot joint RNA + mCH
        ns = len(df_sub['gene'].unique())
        adata_rna = adata_joint_rna[(adata_joint_rna.obs[level] == _ct) & (~adata_joint_rna.obs['MS_NORM'].isna())].copy()

        adata_rna = adata_rna[adata_rna.obs['brain_region_corr'].isin(repr_order)].copy()
        adata_rna = _downsample_reference(adata_rna, cluster_col="MS_compartment", max_cluster_size=4000, min_cluster_size=0)
        adata_rna.obs['brain_region_corr'] = pd.Categorical(adata_rna.obs['brain_region_corr'], categories=repr_order, ordered=True)

        # adata_mch = adata_joint_mch[adata_rna.obs.index].copy()
        adata_mch = adata_joint_mch[(adata_joint_mch.obs[level] == _ct) & (~adata_joint_mch.obs['MS_NORM'].isna())].copy()
        adata_mch = adata_mch[adata_mch.obs['brain_region_corr'].isin(repr_order)].copy()
        adata_mch.obs['brain_region_corr'] = pd.Categorical(adata_mch.obs['brain_region_corr'], categories=repr_order, ordered=True)

        # adata_mch = _downsample_reference(adata_mch, cluster_col="MS_compartment", max_cluster_size=6000, min_cluster_size=0)

        rep_palette = {k : adata_rna.uns["brain_region_corr_palette"].get(k, None) for k in repr_order}
        joint_expr_heatmap(adata_rna, adata_mch, df_sub, n_genes=ns, gene_names=topn, title=f"Top SPARK-X genes in {_ct}",
                    rasterized=True, show=True, save=True, ylabel=None, color_splits="black", heatmap_order_col='direction',
                    replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
                    rna_cmap='agSunset', vmin=0, vmax=1, meth_min_max_quantiles=(0.1, 0.9), rna_min_max_quantiles=(0.25, 0.99),
                    figsize=(12,4))
        # pass
    

In [None]:
sig_counts = df_meta.groupby(["cell_type"], observed=True)\
                .apply(lambda x: ((x[fdr_col] < fdr_thresh)).sum())\
                .reset_index(name="n_sig")\
                .query("n_sig > 1")\
                .sort_values('n_sig', ascending=True)
sig_counts['color'] = sig_counts['cell_type'].map(adata.uns[f"{level}_palette"])

In [None]:
_ct = "STR D1 MSN"
fdr_col = "p_sparkx"
print(_ct)
df_sub = (
    sparkx_df[sparkx_df['cell_type'] == _ct]
    .query(f"{fdr_col} < 1e-100 & {fdr_col} >= 0")
    .groupby(["gene", "replicate"])[[fdr_col, 'rho_axis']]
    .median()
    .sort_values(fdr_col, ascending=True)
    .reset_index()
    .drop_duplicates(subset='gene', keep='first')
)
ns = len(df_sub['gene'].unique())
print(ns)

adata_ct = adata[(adata.obs['Subclass'] == _ct) & (~adata.obs['MS_NORM'].isna())].copy()
adata_ct = adata_ct[adata_ct.obs['brain_region_corr'].isin(repr_order)].copy()
adata_ct = _downsample_reference(adata_ct, cluster_col="MS_compartment", max_cluster_size=6000, min_cluster_size=0)
adata_ct.obs['brain_region_corr'] = pd.Categorical(adata_ct.obs['brain_region_corr'], categories=repr_order, ordered=True)
rep_palette = {k : v for k, v in adata_ct.uns.get("brain_region_corr_palette", {}).items() if k in adata_ct.obs['brain_region_corr'].unique()}
expr_heatmap(adata_ct, df_sub, n_genes=ns, gene_names=10, title=f"Top SPARK-X genes in {_ct}",
             rasterized=True, show=True, save=False, ylabel=None, color_splits="black",
             replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
             cmap='Reds', vmin=0, vmax=1, min_max_quantiles=(0.2, 0.98))

In [None]:
_ct = "STR D1 MSN"
print(_ct)
df_sub = (
    sparkx_df[sparkx_df['cell_type'] == _ct]
    .query("p_sparkx < 0.05")
    .groupby(["gene", "replicate"])[['p_sparkx', 'rho_axis']]
    .median()
    .sort_values('p_sparkx', ascending=True)
    .reset_index()
    .drop_duplicates(subset='gene', keep='first')
)
adata_ct = spatial_mch[(spatial_mch.obs['Subclass'] == _ct) & (~spatial_mch.obs['MS_NORM'].isna())].copy()
adata_ct = adata_ct[adata_ct.obs['brain_region_corr'].isin(repr_order)].copy()
adata_ct.obs['brain_region_corr'] = pd.Categorical(adata_ct.obs['brain_region_corr'], categories=repr_order, ordered=True)
adata_ct = _downsample_reference(adata_ct, cluster_col="MS_compartment", max_cluster_size=6000, min_cluster_size=0)
rep_palette = {k : v for k, v in adata_ct.uns.get("brain_region_corr_palette", {}).items() if k in adata_ct.obs['brain_region_corr'].unique()}
expr_heatmap(adata_ct, df_sub, n_genes=20, gene_names=20, title=f"Top SPARK-X genes in {_ct}",
             rasterized=True, show=True, save=True, ylabel=None, image_path=image_path,
             replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
             cmap='parula', vmin=0, vmax=1, min_max_quantiles=(0.0, 0.98))

In [None]:
_ct = "STR D1 MSN"
print(_ct)
df_sub = (
    sparkx_df[sparkx_df['cell_type'] == _ct]
    .query("p_sparkx < 0.05")
    .groupby(["gene", "replicate"])[['p_sparkx', 'rho_axis']]
    .median()
    .sort_values('p_sparkx', ascending=True)
    .reset_index()
    .drop_duplicates(subset='gene', keep='first')
)
adata_ct = adata[(adata.obs['Subclass'] == _ct) & (~adata.obs['MS_NORM'].isna())].copy()
adata_ct = adata_ct[adata_ct.obs['brain_region_corr'].isin(repr_order)].copy()
adata_ct = _downsample_reference(adata_ct, cluster_col="MS_compartment", max_cluster_size=6000, min_cluster_size=0)
adata_ct.obs['brain_region_corr'] = pd.Categorical(adata_ct.obs['brain_region_corr'], categories=repr_order, ordered=True)
rep_palette = {k : v for k, v in adata_ct.uns.get("brain_region_corr_palette", {}).items() if k in adata_ct.obs['brain_region_corr'].unique()}
expr_heatmap(adata_ct, df_sub, n_genes=20, gene_names=20, title=f"Top SPARK-X genes in {_ct}",
             rasterized=True, show=True, save=False, ylabel=None, color_splits="black",
             replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
             cmap='Reds', vmin=0, vmax=1, min_max_quantiles=(0.2, 0.98))

In [None]:
joint_cells = adata.obs_names.intersection(spatial_mch.obs_names)
adata_joint_rna = adata[joint_cells].copy()
adata_joint_mch = spatial_mch[joint_cells].copy()

In [None]:
adata_joint_rna.X = adata_joint_rna.layers['counts'].copy()

In [None]:
topn = 50
_ct = "STR FS PTHLH-PVALB GABA"
level = "Group"

# _ct = "STR D2 MSN"
# level = "Subclass"

fdr_col = "p_sparkx"
print(_ct)
df_sub = (
    sparkx_df[sparkx_df['cell_type'] == _ct]
    .query(f"{fdr_col} < 1e-2 & {fdr_col} >= 0")
    .groupby(["gene", "replicate"])[[fdr_col, 'rho_axis']]
    .median()
    .sort_values(fdr_col, ascending=True)
    .reset_index()
    .drop_duplicates(subset='gene', keep='first')
)
ns = len(df_sub['gene'].unique())
print(ns)
adata_rna = adata_joint_rna[(adata_joint_rna.obs[level] == _ct) & (~adata_joint_rna.obs['MS_NORM'].isna())].copy()

adata_rna = adata_rna[adata_rna.obs['brain_region_corr'].isin(repr_order)].copy()
adata_rna = _downsample_reference(adata_rna, cluster_col="MS_compartment", max_cluster_size=4000, min_cluster_size=0)
adata_rna.obs['brain_region_corr'] = pd.Categorical(adata_rna.obs['brain_region_corr'], categories=repr_order, ordered=True)

# adata_mch = adata_joint_mch[adata_rna.obs.index].copy()
adata_mch = adata_joint_mch[(adata_joint_mch.obs[level] == _ct) & (~adata_joint_mch.obs['MS_NORM'].isna())].copy()
adata_mch = adata_mch[adata_mch.obs['brain_region_corr'].isin(repr_order)].copy()
adata_mch.obs['brain_region_corr'] = pd.Categorical(adata_mch.obs['brain_region_corr'], categories=repr_order, ordered=True)

# adata_mch = _downsample_reference(adata_mch, cluster_col="MS_compartment", max_cluster_size=6000, min_cluster_size=0)

rep_palette = {k : adata_rna.uns["brain_region_corr_palette"].get(k, None) for k in repr_order}
joint_expr_heatmap(adata_rna, adata_mch, df_sub, n_genes=topn, gene_names=ns, title=f"Top SPARK-X genes in {_ct}",
             rasterized=True, show=True, save=True, ylabel=None, color_splits="black",
             replicate_col="brain_region_corr", replicate_palette=rep_palette, replicate_order=repr_order,
             rna_cmap='agSunset', vmin=0, vmax=1, meth_min_max_quantiles=(0.1, 0.9), rna_min_max_quantiles=(0.25, 0.99),
             figsize=(12,4))

In [None]:
df_sub.loc[df_sub['gene'] == 'STXBP6']

In [None]:
'GLI3' in adata.var_names, 'KHDRBS3' in adata.var_names, 'GSA' in adata.var_names

In [None]:
groupby_replicate = True
replicate_col = "brain_region_corr"

df_col = adata_rna.obs[['Group', 'MS_NORM', 'MS_compartment']].copy()
if df_col.shape[0] > 50000:
    df_col = df_col.sample(50000, random_state=42)
if groupby_replicate:
    df_col[replicate_col] = adata_rna.obs[replicate_col]
    df_col = df_col.sort_values([replicate_col, 'MS_NORM'])
else:
    df_col = df_col.sort_values('MS_NORM')

In [None]:
df_col

In [None]:
df_genes = df_sub.copy()
n_genes = 40
gene_names=10
heatmap_order_col='rho_axis'
df_row = df_genes.iloc[:n_genes].copy()
df_row.drop_duplicates(subset='gene', keep='first', inplace=True)
if "gene" in df_row.columns:
    df_row = df_row.set_index('gene')
toplot = df_row.index[:gene_names]
df_row['annot'] = [c if c in toplot else np.nan for c in df_row.index]
df_row = df_row.sort_values(heatmap_order_col, ascending=True)

In [None]:
df_row.head()

In [None]:
min_max_quantiles=(0.02, 0.98)
if hasattr(adata_rna.X, "toarray"):
    df_expr_rna = adata_rna.X.toarray()
else: 
    df_expr_rna = adata_rna.X
df_expr_rna = pd.DataFrame(df_expr_rna, index=adata_rna.obs_names, columns=adata_rna.var_names).T
df_expr_rna = df_expr_rna.loc[df_row.index, df_col.index]
df_expr_norm_rna = (
    df_expr_rna
    .subtract(df_expr_rna.quantile(min_max_quantiles[0], axis=1), axis=0)
    .div(df_expr_rna.quantile(min_max_quantiles[1], axis=1) - df_expr_rna.quantile(min_max_quantiles[0], axis=1), axis=0)
)

## Tradeseq

In [None]:
# tradeseq
path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/tradeseq_brs/results_all.csv"
df_tradeseq = pd.read_csv(path)
# df_tradeseq = df_tradeseq[df_tradeseq['replicate'] != "CaT"].copy()
df_tradeseq.head()

In [None]:
common_genes = adata.var_names.intersection(spatial_mch.var_names)
df_tradeseq = df_tradeseq[df_tradeseq['gene'].isin(common_genes)].copy()

In [None]:
# df_tradeseq.shape

In [None]:
# fdr_col = "end_test_pvalue"
fdr_col = "association_pvalue"
fdr_thresh = 0.01
sig_counts = df_tradeseq.groupby(["celltype", 'replicate'], observed=True)\
                .apply(lambda x: ((x[fdr_col] < fdr_thresh)).sum())\
                .reset_index(name="n_sig")

fig, ax = plt.subplots(figsize=(6,4))
sns.barplot(data=sig_counts, x="celltype", y="n_sig", hue="replicate", palette="muted", ax=ax)
ax.set_ylabel(f"# Significant Genes (FDR<{fdr_thresh})")
ax.set_xlabel("Cell Type")
ax.set_title("Significant Spatially Variable Genes per Cell Type")
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
plt.tight_layout()
plt.show()

In [None]:
transfer_cols = ['brain_region_corr', 'MS_NORM', 'MS_SCORE', 'MS_compartment']
for col in transfer_cols:
    spatial_mch.obs[col] = adata.obs.loc[spatial_mch.obs.index, col].copy()

In [None]:
repr_order = ["CaH", "CaB", "Pu", "NAC"]
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/CPS/ms_tradeseq"

In [None]:
_ct = "STR D2 MSN"
fdr_col = "end_test_pvalue"
rho_axis_col = "rho_axis"
print(_ct)
df_sub = (
    df_tradeseq[df_tradeseq['celltype'] == _ct]
    .query(f"{fdr_col} < 0.01")
    .groupby(["gene", "replicate"])[[fdr_col, rho_axis_col]]
    .median()
    .sort_values(fdr_col, ascending=True)
    .reset_index()
    .drop_duplicates(subset='gene', keep='first')
)
adata_ct = adata[(adata.obs['Subclass'] == _ct) & (~adata.obs['MS_NORM'].isna())].copy()
adata_ct = adata_ct[adata_ct.obs['brain_region_corr'].isin(repr_order)].copy()
adata_ct = _downsample_reference(adata_ct, cluster_col="MS_compartment", max_cluster_size=10000, min_cluster_size=0)
adata_ct.obs['brain_region_corr'] = pd.Categorical(adata_ct.obs['brain_region_corr'], categories=repr_order, ordered=True)

rep_palette = {k : v for k, v in adata_ct.uns.get("brain_region_corr_palette", {}).items() if k in adata_ct.obs['brain_region_corr'].unique()}
expr_heatmap(adata_ct, df_sub, n_genes=30, gene_names=30, title=f"Top SPARK-X genes in {_ct}",
             rasterized=True, show=True, save=False, ylabel=None, color_splits="black",
             replicate_col="brain_region_corr", replicate_palette=rep_palette, heatmap_order_col=rho_axis_col,
             cmap='Reds', vmin=0, vmax=1, min_max_quantiles=(0.2, 0.98))

In [None]:
_ct = "STR D2 MSN"
fdr_col = "end_test_pvalue"
rho_axis_col = "rho_axis"
print(_ct)
df_sub = (
    df_tradeseq[df_tradeseq['celltype'] == _ct]
    .query(f"{fdr_col} < 0.05")
    .groupby(["gene", "replicate"])[[fdr_col, rho_axis_col]]
    .median()
    .sort_values(fdr_col, ascending=True)
    .reset_index()
    .drop_duplicates(subset='gene', keep='first')
)
adata_ct = spatial_mch[(spatial_mch.obs['Subclass'] == _ct) & (~spatial_mch.obs['MS_NORM'].isna())].copy()
adata_ct = adata_ct[adata_ct.obs['brain_region_corr'].isin(repr_order)].copy()
adata_ct.obs['brain_region_corr'] = pd.Categorical(adata_ct.obs['brain_region_corr'], categories=repr_order, ordered=True)
adata_ct = _downsample_reference(adata_ct, cluster_col="MS_compartment", max_cluster_size=10000, min_cluster_size=0)
rep_palette = {k : v for k, v in adata_ct.uns.get("brain_region_corr_palette", {}).items() if k in adata_ct.obs['brain_region_corr'].unique()}
expr_heatmap(adata_ct, df_sub, n_genes=20, gene_names=20, title=f"Top SPARK-X genes in {_ct}",
             rasterized=True, show=True, save=False, ylabel=None, image_path=image_path,
             replicate_col="brain_region_corr", replicate_palette=rep_palette, 
             cmap='parula', vmin=0, vmax=1, min_max_quantiles=(0.0, 0.98))

In [None]:
joint_cells = adata.obs_names.intersection(spatial_mch.obs_names)
adata_joint_rna = adata[joint_cells].copy()
adata_joint_mch = spatial_mch[joint_cells].copy()

In [None]:
_ct = "CN VIP GABA"
fdr_col = "end_test_pvalue"
rho_axis_col = "rho_axis"
print(_ct)
df_sub = (
    df_tradeseq[df_tradeseq['celltype'] == _ct]
    .query(f"{fdr_col} < 0.05")
    .groupby(["gene", "replicate"])[[fdr_col, rho_axis_col]]
    .median()
    .sort_values([fdr_col, rho_axis_col], ascending=True)
    .reset_index()
    .drop_duplicates(subset='gene', keep='first')
)
adata_rna = adata_joint_rna[(adata_joint_rna.obs['Subclass'] == _ct) & (~adata_joint_rna.obs['MS_NORM'].isna())].copy()

adata_rna = adata_rna[adata_rna.obs['brain_region_corr'].isin(repr_order)].copy()
adata_rna = _downsample_reference(adata_rna, cluster_col="MS_compartment", max_cluster_size=5000, min_cluster_size=0)
adata_rna.obs['brain_region_corr'] = pd.Categorical(adata_rna.obs['brain_region_corr'], categories=repr_order, ordered=True)

# adata_mch = adata_joint_mch[adata_rna.obs.index].copy()
adata_mch = adata_joint_mch[(adata_joint_mch.obs['Subclass'] == _ct) & (~adata_joint_mch.obs['MS_NORM'].isna())].copy()
adata_mch = adata_mch[adata_mch.obs['brain_region_corr'].isin(repr_order)].copy()
adata_mch.obs['brain_region_corr'] = pd.Categorical(adata_mch.obs['brain_region_corr'], categories=repr_order, ordered=True)

# adata_mch = _downsample_reference(adata_mch, cluster_col="MS_compartment", max_cluster_size=6000, min_cluster_size=0)

rep_palette = {k : adata_rna.uns["brain_region_corr_palette"].get(k, None) for k in repr_order}
joint_expr_heatmap(adata_rna, adata_mch, df_sub, n_genes=20, gene_names=20, title=f"Top SVG genes in {_ct}",
             rasterized=True, show=True, save=False, ylabel=None, color_splits="black", heatmap_order_col=rho_axis_col,
             replicate_col="brain_region_corr", replicate_palette=rep_palette,
             rna_cmap='Reds', vmin=0, vmax=1, min_max_quantiles=(0.1, 0.99))