In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad

from scipy.stats import entropy

import matplotlib.pyplot as plt
import seaborn as sns
from spida.pl import plot_categorical, plot_continuous, categorical_scatter
from PyComplexHeatmap import *
plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['axes.linewidth'] = 1
plt.rcParams['axes.facecolor'] = 'white'

## functions

In [None]:
def _make_map(
    adata,
    ref_col,
    cluster_col,
    query = "Modality == 'ref' "
):
    data = adata.obs.query(query).copy()
    df_map = (
        data.groupby(cluster_col, observed=True)[ref_col]
        .value_counts(normalize=True)
        .sort_values(ascending=False)
        .reset_index()
        .drop_duplicates(cluster_col, keep='first')
    )
    return df_map

def _combine_cluster_cell_annots(
    adata,
    cluster_col,
    cell_col,
    cell_prob_col,
    added_col,
    integrated_col=None,
    df_map=None,
    groupby_col = None
): 
    if groupby_col is None:
        data_list = [adata.obs.copy()]
    else: 
        data_list = [group for _, group in adata.obs.groupby(groupby_col, observed=True)]
    for data in data_list:
        df_ct = data[[cluster_col, cell_col, cell_prob_col]].copy()
        # display(df_ct.head())
        df_ct[f'final_{added_col}'] = "U"
        df_ct['eql_col'] = df_ct.apply(lambda x: x[cluster_col] if x[cluster_col] == x[cell_col] else x[f'final_{added_col}'], axis=1)
        if df_map is not None and integrated_col is not None:
            keeper_clusters = df_map.query(f"proportion>0.6").index.tolist()
            df_ct['keeper_clust'] = df_ct.apply(lambda x: x[cluster_col] if x[cluster_col] in keeper_clusters else "U", axis=1)
        df_ct['ff'] = df_ct.apply(lambda x: x[cell_col] if x[cell_prob_col] > 0.9 else "U", axis=1)

        
        if df_map is not None and integrated_col is not None:
            for i, _cell in enumerate(df_ct.index): 
                ff_annot = df_ct.at[_cell, 'ff']
                keeper_annot = df_ct.at[_cell, 'keeper_clust']
                eql_annot = df_ct.at[_cell, 'eql_col']

                if ff_annot == 'U' and keeper_annot == 'U':
                    annot = eql_annot
                elif ff_annot == keeper_annot:
                    annot = ff_annot
                elif ff_annot == 'U':
                    annot = keeper_annot
                elif keeper_annot == 'U':
                    annot = ff_annot
                elif ff_annot != keeper_annot: 
                    annot = "unknown"
                df_ct.at[_cell, f'final_{added_col}'] = annot
        else: 
            for i, _cell in enumerate(df_ct.index): 
                ff_annot = df_ct.at[_cell, 'ff']
                eql_annot = df_ct.at[_cell, 'eql_col']

                if ff_annot == 'U':
                    annot = eql_annot
                else:
                    annot = ff_annot
                df_ct.at[_cell, f'final_{added_col}'] = annot

        
        df_ct[f'final_{added_col}'] = df_ct[f'final_{added_col}'].replace("U", "unknown")
        final_annots = df_ct[f'final_{added_col}'].unique()
        
        adata.obs.loc[data.index, added_col] = df_ct[f'final_{added_col}'].astype(str).copy()
        adata.obs.loc[data.index, f'{added_col}.Prob'] = df_ct[cell_prob_col].copy()
    adata.obs[added_col] = adata.obs[added_col].astype('category')
    # display(adata.obs[added_col].value_counts())



### Add the colors: 
def add_colors(adata, cat_col, palette):
    colors = []
    for _cat in adata.obs[cat_col].cat.categories: 
        try:
            if isinstance(palette, dict):
                color = palette[_cat]
            else:
                color = palette.loc[_cat, 'Hex']
        except KeyError:
            print(_cat)
            color = '#808080'
        colors.append(color)

    adata.uns[f'{cat_col}_colors'] = colors 

# Proseg

## Load Data

In [None]:
data_dir = Path("/home/x-aklein2/projects/aklein/BICAN/BG/annotation/execute/region_donor_lab_gp")

In [None]:
group_tfs_cols = ['c2c_allcools_label_Group', 'allcools_Group_transfer_score', 'allcools_Group', 'Group_Combined']
final_annot = []
for donor_region_dir in data_dir.glob("*"):
    # print(donor_region_dir)
    final_file = donor_region_dir / "final.h5ad"
    if final_file.exists():
        print(final_file)
        adata = ad.read_h5ad(final_file)
        df_sub_obs = adata.obs[['donor', 'brain_region', 'replicate', 'dataset_id', 'Subclass', 'Group']].copy()
        regroup_obs = []
        for _group_file in donor_region_dir.glob("groups/*/final.h5ad"):
            if (_group_file.parent / "joint.h5ad").exists():
                joint_adata = ad.read_h5ad(_group_file.parent / "joint.h5ad")
                df_map = _make_map(adata=joint_adata, ref_col='Group', cluster_col='integrated_leiden')
            else: 
                df_map = None
            adata = ad.read_h5ad(_group_file)
            _combine_cluster_cell_annots(
                    adata,
                    cluster_col='c2c_allcools_label_Group',
                    cell_col='allcools_Group',
                    cell_prob_col='allcools_Group_transfer_score',
                    added_col='Group_Combined',
                    integrated_col='integrated_leiden',
                    df_map=df_map,
                )

            regroup_obs.append(adata.obs[group_tfs_cols])
        df_group_obs = pd.concat(regroup_obs)
        # df_sub_obs['ReGroup'] = np.nan
        df_sub_obs.loc[df_group_obs.index, group_tfs_cols] = df_group_obs[group_tfs_cols].copy()
        final_annot.append(df_sub_obs)
    # break

