In [None]:
# parameters
regions = None
donors = None
replicates = None
genes_to_plot = ["PDE10A", "SYNPR", "WFS1", "CADM1", "PDE8B", "DRD1", "DRD2"]
subclass_plot = ["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"]
plot_regions = ["CAH", "CAB", "CAT", "PU", "NAC"]
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/strmat"


In [None]:
import os
from pathlib import Path
from rich import print as rprint, inspect

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
from spida.utilities._ad_utils import normalize_adata

import matplotlib.pyplot as plt
import seaborn as sns
from spida.pl._utils import add_color_scheme
from spida.pl import plot_categorical, plot_continuous, categorical_scatter, continuous_scatter
from spida.utilities.sd_utils import _get_obs_or_gene

plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['axes.facecolor'] = 'white'

In [None]:
# subclass_plot = ["CN LAMP5-CXCL14 GABA"]
# genes_to_plot = ["TOX3", "ANK1", "GRIK1", "HDAC9", "ADARB2"]

# subclass_plot = ["CN ST18 GABA"]
# genes_to_plot = ["GALNT17", "EPHA4", "GLP1R", "SYT1", "RASGRF2"]

In [None]:
# Print versions of important packages
print(f"Python: {os.sys.version}")
print(f"NumPy: {np.__version__}")
print(f"Pandas: {pd.__version__}")
print(f"Anndata: {ad.__version__}")
print(f"Scanpy: {sc.__version__}")

In [None]:
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/figures/supp_region_grad"
Path(image_path).mkdir(parents=True, exist_ok=True)

## functions

### Normalize + plot categorical 

