In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from spida.pl import plot_categorical, plot_continuous, categorical_scatter

In [None]:
from datetime import datetime 
current_datetime = datetime.now().strftime("%Y-%m-%d_%H:%M")
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/figures/annot/cps/"
os.makedirs(image_path, exist_ok=True)

plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['axes.facecolor'] = 'white'

### function

In [None]:
from spida.pl._utils import add_color_scheme
def plot_categorical(
    adata : ad.AnnData | str, 
    ax=None,
    coord_base='tsne',
    cluster_col='MajorType',
    palette_path=None,
    highlight_group : str | list | None = None,
    color_unknown : str = "#D0D0D0",
    is_background : bool = False,
    axis_format='tiny',
    alpha=0.7,
    coding=True,
    id_marker=True,
    output=None,
    show=True,
    figsize=(4, 3.5),
    sheet_name=None,
    ncol=None,
    fontsize=5,
    legend_fontsize=5,
    legend_kws=None,
    legend_title_fontsize=5,
    marker_fontsize=4,
    marker_pad=0.1,
    linewidth=0.5,
    ret=False,
    text_anno:bool = False,
    text_kws=None,
    **kwargs
):
    """
    Plot cluster.

    Parameters
    ----------
    adata_path :
    ax :
    coord_base :
    cluster_col :
    palette_path :
    highlight_group : str | list
        A group (or list of groups) in the annotation to highlight by setting all other groups to background color (light gray).
    color_unknown : str
        Color to use for "unknown" or NaN categories. Default is light gray (#D0D0D0).
    coding :
    output :
    show :
    figsize :
    sheet_name :
    ncol :
    fontsize :
    legend_fontsize : int
        legend fontsize, default 5
    legend_kws: dict
        kwargs passed to ax.legend
    legend_title_fontsize: int
        legend title fontsize, default 5
    marker_fontsize: int
        Marker fontsize, default 3
        if id_marker is True, and coding is True. legend marker will be a circle (or rectangle) with code
    linewidth : float
        Line width of the legend marker (circle or rectangle), default 0.5
    text_anon: bool
        Whether to add text annotation for each cluster, default False.
	ret : bool 
		Whether to return the plot and colors, default False.
    kwargs : dict
        set text_anno=None to plot clustering without text annotations,
        coding=True to plot clustering without code annotations,
        set show_legend=False to remove the legend

    Returns
    -------

    """
    if sheet_name is None: # for getting color scheme from excel file
        sheet_name=cluster_col
    if isinstance(adata,str): # getting adata
        adata=ad.read_h5ad(adata,backed='r')
    if not isinstance(adata.obs[cluster_col].dtype, pd.CategoricalDtype): # make sure cluster_col is categorical 
        adata.obs[cluster_col] = adata.obs[cluster_col].astype('category')
    # get palette
    if palette_path is not None:
        if isinstance(palette_path,str):
            colors=pd.read_excel(os.path.expanduser(palette_path),sheet_name=sheet_name,index_col=0).Hex.to_dict()
            keys=list(colors.keys())
            existed_vals=adata.obs[cluster_col].unique().tolist()
            for k in existed_vals:
                if k not in keys:
                    colors[k]='gray'
            for k in keys:
                if k not in existed_vals:
                    del colors[k]
        else:
            colors=palette_path
        adata.uns[cluster_col + '_colors'] = [colors.get(k, 'grey') for k in adata.obs[cluster_col].cat.categories.tolist()]
    else:
        if f'{cluster_col}_colors' not in adata.uns:
            colors = add_color_scheme(adata, cluster_col, palette_key=f"{cluster_col}_colors")
        else:
            colors={cluster:color for cluster,color in zip(adata.obs[cluster_col].cat.categories.tolist(),adata.uns[f'{cluster_col}_colors'])}
    data = adata[adata.obs[cluster_col].notna()]
    if highlight_group is not None:
        if isinstance(highlight_group, str):
            highlight_group = [highlight_group]
        all_groups = adata.obs[cluster_col].cat.categories.tolist()
        for group in all_groups:
            if group not in highlight_group:
                colors[group] = color_unknown
                data.obs.loc[data.obs[cluster_col] == group, cluster_col] = np.nan
        data = data[data.obs[cluster_col].notna(), :]
        data.obs[cluster_col] = data.obs[cluster_col].cat.remove_unused_categories()
    colors = {k: v for k, v in colors.items() if k in data.obs[cluster_col].cat.categories.tolist()}
    if is_background:
        bg_colors = {k: color_unknown for k in colors.keys()}
        bg_colors['unknown'] = color_unknown
        colors = bg_colors        

    hue=cluster_col
    text_anno = cluster_col if text_anno else None
    text_kws = {} if text_kws is None else text_kws
    text_kws.setdefault("fontsize", fontsize)
    kwargs.setdefault("hue",hue)
    kwargs.setdefault("text_anno", text_anno)
    kwargs.setdefault("text_kws", text_kws)
    kwargs.setdefault("luminance", 0.65)
    kwargs.setdefault("dodge_text", False)
    kwargs.setdefault("axis_format", axis_format)
    kwargs.setdefault("show_legend", True)
    kwargs.setdefault("marker_fontsize", marker_fontsize)
    kwargs.setdefault("marker_pad", marker_pad)
    kwargs.setdefault("linewidth", linewidth)
    kwargs.setdefault("alpha", alpha)
    kwargs["coding"]=coding
    kwargs["id_marker"]=id_marker
    legend_kws={} if legend_kws is None else legend_kws
    default_lgd_kws=dict(
        fontsize=legend_fontsize,
        title=cluster_col,title_fontsize=legend_title_fontsize)
    if ncol is not None:
        default_lgd_kws['ncol']=ncol
    for k in default_lgd_kws:
        legend_kws.setdefault(k, default_lgd_kws[k])
    kwargs.setdefault("dodge_kws", {
            "arrowprops": {
                "arrowstyle": "->",
                "fc": 'grey',
                "ec": "none",
                "connectionstyle": "angle,angleA=-90,angleB=180,rad=5",
            },
            'autoalign': 'xy'})
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize, dpi=300)
    p = categorical_scatter(
        data=data, #adata[adata.obs[cluster_col].notna(),],
        ax=ax,
        coord_base=coord_base,
        palette=colors,
        legend_kws=legend_kws,
        **kwargs)

    if output is not None:
        plt.savefig(os.path.expanduser(output),bbox_inches='tight',dpi=300)
    if show:
        plt.show()
    
    if ret:
        return p, colors
    else: 
        return None, None