In [None]:
# final_annot = []
# for donor_region_dir in data_dir.glob("*"):
#     # print(donor_region_dir)
#     final_file = donor_region_dir / "final.h5ad"
#     if final_file.exists():
#         print(final_file)
#         adata = ad.read_h5ad(final_file)
#         df_sub_obs = adata.obs[['donor', 'brain_region', 'replicate', 'dataset_id', 'Subclass', 'Group']].copy()
#         regroup_obs = []
#         for _group_file in donor_region_dir.glob("groups/*/final.h5ad"):
#             adata = ad.read_h5ad(_group_file)
#             regroup_obs.append(adata.obs[['c2c_allcools_label_Group']])
#         df_group_obs = pd.concat(regroup_obs)
#         # df_sub_obs['ReGroup'] = np.nan
#         df_sub_obs.loc[df_group_obs.index, 'ReGroup'] = df_group_obs['c2c_allcools_label_Group'].copy()
#         df_sub_obs['ReGroup'] = df_sub_obs['ReGroup'].fillna("unknown")
#         final_annot.append(df_sub_obs)
#         # break

In [None]:
df_obs = pd.concat(final_annot)

In [None]:
df_obs['orig_group'] = df_obs['Group'].copy()
df_obs['Group'] = df_obs['Group_Combined'].copy()
df_obs = df_obs.drop(columns=['Group_Combined'])
# df_obs = df_obs.drop(columns=['group_mismatch'])

In [None]:
# Add neuron type
subclass_to_neighborhood_dict = pd.read_csv("/home/x-aklein2/projects/aklein/BICAN/data/reference/AIT/subclass_to_neighborhood_dict.csv", header=None, index_col=0).to_dict()[1]
df_obs['Subclass'] = df_obs['Subclass'].fillna('unknown')
df_obs['Neighborhood'] = df_obs['Subclass'].map(subclass_to_neighborhood_dict).fillna("unknown")
df_obs['neuron_type'] = df_obs['Neighborhood'].isin(['Nonneuron']).map({True: 'Nonneuron', False: 'Neuron'})
df_obs.loc[df_obs['Neighborhood'] == "unknown", 'neuron_type'] = "unknown"
df_obs['neuron_type'].value_counts()

In [None]:
df_obs.to_csv("/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_GP_PFV8_annotated.tsv", sep="\t", index=True)

In [None]:
all_adata = ad.read_h5ad("/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_ALL/BG_pfv8_all.h5ad")
all_adata

In [None]:
print(df_obs.index.isin(all_adata.obs.index).all())
all_adata = all_adata[df_obs.index].copy()
keep_cols = ['CELL_ID', 'experiment', 'region', 'segmentation', 'donor', 'CENTER_X',
       'CENTER_Y', 'volume', 'nCount_RNA', 'nFeature_RNA', 'nBlank',
       'nCount_RNA_per_Volume', 'pass_qc_pre', 'pass_qc', 'cell',
       'original_cell_id', 'centroid_x', 'centroid_y', 'centroid_z',
       'component', 'surface_area', 'scale', 'transcript_count', 'slide',
       'dataset_id', 'cells_region', 'base_umap_0', 'base_umap_1',
       'base_tsne_0', 'base_tsne_1', 'base_leiden', 'integrated_tsne_0',
       'integrated_umap_1','integrated_leiden', 'integrated_tsne_1', 
       'integrated_umap_0', 'brain_region', 'replicate', 'ename', 'dataset',
       'base_round1_umap_0', 'base_round1_umap_1', 'base_round1_tsne_0',
       'base_round1_tsne_1', 'base_round1_leiden', 'base_round2_leiden',]
all_adata.obs = all_adata.obs[keep_cols]
# del all_adata.obs[df_obs.columns]
all_adata.obs.loc[df_obs.index, df_obs.columns] = df_obs

In [None]:
all_adata.obs['Group'].fillna("unknown", inplace=True)

In [None]:
print(len(df_obs.index))

In [None]:
# all_adata

In [None]:
all_adata.write_h5ad("/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_GP_PFV8_annotated.h5ad")

In [None]:
# df_obs.head()

## Check Aggreement:

In [None]:
df_obs['group_mismatch'] = df_obs['Group'] != df_obs['orig_group']
msvc = df_obs.groupby(['Group'], observed=True).agg({"group_mismatch" : ['sum', 'count']})
msvc.columns = ['mismatch', 'total']
msvc['mismatch_rate'] = msvc['mismatch'] / msvc['total']

fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.bar(data=msvc.reset_index(), x='Group', height='mismatch_rate', color="salmon", edgecolor="black")
ax.set_ylabel("Mismatch Rate", fontsize=6)
ax.set_xlabel("Original Group", fontsize=6)
ax.set_title("Group vs ReGroup Mismatch Rate", fontsize=6)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=6, rotation=45, ha='right')
ax.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()

## Entropy

In [None]:
df_obs['Group'] = df_obs['Group'].astype('category')
df_obs['orig_group'] = df_obs['orig_group'].astype('category')

In [None]:
entropies = {}
for (_donor, _region), _df in df_obs.groupby(['donor', 'brain_region']):
    group_entropies = {}
    for _class in _df['Group'].cat.remove_unused_categories().cat.categories:
        probs = _df.loc[_df['Group'] == _class, 'allcools_Group_transfer_score']
        ent = entropy(probs.round(3).value_counts().sort_index())
        # if ent > 0: 
        group_entropies[_class] = ent
        # print(f"Class: {_class}, Entropy: {group_entropies[_class]}")
    entropies[(_donor, _region)] = group_entropies

In [None]:
entropies_orig = {}
for (_donor, _region), _df in df_obs.groupby(['donor', 'brain_region']):
    group_entropies = {}
    for _class in _df['orig_group'].cat.remove_unused_categories().cat.categories:
        probs = _df.loc[_df['orig_group'] == _class, 'allcools_Group_transfer_score']
        ent = entropy(probs.round(3).value_counts().sort_index())
        # if ent > 0: 
        group_entropies[_class] = ent
        # print(f"Class: {_class}, Entropy: {group_entropies[_class]}")
    entropies_orig[(_donor, _region)] = group_entropies

In [None]:
# donors = df_obs['donor'].unique().tolist()
# regions = df_obs['brain_region'].unique().tolist()
# nrows = len(donors)
# ncols = len(regions)
# fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*4, nrows*4), squeeze=False, sharey=True, sharex=True)
# for i, donor in enumerate(donors):
#     for j, region in enumerate(regions):
#         ax = axes[i, j]
#         if (donor, region) in entropies:
#             group_entropies = entropies[(donor, region)]
#             groups = list(group_entropies.keys())
#             entropy_values = list(group_entropies.values())
#             ax.bar(groups, entropy_values, color='salmon', edgecolor='black')
            
