In [None]:
import os
from pathlib import Path
from rich import print as rprint, inspect

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
from spida.I.allcools import normalize_adata

import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# Print versions of important packages
print(f"Python: {os.sys.version}")
print(f"NumPy: {np.__version__}")
print(f"Pandas: {pd.__version__}")
print(f"Anndata: {ad.__version__}")
print(f"Scanpy: {sc.__version__}")

## functions

In [None]:
# Move to spida.pl
from spida.pl._utils import add_color_scheme
from spida.pl import plot_categorical, plot_continuous, categorical_scatter
def plot_categorical(
    adata : ad.AnnData | str, 
    ax=None,
    coord_base='tsne',
    cluster_col='MajorType',
    palette_path=None,
    highlight_group : str | list | None = None,
    color_unknown : str = "#D0D0D0",
    is_background : bool = False,
    axis_format='tiny',
    alpha=0.7,
    coding=True,
    id_marker=True,
    output=None,
    show=True,
    figsize=(4, 3.5),
    sheet_name=None,
    ncol=None,
    fontsize=5,
    legend_fontsize=5,
    legend_kws=None,
    legend_title_fontsize=5,
    marker_fontsize=4,
    marker_pad=0.1,
    linewidth=0.5,
    ret=False,
    text_anno:bool = False,
    text_kws=None,
    **kwargs
):
    """
    Plot cluster.

    Parameters
    ----------
    adata_path :
    ax :
    coord_base :
    cluster_col :
    palette_path :
    highlight_group : str | list
        A group (or list of groups) in the annotation to highlight by setting all other groups to background color (light gray).
    color_unknown : str
        Color to use for "unknown" or NaN categories. Default is light gray (#D0D0D0).
    coding :
    output :
    show :
    figsize :
    sheet_name :
    ncol :
    fontsize :
    legend_fontsize : int
        legend fontsize, default 5
    legend_kws: dict
        kwargs passed to ax.legend
    legend_title_fontsize: int
        legend title fontsize, default 5
    marker_fontsize: int
        Marker fontsize, default 3
        if id_marker is True, and coding is True. legend marker will be a circle (or rectangle) with code
    linewidth : float
        Line width of the legend marker (circle or rectangle), default 0.5
    text_anon: bool
        Whether to add text annotation for each cluster, default False.
	ret : bool 
		Whether to return the plot and colors, default False.
    kwargs : dict
        set text_anno=None to plot clustering without text annotations,
        coding=True to plot clustering without code annotations,
        set show_legend=False to remove the legend

    Returns
    -------

    """
    if sheet_name is None: # for getting color scheme from excel file
        sheet_name=cluster_col
    if isinstance(adata,str): # getting adata
        adata=ad.read_h5ad(adata,backed='r')
    if not isinstance(adata.obs[cluster_col].dtype, pd.CategoricalDtype): # make sure cluster_col is categorical 
        adata.obs[cluster_col] = adata.obs[cluster_col].astype('category')
    # get palette
    if palette_path is not None:
        if isinstance(palette_path,str):
            colors=pd.read_excel(os.path.expanduser(palette_path),sheet_name=sheet_name,index_col=0).Hex.to_dict()
            keys=list(colors.keys())
            existed_vals=adata.obs[cluster_col].unique().tolist()
            for k in existed_vals:
                if k not in keys:
                    colors[k]='gray'
            for k in keys:
                if k not in existed_vals:
                    del colors[k]
        else:
            colors=palette_path
        adata.uns[cluster_col + '_colors'] = [colors.get(k, 'grey') for k in adata.obs[cluster_col].cat.categories.tolist()]
    else:
        if f'{cluster_col}_colors' not in adata.uns:
            colors = add_color_scheme(adata, cluster_col, palette_key=f"{cluster_col}_colors")
        else:
            colors={cluster:color for cluster,color in zip(adata.obs[cluster_col].cat.categories.tolist(),adata.uns[f'{cluster_col}_colors'])}
    data = adata[adata.obs[cluster_col].notna()]
    if highlight_group is not None:
        if isinstance(highlight_group, str):
            highlight_group = [highlight_group]
        all_groups = adata.obs[cluster_col].cat.categories.tolist()
        for group in all_groups:
            if group not in highlight_group:
                colors[group] = color_unknown
                data.obs.loc[data.obs[cluster_col] == group, cluster_col] = np.nan
        data = data[data.obs[cluster_col].notna(), :]
        data.obs[cluster_col] = data.obs[cluster_col].cat.remove_unused_categories()
    colors = {k: v for k, v in colors.items() if k in data.obs[cluster_col].cat.categories.tolist()}
    if is_background:
        bg_colors = {k: color_unknown for k in colors.keys()}
        bg_colors['unknown'] = color_unknown
        colors = bg_colors        

    hue=cluster_col
    text_anno = cluster_col if text_anno else None
    text_kws = {} if text_kws is None else text_kws
    text_kws.setdefault("fontsize", fontsize)
    kwargs.setdefault("hue",hue)
    kwargs.setdefault("text_anno", text_anno)
    kwargs.setdefault("text_kws", text_kws)
    kwargs.setdefault("luminance", 0.65)
    kwargs.setdefault("dodge_text", False)
    kwargs.setdefault("axis_format", axis_format)
    kwargs.setdefault("show_legend", True)
    kwargs.setdefault("marker_fontsize", marker_fontsize)
    kwargs.setdefault("marker_pad", marker_pad)
    kwargs.setdefault("linewidth", linewidth)
    kwargs.setdefault("alpha", alpha)
    kwargs["coding"]=coding
    kwargs["id_marker"]=id_marker
    legend_kws={} if legend_kws is None else legend_kws
    default_lgd_kws=dict(
        fontsize=legend_fontsize,
        title=cluster_col,title_fontsize=legend_title_fontsize)
    if ncol is not None:
        default_lgd_kws['ncol']=ncol
    for k in default_lgd_kws:
        legend_kws.setdefault(k, default_lgd_kws[k])
    kwargs.setdefault("dodge_kws", {
            "arrowprops": {
                "arrowstyle": "->",
                "fc": 'grey',
                "ec": "none",
                "connectionstyle": "angle,angleA=-90,angleB=180,rad=5",
            },
            'autoalign': 'xy'})
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize, dpi=300)
    p = categorical_scatter(
        data=data, #adata[adata.obs[cluster_col].notna(),],
        ax=ax,
        coord_base=coord_base,
        palette=colors,
        legend_kws=legend_kws,
        **kwargs)

    if output is not None:
        plt.savefig(os.path.expanduser(output),bbox_inches='tight',dpi=300)
    if show:
        plt.show()
    
    if ret:
        return p, colors
    else: 
        return None, None

