In [None]:
# imports
import os
from pathlib import Path
import itertools 
from tqdm import tqdm

import numpy as np
import pandas as pd
import anndata as ad
import geopandas as gpd

import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
import seaborn as sns
import PyComplexHeatmap as pch
plt.rcParams['figure.dpi'] = 150

from spida.utilities.sd_utils import _get_obs_or_gene
from spida.utilities._ad_utils import normalize_adata
from scipy.stats import pearsonr
from statsmodels.stats.multitest import multipletests 

In [None]:
# parameters
ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_PFV8_annotated_v5.h5ad"
geom_store_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/regions/region_geometries.parquet"
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/regions/ms_corr_rna"
DONOR = None

In [None]:
adata = ad.read_h5ad(ad_path)
print(adata.shape)
df_geoms = gpd.read_parquet(geom_store_path)
print(df_geoms.shape)

image_path = Path(image_path)
image_path.mkdir(parents=True, exist_ok=True)

In [None]:
# Normalize Adata
adata.X = adata.layers['counts'].copy()
normalize_adata(
    adata, 
    log1p=True,
)

In [None]:
adata_msn = adata[~adata.obs['MSN_Groups'].isna()].copy()
adata_msn = adata_msn[~adata_msn.obs['MS_SCORE'].isna()].copy()
if DONOR is not None:
    adata_msn = adata_msn[adata_msn.obs['donor'] == DONOR].copy()
adata_msn

In [None]:
# Repeat for each brain region 

In [None]:
donors = adata_msn.obs['donor'].unique().tolist()
brain_regions = adata_msn.obs['brain_region'].unique().tolist()
replicates = adata_msn.obs['replicate'].unique().tolist()
skip = [("UWA7648", "CAT", "ucsd"), ("UWA7648", "CAT", "salk")]

