In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad
from scipy.stats import entropy

import matplotlib.pyplot as plt
import seaborn as sns
from spida.pl import plot_categorical, plot_continuous, categorical_scatter
from PyComplexHeatmap import *

plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['axes.linewidth'] = 1
plt.rcParams['axes.facecolor'] = 'white'

## Functions


In [None]:
def _plot_overlap_heatmap(use_adata, ref_col, qry_col, image_path=None, current_datetime=None):
    if isinstance(use_adata, ad.AnnData):
        use_data = use_adata.obs.copy()
    else: 
        use_data = use_adata.copy()
    vc = use_data.loc[:, [qry_col, ref_col]].value_counts().reset_index()
    D = vc.groupby(qry_col)['count'].sum()
    vc['N']=vc[qry_col].map(D).astype(int)
    vc['fraction']=vc['count']/vc['N']
    data = vc.pivot(index=qry_col, columns=ref_col, values='fraction')
    data.head()

    df_rows=data.index.to_series().to_frame()
    cols=data.columns.tolist()
    max_idx=np.argmax(data.fillna(0).values,axis=1)
    df_rows["GROUP"]=[cols[i] for i in max_idx]
    use_rows=[]
    for col in data.columns.tolist(): 
        df1=df_rows.loc[df_rows['GROUP']==col]
        if df1.shape[0]==0:
            continue
        use_rows.extend(df1[qry_col].unique().tolist())
    df_rows=df_rows.loc[use_rows]
    ct2code=use_data.assign(code=use_data[qry_col].cat.codes).loc[:,[qry_col,'code']].drop_duplicates().set_index(qry_col).code.to_dict()
    # df_rows['Label']=df_rows[cluster_col].apply(lambda x: f"{ct2code[x]}: {x}")
    ret = []
    for x in df_rows[qry_col].tolist():
        ret.extend([f"{ct2code[x]}: {x}"])
    df_rows['Label']=ret
    df_rows.head()

    # Plot
    row_ha=HeatmapAnnotation(
        label=anno_label(df_rows.Label,colors='black',relpos=(0,0.5)),
        axis=0,orientation='right',
    )

    plt.figure(figsize=(24,12))
    ClusterMapPlotter(
        data.loc[df_rows.index.tolist()],row_cluster=False,col_cluster=False,cmap='Reds',
        right_annotation=row_ha,row_split=df_rows['GROUP'],row_split_gap=0.5,
        row_split_order=df_rows['GROUP'].unique().tolist(),
        show_rownames=False,show_colnames=True,yticklabels=True,xticklabels=True,
        xticklabels_kws=dict(labelrotation=-60,labelcolor='blue',labelsize=10),
        yticklabels_kws=dict(labelcolor='red',labelsize=10),
        annot=True,fmt='.2g',linewidth=0.05,linecolor='gold',linestyle='-:',
        label='fraction',legend_kws=dict(extend='both',extendfrac=0.1),
        xlabel=ref_col,ylabel=qry_col,
        xlabel_kws=dict(color='blue',fontsize=14,labelpad=5),xlabel_side='top',
        ylabel_kws=dict(color='red',fontsize=14,labelpad=5), #increace labelpad manually using labelpad (points)
        # xlabel_bbox_kws=dict(facecolor='green'),
        # ylabel_bbox_kws=dict(facecolor='chocolate',edgecolor='red'),|
        # standard_scale=0,
    )
    plt.show()
    plt.close()