# Read

In [None]:
# adata_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_PFV8_annotated_v5.h5ad"
adata_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
adata = ad.read_h5ad(adata_path)
adata

In [None]:
adata_cp = adata.copy()
adata_cp

In [None]:
# spatial_mch = ad.read_h5ad("/home/x-aklein2/projects/aklein/BICAN/BG/data/methylation/BG_mCH_Imp_SubR.h5ad")
# for k, v in adata.uns.items(): 
#     if "colors" in k or "palette" in k: 
#         spatial_mch.uns[k] = v
# spatial_mch.obsm['spatial'] = spatial_mch.obs[['CENTER_X', 'CENTER_Y']].values
# spatial_mch

In [None]:
from spida.P.setup_adata import _calc_embeddings
inspect(_calc_embeddings)
_calc_embeddings(adata, layer="volume_norm", key_added="all_", knn=35, p_cutoff=0.05, run_harmony=True, batch_key="replicate")
fig, axes = plt.subplots(1, 3, figsize=(15,4), dpi=300)
plot_categorical(adata, coord_base='all_umap', cluster_col='brain_region', text_anno=True, coding=True, ax=axes[0], show=False)
plot_categorical(adata, coord_base='all_umap', cluster_col='Subclass', text_anno=True, coding=True, ax=axes[1], show=False)
plot_categorical(adata, coord_base='all_umap', cluster_col='Group', text_anno=True, coding=True, ax=axes[2], show=False)
plt.show()