#             # ax.tick_params(labelbottom=True, labelleft=True)
#             # ax.set_xticklabels(ax.get_xticklabels(), fontsize=6, rotation=45, ha='right')
#             ax.grid(axis='y', linestyle='--', alpha=0.5)
#         else:
#             ax.axis('off')
# for ax, col in zip(axes[0], regions):
#     ax.set_title(col, fontsize=10)
# for ax, row in zip(axes[:,0], donors):
#     ax.set_ylabel(row, fontsize=10, rotation=90, labelpad=10)
# for ax, col in zip(axes[-1], regions):
#     ax.set_xlabel("Group", fontsize=8)
#     ax.tick_params(labelbottom=True)
#     ax.set_xticklabels(ax.get_xticklabels(), fontsize=8, rotation=45, ha='right')
# # for ax, row in zip(axes[:,-1], donors):
# #     ax.set_ylabel(row, fontsize=10, rotation=90, labelpad=10)
# # ax.set_ylabel("Entropy", fontsize=8)


# plt.xticks(rotation=45, ha='right')
# plt.tight_layout()
# plt.show()

In [None]:
print("Unique donors:", df_obs['donor'].unique())
print("Unique brain regions:", df_obs['brain_region'].unique())

# Let's also check a sample of the entropies structure
sample_key = list(entropies.keys())[0]
print(f"\nSample entropy data for {sample_key}:")
print(entropies[sample_key])

In [None]:
# Reorganize the data for plotting
# Create a dataframe where each row is a (donor, region, group) combination with entropy values
plot_data = []
for (donor, region), group_entropies in entropies.items():
    for group, entropy_val in group_entropies.items():
        plot_data.append({
            'donor': donor,
            'brain_region': region,
            'group': group,
            'entropy': entropy_val
        })

entropy_df = pd.DataFrame(plot_data)

# Get unique regions and groups for consistent ordering
regions = sorted(entropy_df['brain_region'].unique())
all_groups = sorted(entropy_df['group'].unique())

# Calculate mean and std for each brain region and group combination
stats_df = entropy_df.groupby(['brain_region', 'group'])['entropy'].agg(['mean', 'std']).reset_index()
stats_df['std'] = stats_df['std'].fillna(0)  # Fill NaN std with 0 for single donor cases

# Create subplots - one for each brain region, stacked vertically
fig, axes = plt.subplots(nrows=len(regions), ncols=1, figsize=(12, 4*len(regions)), 
                        sharex=True, squeeze=False)
axes = axes.flatten()  # Make it easier to index

palette = all_adata.uns['brain_region_palette']

for i, region in enumerate(regions):
    ax = axes[i]
    
    # Get data for this region
    color = palette.get(region, 'gray')
    region_data = stats_df[stats_df['brain_region'] == region]
    
    # Create lists for plotting, ensuring all groups are represented
    groups_to_plot = []
    means_to_plot = []
    stds_to_plot = []
    
    for group in all_groups:
        group_row = region_data[region_data['group'] == group]
        if not group_row.empty:
            groups_to_plot.append(group)
            means_to_plot.append(group_row['mean'].iloc[0])
            stds_to_plot.append(group_row['std'].iloc[0])
        else: 
            groups_to_plot.append(group)
            means_to_plot.append(0)
            stds_to_plot.append(0)
    
    # Create the bar plot
    if groups_to_plot:
        bars = ax.bar(range(len(groups_to_plot)), means_to_plot, 
                     yerr=stds_to_plot, capsize=3,
                     color=color, alpha=0.8, edgecolor='black', linewidth=0.5)

        # Set the x-tick labels
        ax.set_xticks(range(len(groups_to_plot)))
        ax.set_xticklabels(groups_to_plot, rotation=45, ha='right', fontsize=8)
    
    # Formatting
    ax.set_ylabel(f'{region}\nEntropy', fontsize=10, rotation=0, ha='right', va='center')
    ax.grid(axis='y', linestyle='--', alpha=0.3)
    ax.set_ylim(bottom=0)
    
    # Only show x-axis labels on the bottom plot
    if i < len(regions) - 1:
        ax.set_xticklabels([])

# Set the x-label only for the bottom plot
axes[-1].set_xlabel('Group', fontsize=12)

# Overall title
fig.suptitle('Entropy by Brain Region and Group\n(Mean ± Std Dev across donors)', 
             fontsize=14, y=0.98)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

In [None]:
# Reorganize the data for plotting
# Create a dataframe where each row is a (donor, region, group) combination with entropy values
plot_data = []
for (donor, region), group_entropies in entropies_orig.items():
    for group, entropy_val in group_entropies.items():
        plot_data.append({
            'donor': donor,
            'brain_region': region,
            'group': group,
            'entropy': entropy_val
        })

entropy_df = pd.DataFrame(plot_data)

# Get unique regions and groups for consistent ordering
regions = sorted(entropy_df['brain_region'].unique())
all_groups = sorted(entropy_df['group'].unique())

# Calculate mean and std for each brain region and group combination
stats_df = entropy_df.groupby(['brain_region', 'group'])['entropy'].agg(['mean', 'std']).reset_index()
stats_df['std'] = stats_df['std'].fillna(0)  # Fill NaN std with 0 for single donor cases

# Create subplots - one for each brain region, stacked vertically
fig, axes = plt.subplots(nrows=len(regions), ncols=1, figsize=(12, 4*len(regions)), 
                        sharex=True, squeeze=False)
axes = axes.flatten()  # Make it easier to index

palette = all_adata.uns['brain_region_palette']

for i, region in enumerate(regions):
    ax = axes[i]
    
    # Get data for this region
    color = palette.get(region, 'gray')
    region_data = stats_df[stats_df['brain_region'] == region]
    
    # Create lists for plotting, ensuring all groups are represented
    groups_to_plot = []
    means_to_plot = []
    stds_to_plot = []
    
    for group in all_groups:
        group_row = region_data[region_data['group'] == group]
        if not group_row.empty:
            groups_to_plot.append(group)
            means_to_plot.append(group_row['mean'].iloc[0])
            stds_to_plot.append(group_row['std'].iloc[0])
        else: 
            groups_to_plot.append(group)
            means_to_plot.append(0)
            stds_to_plot.append(0)
    
    # Create the bar plot
    if groups_to_plot:
        bars = ax.bar(range(len(groups_to_plot)), means_to_plot, 
                     yerr=stds_to_plot, capsize=3,
                     color=color, alpha=0.8, edgecolor='black', linewidth=0.5)

        # Set the x-tick labels
        ax.set_xticks(range(len(groups_to_plot)))
        ax.set_xticklabels(groups_to_plot, rotation=45, ha='right', fontsize=8)
    
    # Formatting
    ax.set_ylabel(f'{region}\nEntropy', fontsize=10, rotation=0, ha='right', va='center')
    ax.grid(axis='y', linestyle='--', alpha=0.3)
    ax.set_ylim(bottom=0)
    
    # Only show x-axis labels on the bottom plot
    if i < len(regions) - 1:
        ax.set_xticklabels([])