def plot_regional_composition_stacked(adata, region_col='brain_region', subclass_col='RNA.Subclass', 
                                     palette=None, figsize=(10, 6), dpi=300, 
                                     title="Regional Subclass Composition", show_percentages=True):
    """
    Create a stacked barplot showing the cumulative distribution of subclasses across brain regions.
    
    Parameters:
    -----------
    adata : AnnData
        Annotated data object containing observations
    region_col : str, default 'brain_region'
        Column name for brain regions
    subclass_col : str, default 'RNA.Subclass'
        Column name for subclass annotations
    palette : dict, optional
        Color palette for subclasses. If None, will try to use adata.uns palette
    figsize : tuple, default (10, 6)
        Figure size (width, height)
    dpi : int, default 300
        Figure resolution
    title : str, default "Regional Subclass Composition"
        Plot title
    show_percentages : bool, default True
        Whether to show percentages instead of raw counts
    
    Returns:
    --------
    fig, ax : matplotlib figure and axes objects
    """
    
    # Create composition data
    composition_data = adata.obs.groupby([region_col, subclass_col]).size().to_frame(name="count").reset_index()
    
    # Convert to percentage if requested
    if show_percentages:
        region_totals = composition_data.groupby(region_col)['count'].sum()
        composition_data['percentage'] = composition_data.apply(
            lambda x: (x['count'] / region_totals[x[region_col]]) * 100, axis=1
        )
        value_col = 'percentage'
        ylabel = 'Percentage (%)'
    else:
        value_col = 'count'
        ylabel = 'Cell Count'
    
    # Pivot for stacked plotting
    pivot_data = composition_data.pivot(index=region_col, columns=subclass_col, values=value_col).fillna(0)
    
    # Set up color palette
    if palette is None:
        # Try to get palette from adata.uns
        if hasattr(adata, 'uns') and 'AIT_subclass_palette' in adata.uns:
            palette = adata.uns['AIT_subclass_palette']
        else:
            # Generate a default palette
            import matplotlib.cm as cm
            n_colors = len(pivot_data.columns)
            palette = {cat: cm.tab20(i/n_colors) for i, cat in enumerate(pivot_data.columns)}
    
    # Create the plot
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
    
    # Create stacked bar plot
    bottom = np.zeros(len(pivot_data))
    colors = [palette.get(col, '#888888') for col in pivot_data.columns]
    
    bars = []
    for i, (subclass, color) in enumerate(zip(pivot_data.columns, colors)):
        bar = ax.bar(pivot_data.index, pivot_data[subclass], bottom=bottom, 
                    label=subclass, color=color, edgecolor='white', linewidth=0.5)
        bars.append(bar)
        bottom += pivot_data[subclass]
    
    # Formatting
    ax.set_xlabel(region_col.replace('_', ' ').title())
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
    
    # Rotate x-axis labels if needed
    if len(pivot_data.index) > 6:
        ax.tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    
    return fig, ax

# Proseg / CPS

## Read

In [None]:
# adata = ad.read_h5ad("/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_PFV8_annotated_v5.h5ad", backed='r')
adata = ad.read_h5ad("/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad", backed='r')

donor_palette = adata.uns['donor_palette']
lab_palette = adata.uns['replicate_palette']
brain_region_palette = adata.uns['brain_region_palette']
# subclass_palette = adata.uns['Subclass_palette']
# group_palette = adata.uns['Group_palette']
obs = adata.obs.copy()
del adata

In [None]:
meth_annot_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/methylation_2/annot_with_scores.csv"
meth_annot = pd.read_csv(meth_annot_path, index_col=0)
meth_annot.head()

In [None]:
rna_annot = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.tsv"
rna_annot = pd.read_csv(rna_annot, sep='\t', index_col=0)
rna_annot.head()

In [None]:
common_cells = rna_annot.index.intersection(meth_annot.index)

In [None]:
rna_annot.loc[:, 'meth_subclass'] = meth_annot.loc[common_cells, 'subclass'].astype("category").copy()
rna_annot.loc[:, 'meth_subclass_score'] = meth_annot.loc[common_cells, 'subclass_score'].copy()
rna_annot.loc[:, 'meth_group'] = meth_annot.loc[common_cells, 'group'].astype("category").copy()
# rna_annot.loc[:, 'meth_group_score'] = meth_annot.loc[common_cells, 'group_score'].copy()

In [None]:
rna_annot['group_score'] = rna_annot['allcools_Group_transfer_score'].copy()

