In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad

import matplotlib.pyplot as plt
import seaborn as sns
from spida.pl import plot_categorical, plot_continuous, categorical_scatter
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['axes.facecolor'] = 'white'


### functions

In [None]:
from spida.pl._utils import add_color_scheme
def plot_categorical(
    adata : ad.AnnData | str, 
    ax=None,
    coord_base='tsne',
    cluster_col='MajorType',
    palette_path=None,
    highlight_group : str | list | None = None,
    color_unknown : str = "#D0D0D0",
    axis_format='tiny',
    alpha=0.7,
    coding=True,
    id_marker=True,
    output=None,
    show=True,
    figsize=(4, 3.5),
    sheet_name=None,
    ncol=None,
    fontsize=5,
    legend_fontsize=5,
    legend_kws=None,
    legend_title_fontsize=5,
    marker_fontsize=4,
    marker_pad=0.1,
    linewidth=0.5,
    ret=False,
    text_anno:bool = False,
    text_kws=None,
    **kwargs
):
    """
    Plot cluster.

    Parameters
    ----------
    adata_path :
    ax :
    coord_base :
    cluster_col :
    palette_path :
    highlight_group : str | list
        A group (or list of groups) in the annotation to highlight by setting all other groups to background color (light gray).
    color_unknown : str
        Color to use for "unknown" or NaN categories. Default is light gray (#D0D0D0).
    coding :
    output :
    show :
    figsize :
    sheet_name :
    ncol :
    fontsize :
    legend_fontsize : int
        legend fontsize, default 5
    legend_kws: dict
        kwargs passed to ax.legend
    legend_title_fontsize: int
        legend title fontsize, default 5
    marker_fontsize: int
        Marker fontsize, default 3
        if id_marker is True, and coding is True. legend marker will be a circle (or rectangle) with code
    linewidth : float
        Line width of the legend marker (circle or rectangle), default 0.5
    text_anon: bool
        Whether to add text annotation for each cluster, default False.
	ret : bool 
		Whether to return the plot and colors, default False.
    kwargs : dict
        set text_anno=None to plot clustering without text annotations,
        coding=True to plot clustering without code annotations,
        set show_legend=False to remove the legend

    Returns
    -------

    """
    if sheet_name is None: # for getting color scheme from excel file
        sheet_name=cluster_col
    if isinstance(adata,str): # getting adata
        adata=ad.read_h5ad(adata,backed='r')
    if not isinstance(adata.obs[cluster_col].dtype, pd.CategoricalDtype): # make sure cluster_col is categorical 
        adata.obs[cluster_col] = adata.obs[cluster_col].astype('category')
    # get palette
    if palette_path is not None:
        if isinstance(palette_path,str):
            colors=pd.read_excel(os.path.expanduser(palette_path),sheet_name=sheet_name,index_col=0).Hex.to_dict()
            keys=list(colors.keys())
            existed_vals=adata.obs[cluster_col].unique().tolist()
            for k in existed_vals:
                if k not in keys:
                    colors[k]='gray'
            for k in keys:
                if k not in existed_vals:
                    del colors[k]
        else:
            colors=palette_path
        adata.uns[cluster_col + '_colors'] = [colors.get(k, 'grey') for k in adata.obs[cluster_col].cat.categories.tolist()]
    else:
        if f'{cluster_col}_colors' not in adata.uns:
            colors = add_color_scheme(adata, cluster_col, palette_key=f"{cluster_col}_colors")
        else:
            colors={cluster:color for cluster,color in zip(adata.obs[cluster_col].cat.categories.tolist(),adata.uns[f'{cluster_col}_colors'])}
    data = adata[adata.obs[cluster_col].notna()]
    if highlight_group is not None:
        if isinstance(highlight_group, str):
            highlight_group = [highlight_group]
        all_groups = adata.obs[cluster_col].cat.categories.tolist()
        for group in all_groups:
            if group not in highlight_group:
                colors[group] = color_unknown
                data.obs.loc[data.obs[cluster_col] == group, cluster_col] = np.nan

    hue=cluster_col
    text_anno = cluster_col if text_anno else None
    text_kws = {} if text_kws is None else text_kws
    text_kws.setdefault("fontsize", fontsize)
    kwargs.setdefault("hue",hue)
    kwargs.setdefault("text_anno", text_anno)
    kwargs.setdefault("text_kws", text_kws)
    kwargs.setdefault("luminance", 0.65)
    kwargs.setdefault("dodge_text", False)
    kwargs.setdefault("axis_format", axis_format)
    kwargs.setdefault("show_legend", True)
    kwargs.setdefault("marker_fontsize", marker_fontsize)
    kwargs.setdefault("marker_pad", marker_pad)
    kwargs.setdefault("linewidth", linewidth)
    kwargs.setdefault("alpha", alpha)
    kwargs["coding"]=coding
    kwargs["id_marker"]=id_marker
    legend_kws={} if legend_kws is None else legend_kws
    default_lgd_kws=dict(
        fontsize=legend_fontsize,
        title=cluster_col,title_fontsize=legend_title_fontsize)
    if ncol is not None:
        default_lgd_kws['ncol']=ncol
    for k in default_lgd_kws:
        legend_kws.setdefault(k, default_lgd_kws[k])
    kwargs.setdefault("dodge_kws", {
            "arrowprops": {
                "arrowstyle": "->",
                "fc": 'grey',
                "ec": "none",
                "connectionstyle": "angle,angleA=-90,angleB=180,rad=5",
            },
            'autoalign': 'xy'})
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize, dpi=300)
    p = categorical_scatter(
        data=adata[adata.obs[cluster_col].notna(),],
        ax=ax,
        coord_base=coord_base,
        palette=colors,
        legend_kws=legend_kws,
        **kwargs)

    if output is not None:
        plt.savefig(os.path.expanduser(output),bbox_inches='tight',dpi=300)
    if show:
        plt.show()
    
    if ret:
        return p, colors
    else: 
        return None, None