In [None]:
# Move to spida.pl
def plot_categorical(
    adata : ad.AnnData | str, 
    ax=None,
    coord_base='tsne',
    cluster_col='MajorType',
    palette_path=None,
    highlight_group : str | list | None = None,
    color_unknown : str = "#D0D0D0",
    is_background : bool = False,
    axis_format='tiny',
    alpha=0.7,
    coding=True,
    id_marker=True,
    output=None,
    show=True,
    figsize=(4, 3.5),
    sheet_name=None,
    ncol=None,
    fontsize=5,
    legend_fontsize=5,
    legend_kws=None,
    legend_title_fontsize=5,
    marker_fontsize=4,
    marker_pad=0.1,
    linewidth=0.5,
    ret=False,
    text_anno:bool = False,
    text_kws=None,
    **kwargs
):
    """
    Plot cluster.

    Parameters
    ----------
    adata_path :
    ax :
    coord_base :
    cluster_col :
    palette_path :
    highlight_group : str | list
        A group (or list of groups) in the annotation to highlight by setting all other groups to background color (light gray).
    color_unknown : str
        Color to use for "unknown" or NaN categories. Default is light gray (#D0D0D0).
    coding :
    output :
    show :
    figsize :
    sheet_name :
    ncol :
    fontsize :
    legend_fontsize : int
        legend fontsize, default 5
    legend_kws: dict
        kwargs passed to ax.legend
    legend_title_fontsize: int
        legend title fontsize, default 5
    marker_fontsize: int
        Marker fontsize, default 3
        if id_marker is True, and coding is True. legend marker will be a circle (or rectangle) with code
    linewidth : float
        Line width of the legend marker (circle or rectangle), default 0.5
    text_anon: bool
        Whether to add text annotation for each cluster, default False.
	ret : bool 
		Whether to return the plot and colors, default False.
    kwargs : dict
        set text_anno=None to plot clustering without text annotations,
        coding=True to plot clustering without code annotations,
        set show_legend=False to remove the legend

    Returns
    -------

    """
    if sheet_name is None: # for getting color scheme from excel file
        sheet_name=cluster_col
    if isinstance(adata,str): # getting adata
        adata=ad.read_h5ad(adata,backed='r')
    if not isinstance(adata.obs[cluster_col].dtype, pd.CategoricalDtype): # make sure cluster_col is categorical 
        adata.obs[cluster_col] = adata.obs[cluster_col].astype('category')
    # get palette
    if palette_path is not None:
        if isinstance(palette_path,str):
            colors=pd.read_excel(os.path.expanduser(palette_path),sheet_name=sheet_name,index_col=0).Hex.to_dict()
            keys=list(colors.keys())
            existed_vals=adata.obs[cluster_col].unique().tolist()
            for k in existed_vals:
                if k not in keys:
                    colors[k]='gray'
            for k in keys:
                if k not in existed_vals:
                    del colors[k]
        else:
            colors=palette_path
        adata.uns[cluster_col + '_colors'] = [colors.get(k, 'grey') for k in adata.obs[cluster_col].cat.categories.tolist()]
    else:
        if f'{cluster_col}_colors' not in adata.uns:
            colors = add_color_scheme(adata, cluster_col, palette_key=f"{cluster_col}_colors")
        else:
            colors={cluster:color for cluster,color in zip(adata.obs[cluster_col].cat.categories.tolist(),adata.uns[f'{cluster_col}_colors'])}
    data = adata[adata.obs[cluster_col].notna()]
    if highlight_group is not None:
        if isinstance(highlight_group, str):
            highlight_group = [highlight_group]
        all_groups = adata.obs[cluster_col].cat.categories.tolist()
        for group in all_groups:
            if group not in highlight_group:
                colors[group] = color_unknown
                data.obs.loc[data.obs[cluster_col] == group, cluster_col] = np.nan
        data = data[data.obs[cluster_col].notna(), :]
        data.obs[cluster_col] = data.obs[cluster_col].cat.remove_unused_categories()
    colors = {k: v for k, v in colors.items() if k in data.obs[cluster_col].cat.categories.tolist()}
    if is_background:
        bg_colors = {k: color_unknown for k in colors.keys()}
        bg_colors['unknown'] = color_unknown
        colors = bg_colors        

    hue=cluster_col
    text_anno = cluster_col if text_anno else None
    text_kws = {} if text_kws is None else text_kws
    text_kws.setdefault("fontsize", fontsize)
    kwargs.setdefault("hue",hue)
    kwargs.setdefault("text_anno", text_anno)
    kwargs.setdefault("text_kws", text_kws)
    kwargs.setdefault("luminance", 0.65)
    kwargs.setdefault("dodge_text", False)
    kwargs.setdefault("axis_format", axis_format)
    kwargs.setdefault("show_legend", True)
    kwargs.setdefault("marker_fontsize", marker_fontsize)
    kwargs.setdefault("marker_pad", marker_pad)
    kwargs.setdefault("linewidth", linewidth)
    kwargs.setdefault("alpha", alpha)
    kwargs["coding"]=coding
    kwargs["id_marker"]=id_marker
    legend_kws={} if legend_kws is None else legend_kws
    default_lgd_kws=dict(
        fontsize=legend_fontsize,
        title=cluster_col,title_fontsize=legend_title_fontsize)
    if ncol is not None:
        default_lgd_kws['ncol']=ncol
    for k in default_lgd_kws:
        legend_kws.setdefault(k, default_lgd_kws[k])
    kwargs.setdefault("dodge_kws", {
            "arrowprops": {
                "arrowstyle": "->",
                "fc": 'grey',
                "ec": "none",
                "connectionstyle": "angle,angleA=-90,angleB=180,rad=5",
            },
            'autoalign': 'xy'})
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize, dpi=300)
    p = categorical_scatter(
        data=data, #adata[adata.obs[cluster_col].notna(),],
        ax=ax,
        coord_base=coord_base,
        palette=colors,
        legend_kws=legend_kws,
        **kwargs)

    if output is not None:
        plt.savefig(os.path.expanduser(output),bbox_inches='tight',dpi=300)
    if show:
        plt.show()
    
    if ret:
        return p, colors
    else: 
        return None, None

### RNA Plots