In [None]:
ad_cps = ad.read_h5ad('/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_ALL/BG_cps_all.h5ad')
ad_cps

In [None]:
for i, _celltype in enumerate(adata.obs['Group'].cat.categories.tolist()):
    print(_celltype)
    adata_sub = adata[adata.obs['Group'] == _celltype].copy()
    if adata_sub.shape[0] < 100: 
        print(f"Skipping {_celltype} because it has less than 100 cells.")
        continue
    adata_sub.obs['brain_region'] = adata_sub.obs['brain_region'].cat.remove_unused_categories()
    if adata_sub.obs['brain_region'].nunique() < 2:
        print(f"Skipping {_celltype} because it has less than 2 brain regions.")
        continue
    normalize_adata(adata_sub, layer="counts")
    
    sc.tl.rank_genes_groups(adata_sub, groupby="brain_region", method="wilcoxon")
    # sc.tl.dendrogram(adata_sub, groupby="brain_region", key_added="dendrogram_brain_region")
    sc.pl.rank_genes_groups_heatmap(adata_sub, n_genes=5, dendrogram=False)
    
    # if i > 10: 
    #     break
    # sc.tl.rank_genes_groups(
    #     adata,
    #     groupby='Subclass',
    #     reference='rest',
    #     groups=[_celltype],
    #     method='wilcoxon',
    #     n_genes=adata.shape[1],
    #     key_added=f'rank_genes_{_celltype}'
    # )

## Plots

In [None]:
plt.rcParams["axes.facecolor"] = "white"
plt.rcParams["figure.facecolor"] = "white"
plt.rcParams["figure.dpi"] = 300

In [None]:
# GENES to Plot
genes_to_plot = [
    "PENK", "PDE10A", "SYNPR", "ATP2B2", "PDYN", "LAMP5", "TAC1","WFS1",
]

In [None]:
def plot_gene_groups(
    adata: ad.AnnData,
    gene: str,
    groups: list[str],
    regions: list[str],
    group_level: str = "Group",
    labs: list[str] = None,
    donors: list[str] = None,
    layers = None,
    layer_norm = False,
    hue_norm = 0.8,
): 
    if gene not in adata.var_names:
        print(f"Gene {gene} not found in adata.var_names")
        return
    adata = adata[adata.obs['brain_region'].isin(regions)].copy()
    if layers is not None: 
        adata.X = adata.layers[layers].copy()
    if layer_norm:
        normalize_adata(adata, log1p=True)
    
    if labs is None: 
        labs = adata.obs['replicate'].unique().tolist()
    if donors is None:
        donors = adata.obs['donor'].unique().tolist()

    nrows = len(labs)
    ncols = len(regions)
    nplots = len(donors)
    
    for i, donor in enumerate(donors):
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3), dpi=300)
        # adata_don = adata[adata.obs['donor'] == donor].copy()
        # hue_min = adata_don[:, gene].X.min()
        # hue_max = adata_don[:, gene].X.max() * 0.9
        for l, lab in enumerate(labs):
            adata_don = adata[(adata.obs['donor'] == donor) & (adata.obs['replicate'] == lab)].copy()
            hue_min = adata_don[:, gene].X.min()
            hue_max = adata_don[:, gene].X.max() * hue_norm
            for r, region in enumerate(regions):
                ax = axes[l, r] if nrows > 1 and ncols > 1 else axes[max(l, r)]
                adata_sub = adata[(adata.obs['donor'] == donor) & (adata.obs['brain_region'] == region) & (adata.obs['replicate'] == lab)].copy()
                categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax)
                adata_plot = adata_sub[adata_sub.obs[group_level].isin(groups)].copy()
                if adata_plot.shape[0] != 0:
                    plot_continuous(adata_plot, coord_base="spatial", color_by=gene, cmap="YlOrRd", ax=ax, show=False, hue_norm=(hue_min, hue_max)) 
                ax.set_title(f"{region}")
        plt.suptitle(f"Gene: {gene} Donor: {donor}, Expression", y=1.02)
        plt.show()