## Paths

In [None]:
rna_group_dir= Path("/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_ALL_RNA")
mc_group_dir= Path("/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_ALL_MC")

In [None]:
# Cell Type + Markers: 
ctm = {
    "STR D1 MSN": ["STXBP6", "EPHA4", "GDA", "SEMA3E", "KCNIP1", "BACH2", "KCNT1", "KHDRBS3", "ARHGAP6", "GREB1L", "GRIA4", "GRIK1"],
    "STR D2 MSN": ["STXBP6", "EPHA4", "GDA", "SEMA3E", "KCNIP1", "BACH2", "KCNT1", "KHDRBS3", "ARHGAP6", "GREB1L", "GRIA4", "GRIK1"],
}
meta_plots = ['donor', 'replicate', 'brain_region', 'nCount_RNA', 'nFeature_RNA']

# Markers --> Cell Types
STXBP6 = Matrix

KCNIP1 = Striosome

ARHGAP6 = Ventral

In [None]:
neu_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_ALL/BG_pfv8_neu.h5ad"
adata_all = ad.read_h5ad(neu_path)

In [None]:
for _ct, _markers in ctm.items():
    print(f"Plotting {_ct}...")
    
    # RNA
    rna_path = rna_group_dir / f"ALL_RNA_Subclass_{_ct.replace(" ", "_")}.h5ad"
    rna_adata = ad.read_h5ad(rna_path)

    
    
    ncols = 6
    nrows = int(np.ceil(len(meta_plots) / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*4))
    axes = axes.flatten()
    for i, col in enumerate(meta_plots):
        if rna_adata.obs[col].dtype.name == 'category' or rna_adata.obs[col].dtype == object:
            plot_categorical(rna_adata, cluster_col=col, coord_base="subclass_umap", ax=axes[i], show=False, coding=True, text_anno=True)
        else:
            plot_continuous(rna_adata, coord_base="subclass_umap", color_by=col, ax=axes[i], show=False, cmap="cividis")
    for i in range(len(meta_plots), len(axes)):
        fig.delaxes(axes[i])
    plt.suptitle(_ct)
    plt.show()

    ncols = 6
    nrows = int(np.ceil(len(_markers) / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*4))
    axes = axes.flatten()
    for i, _gene in enumerate(_markers):
        plot_continuous(rna_adata, coord_base="subclass_umap", color_by=_gene, ax=axes[i], show=False, cmap="cividis")
    for i in range(len(_markers), len(axes)):
        fig.delaxes(axes[i])
    plt.suptitle(_ct)
    plt.show()
    # for _dsid in rna_adata.obs['dataset_id'].unique(): 
    #     adata_d = rna_adata[rna_adata.obs['dataset_id'] == _dsid].copy()
    #     # print(f"Dataset {_dsid}: {adata_d.shape[0]} cells")
    #     fig, ax = plt.subplots(1, 3, figsize=(10, 2))
    #     # plot_categorical(adata_all[adata_all.obs['dataset_id'] == _dsid],cluster_col="Combined.Subclass",
    #     #                  highlight_group="None",coord_base='spatial',color_unknown="#D0D0D0",ax=ax[0],show=False, show_legend=False)
    #     plot_continuous(adata_d, coord_base="spatial", color_by=_markers[0], cmap="cividis", ax=ax[0], show=False)
    #     # ax[0].set_title(_markers[0])
    #     # plot_categorical(adata_all[adata_all.obs['dataset_id'] == _dsid],cluster_col="Combined.Subclass",
    #     #                  highlight_group="None",coord_base='spatial',color_unknown="#D0D0D0",ax=ax[1],show=False, show_legend=False)
    #     plot_continuous(adata_d, coord_base="spatial", color_by=_markers[1], cmap="cividis", ax=ax[1], show=False)
    #     # ax[1].set_title(_markers[1])
    #     # plot_categorical(adata_all[adata_all.obs['dataset_id'] == _dsid],cluster_col="Combined.Subclass",
    #     #                  highlight_group="None",coord_base='spatial',color_unknown="#D0D0D0",ax=ax[2],show=False, show_legend=False)
    #     plot_continuous(adata_d, coord_base="spatial", color_by=_markers[2], cmap="cividis", ax=ax[2], show=False)
    #     # ax[2].set_title(_markers[2])
    #     plt.suptitle(f"{_ct} - {_dsid}")
    #     plt.show()
    #     # break

    # break