In [None]:
def plot_gene_groups(
    adata: ad.AnnData,
    gene: str,
    groups: list[str],
    regions: list[str],
    group_level: str = "Group",
    labs: list[str] = None,
    donors: list[str] = None,
    layers = None,
    layer_norm = False,
    hue_norm = 0.8,
    save_fig: bool = False,
    image_path: str = None,
    image_name: str = None,
    rasterized: bool = False,
    show: bool = True, 
): 
    if gene not in adata.var_names:
        print(f"Gene {gene} not found in adata.var_names")
        return
    adata = adata[adata.obs['brain_region_corr'].isin(regions)].copy()
    if layers is not None: 
        adata.X = adata.layers[layers].copy()
    if layer_norm:
        normalize_adata(adata, log1p=True)
    
    if labs is None: 
        labs = adata.obs['replicate'].unique().tolist()
    if donors is None:
        donors = adata.obs['donor'].unique().tolist()

    nrows = len(labs)
    ncols = len(regions)
    nplots = len(donors)
    
    for i, donor in enumerate(donors):
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3), dpi=300)
        # adata_don = adata[adata.obs['donor'] == donor].copy()
        # hue_min = adata_don[:, gene].X.min()
        # hue_max = adata_don[:, gene].X.max() * 0.9
        for l, lab in enumerate(labs):
            adata_don = adata[(adata.obs['donor'] == donor) & (adata.obs['replicate'] == lab)].copy()
            hue_min = adata_don[:, gene].X.min()
            hue_max = adata_don[:, gene].X.max() * hue_norm
            for r, region in enumerate(regions):
                ax = axes[l, r] if nrows > 1 and ncols > 1 else axes[max(l, r)]
                adata_sub = adata[(adata.obs['donor'] == donor) & (adata.obs['brain_region_corr'] == region) & (adata.obs['replicate'] == lab)].copy()
                categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax, rasterized=rasterized, axis_format=None)
                adata_plot = adata_sub[adata_sub.obs[group_level].isin(groups)].copy()
                if adata_plot.shape[0] != 0:
                    plot_continuous(adata_plot, coord_base="spatial", color_by=gene, cmap="YlOrRd", ax=ax, show=False, hue_norm=(hue_min, hue_max), rasterized=rasterized, axis_format=None) 
                ax.set_title(f"{region}")
        plt.suptitle(f"Gene: {gene} Donor: {donor}, Expression", y=1.02, rasterized=rasterized)
        if save_fig and image_path is not None:
            if image_name is not None: 
                image_save_path_png = f"{image_path}/{image_name}.png"
                image_save_path_pdf = f"{image_path}/{image_name}.pdf"
            else: 
                image_save_path_png = f"{image_path}/gene_{gene}_donor_{donor}_spatial.png"
                image_save_path_pdf = f"{image_path}/gene_{gene}_donor_{donor}_spatial.pdf"
            plt.savefig(image_save_path_png, bbox_inches='tight', dpi=300)
            plt.savefig(image_save_path_pdf, bbox_inches='tight', dpi=300)
        if show: 
            plt.show()
        plt.close()

def _plot_gene_violinplots(
    adata : ad.AnnData, 
    gene: str,
    regions: list,
    groups, 
    subset_level: str = "Subclass",
    group_level: str = "Group",
    donors: list = None,
    labs: list = None,
    layer: str = "volume_norm",
    layer_norm: bool = True,
    save_fig: bool = False,
    image_path: str = None,
    image_name: str = None,
    rasterized: bool = False,
    show: bool = True, 
): 
    if isinstance(donors, str):
        donors = [donors]
    if isinstance(labs, str):
        labs = [labs]
    
    if layer_norm: 
        normalize_adata(adata, log1p=True)

    adata, _drop_col = _get_obs_or_gene(adata, gene, layer) # get the column from obs or var
    df_obs = adata.obs[["donor", "replicate", "brain_region", "Subclass", "Group", gene]].copy()
    if _drop_col:  # drop the column if it was added for plotting
        adata.obs.drop(columns=[gene], inplace=True)

    df_obs = df_obs.loc[(
        (df_obs[subset_level].isin(groups)) & 
        (df_obs["brain_region"].isin(regions)) & 
        (df_obs['Group'] != "unknown")
    )]
    df_obs['Group'] = df_obs['Group'].cat.remove_unused_categories()
    df_obs['brain_region'] = df_obs['brain_region'].cat.remove_unused_categories()

    
    if donors is None: 
        donors = adata.obs['donor'].unique().tolist()
    if labs is None: 
        labs = adata.obs['replicate'].unique().tolist()

    nrows = len(donors)
    fig, axes = plt.subplots(nrows, 1, figsize=(10,2*nrows+1), dpi=300, sharex=True, sharey=False)
    for i, _d in enumerate(donors):
        df_d = df_obs[df_obs['donor'] == _d]
        ax=axes[i] if nrows > 1 else axes
        sns.violinplot(data=df_d, x=group_level, y=gene, hue="brain_region", ax=ax, inner="quart", hue_order=regions, palette=adata.uns['brain_region_palette'], rasterized=rasterized)
        if i == nrows-1:
            ax.tick_params(axis='x', labelbottom=True)
            ax.set_xticks(range(len(df_d[group_level].cat.categories)))
            ax.set_xticklabels(ax.get_xticklabels(), rotation=15, ha='right', fontsize=12)
        ax.set_title(f"Donor: {_d}", rasterized=rasterized)
        ax.legend(title='Brain Region', fontsize=6, title_fontsize=7, loc='upper right', bbox_to_anchor=(1.15, 1))
    
    plt.suptitle(f"{gene} Expression in STR D1 and D2 MSNs across Brain Regions", y=1.02, rasterized=rasterized)

    if save_fig and image_path is not None:
        if image_name is not None: 
            image_save_path_png = f"{image_path}/{image_name}.png"
            image_save_path_pdf = f"{image_path}/{image_name}.pdf"
        else: 
            image_save_path_png = f"{image_path}/gene_{gene}_violin.png"
            image_save_path_pdf = f"{image_path}/gene_{gene}_violin.pdf"
        plt.savefig(image_save_path_png, bbox_inches='tight', dpi=300)
        plt.savefig(image_save_path_pdf, bbox_inches='tight', dpi=300)
    if show: 
        plt.show()
    plt.close()

