### Calculate Proportion

In [None]:
import pandas as pd

# ------------
# User options
# ------------
group_key = "batch"
category_key = "cell_type"

# specify sub clusters or use all cells
sub_clusters = None
# sub_clusters = ['0', '1', '2', '3']

# -------------------------------------
# Step: Select subset (or skip if None)
# -------------------------------------
if sub_clusters is None:
    selected_obs = adata.obs[[group_key, category_key]]
    print("Using all cells (no sub-cluster filtering).")
else:
    sub_mask = adata.obs['leiden'].isin(sub_clusters)
    selected_obs = adata.obs.loc[sub_mask, [group_key, category_key]]
    print(f"Using sub-clusters: {sub_clusters}")

# ---------------------------------------------------
# Step: Compute counts & proportion within each batch
# ---------------------------------------------------
cluster_counts = (selected_obs.groupby([group_key, category_key]).size().reset_index(name='count'))
cluster_counts['proportion'] = (cluster_counts.groupby(group_key)['count'].transform(lambda x: x / x.sum()))

print("Cluster proportions:")
print(cluster_counts)

# -------------------------------------
# Step: Create pivot table for plotting
# -------------------------------------
df_pivot = cluster_counts.pivot_table(
    index=group_key,
    columns=category_key,
    values='proportion',  # use 'count' instead
    fill_value=0,
    observed=False
)

# Reorder rows according to original batch order
df_pivot = df_pivot.loc[adata.obs[group_key].unique()]

### Visualization

In [None]:
import scanpy as sc

# -----------------------------------
# Step: Color map for Leiden clusters
# -----------------------------------
categories = adata.obs[category_key].cat.categories

palette = sc.pl.palettes.default_20
if len(categories) > len(palette):
    palette = sc.pl.palettes.default_102

color_map = dict(zip(categories,palette[:len(categories)]))

In [None]:
import matplotlib.pyplot as plt

# -----------------------
# Plot: Stacked Bar Chart
# -----------------------
fig, ax = plt.subplots(figsize=(9, 6))

# Initialize bottom positions
bottom = pd.Series([0] * df_pivot.shape[0], index=df_pivot.index)

# Loop to plot each cluster
for cluster in df_pivot.columns:
    ax.bar(
        df_pivot.index,
        df_pivot[cluster],
        bottom=bottom,
        label=f"{cluster}",
        color=color_map.get(cluster, None) if color_map else None
    )
    bottom += df_pivot[cluster]

ax.set_ylabel("Proportion")
ax.set_title(f"Proportion of {category_key} grouped by {group_key}", fontsize=16, fontweight='bold')

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
for spine in ax.spines.values():
    spine.set_linewidth(1)

# Set legend
ax.legend(
    title=category_key,
    bbox_to_anchor=(1.05, 1),
    loc="upper left",
    fontsize=10,
    title_fontsize=12,
    frameon=False
)

plt.xticks(rotation=0)
plt.yticks([])
plt.grid(False)
plt.tight_layout()
plt.show()