In [None]:
for _region in brain_regions: 
    adata_reg = adata_msn[adata_msn.obs['brain_region'] == _region].copy()
    layer=None
    returns = {}
    for _gene in adata_reg.var_names:
        adata_reg, _drop_col = _get_obs_or_gene(adata_reg, _gene, layer) # get the column from obs or var
        df_obs = adata_reg.obs[["donor", "replicate", "brain_region", "Subclass", "Group", "MS_SCORE", "MS_compartment", _gene]].copy()
        if _drop_col:  # drop the column if it was added for plotting
            adata_reg.obs.drop(columns=[_gene], inplace=True)

        # Whether to filter out the zeros / NaNs
        # df_obs = df_obs[(~df_obs[_gene].isna()) & (df_obs[_gene] != 0)]
        n_obs = (df_obs['MS_compartment'] == "Striosome").sum()
        keep_idx = df_obs.loc[df_obs["MS_compartment"] == "Matrix"].sample(n_obs, random_state=42).index
        df_obs = df_obs.loc[keep_idx.tolist() + df_obs.loc[df_obs["MS_compartment"] == "Striosome"].index.tolist()]
        
        _stat, _p = pearsonr(df_obs['MS_SCORE'], df_obs[_gene])
        returns[_gene] = (_stat, _p)

    df_returns = pd.DataFrame.from_dict(returns, orient='index', columns=['pearsonr_stat', 'pearsonr_pval'])
    df_returns['pearsonr_pval_adj'] = multipletests(df_returns['pearsonr_pval'], method='fdr_bh', alpha=0.05, maxiter=1)[1]
    df_returns = df_returns.sort_values(by='pearsonr_pval_adj', ascending=True)
    df_returns.head()

    toplot = df_returns.index[:10]
    print(toplot)

    ncols = 5
    nrows = len(toplot) // ncols + int(len(toplot) % ncols > 0)
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3), constrained_layout=True)

    for i, _gene in enumerate(toplot):
        r = i // ncols
        c = i % ncols
        ax = axes[r, c]

        adata_reg, _drop_col = _get_obs_or_gene(adata_reg, _gene, layer) # get the column from obs or var
        df_obs = adata_reg.obs[["donor", "replicate", "brain_region", "Subclass", "Group", "MS_SCORE", "MS_compartment", _gene]].copy()
        if _drop_col:  # drop the column if it was added for plotting
            adata_reg.obs.drop(columns=[_gene], inplace=True)
        
        # df_obs = df_obs[(~df_obs[_gene].isna()) & (df_obs[_gene] != 0)]
        n_obs = (df_obs['MS_compartment'] == "Striosome").sum()
        keep_idx = df_obs.loc[df_obs["MS_compartment"] == "Matrix"].sample(n_obs, random_state=42).index
        df_obs = df_obs.loc[keep_idx.tolist() + df_obs.loc[df_obs["MS_compartment"] == "Striosome"].index.tolist()]
        
        _stat, _p = pearsonr(df_obs['MS_SCORE'], df_obs[_gene])
        slope, intercept = np.polyfit(df_obs['MS_SCORE'], df_obs[_gene], 1)
        

        ax.scatter(df_obs['MS_SCORE'], df_obs[_gene], alpha=0.1, s=1)
        x_vals = np.arange(df_obs['MS_SCORE'].min()*0.25, df_obs['MS_SCORE'].max()*0.5, 100)
        y_vals = intercept + slope * x_vals
        ax.plot(x_vals, y_vals, '--', color='red', alpha=0.5)
        ax.set_xlabel("MS_SCORE")
        ax.set_ylabel(_gene)
        ax.set_xlim(df_obs['MS_SCORE'].min()*0.25, df_obs['MS_SCORE'].max()*0.5)
        ax.set_ylim(df_obs[_gene].min() - 0.1, df_obs[_gene].max())
        ax.set_title(f"{_gene} Corr.\n r={_stat:.2f}, p={_p:.2e}")

    for j in range(i+1, nrows*ncols):
        r = j // ncols
        c = j % ncols
        ax = axes[r, c]
        ax.axis('off')

    plt.suptitle("Gene Correlations with MS_SCORE across all MSNs")
    plt.savefig(image_path/ f"{_region}_top10_gene_scatter_{DONOR if DONOR is not None else 'all_donors'}.png", dpi=300)
    plt.show()
    plt.close()

    for _gene in toplot:
        pbar = tqdm(itertools.product(donors, replicates))
        for counter, _i in enumerate(pbar):
            if _i in skip:
                # print(f"Skipping {_i}")
                continue
            _donor, _replicate = _i
            sub_geoms = df_geoms.loc[(df_geoms['brain_region'] == _region) & (df_geoms['donor'] == _donor) & (df_geoms['lab'] == _replicate)]
            sub_cells = adata_reg[(adata_reg.obs['donor'] == _donor) & (adata_reg.obs['replicate'] == _replicate)].obs.copy()
            expr = adata_reg[sub_cells.index, _gene].X.toarray()
            sub_cells = gpd.GeoDataFrame(sub_cells, geometry=gpd.points_from_xy(sub_cells['CENTER_X'], sub_cells['CENTER_Y']), crs=None)
            sub_cells[_gene] = expr

            fig, axes = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)
            ax = axes[0]

            neg_scores = sub_cells.loc[sub_cells['MS_SCORE'] < 0, "MS_SCORE"]
            pos_scores = sub_cells.loc[sub_cells['MS_SCORE'] > 0, "MS_SCORE"]
            ticks = np.concatenate((np.percentile(neg_scores, [10, 40, 70]), [0], np.percentile(pos_scores, [30, 60, 90])))

            try: 
                norm = TwoSlopeNorm(vmin=np.percentile(neg_scores, 5), vcenter=0, vmax=np.percentile(pos_scores, 95))
            except ValueError as e:
                norm = None
            sub_cells.plot(ax=ax, column="MS_SCORE", cmap='coolwarm_r', norm=norm, edgecolor='none', markersize=10, alpha=0.5, legend_kwds={"label" : "MS_SCORE"}, legend=True).axis("off");
            cbar_ax = fig.axes[len(axes)] 
            cbar_ax.set_yticks(ticks)
            cbar_ax.set_yticklabels(ticks.astype(int).astype(str))
            cbar_ax.tick_params(labelsize=12)

            sub_geoms[sub_geoms["type"] != "White_Matter"].plot(ax=ax, column="type", color=sub_geoms['type_color'], edgecolor='none', legend=True, alpha=0.1).axis("off");
            sub_geoms[sub_geoms["type"] != "White_Matter"].plot(ax=ax, column="type", color="none", edgecolor='black', legend=True).axis("off");
            ax.legend()
            ax.set_title(f"{_donor} {_region} {_replicate}")

            ax = axes[1]
            try: 
                norm = TwoSlopeNorm(vmin=sub_cells[_gene].min(), vcenter=np.percentile(sub_cells[_gene], 50), vmax=np.percentile(sub_cells[_gene], 90))
            except ValueError as e: 
                norm = None
            sub_cells.plot(ax=ax, column=_gene, cmap='RdYlGn_r', norm=norm, edgecolor='none', markersize=10, alpha=0.9, legend=True).axis("off");
            cbar_ax.tick_params(labelsize=12)

            sub_geoms = sub_geoms.loc[sub_geoms['type'] != "White_Matter"].copy()
            # sub_geoms.plot(ax=ax, column="type", color=sub_geoms['type_color'], edgecolor='none', legend=True, alpha=0.05).axis("off");
            sub_geoms.plot(ax=ax, column="type", color="none", edgecolor='black', legend=True).axis("off");
            ax.legend()
            ax.set_title(f"{_donor} {_region} {_replicate}")

            plt.savefig(image_path/ f"spatial_{_gene}_{_donor}_{_region}_{_replicate}.png", dpi=300)
            # plt.show()
            plt.close()

    df_col = adata_reg.obs[['Group', 'MS_SCORE', 'MS_compartment']].copy()
    if df_col.shape[0] > 50000:
        df_col = df_col.sample(50000, random_state=42)
    df_col = df_col.sort_values('MS_SCORE')
    print(df_col.shape)

    df_row = df_returns.iloc[:20].sort_values(by='pearsonr_stat', ascending=False)
    df_row['annot'] = [c if c in toplot else np.nan for c in df_row.index]
    print(df_row.shape)

    df_expr = adata_reg.X.toarray()
    df_expr = pd.DataFrame(df_expr, index=adata_reg.obs_names, columns=adata_reg.var_names).T
    df_expr = df_expr.loc[df_row.index, df_col.index]
    df_expr_norm = df_expr.subtract(df_expr.min(axis=1), axis=0).div(df_expr.max(axis=1) - df_expr.min(axis=1), axis=0)
    print(df_expr_norm.shape)

    ms_score_norm = TwoSlopeNorm(vmin=df_col['MS_SCORE'].values.min() * 0.1, vcenter=0, vmax=df_col['MS_SCORE'].values.max() * 0.1)

    col_ha = pch.HeatmapAnnotation(
        label=pch.anno_label(
            df_col['MS_compartment'], merge=True, rotation=90, extend=True,
            colors={"Matrix": "blue", "Striosome": "red"}, 
        ),
        MS_COMPARTMENT=pch.anno_simple(df_col['MS_compartment'], colors={"Matrix": "blue", "Striosome": "red"}),
        MS_SCORE=pch.anno_simple(df_col['MS_SCORE'], cmap="coolwarm_r", norm=ms_score_norm),
        verbose=1, axis=1, plot_legend=False
    )

    left_ha = pch.HeatmapAnnotation(
        label=pch.anno_label(
            df_row['annot'], merge=True, rotation=0, extend=False,
            colors="black", relpos=(1, 0.5), 
        ),
        # Genes=pch.anno_simple(df_row[0]),
        verbose=1, axis=0
    )

    plt.figure(figsize=(8,6))
    cm = pch.ClusterMapPlotter(
        data=df_expr_norm,
        top_annotation=col_ha,
        left_annotation=left_ha,
        row_cluster=False,
        col_cluster=False,
        row_dendrogram=False,
        label="Expression",
        cmap='plasma',
        rasterized=True, 
        ylabel="Genes",
        xlabel="Cells",
        vmax=0.5,
    )

    plt.savefig(image_path / f"{_region}_top20_genes_heatmap_{DONOR if DONOR is not None else 'all'}.png", dpi=300)
    plt.show()
    plt.close()