### Meth Plots

In [None]:
def plot_gene_groups_meth(
    adata: ad.AnnData,
    gene: str,
    groups: list[str],
    regions: list[str],
    group_level: str = "Group",
    labs: list[str] = None,
    donors: list[str] = None,
    layers = None,
    layer_norm = False,
    max_norm = 90,
    min_norm = 10,
    bg_adata: ad.AnnData = None,
    cmap="YlOrRd",
    save_fig: bool = False,
    image_path: str = None,
    image_name: str = None,
    rasterized: bool = False,
    show: bool = True, 
): 
    if gene not in adata.var_names:
        print(f"Gene {gene} not found in adata.var_names")
        return
    adata = adata[adata.obs['brain_region_corr'].isin(regions)].copy()
    if layers is not None: 
        adata.X = adata.layers[layers].copy()
    if layer_norm:
        normalize_adata(adata, log1p=True)
    
    if labs is None: 
        labs = adata.obs['replicate'].unique().tolist()
    if donors is None:
        donors = adata.obs['donor'].unique().tolist()

    nrows = len(labs)
    ncols = len(regions)
    nplots = len(donors)
    
    for i, donor in enumerate(donors):
        fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3), dpi=300)
        # adata_don = adata[adata.obs['donor'] == donor].copy()
        # hue_min = adata_don[:, gene].X.min()
        # hue_max = adata_don[:, gene].X.max() * 0.9
        for l, lab in enumerate(labs):
            adata_don = adata[(adata.obs['donor'] == donor) & (adata.obs['replicate'] == lab)].copy()
            hue_min = np.percentile(adata_don[:, gene].X, min_norm)
            hue_max = np.percentile(adata_don[:, gene].X, max_norm)
            for r, region in enumerate(regions):
                ax = axes[l, r] if nrows > 1 and ncols > 1 else axes[max(l, r)]
                adata_sub = adata[(adata.obs['donor'] == donor) & (adata.obs['brain_region_corr'] == region) & (adata.obs['replicate'] == lab)].copy()
                if bg_adata is not None: 
                    bg_sub = bg_adata[(bg_adata.obs['donor'] == donor) & (bg_adata.obs['brain_region_corr'] == region) & (bg_adata.obs['replicate'] == lab)].copy()
                    categorical_scatter(bg_sub, coord_base="spatial", color='lightgrey', ax=ax, rasterized=rasterized, axis_format=None)
                else: 
                    categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax, rasterized=rasterized, axis_format=None)
                adata_plot = adata_sub[adata_sub.obs[group_level].isin(groups)].copy()
                if adata_plot.shape[0] != 0:
                    plot_continuous(adata_plot, coord_base="spatial", color_by=gene, cmap=cmap, ax=ax, show=False, hue_norm=(hue_min, hue_max), rasterized=rasterized, axis_format=None) 
                ax.set_title(f"{region}")
        plt.suptitle(f"Gene: {gene} Donor: {donor}, mCH score", y=1.02, rasterized=rasterized)

        if save_fig and image_path is not None:
            if image_name is not None: 
                image_save_path_png = f"{image_path}/{image_name}.png"
                image_save_path_pdf = f"{image_path}/{image_name}.pdf"
            else: 
                image_save_path_png = f"{image_path}/gene_{gene}_donor_{donor}_mch_spatial.png"
                image_save_path_pdf = f"{image_path}/gene_{gene}_donor_{donor}_mch_spatial.pdf"
            plt.savefig(image_save_path_png, bbox_inches='tight', dpi=300)
            plt.savefig(image_save_path_pdf, bbox_inches='tight', dpi=300)
        if show: 
            plt.show()
        plt.close()