# Set the x-label only for the bottom plot
axes[-1].set_xlabel('Group', fontsize=12)

# Overall title
fig.suptitle('Entropy by Brain Region and Group\n(Mean ± Std Dev across donors)', 
             fontsize=14, y=0.98)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

## Specific Regions

### PU

In [None]:
msn_palette = pd.read_excel("/home/x-aklein2/projects/aklein/BICAN/data/color_scheme.xlsx", sheet_name="MSN", index_col=0).to_dict()["Hex"]
msn_palette

In [None]:
donors= all_adata.obs['donor'].unique().tolist()
labs= all_adata.obs['replicate'].unique().tolist()

In [None]:
adata_pu = all_adata[all_adata.obs['brain_region'] == 'PU'].copy()

In [None]:
adata_pu.obs['Group'].value_counts()

In [None]:
plot_cats = ['STRd D1 Striosome MSN', 'STRd D2 Striosome MSN', 'STRd D1 Matrix MSN', 'STRd D2 Matrix MSN', 
             'STRd D1/D2 Hybrid MSN', 'STRv D1 MSN', 'STRv D2 MSN']
ncols = 4
nrows = int(np.ceil(len(plot_cats) / ncols))

In [None]:
msn_palette['unknown'] = 'magenta'

In [None]:
ncols = len(donors)
nrows = len(labs)
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*3))
for i, _donor in enumerate(donors): 
    for j, _lab in enumerate(labs):
        ax = axes[j, i]
        adata_sub = adata_pu[(adata_pu.obs['donor'] == _donor) & (adata_pu.obs['replicate'] == _lab)]
        if adata_sub.shape[0] > 0:
                
            adata_cat = adata_sub[adata_sub.obs['Group'].isin(plot_cats)]
            if adata_cat.shape[0] == 0:
                ax.set_visible(False)
            else:
                categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax)
                plot_categorical(adata_cat, coord_base="spatial", cluster_col="Group", palette_path = msn_palette, ax=ax, show=False)
                ax.set_title(f"{_donor} - {_lab}\n{adata_cat.shape[0]} cells", fontsize=8)
            
plt.suptitle(f"PU", fontsize=16)
plt.tight_layout()
plt.show()


### GP

In [None]:
adata_gp = all_adata[all_adata.obs['brain_region'] == 'GP'].copy()

In [None]:
group_palette = pd.read_excel("/home/x-aklein2/projects/aklein/BICAN/data/color_scheme.xlsx", sheet_name="Group", index_col=0).to_dict()["Hex"]
group_palette

In [None]:
plot_cats = [x for x in adata_gp.obs['Group'].unique() if "GP" in x]

In [None]:
ncols = len(donors)
nrows = len(labs)
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*3))
for i, _donor in enumerate(donors): 
    for j, _lab in enumerate(labs):
        ax = axes[j, i]
        adata_sub = adata_gp[(adata_gp.obs['donor'] == _donor) & (adata_gp.obs['replicate'] == _lab)]
        if adata_sub.shape[0] > 0:
                
            adata_cat = adata_sub[adata_sub.obs['orig_group'].isin(plot_cats)]
            if adata_cat.shape[0] == 0:
                ax.set_visible(False)
            else:
                categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax)
                plot_categorical(adata_cat, coord_base="spatial", cluster_col="orig_group", palette_path = group_palette, ax=ax, show=False, show_legend=False)
                ax.set_title(f"{_donor} - {_lab}\n{adata_cat.shape[0]} cells", fontsize=8)
            
plt.suptitle(f"GP", fontsize=16)
plt.tight_layout()
plt.show()


# CPSAM

## Load Data

In [None]:
data_dir = Path("/home/x-aklein2/projects/aklein/BICAN/BG/annotation/execute/region_donor_lab_cpsam_gp")

In [None]:
group_tfs_cols = ['c2c_allcools_label_Group', 'allcools_Group_transfer_score', 'allcools_Group', 'Group_Combined']
final_annot = []
for donor_region_dir in data_dir.glob("*"):
    # print(donor_region_dir)
    final_file = donor_region_dir / "final.h5ad"
    if final_file.exists():
        print(final_file)
        adata = ad.read_h5ad(final_file)
        df_sub_obs = adata.obs[['donor', 'brain_region', 'replicate', 'dataset_id', 'Subclass', 'Group']].copy()
        regroup_obs = []
        for _group_file in donor_region_dir.glob("groups/*/final.h5ad"):
            if (_group_file.parent / "joint.h5ad").exists():
                joint_adata = ad.read_h5ad(_group_file.parent / "joint.h5ad")
                df_map = _make_map(adata=joint_adata, ref_col='Group', cluster_col='integrated_leiden')
            else: 
                df_map = None
            adata = ad.read_h5ad(_group_file)
            _combine_cluster_cell_annots(
                    adata,
                    cluster_col='c2c_allcools_label_Group',
                    cell_col='allcools_Group',
                    cell_prob_col='allcools_Group_transfer_score',
                    added_col='Group_Combined',
                    integrated_col='integrated_leiden',
                    df_map=df_map,
                )

            regroup_obs.append(adata.obs[group_tfs_cols])
        df_group_obs = pd.concat(regroup_obs)
        # df_sub_obs['ReGroup'] = np.nan
        df_sub_obs.loc[df_group_obs.index, group_tfs_cols] = df_group_obs[group_tfs_cols].copy()
        final_annot.append(df_sub_obs)
    # break

In [None]:
df_obs = pd.concat(final_annot)
df_obs['orig_group'] = df_obs['Group'].copy()
df_obs['Group'] = df_obs['Group_Combined'].copy()
df_obs = df_obs.drop(columns=['Group_Combined'])
# df_obs = df_obs.drop(columns=['group_mismatch'])