In [None]:
common_annot = rna_annot.loc[common_cells, ['donor', 'replicate', 'brain_region', 'Subclass', 'meth_subclass', "Group", "meth_group"]] # , "meth_group_score", "group_score"]]
common_annot = common_annot.astype({
    "donor" : "category",
    "replicate" : "category",
    "brain_region" : "category",
    "Subclass" : "category",
    "meth_subclass" : "category",
    "Group" : "category",
    "meth_group" : "category",
})
common_subclasses = common_annot['Subclass'].cat.categories.union(common_annot['meth_subclass'].cat.categories)
common_annot['Subclass'] = common_annot['Subclass'].cat.set_categories(common_subclasses)
common_annot['meth_subclass'] = common_annot['meth_subclass'].cat.set_categories(common_subclasses)

common_groups = common_annot['Group'].cat.categories.union(common_annot['meth_group'].cat.categories)
common_annot['Group'] = common_annot['Group'].cat.set_categories(common_groups)
common_annot['meth_group'] = common_annot['meth_group'].cat.set_categories(common_groups)

common_annot.dtypes

In [None]:
common_annot['subclass_match'] = (common_annot['Subclass'] == common_annot['meth_subclass'])
common_annot['group_match'] = (common_annot['Group'] == common_annot['meth_group'])

In [None]:
dfa = common_annot

In [None]:
dfb = (
    dfa.
    groupby(["donor", "replicate"])
    .agg({
        "subclass_match" : ["sum", "size"],       
        "group_match" : ["sum", "size"]
        }))
dfb.columns = ['_'.join(col).strip() for col in dfb.columns.values]
dfb['subclass_match_rate'] = dfb['subclass_match_sum'] / dfb['subclass_match_size']
dfb['group_match_rate'] = dfb['group_match_sum'] / dfb['group_match_size']
dfb

## Agreement

In [None]:
agreement = []
for df in common_annot.groupby('brain_region', observed=True):
    region, dfa = df
    # subclass_agreement = (dfa['Subclass'] == dfa['meth_subclass']).sum() / dfa.shape[0] * 100
    # group_agreement = (dfa['Group'] == dfa['meth_group']).sum() / dfa.shape[0] * 100
    # agreement.append((region, dfa.shape[0], subclass_agreement, group_agreement))

    dfb = (
        dfa.
        groupby(["donor", "replicate"], observed=True)
        .agg({
            "subclass_match" : ["sum", "size"],       
            "group_match" : ["sum", "size"]
            })
        )
    dfb.columns = ['_'.join(col).strip() for col in dfb.columns.values]
    dfb['subclass_match_rate'] = dfb['subclass_match_sum'] / dfb['subclass_match_size']
    dfb['group_match_rate'] = dfb['group_match_sum'] / dfb['group_match_size']
    agreement.append((region, 
                      dfb['subclass_match_size'].mean(),
                      dfb['group_match_size'].std(),
                      dfb['subclass_match_rate'].mean() * 100, 
                      dfb['subclass_match_rate'].std() * 100,
                      dfb['group_match_rate'].mean() * 100,
                      dfb['group_match_rate'].std() * 100))
    # break

    # print(f"Region: {region}, Cells: {dfa.shape[0]}")
    # print(f"Subclass Agreement: {subclass_agreement:.3f}%")
    # print(f"Group Agreement: {group_agreement:.3f}%")

In [None]:
agreement_df = pd.DataFrame(agreement, columns=['brain_region', 'n_cells_mean', 'n_cells_std', 'subclass_agreement_mean', 'subclass_agreement_std', 'group_agreement_mean', 'group_agreement_std'])
agreement_df['color'] = agreement_df['brain_region'].map(brain_region_palette)

fig, axes = plt.subplots(1, 3, dpi=300, figsize=(15, 4))

ax = axes[0]
ax.bar(data=agreement_df, x='brain_region', height='n_cells_mean', color='color', yerr=agreement_df['n_cells_std'], capsize=5)
ax.set_title("Common Annotated Cell Types by Brain Region")
ax.set_ylabel("Cell Number")
ax.set_xticks(np.arange(len(agreement_df['brain_region'])))
ax.set_xticklabels(agreement_df['brain_region'], fontsize=8)
ax.grid(axis='y', linestyle='--', alpha=0.75)