In [None]:
# layer=None
# returns = {}
# for _gene in adata_msn.var_names:
#     adata_msn, _drop_col = _get_obs_or_gene(adata_msn, _gene, layer) # get the column from obs or var
#     df_obs = adata_msn.obs[["donor", "replicate", "brain_region", "Subclass", "Group", "MS_SCORE", "MS_compartment", _gene]].copy()
#     if _drop_col:  # drop the column if it was added for plotting
#         adata_msn.obs.drop(columns=[_gene], inplace=True)
    
#     # Whether to filter out the zeros / NaNs
#     # df_obs = df_obs[(~df_obs[_gene].isna()) & (df_obs[_gene] != 0)]
#     n_obs = (df_obs['MS_compartment'] == "Striosome").sum()
#     keep_idx = df_obs.loc[df_obs["MS_compartment"] == "Matrix"].sample(n_obs, random_state=42).index
#     df_obs = df_obs.loc[keep_idx.tolist() + df_obs.loc[df_obs["MS_compartment"] == "Striosome"].index.tolist()]
    
#     _stat, _p = pearsonr(df_obs['MS_SCORE'], df_obs[_gene])
#     returns[_gene] = (_stat, _p)

In [None]:
# toplot = df_returns.index[:10]
# toplot

In [None]:
# df_returns = pd.DataFrame.from_dict(returns, orient='index', columns=['pearsonr_stat', 'pearsonr_pval'])
# df_returns['pearsonr_pval_adj'] = multipletests(df_returns['pearsonr_pval'], method='fdr_bh', alpha=0.05, maxiter=1)[1]
# df_returns = df_returns.sort_values(by='pearsonr_pval_adj', ascending=True)
# df_returns.head()