In [None]:
# Add neuron type
subclass_to_neighborhood_dict = pd.read_csv("/home/x-aklein2/projects/aklein/BICAN/data/reference/AIT/subclass_to_neighborhood_dict.csv", header=None, index_col=0).to_dict()[1]
df_obs['Subclass'] = df_obs['Subclass'].fillna('unknown')
df_obs['Neighborhood'] = df_obs['Subclass'].map(subclass_to_neighborhood_dict).fillna("unknown")
df_obs['neuron_type'] = df_obs['Neighborhood'].isin(['Nonneuron']).map({True: 'Nonneuron', False: 'Neuron'})
df_obs.loc[df_obs['Neighborhood'] == "unknown", 'neuron_type'] = "unknown"
df_obs['neuron_type'].value_counts()

In [None]:
df_obs.to_csv("/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CP_annotated_v2.tsv", sep="\t", index=True)

In [None]:
all_adata = ad.read_h5ad("/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_ALL/BG_cpsam_all.h5ad")
all_adata

In [None]:
print(df_obs.index.isin(all_adata.obs.index).all())
all_adata = all_adata[df_obs.index].copy()
keep_cols = ['CELL_ID', 'experiment', 'region', 'segmentation', 'donor', 'CENTER_X',
       'CENTER_Y', 'volume', 'nCount_RNA', 'nFeature_RNA', 'nBlank',
       'nCount_RNA_per_Volume', 'pass_qc_pre', 'pass_qc', 'transcript_count', 'slide',
       'dataset_id', 'cells_region', 'base_umap_0', 'base_umap_1',
       'base_tsne_0', 'base_tsne_1', 'base_leiden', 'integrated_tsne_0',
       'integrated_umap_1','integrated_leiden', 'integrated_tsne_1', 
       'integrated_umap_0', 'brain_region', 'replicate', 'ename', 'dataset',
       'base_round1_umap_0', 'base_round1_umap_1', 'base_round1_tsne_0',
       'base_round1_tsne_1', 'base_round1_leiden', 'base_round2_leiden',]
all_adata.obs = all_adata.obs[keep_cols]
all_adata.obs.loc[df_obs.index, df_obs.columns] = df_obs

In [None]:
all_adata.write_h5ad("/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPSAM_annotated_v2.h5ad")

## Check Agreement

In [None]:
df_obs['group_mismatch'] = df_obs['Group'] != df_obs['orig_group']
msvc = df_obs.groupby(['Group'], observed=True).agg({"group_mismatch" : ['sum', 'count']})
msvc.columns = ['mismatch', 'total']
msvc['mismatch_rate'] = msvc['mismatch'] / msvc['total']

fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.bar(data=msvc.reset_index(), x='Group', height='mismatch_rate', color="salmon", edgecolor="black")
ax.set_ylabel("Mismatch Rate", fontsize=6)
ax.set_xlabel("Original Group", fontsize=6)
ax.set_title("Group vs ReGroup Mismatch Rate", fontsize=6)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=6, rotation=45, ha='right')
ax.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()

## Entropy

In [None]:
df_obs['Group'] = df_obs['Group'].astype('category')
df_obs['orig_group'] = df_obs['orig_group'].astype('category')

In [None]:
entropies = {}
for (_donor, _region), _df in df_obs.groupby(['donor', 'brain_region']):
    group_entropies = {}
    for _class in _df['Group'].cat.remove_unused_categories().cat.categories:
        probs = _df.loc[_df['Group'] == _class, 'allcools_Group_transfer_score']
        ent = entropy(probs.round(3).value_counts().sort_index())
        # if ent > 0: 
        group_entropies[_class] = ent
        # print(f"Class: {_class}, Entropy: {group_entropies[_class]}")
    entropies[(_donor, _region)] = group_entropies

In [None]:
entropies_orig = {}
for (_donor, _region), _df in df_obs.groupby(['donor', 'brain_region']):
    group_entropies = {}
    for _class in _df['orig_group'].cat.remove_unused_categories().cat.categories:
        probs = _df.loc[_df['orig_group'] == _class, 'allcools_Group_transfer_score']
        ent = entropy(probs.round(3).value_counts().sort_index())
        # if ent > 0: 
        group_entropies[_class] = ent
        # print(f"Class: {_class}, Entropy: {group_entropies[_class]}")
    entropies_orig[(_donor, _region)] = group_entropies

In [None]:
# Reorganize the data for plotting
# Create a dataframe where each row is a (donor, region, group) combination with entropy values
plot_data = []
for (donor, region), group_entropies in entropies.items():
    for group, entropy_val in group_entropies.items():
        plot_data.append({
            'donor': donor,
            'brain_region': region,
            'group': group,
            'entropy': entropy_val
        })

entropy_df = pd.DataFrame(plot_data)

# Get unique regions and groups for consistent ordering
regions = sorted(entropy_df['brain_region'].unique())
all_groups = sorted(entropy_df['group'].unique())

# Calculate mean and std for each brain region and group combination
stats_df = entropy_df.groupby(['brain_region', 'group'])['entropy'].agg(['mean', 'std']).reset_index()
stats_df['std'] = stats_df['std'].fillna(0)  # Fill NaN std with 0 for single donor cases

# Create subplots - one for each brain region, stacked vertically
fig, axes = plt.subplots(nrows=len(regions), ncols=1, figsize=(12, 4*len(regions)), 
                        sharex=True, squeeze=False)
axes = axes.flatten()  # Make it easier to index

palette = all_adata.uns['brain_region_palette']