ax = axes[1]
ax.bar(data=agreement_df, x='brain_region', height='subclass_agreement_mean', color='color', yerr=agreement_df['subclass_agreement_std'], capsize=5)
ax.set_title("Subclass Agreement by Brain Region")
ax.set_ylabel("Subclass Agreement (%)")
ax.set_ylim(0, 100)
ax.set_xticks(np.arange(len(agreement_df['brain_region'])))
ax.set_xticklabels(agreement_df['brain_region'], fontsize=8)
ax.grid(axis='y', linestyle='--', alpha=0.75)

ax = axes[2]
ax.bar(data=agreement_df, x='brain_region', height='group_agreement_mean', color='color', yerr=agreement_df['group_agreement_std'], capsize=5)
ax.set_title("Group Agreement by Brain Region")
ax.set_ylabel("Group Agreement (%)")
ax.set_ylim(0, 100)
ax.set_xticks(np.arange(len(agreement_df['brain_region'])))
ax.set_xticklabels(agreement_df['brain_region'], fontsize=8)
ax.grid(axis='y', linestyle='--', alpha=0.75)


In [None]:
print("Subclass Level Agreement %.3f%%" % (((common_annot['Subclass'] == common_annot['meth_subclass']).sum() / common_annot.shape[0]) * 100))
print("Group Level Agreement %.3f%%" % (((common_annot['Group'] == common_annot['meth_group']).sum() / common_annot.shape[0]) * 100))

In [None]:
_plot_overlap_heatmap(common_annot[common_annot['brain_region'] == "CAB"], ref_col='Subclass', qry_col='meth_subclass')
_plot_overlap_heatmap(common_annot[common_annot['brain_region'] == "CAB"], ref_col='Group', qry_col='meth_group')
# _plot_overlap_heatmap(common_annot, ref_col='Subclass', qry_col='meth_subclass')
# _plot_overlap_heatmap(common_annot, ref_col='Group', qry_col='meth_group')

In [None]:
entropies_rna = {}
for (_donor, _region), _df in common_annot.groupby(['donor', 'brain_region'], observed=True):
    group_entropies = {}
    for _class in _df['Group'].cat.remove_unused_categories().cat.categories:
        probs = _df.loc[_df['Group'] == _class, 'group_score']
        ent = entropy(probs.round(3).value_counts().sort_index())
        # if ent > 0: 
        group_entropies[_class] = ent
        # print(f"Class: {_class}, Entropy: {group_entropies[_class]}")
    entropies_rna[(_donor, _region)] = group_entropies

entropies_meth = {}
for (_donor, _region), _df in common_annot.groupby(['donor', 'brain_region'], observed=True):
    group_entropies = {}
    for _class in _df['meth_group'].cat.remove_unused_categories().cat.categories:
        probs = _df.loc[_df['meth_group'] == _class, 'meth_group_score']
        ent = entropy(probs.round(3).value_counts().sort_index())
        # if ent > 0: 
        group_entropies[_class] = ent
        # print(f"Class: {_class}, Entropy: {group_entropies[_class]}")
    entropies_meth[(_donor, _region)] = group_entropies

In [None]:
def entropy_to_df(entropies_dict, method_name="RNA"):
    plot_data = []
    for (donor, region), group_entropies in entropies_dict.items():
        for group, entropy_val in group_entropies.items():
            plot_data.append({
                'donor': donor,
                'brain_region': region,
                'group': group,
                'entropy': entropy_val,
                'method': method_name
            })
    entropy_df = pd.DataFrame(plot_data)
    return entropy_df

In [None]:
df_ent_rna = entropy_to_df(entropies_rna, method_name="RNA")
df_ent_mc = entropy_to_df(entropies_meth, method_name="MC")
df_ent = df_ent_rna.merge(df_ent_mc, on=['donor', 'brain_region', 'group'], suffixes=('_RNA', '_MC'))