def _plot_gene_violinplots_meth(
    adata : ad.AnnData, 
    gene: str,
    regions: list,
    groups, 
    subset_level: str = "Subclass",
    group_level: str = "Group",
    donors: list = None,
    labs: list = None,
    layer: str = "volume_norm",
    layer_norm: bool = True,
    save_fig: bool = False,
    image_path: str = None,
    image_name: str = None,
    rasterized: bool = False,
    show: bool = True, 
): 
    if isinstance(donors, str):
        donors = [donors]
    if isinstance(labs, str):
        labs = [labs]
    
    if layer_norm: 
        normalize_adata(adata, log1p=True)

    adata, _drop_col = _get_obs_or_gene(adata, gene, layer) # get the column from obs or var
    df_obs = adata.obs[["donor", "replicate", "brain_region", "Subclass", "Group", gene]].copy()
    if _drop_col:  # drop the column if it was added for plotting
        adata.obs.drop(columns=[gene], inplace=True)

    df_obs = df_obs.loc[(
        (df_obs[subset_level].isin(groups)) & 
        (df_obs["brain_region"].isin(regions)) & 
        (df_obs['Group'] != "unknown")
    )]
    df_obs['Group'] = df_obs['Group'].cat.remove_unused_categories()
    df_obs['brain_region'] = df_obs['brain_region'].cat.remove_unused_categories()

    
    if donors is None: 
        donors = adata.obs['donor'].unique().tolist()
    if labs is None: 
        labs = adata.obs['replicate'].unique().tolist()

    nrows = len(donors)
    fig, axes = plt.subplots(nrows, 1, figsize=(10,2*nrows+1), dpi=300, sharex=True, sharey=False)
    for i, _d in enumerate(donors):
        df_d = df_obs[df_obs['donor'] == _d]
        ax=axes[i] if nrows > 1 else axes
        sns.violinplot(data=df_d, x=group_level, y=gene, hue="brain_region", ax=ax, inner="quart", hue_order=regions, palette=adata.uns['brain_region_palette'], rasterized=rasterized)
        if i == nrows-1:
            ax.tick_params(axis='x', labelbottom=True)
            ax.set_xticks(range(len(df_d[group_level].cat.categories)))
            ax.set_xticklabels(ax.get_xticklabels(), rotation=15, ha='right', fontsize=12)
        ax.set_title(f"Donor: {_d}", rasterized=rasterized)
        ax.legend(title='Brain Region', fontsize=6, title_fontsize=7, loc='upper right', bbox_to_anchor=(1.15, 1))
    
    plt.suptitle(f"{gene} mCH Score in STR D1 and D2 MSNs across Brain Regions", y=1.02, rasterized=rasterized)
    
    if save_fig and image_path is not None:
        if image_name is not None: 
            image_save_path_png = f"{image_path}/{image_name}.png"
            image_save_path_pdf = f"{image_path}/{image_name}.pdf"
        else: 
            image_save_path_png = f"{image_path}/gene_{gene}_mch_violin.png"
            image_save_path_pdf = f"{image_path}/gene_{gene}_mch_violin.pdf"
        plt.savefig(image_save_path_png, bbox_inches='tight', dpi=300)
        plt.savefig(image_save_path_pdf, bbox_inches='tight', dpi=300)
    if show: 
        plt.show()
    plt.close()

## Read

In [None]:
adata_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
adata = ad.read_h5ad(adata_path)
adata

