The point of this notebook is to generate the figures summarizing the annotation + integration of spatial data with HMBA + snm3C data. The figures generated here go into spatial supp. figures 3-4 (supp. #)

author: Amit Klein
email: a3klein@ucsd.edu

In [None]:
import os
from pathlib import Path
from tqdm import tqdm

import numpy as np
import pandas as pd
import anndata as ad
from scipy.stats import entropy

import matplotlib.pyplot as plt
import seaborn as sns
from spida.pl import plot_categorical, categorical_scatter, plot_continuous
from spida.pl._utils import plot_text_legend

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 8
plt.rcParams['axes.facecolor'] = 'white'
    
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['savefig.transparent'] = True
plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['savefig.pad_inches'] = 0.01

In [None]:
image_path = Path("/home/x-aklein2/projects/aklein/BICAN/BG/images/figures/supp_annot")
image_path.mkdir(exist_ok=True, parents=True)

In [None]:
adata_path = Path("/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad")
adata = ad.read_h5ad(adata_path)
adata

### Helper Functions

In [None]:
def add_colors(adata, cat_col, palette):
    colors = []
    for _cat in adata.obs[cat_col].cat.categories: 
        try:
            if isinstance(palette, dict):
                color = palette[_cat]
            else:
                color = palette.loc[_cat, 'Hex']
        except KeyError:
            print(_cat)
            color = '#808080'
        colors.append(color)

    adata.uns[f'{cat_col}_colors'] = colors

def entropy_to_df(entropies_dict, method_name="RNA"):
    plot_data = []
    for (donor, region, lab), group_entropies in entropies_dict.items():
        for group, entropy_val in group_entropies.items():
            plot_data.append({
                'donor': donor,
                'brain_region': region,
                'replicate': lab,
                'group': group,
                'entropy': entropy_val,
                'method': method_name
            })
    entropy_df = pd.DataFrame(plot_data)
    return entropy_df

In [None]:
def export_text_circle_legend(
    code2label,
    color_dict,
    filename=None,
    save=False,
    show=True,
    marker_fontsize=8,
    luminance=0.6,
    title=None,
    legend_kws=None
):

    fig = plt.figure(figsize=(2, 2))
    ax = fig.add_axes([0, 0, 1, 1])
    ax.axis('off')

    # Step 3: call your exact legend creation function
    plot_text_legend(
        color_dict=color_dict,
        code2label=code2label,
        ax=ax,
        title=title,
        legend_kws=legend_kws,
        marker_fontsize=marker_fontsize,
        luminance=luminance
        # The remaining kwargs will follow your defaults
    )
    # plt.tight_layout()
    
    if save and filename is not None: 
        if isinstance(filename, str): 
            filename = [filename]
        for _f in filename:
            fig.savefig(_f) #, bbox_inches='tight')
    if show:
        plt.show()
    plt.close(fig)


## Example RNA Integration

In [None]:
### Need to: Get an example - CAH 5224 salk
int_dir = Path("/home/x-aklein2/projects/aklein/BICAN/BG/annotation/execute/region_donor_lab_cps2/CAH_UCI5224_salk")

ad_neu_sub = ad.read_h5ad(int_dir / "neuronal_subclass.h5ad")
ad_neu_sub

In [None]:
rna_cell_type_column = "Subclass"
qry_cell_type_column = "ANNOT"

common_cells = ad_neu_sub.obs_names.intersection(adata.obs_names)
ad_neu_sub.obs.loc[common_cells, "ANNOT"] = adata.obs.loc[common_cells, "Subclass"]
add_colors(ad_neu_sub, qry_cell_type_column, adata.uns["Subclass_palette"])
add_colors(ad_neu_sub, rna_cell_type_column, adata.uns["Subclass_palette"])


# For the legend
qry_incl_groups = ad_neu_sub.obs[qry_cell_type_column].unique().dropna().tolist().copy()
rna_incl_groups = ad_neu_sub.obs[rna_cell_type_column].unique().dropna().tolist().copy()
incl_groups = np.unique(qry_incl_groups + rna_incl_groups).tolist()

df = ad_neu_sub.obs[[rna_cell_type_column, qry_cell_type_column]].copy()
df[rna_cell_type_column] = df[rna_cell_type_column].cat.set_categories(incl_groups)
df[qry_cell_type_column] = df[qry_cell_type_column].cat.set_categories(incl_groups)
df['comb'] = df[rna_cell_type_column].fillna(df[qry_cell_type_column])

codes = {k: v for k, v in enumerate(sorted(incl_groups))}
ad_neu_sub.obs['codes'] = df['comb'].map({v: k for k, v in codes.items()})
color_dict = {label: adata.uns[f"{rna_cell_type_column}_palette"][label] for label in incl_groups}

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6), dpi=200)
axes = axes.flatten()

