In [None]:
# imports
import os
from pathlib import Path
import itertools 
from tqdm import tqdm

import numpy as np
import pandas as pd
import anndata as ad
import geopandas as gpd

import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm
import seaborn as sns
import PyComplexHeatmap as pch
plt.rcParams['figure.dpi'] = 150

from spida.utilities.sd_utils import _get_obs_or_gene
from spida.utilities._ad_utils import normalize_adata
from scipy.stats import pearsonr
from statsmodels.stats.multitest import multipletests 

In [None]:
# parameters
ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_PFV8_annotated_v5.h5ad"
geom_store_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/regions/region_geometries.parquet"
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/regions/ms_corr_rna"
DONOR = None

In [None]:
adata = ad.read_h5ad(ad_path)
print(adata.shape)
df_geoms = gpd.read_parquet(geom_store_path)
print(df_geoms.shape)

image_path = Path(image_path)
image_path.mkdir(parents=True, exist_ok=True)

In [None]:
# Normalize Adata
adata.X = adata.layers['counts'].copy()
normalize_adata(
    adata, 
    log1p=True,
)

In [None]:
groups = adata.obs['Group'].unique().tolist()
for _group in groups:
    print(_group, adata.obs['Group'].value_counts()[_group])
    adata_sub = adata[adata.obs['Group'] == _group].copy()

    layer=None
    returns = {}
    for _gene in adata_sub.var_names:
        adata_sub, _drop_col = _get_obs_or_gene(adata_sub, _gene, layer) # get the column from obs or var
        df_obs = adata_sub.obs[["donor", "replicate", "brain_region", "Subclass", "Group", "MS_SCORE", "MS_compartment", _gene]].copy()
        if _drop_col:  # drop the column if it was added for plotting
            adata_sub.obs.drop(columns=[_gene], inplace=True)

        # Filter out NA MS_SCORE cells (like WM cells)
        df_obs = df_obs[~df_obs['MS_SCORE'].isna()]
        # Whether to filter out the zeros / NaNs on the gene
        # df_obs = df_obs[(~df_obs[_gene].isna()) & (df_obs[_gene] != 0)]
        n_obs = (df_obs['MS_compartment'] == "Striosome").sum()
        keep_idx = df_obs.loc[df_obs["MS_compartment"] == "Matrix"].sample(n_obs, random_state=42).index
        df_obs = df_obs.loc[keep_idx.tolist() + df_obs.loc[df_obs["MS_compartment"] == "Striosome"].index.tolist()]
        
        _stat, _p = pearsonr(df_obs['MS_SCORE'], df_obs[_gene])
        returns[_gene] = (_stat, _p)

    df_returns = pd.DataFrame.from_dict(returns, orient='index', columns=['pearsonr_stat', 'pearsonr_pval'])
    df_returns['pearsonr_pval_adj'] = multipletests(df_returns['pearsonr_pval'], method='fdr_bh', alpha=0.05, maxiter=1)[1]
    df_returns = df_returns.sort_values(by='pearsonr_pval_adj', ascending=True)
    df_returns.head()
    df_returns.to_csv(image_path / f"{_group}_ms_corr_rna_results.csv")

    toplot = df_returns.index[:10]
    print(toplot)

    ncols = 5
    nrows = len(toplot) // ncols + int(len(toplot) % ncols > 0)
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3), constrained_layout=True)

    for i, _gene in enumerate(toplot):
        r = i // ncols
        c = i % ncols
        ax = axes[r, c]

        adata_sub, _drop_col = _get_obs_or_gene(adata_sub, _gene, layer) # get the column from obs or var
        df_obs = adata_sub.obs[["donor", "replicate", "brain_region", "Subclass", "Group", "MS_SCORE", "MS_compartment", _gene]].copy()
        if _drop_col:  # drop the column if it was added for plotting
            adata_sub.obs.drop(columns=[_gene], inplace=True)
        
        # Filter out NA MS_SCORE cells (like WM cells)
        df_obs = df_obs[~df_obs['MS_SCORE'].isna()]

        # df_obs = df_obs[(~df_obs[_gene].isna()) & (df_obs[_gene] != 0)]
        
        n_obs = (df_obs['MS_compartment'] == "Striosome").sum()
        keep_idx = df_obs.loc[df_obs["MS_compartment"] == "Matrix"].sample(n_obs, random_state=42).index
        df_obs = df_obs.loc[keep_idx.tolist() + df_obs.loc[df_obs["MS_compartment"] == "Striosome"].index.tolist()]
        
        _stat, _p = pearsonr(df_obs['MS_SCORE'], df_obs[_gene])
        slope, intercept = np.polyfit(df_obs['MS_SCORE'], df_obs[_gene], 1)
        

        ax.scatter(df_obs['MS_SCORE'], df_obs[_gene], alpha=0.1, s=1)
        x_vals = np.arange(df_obs['MS_SCORE'].min()*0.25, df_obs['MS_SCORE'].max()*0.5, 100)
        y_vals = intercept + slope * x_vals
        ax.plot(x_vals, y_vals, '--', color='red', alpha=0.5)
        ax.set_xlabel("MS_SCORE")
        ax.set_ylabel(_gene)
        ax.set_xlim(df_obs['MS_SCORE'].min()*0.25, df_obs['MS_SCORE'].max()*0.5)
        ax.set_ylim(df_obs[_gene].min() - 0.1, df_obs[_gene].max())
        ax.set_title(f"{_gene} Corr.\n r={_stat:.2f}, p={_p:.2e}")

    for j in range(i+1, nrows*ncols):
        r = j // ncols
        c = j % ncols
        ax = axes[r, c]
        ax.axis('off')

    plt.suptitle(f"Gene Correlations with MS_SCORE for {_group}")
    plt.savefig(image_path/ f"{_group}_top10_gene_scatter_{DONOR if DONOR is not None else 'all_donors'}.png", dpi=300)
    # plt.show()
    plt.close()

    df_col = adata_sub.obs[['Group', 'MS_SCORE', 'MS_compartment']].copy()
    if df_col.shape[0] > 50000:
        df_col = df_col.sample(50000, random_state=42)
    df_col = df_col.sort_values('MS_SCORE')
    print(df_col.shape)

    df_row = df_returns.iloc[:20].sort_values(by='pearsonr_stat', ascending=False)
    df_row['annot'] = [c if c in toplot else np.nan for c in df_row.index]
    print(df_row.shape)

    df_expr = adata_sub.X.toarray()
    df_expr = pd.DataFrame(df_expr, index=adata_sub.obs_names, columns=adata_sub.var_names).T
    df_expr = df_expr.loc[df_row.index, df_col.index]
    df_expr_norm = df_expr.subtract(df_expr.min(axis=1), axis=0).div(df_expr.max(axis=1) - df_expr.min(axis=1), axis=0)
    print(df_expr_norm.shape)

    ms_score_norm = TwoSlopeNorm(vmin=df_col['MS_SCORE'].values.min() * 0.1, vcenter=0, vmax=df_col['MS_SCORE'].values.max() * 0.1)

    col_ha = pch.HeatmapAnnotation(
        label=pch.anno_label(
            df_col['MS_compartment'], merge=True, rotation=90, extend=True,
            colors={"Matrix": "blue", "Striosome": "red"}, 
        ),
        MS_COMPARTMENT=pch.anno_simple(df_col['MS_compartment'], colors={"Matrix": "blue", "Striosome": "red"}),
        MS_SCORE=pch.anno_simple(df_col['MS_SCORE'], cmap="coolwarm_r", norm=ms_score_norm),
        verbose=1, axis=1, plot_legend=False
    )

    left_ha = pch.HeatmapAnnotation(
        label=pch.anno_label(
            df_row['annot'], merge=True, rotation=0, extend=False,
            colors="black", relpos=(1, 0.5), 
        ),
        # Genes=pch.anno_simple(df_row[0]),
        verbose=1, axis=0
    )

    plt.figure(figsize=(8,6))
    cm = pch.ClusterMapPlotter(
        data=df_expr_norm,
        top_annotation=col_ha,
        left_annotation=left_ha,
        row_cluster=False,
        col_cluster=False,
        row_dendrogram=False,
        label="Expression",
        cmap='plasma',
        rasterized=True, 
        ylabel="Genes",
        xlabel="Cells",
        vmax=0.5,
    )

    plt.suptitle(f"Top 20 Genes Correlated with MS_SCORE for {_group}")
    plt.savefig(image_path / f"{_group}_top20_genes_heatmap_{DONOR if DONOR is not None else 'all'}.png", dpi=300)
    plt.show()
    plt.close()