## Read

In [None]:
# adata = ad.read_h5ad('/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPSAM_annotated_v2.h5ad')
# adata = ad.read_h5ad('/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_PFV8_annotated_v5.h5ad')
adata = ad.read_h5ad('/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad')
adata

In [None]:
donors = adata.obs['donor'].unique().tolist()
labs = adata.obs['replicate'].unique().tolist()
brain_regions = adata.obs['brain_region'].unique().tolist()

In [None]:
# For handling the anndata version differences
for _key, _val in adata.uns.items(): 
    if type(_val) == dict: 
        print(_key)
        for __key, __val in _val.items():
            if __val is None: 
                print(__key, __val)
# del adata.uns['joint_pca']
# adata.write_h5ad('/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CP_annotated_v1.h5ad')

In [None]:
adata.uns['neuron_type_colors'] = [
    '#0762f5',
    '#f7781e'
]

### NAC

In [None]:
adata_nac = adata[adata.obs['brain_region'] == 'NAC'].copy()
adata_nac.obs['Subclass'] = adata_nac.obs['Subclass'].cat.remove_unused_categories()
adata_nac.obs['Group'] = adata_nac.obs['Group'].cat.remove_unused_categories()

In [None]:
fig, axes = plt.subplots(1,3, figsize=(16, 4))
plot_categorical(adata_nac, cluster_col='base_round1_leiden', coord_base="base_round1_umap", ax=axes[0], show=False)
plot_categorical(adata_nac, cluster_col='donor', coord_base="base_round1_umap", ax=axes[1], show=False)
plot_categorical(adata_nac, cluster_col='replicate', coord_base="base_round1_umap", ax=axes[2], show=False)
plt.savefig(Path(image_path) / "NAC_UMAP_meta.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 4), dpi=300)
plot_categorical(adata_nac, cluster_col='neuron_type', coord_base="base_round1_umap", ax=axes[0], coding=True, text_anno=True, show=False)
plot_categorical(adata_nac, cluster_col='Subclass', coord_base="base_round1_umap", ax=axes[1], coding=True, text_anno=True, show=False)
plot_categorical(adata_nac, cluster_col='Group', coord_base="base_round1_umap", ax=axes[2], coding=False, text_anno=False, show=False)
plt.savefig(Path(image_path) / "NAC_UMAP_annot.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
# adata_nac_salk = 
fig, axes = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
plot_categorical(adata_nac[adata_nac.obs['replicate']=='salk'], cluster_col='Subclass', coord_base="spatial", ax=axes, show=False)
plt.savefig(Path(image_path) / f"NAC_spatial_subclass_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
plot_categorical(adata_nac[(adata_nac.obs['replicate']=='salk') & (adata_nac.obs['Group'] != "unknown")],
                 cluster_col="Group",
                 coord_base='spatial',
                 color_unknown="#D0D0D0",
                 ax=ax,
                 show=False)
plt.savefig(Path(image_path) / f"NAC_spatial_group_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
plot_cats = ['STRd D1 Striosome MSN', 'STRd D2 Striosome MSN', 'STRd D1 Matrix MSN', 'STRd D2 Matrix MSN', 
             'STRd D1/D2 Hybrid MSN', 'STRv D1 MSN', 'STRv D2 MSN']
ncols = len(donors)
nrows = len(labs)
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*2 + 1), dpi=300)
for i, _donor in enumerate(donors): 
    for j, _lab in enumerate(labs):
        ax = axes[j, i]
        adata_sub = adata_nac[(adata_nac.obs['donor'] == _donor) & (adata_nac.obs['replicate'] == _lab)]
        if adata_sub.shape[0] > 0:
                
            adata_cat = adata_sub[adata_sub.obs['Group'].isin(plot_cats)]
            if adata_cat.shape[0] == 0:
                ax.set_visible(False)
            else:
                categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax)
                plot_categorical(adata_cat, coord_base="spatial", cluster_col="Group", palette_path = adata_nac.uns['MSN_Groups_palette'], ax=ax, show=False)
                ax.set_title(f"{_donor} - {_lab}\n{adata_cat.shape[0]} cells", fontsize=8)
            
plt.suptitle(f"NAC", fontsize=16)
plt.tight_layout()
plt.savefig(Path(image_path) / f"NAC_spatial_MSNs.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()


In [None]:
# plot_col = "Group"
# fig, ax = plt.subplots(figsize=(10, 10))
# # ax.set_facecolor("black")
# adata_s = adata_nac[(adata_nac.obs['replicate']=='salk') & (adata_nac.obs[plot_col] != "unknown")].copy()
# adata_s.obs[plot_col] = adata_s.obs[plot_col].cat.remove_unused_categories().copy()

# plot_categorical(adata_s,
#                  cluster_col=plot_col,
#                  highlight_group=None,
#                  is_background=True,
#                  coord_base='spatial',
#                  color_unknown="#D0D0D0",
#                  show_legend=False,
#                  ax=ax, show=False)
# plot_categorical(adata_s,
#                  cluster_col=plot_col,
#                  highlight_group=plot_cats,
#                  color_unknown="#D0D0D0",
#                  coord_base='spatial',
#                  alpha=1,
#                  s=4,
#                  max_points=None,
#                  legend_kws={"markersize": 10},
#                  legend_fontsize=10,
#                  legend_title_fontsize=12,
#                  ax=ax, show=False)
# plt.show()

## Pu

In [None]:
adata_pu = adata[adata.obs['brain_region'] == 'PU'].copy()
adata_pu.obs['Subclass'] = adata_pu.obs['Subclass'].cat.remove_unused_categories()
adata_pu.obs['Group'] = adata_pu.obs['Group'].cat.remove_unused_categories()

In [None]:
fig, axes = plt.subplots(1,3, figsize=(16, 4))
plot_categorical(adata_pu, cluster_col='base_round1_leiden', coord_base="base_round1_umap", ax=axes[0], show=False)
plot_categorical(adata_pu, cluster_col='donor', coord_base="base_round1_umap", ax=axes[1], show=False)
plot_categorical(adata_pu, cluster_col='replicate', coord_base="base_round1_umap", ax=axes[2], show=False)
plt.savefig(Path(image_path) / "PU_UMAP_meta.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 4), dpi=300)
plot_categorical(adata_pu, cluster_col='neuron_type', coord_base="base_round1_umap", ax=axes[0], coding=True, text_anno=True, show=False)
plot_categorical(adata_pu, cluster_col='Subclass', coord_base="base_round1_umap", ax=axes[1], coding=True, text_anno=True, show=False)
plot_categorical(adata_pu, cluster_col='Group', coord_base="base_round1_umap", ax=axes[2], coding=False, text_anno=False, show=False)
plt.savefig(Path(image_path) / "PU_UMAP_annot.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(3, 3), dpi=300)
plot_categorical(adata_pu[adata_pu.obs['replicate']=='salk'], cluster_col='Subclass', coord_base="spatial", ax=axes, show=False)
plt.savefig(Path(image_path) / "PU_spatial_subclass_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
plot_cats = ['STRd D1 Striosome MSN', 'STRd D2 Striosome MSN', 'STRd D1 Matrix MSN', 'STRd D2 Matrix MSN', 
             'STRd D1/D2 Hybrid MSN', 'STRv D1 MSN', 'STRv D2 MSN']
ncols = len(donors)
nrows = len(labs)
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*2+1), dpi=300)
for i, _donor in enumerate(donors): 
    for j, _lab in enumerate(labs):
        ax = axes[j, i]
        adata_sub = adata_pu[(adata_pu.obs['donor'] == _donor) & (adata_pu.obs['replicate'] == _lab)]
        if adata_sub.shape[0] > 0:
                
            adata_cat = adata_sub[adata_sub.obs['Group'].isin(plot_cats)]
            if adata_cat.shape[0] == 0:
                ax.set_visible(False)
            else:
                categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax)
                plot_categorical(adata_cat, coord_base="spatial", cluster_col="Group", palette_path = adata_pu.uns['MSN_Groups_palette'], ax=ax, show=False)
                ax.set_title(f"{_donor} - {_lab}\n{adata_cat.shape[0]} cells", fontsize=8)
            
plt.suptitle(f"PU", fontsize=16)
plt.tight_layout()
plt.savefig(Path(image_path) / f"PU_spatial_MSNs.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()


In [None]:
# plot_cats = adata_pu[adata_pu.obs['Subclass'] == "STR D1 MSN"].obs['Group'].value_counts().index
# plot_cats
# plot_col = "Group"
# fig, ax = plt.subplots(figsize=(10, 10))
# # ax.set_facecolor("black")
# adata_s = adata_pu[(adata_pu.obs['replicate']=='salk') & (adata_pu.obs[plot_col] != "unknown")].copy()
# adata_s.obs[plot_col] = adata_s.obs[plot_col].cat.remove_unused_categories().copy()

# plot_categorical(adata_s,
#                  cluster_col=plot_col,
#                  highlight_group=None,
#                  is_background=True,
#                  coord_base='spatial',
#                  color_unknown="#D0D0D0",
#                  show_legend=False,
#                  ax=ax, show=False)
# plot_categorical(adata_s,
#                  cluster_col=plot_col,
#                  highlight_group=plot_cats,
#                  color_unknown="#D0D0D0",
#                  coord_base='spatial',
#                  alpha=1,
#                  s=4,
#                  max_points=None,
#                  legend_kws={"markersize": 10},
#                  legend_fontsize=10,
#                  legend_title_fontsize=12,
#                  ax=ax, show=False)
# plt.show()

## GP

In [None]:
adata_gp = adata[adata.obs['brain_region'] == 'GP'].copy()
adata_gp.obs['Subclass'] = adata_gp.obs['Subclass'].cat.remove_unused_categories()
adata_gp.obs['Group'] = adata_gp.obs['Group'].cat.remove_unused_categories()

In [None]:
fig, axes = plt.subplots(1,3, figsize=(16, 4))
plot_categorical(adata_gp, cluster_col='base_round1_leiden', coord_base="base_round1_umap", ax=axes[0], show=False)
plot_categorical(adata_gp, cluster_col='donor', coord_base="base_round1_umap", ax=axes[1], show=False)
plot_categorical(adata_gp, cluster_col='replicate', coord_base="base_round1_umap", ax=axes[2], show=False)
plt.savefig(Path(image_path) / "GP_UMAP_meta.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 4), dpi=300)
plot_categorical(adata_gp, cluster_col='neuron_type', coord_base="base_round1_umap", ax=axes[0], coding=True, text_anno=True, show=False)
plot_categorical(adata_gp, cluster_col='Subclass', coord_base="base_round1_umap", ax=axes[1], coding=True, text_anno=True, show=False)
plot_categorical(adata_gp, cluster_col='Group', coord_base="base_round1_umap", ax=axes[2], coding=False, text_anno=False, show=False)
plt.savefig(Path(image_path) / "GP_UMAP_annot.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()


In [None]:
fig, axes = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
plot_categorical(adata_gp[adata_gp.obs['replicate']=='salk'], cluster_col='Subclass', coord_base="spatial", ax=axes, show=False)
plt.savefig(Path(image_path) / "GP_spatial_subclass_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
plot_categorical(adata_gp[(adata_gp.obs['replicate']=='salk') & (adata_gp.obs['Group'] != "unknown")],
                 cluster_col="Group",
                 coord_base='spatial',
                 color_unknown="#D0D0D0",
                 ax=ax,
                 show=False)
plt.savefig(Path(image_path) / f"GP_spatial_group_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
plot_cats = [a for a in adata_gp.obs['Group'].unique().tolist() if "GP" in a]

In [None]:
# plot_cats = adata_gp[adata_gp.obs['Subclass'] == "CN LHX8 GABA"].obs['Group'].value_counts().index
# plot_cats

In [None]:
plot_col = "Group"
fig, ax = plt.subplots(figsize=(10, 10))
# ax.set_facecolor("black")
adata_s = adata_gp[(adata_gp.obs['replicate']=='ucsd') & (adata_gp.obs[plot_col] != "unknown")].copy()

plot_categorical(adata_s,
                 cluster_col=plot_col,
                 highlight_group=None,
                 is_background=True,
                 coord_base='spatial',
                 color_unknown="#D0D0D0",
                 show_legend=False,
                 ax=ax, show=False)
plot_categorical(adata_s,
                 cluster_col=plot_col,
                 highlight_group=plot_cats,
                 color_unknown="#D0D0D0",
                 coord_base='spatial',
                 alpha=1,
                 s=8,
                 max_points=None,
                 legend_kws={"markersize": 10},
                 legend_fontsize=10,
                 legend_title_fontsize=12,
                 ax=ax, show=False)
plt.savefig(Path(image_path) / "GP_spatial_GPg_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

## CAB

In [None]:
_reg = "CAB"
adata_reg = adata[adata.obs['brain_region'] == _reg].copy()
adata_reg.obs['Subclass'] = adata_reg.obs['Subclass'].cat.remove_unused_categories()
adata_reg.obs['Group'] = adata_reg.obs['Group'].cat.remove_unused_categories()

fig, axes = plt.subplots(1,3, figsize=(16, 4), dpi=300)
plot_categorical(adata_reg, cluster_col='base_round1_leiden', coord_base="base_round1_umap", ax=axes[0], show=False)
plot_categorical(adata_reg, cluster_col='donor', coord_base="base_round1_umap", ax=axes[1], show=False)
plot_categorical(adata_reg, cluster_col='replicate', coord_base="base_round1_umap", ax=axes[2], show=False)
plt.savefig(Path(image_path) / f"{_reg}_UMAP_meta.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, axes = plt.subplots(1, 3, figsize=(20, 4), dpi=300)
plot_categorical(adata_reg, cluster_col='neuron_type', coord_base="base_round1_umap", ax=axes[0], coding=True, text_anno=True, show=False)
plot_categorical(adata_reg, cluster_col='Subclass', coord_base="base_round1_umap", ax=axes[1], coding=True, text_anno=True, show=False)
plot_categorical(adata_reg, cluster_col='Group', coord_base="base_round1_umap", ax=axes[2], coding=False, text_anno=False, show=False)
plt.savefig(Path(image_path) / f"{_reg}_UMAP_annot.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, axes = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
plot_categorical(adata_reg[adata_reg.obs['replicate']=='salk'], cluster_col='Subclass', coord_base="spatial", ax=axes, show=False)
plt.savefig(Path(image_path) / f"{_reg}_spatial_subclass_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
plot_categorical(adata_reg[(adata_reg.obs['replicate']=='salk') & (adata_reg.obs['Group'] != "unknown")],
                 cluster_col="Group",
                 coord_base='spatial',
                 color_unknown="#D0D0D0",
                 ax=ax,
                 show=False)

plt.savefig(Path(image_path) / f"{_reg}_spatial_group_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
plot_cats = ['STRd D1 Striosome MSN', 'STRd D2 Striosome MSN', 'STRd D1 Matrix MSN', 'STRd D2 Matrix MSN', 
             'STRd D1/D2 Hybrid MSN', 'STRv D1 MSN', 'STRv D2 MSN']
ncols = len(donors)
nrows = len(labs)
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*2+1), dpi=300)
for i, _donor in enumerate(donors): 
    for j, _lab in enumerate(labs):
        ax = axes[j, i]
        adata_sub = adata_reg[(adata_reg.obs['donor'] == _donor) & (adata_reg.obs['replicate'] == _lab)]
        if adata_sub.shape[0] > 0:
                
            adata_cat = adata_sub[adata_sub.obs['Group'].isin(plot_cats)]
            if adata_cat.shape[0] == 0:
                ax.set_visible(False)
            else:
                categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax)
                plot_categorical(adata_cat, coord_base="spatial", cluster_col="Group", palette_path = adata_reg.uns['MSN_Groups_palette'], ax=ax, show=False)
                ax.set_title(f"{_donor} - {_lab}\n{adata_cat.shape[0]} cells", fontsize=8)
            
plt.suptitle(f"CAB", fontsize=16)
plt.tight_layout()
plt.savefig(Path(image_path) / f"CAB_spatial_MSNs.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
plot_cats = ['VLMC, Lymphocyte', "Endo", "Ependymal"]
ncols = len(donors)
nrows = len(labs)
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*2+1), dpi=300)
for i, _donor in enumerate(donors): 
    for j, _lab in enumerate(labs):
        ax = axes[j, i]
        adata_sub = adata_reg[(adata_reg.obs['donor'] == _donor) & (adata_reg.obs['replicate'] == _lab)]
        if adata_sub.shape[0] > 0:
                
            adata_cat = adata_sub[adata_sub.obs['Group'].isin(plot_cats)]
            if adata_cat.shape[0] == 0:
                ax.set_visible(False)
            else:
                categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax)
                plot_categorical(adata_cat, coord_base="spatial", cluster_col="Group", palette_path = adata_reg.uns['Group_palette'], ax=ax, show=False)
                ax.set_title(f"{_donor} - {_lab}\n{adata_cat.shape[0]} cells", fontsize=8)
            
plt.suptitle(f"CAB", fontsize=16)
plt.tight_layout()
# plt.savefig(Path(image_path) / f"CAB_spatial_MSNs.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
# # print(adata_reg.obs['Subclass'].value_counts())
# plot_cats = adata_reg[adata_reg.obs['Subclass'] == "STR D1 MSN"].obs['Group'].value_counts().index
# print(plot_cats)

# plot_col = "Group"
# fig, ax = plt.subplots(figsize=(10, 10))
# # ax.set_facecolor("black")
# adata_s = adata_reg[(adata_reg.obs['replicate']=='salk') & (adata_reg.obs[plot_col] != "unknown")].copy()
# adata_s.obs[plot_col] = adata_s.obs[plot_col].cat.remove_unused_categories().copy()

# plot_categorical(adata_s,
#                  cluster_col=plot_col,
#                  highlight_group=None,
#                  is_background=True,
#                  coord_base='spatial',
#                  color_unknown="#D0D0D0",
#                  show_legend=False,
#                  ax=ax, show=False)
# plot_categorical(adata_s,
#                  cluster_col=plot_col,
#                  highlight_group=plot_cats,
#                  color_unknown="#D0D0D0",
#                  coord_base='spatial',
#                  alpha=1,
#                  s=4,
#                  max_points=None,
#                  legend_kws={"markersize": 10},
#                  legend_fontsize=10,
#                  legend_title_fontsize=12,
#                  ax=ax, show=False)
# plt.show()

### CAH

In [None]:
_reg = "CAH"
adata_reg = adata[adata.obs['brain_region'] == _reg].copy()
adata_reg.obs['Subclass'] = adata_reg.obs['Subclass'].cat.remove_unused_categories()
adata_reg.obs['Group'] = adata_reg.obs['Group'].cat.remove_unused_categories()

fig, axes = plt.subplots(1,3, figsize=(16, 4), dpi=300)
plot_categorical(adata_reg, cluster_col='base_round1_leiden', coord_base="base_round1_umap", ax=axes[0], show=False)
plot_categorical(adata_reg, cluster_col='donor', coord_base="base_round1_umap", ax=axes[1], show=False)
plot_categorical(adata_reg, cluster_col='replicate', coord_base="base_round1_umap", ax=axes[2], show=False)
plt.savefig(Path(image_path) / f"{_reg}_UMAP_meta.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, axes = plt.subplots(1, 3, figsize=(20, 4), dpi=300)
plot_categorical(adata_reg, cluster_col='neuron_type', coord_base="base_round1_umap", ax=axes[0], coding=True, text_anno=True, show=False)
plot_categorical(adata_reg, cluster_col='Subclass', coord_base="base_round1_umap", ax=axes[1], coding=True, text_anno=True, show=False)
plot_categorical(adata_reg, cluster_col='Group', coord_base="base_round1_umap", ax=axes[2], coding=False, text_anno=False, show=False)
plt.savefig(Path(image_path) / f"{_reg}_UMAP_annot.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, axes = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
plot_categorical(adata_reg[adata_reg.obs['replicate']=='salk'], cluster_col='Subclass', coord_base="spatial", ax=axes, show=False)
plt.savefig(Path(image_path) / f"{_reg}_spatial_subclass_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
plot_categorical(adata_reg[(adata_reg.obs['replicate']=='salk') & (adata_reg.obs['Group'] != "unknown")],
                 cluster_col="Group",
                 coord_base='spatial',
                 color_unknown="#D0D0D0",
                 ax=ax,
                 show=False)

plt.savefig(Path(image_path) / f"{_reg}_spatial_group_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
plot_cats = ['STRd D1 Striosome MSN', 'STRd D2 Striosome MSN', 'STRd D1 Matrix MSN', 'STRd D2 Matrix MSN', 
             'STRd D1/D2 Hybrid MSN', 'STRv D1 MSN', 'STRv D2 MSN']
ncols = len(donors)
nrows = len(labs)
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*2+1), dpi=300)
for i, _donor in enumerate(donors): 
    for j, _lab in enumerate(labs):
        ax = axes[j, i]
        adata_sub = adata_reg[(adata_reg.obs['donor'] == _donor) & (adata_reg.obs['replicate'] == _lab)]
        if adata_sub.shape[0] > 0:
                
            adata_cat = adata_sub[adata_sub.obs['Group'].isin(plot_cats)]
            if adata_cat.shape[0] == 0:
                ax.set_visible(False)
            else:
                categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax)
                plot_categorical(adata_cat, coord_base="spatial", cluster_col="Group", palette_path = adata_reg.uns['MSN_Groups_palette'], ax=ax, show=False)
                ax.set_title(f"{_donor} - {_lab}\n{adata_cat.shape[0]} cells", fontsize=8)
            
plt.suptitle(f"CAH", fontsize=16)
plt.tight_layout()
plt.savefig(Path(image_path) / f"CAH_spatial_MSNs.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()


In [None]:
# # print(adata_reg.obs['Subclass'].value_counts())
# plot_cats = adata_reg[adata_reg.obs['Subclass'] == "STR D1 MSN"].obs['Group'].value_counts().index
# print(plot_cats)

# plot_col = "Group"
# fig, ax = plt.subplots(figsize=(10, 10))
# # ax.set_facecolor("black")
# adata_s = adata_reg[(adata_reg.obs['replicate']=='salk') & (adata_reg.obs[plot_col] != "unknown")].copy()
# adata_s.obs[plot_col] = adata_s.obs[plot_col].cat.remove_unused_categories().copy()

# plot_categorical(adata_s,
#                  cluster_col=plot_col,
#                  highlight_group=None,
#                  is_background=True,
#                  coord_base='spatial',
#                  color_unknown="#D0D0D0",
#                  show_legend=False,
#                  ax=ax, show=False)
# plot_categorical(adata_s,
#                  cluster_col=plot_col,
#                  highlight_group=plot_cats,
#                  color_unknown="#D0D0D0",
#                  coord_base='spatial',
#                  alpha=1,
#                  s=4,
#                  max_points=None,
#                  legend_kws={"markersize": 10},
#                  legend_fontsize=10,
#                  legend_title_fontsize=12,
#                  ax=ax, show=False)
# plt.show()

### CAT

In [None]:
_reg = "CAT"
adata_reg = adata[adata.obs['brain_region'] == _reg].copy()
adata_reg.obs['Subclass'] = adata_reg.obs['Subclass'].cat.remove_unused_categories()
adata_reg.obs['Group'] = adata_reg.obs['Group'].cat.remove_unused_categories()

fig, axes = plt.subplots(1,3, figsize=(16, 4), dpi=300)
plot_categorical(adata_reg, cluster_col='base_round1_leiden', coord_base="base_round1_umap", ax=axes[0], show=False)
plot_categorical(adata_reg, cluster_col='donor', coord_base="base_round1_umap", ax=axes[1], show=False)
plot_categorical(adata_reg, cluster_col='replicate', coord_base="base_round1_umap", ax=axes[2], show=False)
plt.savefig(Path(image_path) / f"{_reg}_UMAP_meta.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, axes = plt.subplots(1, 3, figsize=(20, 4), dpi=300)
plot_categorical(adata_reg, cluster_col='neuron_type', coord_base="base_round1_umap", ax=axes[0], coding=True, text_anno=True, show=False)
plot_categorical(adata_reg, cluster_col='Subclass', coord_base="base_round1_umap", ax=axes[1], coding=True, text_anno=True, show=False)
plot_categorical(adata_reg, cluster_col='Group', coord_base="base_round1_umap", ax=axes[2], coding=False, text_anno=False, show=False)
plt.savefig(Path(image_path) / f"{_reg}_UMAP_annot.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, axes = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
plot_categorical(adata_reg[adata_reg.obs['replicate']=='salk'], cluster_col='Subclass', coord_base="spatial", ax=axes, show=False)
plt.savefig(Path(image_path) / f"{_reg}_spatial_subclass_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
plot_categorical(adata_reg[(adata_reg.obs['replicate']=='salk') & (adata_reg.obs['Group'] != "unknown")],
                 cluster_col="Group",
                 coord_base='spatial',
                 color_unknown="#D0D0D0",
                 ax=ax,
                 show=False)

plt.savefig(Path(image_path) / f"{_reg}_spatial_group_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
plot_cats = ['STRd D1 Striosome MSN', 'STRd D2 Striosome MSN', 'STRd D1 Matrix MSN', 'STRd D2 Matrix MSN', 
             'STRd D1/D2 Hybrid MSN', 'STRv D1 MSN', 'STRv D2 MSN']
ncols = len(adata_reg.obs['donor'].unique().tolist())
nrows = len(labs)
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*2+1), dpi=300)
for i, _donor in enumerate(adata_reg.obs['donor'].unique().tolist()): 
    for j, _lab in enumerate(labs):
        ax = axes[j, i]
        adata_sub = adata_reg[(adata_reg.obs['donor'] == _donor) & (adata_reg.obs['replicate'] == _lab)]
        if adata_sub.shape[0] > 0:
                
            adata_cat = adata_sub[adata_sub.obs['Group'].isin(plot_cats)]
            if adata_cat.shape[0] == 0:
                ax.set_visible(False)
            else:
                categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax)
                plot_categorical(adata_cat, coord_base="spatial", cluster_col="Group", palette_path = adata_reg.uns['MSN_Groups_palette'], ax=ax, show=False)
                ax.set_title(f"{_donor} - {_lab}\n{adata_cat.shape[0]} cells", fontsize=8)
            
plt.suptitle(f"CAT", fontsize=16)
plt.tight_layout()
plt.savefig(Path(image_path) / f"CAT_spatial_MSNs.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()


In [None]:
# # print(adata_reg.obs['Subclass'].value_counts())
# plot_cats = adata_reg[adata_reg.obs['Subclass'] == "STR D1 MSN"].obs['Group'].value_counts().index
# print(plot_cats)

# plot_col = "Group"
# fig, ax = plt.subplots(figsize=(10, 10))
# # ax.set_facecolor("black")
# adata_s = adata_reg[(adata_reg.obs['replicate']=='salk') & (adata_reg.obs[plot_col] != "unknown")].copy()
# adata_s.obs[plot_col] = adata_s.obs[plot_col].cat.remove_unused_categories().copy()

# plot_categorical(adata_s,
#                  cluster_col=plot_col,
#                  highlight_group=None,
#                  is_background=True,
#                  coord_base='spatial',
#                  color_unknown="#D0D0D0",
#                  show_legend=False,
#                  ax=ax, show=False)
# plot_categorical(adata_s,
#                  cluster_col=plot_col,
#                  highlight_group=plot_cats,
#                  color_unknown="#D0D0D0",
#                  coord_base='spatial',
#                  alpha=1,
#                  s=4,
#                  max_points=None,
#                  legend_kws={"markersize": 10},
#                  legend_fontsize=10,
#                  legend_title_fontsize=12,
#                  ax=ax, show=False)
# plt.show()

### MGM

In [None]:
_reg = "MGM1"
adata_reg = adata[adata.obs['brain_region'] == _reg].copy()
adata_reg.obs['Subclass'] = adata_reg.obs['Subclass'].cat.remove_unused_categories()
adata_reg.obs['Group'] = adata_reg.obs['Group'].cat.remove_unused_categories()

fig, axes = plt.subplots(1,3, figsize=(16, 4), dpi=300)
plot_categorical(adata_reg, cluster_col='base_round1_leiden', coord_base="base_round1_umap", ax=axes[0], show=False)
plot_categorical(adata_reg, cluster_col='donor', coord_base="base_round1_umap", ax=axes[1], show=False)
plot_categorical(adata_reg, cluster_col='replicate', coord_base="base_round1_umap", ax=axes[2], show=False)
plt.savefig(Path(image_path) / f"{_reg}_UMAP_meta.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, axes = plt.subplots(1, 3, figsize=(20, 4), dpi=300)
plot_categorical(adata_reg, cluster_col='neuron_type', coord_base="base_round1_umap", ax=axes[0], coding=True, text_anno=True, show=False)
plot_categorical(adata_reg, cluster_col='Subclass', coord_base="base_round1_umap", ax=axes[1], coding=True, text_anno=True, show=False)
plot_categorical(adata_reg, cluster_col='Group', coord_base="base_round1_umap", ax=axes[2], coding=False, text_anno=False, show=False)
plt.savefig(Path(image_path) / f"{_reg}_UMAP_annot.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, axes = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
plot_categorical(adata_reg[adata_reg.obs['replicate']=='salk'], cluster_col='Subclass', coord_base="spatial", ax=axes, show=False)
plt.savefig(Path(image_path) / f"{_reg}_spatial_subclass_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
plot_categorical(adata_reg[(adata_reg.obs['replicate']=='salk') & (adata_reg.obs['Group'] != "unknown")],
                 cluster_col="Group",
                 coord_base='spatial',
                 color_unknown="#D0D0D0",
                 ax=ax,
                 show=False)

plt.savefig(Path(image_path) / f"{_reg}_spatial_group_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
# print(adata_reg.obs['Group'].value_counts())
# plot_cats = ["SN SOX6 Dopa", "SN-VTR-HTH GATA3-TCF7L2 GABA", "VTR-HTH Glut", "SN EBF2 GABA", "SN SEMA5A GABA", "LAMP5-CXCL14 GABA"]

In [None]:
# print(adata_reg.obs['Subclass'].value_counts())
# plot_cats = adata_reg[adata_reg.obs['Subclass'] == "F M GATA3 GABA"].obs['Group'].value_counts().index
plot_cats = ["SN SOX6 Dopa", "SN-VTR-HTH GATA3-TCF7L2 GABA", "VTR-HTH Glut", "SN EBF2 GABA", "SN SEMA5A GABA", "LAMP5-CXCL14 GABA"]
print(plot_cats)

plot_col = "Group"
fig, ax = plt.subplots(figsize=(10, 10))
# ax.set_facecolor("black")
adata_s = adata_reg[(adata_reg.obs['replicate']=='salk') & (adata_reg.obs[plot_col] != "unknown")].copy()
adata_s.obs[plot_col] = adata_s.obs[plot_col].cat.remove_unused_categories().copy()

plot_categorical(adata_s,
                 cluster_col=plot_col,
                 highlight_group=None,
                 is_background=True,
                 coord_base='spatial',
                 color_unknown="#D0D0D0",
                 show_legend=False,
                 ax=ax, show=False)
plot_categorical(adata_s,
                 cluster_col=plot_col,
                 highlight_group=plot_cats,
                 color_unknown="#D0D0D0",
                 coord_base='spatial',
                 alpha=1,
                 s=4,
                 max_points=None,
                 legend_kws={"markersize": 10},
                 legend_fontsize=10,
                 legend_title_fontsize=12,
                 ax=ax, show=False)
plt.savefig(Path(image_path) / f"{_reg}_spatial_dopamine_gaba.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

### STH

In [None]:
_reg = "SUBTH"
adata_reg = adata[adata.obs['brain_region'] == _reg].copy()
adata_reg.obs['Subclass'] = adata_reg.obs['Subclass'].cat.remove_unused_categories()
adata_reg.obs['Group'] = adata_reg.obs['Group'].cat.remove_unused_categories()

fig, axes = plt.subplots(1,3, figsize=(16, 4), dpi=300)
plot_categorical(adata_reg, cluster_col='base_round1_leiden', coord_base="base_round1_umap", ax=axes[0], show=False)
plot_categorical(adata_reg, cluster_col='donor', coord_base="base_round1_umap", ax=axes[1], show=False)
plot_categorical(adata_reg, cluster_col='replicate', coord_base="base_round1_umap", ax=axes[2], show=False)
plt.savefig(Path(image_path) / f"{_reg}_UMAP_meta.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, axes = plt.subplots(1, 3, figsize=(20, 4), dpi=300)
plot_categorical(adata_reg, cluster_col='neuron_type', coord_base="base_round1_umap", ax=axes[0], coding=True, text_anno=True, show=False)
plot_categorical(adata_reg, cluster_col='Subclass', coord_base="base_round1_umap", ax=axes[1], coding=True, text_anno=True, show=False)
plot_categorical(adata_reg, cluster_col='Group', coord_base="base_round1_umap", ax=axes[2], coding=False, text_anno=False, show=False)
plt.savefig(Path(image_path) / f"{_reg}_UMAP_annot.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, axes = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
plot_categorical(adata_reg[adata_reg.obs['replicate']=='salk'], cluster_col='Subclass', coord_base="spatial", ax=axes, show=False)
plt.savefig(Path(image_path) / f"{_reg}_spatial_subclass_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
plot_categorical(adata_reg[(adata_reg.obs['replicate']=='salk') & (adata_reg.obs['Group'] != "unknown")],
                 cluster_col="Group",
                 coord_base='spatial',
                 color_unknown="#D0D0D0",
                 ax=ax,
                 show=False)

plt.savefig(Path(image_path) / f"{_reg}_spatial_group_salk.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
# print(adata_reg.obs['Subclass'].value_counts())
# plot_cats = adata_reg[adata_reg.obs['Subclass'] == "F Glut"].obs['Group'].value_counts().index
plot_cats = ["SN-VTR-HTH GATA3-TCF7L2 GABA", "ZI-HTH GABA", "STH PVALB-PITX2 Glut", "SN-VTR CALB1 Dopa", "VTR-HTH Glut", "SN-VTR GAD2 Dopa"]
print(plot_cats)

plot_col = "Group"
fig, ax = plt.subplots(figsize=(10, 10))
# ax.set_facecolor("black")
adata_s = adata_reg[(adata_reg.obs['replicate']=='salk') & (adata_reg.obs[plot_col] != "unknown")].copy()
adata_s.obs[plot_col] = adata_s.obs[plot_col].cat.remove_unused_categories().copy()

plot_categorical(adata_s,
                 cluster_col=plot_col,
                 highlight_group=None,
                 is_background=True,
                 coord_base='spatial',
                 color_unknown="#D0D0D0",
                 show_legend=False,
                 ax=ax, show=False)
plot_categorical(adata_s,
                 cluster_col=plot_col,
                 highlight_group=plot_cats,
                 color_unknown="#D0D0D0",
                 coord_base='spatial',
                 alpha=1,
                 s=4,
                 max_points=None,
                 legend_kws={"markersize": 10},
                 legend_fontsize=10,
                 legend_title_fontsize=12,
                 ax=ax, show=False)
plt.savefig(Path(image_path) / f"{_reg}_spatial_sthg.png", bbox_inches='tight', dpi=300)
plt.show()
plt.close()

## Region Template

In [None]:
# _reg = "CAB"
# adata_reg = adata[adata.obs['brain_region'] == _reg].copy()
# adata_reg.obs['RNA.Subclass'] = adata_reg.obs['RNA.Subclass'].cat.remove_unused_categories()
# adata_reg.obs['RNA.Group'] = adata_reg.obs['RNA.Group'].cat.remove_unused_categories()

# fig, axes = plt.subplots(1,3, figsize=(16, 4), dpi=300)
# plot_categorical(adata_reg, cluster_col='base_round1_leiden', coord_base="base_round1_umap", ax=axes[0], show=False)
# plot_categorical(adata_reg, cluster_col='donor', coord_base="base_round1_umap", ax=axes[1], show=False)
# plot_categorical(adata_reg, cluster_col='replicate', coord_base="base_round1_umap", ax=axes[2], show=False)
# plt.savefig(Path(image_path) / "annotations/ff"/ f"{current_datetime}_{_reg}_UMAP.png", bbox_inches='tight', dpi=300)
# plt.show()
# plt.close()

# fig, axes = plt.subplots(1, 3, figsize=(20, 4), dpi=300)
# plot_categorical(adata_reg, cluster_col='neuron_type', coord_base="base_round1_umap", ax=axes[0], coding=True, text_anno=True, show=False)
# plot_categorical(adata_reg, cluster_col='RNA.Subclass', coord_base="base_round1_umap", ax=axes[1], coding=True, text_anno=True, show=False)
# plot_categorical(adata_reg, cluster_col='RNA.Group', coord_base="base_round1_umap", ax=axes[2], coding=False, text_anno=False, show=False)
# plt.savefig(Path(image_path) / "annotations/ff" / f"{current_datetime}_{_reg}_UMAP_annot.png", bbox_inches='tight', dpi=300)
# plt.show()
# plt.close()

# fig, axes = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
# plot_categorical(adata_reg[adata_reg.obs['replicate']=='salk'], cluster_col='RNA.Subclass', coord_base="spatial", ax=axes, show=False)
# plt.show()
# plt.close()

# fig, ax = plt.subplots(figsize=(4, 4), dpi=300)
# plot_categorical(adata_reg[(adata_reg.obs['replicate']=='salk') & (adata_reg.obs['RNA.Group'] != "unknown")],
#                  cluster_col="RNA.Group",
#                  coord_base='spatial',
#                  color_unknown="#D0D0D0",
#                  ax=ax,
#                  show=False)
# plt.show()

# print(adata_reg.obs['RNA.Subclass'].value_counts())

# plot_cats = adata_reg[adata_reg.obs['RNA.Subclass'] == "F GABA"].obs['RNA.Group'].value_counts().index
# print(plot_cats)

# fig, ax = plt.subplots(figsize=(10, 10))
# # ax.set_facecolor("black")
# plot_categorical(adata_reg[(adata_reg.obs['replicate']=='salk') & (adata_reg.obs['RNA.Group'] != "unknown")],
#                  cluster_col="RNA.Group",
#                  highlight_group="None",
#                  coord_base='spatial',
#                  color_unknown="#D0D0D0",
#                  ax=ax, show=False)
# plot_categorical(adata_reg[(adata_reg.obs['replicate']=='salk') & (adata_reg.obs['RNA.Group'] != "unknown")],
#                  cluster_col="RNA.Group",
#                  highlight_group=plot_cats,
#                  color_unknown="#D0D0D0",
#                  coord_base='spatial',
#                  max_points=None,
#                  alpha=1,
#                  s=4,
#                  legend_kws={"markersize": 10},
#                  legend_fontsize=10,
#                  ax=ax, show=False)
# plt.show()