ax = axes[0]
categorical_scatter(data=ad_neu_sub, coord_base="integrated_umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
plot_categorical(ad_neu_sub, coord_base="integrated_umap", cluster_col=rna_cell_type_column, 
                 show=False, coding='codes', text_anno=True, ax=ax, marker_fontsize=8, show_legend=False, 
                 legend_kws=dict(bbox_to_anchor=(0.8, 1), loc='upper right', borderaxespad=0., fontsize=10, title_fontsize=12, title="Subclass"),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"HMBA Ref Ca Subclass", fontsize=16)

ax = axes[1]
categorical_scatter(data=ad_neu_sub, coord_base="integrated_umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
plot_categorical(ad_neu_sub, coord_base="integrated_umap", cluster_col=qry_cell_type_column,
                 show=False, coding='codes', text_anno=True, ax=ax, labelsize=10, marker_fontsize=8, show_legend=False, 
                 legend_kws=dict(bbox_to_anchor=(0.8, 1), loc='upper right', borderaxespad=0., fontsize=10, title_fontsize=12, title="Subclass"),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"MERSCOPE Annotated Subclass", fontsize=16)

plt.tight_layout()
plt.savefig(image_path / "RNA_Ca_Subclass_Annot.png", dpi=300, bbox_inches='tight')
plt.savefig(image_path / "RNA_Ca_Subclass_Annot.pdf", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

export_text_circle_legend(
    codes, color_dict, marker_fontsize=6,
    show=False, save=True,
    filename=[image_path / f"RNA_Ca_Subclass_Annot_Legend.{ext}" for ext in ['pdf', 'png']],
    legend_kws=dict(bbox_to_anchor=(0, 1), borderaxespad=0., ncol=1, title=rna_cell_type_column, title_fontsize=12)
)

In [None]:
rna_cell_type_column = "Group"
qry_cell_type_column = "G_ANNOT"

common_cells = ad_neu_sub.obs_names.intersection(adata.obs_names)
ad_neu_sub.obs.loc[common_cells, "G_ANNOT"] = adata.obs.loc[common_cells, "Group"]
add_colors(ad_neu_sub, qry_cell_type_column, adata.uns["Group_palette"])
add_colors(ad_neu_sub, rna_cell_type_column, adata.uns["Group_palette"])

qry_incl_groups = ad_neu_sub.obs[qry_cell_type_column].unique().dropna().tolist().copy()
rna_incl_groups = ad_neu_sub.obs[rna_cell_type_column].unique().dropna().tolist().copy()
incl_groups = np.unique(qry_incl_groups + rna_incl_groups).tolist()

df = ad_neu_sub.obs[[rna_cell_type_column, qry_cell_type_column]].copy()
df[rna_cell_type_column] = df[rna_cell_type_column].cat.set_categories(incl_groups)
df[qry_cell_type_column] = df[qry_cell_type_column].cat.set_categories(incl_groups)
df['comb'] = df[rna_cell_type_column].fillna(df[qry_cell_type_column])

codes = {k: v for k, v in enumerate(sorted(incl_groups))}
ad_neu_sub.obs['codes'] = df['comb'].map({v: k for k, v in codes.items()})
color_dict = {label: adata.uns[f"{rna_cell_type_column}_palette"][label] for label in incl_groups}

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6), dpi=200)
axes = axes.flatten()

ax = axes[0]
categorical_scatter(data=ad_neu_sub, coord_base="integrated_umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
plot_categorical(ad_neu_sub, coord_base="integrated_umap", cluster_col=rna_cell_type_column, 
                 show=False, coding='codes', text_anno=True, ax=ax, marker_fontsize=6, show_legend=False,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0., ncol=1, title="Group", fontsize=8, title_fontsize=10),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"HMBA Ref Ca Group", fontsize=16)

ax = axes[1]
categorical_scatter(data=ad_neu_sub, coord_base="integrated_umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
plot_categorical(ad_neu_sub, coord_base="integrated_umap", cluster_col=qry_cell_type_column, 
                 show=False, coding='codes', text_anno=True, ax=ax, labelsize=10, marker_fontsize=6, show_legend=False,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0., ncol=1, title="Group", fontsize=8, title_fontsize=10),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"MERSCOPE Annotated Group", fontsize=16)

plt.savefig(image_path / "RNA_Ca_Group_Annot.png", dpi=300, bbox_inches='tight')
plt.savefig(image_path / "RNA_Ca_Group_Annot.pdf", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

export_text_circle_legend(
    codes, color_dict, marker_fontsize=6,
    show=False, save=True,
    filename=[image_path / f"RNA_Ca_Group_Annot_Legend.{ext}" for ext in ['pdf', 'png']],
    legend_kws=dict(bbox_to_anchor=(0, 1), borderaxespad=0., ncol=1, title=rna_cell_type_column, title_fontsize=12)
)

In [None]:
ad_nn_sub = ad.read_h5ad(int_dir / "nonneuronal_subclass.h5ad")
ad_nn_sub

In [None]:
rna_cell_type_column = "Subclass"
qry_cell_type_column = "ANNOT"

common_cells = ad_nn_sub.obs_names.intersection(adata.obs_names)
ad_nn_sub.obs.loc[common_cells, "ANNOT"] = adata.obs.loc[common_cells, "Subclass"]
add_colors(ad_nn_sub, qry_cell_type_column, adata.uns["Subclass_palette"])
add_colors(ad_nn_sub, rna_cell_type_column, adata.uns["Subclass_palette"])

qry_incl_groups = ad_nn_sub.obs[qry_cell_type_column].unique().dropna().tolist().copy()
rna_incl_groups = ad_nn_sub.obs[rna_cell_type_column].unique().dropna().tolist().copy()
incl_groups = np.unique(qry_incl_groups + rna_incl_groups).tolist()

df = ad_nn_sub.obs[[rna_cell_type_column, qry_cell_type_column]].copy()
df[rna_cell_type_column] = df[rna_cell_type_column].cat.set_categories(incl_groups)
df[qry_cell_type_column] = df[qry_cell_type_column].cat.set_categories(incl_groups)
df['comb'] = df[rna_cell_type_column].fillna(df[qry_cell_type_column])

codes = {k: v for k, v in enumerate(sorted(incl_groups))}
ad_nn_sub.obs['codes'] = df['comb'].map({v: k for k, v in codes.items()})
color_dict = {label: adata.uns[f"{rna_cell_type_column}_palette"][label] for label in incl_groups}


fig, axes = plt.subplots(1, 2, figsize=(12, 6), dpi=200)
axes = axes.flatten()

ax = axes[0]
categorical_scatter(data=ad_nn_sub, coord_base="integrated_umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
plot_categorical(ad_nn_sub, coord_base="integrated_umap", cluster_col=rna_cell_type_column, 
                 show=False, coding="codes", text_anno=True, ax=ax, show_legend=False,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"HMBA Ref Ca Subclass", fontsize=16)

ax = axes[1]
categorical_scatter(data=ad_nn_sub, coord_base="integrated_umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
plot_categorical(ad_nn_sub, coord_base="integrated_umap", cluster_col=qry_cell_type_column,
                 show=False, coding="codes", text_anno=True, ax=ax, show_legend=False,
                 legend_kws=dict(bbox_to_anchor=(0.8, 1), loc='upper left', borderaxespad=0., title="Subclass"),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"MERSCOPE Annotated Subclass", fontsize=16)

plt.savefig(image_path / "RNA_Ca_Subclass_Annot_NN.png", dpi=300, bbox_inches='tight')
plt.savefig(image_path / "RNA_Ca_Subclass_Annot_NN.pdf", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

export_text_circle_legend(
    codes, color_dict, marker_fontsize=6,
    show=False, save=True,
    filename=[image_path / f"RNA_Ca_Subclass_Annot_NN_Legend.{ext}" for ext in ['pdf', 'png']],
    legend_kws=dict(bbox_to_anchor=(0, 1), borderaxespad=0., ncol=1, title=rna_cell_type_column, title_fontsize=12)
)

In [None]:
rna_cell_type_column = "Group"
qry_cell_type_column = "G_ANNOT"

common_cells = ad_nn_sub.obs_names.intersection(adata.obs_names)
ad_nn_sub.obs.loc[common_cells, "G_ANNOT"] = adata.obs.loc[common_cells, "Group"]
add_colors(ad_nn_sub, qry_cell_type_column, adata.uns["Group_palette"])
add_colors(ad_nn_sub, rna_cell_type_column, adata.uns["Group_palette"])

qry_incl_groups = ad_nn_sub.obs[qry_cell_type_column].unique().dropna().tolist().copy()
rna_incl_groups = ad_nn_sub.obs[rna_cell_type_column].unique().dropna().tolist().copy()
incl_groups = np.unique(qry_incl_groups + rna_incl_groups).tolist()

df = ad_nn_sub.obs[[rna_cell_type_column, qry_cell_type_column]].copy()
df[rna_cell_type_column] = df[rna_cell_type_column].cat.set_categories(incl_groups)
df[qry_cell_type_column] = df[qry_cell_type_column].cat.set_categories(incl_groups)
df['comb'] = df[rna_cell_type_column].fillna(df[qry_cell_type_column])

codes = {k: v for k, v in enumerate(sorted(incl_groups))}
ad_nn_sub.obs['codes'] = df['comb'].map({v: k for k, v in codes.items()})
color_dict = {label: adata.uns[f"{rna_cell_type_column}_palette"][label] for label in incl_groups}


fig, axes = plt.subplots(1, 2, figsize=(12, 6), dpi=200)
axes = axes.flatten()

ax = axes[0]
categorical_scatter(data=ad_nn_sub, coord_base="integrated_umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
plot_categorical(ad_nn_sub, coord_base="integrated_umap", cluster_col=rna_cell_type_column, 
                 show=False, coding="codes", text_anno=True, ax=ax, show_legend=False,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0., ncol=1),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"HMBA Ref Ca Group", fontsize=16)

ax = axes[1]
categorical_scatter(data=ad_nn_sub, coord_base="integrated_umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
plot_categorical(ad_nn_sub, coord_base="integrated_umap", cluster_col=qry_cell_type_column, 
                 show=False, coding="codes", text_anno=True, ax=ax, show_legend=False,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0., ncol=1, title="Group"),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"MERSCOPE Annotated Group", fontsize=16)

plt.savefig(image_path / "RNA_Ca_Group_Annot_NN.png", dpi=300, bbox_inches='tight')
plt.savefig(image_path / "RNA_Ca_Group_Annot_NN.pdf", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

export_text_circle_legend(
    codes, color_dict, marker_fontsize=6,
    show=False, save=True,
    filename=[image_path / f"RNA_Ca_Group_Annot_NN_Legend.{ext}" for ext in ['pdf', 'png']],
    legend_kws=dict(bbox_to_anchor=(0, 1), borderaxespad=0., ncol=1, title=rna_cell_type_column, title_fontsize=12)
)

## Example Meth Integration

In [None]:
# import subprocess
# tt = Path("/home/x-aklein2/projects/aklein/BICAN/BG/images/figures/supp_annot")
# for _file in tt.glob("*"):
#     if ":" in _file.stem:
#         _new_file = _file.parent / (_file.stem.replace(":", "_") + _file.suffix)
#         subprocess.run(["mv", str(_file), str(_new_file)])
#         # print(_file)
#         # print(_new_file)

In [None]:
work_dir = Path('/anvil/projects/x-mcb130189/qzeng/analysis/251105_merfish_methylation_2/Neuron.mC_merfish.integration-3')
all_adata_files = list(work_dir.glob('*/final_with_coords.h5ad'))
len(all_adata_files)

In [None]:
mc_image_path = Path("/home/x-aklein2/projects/aklein/BICAN/BG/images/annotations/qz_meth")
mc_image_path.mkdir(exist_ok=True, parents=True)

In [None]:
for i, _test_file in enumerate(all_adata_files): 
    title = _test_file.parents[0].stem.replace(":", "_")
    print(title)
    temp_adata = ad.read_h5ad(_test_file)
    add_colors(temp_adata, 'donor', adata.uns['donor_palette'])
    add_colors(temp_adata, "Modality", {"mC": "#1f77b4", "merfish": "#ff7f0e"})

    # SUBCLASS
    hue = "Subclass"
    hue_mc = f"{hue}_transfer"

    qry_incl_groups = temp_adata.obs[hue].unique().dropna().tolist().copy()
    rna_incl_groups = temp_adata.obs[hue_mc].unique().dropna().tolist().copy()
    incl_groups = np.unique(qry_incl_groups + rna_incl_groups).tolist()

    df = temp_adata.obs[[hue_mc, hue]].copy()
    df[hue_mc] = df[hue_mc].cat.set_categories(incl_groups)
    df[hue] = df[hue].cat.set_categories(incl_groups)
    df['comb'] = df[hue_mc].fillna(df[hue])

    codes = {k: v for k, v in enumerate(sorted(incl_groups))}
    temp_adata.obs['codes'] = df['comb'].map({v: k for k, v in codes.items()})
    color_dict = {label: adata.uns[f"{hue}_palette"][label] for label in incl_groups}

    mc_data = temp_adata[temp_adata.obs['Modality'] == 'mC']
    merfish_data = temp_adata[temp_adata.obs['Modality'] == 'merfish']
    add_colors(mc_data, hue, adata.uns["Subclass_palette"])
    add_colors(merfish_data, hue, adata.uns["Subclass_palette"])
    add_colors(merfish_data, hue_mc, adata.uns["Subclass_palette"])


    fig, axes = plt.subplots(1, 3, figsize=(18, 6), dpi=200)
    axes = axes.flatten()

    ax = axes[0]
    categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
    plot_categorical(mc_data, coord_base="umap", cluster_col=hue, 
                    show=False, coding="codes", text_anno=True, ax=ax, show_legend=False,
                    legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
                    rasterized=True, axis_format=None
                    )
    ax.set_title(f"snm3C Ref Subclass", fontsize=16)

    ax = axes[1]
    categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
    plot_categorical(merfish_data, coord_base="umap", cluster_col=hue_mc,
                    show=False, coding="codes", text_anno=True, ax=ax, show_legend=False,
                    legend_kws=dict(bbox_to_anchor=(0.8, 1), loc='upper left', borderaxespad=0., title="Subclass"),
                    rasterized=True, axis_format=None
                    )
    ax.set_title(f"MERSCOPE MC Annotated Subclass", fontsize=16)

    ax = axes[2]
    categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
    plot_categorical(merfish_data, coord_base="umap", cluster_col=hue,
                    show=False, coding="codes", text_anno=True, ax=ax, show_legend=False, 
                    legend_kws=dict(bbox_to_anchor=(0.8, 1), loc='upper left', borderaxespad=0., title="Subclass"),
                    rasterized=True, axis_format=None
                    )
    ax.set_title(f"MERSCOPE RNA Annotated Subclass", fontsize=16)

    # plt.suptitle(f"{title} - Subclass Annotation")
    plt.savefig(mc_image_path / f"{title}_Subclass_Annot.png", dpi=300, bbox_inches='tight')
    plt.savefig(mc_image_path / f"{title}_Subclass_Annot.pdf", dpi=300, bbox_inches='tight')
    # plt.show()
    plt.close()

    export_text_circle_legend(
        codes, color_dict, marker_fontsize=6,
        show=False, save=True,
        filename=[mc_image_path / f"{title}_Subclass_Annot_Legend.{ext}" for ext in ['pdf', 'png']],
        legend_kws=dict(bbox_to_anchor=(0, 1), borderaxespad=0., ncol=1, title=rna_cell_type_column, title_fontsize=12)
    )

    # GROUP
    hue = "Group"
    hue_mc = f"{hue}_transfer"

    qry_incl_groups = temp_adata.obs[hue].unique().dropna().tolist().copy()
    rna_incl_groups = temp_adata.obs[hue_mc].unique().dropna().tolist().copy()
    incl_groups = np.unique(qry_incl_groups + rna_incl_groups).tolist()

    df = temp_adata.obs[[hue_mc, hue]].copy()
    df[hue_mc] = df[hue_mc].cat.set_categories(incl_groups)
    df[hue] = df[hue].cat.set_categories(incl_groups)
    df['comb'] = df[hue_mc].fillna(df[hue])

    codes = {k: v for k, v in enumerate(sorted(incl_groups))}
    temp_adata.obs['codes'] = df['comb'].map({v: k for k, v in codes.items()})
    color_dict = {label: adata.uns[f"{hue}_palette"][label] for label in incl_groups}

    mc_data = temp_adata[temp_adata.obs['Modality'] == 'mC']
    merfish_data = temp_adata[temp_adata.obs['Modality'] == 'merfish']
    add_colors(mc_data, hue, adata.uns[f"{hue}_palette"])
    add_colors(merfish_data, hue, adata.uns[f"{hue}_palette"])
    add_colors(merfish_data, hue_mc, adata.uns[f"{hue}_palette"])

    fig, axes = plt.subplots(1, 3, figsize=(18, 6), dpi=200)
    axes = axes.flatten()

    ax = axes[0]
    categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
    plot_categorical(mc_data, coord_base="umap", cluster_col=hue, 
                    show=False, coding='codes', text_anno=True, ax=ax, show_legend=False,
                    legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
                    rasterized=True, axis_format=None
                    )
    ax.set_title(f"snm3C Ref {hue}", fontsize=16)

    ax = axes[1]
    categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
    plot_categorical(merfish_data, coord_base="umap", cluster_col=hue_mc,
                    show=False, coding='codes', text_anno=True, ax=ax, show_legend=False,
                    legend_kws=dict(bbox_to_anchor=(0.8, 1), loc='upper left', borderaxespad=0., title=hue),
                    rasterized=True, axis_format=None
                    )
    ax.set_title(f"MERSCOPE MC Annotated {hue}", fontsize=16)

    ax = axes[2]
    categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
    plot_categorical(merfish_data, coord_base="umap", cluster_col=hue,
                    show=False, coding='codes', text_anno=True, ax=ax, show_legend=False,
                    legend_kws=dict(bbox_to_anchor=(0.8, 1), loc='upper left', borderaxespad=0., title=hue),
                    rasterized=True, axis_format=None
                    )
    ax.set_title(f"MERSCOPE RNA Annotated {hue}", fontsize=16)

    # plt.suptitle(f"{title} - {hue} Annotation")
    plt.savefig(mc_image_path / f"{title}_Group_Annot.png", dpi=300, bbox_inches='tight')
    plt.savefig(mc_image_path / f"{title}_Group_Annot.pdf", dpi=300, bbox_inches='tight')
    # plt.show()
    plt.close()

    
    export_text_circle_legend(
        codes, color_dict, marker_fontsize=6,
        show=False, save=True,
        filename=[mc_image_path / f"{title}_Group_Annot_Legend.{ext}" for ext in ['pdf', 'png']],
        legend_kws=dict(bbox_to_anchor=(0, 1), borderaxespad=0., ncol=1, title=rna_cell_type_column, title_fontsize=12)
    )

    # Plot Meta
    fig, axes = plt.subplots(1, 3, figsize=(18, 6), dpi=200)
    axes = axes.flatten()

    ax = axes[0]
    categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
    plot_categorical(temp_adata, coord_base="umap", cluster_col="Modality", 
                    show=False, coding=False, text_anno=False, ax=ax, marker_fontsize=8,
                    legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0., fontsize=10, title_fontsize=12, title="Modality"),
                    rasterized=True, axis_format=None
                    )
    ax.set_title(f"Modality", fontsize=16)


    ax = axes[1]
    categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
    plot_categorical(mc_data, coord_base="umap", cluster_col="donor", 
                    show=False, coding=False, text_anno=False, ax=ax, marker_fontsize=8,
                    legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0., fontsize=10, title_fontsize=12, title="Donor"),
                    rasterized=True, axis_format=None
                    )
    ax.set_title(f"Donor", fontsize=16)

    ax = axes[2]
    categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
    plot_categorical(mc_data, coord_base="umap", cluster_col=hue, 
                    show=False, coding=True, text_anno=True, ax=ax, marker_fontsize=8,
                    legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0., fontsize=8, title_fontsize=12, title=hue),
                    rasterized=True, axis_format=None
                    )
    ax.set_title(f"snm3C Ref {hue}", fontsize=16)
    # plt.suptitle(f"Meta Information - {title}")
    plt.savefig(mc_image_path / f"{title}_meta.png", dpi=300, bbox_inches='tight')
    plt.savefig(mc_image_path / f"{title}_meta.pdf", dpi=300, bbox_inches='tight')
    # plt.show()
    plt.close()

    # if i == 0: 
    #     break


## RN specific figures

In [None]:
rn_image_path = Path("/home/x-aklein2/projects/aklein/BICAN/BG/images/figures/supp_rn")
rn_image_path.mkdir(exist_ok=True, parents=True)
show_plots = False

### Null Example

In [None]:
_donor = "UCI2424"
_lab = "salk"

_file = f"/anvil/projects/x-mcb130189/qzeng/analysis/251105_merfish_methylation_2/Neuron.mC_merfish.integration-3/MGM1-MGM1:{_donor}:{_lab}:Neuron/final_with_coords.h5ad"
temp_adata = ad.read_h5ad(_file)
add_colors(temp_adata, 'donor', adata.uns['donor_palette'])
add_colors(temp_adata, "Modality", {"mC": "#1f77b4", "merfish": "#ff7f0e"})

hue = "Subclass"
hue_mc = f"{hue}_transfer"
mc_data = temp_adata[temp_adata.obs['Modality'] == 'mC']
merfish_data = temp_adata[temp_adata.obs['Modality'] == 'merfish']

In [None]:
add_colors(mc_data, hue, adata.uns["Subclass_palette"])
add_colors(merfish_data, hue, adata.uns["Subclass_palette"])
add_colors(merfish_data, hue_mc, adata.uns["Subclass_palette"])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(mc_data, coord_base="umap", cluster_col=hue, 
                 show=False, coding=False, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0., fontsize=8, title_fontsize=8, ),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"snm3C Ref. Subclass", fontsize=12)
plt.savefig(rn_image_path / f"{_donor}_{_lab}_snmc3C_Subclass_Annot.png", dpi=300, bbox_inches='tight')
plt.savefig(rn_image_path / f"{_donor}_{_lab}_snmc3C_Subclass_Annot.pdf", dpi=300, bbox_inches='tight')
if show_plots:
    plt.show()
plt.close()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(merfish_data, coord_base="umap", cluster_col=hue_mc,
                 show=False, coding=False, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.8, 1), loc='upper left', borderaxespad=0., fontsize=8, title_fontsize=8, title="Subclass"),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"MERSCOPE MC Annotated Subclass", fontsize=12)