In [None]:
from spida.pl import continuous_scatter

In [None]:
from spida.utilities.sd_utils import _get_obs_or_gene

In [None]:
def _plot_gene_violinplots(
    adata : ad.AnnData, 
    gene: str,
    regions: list,
    groups, 
    subset_level: str = "Subclass",
    group_level: str = "Group",
    donors: list = None,
    labs: list = None,
    layer: str = "volume_norm",
    layer_norm: bool = True,
): 
    if isinstance(donors, str):
        donors = [donors]
    if isinstance(labs, str):
        labs = [labs]
    
    if layer_norm: 
        normalize_adata(adata, log1p=True)

    adata, _drop_col = _get_obs_or_gene(adata, gene, layer) # get the column from obs or var
    df_obs = adata.obs[["donor", "replicate", "brain_region", "Subclass", "Group", gene]].copy()
    if _drop_col:  # drop the column if it was added for plotting
        adata.obs.drop(columns=[gene], inplace=True)

    df_obs = df_obs.loc[(
        (df_obs[subset_level].isin(groups)) & 
        (df_obs["brain_region"].isin(regions)) & 
        (df_obs['Group'] != "unknown")
    )]
    df_obs['Group'] = df_obs['Group'].cat.remove_unused_categories()
    df_obs['brain_region'] = df_obs['brain_region'].cat.remove_unused_categories()

    
    if donors is None: 
        donors = adata.obs['donor'].unique().tolist()
    if labs is None: 
        labs = adata.obs['replicate'].unique().tolist()

    nrows = len(donors)
    fig, axes = plt.subplots(nrows, 1, figsize=(10,2*nrows+1), dpi=300, sharex=True, sharey=False)
    for i, _d in enumerate(donors):
        df_d = df_obs[df_obs['donor'] == _d]
        ax=axes[i] if nrows > 1 else axes
        sns.violinplot(data=df_d, x=group_level, y=gene, hue="brain_region", ax=ax, inner="quart", hue_order=regions, palette=adata.uns['brain_region_palette'])
        if i == nrows-1:
            ax.tick_params(axis='x', labelbottom=True)
            ax.set_xticks(range(len(df_d[group_level].cat.categories)))
            ax.set_xticklabels(ax.get_xticklabels(), rotation=15, ha='right', fontsize=12)
        ax.set_title(f"Donor: {_d}")
        ax.legend(title='Brain Region', fontsize=6, title_fontsize=7, loc='upper right', bbox_to_anchor=(1.15, 1))
    
    plt.suptitle(f"{gene} Expression in STR D1 and D2 MSNs across Brain Regions", y=1.02)
    plt.show()

## Methylation

In [None]:
def plot_gene_groups(
    adata: ad.AnnData,
    gene: str,
    groups: list[str],
    regions: list[str],
    group_level: str = "Group",
    labs: list[str] = None,
    donors: list[str] = None,
    layers = None,
    layer_norm = False,
    max_norm = 90,
    min_norm = 10,
    bg_adata: ad.AnnData = None,
    cmap="YlOrRd",
): 
    if gene not in adata.var_names:
        print(f"Gene {gene} not found in adata.var_names")
        return
    adata = adata[adata.obs['brain_region'].isin(regions)].copy()
    if layers is not None: 
        adata.X = adata.layers[layers].copy()
    if layer_norm:
        normalize_adata(adata, log1p=True)
    
    if labs is None: 
        labs = adata.obs['replicate'].unique().tolist()
    if donors is None:
        donors = adata.obs['donor'].unique().tolist()

    nrows = len(labs)
    ncols = len(regions)
    nplots = len(donors)
    
    for i, donor in enumerate(donors):
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3), dpi=300)
        # adata_don = adata[adata.obs['donor'] == donor].copy()
        # hue_min = adata_don[:, gene].X.min()
        # hue_max = adata_don[:, gene].X.max() * 0.9
        for l, lab in enumerate(labs):
            adata_don = adata[(adata.obs['donor'] == donor) & (adata.obs['replicate'] == lab)].copy()
            hue_min = np.percentile(adata_don[:, gene].X, min_norm)
            hue_max = np.percentile(adata_don[:, gene].X, max_norm)
            for r, region in enumerate(regions):
                ax = axes[l, r] if nrows > 1 and ncols > 1 else axes[max(l, r)]
                adata_sub = adata[(adata.obs['donor'] == donor) & (adata.obs['brain_region'] == region) & (adata.obs['replicate'] == lab)].copy()
                if bg_adata is not None: 
                    bg_sub = bg_adata[(bg_adata.obs['donor'] == donor) & (bg_adata.obs['brain_region'] == region) & (bg_adata.obs['replicate'] == lab)].copy()
                    categorical_scatter(bg_sub, coord_base="spatial", color='lightgrey', ax=ax)
                else: 
                    categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax)
                adata_plot = adata_sub[adata_sub.obs[group_level].isin(groups)].copy()
                if adata_plot.shape[0] != 0:
                    plot_continuous(adata_plot, coord_base="spatial", color_by=gene, cmap=cmap, ax=ax, show=False, hue_norm=(hue_min, hue_max)) 
                ax.set_title(f"{region}")
        plt.suptitle(f"Gene: {gene} Donor: {donor}, mCH score", y=1.02)
        plt.show()

