In [None]:
import os 
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc

import matplotlib.pyplot as plt
from spida.pl import plot_categorical, categorical_scatter
plt.rcParams['figure.dpi'] = 150
plt.rcParams['axes.facecolor'] = 'white'

In [None]:
annot_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
adata = ad.read_h5ad(annot_path)
adata

In [None]:
import itertools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def plot_grouped_pies(
    df: pd.DataFrame,
    cluster_col: str,
    subclass_col: str,
    normalize_subclass: bool = False,
    ncols: int = 4,
    figsize: tuple | None = None,
    min_label_frac: float = 0.02,
    title_prefix: str = "",
    palette: dict | list | tuple | None = None,
    radius: float = 1.15,
    labeldistance: float = 1.05,
    textprops: dict | None = None,
) -> tuple[plt.Figure, list[plt.Axes]]:
    """
    Plot a pie chart for each category in `cluster_col`, showing proportions of `subclass_col`.
    
    If `normalize_subclass` is True, each subclass is weighted by its overall abundance
    in the full dataset to control for subclass totals.
    """
    # Crosstab of counts: rows=clusters, cols=subclasses
    counts = pd.crosstab(df[cluster_col], df[subclass_col])
    
    if normalize_subclass:
        subclass_totals = counts.sum(axis=0)
        # avoid divide-by-zero
        weights = counts.divide(subclass_totals.replace(0, pd.NA), axis=1)
        weights = weights.fillna(0)
        # normalize each cluster to sum to 1 for pie plotting
        data = weights.div(weights.sum(axis=1).replace(0, pd.NA), axis=0).fillna(0)
    else:
        data = counts.div(counts.sum(axis=1).replace(0, pd.NA), axis=0).fillna(0)

    n_clusters = data.shape[0]
    nrows = (n_clusters + ncols - 1) // ncols
    if figsize is None:
        figsize = (4 * ncols, 4 * nrows)
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
    if isinstance(axes, np.ndarray):
        axes = axes.flatten().tolist()
    else:
        axes = [axes]
    
    # Build color list aligned to subclass order
    colors = None
    if palette is not None:
        if isinstance(palette, dict):
            colors = [palette.get(k) for k in data.columns]
        else:
            colors = list(itertools.islice(itertools.cycle(palette), len(data.columns)))
    
    if textprops is None:
        textprops = {"fontsize": 8}
    
    for i, (cluster, row) in enumerate(data.iterrows()):
        ax = axes[i]
        fracs = row.values
        labels = [f"{k} ({v:.1%})" if v >= min_label_frac else "" for k, v in row.items()]
        ax.pie(
            fracs,
            labels=labels,
            startangle=90,
            counterclock=False,
            colors=colors,
            radius=radius,
            labeldistance=labeldistance,
            textprops=textprops,
        )
        ax.set_title(f"{title_prefix}{cluster}")
        ax.axis("equal")
    
    # Hide unused axes
    for j in range(i + 1, len(axes)):
        axes[j].axis("off")
    
    plt.tight_layout()
    return fig, axes

# Example: plot_grouped_pies(adata.obs, "leiden", "subclass", normalize_subclass=True, palette=my_palette)

In [None]:
adata_sub = adata[adata.obs['dataset_id'] == adata.obs.sample()['dataset_id'].values[0]]
title = adata_sub.obs['brain_region_corr'].values[0] + ' - ' + adata_sub.obs['donor'].values[0] + ' - ' + adata_sub.obs['replicate'].values[0]
adata_sub

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(8,4))
plot_categorical(adata_sub, cluster_col='base_leiden', coord_base='base_umap', ax=axes[0], show=False, text_anno=True)
axes[0].set_title('Leiden Clusters')
plot_categorical(adata_sub, cluster_col='Subclass', coord_base='base_umap', ax=axes[1], show=False, text_anno=True)
axes[1].set_title('Annotated Subclasses')
plt.suptitle(title)
plt.show()

In [None]:
plot_grouped_pies(
    adata_sub.obs,
    'base_leiden',
    'Subclass',
    normalize_subclass=False,
    palette=adata.uns['Subclass_palette'],
    # palette=adata_sub.uns['base_leiden_colors'],
    radius=1.5,
    textprops={"fontsize": 8}
)
plt.show()