In [None]:
spatial_mch = ad.read_h5ad("/home/x-aklein2/projects/aklein/BICAN/BG/data/methylation_2/BG_mCH_Imp_SubR.h5ad")
for k, v in adata.uns.items(): 
    if "colors" in k or "palette" in k: 
        spatial_mch.uns[k] = v
spatial_mch.obsm['spatial'] = spatial_mch.obs[['CENTER_X', 'CENTER_Y']].values
spatial_mch.obs['brain_region_corr'] = adata.obs['brain_region_corr'].copy()
spatial_mch

In [None]:
donors = adata.obs['donor'].unique().tolist()
labs = adata.obs['replicate'].unique().tolist()
brain_regions = adata.obs['brain_region'].keys().tolist()
print(donors)

## Plots

### RNA

In [None]:
genes_to_plot = ["PDE8B"]

In [None]:
adata.obs['MSN_Groups'].unique()

In [None]:
subclass_plot = ["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"]
# subclass_plot = ["STR Hybrid MSN"]
# subclass_plot = ["STR D1 MSN"]
# subclass_plot = ["STR D2 MSN"]
# group_plot = ['STRv D1 MSN']
# plot_regions = ["CAH", "CAB", "CAT", "PU", "NAC"]
plot_regions = ["CaT", "Pu", "CaB", "CaH", "NAC"]

In [None]:
image_path

In [None]:
# # violin plots
# for _gene in genes_to_plot:
#     for _donor in donors: 
#         for _lab in labs: 
#             _plot_gene_violinplots(
#                 adata,
#                 gene=_gene,
#                 groups=subclass_plot,
#                 subset_level="Subclass",
#                 group_level="Group",
#                 regions=plot_regions,
#                 donors=[_donor],
#                 labs=[_lab],
#                 layer='volume_norm',
#                 layer_norm=False,
#                 save_fig=True, 
#                 image_path=image_path,
#                 image_name=f'violin_{_gene}_{_donor}_{_lab}',
#                 rasterized=True,
#                 show=False,
#             )

#             # _plot_gene_violinplots(
#             #     adata,
#             #     gene=_gene,
#             #     groups=subclass_plot,
#             #     subset_level="Subclass",
#             #     group_level="Group",
#             #     regions=["PU", "CAB", "NAC"],
#             #     donors=[_donor],
#             #     labs=[_lab],
#             #     layer='volume_norm',
#             #     layer_norm=True,
#             #     save_fig=True, 
#             #     image_path=image_path,
#             #     image_name=f'violin_DV_{_gene}_{_donor}_{_lab}',
#             #     show=False,
#             # )

In [None]:
# Spatial Plots
for _gene in genes_to_plot:
    for _donor in donors: 
        for _lab in labs: 
            plot_gene_groups(
                adata,
                groups=subclass_plot,
                group_level="Subclass", 
                regions=plot_regions,
                donors=[_donor],
                labs = [_lab],
                gene=_gene,
                layers='volume_norm',
                layer_norm=False,
                hue_norm=0.5,
                save_fig=True,
                image_path=image_path,
                image_name=f'spatial_{_gene}_{_donor}_{_lab}',
                rasterized=True,
                show=False,
            )

In [None]:
# # Spatial Plots
# for _gene in genes_to_plot:
#     for _donor in donors: 
#         for _lab in labs: 
#             plot_gene_groups(
#                 adata,
#                 groups=subclass_plot,
#                 group_level="Subclass", 
#                 regions=plot_regions,
#                 donors=[_donor],
#                 labs = [_lab],
#                 gene=_gene,
#                 layers='volume_norm',
#                 layer_norm=False,
#                 hue_norm=0.5,
#                 save_fig=True,
#                 image_path=image_path,
#                 image_name=f'MSN_D2_{_gene}_{_donor}_{_lab}',
#                 rasterized=True,
#                 show=False,
#             )

### Meth

#### violin plots

In [None]:
# violin plots
for _gene in genes_to_plot:
    for _donor in donors: 
        for _lab in labs: 
            _plot_gene_violinplots_meth(
                spatial_mch,
                gene=_gene,
                groups=["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"],
                subset_level="Subclass",
                group_level="Group",
                regions=plot_regions,
                donors=[_donor],
                labs=[_lab],
                layer=None,
                layer_norm=False,
                save_fig=True, 
                image_path=image_path,
                image_name=f'violin_mch_{_gene}_{_donor}_{_lab}',
                rasterized=True,
                show=False,
            )