In [None]:
# ncols = 5
# nrows = len(toplot) // ncols + int(len(toplot) % ncols > 0)
# fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3), constrained_layout=True)

# for i, _gene in enumerate(toplot):
#     r = i // ncols
#     c = i % ncols
#     ax = axes[r, c]

#     adata_msn, _drop_col = _get_obs_or_gene(adata_msn, _gene, layer) # get the column from obs or var
#     df_obs = adata_msn.obs[["donor", "replicate", "brain_region", "Subclass", "Group", "MS_SCORE", "MS_compartment", _gene]].copy()
#     if _drop_col:  # drop the column if it was added for plotting
#         adata_msn.obs.drop(columns=[_gene], inplace=True)
    
#     # df_obs = df_obs[(~df_obs[_gene].isna()) & (df_obs[_gene] != 0)]
#     n_obs = (df_obs['MS_compartment'] == "Striosome").sum()
#     keep_idx = df_obs.loc[df_obs["MS_compartment"] == "Matrix"].sample(n_obs, random_state=42).index
#     df_obs = df_obs.loc[keep_idx.tolist() + df_obs.loc[df_obs["MS_compartment"] == "Striosome"].index.tolist()]
    
#     _stat, _p = pearsonr(df_obs['MS_SCORE'], df_obs[_gene])
#     slope, intercept = np.polyfit(df_obs['MS_SCORE'], df_obs[_gene], 1)
    