for i, region in enumerate(regions):
    ax = axes[i]
    
    # Get data for this region
    color = palette.get(region, 'gray')
    region_data = stats_df[stats_df['brain_region'] == region]
    
    # Create lists for plotting, ensuring all groups are represented
    groups_to_plot = []
    means_to_plot = []
    stds_to_plot = []
    
    for group in all_groups:
        group_row = region_data[region_data['group'] == group]
        if not group_row.empty:
            groups_to_plot.append(group)
            means_to_plot.append(group_row['mean'].iloc[0])
            stds_to_plot.append(group_row['std'].iloc[0])
        else: 
            groups_to_plot.append(group)
            means_to_plot.append(0)
            stds_to_plot.append(0)
    
    # Create the bar plot
    if groups_to_plot:
        bars = ax.bar(range(len(groups_to_plot)), means_to_plot, 
                     yerr=stds_to_plot, capsize=3,
                     color=color, alpha=0.8, edgecolor='black', linewidth=0.5)

        # Set the x-tick labels
        ax.set_xticks(range(len(groups_to_plot)))
        ax.set_xticklabels(groups_to_plot, rotation=45, ha='right', fontsize=8)
    
    # Formatting
    ax.set_ylabel(f'{region}\nEntropy', fontsize=10, rotation=0, ha='right', va='center')
    ax.grid(axis='y', linestyle='--', alpha=0.3)
    ax.set_ylim(bottom=0)
    
    # Only show x-axis labels on the bottom plot
    if i < len(regions) - 1:
        ax.set_xticklabels([])

# Set the x-label only for the bottom plot
axes[-1].set_xlabel('Group', fontsize=12)

# Overall title
fig.suptitle('Entropy by Brain Region and Group\n(Mean ± Std Dev across donors)', 
             fontsize=14, y=0.98)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

In [None]:
# Reorganize the data for plotting
# Create a dataframe where each row is a (donor, region, group) combination with entropy values
plot_data = []
for (donor, region), group_entropies in entropies_orig.items():
    for group, entropy_val in group_entropies.items():
        plot_data.append({
            'donor': donor,
            'brain_region': region,
            'group': group,
            'entropy': entropy_val
        })

entropy_df = pd.DataFrame(plot_data)

# Get unique regions and groups for consistent ordering
regions = sorted(entropy_df['brain_region'].unique())
all_groups = sorted(entropy_df['group'].unique())

# Calculate mean and std for each brain region and group combination
stats_df = entropy_df.groupby(['brain_region', 'group'])['entropy'].agg(['mean', 'std']).reset_index()
stats_df['std'] = stats_df['std'].fillna(0)  # Fill NaN std with 0 for single donor cases

# Create subplots - one for each brain region, stacked vertically
fig, axes = plt.subplots(nrows=len(regions), ncols=1, figsize=(12, 4*len(regions)), 
                        sharex=True, squeeze=False)
axes = axes.flatten()  # Make it easier to index

palette = all_adata.uns['brain_region_palette']

for i, region in enumerate(regions):
    ax = axes[i]
    
    # Get data for this region
    color = palette.get(region, 'gray')
    region_data = stats_df[stats_df['brain_region'] == region]
    
    # Create lists for plotting, ensuring all groups are represented
    groups_to_plot = []
    means_to_plot = []
    stds_to_plot = []
    
    for group in all_groups:
        group_row = region_data[region_data['group'] == group]
        if not group_row.empty:
            groups_to_plot.append(group)
            means_to_plot.append(group_row['mean'].iloc[0])
            stds_to_plot.append(group_row['std'].iloc[0])
        else: 
            groups_to_plot.append(group)
            means_to_plot.append(0)
            stds_to_plot.append(0)
    
    # Create the bar plot
    if groups_to_plot:
        bars = ax.bar(range(len(groups_to_plot)), means_to_plot, 
                     yerr=stds_to_plot, capsize=3,
                     color=color, alpha=0.8, edgecolor='black', linewidth=0.5)

        # Set the x-tick labels
        ax.set_xticks(range(len(groups_to_plot)))
        ax.set_xticklabels(groups_to_plot, rotation=45, ha='right', fontsize=8)
    
    # Formatting
    ax.set_ylabel(f'{region}\nEntropy', fontsize=10, rotation=0, ha='right', va='center')
    ax.grid(axis='y', linestyle='--', alpha=0.3)
    ax.set_ylim(bottom=0)
    
    # Only show x-axis labels on the bottom plot
    if i < len(regions) - 1:
        ax.set_xticklabels([])

# Set the x-label only for the bottom plot
axes[-1].set_xlabel('Group', fontsize=12)

# Overall title
fig.suptitle('Entropy by Brain Region and Group\n(Mean ± Std Dev across donors)', 
             fontsize=14, y=0.98)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

## Specific Regions

### PU

In [None]:
adata_pu = all_adata[all_adata.obs['brain_region'] == 'PU'].copy()

In [None]:
plot_cats = ['STRd D1 Striosome MSN', 'STRd D2 Striosome MSN', 'STRd D1 Matrix MSN', 'STRd D2 Matrix MSN', 
             'STRd D1/D2 Hybrid MSN', 'STRv D1 MSN', 'STRv D2 MSN']
ncols = 4
nrows = int(np.ceil(len(plot_cats) / ncols))

In [None]:
ncols = len(donors)
nrows = len(labs)
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*3))
for i, _donor in enumerate(donors): 
    for j, _lab in enumerate(labs):
        ax = axes[j, i]
        adata_sub = adata_pu[(adata_pu.obs['donor'] == _donor) & (adata_pu.obs['replicate'] == _lab)]
        if adata_sub.shape[0] > 0:
                
            adata_cat = adata_sub[adata_sub.obs['Group'].isin(plot_cats)]
            if adata_cat.shape[0] == 0:
                ax.set_visible(False)
            else:
                categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax)
                plot_categorical(adata_cat, coord_base="spatial", cluster_col="Group", palette_path = msn_palette, ax=ax, show=False)
                ax.set_title(f"{_donor} - {_lab}\n{adata_cat.shape[0]} cells", fontsize=8)
            
plt.suptitle(f"PU", fontsize=16)
plt.tight_layout()
plt.show()


### GP

In [None]:
adata_gp = all_adata[all_adata.obs['brain_region'] == 'GP'].copy()
ncols = len(donors)
nrows = len(labs)
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*3))
for i, _donor in enumerate(donors): 
    for j, _lab in enumerate(labs):
        ax = axes[j, i]
        adata_sub = adata_gp[(adata_gp.obs['donor'] == _donor) & (adata_gp.obs['replicate'] == _lab)]
        if adata_sub.shape[0] > 0:
                
            adata_cat = adata_sub[adata_sub.obs['Group'].isin(plot_cats)]
            if adata_cat.shape[0] == 0:
                ax.set_visible(False)
            else:
                categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax)
                plot_categorical(adata_cat, coord_base="spatial", cluster_col="Group", palette_path = group_palette, ax=ax, show=False, show_legend=False)
                ax.set_title(f"{_donor} - {_lab}\n{adata_cat.shape[0]} cells", fontsize=8)
            