def _plot_gene_violinplots(
    adata : ad.AnnData, 
    gene: str,
    regions: list,
    groups, 
    subset_level: str = "Subclass",
    group_level: str = "Group",
    donors: list = None,
    labs: list = None,
    layer: str = "volume_norm",
    layer_norm: bool = True,
): 
    if isinstance(donors, str):
        donors = [donors]
    if isinstance(labs, str):
        labs = [labs]
    
    if layer_norm: 
        normalize_adata(adata, log1p=True)

    adata, _drop_col = _get_obs_or_gene(adata, gene, layer) # get the column from obs or var
    df_obs = adata.obs[["donor", "replicate", "brain_region", "Subclass", "Group", gene]].copy()
    if _drop_col:  # drop the column if it was added for plotting
        adata.obs.drop(columns=[gene], inplace=True)

    df_obs = df_obs.loc[(
        (df_obs[subset_level].isin(groups)) & 
        (df_obs["brain_region"].isin(regions)) & 
        (df_obs['Group'] != "unknown")
    )]
    df_obs['Group'] = df_obs['Group'].cat.remove_unused_categories()
    df_obs['brain_region'] = df_obs['brain_region'].cat.remove_unused_categories()

    
    if donors is None: 
        donors = adata.obs['donor'].unique().tolist()
    if labs is None: 
        labs = adata.obs['replicate'].unique().tolist()

    nrows = len(donors)
    fig, axes = plt.subplots(nrows, 1, figsize=(10,2*nrows+1), dpi=300, sharex=True, sharey=False)
    for i, _d in enumerate(donors):
        df_d = df_obs[df_obs['donor'] == _d]
        ax=axes[i] if nrows > 1 else axes
        sns.violinplot(data=df_d, x=group_level, y=gene, hue="brain_region", ax=ax, inner="quart", hue_order=regions, palette=adata.uns['brain_region_palette'])
        if i == nrows-1:
            ax.tick_params(axis='x', labelbottom=True)
            ax.set_xticks(range(len(df_d[group_level].cat.categories)))
            ax.set_xticklabels(ax.get_xticklabels(), rotation=15, ha='right', fontsize=12)
        ax.set_title(f"Donor: {_d}")
        ax.legend(title='Brain Region', fontsize=6, title_fontsize=7, loc='upper right', bbox_to_anchor=(1.15, 1))
    
    plt.suptitle(f"{gene} mCH Score in STR D1 and D2 MSNs across Brain Regions", y=1.02)
    plt.show()

# Trying to do the PyComplexHeatmap
Not working Yet


In [None]:
import PyComplexHeatmap as pch
pch.__version__