In [None]:
from spida.P.setup_adata import _calc_embeddings

In [None]:
adata_sub = rna_adata[rna_adata.obs['dataset_id'] == 'CAB_UCI5224_salk'].copy()
adata_sub

In [None]:
_calc_embeddings(
    adata_sub,
    min_dist=0.25,
    knn=30,
    key_added="sub_"
)

In [None]:
fig, ax = plt.subplots(1, 4, figsize=(10, 3))
plot_categorical(adata_sub, cluster_col="sub_leiden", coord_base="sub_umap", show=False, coding=True, text_anno=True, id_marker=True, ax=ax[0])
plot_categorical(adata_sub, cluster_col="sub_leiden", coord_base="spatial", show=False, coding=True, text_anno=True, id_marker=True, ax=ax[1])
plot_continuous(adata_sub, coord_base="sub_umap", color_by="nCount_RNA", ax=ax[2], show=False, cmap="cividis")
plot_continuous(adata_sub, coord_base="sub_umap", color_by="nFeature_RNA", ax=ax[3], show=False, cmap="cividis")
plt.show()

In [None]:
ncols = 6
nrows = int(np.ceil(len(_markers) / ncols))
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*4))
axes = axes.flatten()
for i, _gene in enumerate(_markers):
    plot_continuous(adata_sub, coord_base="sub_umap", color_by=_gene, ax=axes[i], show=False, cmap="cividis")
for i in range(len(_markers), len(axes)):
    fig.delaxes(axes[i])
plt.suptitle(_ct)
plt.show()

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
plot_continuous(rna_adata, coord_base="all_round1_umap", color_by=_markers[0], ax=ax[0], show=False)
plot_continuous(rna_adata, coord_base="all_round1_umap", color_by=_markers[1], ax=ax[1], show=False)
plot_continuous(rna_adata, coord_base="all_round1_umap", color_by=_markers[2], ax=ax[2], show=False)
plt.show()

In [None]:
plot_categorical(rna_adata, coord_base="all_round1_umap", cluster_col="allcools_Group_filt")