#     ax.scatter(df_obs['MS_SCORE'], df_obs[_gene], alpha=0.1, s=1)
#     x_vals = np.arange(df_obs['MS_SCORE'].min()*0.25, df_obs['MS_SCORE'].max()*0.5, 100)
#     y_vals = intercept + slope * x_vals
#     ax.plot(x_vals, y_vals, '--', color='red', alpha=0.5)
#     ax.set_xlabel("MS_SCORE")
#     ax.set_ylabel(_gene)
#     ax.set_xlim(df_obs['MS_SCORE'].min()*0.25, df_obs['MS_SCORE'].max()*0.5)
#     ax.set_ylim(df_obs[_gene].min() - 0.1, df_obs[_gene].max())
#     ax.set_title(f"{_gene} Corr.\n r={_stat:.2f}, p={_p:.2e}")

# for j in range(i+1, nrows*ncols):
#     r = j // ncols
#     c = j % ncols
#     ax = axes[r, c]
#     ax.axis('off')

# plt.suptitle("Gene Correlations with MS_SCORE across all MSNs")
# plt.savefig(image_path/ f"top10_gene_scatter_{DONOR if DONOR is not None else 'all_donors'}.png", dpi=300)
# plt.show()
# plt.close()

In [None]:
# for _gene in toplot:
#     pbar = tqdm(itertools.product(donors, brain_regions, replicates))
#     for counter, _i in enumerate(pbar):
#         if _i in skip:
#             # print(f"Skipping {_i}")
#             continue
#         _donor, _brain_region, _replicate = _i
#         sub_geoms = df_geoms.loc[(df_geoms['brain_region'] == _brain_region) & (df_geoms['donor'] == _donor) & (df_geoms['lab'] == _replicate)]
#         sub_cells = adata_msn[(adata_msn.obs['brain_region'] == _brain_region) & (adata_msn.obs['donor'] == _donor) & (adata_msn.obs['replicate'] == _replicate)].obs.copy()
#         expr = adata_msn[sub_cells.index, _gene].X.toarray()
#         sub_cells = gpd.GeoDataFrame(sub_cells, geometry=gpd.points_from_xy(sub_cells['CENTER_X'], sub_cells['CENTER_Y']), crs=None)
#         sub_cells[_gene] = expr

#         fig, axes = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)
#         ax = axes[0]

#         neg_scores = sub_cells.loc[sub_cells['MS_SCORE'] < 0, "MS_SCORE"]
#         pos_scores = sub_cells.loc[sub_cells['MS_SCORE'] > 0, "MS_SCORE"]
#         ticks = np.concatenate((np.percentile(neg_scores, [10, 40, 70]), [0], np.percentile(pos_scores, [30, 60, 90])))

#         try: 
#             norm = TwoSlopeNorm(vmin=np.percentile(neg_scores, 5), vcenter=0, vmax=np.percentile(pos_scores, 95))
#         except ValueError as e:
#             norm = None
#         sub_cells.plot(ax=ax, column="MS_SCORE", cmap='coolwarm_r', norm=norm, edgecolor='none', markersize=10, alpha=0.5, legend_kwds={"label" : "MS_SCORE"}, legend=True).axis("off");
#         cbar_ax = fig.axes[len(axes)] 
#         cbar_ax.set_yticks(ticks)
#         cbar_ax.set_yticklabels(ticks.astype(int).astype(str))
#         cbar_ax.tick_params(labelsize=12)

#         sub_geoms[sub_geoms["type"] != "White_Matter"].plot(ax=ax, column="type", color=sub_geoms['type_color'], edgecolor='none', legend=True, alpha=0.1).axis("off");
#         sub_geoms[sub_geoms["type"] != "White_Matter"].plot(ax=ax, column="type", color="none", edgecolor='black', legend=True).axis("off");
#         ax.legend()
#         ax.set_title(f"{_donor} {_brain_region} {_replicate}")