plt.suptitle(f"GP", fontsize=16)
plt.tight_layout()
plt.show()


# CPS

## Load Data

In [None]:
data_dir = Path("/home/x-aklein2/projects/aklein/BICAN/BG/annotation/execute/region_donor_lab_cps2")

In [None]:
group_tfs_cols = ['c2c_allcools_label_Group', 'allcools_Group_transfer_score', 'allcools_Group', 'Group_Combined']
final_annot = []
for donor_region_dir in data_dir.glob("*"):
    # print(donor_region_dir)
    final_file = donor_region_dir / "final.h5ad"
    if final_file.exists():
        print(final_file)
        adata = ad.read_h5ad(final_file)
        df_sub_obs = adata.obs[['donor', 'brain_region', 'replicate', 'dataset_id', 'Subclass', 'Group']].copy()
        regroup_obs = []
        for _group_file in donor_region_dir.glob("groups/*/final.h5ad"):
            if (_group_file.parent / "joint.h5ad").exists():
                joint_adata = ad.read_h5ad(_group_file.parent / "joint.h5ad")
                df_map = _make_map(adata=joint_adata, ref_col='Group', cluster_col='integrated_leiden')
            else: 
                df_map = None
            adata = ad.read_h5ad(_group_file)
            _combine_cluster_cell_annots(
                    adata,
                    cluster_col='c2c_allcools_label_Group',
                    cell_col='allcools_Group',
                    cell_prob_col='allcools_Group_transfer_score',
                    added_col='Group_Combined',
                    integrated_col='integrated_leiden',
                    df_map=df_map,
                )

            regroup_obs.append(adata.obs[group_tfs_cols])
        df_group_obs = pd.concat(regroup_obs)
        # df_sub_obs['ReGroup'] = np.nan
        df_sub_obs.loc[df_group_obs.index, group_tfs_cols] = df_group_obs[group_tfs_cols].copy()
        final_annot.append(df_sub_obs)
    # break

In [None]:
df_obs = pd.concat(final_annot)
df_obs['orig_group'] = df_obs['Group'].copy()
df_obs['Group'] = df_obs['Group_Combined'].copy()
df_obs = df_obs.drop(columns=['Group_Combined'])
# df_obs = df_obs.drop(columns=['group_mismatch'])

In [None]:
# Add neuron type
subclass_to_neighborhood_dict = pd.read_csv("/home/x-aklein2/projects/aklein/BICAN/data/reference/AIT/subclass_to_neighborhood_dict.csv", header=None, index_col=0).to_dict()[1]
df_obs['Subclass'] = df_obs['Subclass'].fillna('unknown')
df_obs['Neighborhood'] = df_obs['Subclass'].map(subclass_to_neighborhood_dict).fillna("unknown")
df_obs['neuron_type'] = df_obs['Neighborhood'].isin(['Nonneuron']).map({True: 'Nonneuron', False: 'Neuron'})
df_obs.loc[df_obs['Neighborhood'] == "unknown", 'neuron_type'] = "unknown"
df_obs['neuron_type'].value_counts()

In [None]:
all_adata = ad.read_h5ad("/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_ALL/BG_cps_all.h5ad")
all_adata

In [None]:
print(df_obs.index.isin(all_adata.obs.index).all())
all_adata = all_adata[df_obs.index].copy()
keep_cols = ['CELL_ID', 'experiment', 'region', 'segmentation', 'donor', 'CENTER_X',
       'CENTER_Y', 'volume', 'nCount_RNA', 'nFeature_RNA', 'nBlank',
       'nCount_RNA_per_Volume', 'pass_qc_pre', 'pass_qc', 'transcript_count', 'slide',
       'dataset_id', 'cells_region', 'base_umap_0', 'base_umap_1',
       'base_tsne_0', 'base_tsne_1', 'base_leiden',
       'brain_region', 'replicate', 'ename', 'dataset',
       'base_round1_umap_0', 'base_round1_umap_1', 'base_round1_tsne_0',
       'base_round1_tsne_1', 'base_round1_leiden', 'base_round2_leiden',]
all_adata.obs = all_adata.obs[keep_cols]
all_adata.obs.loc[df_obs.index, df_obs.columns] = df_obs

In [None]:
all_adata.obs['Group'] = all_adata.obs['Group'].fillna("unknown") # Filling in the NA Groups (coming from where the subclass is NA)
all_adata.obs.loc[all_adata.obs['Subclass'] == "unknown", "neuron_type"] = "unknown" # Assigning neuron_type as unknown where Subclass is unknown

In [None]:
all_adata.obs['Group'] = all_adata.obs['Group'].astype('category')
all_adata.obs['orig_group'] = all_adata.obs['orig_group'].astype('category')
all_adata.obs['Subclass'] = all_adata.obs['Subclass'].astype('category')
all_adata.obs['neuron_type'] = all_adata.obs['neuron_type'].astype('category')

In [None]:
color_palette_path = "/home/x-aklein2/projects/aklein/BICAN/data/color_scheme.xlsx"
subclass_color_palette = pd.read_excel(color_palette_path, index_col=0, sheet_name="Subclass").to_dict()['Hex']
group_color_palette = pd.read_excel(color_palette_path, index_col=0, sheet_name="Group").to_dict()['Hex']
msn_palette = pd.read_excel(color_palette_path, index_col=0, sheet_name="MSN").to_dict()['Hex']

neuron_type_palette = {
    "Neuron": "#195f91",
    "Nonneuron": "#dd6a06",
    "unknown": "#808080",
}

all_adata.uns['Subclass_palette'] = subclass_color_palette
all_adata.uns['Group_palette'] = group_color_palette
all_adata.uns['MSN_Groups_palette'] = msn_palette
all_adata.uns['Neuron_type_palette'] = neuron_type_palette