In [None]:
# Get unique regions and groups for consistent ordering
regions = sorted(df_ent['brain_region'].unique())
all_groups = sorted(df_ent['group'].unique())

# Calculate mean and std for each brain region, group, and method combination
stats_rna = df_ent.groupby(['brain_region', 'group'])['entropy_RNA'].agg(['mean', 'std']).reset_index()
stats_rna.columns = ['brain_region', 'group', 'mean_RNA', 'std_RNA']
stats_rna['std_RNA'] = stats_rna['std_RNA'].fillna(0)

stats_mc = df_ent.groupby(['brain_region', 'group'])['entropy_MC'].agg(['mean', 'std']).reset_index()
stats_mc.columns = ['brain_region', 'group', 'mean_MC', 'std_MC']
stats_mc['std_MC'] = stats_mc['std_MC'].fillna(0)

# Merge the stats
stats_df = stats_rna.merge(stats_mc, on=['brain_region', 'group'], how='outer').fillna(0)

# Create subplots - one for each brain region, stacked vertically
fig, axes = plt.subplots(nrows=len(regions), ncols=1, figsize=(12, 4*len(regions)), 
                        sharex=True, squeeze=False)
axes = axes.flatten()  # Make it easier to index

palette = brain_region_palette

for i, region in enumerate(regions):
    ax = axes[i]
    
    # Get data for this region
    region_color = palette.get(region, 'gray')
    region_data = stats_df[stats_df['brain_region'] == region]
    
    # Create lists for plotting, ensuring all groups are represented
    groups_to_plot = []
    means_rna = []
    stds_rna = []
    means_mc = []
    stds_mc = []
    
    for group in all_groups:
        group_row = region_data[region_data['group'] == group]
        if not group_row.empty:
            groups_to_plot.append(group)
            means_rna.append(group_row['mean_RNA'].iloc[0])
            stds_rna.append(group_row['std_RNA'].iloc[0])
            means_mc.append(group_row['mean_MC'].iloc[0])
            stds_mc.append(group_row['std_MC'].iloc[0])
        else: 
            groups_to_plot.append(group)
            means_rna.append(0)
            stds_rna.append(0)
            means_mc.append(0)
            stds_mc.append(0)
    
    # Create grouped bar plot
    if groups_to_plot:
        x_pos = np.arange(len(groups_to_plot))
        width = 0.35  # Width of bars
        
        # RNA bars
        bars1 = ax.bar(x_pos - width/2, means_rna, width, 
                      yerr=stds_rna, capsize=3,
                      color='lightcoral', alpha=0.8, 
                      edgecolor='black', linewidth=0.5,
                      label='RNA')
        
        # MC bars
        bars2 = ax.bar(x_pos + width/2, means_mc, width,
                      yerr=stds_mc, capsize=3,
                      color='lightblue', alpha=0.8, 
                      edgecolor='black', linewidth=0.5,
                      label='MC')

        # Set the x-tick labels
        ax.set_xticks(x_pos)
        ax.set_xticklabels(groups_to_plot, rotation=45, ha='right', fontsize=8)
    
    # Formatting
    ax.set_ylabel(f'{region}\nEntropy', fontsize=10, rotation=0, ha='right', va='center')
    ax.grid(axis='y', linestyle='--', alpha=0.3)
    ax.set_ylim(bottom=0)
    
    # Add legend only to the first subplot
    if i == 0:
        ax.legend(loc='upper right')
    
    # Only show x-axis labels on the bottom plot
    if i < len(regions) - 1:
        ax.set_xticklabels([])

# Set the x-label only for the bottom plot
axes[-1].set_xlabel('Group', fontsize=12)

# Overall title
fig.suptitle('Entropy Comparison: RNA vs Methylation by Brain Region and Group\n(Mean Â± Std Dev across donors)', 
             fontsize=14, y=0.98)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()