#         ax = axes[1]
#         try: 
#             norm = TwoSlopeNorm(vmin=sub_cells[_gene].min(), vcenter=np.percentile(sub_cells[_gene], 50), vmax=np.percentile(sub_cells[_gene], 90))
#         except ValueError as e: 
#             norm = None
#         sub_cells.plot(ax=ax, column=_gene, cmap='RdYlGn_r', norm=norm, edgecolor='none', markersize=10, alpha=0.9, legend=True).axis("off");
#         cbar_ax.tick_params(labelsize=12)

#         sub_geoms = sub_geoms.loc[sub_geoms['type'] != "White_Matter"].copy()
#         # sub_geoms.plot(ax=ax, column="type", color=sub_geoms['type_color'], edgecolor='none', legend=True, alpha=0.05).axis("off");
#         sub_geoms.plot(ax=ax, column="type", color="none", edgecolor='black', legend=True).axis("off");
#         ax.legend()
#         ax.set_title(f"{_donor} {_brain_region} {_replicate}")

#         plt.savefig(image_path/ f"spatial_{_gene}_{_donor}_{_brain_region}_{_replicate}.png", dpi=300)
#         # plt.show()
#         plt.close()

In [None]:
# df_col = adata_msn.obs[['Group', 'MS_SCORE', 'MS_compartment']].copy()
# if df_col.shape[0] > 50000:
#     df_col = df_col.sample(50000, random_state=42)
# df_col = df_col.sort_values('MS_SCORE')
# print(df_col.shape)

# df_row = df_returns.iloc[:20].sort_values(by='pearsonr_stat', ascending=False)
# df_row['annot'] = [c if c in toplot else np.nan for c in df_row.index]
# print(df_row.shape)

# df_expr = adata_msn.X.toarray()
# df_expr = pd.DataFrame(df_expr, index=adata_msn.obs_names, columns=adata_msn.var_names).T
# df_expr = df_expr.loc[df_row.index, df_col.index]
# df_expr_norm = df_expr.subtract(df_expr.min(axis=1), axis=0).div(df_expr.max(axis=1) - df_expr.min(axis=1), axis=0)
# print(df_expr_norm.shape)

In [None]:
# ms_score_norm = TwoSlopeNorm(vmin=df_col['MS_SCORE'].values.min() * 0.1, vcenter=0, vmax=df_col['MS_SCORE'].values.max() * 0.1)

# col_ha = pch.HeatmapAnnotation(
#     label=pch.anno_label(
#         df_col['MS_compartment'], merge=True, rotation=90, extend=True,
#         colors={"Matrix": "blue", "Striosome": "red"}, 
#     ),
#     MS_COMPARTMENT=pch.anno_simple(df_col['MS_compartment'], colors={"Matrix": "blue", "Striosome": "red"}),
#     MS_SCORE=pch.anno_simple(df_col['MS_SCORE'], cmap="coolwarm_r", norm=ms_score_norm),
#     verbose=1, axis=1, plot_legend=False
# )

# left_ha = pch.HeatmapAnnotation(
#     label=pch.anno_label(
#         df_row['annot'], merge=True, rotation=0, extend=False,
#         colors="black", relpos=(1, 0.5), 
#     ),
#     # Genes=pch.anno_simple(df_row[0]),
#     verbose=1, axis=0
# )

# plt.figure(figsize=(8,6))
# cm = pch.ClusterMapPlotter(
#     data=df_expr_norm,
#     top_annotation=col_ha,
#     left_annotation=left_ha,
#     row_cluster=False,
#     col_cluster=False,
#     row_dendrogram=False,
#     label="Expression",
#     cmap='plasma',
#     rasterized=True, 
#     ylabel="Genes",
#     xlabel="Cells",
#     vmax=0.5,
# )

# plt.savefig(image_path / f"top20_genes_heatmap_{DONOR if DONOR is not None else 'all'}.png", dpi=300)
# plt.show()
# plt.close()