plt.savefig(rn_image_path / f"{_donor}_{_lab}_merscopeMC_Subclass_Annot.png", dpi=300, bbox_inches='tight')
plt.savefig(rn_image_path / f"{_donor}_{_lab}_merscopeMC_Subclass_Annot.pdf", dpi=300, bbox_inches='tight')
if show_plots:
    plt.show()
plt.close()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(merfish_data, coord_base="umap", cluster_col=hue,
                 show=False, coding=False, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.8, 1), loc='upper left', borderaxespad=0., fontsize=8, title_fontsize=8, title="Subclass"),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"MERSCOPE RNA Annotated Subclass", fontsize=12)
plt.savefig(rn_image_path / f"{_donor}_{_lab}_merscopeRNA_Subclass_Annot.png", dpi=300, bbox_inches='tight')
plt.savefig(rn_image_path / f"{_donor}_{_lab}_merscopeRNA_Subclass_Annot.pdf", dpi=300, bbox_inches='tight')
if show_plots:
    plt.show()
plt.close()

In [None]:
hue = "Group"
hue_mc = f"{hue}_transfer"
add_colors(mc_data, hue, adata.uns[f"{hue}_palette"])
add_colors(merfish_data, hue, adata.uns[f"{hue}_palette"])
add_colors(merfish_data, hue_mc, adata.uns[f"{hue}_palette"])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(mc_data, coord_base="umap", cluster_col=hue, 
                 show=False, coding=True, text_anno=True, ax=ax, 
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0., fontsize=8, title_fontsize=8, markersize=20),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"snm3C Ref. {hue}", fontsize=12)
plt.savefig(rn_image_path / f"{_donor}_{_lab}_snmc3C_{hue}_Annot.png", dpi=300, bbox_inches='tight')
plt.savefig(rn_image_path / f"{_donor}_{_lab}_snmc3C_{hue}_Annot.pdf", dpi=300, bbox_inches='tight')
if show_plots:
    plt.show()
plt.close()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(merfish_data, coord_base="umap", cluster_col=hue_mc,
                 show=False, coding=True, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.8, 1), loc='upper left', borderaxespad=0., title=hue, fontsize=8, title_fontsize=8, markersize=20),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"MERSCOPE MC Annotated {hue}", fontsize=12)
plt.savefig(rn_image_path / f"{_donor}_{_lab}_merscopeMC_{hue}_Annot.png", dpi=300, bbox_inches='tight')
plt.savefig(rn_image_path / f"{_donor}_{_lab}_merscopeMC_{hue}_Annot.pdf", dpi=300, bbox_inches='tight')
if show_plots:
    plt.show()
plt.close()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(merfish_data, coord_base="umap", cluster_col=hue,
                 show=False, coding=True, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.8, 1), loc='upper left', borderaxespad=0., title=hue, fontsize=8, title_fontsize=8, markersize=20),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"MERSCOPE RNA Annotated {hue}", fontsize=12)
plt.savefig(rn_image_path / f"{_donor}_{_lab}_merscopeRNA_{hue}_Annot.png", dpi=300, bbox_inches='tight')
plt.savefig(rn_image_path / f"{_donor}_{_lab}_merscopeRNA_{hue}_Annot.pdf", dpi=300, bbox_inches='tight')
if show_plots:
    plt.show()
plt.close()

In [None]:
# add_colors(temp_adata, 'donor', adata.uns['donor_palette'])
add_colors(temp_adata, "Modality", {"mC": "#19AAD1", "merfish": "#E8743B"})
# add_colors(temp_adata, "Modality", {"mC": "#8EBA42", "merfish": "#9467BD"})

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(temp_adata, coord_base="umap", cluster_col="Modality", 
                 show=False, coding=False, text_anno=False, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0., fontsize=8, title_fontsize=8),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"Modality", fontsize=12)
plt.savefig(rn_image_path / f"{_donor}_{_lab}_Modality.png", dpi=300, bbox_inches='tight')
plt.savefig(rn_image_path / f"{_donor}_{_lab}_Modality.pdf", dpi=300, bbox_inches='tight')
if show_plots:
    plt.show()
plt.close()

fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(mc_data, coord_base="umap", cluster_col="donor", 
                 show=False, coding=False, text_anno=False, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0., fontsize=8, title_fontsize=8, title="Donor"),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"snm3C Ref. Donor", fontsize=12)
plt.savefig(rn_image_path / f"{_donor}_{_lab}_Donor.png", dpi=300, bbox_inches='tight')
plt.savefig(rn_image_path / f"{_donor}_{_lab}_Donor.pdf", dpi=300, bbox_inches='tight')
if show_plots:
    plt.show()
plt.close()

In [None]:
merfish_data.obsm['spatial'] = merfish_data.obs[['CENTER_X', 'CENTER_Y']].to_numpy()
adata_ss = adata[(adata.obs['experiment'] == merfish_data.obs['experiment'].unique()[0]) &
                 (adata.obs['region'] == merfish_data.obs['region'].unique()[0])]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(temp_adata, coord_base="umap", cluster_col="leiden", 
                 show=False, coding=True, text_anno=True, ax=ax, show_legend=False,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0., fontsize=8, title_fontsize=8, markersize=10),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"Integrated Leiden Clusters", fontsize=12)
plt.savefig(rn_image_path / f"{_donor}_{_lab}_int_leiden.png", dpi=300, bbox_inches='tight')
plt.savefig(rn_image_path / f"{_donor}_{_lab}_int_leiden.pdf", dpi=300, bbox_inches='tight')
if show_plots:
    plt.show()
plt.close()

fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
categorical_scatter(data=adata_ss, coord_base="spatial", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(merfish_data, coord_base="spatial", cluster_col="leiden", 
                 show=False, coding=True, text_anno=True, ax=ax, show_legend=False, 
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"Spatial Integrated Leiden Clusters", fontsize=12)
plt.savefig(rn_image_path / f"{_donor}_{_lab}_int_leiden_spatial.png", dpi=300, bbox_inches='tight')
plt.savefig(rn_image_path / f"{_donor}_{_lab}_int_leiden_spatial.pdf", dpi=300, bbox_inches='tight')
if show_plots:
    plt.show()
plt.close()

In [None]:
# fmg_data = temp_adata[(temp_adata.obs['Subclass'] == "F M GATA3 GABA") & (temp_adata.obs['Modality'] == "mC")].obs.copy()
fmg_data = temp_adata[(temp_adata.obs['Modality'] == "mC")].obs.copy()
ref_vc = fmg_data.groupby('leiden', observed=True)['Group'].value_counts(normalize=True).sort_values(ascending=False).reset_index()
ref_vc.drop_duplicates(subset=['leiden'], keep='first', inplace=True)
gmap = ref_vc.set_index('leiden')[['Group']].to_dict()['Group']
g_to_rn = {'F M GATA3 GABA' : "RN", 
            'SN GATA3-PVALB GABA' : "SN", 
            'SN-VTR-HTH GATA3-TCF7L2 GABA' : "SN-VTR-HTH"
        }
gmap = {k : g_to_rn[v] for k, v in gmap.items() if v in g_to_rn.keys()}


cc_col = 'FM_GATA3_GABA_TYPES'
temp_adata.obs[cc_col] = temp_adata.obs['leiden'].map(gmap).fillna("other").astype('category')

fm_palette = {
    'RN' : '#880808',
    'SN' : adata.uns['Group_palette']['SN GATA3-PVALB GABA'],
    'SN-VTR-HTH' : adata.uns['Group_palette']['SN-VTR-HTH GATA3-TCF7L2 GABA'],
    'other' : '#d3d3d3'
}

add_colors(temp_adata, cc_col, fm_palette)
merfish_data.obs[cc_col] = temp_adata.obs[cc_col]
add_colors(merfish_data, cc_col, fm_palette)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(temp_adata, coord_base="umap", cluster_col=cc_col, 
                 show=False, coding=False, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0., title="F M GATA3 GABA Types", fontsize=8, title_fontsize=8),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"F M GATA3 GABA Types", fontsize=12)
plt.savefig(rn_image_path / f"{_donor}_{_lab}_umap_rn_plots.png", dpi=300, bbox_inches='tight')
plt.savefig(rn_image_path / f"{_donor}_{_lab}_umap_rn_plots.pdf", dpi=300, bbox_inches='tight')
if show_plots:
    plt.show()
plt.close()

fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
categorical_scatter(data=adata_ss, coord_base="spatial", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(merfish_data, coord_base="spatial", cluster_col=cc_col, 
                 show=False, coding=False, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0., title="F M GATA3 GABA Types", fontsize=8, title_fontsize=8),
                 rasterized=True, axis_format=None
                )
ax.set_title(f"F M GATA3 GABA Types", fontsize=12)
plt.savefig(rn_image_path / f"{_donor}_{_lab}_spatial_rn_plots.png", dpi=300, bbox_inches='tight')
plt.savefig(rn_image_path / f"{_donor}_{_lab}_spatial_rn_plots.pdf", dpi=300, bbox_inches='tight')
if show_plots:
    plt.show()
plt.close()

### Overview 

In [None]:
# add_colors(mc_data, hue, adata.uns["Subclass_palette"])
# add_colors(merfish_data, hue, adata.uns["Subclass_palette"])
# add_colors(merfish_data, hue_mc, adata.uns["Subclass_palette"])

# fig, axes = plt.subplots(1, 3, figsize=(18, 6), dpi=200)
# axes = axes.flatten()

# ax = axes[0]
# categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
# plot_categorical(mc_data, coord_base="umap", cluster_col=hue, 
#                  show=False, coding=True, text_anno=True, ax=ax,
#                  legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
#                  rasterized=True
#                 )
# ax.set_title(f"snm3C Ref Subclass")

# ax = axes[1]
# categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
# plot_categorical(merfish_data, coord_base="umap", cluster_col=hue_mc,
#                  show=False, coding=True, text_anno=True, ax=ax,
#                  legend_kws=dict(bbox_to_anchor=(0.8, 1), loc='upper left', borderaxespad=0., title="Subclass"),
#                  rasterized=True
#                 )
# ax.set_title(f"MERSCOPE MC Annotated Subclass")

# ax = axes[2]
# categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
# plot_categorical(merfish_data, coord_base="umap", cluster_col=hue,
#                  show=False, coding=True, text_anno=True, ax=ax,
#                  legend_kws=dict(bbox_to_anchor=(0.8, 1), loc='upper left', borderaxespad=0., title="Subclass"),
#                  rasterized=True
#                 )
# ax.set_title(f"MERSCOPE RNA Annotated Subclass")

# # plt.savefig(image_path / "RNA_Ca_Subclass_Annot.png", dpi=300, bbox_inches='tight')
# # plt.savefig(image_path / "RNA_Ca_Subclass_Annot.pdf", dpi=300, bbox_inches='tight')
# plt.show()
# plt.close()

In [None]:
# hue = "Group"
# hue_mc = f"{hue}_transfer"
# add_colors(mc_data, hue, adata.uns[f"{hue}_palette"])
# add_colors(merfish_data, hue, adata.uns[f"{hue}_palette"])
# add_colors(merfish_data, hue_mc, adata.uns[f"{hue}_palette"])

In [None]:
# fig, axes = plt.subplots(1, 3, figsize=(18, 6), dpi=200)
# axes = axes.flatten()

# ax = axes[0]
# categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
# plot_categorical(mc_data, coord_base="umap", cluster_col=hue, 
#                  show=False, coding=True, text_anno=True, ax=ax,
#                  legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
#                  rasterized=True
#                 )
# ax.set_title(f"snm3C Ref {hue}")

# ax = axes[1]
# categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
# plot_categorical(merfish_data, coord_base="umap", cluster_col=hue_mc,
#                  show=False, coding=True, text_anno=True, ax=ax,
#                  legend_kws=dict(bbox_to_anchor=(0.8, 1), loc='upper left', borderaxespad=0., title=hue),
#                  rasterized=True
#                 )
# ax.set_title(f"MERSCOPE MC Annotated {hue}")

# ax = axes[2]
# categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
# plot_categorical(merfish_data, coord_base="umap", cluster_col=hue,
#                  show=False, coding=True, text_anno=True, ax=ax,
#                  legend_kws=dict(bbox_to_anchor=(0.8, 1), loc='upper left', borderaxespad=0., title=hue),
#                  rasterized=True
#                 )
# ax.set_title(f"MERSCOPE RNA Annotated {hue}")

# # plt.savefig(image_path / "RNA_Ca_Subclass_Annot.png", dpi=300, bbox_inches='tight')
# # plt.savefig(image_path / "RNA_Ca_Subclass_Annot.pdf", dpi=300, bbox_inches='tight')
# plt.show()
# plt.close()

In [None]:
# Plot Meta
fig, axes = plt.subplots(1, 3, figsize=(18, 6), dpi=200)
axes = axes.flatten()

ax = axes[0]
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(temp_adata, coord_base="umap", cluster_col="Modality", 
                 show=False, coding=True, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
                 rasterized=True
                )
ax.set_title(f"Modality")


ax = axes[1]
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(mc_data, coord_base="umap", cluster_col="donor", 
                 show=False, coding=True, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
                 rasterized=True
                )
ax.set_title(f"donor")

ax = axes[2]
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(mc_data, coord_base="umap", cluster_col=hue, 
                 show=False, coding=True, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
                 rasterized=True
                )
ax.set_title(f"snm3C Ref {hue}")

plt.show()
plt.close()


In [None]:
merfish_data.obsm['spatial'] = merfish_data.obs[['CENTER_X', 'CENTER_Y']].to_numpy()

In [None]:
merfish_data.obsm['spatial'] = merfish_data.obs[['CENTER_X', 'CENTER_Y']].to_numpy()
adata_ss = adata[(adata.obs['experiment'] == merfish_data.obs['experiment'].unique()[0]) &
                 (adata.obs['region'] == merfish_data.obs['region'].unique()[0])]

In [None]:
# Plot Meta
fig, axes = plt.subplots(1, 3, figsize=(18, 6), dpi=200)
axes = axes.flatten()

ax = axes[0]
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(temp_adata, coord_base="umap", cluster_col="leiden", 
                 show=False, coding=True, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
                 rasterized=True
                )
ax.set_title(f"Modality")


