The point of this notebook is to illustrate the region calling paradigm (both for white matter and for striosome - matrix). The figures generated here go into spatial supplementary figure 5 (overall supp #).

Author: Amit Klein
email: a3klein@ucsd.edu

In [None]:
# Imports
import os
from pathlib import Path
import itertools 
from tqdm import tqdm

import math
import numpy as np
import pandas as pd
import anndata as ad
import spatialdata as sd

import geopandas as gpd
from shapely import Polygon, Point, box
from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA, NMF
from sklearn.cluster import KMeans
import libpysal as lps
import networkx as nx
import alphashape

from statsmodels.stats.multitest import multipletests

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import TwoSlopeNorm
import spatialdata_plot as sdp # type: ignore
from spida.pl import plot_categorical, categorical_scatter

from spida.utilities.tiling import create_hexagonal_grid

import warnings
warnings.filterwarnings('ignore')

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 8
plt.rcParams['axes.facecolor'] = 'white'
    
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['savefig.transparent'] = True
plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['savefig.pad_inches'] = 0.01

RASTERIZED = False

In [None]:
# parameters
ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
sd_store = "/home/x-aklein2/projects/aklein/BICAN/data/zarr_store"

image_path = Path("/home/x-aklein2/projects/aklein/BICAN/BG/images/figures/supp_reg")
image_path.mkdir(exist_ok=True, parents=True)

In [None]:
adata = ad.read_h5ad(ad_path)
adata

# White Matter Calling

In [None]:
# transfer_genes = ["BCAS1", "OPALIN", "MOBP", "PLEKHH1"]
hex_size = 30
hex_overlap = 0
gmm_cov_type = "tied"
gene_agreement_thr = 0.5
dsc_comp_min_size = 2
gmm_ncomp = 2
transfer_genes = ["BCAS1"]

In [None]:
donors = adata.obs['donor'].unique().tolist()
replicates = adata.obs['replicate'].unique().tolist()
brain_regions = adata.obs['brain_region'].unique().tolist()
experiments = adata.obs['experiment'].unique().tolist()
skip = [("UWA7648", "CAT", "ucsd"), ("UWA7648", "CAT", "salk")]

In [None]:
if isinstance(transfer_genes, str):
    transfer_genes = [transfer_genes]

In [None]:
# Example region: 
_donor = "UCI5224"
_brain_region = "CAB"
_replicate = "salk"

In [None]:
# Read in the spatialdata object
_experiment, _region = adata.obs.loc[(adata.obs['donor'] == _donor) & 
                        (adata.obs['brain_region'] == _brain_region) & 
                        (adata.obs['replicate'] == _replicate), ['experiment', 'region']].values[0]

zarr_path = f"{sd_store}/{_experiment}/{_region}"
sdata = sd.read_zarr(zarr_path)

In [None]:
# Load in the Transcript Dataframe
cs = "pixel"
ch = "DAPI"
image_key = f"default_{_experiment}_{_region}_z3"
points_key = f"proseg_fv38_{_experiment}_{_region}_transcripts"
shapes_key = f"proseg_fv38_{_experiment}_{_region}_polygons"
tab_key1 = "proseg_fv38_table_filt"
tab_key2 = "proseg_fv38_annot"

fts = sdata[points_key].compute()
fts = fts.reset_index()
fts['gene'] = fts['gene'].astype("category")


gdf = gpd.GeoDataFrame(fts, geometry=gpd.points_from_xy(fts['x'], fts['y'])) 
gdf = gdf.loc[gdf['gene'] == "BCAS1"]

In [None]:
# Get the Distance Band Weights     
kd = lps.cg.KDTree(gdf.geometry.apply(lambda geom: (geom.x, geom.y)).tolist())
wnndb = lps.weights.DistanceBand(kd, threshold=100, p=1, binary=False, alpha=-1, ids=gdf.index.tolist())

# Score points based on weights
bcas_weights = []
for i, idx in enumerate(gdf.index):
    bcas_weights.append(np.sum(wnndb.weights[idx]))
gdf['bcas1_score'] = bcas_weights

for _gene in transfer_genes: 
    print(_gene, _gene in fts['gene'].cat.categories)
    gdf[_gene] = (gdf['gene'] == _gene).astype(int)

In [None]:
# Create Grid and aggregate Gene Information in it
total_bounds = gdf.total_bounds  # (minx, miny, maxx, maxy)
grid = create_hexagonal_grid(total_bounds, hex_size, overlap=hex_overlap)

grid['hex_id'] = grid.index.astype(str)
grid = grid.set_index("hex_id")
joint_grid = gpd.sjoin(grid, gdf, how="inner", predicate="contains")

grid['cell_count'] = joint_grid.groupby('hex_id').size()
for _gene in transfer_genes:
    grid[f'{_gene}_count'] = joint_grid.groupby('hex_id')[_gene].sum()

In [None]:
# Plot Gene Counts in Grid
ncols = len(transfer_genes)
fig, axes = plt.subplots(1, ncols, figsize=(5*ncols, 5))
for i, _gene in enumerate(transfer_genes):
    ax=axes[i] if ncols > 1 else axes
    ax.set_title(f'{_gene} Counts', fontsize=14)
    grid.plot(ax=ax, column=f"{_gene}_count", vmax=np.nanpercentile(grid[f"{_gene}_count"], 99), cmap="YlOrRd",
            edgecolor='k', alpha=0.7, linewidth=0.5, rasterized=RASTERIZED, legend=True, legend_kwds={'shrink': 0.6} #, 'fontsize':12}
            ).axis('off');
    cbar_ax = fig.axes[1]
    cbar_ax.tick_params(labelsize=12)
plt.savefig(image_path / f"{_donor}_{_brain_region}_{_replicate}_bcas1_counts_grid.pdf", dpi=300, bbox_inches='tight')
plt.savefig(image_path / f"{_donor}_{_brain_region}_{_replicate}_bcas1_counts_grid.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
# Plot GMM fit for each gene
fig, axes = plt.subplots(1, ncols, figsize=(5*ncols, 3))
for i, _gene in enumerate(transfer_genes):
    print(f"Fitting GMM for {_gene}...")
    df_gene = grid[[_gene + '_count']].dropna().copy()
    gmm = GaussianMixture(n_components=gmm_ncomp, random_state=0, covariance_type=gmm_cov_type).fit(df_gene[[_gene + '_count']].values)
    gene_prediction = gmm.predict(grid[[_gene + '_count']].dropna().values)
    df_gene['predict'] = gene_prediction
    pred_val = df_gene.groupby("predict").mean().idxmax(axis=0)[0]
    df_gene['predict'] = (df_gene['predict'] == pred_val).astype(int)
    grid.loc[df_gene.index, _gene + '_predict'] = df_gene['predict']
    
    # for plotting
    ax = axes[i] if ncols > 1 else axes
    sns.histplot(ax=ax, data=df_gene, x=_gene + '_count', bins=30, hue="predict", palette="viridis", edgecolor='k', rasterized=RASTERIZED, alpha=0.7, stat="density")
    ax.set_xlabel(f'{_gene} Counts', fontsize=12)
    ax.set_ylabel('Density', fontsize=12)
    ax.set_title(f'{_gene} GMM Prediction', fontsize=14)
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=12)
    ax.set_yticklabels(ax.get_yticklabels(), fontsize=12)

plt.savefig(image_path / f"{_donor}_{_brain_region}_{_replicate}_bcas1_gmm_fit.pdf", dpi=300, bbox_inches='tight')
plt.savefig(image_path / f"{_donor}_{_brain_region}_{_replicate}_bcas1_gmm_fit.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
# Plot predicted regions based on GMM
ncols = len(transfer_genes)
fig, axes = plt.subplots(1, ncols, figsize=(5*ncols, 5))
for i, _gene in enumerate(transfer_genes):
    ax=axes[i] if ncols > 1 else axes
    ax.set_title(f'{_gene} Prediction', fontsize=14)
    grid.plot(ax=ax, column=f"{_gene}_predict", cmap="YlOrRd", edgecolor='k', alpha=0.7, linewidth=0.5, rasterized=RASTERIZED).axis('off');
plt.savefig(image_path / f"{_donor}_{_brain_region}_{_replicate}_bcas1_gmm_pred.pdf", dpi=300, bbox_inches='tight')
plt.savefig(image_path / f"{_donor}_{_brain_region}_{_replicate}_bcas1_gmm_pred.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
# Combining gene predictions into overall region selection
hex_ids = {}
for _gene in transfer_genes: 
    hexes = set(grid[grid[_gene + "_predict"] == 1].index)
    for _h in hexes: 
        hex_ids[_h] = 1 if _h not in hex_ids else hex_ids[_h] + 1

df_hids = pd.DataFrame.from_dict(hex_ids, orient="index")
df_hids = df_hids / len(transfer_genes)
chosen_cells = df_hids[df_hids[0] >= gene_agreement_thr].index
chosen_cells = grid.loc[list(chosen_cells)]
print(f"Chosen Hexes: {len(chosen_cells)}")

# Plot selection
fig, ax = plt.subplots()
grid.plot(ax=ax, color='lightgrey', edgecolor='k', alpha=0.5)
chosen_cells.plot(ax=ax, color='red', edgecolor='k', alpha=0.8)

In [None]:
# Join all chosen cells into geometries based on connectivity
chosen_cells = chosen_cells.reset_index()
W = lps.weights.Queen.from_dataframe(chosen_cells)
G = W.to_networkx()
connected_components = list(nx.connected_components(G))
disconnected_comp = [comp for comp in connected_components]

chosen_cells['comp'] = -1
geoms = []
for i, disc in enumerate(disconnected_comp): 
    # print(len(disc))
    if len(disc) < dsc_comp_min_size:
        continue
    chosen_cells.loc[list(disc), "comp"] = i
    temp = chosen_cells[chosen_cells.index.isin(disc)]
    geom = temp.union_all().convex_hull
    geoms.append(geom)

gdf_geoms = gpd.GeoDataFrame(geometry=geoms)
union = gdf_geoms.union_all()
gdf_geoms = gpd.GeoDataFrame(geometry=[union])

In [None]:
# Save Geoms: 
gdf_geoms = gdf_geoms.explode()

geoms_key = "wm_regions"
sdata[geoms_key] = sd.models.ShapesModel().parse(gdf_geoms)
sd.transformations.set_transformation(
    sdata[geoms_key],
    sd.transformations.get_transformation(sdata[points_key], to_coordinate_system="pixel"),
    to_coordinate_system="pixel"
)

sd.transformations.set_transformation(
    sdata[geoms_key],
    sd.transformations.get_transformation(sdata[points_key], to_coordinate_system="global"),
    to_coordinate_system="global"
)

In [None]:
# final Plots
try: 
    fig, ax = plt.subplots(figsize=(5,5))
    (
        sdata.pl.render_images(image_key, channel="DAPI", cmap="gray")
        .pl.render_shapes(geoms_key, color="none", outline_color="red", outline_width=2, outline_alpha=1, fill_alpha=0.5)
        .pl.show(ax=ax, coordinate_systems=cs)
    )
    ax.set_title(f"{_donor} {_brain_region} {_replicate} - WM Regions")
    plt.savefig(image_path / f"{_donor}_{_brain_region}_{_replicate}_bcas1_final_region_DAPI.pdf", dpi=300, bbox_inches='tight')
    plt.savefig(image_path / f"{_donor}_{_brain_region}_{_replicate}_bcas1_final_region_DAPI.png", dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()
except ValueError as e:
    print(f"Could not plot final figure for {_i}: {e}")


# Striosome - Matrix Calling

In [None]:
# I don't really think that I need to expand upon this too much, it is KNN on MSN subtypes --> outlier removal via KNN graph within each cluster --> alphashape for regions.
# Do I want to plot a couple of examples though?

In [None]:
geom_store_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/regions/region_geometries_cps.parquet"
geoms_all = gpd.read_parquet(geom_store_path)

In [None]:
# Example region: 
_donor = "UCI5224"
_brain_region = "CAB"
_replicate = "ucsd"

In [None]:
sub_geoms = geoms_all.loc[(geoms_all['donor'] == _donor) & 
                        (geoms_all['brain_region'] == _brain_region) & 
                        (geoms_all['lab'] == _replicate)]
str_geoms = sub_geoms[sub_geoms['type'] == "Striosome"]
mat_geoms = sub_geoms[sub_geoms['type'] == "Matrix"]
adata_ss = adata[(adata.obs['donor'] == _donor) & 
                  (adata.obs['brain_region'] == _brain_region) & 
                  (adata.obs['replicate'] == _replicate)].copy()
cells = gpd.GeoDataFrame(adata_ss.obs, geometry=gpd.points_from_xy(adata_ss.obs['CENTER_X'], adata_ss.obs['CENTER_Y']))

In [None]:
fig, ax = plt.subplots(figsize=(8,8))

cells.plot(ax=ax, color='lightgrey', markersize=3, edgecolor='none').axis("off");
cells.plot(
    ax=ax, column="MS_NORM", cmap='coolwarm_r',
    edgecolor='none', markersize=3, alpha=0.5,
    legend_kwds={'shrink': 0.6}, legend=True,
    rasterized=RASTERIZED
).axis("off");
cbar_ax = fig.axes[1]
cbar_ax.set_title("Mat-Str Score", fontsize=12)
cbar_ax.tick_params(labelsize=10)

str_geoms.plot(ax=ax, color='red', edgecolor='none', alpha=0.05, rasterized=RASTERIZED)
mat_geoms.plot(ax=ax, color='blue', edgecolor='none', alpha=0.05, rasterized=RASTERIZED)
str_geoms.plot(ax=ax, color='none', edgecolor='black', linewidth=1, alpha=1, rasterized=RASTERIZED)
mat_geoms.plot(ax=ax, color='none', edgecolor='black', linewidth=1, alpha=1, rasterized=RASTERIZED)

# ax.legend()
ax.set_title(f"{_donor} {_brain_region} {_replicate}", fontsize=14)

# plt.savefig(image_path / f"{_donor}_{_brain_region}_{_replicate}_ms_score_regions.pdf", dpi=300, bbox_inches='tight')
# plt.savefig(image_path / f"{_donor}_{_brain_region}_{_replicate}_ms_score_regions.png", dpi=300, bbox_inches='tight')
# plt.show()
plt.close()

In [None]:
fig, ax = plt.subplots(figsize=(8,8))
categorical_scatter(
    data=adata_ss, coord_base="spatial",
    max_points=None, hue=None,
    scatter_kws=dict(color='lightgrey'),
    rasterized=RASTERIZED, axis_format=None,
    ax=ax, 
)
plot_categorical(
    adata_ss, cluster_col="MSN_Groups", coord_base="spatial",
    legend_kws=dict(title="MSN Groups", fontsize=8, title_fontsize=12), 
    axis_format=None,
    show=False, ax=ax,
    show_legend=False,
    #, scatter_kws=dict(s=2, linewidth=0, edgecolor='black'))
    ) 
    
# plt.savefig(image_path / f"{_donor}_{_brain_region}_{_replicate}_msn_groups.pdf", dpi=300, bbox_inches='tight')
# plt.savefig(image_path / f"{_donor}_{_brain_region}_{_replicate}_msn_groups.png", dpi=300, bbox_inches='tight')
# plt.show()
plt.close()

In [None]:
fig, ax = plt.subplots(figsize=(8,8))
categorical_scatter(
    data=adata_ss, coord_base="spatial",
    max_points=None, hue=None,
    scatter_kws=dict(color='lightgrey'),
    rasterized=RASTERIZED, axis_format=None,
    ax=ax, 
)
plot_categorical(
    adata_ss, cluster_col="MSN_Groups", coord_base="spatial",
    legend_kws=dict(title="MSN Groups", fontsize=8, title_fontsize=12), 
    axis_format=None,
    show=False, ax=ax,
    #, scatter_kws=dict(s=2, linewidth=0, edgecolor='black'))
    ) 

str_geoms.plot(ax=ax, color='red', edgecolor='none', alpha=0.05, rasterized=RASTERIZED)
mat_geoms.plot(ax=ax, color='blue', edgecolor='none', alpha=0.05, rasterized=RASTERIZED)
str_geoms.plot(ax=ax, color='none', edgecolor='black', linewidth=1, alpha=1, rasterized=RASTERIZED)
mat_geoms.plot(ax=ax, color='none', edgecolor='black', linewidth=1, alpha=1, rasterized=RASTERIZED)
wm_geoms = sub_geoms[sub_geoms['type'] == "White_Matter"]
wm_geoms.plot(ax=ax, color='lightgrey', edgecolor='none', linewidth=1, alpha=0.05, rasterized=RASTERIZED)

# plt.savefig(image_path / f"{_donor}_{_brain_region}_{_replicate}_msn_groups_with_regions.pdf", dpi=300, bbox_inches='tight')
# plt.savefig(image_path / f"{_donor}_{_brain_region}_{_replicate}_msn_groups_with_regions.png", dpi=300, bbox_inches='tight')
# plt.show()
plt.close()

# Compartment Enrichment 

## Helper Functions

In [None]:
def create_stacked_bar_chart(df, group_column, cell_type_column='cell_type', 
                           figsize=(12, 8), title=None, colors=None, 
                           show_percentages=True, rotation=45, rasterized=False,
                           legend_threshold=5.0, text_threshold=2.0,
                           legend_fontsize=12, def_fontsize=12, title_fontsize=12,
                           xlabel=None, 
                        ):
    """
    Create a stacked bar chart showing cell type percentages across groups.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        The input dataframe containing the data
    group_column : str
        Column name to group by (x-axis categories)
    cell_type_column : str, default 'cell_type'
        Column name containing cell type information
    figsize : tuple, default (12, 8)
        Figure size (width, height)
    title : str, optional
        Chart title
    colors : list or dict, optional
        Colors for cell types. If None, uses seaborn default palette
    show_percentages : bool, default True
        Whether to show percentage labels on bars
    rotation : int, default 45
        Rotation angle for x-axis labels
    legend_threshold : float, default 5.0
        Minimum percentage threshold for including cell types in legend
    
    Returns:
    --------
    fig, ax : matplotlib figure and axis objects
    """
    
    # Calculate cell type counts and percentages
    counts = df.groupby([group_column, cell_type_column]).size().unstack(fill_value=0)
    percentages = counts.div(counts.sum(axis=1), axis=0) * 100
    
    # Set up colors
    n_cell_types = len(counts.columns)
    if colors is None:
        colors = sns.color_palette("Set3", n_cell_types)
    elif isinstance(colors, dict):
        colors = [colors.get(ct, 'gray') for ct in counts.columns]
    
    # Create the plot
    fig, ax = plt.subplots(figsize=figsize)
    
    # Determine which cell types meet the legend threshold
    # Calculate max percentage for each cell type across all groups
    max_percentages = percentages.max(axis=0)
    legend_cell_types = max_percentages[max_percentages >= legend_threshold].index.tolist()
    
    # Create stacked bar chart
    bottom = np.zeros(len(percentages))
    bars = []
    
    for i, cell_type in enumerate(percentages.columns):
        # Only include in legend if it meets the threshold
        label = cell_type if cell_type in legend_cell_types else None
        
        bar = ax.bar(percentages.index, percentages[cell_type], 
                    bottom=bottom, label=label, color=colors[i], 
                    rasterized=rasterized)
        bars.append(bar)
        
        # Add percentage labels if requested
        if show_percentages:
            for j, (idx, value) in enumerate(percentages[cell_type].items()):
                if value > text_threshold:  # Only show label if percentage > 2%
                    ax.text(j, bottom[j] + value/2, f'{value:.1f}%', 
                           ha='center', va='center', fontsize=def_fontsize, fontweight='bold', 
                           rasterized=rasterized)
        
        bottom += percentages[cell_type]
    
    # Customize the plot
    if xlabel is None: 
        ax.set_xlabel(group_column.replace('_', ' ').title(), fontsize=def_fontsize)
    else: 
        ax.set_xlabel(xlabel, fontsize=def_fontsize)
    ax.set_ylabel('Percentage (%)', fontsize=def_fontsize)
    ax.set_ylim(0, 100)
    
    if title:
        ax.set_title(title, fontsize=title_fontsize, fontweight='bold', rasterized=rasterized)
    else:
        ax.set_title(f'Cell Type Distribution by {group_column.replace("_", " ").title()}', 
                    fontsize=title_fontsize, fontweight='bold', rasterized=rasterized)
    
    # Rotate x-axis labels
    if rotation != 0: 
        plt.xticks(rotation=rotation, ha='right', fontsize=def_fontsize)
    else: 
        plt.xticks(ha='center', fontsize=def_fontsize)
    
    # Add legend (only for cell types that meet the threshold)
    legend_handles = [bar for bar, ct in zip(bars, percentages.columns) if ct in legend_cell_types]
    if len(legend_handles) <= 20 and len(legend_handles) > 0:
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=legend_fontsize)
    elif len(legend_handles) > 20:
        # For many legend items, you might want to handle differently
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=legend_fontsize, ncol=2)
    


    # Add grid for better readability
    ax.grid(axis='y', alpha=0.3, linestyle='--')
    
    # plt.tight_layout()
    
    return fig, ax

# Example usage:
# Assuming you have a dataframe 'df' with columns 'region' and 'cell_type'
# fig, ax = create_stacked_bar_chart(df, group_column='region', cell_type_column='cell_type')
# plt.show()

# Alternative simpler version for quick use:
def quick_stacked_bar(df, group_col, cell_type_col='cell_type'):
    """Quick version with minimal customization"""
    counts = df.groupby([group_col, cell_type_col]).size().unstack(fill_value=0)
    percentages = counts.div(counts.sum(axis=1), axis=0) * 100
    
    ax = percentages.plot(kind='bar', stacked=True, figsize=(10, 6), 
                         colormap='Set3', rot=45)
    ax.set_ylabel('Percentage (%)')
    ax.set_title(f'Cell Type Distribution by {group_col}')
    plt.tight_layout()
    return ax
    
def _get_palette(key): 
    if key.split("_")[0] == "subclass": 
        return adata.uns['Subclass_palette']
    # elif key.split("_")[1] != "white": 
    #     return adata.uns['MSN_Groups_palette']
    else: 
        return adata.uns['Group_palette']


In [None]:
def dl_tau2(yi, vi):
    """
    DerSimonianâ€“Laird estimator of between-study variance Tau sq.
    """
    w = 1.0 / vi
    ybar = np.sum(w * yi) / np.sum(w)
    Q = np.sum(w * (yi - ybar) ** 2)
    k = len(yi)
    c = np.sum(w) - np.sum(w ** 2) / np.sum(w)
    tau2 = max(0.0, (Q - (k - 1)) / c) if c > 0 else 0.0
    return tau2, Q

def re_meta(yi, vi):
    """
    Run a random-effects meta-analysis given effect sizes yi and variances vi.
    Returns pooled mean, SE, z-statistic, p-value, CI, tau-sq, and I-sq.
    """
    from scipy.stats import norm
    
    tau2, Q = dl_tau2(yi, vi)
    w_star = 1.0 / (vi + tau2)
    mu = np.sum(w_star * yi) / np.sum(w_star)
    se = np.sqrt(1.0 / np.sum(w_star))
    z = mu / se if se > 0 else np.nan
    p = 2 * norm.sf(abs(z)) if np.isfinite(z) else np.nan
    ci_lb, ci_ub = mu - 1.96 * se, mu + 1.96 * se
    k = len(yi)
    I2 = max(0.0, (Q - (k - 1)) / Q) * 100 if (k > 1 and Q > 0) else 0.0
    return dict(mu=mu, se=se, z=z, p=p,
                ci_lb=ci_lb, ci_ub=ci_ub,
                tau2=tau2, Q=Q, k=k, I2=I2)

## Do

In [None]:
DIR = Path("/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/ms_enrichment")

In [None]:
naming_map = {
    "subclass_white_matter": "Subclass - White Matter",
    "subclass_matrix": "Subclass - Matrix",
    "subclass_striosome": "Subclass - Striosome",
    "group_white_matter": "Group - White Matter",
    "group_matrix": "Group - Matrix",
    "group_striosome": "Group - Striosome",
}

In [None]:
agg_tables = {}
for _i, _level in enumerate(["subclass", "group"]):
    for _j, _compartment in enumerate(["white_matter", "matrix", "striosome"]):
        df_list = []
        for _file in DIR.glob(f"ms_composition_{_level}_{_compartment}*.csv"):
            _donor, _region, _lab = _file.stem.split("_")[-3:]
            df = pd.read_csv(_file)
            df['donor'] = _donor
            df['region'] = _region
            df['lab'] = _lab
            df['id'] = f"{_donor}|{_region}|{_lab}"
            df_list.append(df)
        df_ms = pd.concat(df_list, axis=0)

        rows = []
        for cat, df_cat in df_ms.groupby('cell_type'):
            # df_cat['var_null'] = df_cat['std_null_count'] ** 2
            res = re_meta(df_cat['log_2FC'].values, 1)
            res['cell_type'] = cat
            rows.append(res)
        df_rows = pd.DataFrame(rows)
        df_rows["p_fdr"] = multipletests(df_rows["p"], method="fdr_bh")[1]
        agg_tables[f"{_level}_{_compartment}"] = df_rows

### Individual plots

In [None]:
# for i, (_key, df_ms) in enumerate(agg_tables.items()):
#     fontsize= 10
#     if i < 3: 
#         height=8 
#     else: 
#         height=12


#     fig, ax = plt.subplots(figsize=(6, height))
#     palette = _get_palette(_key)

#     rightmost = max(df_ms['ci_ub'])
#     leftmost = min(df_ms['ci_lb'])
#     star_x = rightmost + (rightmost - leftmost) / 15
#     labels = []
#     for idx, (_c, _row) in enumerate(df_ms.sort_values(by="ci_lb").iterrows()): 
#         point = ax.scatter([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
#         ax.plot([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
#         labels.append(_row['cell_type'])
#         stars = ""
#         if _row['p_fdr'] < 0.01: 
#             stars += "*"
#         if _row['p_fdr'] < 0.001: 
#             stars += "*"
#         if _row['p_fdr'] < 0.0001: 
#             stars += "*"
#         ax.text(star_x, idx-0.5, stars, fontsize=12, va='center')
#         # print(stars)

#     ax.axvline(0, linestyle='--', color='gray', alpha=0.5)
#     ax.set_yticks(np.arange(0, len(labels)))
#     ax.set_yticklabels(labels, fontsize=fontsize)
#     ax.grid(axis='y', linestyle='--', color='gray', alpha=0.25)
#     ax.set_title(naming_map[_key], fontsize=14)
#     plt.tight_layout()
#     # plt.savefig(image_path / f"ms_enrichment_PI_{_key}.pdf", dpi=300, bbox_inches="tight")
#     # plt.savefig(image_path / f"ms_enrichment_PI_{_key}.png", dpi=300, bbox_inches="tight")
#     plt.show()
#     plt.close()

In [None]:
for i, (_key, df_ms) in enumerate(agg_tables.items()):
    if i < 3: 
        height=4
        fontsize= 14
    else: 
        height=5
        fontsize= 12

    fig, ax = plt.subplots(figsize=(6, height))
    palette = _get_palette(_key)

    df_ms = df_ms[(df_ms['ci_lb'] > 0) | (df_ms['ci_ub'] < 0)]
    df_ms = df_ms[~df_ms['cell_type'].isin(["unknown", "VLMC"])]

    rightmost = max(df_ms['ci_ub'])
    leftmost = min(df_ms['ci_lb'])
    star_x = rightmost + (rightmost - leftmost) / 15
    labels = []
    for idx, (_c, _row) in enumerate(df_ms.sort_values(by="ci_lb").iterrows()): 
        point = ax.scatter([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
        ax.plot([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
        labels.append(_row['cell_type'])
        stars = ""
        if _row['p_fdr'] < 0.01: 
            stars += "*"
        if _row['p_fdr'] < 0.001: 
            stars += "*"
        if _row['p_fdr'] < 0.0001: 
            stars += "*"
        ax.text(star_x, idx-0.1, stars, fontsize=12, va='center')
        # print(stars)

    ax.axvline(0, linestyle='--', color='gray', alpha=0.5)
    ax.set_yticks(np.arange(0, len(labels)))
    ax.set_yticklabels(labels, fontsize=fontsize)
    ax.grid(axis='y', linestyle='--', color='gray', alpha=0.25)
    ax.set_title(naming_map[_key], fontsize=24)
    plt.tight_layout()
    plt.savefig(image_path / f"ms_enrichment_PI_{_key}_top.pdf", dpi=300, bbox_inches="tight")
    plt.savefig(image_path / f"ms_enrichment_PI_{_key}_top.png", dpi=300, bbox_inches="tight")
    # plt.show()
    plt.close()

### Group plots (of multiple images together)

In [None]:
# fig, axes = plt.subplots(1, 3, figsize=(15,5))
# for i, (_key, df_ms) in enumerate(agg_tables.items()):
#     if i >= 3:
#         break
#     # print(f"{_key}: {_value.shape}")
#     ax = axes[i]
#     palette = _get_palette(_key)

#     rightmost = max(df_ms['ci_ub'])
#     leftmost = min(df_ms['ci_lb'])
#     star_x = rightmost + (rightmost - leftmost) / 15
#     labels = []
#     for idx, (_c, _row) in enumerate(df_ms.sort_values(by="ci_lb").iterrows()): 
#         point = ax.scatter([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
#         ax.plot([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
#         labels.append(_row['cell_type'])
#         stars = ""
#         if _row['p_fdr'] < 0.01: 
#             stars += "*"
#         if _row['p_fdr'] < 0.001: 
#             stars += "*"
#         if _row['p_fdr'] < 0.0001: 
#             stars += "*"
#         ax.text(star_x, idx-0.5, stars, fontsize=12, va='center')
#         # print(stars)

#     ax.axvline(0, linestyle='--', color='gray', alpha=0.5)
#     ax.set_yticks(np.arange(0, len(labels)))
#     ax.set_yticklabels(labels, fontsize=8)
#     ax.grid(axis='y', linestyle='--', color='gray', alpha=0.25)
#     ax.set_title(naming_map[_key], fontsize=14)
# plt.tight_layout()
# # plt.savefig(image_path / "ms_enrichment_PI_sub.pdf", dpi=300, bbox_inches="tight")
# # plt.savefig(image_path / "ms_enrichment_PI_sub.png", dpi=300, bbox_inches="tight")
# # plt.show()
# plt.close()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15,5))
for i, (_key, df_ms) in enumerate(agg_tables.items()):
    if i < 3:
        continue
    # print(f"{_key}: {_value.shape}")
    ax = axes[i % 3]
    palette = _get_palette(_key)

    rightmost = max(df_ms['ci_ub'])
    leftmost = min(df_ms['ci_lb'])
    star_x = rightmost + (rightmost - leftmost) / 15
    labels = []
    for idx, (_c, _row) in enumerate(df_ms.sort_values(by="ci_lb").iterrows()): 
        point = ax.scatter([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
        ax.plot([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
        labels.append(_row['cell_type'])
        stars = ""
        if _row['p_fdr'] < 0.01: 
            stars += "*"
        if _row['p_fdr'] < 0.001: 
            stars += "*"
        if _row['p_fdr'] < 0.0001: 
            stars += "*"
        ax.text(star_x, idx-0.5, stars, fontsize=12, va='center')
        # print(stars)

    ax.axvline(0, linestyle='--', color='gray', alpha=0.5)
    ax.set_yticks(np.arange(0, len(labels)))
    ax.set_yticklabels(labels, fontsize=6)
    ax.grid(axis='y', linestyle='--', color='gray', alpha=0.25)
    ax.set_title(naming_map[_key], fontsize=14)
plt.tight_layout()
# plt.savefig(image_path / "ms_enrichment_PI_gr.pdf", dpi=300, bbox_inches="tight")
# plt.savefig(image_path / "ms_enrichment_PI_gr.png", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15,10))
for i, (_key, df_ms) in enumerate(agg_tables.items()):
    # print(f"{_key}: {_value.shape}")
    ax = axes[i // 3, i % 3]
    palette = _get_palette(_key)

    rightmost = max(df_ms['ci_ub'])
    leftmost = min(df_ms['ci_lb'])
    star_x = rightmost + (rightmost - leftmost) / 15
    labels = []
    for idx, (_c, _row) in enumerate(df_ms.sort_values(by="ci_lb").iterrows()): 
        point = ax.scatter([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
        ax.plot([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
        labels.append(_row['cell_type'])
        stars = ""
        if _row['p_fdr'] < 0.01: 
            stars += "*"
        if _row['p_fdr'] < 0.001: 
            stars += "*"
        if _row['p_fdr'] < 0.0001: 
            stars += "*"
        ax.text(star_x, idx-0.5, stars, fontsize=12, va='center')
        # print(stars)

    ax.axvline(0, linestyle='--', color='gray', alpha=0.5)
    ax.set_yticks(np.arange(0, len(labels)))
    ax.set_yticklabels(labels, fontsize=8)
    ax.grid(axis='y', linestyle='--', color='gray', alpha=0.25)
    ax.set_title(naming_map[_key], fontsize=14)
plt.tight_layout()
# plt.savefig(image_path / "ms_enrichment_PI_all.pdf", dpi=300, bbox_inches="tight")
# plt.savefig(image_path / "ms_enrichment_PI_all.png", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

In [None]:
# ## White Matter only Group + Subclass
# fig, axes = plt.subplots(1, 2, figsize=(8,4))
# keys = ["subclass_white_matter", "group_white_matter"]
# for i, _key in enumerate(keys):
#     df_ms = agg_tables[_key]
#     # print(f"{_key}: {_value.shape}")
#     ax = axes[i]
#     palette = _get_palette(_key)

#     rightmost = max(df_ms['ci_ub'])
#     leftmost = min(df_ms['ci_lb'])
#     star_x = rightmost + (rightmost - leftmost) / 10
#     labels = []
#     for idx, (_c, _row) in enumerate(df_ms.sort_values(by="ci_lb").iterrows()): 
#         point = ax.scatter([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
#         ax.plot([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
#         labels.append(_row['cell_type'])
#         stars = ""
#         if _row['p_fdr'] < 0.01: 
#             stars += "*"
#         if _row['p_fdr'] < 0.001: 
#             stars += "*"
#         if _row['p_fdr'] < 0.0001: 
#             stars += "*"
#         ax.text(star_x, idx-1, stars)
#         # print(stars)

#     ax.axvline(0, linestyle='--', color='gray', alpha=0.5)
#     ax.set_yticks(np.arange(0, len(labels)))
#     ax.set_yticklabels(labels, fontsize=6)
#     ax.grid(axis='y', linestyle='--', color='gray', alpha=0.25)
#     ax.set_title(naming_map[_key])
# plt.tight_layout()
# # plt.savefig(image_path / "comp_enr_PI_WM.pdf", dpi=300, bbox_inches="tight")
# # plt.savefig(image_path / "comp_enr_PI_WM.png", dpi=300, bbox_inches="tight")
# plt.show()
# plt.close()

In [None]:
# ## White Matter only Group + Subclass
# fig, axes = plt.subplots(1, 2, figsize=(10,4))
# keys = ["subclass_striosome", "subclass_matrix"]
# for i, _key in enumerate(keys):
#     df_ms = agg_tables[_key]
#     # print(f"{_key}: {_value.shape}")
#     ax = axes[i]
#     palette = _get_palette(_key)

#     rightmost = max(df_ms['ci_ub'])
#     leftmost = min(df_ms['ci_lb'])
#     star_x = rightmost + (rightmost - leftmost) / 15
#     labels = []
#     for idx, (_c, _row) in enumerate(df_ms.sort_values(by="ci_lb").iterrows()): 
#         point = ax.scatter([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
#         ax.plot([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
#         labels.append(_row['cell_type'])
#         stars = ""
#         if _row['p_fdr'] < 0.01: 
#             stars += "*"
#         if _row['p_fdr'] < 0.001: 
#             stars += "*"
#         if _row['p_fdr'] < 0.0001: 
#             stars += "*"
#         ax.text(star_x, idx-0.5, stars, fontsize=12, va='center')
#         # print(stars)

#     ax.axvline(0, linestyle='--', color='gray', alpha=0.5)
#     ax.set_yticks(np.arange(0, len(labels)))
#     ax.set_yticklabels(labels, fontsize=6)
#     ax.grid(axis='y', linestyle='--', color='gray', alpha=0.25)
#     ax.set_title(naming_map[_key], fontsize=14)
# plt.tight_layout()
# plt.savefig(image_path / "comp_enr_PI_MS_sub.pdf", dpi=300, bbox_inches="tight")
# plt.savefig(image_path / "comp_enr_PI_MS_sub.png", dpi=300, bbox_inches="tight")
# # plt.savefig(image_path / "comp_enr_PI_MS_sub.svg", dpi=300, bbox_inches="tight")
# plt.show()
# plt.close()

In [None]:
# ## White Matter only Group + Subclass
# fig, axes = plt.subplots(1, 2, figsize=(10,5))
# keys = ["group_striosome", "group_matrix"]
# for i, _key in enumerate(keys):
#     df_ms = agg_tables[_key]
#     # print(f"{_key}: {_value.shape}")
#     ax = axes[i]
#     palette = _get_palette(_key)

#     rightmost = max(df_ms['ci_ub'])
#     leftmost = min(df_ms['ci_lb'])
#     star_x = rightmost + (rightmost - leftmost) / 15
#     labels = []
#     for idx, (_c, _row) in enumerate(df_ms.sort_values(by="ci_lb").iterrows()): 
#         point = ax.scatter([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
#         ax.plot([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
#         labels.append(_row['cell_type'])
#         stars = ""
#         if _row['p_fdr'] < 0.01: 
#             stars += "*"
#         if _row['p_fdr'] < 0.001: 
#             stars += "*"
#         if _row['p_fdr'] < 0.0001: 
#             stars += "*"
#         ax.text(star_x, idx-0.5, stars, fontsize=12, va='center')
#         # print(stars)

#     ax.axvline(0, linestyle='--', color='gray', alpha=0.5)
#     ax.set_yticks(np.arange(0, len(labels)))
#     ax.set_yticklabels(labels, fontsize=6)
#     ax.grid(axis='y', linestyle='--', color='gray', alpha=0.25)
#     ax.set_title(naming_map[_key], fontsize=14)
# plt.tight_layout()
# plt.savefig(image_path / "comp_enr_PI_MS_gr.pdf", dpi=300, bbox_inches="tight")
# plt.savefig(image_path / "comp_enr_PI_MS_gr.png", dpi=300, bbox_inches="tight")
# plt.show()
# plt.close()

In [None]:
# ## Group level Striosome Striosome with removing all cells which include 0 in their P.I.
# fig, ax = plt.subplots(1, 1, figsize=(4,3))
# _key = "group_striosome"
# df_ms = agg_tables[_key]
# df_ms = df_ms[(df_ms['ci_lb'] > 0) | (df_ms['ci_ub'] < 0)]
# # df_ms = df_ms[df_ms['p_fdr'] < 0.01]
# # print(f"{_key}: {_value.shape}")
# palette = _get_palette(_key)
# rightmost = max(df_ms['ci_ub'])
# leftmost = min(df_ms['ci_lb'])
# star_x = rightmost + (rightmost - leftmost) / 15
# labels = []
# for idx, (_c, _row) in enumerate(df_ms.sort_values(by="ci_lb").iterrows()): 
#     point = ax.scatter([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
#     ax.plot([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
#     labels.append(_row['cell_type'])
#     stars = ""
#     if _row['p_fdr'] < 0.01: 
#         stars += "*"
#     if _row['p_fdr'] < 0.001: 
#         stars += "*"
#     if _row['p_fdr'] < 0.0001: 
#         stars += "*"
#     ax.text(star_x, idx-0.5, stars, fontsize=12, va='center')
#     # print(stars)

# ax.axvline(0, linestyle='--', color='gray', alpha=0.5)
# ax.set_yticks(np.arange(0, len(labels)))
# ax.set_yticklabels(labels, fontsize=6)
# ax.grid(axis='y', linestyle='--', color='gray', alpha=0.25)
# ax.set_title(naming_map[_key])

# plt.tight_layout()
# plt.savefig(image_path / "comp_enr_PI_S_top_gr.pdf", dpi=300, bbox_inches="tight")
# plt.savefig(image_path / "comp_enr_PI_S_top_gr.png", dpi=300, bbox_inches="tight")
# # plt.savefig(image_path / "comp_enr_PI_S_top_gr.svg", dpi=300, bbox_inches="tight")
# plt.show()
# plt.close()

In [None]:
# ## Group level Striosome Matrix with removing all cells which include 0 in their P.I.
# fig, ax = plt.subplots(1, 1, figsize=(4,3))
# _key = "group_matrix"
# df_ms = agg_tables[_key]
# df_ms = df_ms[(df_ms['ci_lb'] > 0) | (df_ms['ci_ub'] < 0)]
# # df_ms = df_ms[df_ms['p_fdr'] < 0.05]
# # print(f"{_key}: {_value.shape}")
# palette = _get_palette(_key)
# rightmost = max(df_ms['ci_ub'])
# leftmost = min(df_ms['ci_lb'])
# star_x = rightmost + (rightmost - leftmost) / 15
# labels = []
# for idx, (_c, _row) in enumerate(df_ms.sort_values(by="ci_lb").iterrows()): 
#     point = ax.scatter([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
#     ax.plot([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
#     labels.append(_row['cell_type'])
#     stars = ""
#     if _row['p_fdr'] < 0.01: 
#         stars += "*"
#     if _row['p_fdr'] < 0.001: 
#         stars += "*"
#     if _row['p_fdr'] < 0.0001: 
#         stars += "*"
#     ax.text(star_x, idx-0.5, stars, fontsize=12, va='center')
#     # print(stars)

# ax.axvline(0, linestyle='--', color='gray', alpha=0.5)
# ax.set_yticks(np.arange(0, len(labels)))
# ax.set_yticklabels(labels, fontsize=6)
# ax.grid(axis='y', linestyle='--', color='gray', alpha=0.25)
# ax.set_title(naming_map[_key])

# plt.tight_layout()
# plt.savefig(image_path / "comp_enr_PI_M_top_gr.pdf", dpi=300, bbox_inches="tight")
# plt.savefig(image_path / "comp_enr_PI_M_top_gr.png", dpi=300, bbox_inches="tight")
# # plt.savefig(image_path / "comp_enr_PI_M_top_gr.svg", dpi=300, bbox_inches="tight")
# plt.show()
# plt.close()

# MS Composition 

In [None]:
def create_stacked_bar_chart(df, group_column, cell_type_column='cell_type', 
                           figsize=(12, 8), title=None, colors=None, 
                           show_percentages=True, rotation=45, rasterized=False,
                           legend_threshold=5.0, text_threshold=2.0,
                           legend_fontsize=12, def_fontsize=12, title_fontsize=12,
                           xlabel=None, show_legend=True,
                        ):
    """
    Create a stacked bar chart showing cell type percentages across groups.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        The input dataframe containing the data
    group_column : str
        Column name to group by (x-axis categories)
    cell_type_column : str, default 'cell_type'
        Column name containing cell type information
    figsize : tuple, default (12, 8)
        Figure size (width, height)
    title : str, optional
        Chart title
    colors : list or dict, optional
        Colors for cell types. If None, uses seaborn default palette
    show_percentages : bool, default True
        Whether to show percentage labels on bars
    rotation : int, default 45
        Rotation angle for x-axis labels
    legend_threshold : float, default 5.0
        Minimum percentage threshold for including cell types in legend
    
    Returns:
    --------
    fig, ax : matplotlib figure and axis objects
    """
    
    # Calculate cell type counts and percentages
    counts = df.groupby([group_column, cell_type_column]).size().unstack(fill_value=0)
    percentages = counts.div(counts.sum(axis=1), axis=0) * 100
    
    # Set up colors
    n_cell_types = len(counts.columns)
    if colors is None:
        colors = sns.color_palette("Set3", n_cell_types)
    elif isinstance(colors, dict):
        colors = [colors.get(ct, 'gray') for ct in counts.columns]
    
    # Create the plot
    fig, ax = plt.subplots(figsize=figsize)
    
    # Determine which cell types meet the legend threshold
    # Calculate max percentage for each cell type across all groups
    max_percentages = percentages.max(axis=0)
    legend_cell_types = max_percentages[max_percentages >= legend_threshold].index.tolist()
    
    # Create stacked bar chart
    bottom = np.zeros(len(percentages))
    bars = []
    
    for i, cell_type in enumerate(percentages.columns):
        # Only include in legend if it meets the threshold
        label = cell_type if cell_type in legend_cell_types else None
        
        bar = ax.bar(percentages.index, percentages[cell_type], 
                    bottom=bottom, label=label, color=colors[i], 
                    rasterized=rasterized)
        bars.append(bar)
        
        # Add percentage labels if requested
        if show_percentages:
            for j, (idx, value) in enumerate(percentages[cell_type].items()):
                if value > text_threshold:  # Only show label if percentage > 2%
                    ax.text(j, bottom[j] + value/2, f'{value:.1f}%', 
                           ha='center', va='center', fontsize=def_fontsize, fontweight='bold', 
                           rasterized=rasterized)
        
        bottom += percentages[cell_type]
    
    # Customize the plot
    if xlabel is None: 
        ax.set_xlabel(group_column.replace('_', ' ').title(), fontsize=def_fontsize)
    else: 
        ax.set_xlabel(xlabel, fontsize=def_fontsize)
    ax.set_ylabel('Percentage (%)', fontsize=def_fontsize)
    ax.set_ylim(0, 100)
    
    if title:
        ax.set_title(title, fontsize=title_fontsize, fontweight='bold', rasterized=rasterized)
    else:
        ax.set_title(f'Cell Type Distribution by {group_column.replace("_", " ").title()}', 
                    fontsize=title_fontsize, fontweight='bold', rasterized=rasterized)
    
    # Rotate x-axis labels
    if rotation != 0: 
        plt.xticks(rotation=rotation, ha='right', fontsize=def_fontsize)
    else: 
        plt.xticks(ha='center', fontsize=def_fontsize)
    
    if show_legend: 
        # Add legend (only for cell types that meet the threshold)
        legend_handles = [bar for bar, ct in zip(bars, percentages.columns) if ct in legend_cell_types]
        if len(legend_handles) <= 20 and len(legend_handles) > 0:
            ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=legend_fontsize)
        elif len(legend_handles) > 20:
            # For many legend items, you might want to handle differently
            ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=legend_fontsize, ncol=2)
    


    # Add grid for better readability
    ax.grid(axis='y', alpha=0.3, linestyle='--')
    
    # plt.tight_layout()
    
    return fig, ax

In [None]:
def create_chart_legend(
    df,
    group_column,
    cell_type_column='cell_type',
    figsize=(2, 1),
    title=None,
    colors=None,
    legend_threshold=5.0,
    legend_order=None,
    text_threshold=2.0,
    legend_fontsize=12,
    title_fontsize=12,
    rasterized=False,
):
    """Draw only the legend for a stacked bar chart using the same handles logic as create_stacked_bar_chart."""

    # Calculate cell type counts and percentages
    counts = df.groupby([group_column, cell_type_column]).size().unstack(fill_value=0)
    percentages = counts.div(counts.sum(axis=1), axis=0) * 100
    if legend_order is not None: 
        percentages = percentages[legend_order]

    # Set up colors
    n_cell_types = len(counts.columns)
    if colors is None:
        colors = sns.color_palette("Set3", n_cell_types)
    elif isinstance(colors, dict):
        colors = [colors.get(ct, 'gray') for ct in percentages.columns]

    # Determine which cell types meet the legend threshold (mirrors create_stacked_bar_chart)
    max_percentages = percentages.max(axis=0)
    legend_cell_types = max_percentages[max_percentages >= legend_threshold].index.tolist()

    # Build legend handles matching the stacked bar colors/order
    handles_labels = [
        (plt.Rectangle((0, 0), 1, 1, facecolor=colors[i], edgecolor='none', linewidth=0.5), ct)
        for i, ct in enumerate(percentages.columns)
        if ct in legend_cell_types
    ]

    fig, ax = plt.subplots(figsize=figsize)
    ax.axis('off')

    if handles_labels:
        handles, labels = zip(*handles_labels)
        ncol = 1 if len(handles) <= 20 else 2
        legend = ax.legend(
            handles=handles,
            labels=labels,
            bbox_to_anchor=(0.5, 0.5),
            loc='center',
            fontsize=legend_fontsize,
            ncol=ncol,
            frameon=True,
        )
        if title:
            legend.set_title(title, prop={'size': title_fontsize})

    return fig, ax


In [None]:
df_wm = adata[adata.obs['wm_compartment'] == "WM"].obs.copy()
df_wm = df_wm[df_wm['Subclass'] != 'unknown'].copy()
df_mat = adata[adata.obs['MS_compartment'] == "Matrix"].obs.copy()
df_str = adata[adata.obs['MS_compartment'] == "Striosome"].obs.copy()

In [None]:
fig, ax = create_stacked_bar_chart(
    df_wm,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='Group',
    title='Cell Type Distribution in White Matter Regions',
    colors=adata.uns['Group_palette'],
    rasterized=RASTERIZED,
    legend_threshold=2.0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24,
    xlabel="",
    show_legend=False,)
plt.savefig(image_path / "ms_composition_wm_group.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "ms_composition_wm_group.pdf", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

fig, ax = create_chart_legend(
    df_wm,
    group_column='brain_region_corr',
    cell_type_column='Group',
    title='Group',
    colors=adata.uns['Group_palette'],
    rasterized=RASTERIZED,
    legend_threshold=2.0,
    text_threshold=5.0,
    legend_fontsize=14,
    title_fontsize=24,
)
plt.savefig(image_path / "ms_composition_wm_group_legend.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "ms_composition_wm_group_legend.pdf", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

In [None]:
fig, ax = create_stacked_bar_chart(
    df_wm,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='Subclass',
    title='Cell Type Distribution in White Matter Regions',
    colors=adata.uns['Subclass_palette'],
    rasterized=RASTERIZED,
    legend_threshold=2.0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24, 
    xlabel="",
    show_legend=False,
)
plt.savefig(image_path / "ms_composition_wm_subclass.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "ms_composition_wm_subclass.pdf", dpi=300, bbox_inches="tight")
# plt.savefig(image_path / "ms_composition_wm_subclass.svg", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

fig, ax = create_chart_legend(
    df_wm,
    group_column='brain_region_corr',
    cell_type_column='Subclass',
    title='Subclass',
    colors=adata.uns['Subclass_palette'],
    rasterized=RASTERIZED,
    legend_threshold=2.0,
    text_threshold=5.0,
    legend_fontsize=14,
    title_fontsize=24,
)
plt.savefig(image_path / "ms_composition_wm_subclass_legend.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "ms_composition_wm_subclass_legend.pdf", dpi=300, bbox_inches="tight")
plt.close()

In [None]:
leg_ord = [
    'STRd D1 Matrix MSN',
    'STRd D2 Matrix MSN',
    'STRd D1 Striosome MSN', 
    'STRd D2 Striosome MSN',
    'STRd D2 StrioMat Hybrid MSN', 
    'STR D1D2 Hybrid MSN', 
    'STRv D1 MSN',
    'STRv D2 MSN',
    'STRv D1 NUDAP MSN', 
    'OT D1 ICj', 
    'unknown'
]

In [None]:
fig, ax = create_stacked_bar_chart(
    df_mat,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='MSN_Groups',
    title='Cell Type Distribution in Matrix Regions',
    colors=adata.uns['MSN_Groups_palette'],
    rasterized=RASTERIZED,
    legend_threshold=2.0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24,
    xlabel="",
    show_legend=False,
)
plt.savefig(image_path / "ms_composition_matrix.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "ms_composition_matrix.pdf", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

fig, ax = create_chart_legend(
    df_mat,
    group_column='brain_region_corr',
    cell_type_column='MSN_Groups',
    title='Group',
    colors=adata.uns['MSN_Groups_palette'],
    rasterized=RASTERIZED,
    legend_threshold=0,
    legend_order = leg_ord,
    text_threshold=5.0,
    legend_fontsize=14,
    title_fontsize=24,
    figsize=(1,1)
)
plt.savefig(image_path / "ms_composition_legend.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "ms_composition_legend.pdf", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

In [None]:
fig, ax = create_stacked_bar_chart(
    df_str,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='MSN_Groups',
    title='Cell Type Distribution in Striosome Regions',
    colors=adata.uns['MSN_Groups_palette'],
    rasterized=RASTERIZED,
    legend_threshold=2.0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24,
    xlabel="", 
    show_legend=False,
    )
plt.savefig(image_path / "ms_composition_str.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "ms_composition_str.pdf", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

# Plotting All Regions

In [None]:
geom_store_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/regions/region_geometries_cps.parquet"
geoms_all = gpd.read_parquet(geom_store_path)

br_to_brc_map = adata.obs[['brain_region', 'brain_region_corr']].drop_duplicates().set_index('brain_region')['brain_region_corr'].to_dict()
geoms_all['brain_region_corr'] = geoms_all['brain_region'].map(br_to_brc_map)

brain_regions = ['CaH', 'CaB', 'CaT', 'Pu', 'NAC', "GP", "MGM1", "STH"]
donors = ["UCI2424", "UCI4723", "UCI5224", "UWA7648"]

In [None]:
plot_lab = "ucsd"

In [None]:
# donors = adata.obs['donor'].unique().tolist()
# brain_regions = adata.obs['brain_region_corr'].unique().tolist()
# replicates = adata.obs['replicate'].unique().tolist()

fig, axes = plt.subplots(ncols=len(brain_regions), nrows=len(donors), figsize=(3*len(brain_regions), 3*len(donors)), constrained_layout=True)
for i, _br in enumerate(brain_regions): 
    for j, _donor in enumerate(donors): 
        geom = geoms_all[(geoms_all['brain_region_corr'] == _br) & (geoms_all['donor'] == _donor) & (geoms_all['lab'] == plot_lab)]
        if geom.shape[0] == 0:
            print("Skipping ", _br, _donor, " due to missing geometry.")
            axes[j, i].axis('off') if len(donors) > 1 else axes[i].axis('off')
            continue
        ax = axes[j, i] if len(donors) > 1 else axes[i]

        geom.plot(ax=ax, color=geom['type_color'], edgecolor='none', alpha=0.6).axis("off");
        geom.plot(ax=ax, color='none', edgecolor='black', alpha=1, linewidth=0.5).axis("off");

        cells = adata.obs[(adata.obs['brain_region_corr'] == _br) & (adata.obs['donor'] == _donor) & (adata.obs['replicate'] == plot_lab)].copy()
        cells = gpd.GeoDataFrame(
            cells,
            geometry=gpd.points_from_xy(cells['CENTER_X'], cells['CENTER_Y']),
        )
        hull = cells.unary_union.convex_hull
        x, y = hull.exterior.xy
        ax.fill(x, y, alpha=1, fc='none', ec='black')

        if j == 0: 
            ax.set_title(_br, fontsize=24, y=1.05)

fig.text(0.11, 0.74, donors[0], fontsize=24, ha='center', rotation='vertical')
fig.text(0.11, 0.54, donors[1], fontsize=24, ha='center', rotation='vertical')
fig.text(0.11, 0.34, donors[2], fontsize=24, ha='center', rotation='vertical')
fig.text(0.11, 0.14, donors[3], fontsize=24, ha='center', rotation='vertical')
        
plt.tight_layout()        
plt.savefig(image_path / f"all_region_geometries_{plot_lab}.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / f"all_region_geometries_{plot_lab}.pdf", dpi=300, bbox_inches="tight")
plt.show()