#             _plot_gene_violinplots_meth(
#                 spatial_mch,
#                 gene=_gene,
#                 groups=["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"],
#                 subset_level="Subclass",
#                 group_level="Group",
#                 regions=["PU", "CAB", "NAC"],
#                 donors=[_donor],
#                 labs=[_lab],
#                 layer=None,
#                 layer_norm=False,
#                 save_fig=True, 
#                 image_path=image_path,
#                 image_name=f'violin_mch_DV_{_gene}_{_donor}_{_lab}',
#                 show=False,
#             )

#### spatial plots

In [None]:
spatial_mch.obs.column

In [None]:
# Spatial Plots
for _gene in genes_to_plot:
    for _donor in donors: 
        for _lab in labs: 
            plot_gene_groups_meth(
                spatial_mch,
                bg_adata = adata,
                groups=subclass_plot,
                group_level="Subclass", 
                regions=plot_regions,
                donors=[_donor],
                labs = [_lab],
                gene=_gene,
                layers=None,
                layer_norm=False,
                cmap="Blues_r",
                min_norm=10,
                max_norm=90,
                save_fig=True,
                image_path=image_path,
                image_name=f'spatial_mch_{_gene}_{_donor}_{_lab}',
                rasterized=True,
                show=False,
            )

            # plot_gene_groups_meth(
            #     spatial_mch,
            #     bg_adata = adata,
            #     groups=["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"],
            #     group_level="Subclass", 
            #     regions=["CAH", "CAB", "CAT"],
            #     donors=[_donor],
            #     labs = [_lab],
            #     gene=_gene,
            #     layers=None,
            #     layer_norm=False,
            #     cmap="Blues_r",
            #     min_norm=10,
            #     max_norm=90,
            #     save_fig=True,
            #     image_path=image_path,
            #     image_name=f'spatial_mch_AP_{_gene}_{_donor}_{_lab}',
            #     show=False,
            # )
    #         break
    #     # break
    # break

In [None]:
# plot_gene_groups_meth(
#     spatial_mch,
#     bg_adata = adata,
#     groups=["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"],
#     group_level="Subclass",
#     regions=["CAH", "CAB", "CAT"],
#     # regions=["PU", "CAB", "NAC"],
#     donors=adata.obs['donor'].unique().tolist(),
#     labs = ['salk'],
#     gene="CADM1",
#     layers=None,
#     layer_norm=False,
#     cmap="Blues_r",
#     min_norm=10,
#     max_norm=90
#     # hue_norm=0.9
# )

# plot_gene_groups_meth(
#     spatial_mch,
#     bg_adata = adata,
#     groups=["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"],
#     group_level="Subclass",
#     # regions=["CAH", "CAB", "CAT"],
#     regions=["PU", "CAB", "NAC"],
#     donors=adata.obs['donor'].unique().tolist(),
#     labs = ['salk'],
#     gene="CADM1",
#     layers=None,
#     layer_norm=False,
#     cmap="Blues_r",
#     min_norm=10,
#     max_norm=90
#     # hue_norm=0.9
# )

In [None]:
# plot_gene_groups_meth(
#     spatial_mch,
#     bg_adata = adata,
#     groups=["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"],
#     group_level="Subclass",
#     regions=["CAH", "CAB", "CAT"],
#     # regions=["PU", "CAB", "NAC"],
#     donors=adata.obs['donor'].unique().tolist(),
#     labs = ['salk'],
#     gene="PDE8B",
#     layers=None,
#     layer_norm=False,
#     cmap="Blues_r",
#     min_norm=10,
#     max_norm=90
#     # hue_norm=0.9
# )

# plot_gene_groups_meth(
#     spatial_mch,
#     bg_adata = adata,
#     groups=["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"],
#     group_level="Subclass",
#     # regions=["CAH", "CAB", "CAT"],
#     regions=["PU", "CAB", "NAC"],
#     donors=adata.obs['donor'].unique().tolist(),
#     labs = ['salk'],
#     gene="PDE8B",
#     layers=None,
#     layer_norm=False,
#     cmap="Blues_r",
#     min_norm=10,
#     max_norm=90
#     # hue_norm=0.9
# )