ax = axes[1]
categorical_scatter(data=adata_ss, coord_base="spatial", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(merfish_data, coord_base="spatial", cluster_col="leiden", 
                 show=False, coding=True, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
                 rasterized=True
                )
ax.set_title(f"donor")

ax = axes[2]
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(mc_data, coord_base="umap", cluster_col='Subclass', 
                 show=False, coding=True, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
                 rasterized=True
                )
ax.set_title(f"snm3C Ref Subclass")

plt.show()
plt.close()


In [None]:
cc_col = 'FM_GATA3_GABA_TYPES'
# ucsd sample
# temp_adata.obs[cc_col] = temp_adata.obs['leiden'].map({
#     '0' : 'RN', 
#     '1' : 'RN', 
#     '2' : 'RN', 
#     '3' : 'RN', 
#     '7' : 'splatter',
#     '8' : 'splatter',
#     '19' :'splatter',
#     '22' : 'splatter'
# }).fillna("other").astype('category')

# salk sample
temp_adata.obs[cc_col] = temp_adata.obs['leiden'].map({
    '0' : 'RN', 
    '1' : 'RN', 
    '2' : 'RN', 
    '5' : 'RN', 
    '6' : 'splatter',
    '8' :'splatter',
    '12' : 'splatter',
    '17' : 'splatter',
    '21' : 'splatter',
}).fillna("other").astype('category')

fm_palette = {
    'RN' : '#1f77b4',
    'splatter' : '#ff7f0e',
    'other' : '#808080'
}

add_colors(temp_adata, cc_col, fm_palette)
merfish_data.obs[cc_col] = temp_adata.obs[cc_col]
add_colors(merfish_data, cc_col, fm_palette)

In [None]:
# Plot Meta
fig, axes = plt.subplots(1, 3, figsize=(18, 6), dpi=200)
axes = axes.flatten()

ax = axes[0]
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(temp_adata, coord_base="umap", cluster_col=cc_col, 
                 show=False, coding=True, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
                 rasterized=True
                )
ax.set_title(f"Modality")


ax = axes[1]
categorical_scatter(data=adata_ss, coord_base="spatial", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(merfish_data, coord_base="spatial", cluster_col=cc_col, 
                 show=False, coding=True, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
                 rasterized=True
                )
ax.set_title(f"donor")

ax = axes[2]
categorical_scatter(data=temp_adata, coord_base="umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True)
plot_categorical(mc_data, coord_base="umap", cluster_col='Subclass', 
                 show=False, coding=True, text_anno=True, ax=ax,
                 legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
                 rasterized=True
                )
ax.set_title(f"snm3C Ref Subclass")

plt.show()
plt.close()


## Subclass / Group annotation entropy

### Helper Plot

In [None]:
# Code from ChatGPT and then modified
def violin_jitter_plot_with_palette(
    ax,
    df,
    cat_col,
    y_col,
    palette=None,             
    max_points=500,           
    jitter=0.15,
    violin_alpha=0.7,
    point_size=12,
    point_alpha=0.9,
    seed=42,
    sort_by_median=True,
    median_marker="line",    # "line" or "dot"
    median_color="black",
    median_linewidth=2,
    median_size=40,
    xlabel_color_map = None, 
    rasterized=False,
    title=None,
    title_fontsize=24,
    ax_tick_fontsize=12,
    ax_label_fontsize=14,
    ax_tick_rotation=30,
):

    rng = np.random.default_rng(seed)

    if sort_by_median:
        grouped = df.groupby(cat_col)[y_col].median().sort_values()
        categories = list(grouped.index)
    else:
        categories = list(df[cat_col].unique())

    # Color palette
    if palette is None:
        palette = {cat: "gray" for cat in categories}

    # Positions on x-axis
    x_positions = np.arange(len(categories))
    cat_to_x = {cat: xi for xi, cat in enumerate(categories)}

    # ------------------------------
    # 2. Prepare data for violins
    # ------------------------------
    data_for_violin = [df.loc[df[cat_col] == cat, y_col].values
                       for cat in categories]

    vp = ax.violinplot(
        data_for_violin,
        positions=x_positions,
        showmeans=False,
        showextrema=False,
        showmedians=False
    )

    # Style violins
    for body, cat in zip(vp["bodies"], categories):
        body.set_facecolor(palette.get(cat, "gray"))
        body.set_edgecolor("black")
        body.set_alpha(violin_alpha)

    # ------------------------------
    # 3. Jittered scatter points
    # ------------------------------
    for cat in categories:
        subset = df[df[cat_col] == cat][y_col].values

        # Downsample jitter points
        if len(subset) > max_points:
            subset = rng.choice(subset, size=max_points, replace=False)

        xi = cat_to_x[cat]
        jitter_x = xi + rng.uniform(-jitter, jitter, size=len(subset))

        ax.scatter(
            jitter_x,
            subset,
            s=point_size,
            color=palette.get(cat, "black"),
            alpha=point_alpha,
            zorder=3,
        )

    # ------------------------------
    # 4. Median markers
    # ------------------------------
    medians = [np.median(df[df[cat_col] == cat][y_col].values)
               for cat in categories]

    for xi, med in zip(x_positions, medians):
        if median_marker == "line":
            ax.hlines(
                y=med,
                xmin=xi - 0.25,
                xmax=xi + 0.25,
                color=median_color,
                linewidth=median_linewidth,
                zorder=4
            )
        elif median_marker == "dot":
            ax.scatter(
                [xi], [med],
                color=median_color,
                s=median_size,
                zorder=4
            )

    ax.set_xticks(x_positions)
    ax.set_xticklabels(categories, rotation=ax_tick_rotation, ha="right", fontsize=ax_tick_fontsize, rasterized=rasterized)
    if xlabel_color_map:
        for i, ticklabel in enumerate(ax.get_xticklabels()): 
            ticklabel.set_color(xlabel_color_map.get(ticklabel.get_text(), '#808080'))
    ax.set_xlim(-0.5, len(categories) - 0.5)


    ax.set_ylabel(f'\nAnnotation Score', fontsize=ax_label_fontsize, rasterized=rasterized)
    ax.set_yticks(np.arange(0, 1.1, 0.2), np.arange(0, 1.1, 0.2).round(1), fontsize=ax_tick_fontsize, rasterized=rasterized)
    # ax.set_yticklabels(ax.get_yticklabels(), fontsize=ax_tick_fontsize, rasterized=rasterized)
    ax.grid(axis='y', linestyle='--', alpha=0.3, rasterized=rasterized)
    ax.set_ylim(bottom=0)

    ax.set_title(title, fontsize=title_fontsize, rasterized=rasterized)

    return ax


### RNA

In [None]:
# Entropy by cell type (with std-dev error bars for each dsid!)
# For Group, wherever there is 1 subclass - group just drop that group and don't show it! 

In [None]:
# TODO: Assign the transfer score!
all_annots = []
root_annot_path = Path("/home/x-aklein2/projects/aklein/BICAN/BG/annotation/execute/region_donor_lab_cps2/")
for _dir in root_annot_path.iterdir():
    if not _dir.is_dir():
        continue
    ts_scores = pd.read_csv(_dir / "nn_gr_label_transfer.tsv", sep="\t", index_col=0)
    df = adata.obs.loc[ts_scores.index].copy()
    dft = ts_scores.idxmax(axis=1).to_frame(name="annot").merge(df, left_index=True, right_index=True)
    ft = ts_scores.reset_index().melt(id_vars = ["index"])
    ft['annot'] = ft['index'].map(dft['Group'].to_dict())
    nn_gr_df_annots = ft.loc[ft['variable'] == ft['annot']]
    nn_gr_df_annots = (
        nn_gr_df_annots.rename(columns={"annot": "group", "value" : "group_score"})
        .drop(columns=["variable"])
        .set_index("index")
    )

    ts_scores = pd.read_csv(_dir / "neu_gr_label_transfer.tsv", sep="\t", index_col=0)
    df = adata.obs.loc[ts_scores.index].copy()
    dft = ts_scores.idxmax(axis=1).to_frame(name="annot").merge(df, left_index=True, right_index=True)
    ft = ts_scores.reset_index().melt(id_vars = ["index"])
    ft['annot'] = ft['index'].map(dft['Group'].to_dict())
    neu_gr_df_annots = ft.loc[ft['variable'] == ft['annot']]
    neu_gr_df_annots = (
        neu_gr_df_annots.rename(columns={"annot": "group", "value" : "group_score"})
        .drop(columns=["variable"])
        .set_index("index")
    )

    ts_scores = pd.read_csv(_dir / "nn_sub_label_transfer.tsv", sep="\t", index_col=0)
    df = adata.obs.loc[ts_scores.index].copy()
    dft = ts_scores.idxmax(axis=1).to_frame(name="annot").merge(df, left_index=True, right_index=True)
    ft = ts_scores.reset_index().melt(id_vars = ["index"])
    ft['annot'] = ft['index'].map(dft['Subclass'].to_dict())
    nn_sub_df_annots = ft.loc[ft['variable'] == ft['annot']]
    nn_sub_df_annots = (
        nn_sub_df_annots.rename(columns={"annot": "subclass", "value" : "subclass_score"})
        .drop(columns=["variable"])
        .set_index("index")
    )

    ts_scores = pd.read_csv(_dir / "neu_sub_label_transfer.tsv", sep="\t", index_col=0)
    df = adata.obs.loc[ts_scores.index].copy()
    dft = ts_scores.idxmax(axis=1).to_frame(name="annot").merge(df, left_index=True, right_index=True)
    ft = ts_scores.reset_index().melt(id_vars = ["index"])
    ft['annot'] = ft['index'].map(dft['Subclass'].to_dict())
    neu_sub_df_annots = ft.loc[ft['variable'] == ft['annot']]
    neu_sub_df_annots = (
        neu_sub_df_annots.rename(columns={"annot": "subclass", "value" : "subclass_score"})
        .drop(columns=["variable"])
        .set_index("index")
    )

    gr_df_annots = pd.concat((neu_gr_df_annots, nn_gr_df_annots), axis=0)
    sub_df_annots = pd.concat((neu_sub_df_annots, nn_sub_df_annots), axis=0)
    df_annots = pd.concat((sub_df_annots, gr_df_annots), axis=1)

    all_annots.append(df_annots)
    # break
all_annots = pd.concat(all_annots, axis=0)

In [None]:
all_annots = all_annots.replace(0, 1)
all_annots.head()

In [None]:
level = "Subclass"
lower_lvl = level.lower()
fig, ax = plt.subplots(figsize=(12,4))


labels = adata.obs[level].unique().tolist()
group_to_nt = adata.obs[[level, "Neighborhood"]].drop_duplicates().set_index(level).to_dict()['Neighborhood']
tick_to_nn_pal = {tick: adata.uns['m3c_neighborhood_palette'].get(nt, '#808080') for tick, nt in group_to_nt.items()}

violin_jitter_plot_with_palette(
    ax,
    df=all_annots,
    palette=adata.uns[f"{level}_palette"],
    cat_col=lower_lvl,
    y_col=f"{lower_lvl}_score",
    max_points=50,       # downsample per category
    jitter=0.05,
    violin_alpha=0.6,
    point_size=1,
    point_alpha=0.9,
    xlabel_color_map=tick_to_nn_pal,
    rasterized=False,
    title = f"Annotation Score: HMBA - {level}",
    title_fontsize=24,
    ax_tick_fontsize=8,
    ax_label_fontsize=10,
    ax_tick_rotation=30,
)

plt.tight_layout()
plt.savefig(image_path / f"Annotation_Score_HMBA_{lower_lvl}.png", dpi=300, bbox_inches='tight')
plt.savefig(image_path / f"Annotation_Score_HMBA_{lower_lvl}.pdf", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
level = "Group"
lower_lvl = level.lower()
fig, ax = plt.subplots(figsize=(12,4))


labels = adata.obs[level].unique().tolist()
group_to_nt = adata.obs[[level, "Neighborhood"]].drop_duplicates().set_index(level).to_dict()['Neighborhood']
tick_to_nn_pal = {tick: adata.uns['m3c_neighborhood_palette'].get(nt, '#808080') for tick, nt in group_to_nt.items()}

violin_jitter_plot_with_palette(
    ax,
    df=all_annots,
    palette=adata.uns[f"{level}_palette"],
    cat_col=lower_lvl,
    y_col=f"{lower_lvl}_score",
    max_points=50,       # downsample per category
    jitter=0.05,
    violin_alpha=0.6,
    point_size=1,
    point_alpha=0.9,
    xlabel_color_map=tick_to_nn_pal,
    rasterized=False,
    title = f"Annotation Score: HMBA - {level}",
    title_fontsize=24,
    ax_tick_fontsize=8,
    ax_label_fontsize=10,
    ax_tick_rotation=30,
)

plt.tight_layout()
plt.savefig(image_path / f"Annotation_Score_HMBA_{lower_lvl}.png", dpi=300, bbox_inches='tight')
plt.savefig(image_path / f"Annotation_Score_HMBA_{lower_lvl}.pdf", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
adata.obs.loc[all_annots.index, 'subclass_score'] = all_annots.loc[all_annots.index, 'subclass_score']
adata.obs.loc[all_annots.index, 'group_score'] = all_annots.loc[all_annots.index, 'group_score']
common_annot = adata.obs.copy()

In [None]:
label_toplot = "Subclass"
lower_tpl = label_toplot.lower()

In [None]:
entropies_rna = {}
for (_donor, _region, _replicate), _df in common_annot.groupby(['donor', 'brain_region', 'replicate'], observed=True):
    group_entropies = {}
    for _class in _df[label_toplot].cat.remove_unused_categories().cat.categories:
        probs = _df.loc[_df[label_toplot] == _class, f'{lower_tpl}_score']
        ent = entropy(probs.round(3).value_counts().sort_index())
        # if ent > 0: 
        group_entropies[_class] = ent
        # print(f"Class: {_class}, Entropy: {group_entropies[_class]}")
    entropies_rna[(_donor, _region, _replicate)] = group_entropies

In [None]:
df_ent_rna = entropy_to_df(entropies_rna, method_name="RNA")
# df_ent_mc = entropy_to_df(entropies_meth, method_name="MC")
# df_ent = df_ent_rna.merge(df_ent_mc, on=['donor', 'brain_region', 'group'], suffixes=('_RNA', '_MC'))

In [None]:
df_ent_rna.head()

In [None]:
# all_groups = (df_ent_rna['group'].unique())

# Calculate mean and std for each brain region, group, and method combination
stats_rna = df_ent_rna.groupby(['group'])['entropy'].agg(['mean', 'std']).reset_index()
stats_rna.columns = ['group', 'mean_RNA', 'std_RNA']
stats_rna['std_RNA'] = stats_rna['std_RNA'].fillna(0)
stats_rna = stats_rna[stats_rna['group'] != 'unknown']
all_groups = stats_rna.sort_values(by="mean_RNA")['group'].tolist()

In [None]:
rasterized = False
# Create subplots - one for each brain region, stacked vertically
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(12, 4), 
                        sharex=True, squeeze=False)

ax = axes.flatten()[0]

palette = adata.uns[f'{label_toplot}_palette']

# Get data for this region
# region_color = palette.get(region, 'gray')
# region_data = stats_df[stats_df['brain_region'] == region]

# Create lists for plotting, ensuring all groups are represented
groups_to_plot = []
means_rna = []
stds_rna = []
means_mc = []
stds_mc = []

for group in all_groups:
    group_row = stats_rna[stats_rna['group'] == group]
    if not group_row.empty:
        groups_to_plot.append(group)
        means_rna.append(group_row['mean_RNA'].iloc[0])
        stds_rna.append(group_row['std_RNA'].iloc[0])
        # means_mc.append(group_row['mean_MC'].iloc[0])
        # stds_mc.append(group_row['std_MC'].iloc[0])
    else: 
        groups_to_plot.append(group)
        means_rna.append(0)
        stds_rna.append(0)
        # means_mc.append(0)
        # stds_mc.append(0)

# Create grouped bar plot
if groups_to_plot:
    x_pos = np.arange(len(groups_to_plot))
    width = 0.7  # Width of bars
    
    # RNA bars
    bars1 = ax.bar(x_pos - width/2, means_rna, width, 
                    yerr=stds_rna, capsize=3,
                    color=[palette.get(group, 'gray') for group in groups_to_plot],
                    alpha=0.8, 
                    edgecolor='black', linewidth=0.5, 
                    rasterized=rasterized,)
    
    # # MC bars
    # bars2 = ax.bar(x_pos + width/2, means_mc, width,
    #                 yerr=stds_mc, capsize=3,
    #                 color='lightblue', alpha=0.8, 
    #                 edgecolor='black', linewidth=0.5,
    #                 label='MC')

    # Set the x-tick labels
    ax.set_xticks(x_pos-0.35)
    ax.set_xticklabels(groups_to_plot, rotation=30, ha='right', fontsize=7, rasterized=rasterized)

group_to_nt = adata.obs[[label_toplot, "Neighborhood"]].drop_duplicates().set_index(label_toplot).to_dict()['Neighborhood']
tick_to_nn_pal = {tick: adata.uns['m3c_neighborhood_palette'].get(nt, '#808080') for tick, nt in group_to_nt.items()}
for i, ticklabel in enumerate(ax.get_xticklabels()): 
    ticklabel.set_color(tick_to_nn_pal.get(ticklabel.get_text(), '#808080'))

# Formatting
ax.set_ylabel(f'\nEntropy', fontsize=10, rotation=0, ha='right', va='center', rasterized=rasterized)
ax.grid(axis='y', linestyle='--', alpha=0.3, rasterized=rasterized)
ax.set_ylim(bottom=0)

# Add legend only to the first subplot
# ax.legend(loc='upper right')

# Set the x-label only for the bottom plot
# ax.set_xlabel(label_toplot, fontsize=12)

# Overall title
ax.set_title(f'Annotation Entropy: HMBA - {label_toplot}\n(Mean ± Std Dev across samples)', 
            fontsize=14, y=0.98, rasterized=rasterized)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(image_path / f"Annotation_Entropy_HMBA_{lower_tpl}.png", dpi=300, bbox_inches='tight')
plt.savefig(image_path / f"Annotation_Entropy_HMBA_{lower_tpl}.pdf", dpi=300, bbox_inches='tight')
plt.show()

### Meth

In [None]:
all_annots = pd.read_csv("/home/x-aklein2/projects/aklein/BICAN/BG/data/methylation_2/annot_with_scores.csv", index_col=0)
all_annots.head()

In [None]:
level = "Subclass"
lower_lvl = level.lower()
fig, ax = plt.subplots(figsize=(12,4))


labels = adata.obs[level].unique().tolist()
group_to_nt = adata.obs[[level, "Neighborhood"]].drop_duplicates().set_index(level).to_dict()['Neighborhood']
tick_to_nn_pal = {tick: adata.uns['m3c_neighborhood_palette'].get(nt, '#808080') for tick, nt in group_to_nt.items()}

violin_jitter_plot_with_palette(
    ax,
    df=all_annots,
    palette=adata.uns[f"{level}_palette"],
    cat_col=lower_lvl,
    y_col=f"{lower_lvl}_score",
    max_points=50,       # downsample per category
    jitter=0.05,
    violin_alpha=0.6,
    point_size=1,
    point_alpha=0.9,
    xlabel_color_map=tick_to_nn_pal,
    rasterized=False,
    title = f"Annotation Score: snm3C - {level}",
    title_fontsize=24,
    ax_tick_fontsize=8,
    ax_label_fontsize=10,
    ax_tick_rotation=30,
)

plt.tight_layout()
plt.savefig(image_path / f"Annotation_Score_snm3C_{lower_lvl}.png", dpi=300, bbox_inches='tight')
plt.savefig(image_path / f"Annotation_Score_snm3C_{lower_lvl}.pdf", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
level = "Group"
lower_lvl = level.lower()
fig, ax = plt.subplots(figsize=(12,4))


labels = adata.obs[level].unique().tolist()
group_to_nt = adata.obs[[level, "Neighborhood"]].drop_duplicates().set_index(level).to_dict()['Neighborhood']
tick_to_nn_pal = {tick: adata.uns['m3c_neighborhood_palette'].get(nt, '#808080') for tick, nt in group_to_nt.items()}

violin_jitter_plot_with_palette(
    ax,
    df=all_annots,
    palette=adata.uns[f"{level}_palette"],
    cat_col=lower_lvl,
    y_col=f"{lower_lvl}_score",
    max_points=50,       # downsample per category
    jitter=0.05,
    violin_alpha=0.6,
    point_size=1,
    point_alpha=0.9,
    xlabel_color_map=tick_to_nn_pal,
    rasterized=False,
    title = f"Annotation Score: snm3C - {level}",
    title_fontsize=24,
    ax_tick_fontsize=8,
    ax_label_fontsize=10,
    ax_tick_rotation=30,
)

plt.tight_layout()
plt.savefig(image_path / f"Annotation_Score_snm3C_{lower_lvl}.png", dpi=300, bbox_inches='tight')
plt.savefig(image_path / f"Annotation_Score_snm3C_{lower_lvl}.pdf", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
# del adata.obs['subclass_score']
# del adata.obs['group_score']
adata.obs.loc[all_annots.index, 'subclass_score'] = all_annots.loc[all_annots.index, 'subclass_score']
adata.obs.loc[all_annots.index, 'group_score'] = all_annots.loc[all_annots.index, 'group_score']
common_annot = adata.obs.copy()

In [None]:
label_toplot = "Group"
lower_tpl = label_toplot.lower()

In [None]:
entropies_meth = {}
for (_donor, _region, _replicate), _df in common_annot.groupby(['donor', 'brain_region', 'replicate'], observed=True):
    group_entropies = {}
    for _class in _df[label_toplot].cat.remove_unused_categories().cat.categories:
        probs = _df.loc[_df[label_toplot] == _class, f'{lower_tpl}_score']
        ent = entropy(probs.round(3).value_counts().sort_index())
        # if ent > 0: 
        group_entropies[_class] = ent
        # print(f"Class: {_class}, Entropy: {group_entropies[_class]}")
    entropies_meth[(_donor, _region, _replicate)] = group_entropies

In [None]:
# df_ent_rna = entropy_to_df(entropies_rna, method_name="RNA")
df_ent_mc = entropy_to_df(entropies_meth, method_name="MC")
# df_ent = df_ent_rna.merge(df_ent_mc, on=['donor', 'brain_region', 'group'], suffixes=('_RNA', '_MC'))

In [None]:
df_ent_mc.head()

In [None]:
# all_groups = (df_ent_rna['group'].unique())

# Calculate mean and std for each brain region, group, and method combination
stats_mc = df_ent_mc.groupby(['group'])['entropy'].agg(['mean', 'std']).reset_index()
stats_mc.columns = ['group', 'mean_MC', 'std_MC']
stats_mc['std_MC'] = stats_mc['std_MC'].fillna(0)
stats_mc = stats_mc[stats_mc['group'] != 'unknown']
stats_mc = stats_mc[stats_mc['mean_MC'] != 0]
all_groups = stats_mc.sort_values(by="mean_MC")['group'].tolist()

In [None]:
rasterized = False
# Create subplots - one for each brain region, stacked vertically
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(12, 4), 
                        sharex=True, squeeze=False)

ax = axes.flatten()[0]

palette = adata.uns[f'{label_toplot}_palette']

# Get data for this region
# region_color = palette.get(region, 'gray')
# region_data = stats_df[stats_df['brain_region'] == region]

# Create lists for plotting, ensuring all groups are represented
groups_to_plot = []
means_rna = []
stds_rna = []
means_mc = []
stds_mc = []

for group in all_groups:
    group_row = stats_mc[stats_mc['group'] == group]
    if not group_row.empty:
        groups_to_plot.append(group)
        # means_rna.append(group_row['mean_RNA'].iloc[0])
        # stds_rna.append(group_row['std_RNA'].iloc[0])
        means_mc.append(group_row['mean_MC'].iloc[0])
        stds_mc.append(group_row['std_MC'].iloc[0])
    else: 
        groups_to_plot.append(group)
        # means_rna.append(0)
        # stds_rna.append(0)
        means_mc.append(0)
        stds_mc.append(0)

# Create grouped bar plot
if groups_to_plot:
    x_pos = np.arange(len(groups_to_plot))
    width = 0.7  # Width of bars
    
    # RNA bars
    # bars1 = ax.bar(x_pos - width/2, means_rna, width, 
    #                 yerr=stds_rna, capsize=3,
    #                 color=[palette.get(group, 'gray') for group in groups_to_plot],
    #                 alpha=0.8, 
    #                 edgecolor='black', linewidth=0.5)
    
    # MC bars
    bars2 = ax.bar(x_pos + width/2, means_mc, width,
                    yerr=stds_mc, capsize=3,
                    color=[palette.get(group, 'gray') for group in groups_to_plot],
                    alpha=0.8, 
                    edgecolor='black', linewidth=0.5,
                    rasterized=rasterized
    )
    #                 label='MC')

    # Set the x-tick labels
    ax.set_xticks(x_pos+0.35)
    ax.set_xticklabels(groups_to_plot, rotation=30, ha='right', fontsize=8, rasterized=rasterized)

group_to_nt = adata.obs[[label_toplot, "Neighborhood"]].drop_duplicates().set_index(label_toplot).to_dict()['Neighborhood']
tick_to_nn_pal = {tick: adata.uns['m3c_neighborhood_palette'].get(nt, '#808080') for tick, nt in group_to_nt.items()}
for i, ticklabel in enumerate(ax.get_xticklabels()): 
    ticklabel.set_color(tick_to_nn_pal.get(ticklabel.get_text(), '#808080'))

# Formatting
ax.set_ylabel(f'\nEntropy', fontsize=10, rotation=0, ha='right', va='center', rasterized=rasterized)
ax.grid(axis='y', linestyle='--', alpha=0.3, rasterized=rasterized)
ax.set_ylim(bottom=0)

# Add legend only to the first subplot
# ax.legend(loc='upper right')

# Set the x-label only for the bottom plot
# ax.set_xlabel(label_toplot, fontsize=12)

# Overall title
ax.set_title(f'Annotation Entropy: snm3C - {label_toplot}\n(Mean ± Std Dev across samples)', 
            fontsize=14, y=0.98, rasterized=rasterized)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(image_path / f"Annotation_Entropy_snm3C_{lower_tpl}.png", dpi=300, bbox_inches='tight')
plt.savefig(image_path / f"Annotation_Entropy_snm3C_{lower_tpl}.pdf", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

## Neuron type / Subclass / Group counts

In [None]:
rasterized = False
adata.obs['neuron_type'] = adata.obs['neuron_type'].cat.remove_unused_categories()
region_neuron_composition = adata.obs.groupby(['brain_region_corr', 'neuron_type']).size().to_frame().reset_index()

# Pivot to wide format for matplotlib grouped bar plot
pivot = region_neuron_composition.pivot(index='neuron_type', columns='brain_region_corr', values=0).fillna(0)

fig, ax = plt.subplots(1, 1, figsize=(12, 4), dpi=300)
n_groups = len(pivot.index)
n_bars = len(pivot.columns)
x = np.arange(n_groups)
width = 0.9 / max(n_bars, 1)

# Resolve palette (allow dict or list-like palettes stored in adata.uns)
palette = adata.uns.get('brain_region_corr_palette', None)
if isinstance(palette, dict):
    colors = [palette.get(col, '#808080') for col in pivot.columns]
else:
    try:
        colors = [palette[i] for i in range(len(pivot.columns))]
    except Exception:
        colors = ['#808080'] * n_bars

bars_containers = []
for i, col in enumerate(pivot.columns):
    xpos = x + (i - (n_bars-1)/2) * width
    container = ax.bar(xpos, pivot[col].values, width, label=str(col), color=colors[i], edgecolor='black', linewidth=0.5)
    bars_containers.append(container)

# collect the actual Rectangle artists for the bars
bar_artists = [rect for container in bars_containers for rect in container]

# from adjustText import adjust_text
# # Add labels on top of bars as numbered bold text, then adjust to avoid overlap
texts = []
for rect in bar_artists:
    h = rect.get_height()
    x = rect.get_x() + rect.get_width() / 2
    txt = ax.text(
        x,
        h + 0.02 * ax.get_ylim()[1],         # initial vertical offset (tweak as needed)
        f"{int(h):,}",
        ha='center', va='bottom',
        fontsize=6, fontweight='bold',
        clip_on=False,
        zorder=10
    )
    texts.append(txt)
# counter = 1
# for container in bars_containers:
#     for rect in container:
#         h = rect.get_height()
#         x = rect.get_x() + rect.get_width() / 2
#         # Label: formatted count (bold), no enumeration prefix
#         label = f"{int(h):,}"
#         txt = ax.text(x, h + max(0.01 * ax.get_ylim()[1], 1e-6), label, ha='center', va='bottom', fontsize=6, fontweight='bold', clip_on=False)
#         texts.append(txt)
#         counter += 1

# Adjust text to reduce overlaps and allow horizontal shifts (x + y)
# adjust_text(
#     texts,
#     add_objects=bar_artists,                 # treat bars as objects to avoid
#     only_move={'points': 'xy', 'text': 'xy'}, # allow labels to move in both x and y
#     expand_text=(1.05, 1.05),                 # padding around text boxes
#     expand_objects=(1.25, 1.25),            # padding around bar rectangles
#     force=1.0,                               # stronger push to resolve overlaps
#     # force_pull=0.1,                          # allow labels to be pulled away from objects
#     precision=0.1,                           # finer convergence precision
#     maxiter=500,                             # allow more iterations for difficult layouts
#     avoid_self=True,                         # avoid labels overlapping each other
#     arrowprops=dict(arrowstyle='-', color='k', lw=0.3),
#     ax=ax
# )

# Ensure x-ticks are placed at group centers
ax.set_xticks([rect.get_x() + rect.get_width() / 2 for rect in bars_containers[3]])
ax.set_xticklabels(pivot.index, rotation=0, ha='center', fontsize=10, rasterized=rasterized)
ax.set_ylabel("Cell Count", fontsize=10, rasterized=rasterized)
ax.set_title("Regional Neuronal Composition", fontsize=16, rasterized=rasterized)
ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
ax.grid(axis='y', linestyle='--', alpha=0.3)

plt.tight_layout()
plt.savefig(image_path / "regional_neuron_composition.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "regional_neuron_composition.pdf", dpi=300, bbox_inches="tight")
plt.show()

# Composition plots

### function

In [None]:
def create_stacked_bar_chart(df, group_column, cell_type_column='cell_type', 
                           figsize=(12, 8), title=None, colors=None, 
                           show_percentages=True, rotation=45, rasterized=False,
                           legend_threshold=5.0, text_threshold=2.0,
                           legend_fontsize=12, def_fontsize=12, title_fontsize=12,
                           xlabel=None, 
                        ):
    """
    Create a stacked bar chart showing cell type percentages across groups.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        The input dataframe containing the data
    group_column : str
        Column name to group by (x-axis categories)
    cell_type_column : str, default 'cell_type'
        Column name containing cell type information
    figsize : tuple, default (12, 8)
        Figure size (width, height)
    title : str, optional
        Chart title
    colors : list or dict, optional
        Colors for cell types. If None, uses seaborn default palette
    show_percentages : bool, default True
        Whether to show percentage labels on bars
    rotation : int, default 45
        Rotation angle for x-axis labels
    legend_threshold : float, default 5.0
        Minimum percentage threshold for including cell types in legend
    
    Returns:
    --------
    fig, ax : matplotlib figure and axis objects
    """
    
    # Calculate cell type counts and percentages
    counts = df.groupby([group_column, cell_type_column]).size().unstack(fill_value=0)
    percentages = counts.div(counts.sum(axis=1), axis=0) * 100
    
    # Set up colors
    n_cell_types = len(counts.columns)
    if colors is None:
        colors = sns.color_palette("Set3", n_cell_types)
    elif isinstance(colors, dict):
        colors = [colors.get(ct, 'gray') for ct in counts.columns]
    
    # Create the plot
    fig, ax = plt.subplots(figsize=figsize)
    
    # Determine which cell types meet the legend threshold
    # Calculate max percentage for each cell type across all groups
    max_percentages = percentages.max(axis=0)
    legend_cell_types = max_percentages[max_percentages >= legend_threshold].index.tolist()
    
    # Create stacked bar chart
    bottom = np.zeros(len(percentages))
    bars = []
    
    for i, cell_type in enumerate(percentages.columns):
        # Only include in legend if it meets the threshold
        label = cell_type if cell_type in legend_cell_types else None
        
        bar = ax.bar(percentages.index, percentages[cell_type], 
                    bottom=bottom, label=label, color=colors[i], 
                    rasterized=rasterized)
        bars.append(bar)
        
        # Add percentage labels if requested
        if show_percentages:
            for j, (idx, value) in enumerate(percentages[cell_type].items()):
                if value > text_threshold:  # Only show label if percentage > 2%
                    ax.text(j, bottom[j] + value/2, f'{value:.1f}%', 
                           ha='center', va='center', fontsize=def_fontsize, fontweight='bold', 
                           rasterized=rasterized)
        
        bottom += percentages[cell_type]
    
    # Customize the plot
    if xlabel is None: 
        ax.set_xlabel(group_column.replace('_', ' ').title(), fontsize=def_fontsize)
    else: 
        ax.set_xlabel(xlabel, fontsize=def_fontsize)
    ax.set_ylabel('Percentage (%)', fontsize=def_fontsize)
    ax.set_ylim(0, 100)
    
    if title:
        ax.set_title(title, fontsize=title_fontsize, fontweight='bold', rasterized=rasterized)
    else:
        ax.set_title(f'Cell Type Distribution by {group_column.replace("_", " ").title()}', 
                    fontsize=title_fontsize, fontweight='bold', rasterized=rasterized)
    
    # Rotate x-axis labels
    if rotation != 0: 
        plt.xticks(rotation=rotation, ha='right', fontsize=def_fontsize)
    else: 
        plt.xticks(ha='center', fontsize=def_fontsize)
    
    # Add legend (only for cell types that meet the threshold)
    legend_handles = [bar for bar, ct in zip(bars, percentages.columns) if ct in legend_cell_types]
    if len(legend_handles) <= 20 and len(legend_handles) > 0:
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=legend_fontsize)
    elif len(legend_handles) > 20:
        # For many legend items, you might want to handle differently
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=legend_fontsize, ncol=2)
    


    # Add grid for better readability
    ax.grid(axis='y', alpha=0.3, linestyle='--')
    
    # plt.tight_layout()
    
    return fig, ax

## Plot

In [None]:
fig, ax = create_stacked_bar_chart(
    adata[adata.obs['neuron_type'] == "Nonneuron"].obs,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='Subclass',
    title='Nonneuronal Subclass Cell Type Distribution',
    colors=adata.uns['Subclass_palette'],
    rasterized=False,
    legend_threshold=0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24, 
    xlabel="")
plt.savefig(image_path / "nn_composition_subclass.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "nn_composition_subclass.pdf", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

In [None]:
fig, ax = create_stacked_bar_chart(
    adata[adata.obs['neuron_type'] == "Nonneuron"].obs,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='Group',
    title='Nonneuronal Group Cell Type Distribution',
    colors=adata.uns['Group_palette'],
    rasterized=False,
    legend_threshold=0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24, 
    xlabel="")
plt.savefig(image_path / "nn_composition_group.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "nn_composition_group.pdf", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

In [None]:
fig, ax = create_stacked_bar_chart(
    adata[adata.obs['neuron_type'] == "Neuron"].obs,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='Subclass',
    title='Neuronal Subclass Cell Type Distribution',
    colors=adata.uns['Subclass_palette'],
    rasterized=False,
    legend_threshold=2.0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24, 
    xlabel="")
plt.savefig(image_path / "neu_composition_subclass.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "neu_composition_subclass.pdf", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

In [None]:
fig, ax = create_stacked_bar_chart(
    adata[adata.obs['neuron_type'] == "Neuron"].obs,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='Group',
    title='Neuronal Group Cell Type Distribution',
    colors=adata.uns['Group_palette'],
    rasterized=False,
    legend_threshold=0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24, 
    xlabel="")
plt.savefig(image_path / "neu_composition_group.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "neu_composition_group.pdf", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

# Meth - RNA Overlap

In [None]:
from PyComplexHeatmap import *
def _plot_overlap_heatmap(
    use_adata,
    ref_col,
    qry_col,
    image_path=None,
    show=True,
    save_fig=False,
    current_datetime=None,
    rasterized=True,
    xlabel=None, 
    ylabel=None,
    annot_fontsize=10,
    ):
    if isinstance(use_adata, ad.AnnData):
        use_data = use_adata.obs.copy()
    else: 
        use_data = use_adata.copy()
    vc = use_data.loc[:, [qry_col, ref_col]].value_counts().reset_index()
    D = vc.groupby(qry_col)['count'].sum()
    vc['N']=vc[qry_col].map(D).astype(int)
    vc['fraction']=vc['count']/vc['N']
    data = vc.pivot(index=qry_col, columns=ref_col, values='fraction')
    data.head()

    df_rows=data.index.to_series().to_frame()
    cols=data.columns.tolist()
    max_idx=np.argmax(data.fillna(0).values,axis=1)
    df_rows["GROUP"]=[cols[i] for i in max_idx]
    use_rows=[]
    for col in data.columns.tolist(): 
        df1=df_rows.loc[df_rows['GROUP']==col]
        if df1.shape[0]==0:
            continue
        use_rows.extend(df1[qry_col].unique().tolist())
    df_rows=df_rows.loc[use_rows]
    ct2code=use_data.assign(code=use_data[qry_col].cat.codes).loc[:,[qry_col,'code']].drop_duplicates().set_index(qry_col).code.to_dict()
    # df_rows['Label']=df_rows[cluster_col].apply(lambda x: f"{ct2code[x]}: {x}")
    ret = []
    for x in df_rows[qry_col].tolist():
        ret.extend([f"{ct2code[x]}: {x}"])
    df_rows['Label']=ret
    df_rows.head()

    # Plot
    row_ha=HeatmapAnnotation(
        label=anno_label(df_rows.Label,colors='black',relpos=(0,0.5)),
        axis=0,orientation='right',
    )

    plt.figure(figsize=(12,6))
    ClusterMapPlotter(
        data.loc[df_rows.index.tolist()],
        row_cluster=False,
        col_cluster=False,
        cmap='Reds',
        rasterized=rasterized,
        right_annotation=row_ha,
        row_split=df_rows['GROUP'],
        row_split_gap=0.5,
        row_split_order=df_rows['GROUP'].unique().tolist(),
        show_rownames=False,
        show_colnames=True,
        yticklabels=True,
        xticklabels=True,
        xticklabels_kws=dict(labelrotation=-45,labelcolor='blue',labelsize=10),
        yticklabels_kws=dict(labelcolor='red',labelsize=10),
        annot=True,
        annot_kws=dict(fontsize=annot_fontsize),
        fmt='.2g',
        linewidth=0.05,
        linecolor='gold',
        linestyle='-:',
        label='fraction',
        legend_kws=dict(extend='both',extendfrac=0.1),
        xlabel=ref_col if xlabel is None else xlabel,
        ylabel=qry_col if ylabel is None else ylabel,
        xlabel_kws=dict(color='blue',fontsize=14,labelpad=5),
        xlabel_side='top',
        ylabel_kws=dict(color='red',fontsize=14,labelpad=5), #increace labelpad manually using labelpad (points)
        # xlabel_bbox_kws=dict(facecolor='green'),
        # ylabel_bbox_kws=dict(facecolor='chocolate',edgecolor='red'),|
        # standard_scale=0,
    )
    if image_path is not None and save_fig is True: 
        plt.savefig(image_path / f"Overlap_Heatmap_{qry_col}_by_{ref_col}.png", dpi=300, bbox_inches='tight')
        plt.savefig(image_path / f"Overlap_Heatmap_{qry_col}_by_{ref_col}.pdf", dpi=300, bbox_inches='tight')
    if show is True: 
        plt.show()
    plt.close()

In [None]:
meth_annot_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/methylation_2/annot_with_scores.csv"
meth_annot = pd.read_csv(meth_annot_path, index_col=0)

# rna_annot = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.tsv"
# rna_annot = pd.read_csv(rna_annot, sep='\t', index_col=0)
rna_annot = adata.obs.copy()

common_cells = rna_annot.index.intersection(meth_annot.index)
rna_annot.loc[:, 'meth_subclass'] = meth_annot.loc[common_cells, 'subclass'].astype("category").copy()
rna_annot.loc[:, 'meth_subclass_score'] = meth_annot.loc[common_cells, 'subclass_score'].copy()
rna_annot.loc[:, 'meth_group'] = meth_annot.loc[common_cells, 'group'].astype("category").copy()
rna_annot.loc[:, 'meth_group_score'] = meth_annot.loc[common_cells, 'group_score'].copy()

In [None]:
common_annot = rna_annot.loc[common_cells, ['donor', 'replicate', 'brain_region', 'brain_region_corr', 'Subclass', 'meth_subclass', "Group", "meth_group"]] # , "meth_group_score", "group_score"]]
common_annot = common_annot.astype({
    "donor" : "category",
    "replicate" : "category",
    "brain_region" : "category",
    "Subclass" : "category",
    "meth_subclass" : "category",
    "Group" : "category",
    "meth_group" : "category",
})
common_subclasses = common_annot['Subclass'].cat.categories.union(common_annot['meth_subclass'].cat.categories)
common_annot['Subclass'] = common_annot['Subclass'].cat.set_categories(common_subclasses)
common_annot['meth_subclass'] = common_annot['meth_subclass'].cat.set_categories(common_subclasses)

common_groups = common_annot['Group'].cat.categories.union(common_annot['meth_group'].cat.categories)
common_annot['Group'] = common_annot['Group'].cat.set_categories(common_groups)
common_annot['meth_group'] = common_annot['meth_group'].cat.set_categories(common_groups)

common_annot.dtypes

In [None]:
print("Subclass Level Agreement %.3f%%" % (((common_annot['Subclass'] == common_annot['meth_subclass']).sum() / common_annot.shape[0]) * 100))
print("Group Level Agreement %.3f%%" % (((common_annot['Group'] == common_annot['meth_group']).sum() / common_annot.shape[0]) * 100))

In [None]:
# _plot_overlap_heatmap(common_annot[common_annot['brain_region'] == "CAB"], ref_col='Subclass', qry_col='meth_subclass')
# _plot_overlap_heatmap(common_annot[common_annot['brain_region'] == "CAB"], ref_col='Group', qry_col='meth_group')
_plot_overlap_heatmap(common_annot, ref_col='Subclass', qry_col='meth_subclass', show=False, save_fig=True, ylabel="snm3C Subclass", xlabel="HMBA Subclass", image_path=image_path, annot_fontsize=8)
_plot_overlap_heatmap(common_annot, ref_col='Group', qry_col='meth_group', show=False, save_fig=True, ylabel="snm3C Group", xlabel="HMBA Group", image_path=image_path, annot_fontsize=4)

In [None]:
common_annot['subclass_match'] = (common_annot['Subclass'] == common_annot['meth_subclass'])
common_annot['group_match'] = (common_annot['Group'] == common_annot['meth_group'])
brain_region_palette = adata.uns['brain_region_corr_palette']

In [None]:
agreement = []
for df in common_annot.groupby('brain_region_corr', observed=True):
    region, dfa = df
    
    dfb = (
        dfa.
        groupby(["donor", "replicate"], observed=True)
        .agg({
            "subclass_match" : ["sum", "size"],       
            "group_match" : ["sum", "size"]
            })
        )
    dfb.columns = ['_'.join(col).strip() for col in dfb.columns.values]
    dfb['subclass_match_rate'] = dfb['subclass_match_sum'] / dfb['subclass_match_size']
    dfb['group_match_rate'] = dfb['group_match_sum'] / dfb['group_match_size']
    agreement.append((region, 
                      dfb['subclass_match_size'].mean(),
                      dfb['group_match_size'].std(),
                      dfb['subclass_match_rate'].mean() * 100, 
                      dfb['subclass_match_rate'].std() * 100,
                      dfb['group_match_rate'].mean() * 100,
                      dfb['group_match_rate'].std() * 100))

In [None]:
rasterized = False
agreement_df = pd.DataFrame(agreement, columns=['brain_region_corr', 'n_cells_mean', 'n_cells_std', 'subclass_agreement_mean', 'subclass_agreement_std', 'group_agreement_mean', 'group_agreement_std'])
agreement_df['color'] = agreement_df['brain_region_corr'].map(brain_region_palette)

fig, axes = plt.subplots(1, 3, dpi=300, figsize=(15, 4))

ax = axes[0]
ax.bar(data=agreement_df, x='brain_region_corr', height='n_cells_mean', color='color', yerr=agreement_df['n_cells_std'], capsize=5, rasterized=rasterized)
ax.set_title("Common Annotated Cell Types by Brain Region", fontsize=12)
ax.set_ylabel("Cell Number", fontsize=12)
ax.set_xticks(np.arange(len(agreement_df['brain_region_corr'])))
ax.set_xticklabels(agreement_df['brain_region_corr'], fontsize=10, rasterized=rasterized)
ax.grid(axis='y', linestyle='--', alpha=0.75, rasterized=rasterized)

ax = axes[1]
ax.bar(data=agreement_df, x='brain_region_corr', height='subclass_agreement_mean', color='color', yerr=agreement_df['subclass_agreement_std'], capsize=5, rasterized=rasterized)
ax.set_title("Subclass Agreement by Brain Region", fontsize=12)
ax.set_ylabel("Subclass Agreement (%)", fontsize=12)
ax.set_ylim(0, 100)
ax.set_xticks(np.arange(len(agreement_df['brain_region_corr'])))
ax.set_xticklabels(agreement_df['brain_region_corr'], fontsize=10, rasterized=rasterized)
ax.grid(axis='y', linestyle='--', alpha=0.75, rasterized=rasterized)

ax = axes[2]
ax.bar(data=agreement_df, x='brain_region_corr', height='group_agreement_mean', color='color', yerr=agreement_df['group_agreement_std'], capsize=5, rasterized=rasterized)
ax.set_title("Group Agreement by Brain Region", fontsize=12)
ax.set_ylabel("Group Agreement (%)", fontsize=12)
ax.set_ylim(0, 100)
ax.set_xticks(np.arange(len(agreement_df['brain_region_corr'])))
ax.set_xticklabels(agreement_df['brain_region_corr'], fontsize=10, rasterized=rasterized)
ax.grid(axis='y', linestyle='--', alpha=0.75, rasterized=rasterized)

plt.savefig(image_path / "snm3C_HMBA_Annotation_Agreement_by_Brain_Region.png", dpi=300, bbox_inches='tight')
plt.savefig(image_path / "snm3C_HMBA_Annotation_Agreement_by_Brain_Region.pdf", dpi=300, bbox_inches='tight')
plt.show()
plt.close()


# Plots for tatiana (all annotations on the figure)

In [None]:
tt_image_path = Path("/home/x-aklein2/projects/aklein/BICAN/BG/images/figures/supp_tt_annot")
tt_image_path.mkdir(exist_ok=True, parents=True)
show_plots = False

In [None]:
for _dsid in adata.obs['dataset_id'].unique(): 
    adata_sub = adata[adata.obs['dataset_id'] == _dsid].copy()
    _donor = adata_sub.obs['donor'].unique()[0]
    _region = adata_sub.obs['brain_region_corr'].unique()[0]
    _lab = adata_sub.obs['replicate'].unique()[0]
    
    fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=300)
    categorical_scatter(data=adata_sub, coord_base="spatial", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=ax, rasterized=True, axis_format=None)
    plot_categorical(adata_sub, coord_base="spatial", cluster_col="Group", 
                    show=False, coding=False, text_anno=False, ax=ax, show_legend=False,
                    legend_kws=dict(bbox_to_anchor=(0.9, 1), loc='upper left', borderaxespad=0.),
                    rasterized=True, axis_format=None
                    )
    # ax.set_title(f"{_donor} - {_region} - {_lab}")
    plt.savefig(tt_image_path / f"{_donor}_{_region}_{_lab}.png", dpi=300, bbox_inches='tight')
    plt.savefig(tt_image_path / f"{_donor}_{_region}_{_lab}.pdf", dpi=300, bbox_inches='tight')
    if show_plots:
        plt.show()
    plt.close()
    # break
    