In [None]:
# df_obs.drop(columns=['Subclass'], inplace=True)
df_obs = df_obs.sort_values(by=['donor', 'Group', 'brain_region', 'replicate',])
df_melt = df_obs.melt(id_vars=["donor", "Group", "brain_region", "replicate"], var_name="gene", value_name="expression")
df_melt = df_melt.astype({'donor': 'string', 'replicate': 'string', 'brain_region': 'string', 'Group': 'string', 'gene': 'string', 'expression': 'float64'})
df_melt['index'] = df_melt["donor"] + "." + df_melt["Group"] + "." + df_melt["brain_region"] + "." + df_melt["replicate"]

df_annot = df_melt[['index']].drop_duplicates().set_index('index')
df_annot['Donor'] = df_annot.index.str.split('.').str[0]
df_annot['Donor_Color'] = df_annot['Donor'].map(adata.uns['donor_palette'])
df_annot['Group'] = df_annot.index.str.split('.').str[1]
df_annot['Group_Color'] = df_annot['Group'].map(adata.uns['MSN_Groups_palette'])
df_annot['Brain_Region'] = df_annot.index.str.split('.').str[2]
df_annot['Brain_Region_Color'] = df_annot['Brain_Region'].map(adata.uns['brain_region_palette'])
df_annot['Replicate'] = df_annot.index.str.split('.').str[3]
df_annot['Replicate_Color'] = df_annot['Replicate'].map(adata.uns['replicate_palette'])

melt_by_donor = {}
for _donor in df_melt['donor'].unique().tolist():
    df_melt_d = df_melt.loc[df_melt['donor'] == _donor]
    df_melt_d.drop(columns=['donor', 'replicate', 'brain_region', 'Group', 'gene'], inplace=True)
    df_melt_d['cell'] = 1 + df_melt_d.groupby("index").cumcount()
    df_melt_d = df_melt_d.pivot(index="index", columns='cell', values="expression")    
    melt_by_donor[_donor] = df_melt_d

In [None]:
df_annot['Group'] = df_annot['Group'].astype('category').cat.remove_unused_categories()
df_melt_2424 = melt_by_donor['UCI2424']
df_annot_2424 = df_annot.loc[df_melt_2424.index]

In [None]:
br_palette = {k: v for k, v in adata.uns['brain_region_palette'].items() if k in df_annot_2424['Brain_Region'].unique().tolist()}
msn_palette = {k: v for k, v in adata.uns['MSN_Groups_palette'].items() if k in df_annot_2424['Group'].unique().tolist()}

In [None]:
fig, ax = plt.subplots(figsize=(8,4), dpi=300)
pch.HeatmapAnnotation(
    label=pch.anno_label(df_annot_2424.Group, colors=msn_palette, merge=True, rotation=15),
    Group = pch.anno_simple(df_annot_2424.Group, colors=msn_palette, legend=True, height=2),
    Brain_Region = pch.anno_simple(df_annot_2424.Brain_Region, colors=br_palette, legend=True, height=2),
    UCI2424=pch.anno_boxplot(df_melt_2424, cmap="turbo", legend=False, grid=True),
    plot=True, legend=True, legend_gap=2, hgap=2, vgap=1, legend_hpad=0.2, wgap=1
)
plt.show()

In [None]:
plt.figure(figsize=(10,4), dpi=300)
pch.HeatmapAnnotation(
    # label=pch.anno_label(df_annot.Group, merge=True, rotation=15),
    Group = pch.anno_simple(df_annot.Group, color=df_annot["Group_Color"], legend=True),
    Brain_Region = pch.anno_simple(df_annot.Brain_Region, color=df_annot['Brain_Region_Color'], legend=True),
    UCI2424=pch.anno_boxplot(melt_by_donor['UCI2424'], cmap="turbo", legend=True),
    UCI4723=pch.anno_boxplot(melt_by_donor['UCI4723'], cmap="turbo", legend=True),
    UCI5224=pch.anno_boxplot(melt_by_donor['UCI5224'], cmap="turbo", legend=True),
    UWA7648=pch.anno_boxplot(melt_by_donor['UWA7648'], cmap="turbo", legend=True),
    plot=True, legend=True, legend_gap=5
)
plt.show()