all_adata.obs['MSN_Groups'] = all_adata[all_adata.obs['Subclass'].isin(['STR D1 MSN', 'STR D2 MSN', 'STR Hybrid MSN', 'OT Granular GABA'])].obs['Group'].astype('category')
# IT_TYPES
IT_types = [
    'LAMP5-CXCL14 GABA',
    'LAMP5-LHX6 GABA',
    'OB FRMD7 GABA',
    'STR FS PTHLH-PVALB GABA',
    'STR LYPD6-RSPO2 GABA',
    'STR SST-CHODL GABA',
    'STR SST-RSPO2 GABA',
    'STR TAC3-PLPP4 GABA',
    'STR-BF TAC3-PLPP4-LHX8 GABA',
    'STRd Cholinergic GABA',
    'VIP GABA',
]
all_adata.obs['IT_Group'] = all_adata[all_adata.obs['Group'].isin(IT_types)].obs['Group'].astype('category')

add_colors(all_adata, 'Subclass', subclass_color_palette)
add_colors(all_adata, 'Group', group_color_palette)
add_colors(all_adata, 'MSN_Groups', msn_palette)
add_colors(all_adata, 'IT_Group', group_color_palette)
add_colors(all_adata, 'neuron_type', neuron_type_palette)

all_adata.obs['Subclass'] = all_adata.obs['Subclass'].cat.remove_unused_categories()
all_adata.obs['Group'] = all_adata.obs['Group'].cat.remove_unused_categories()
all_adata.obs['MSN_Groups'] = all_adata.obs['MSN_Groups'].cat.remove_unused_categories()
all_adata.obs['IT_Group'] = all_adata.obs['IT_Group'].cat.remove_unused_categories()

In [None]:
df_obs.to_csv("/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.tsv", sep="\t", index=True)
all_adata.write_h5ad("/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad")

In [None]:
# Doing some naming fixes (to match the paper): 
adata_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
adata = ad.read_h5ad(adata_path)
adata

In [None]:
br_mapper = {
    "CAB" : "CaB",
    "CAH": "CaH",
    "CAT" : "CaT",
    "MGM1" : "MGM1",
    "NAC" : "NAC",
    "GP" : "GP", 
    "PU" : "Pu",
    "SUBTH" : "STH"
}
br_mapper_rev = {v : k for k, v in br_mapper.items()}

# adata.obs['brain_region'] = adata.obs['brain_region'].map(br_mapper_rev).astype('category')
# adata.obs['brain_region'].unique()
adata.obs['brain_region_corr'] = adata.obs['brain_region'].map(br_mapper).astype('category')
print(adata.obs['brain_region_corr'].unique())

br_palette = adata.uns['brain_region_palette']
br_palette_fixed = {br_mapper[k] : v for k, v in br_palette.items() if k in br_mapper.keys()}
adata.uns['brain_region_corr_palette'] = br_palette_fixed
add_colors(adata, 'brain_region_corr', br_palette_fixed)
# add_colors(adata, 'brain_region', br_palette)

adata.write_h5ad(adata_path)

## Check Agreement

In [None]:
df_obs['group_mismatch'] = df_obs['Group'] != df_obs['orig_group']
msvc = df_obs.groupby(['Group'], observed=True).agg({"group_mismatch" : ['sum', 'count']})
msvc.columns = ['mismatch', 'total']
msvc['mismatch_rate'] = msvc['mismatch'] / msvc['total']

fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.bar(data=msvc.reset_index(), x='Group', height='mismatch_rate', color="salmon", edgecolor="black")
ax.set_ylabel("Mismatch Rate", fontsize=6)
ax.set_xlabel("Original Group", fontsize=6)
ax.set_title("Group vs ReGroup Mismatch Rate", fontsize=6)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=6, rotation=45, ha='right')
ax.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()

## Plot PU / GP

In [None]:
msn_palette = all_adata.uns['MSN_Groups_palette']
group_palette = all_adata.uns['Group_palette']
msn_palette

In [None]:
adata_pu = all_adata[all_adata.obs['brain_region'] == 'PU'].copy()
plot_cats = ['STRd D1 Striosome MSN', 'STRd D2 Striosome MSN', 'STRd D1 Matrix MSN', 'STRd D2 Matrix MSN', 
             'STRd D1/D2 Hybrid MSN', 'STRv D1 MSN', 'STRv D2 MSN']
ncols = 4
nrows = int(np.ceil(len(plot_cats) / ncols))
donors= all_adata.obs['donor'].unique().tolist()
labs= all_adata.obs['replicate'].unique().tolist()

In [None]:
ncols = len(donors)
nrows = len(labs)
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*3))
for i, _donor in enumerate(donors): 
    for j, _lab in enumerate(labs):
        ax = axes[j, i]
        adata_sub = adata_pu[(adata_pu.obs['donor'] == _donor) & (adata_pu.obs['replicate'] == _lab)]
        if adata_sub.shape[0] > 0:
                
            adata_cat = adata_sub[adata_sub.obs['Group'].isin(plot_cats)]
            if adata_cat.shape[0] == 0:
                ax.set_visible(False)
            else:
                categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax)
                plot_categorical(adata_cat, coord_base="spatial", cluster_col="Group", palette_path = msn_palette, ax=ax, show=False)
                ax.set_title(f"{_donor} - {_lab}\n{adata_cat.shape[0]} cells", fontsize=8)
            
plt.suptitle(f"PU", fontsize=16)
plt.tight_layout()
plt.show()

In [None]:
adata_cat

In [None]:
adata_gp = all_adata[all_adata.obs['brain_region'] == 'GP'].copy()
plot_cats = [a for a in adata_gp.obs['Group'].unique().tolist() if "GP" in a]
ncols = len(donors)
nrows = len(labs)
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3, nrows*3))
for i, _donor in enumerate(donors): 
    for j, _lab in enumerate(labs):
        ax = axes[j, i]
        adata_sub = adata_gp[(adata_gp.obs['donor'] == _donor) & (adata_gp.obs['replicate'] == _lab)]
        if adata_sub.shape[0] > 0:
                
            adata_cat = adata_sub[adata_sub.obs['Group'].isin(plot_cats)]
            if adata_cat.shape[0] == 0:
                ax.set_visible(False)
            else:
                categorical_scatter(adata_sub, coord_base="spatial", color='lightgrey', ax=ax)
                plot_categorical(adata_cat, coord_base="spatial", cluster_col="Group", palette_path = group_palette, ax=ax, show=False, show_legend=False)
                ax.set_title(f"{_donor} - {_lab}\n{adata_cat.shape[0]} cells", fontsize=8)
            
plt.suptitle(f"GP", fontsize=16)
plt.tight_layout()
plt.show()
