In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad
from openpyxl import load_workbook

import matplotlib.pyplot as plt
import seaborn as sns
from plottable import Table, ColumnDefinition
from plottable.formatters import decimal_to_percent
from PyComplexHeatmap import HeatmapAnnotation, anno_label, ClusterMapPlotter
from spida.pl import plot_categorical, plot_continuous, categorical_scatter

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 14
plt.rcParams['axes.facecolor'] = 'white'

### Functions

In [None]:
def _plot_overlap_heatmap(use_adata, ref_col, qry_col, image_path=None, current_datetime=None):
    vc = use_adata.obs.loc[:, [qry_col, ref_col]].value_counts().reset_index()
    D = vc.groupby(qry_col)['count'].sum()
    vc['N']=vc[qry_col].map(D).astype(int)
    vc['fraction']=vc['count']/vc['N']
    data = vc.pivot(index=qry_col, columns=ref_col, values='fraction')
    data.head()

    df_rows=data.index.to_series().to_frame()
    cols=data.columns.tolist()
    max_idx=np.argmax(data.fillna(0).values,axis=1)
    df_rows["GROUP"]=[cols[i] for i in max_idx]
    use_rows=[]
    for col in data.columns.tolist(): 
        df1=df_rows.loc[df_rows['GROUP']==col]
        if df1.shape[0]==0:
            continue
        use_rows.extend(df1[qry_col].unique().tolist())
    df_rows=df_rows.loc[use_rows]
    ct2code=use_adata.obs.assign(code=use_adata.obs[qry_col].cat.codes).loc[:,[qry_col,'code']].drop_duplicates().set_index(qry_col).code.to_dict()
    # df_rows['Label']=df_rows[cluster_col].apply(lambda x: f"{ct2code[x]}: {x}")
    ret = []
    for x in df_rows[qry_col].tolist():
        ret.extend([f"{ct2code[x]}: {x}"])
    df_rows['Label']=ret
    df_rows.head()

    # Plot
    row_ha=HeatmapAnnotation(
        label=anno_label(df_rows.Label,colors='black',relpos=(0,0.5)),
        axis=0,orientation='right',
    )

    plt.figure(figsize=(24,12))
    ClusterMapPlotter(
        data.loc[df_rows.index.tolist()],row_cluster=False,col_cluster=False,cmap='Reds',
        right_annotation=row_ha,row_split=df_rows['GROUP'],row_split_gap=0.5,
        row_split_order=df_rows['GROUP'].unique().tolist(),
        show_rownames=False,show_colnames=True,yticklabels=True,xticklabels=True,
        xticklabels_kws=dict(labelrotation=-60,labelcolor='blue',labelsize=10),
        yticklabels_kws=dict(labelcolor='red',labelsize=10),
        annot=True,fmt='.2g',linewidth=0.05,linecolor='gold',linestyle='-:',
        label='fraction',legend_kws=dict(extend='both',extendfrac=0.1),
        xlabel=ref_col,ylabel=qry_col,
        xlabel_kws=dict(color='blue',fontsize=14,labelpad=5),xlabel_side='top',
        ylabel_kws=dict(color='red',fontsize=14,labelpad=5), #increace labelpad manually using labelpad (points)
        # xlabel_bbox_kws=dict(facecolor='green'),
        # ylabel_bbox_kws=dict(facecolor='chocolate',edgecolor='red'),|
        # standard_scale=0,
    )
    plt.show()
    plt.close()


def plot_regional_composition_stacked(adata, region_col='brain_region', subclass_col='RNA.Subclass', 
                                     palette=None, figsize=(10, 6), dpi=300, 
                                     title="Regional Subclass Composition", show_percentages=True):
    """
    Create a stacked barplot showing the cumulative distribution of subclasses across brain regions.
    
    Parameters:
    -----------
    adata : AnnData
        Annotated data object containing observations
    region_col : str, default 'brain_region'
        Column name for brain regions
    subclass_col : str, default 'RNA.Subclass'
        Column name for subclass annotations
    palette : dict, optional
        Color palette for subclasses. If None, will try to use adata.uns palette
    figsize : tuple, default (10, 6)
        Figure size (width, height)
    dpi : int, default 300
        Figure resolution
    title : str, default "Regional Subclass Composition"
        Plot title
    show_percentages : bool, default True
        Whether to show percentages instead of raw counts
    
    Returns:
    --------
    fig, ax : matplotlib figure and axes objects
    """
    
    # Create composition data
    composition_data = adata.obs.groupby([region_col, subclass_col]).size().to_frame(name="count").reset_index()
    
    # Convert to percentage if requested
    if show_percentages:
        region_totals = composition_data.groupby(region_col)['count'].sum()
        composition_data['percentage'] = composition_data.apply(
            lambda x: (x['count'] / region_totals[x[region_col]]) * 100, axis=1
        )
        value_col = 'percentage'
        ylabel = 'Percentage (%)'
    else:
        value_col = 'count'
        ylabel = 'Cell Count'
    
    # Pivot for stacked plotting
    pivot_data = composition_data.pivot(index=region_col, columns=subclass_col, values=value_col).fillna(0)
    
    # Set up color palette
    if palette is None:
        # Try to get palette from adata.uns
        if hasattr(adata, 'uns') and f'{subclass_col}_palette' in adata.uns:
            palette = adata.uns[f'{subclass_col}_palette']
        else:
            # Generate a default palette
            import matplotlib.cm as cm
            n_colors = len(pivot_data.columns)
            palette = {cat: cm.tab20(i/n_colors) for i, cat in enumerate(pivot_data.columns)}
    
    # Create the plot
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
    
    # Create stacked bar plot
    bottom = np.zeros(len(pivot_data))
    colors = [palette.get(col, '#888888') for col in pivot_data.columns]
    
    bars = []
    for i, (subclass, color) in enumerate(zip(pivot_data.columns, colors)):
        bar = ax.bar(pivot_data.index, pivot_data[subclass], bottom=bottom, 
                    label=subclass, color=color, edgecolor='white', linewidth=0.5)
        bars.append(bar)
        bottom += pivot_data[subclass]
    
    # Formatting
    ax.set_xlabel(region_col.replace('_', ' ').title())
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
    
    # Rotate x-axis labels if needed
    if len(pivot_data.index) > 6:
        ax.tick_params(axis='x', rotation=45)
    
    # plt.tight_layout()
    
    return fig, ax

def _composition_by_class(
    adata : ad.AnnData,
    class_col : str = 'RNA.Subclass',
    group_col : str = 'brain_region',
    palette : dict | None = None,
    show : bool = True, 
    save_fig : str | None = None,
    save_path : str | None = None,
):
    """
    Create a stacked barplot showing the cumulative distribution of subclasses across brain regions.
    
    Parameters:
    -----------
    data : AnnData or DataFrame
        Annotated data object or DataFrame containing observations
    class_col : str, default 'RNA.Subclass'
        Column name for subclass annotations
    group_col : str, default 'brain_region'
        Column name for brain regions
    palette : dict, optional
        Color palette for subclasses. If None, will try to use adata.uns palette
    """

    # Calculate composition for each class
    composition = adata.obs.groupby([group_col, class_col]).size().to_frame(name="count").reset_index()

    # Calculate percentage composition within each group
    group_totals = composition.groupby(class_col)['count'].sum()
    composition['percentage'] = composition.apply(
        lambda x: (x['count'] / group_totals[x[class_col]]) * 100, axis=1
    )
    pivot_data = composition.pivot(index=class_col, columns=group_col, values='percentage').fillna(0)

    palette = palette or adata.uns.get(f"{group_col}_palette", None)

    # Create the stacked bar chart
    fig, ax = plt.subplots(figsize=(12, 8), dpi=200)

    bottoms = {}
    old_group = None
    # Create stacked bar chart
    for _group in pivot_data.columns:
        ax.bar(pivot_data.index, pivot_data[_group], label=_group, color=palette.get(_group, '#888888'), bottom=bottoms.get(old_group, 0))
        bottoms[_group] = bottoms.get(old_group, 0) + pivot_data[_group].values
        old_group = _group

    ax.set_title(f"{group_col} Composition (%) for each {class_col}", fontsize=14, fontweight='bold')
    ax.set_xlabel(class_col, fontsize=12)
    ax.set_ylabel("Percentage (%)", fontsize=12)
    ax.legend(title=group_col, bbox_to_anchor=(1.05, 1), loc='upper left')

    # Rotate x-axis labels for better readability
    plt.xticks(rotation=45, ha='right', fontsize=6)
    plt.tight_layout()

    if save_fig and save_path:
        plt.savefig(os.path.join(save_path, save_fig), bbox_inches='tight')
    if show:
        plt.show()
    else: 
        return fig, ax


def plot_annot_table(
    adata : ad.AnnData | pd.DataFrame,
    group_col : str,
    class_col : str,
    palette : dict | None = None,
    plot_as_percent : bool = False,
): 
    """
    """
    adata = adata if isinstance(adata, ad.AnnData) else ad.AnnData(adata)
    composition = adata.obs.groupby([group_col, class_col]).size().to_frame(name="count").reset_index()
    comp = composition.pivot(index=group_col, columns=class_col, values='count').fillna(0).T

    palette = palette or adata.uns.get(f"{class_col}_palette", None)
    
    formatter=None
    if plot_as_percent: 
        comp = comp.div(comp.sum(axis=1), axis=0)
        formatter = decimal_to_percent

    fig, ax = plt.subplots(figsize=(10,6), dpi=200)
    tab = Table(
        comp, 
        textprops={'fontsize':6},
        col_label_divider=True,
        col_label_divider_kw={'color':'black', 'linewidth':0.5},
        odd_row_color='#f0f0f0',
        column_definitions=(
        [
            ColumnDefinition(name=class_col, width=1.5),
        ] + 
        [
            ColumnDefinition(name=_reg, formatter=formatter, width=0.8) for _reg in comp.columns.tolist()
        ]
    )
        )
    tab.columns[class_col].set_linewidth(0)
    for _r, _row in tab.rows.items(): 
        lab = _row.cells[0].text.get_text()
        col = palette[lab] if lab in palette else '#888888'
        _row.set_facecolor(col)

    plt.show()
    plt.close()


### Add the colors: 
def add_colors(adata, cat_col, palette):
    colors = []
    for _cat in adata.obs[cat_col].cat.categories: 
        try:
            if isinstance(palette, dict):
                color = palette[_cat]
            else:
                color = palette.loc[_cat, 'Hex']
        except KeyError:
            print(_cat)
            color = '#808080'
        colors.append(color)

    adata.uns[f'{cat_col}_colors'] = colors   

In [None]:
def create_stacked_bar_chart(df, group_column, cell_type_column='cell_type', 
                           figsize=(12, 8), title=None, colors=None, 
                           show_percentages=True, rotation=45, rasterized=False,
                           legend_threshold=5.0, text_threshold=2.0,
                           legend_fontsize=12, def_fontsize=12, title_fontsize=12,
                           xlabel=None, 
                        ):
    """
    Create a stacked bar chart showing cell type percentages across groups.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        The input dataframe containing the data
    group_column : str
        Column name to group by (x-axis categories)
    cell_type_column : str, default 'cell_type'
        Column name containing cell type information
    figsize : tuple, default (12, 8)
        Figure size (width, height)
    title : str, optional
        Chart title
    colors : list or dict, optional
        Colors for cell types. If None, uses seaborn default palette
    show_percentages : bool, default True
        Whether to show percentage labels on bars
    rotation : int, default 45
        Rotation angle for x-axis labels
    legend_threshold : float, default 5.0
        Minimum percentage threshold for including cell types in legend
    
    Returns:
    --------
    fig, ax : matplotlib figure and axis objects
    """
    
    # Calculate cell type counts and percentages
    counts = df.groupby([group_column, cell_type_column]).size().unstack(fill_value=0)
    percentages = counts.div(counts.sum(axis=1), axis=0) * 100
    
    # Set up colors
    n_cell_types = len(counts.columns)
    if colors is None:
        colors = sns.color_palette("Set3", n_cell_types)
    elif isinstance(colors, dict):
        colors = [colors.get(ct, 'gray') for ct in counts.columns]
    
    # Create the plot
    fig, ax = plt.subplots(figsize=figsize)
    
    # Determine which cell types meet the legend threshold
    # Calculate max percentage for each cell type across all groups
    max_percentages = percentages.max(axis=0)
    legend_cell_types = max_percentages[max_percentages >= legend_threshold].index.tolist()
    
    # Create stacked bar chart
    bottom = np.zeros(len(percentages))
    bars = []
    
    for i, cell_type in enumerate(percentages.columns):
        # Only include in legend if it meets the threshold
        label = cell_type if cell_type in legend_cell_types else None
        
        bar = ax.bar(percentages.index, percentages[cell_type], 
                    bottom=bottom, label=label, color=colors[i], 
                    rasterized=rasterized)
        bars.append(bar)
        
        # Add percentage labels if requested
        if show_percentages:
            for j, (idx, value) in enumerate(percentages[cell_type].items()):
                if value > text_threshold:  # Only show label if percentage > 2%
                    ax.text(j, bottom[j] + value/2, f'{value:.1f}%', 
                           ha='center', va='center', fontsize=def_fontsize, fontweight='bold', 
                           rasterized=rasterized)
        
        bottom += percentages[cell_type]
    
    # Customize the plot
    if xlabel is None: 
        ax.set_xlabel(group_column.replace('_', ' ').title(), fontsize=def_fontsize)
    else: 
        ax.set_xlabel(xlabel, fontsize=def_fontsize)
    ax.set_ylabel('Percentage (%)', fontsize=def_fontsize)
    ax.set_ylim(0, 100)
    
    if title:
        ax.set_title(title, fontsize=title_fontsize, fontweight='bold', rasterized=rasterized)
    else:
        ax.set_title(f'Cell Type Distribution by {group_column.replace("_", " ").title()}', 
                    fontsize=title_fontsize, fontweight='bold', rasterized=rasterized)
    
    # Rotate x-axis labels
    if rotation != 0: 
        plt.xticks(rotation=rotation, ha='right', fontsize=def_fontsize)
    else: 
        plt.xticks(ha='center', fontsize=def_fontsize)
    
    # Add legend (only for cell types that meet the threshold)
    legend_handles = [bar for bar, ct in zip(bars, percentages.columns) if ct in legend_cell_types]
    if len(legend_handles) <= 20 and len(legend_handles) > 0:
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=legend_fontsize)
    elif len(legend_handles) > 20:
        # For many legend items, you might want to handle differently
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=legend_fontsize, ncol=2)
    


    # Add grid for better readability
    ax.grid(axis='y', alpha=0.3, linestyle='--')
    
    # plt.tight_layout()
    
    return fig, ax

# Example usage:
# Assuming you have a dataframe 'df' with columns 'region' and 'cell_type'
# fig, ax = create_stacked_bar_chart(df, group_column='region', cell_type_column='cell_type')
# plt.show()

# Alternative simpler version for quick use:
def quick_stacked_bar(df, group_col, cell_type_col='cell_type'):
    """Quick version with minimal customization"""
    counts = df.groupby([group_col, cell_type_col]).size().unstack(fill_value=0)
    percentages = counts.div(counts.sum(axis=1), axis=0) * 100
    
    ax = percentages.plot(kind='bar', stacked=True, figsize=(10, 6), 
                         colormap='Set3', rot=45)
    ax.set_ylabel('Percentage (%)')
    ax.set_title(f'Cell Type Distribution by {group_col}')
    plt.tight_layout()
    return ax

## Read

In [None]:
adata_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
# adata_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_PFV8_annotated_v5.h5ad"
# adata_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPSAM_annotated_v2.h5ad"
adata = ad.read_h5ad(adata_path)
adata

In [None]:
# Moved this elsewhere already!

# adata.obs['Group'] = adata.obs['Group'].fillna("unknown")
# msn_dtypes = ["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"]
# adata.obs.loc[(adata.obs['Subclass'].isin(msn_dtypes)) & (adata.obs['brain_region'] == "GP"), "Subclass"] = "unknown"
# adata.obs.loc[(adata.obs['Subclass'].isin(msn_dtypes)) & (adata.obs['brain_region'] == "GP"), "Group"] = "unknown"

# color_palette_path = "/home/x-aklein2/projects/aklein/BICAN/data/color_scheme.xlsx"
# subclass_color_palette = pd.read_excel(color_palette_path, index_col=0, sheet_name="Subclass").to_dict()['Hex']
# group_color_palette = pd.read_excel(color_palette_path, index_col=0, sheet_name="Group").to_dict()['Hex']
# msn_palette = pd.read_excel(color_palette_path, index_col=0, sheet_name="MSN").to_dict()['Hex']

# neuron_type_palette = {
#     "Neuron": "#1f77b4",
#     "Nonneuron": "#ff7f0e",
#     "unknown": "#808080",
# }

# adata.uns['Subclass_palette'] = subclass_color_palette
# adata.uns['Group_palette'] = group_color_palette
# adata.uns['MSN_Groups_palette'] = msn_palette
# adata.uns['Neuron_type_palette'] = neuron_type_palette

# add_colors(adata, 'Subclass', subclass_color_palette)
# add_colors(adata, 'Group', group_color_palette)
# adata.obs['MSN_Groups'] = adata[adata.obs['Subclass'].isin(['STR D1 MSN', 'STR D2 MSN', 'STR Hybrid MSN', 'OT Granular GABA'])].obs['Group'].astype('category')
# add_colors(adata, 'MSN_Groups', msn_palette)
# add_colors(adata, 'neuron_type', neuron_type_palette)

# adata.obs['Subclass'] = adata.obs['Subclass'].cat.remove_unused_categories()
# adata.obs['Group'] = adata.obs['Group'].cat.remove_unused_categories()
# adata.obs['MSN_Groups'] = adata.obs['MSN_Groups'].cat.remove_unused_categories()
# adata.write_h5ad(adata_path)

## plots

In [None]:
image_path = Path("/home/x-aklein2/projects/aklein/BICAN/BG/images/CPS/composition")
image_path.mkdir(parents=True, exist_ok=True)

In [None]:
# _composition_by_class(adata, class_col='Group', group_col='donor', show=True)
# _composition_by_class(adata, class_col='Subclass', group_col='donor', show=True)

In [None]:
adata.obs['neuron_type'] = adata.obs['neuron_type'].cat.remove_unused_categories()
region_neuron_composition = adata.obs.groupby(['brain_region_corr', 'neuron_type']).size().to_frame().reset_index()

fig, ax = plt.subplots(1, 1, figsize=(8, 4), dpi=300)
bars = sns.barplot(data=region_neuron_composition, x='neuron_type', y=0, hue='brain_region_corr', ax=ax, palette=adata.uns['brain_region_corr_palette'])
[ax.bar_label(bars.containers[i], fontsize=6, padding=2, ) for i in range(len(bars.containers))]
ax.set_xlabel("Neuron Type")
ax.set_ylabel("Cell Count")
ax.set_title("Regional Neuronal Composition")
ax.legend(bbox_to_anchor=(1, 1), loc='upper left')

plt.savefig(image_path / "regional_neuron_composition.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "regional_neuron_composition.pdf", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
# plot_regional_composition_stacked(adata, region_col="brain_region", subclass_col='Subclass', title = "Regional Subclass Composition")
# plot_regional_composition_stacked(adata, region_col="brain_region", subclass_col='neuron_type', title = "Regional Neuron Type Composition", palette=neuron_type_palette)

In [None]:
fig, ax = create_stacked_bar_chart(
    adata[adata.obs['neuron_type'] == "Nonneuron"].obs,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='Subclass',
    title='Nonneuronal Subclass Cell Type Distribution',
    colors=adata.uns['Subclass_palette'],
    rasterized=True,
    legend_threshold=0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24, 
    xlabel="")
plt.savefig(image_path / "nn_composition_subclass.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "nn_composition_subclass.pdf", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
fig, ax = create_stacked_bar_chart(
    adata[adata.obs['neuron_type'] == "Nonneuron"].obs,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='Group',
    title='Nonneuronal Group Cell Type Distribution',
    colors=adata.uns['Group_palette'],
    rasterized=True,
    legend_threshold=0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24, 
    xlabel="")
plt.savefig(image_path / "nn_composition_group.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "nn_composition_group.pdf", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
fig, ax = create_stacked_bar_chart(
    adata[adata.obs['neuron_type'] == "Neuron"].obs,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='Subclass',
    title='Neuronal Subclass Cell Type Distribution',
    colors=adata.uns['Subclass_palette'],
    rasterized=True,
    legend_threshold=2.0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24, 
    xlabel="")
plt.savefig(image_path / "neu_composition_subclass.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "neu_composition_subclass.pdf", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
fig, ax = create_stacked_bar_chart(
    adata[adata.obs['neuron_type'] == "Neuron"].obs,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='Group',
    title='Neuronal Group Cell Type Distribution',
    colors=adata.uns['Group_palette'],
    rasterized=True,
    legend_threshold=0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24, 
    xlabel="")
plt.savefig(image_path / "neu_composition_group.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "neu_composition_group.pdf", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
# plot_regional_composition_stacked(adata[adata.obs['neuron_type'] == 'Nonneuron'], region_col="brain_region", subclass_col='Subclass', title = "Regional Nonneuron Subclass Composition")
# plot_regional_composition_stacked(adata[adata.obs['neuron_type'] == 'Nonneuron'], region_col="brain_region", subclass_col='Group', title = "Regional Nonneuron Group Composition",)
# plot_regional_composition_stacked(adata[adata.obs['neuron_type'] == 'Neuron'], region_col="brain_region", subclass_col='Subclass', title = "Regional Neuron Subclass Composition")
# plot_regional_composition_stacked(adata[adata.obs['neuron_type'] == 'Neuron'], region_col="brain_region", subclass_col='Group', title = "Regional Neuron Group Composition")

In [None]:
# plot_annot_table(adata, group_col='brain_region', class_col='Subclass', palette=adata.uns['Subclass_palette'], plot_as_percent=False)
# plot_annot_table(adata, group_col='brain_region', class_col='Group', palette=adata.uns['Group_palette'], plot_as_percent=False)
# plot_annot_table(adata, group_col='brain_region', class_col='Subclass', palette=adata.uns['Subclass_palette'], plot_as_percent=True)
# plot_annot_table(adata, group_col='brain_region', class_col='Group', palette=adata.uns['Group_palette'], plot_as_percent=True)

In [None]:
fig, ax = create_stacked_bar_chart(
    adata[~adata.obs['MSN_Groups'].isna()].obs,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='Group',
    title='MSN Cell Type Composition',
    colors=adata.uns['MSN_Groups_palette'],
    rasterized=True,
    legend_threshold=0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24, 
    xlabel="")
plt.savefig(image_path / "msn_composition.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "msn_composition.pdf", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

In [None]:
adata.obs.columns

In [None]:
fig, ax = create_stacked_bar_chart(
    adata[~adata.obs['IT_Group'].isna()].obs,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='Group',
    title='IT Cell Type Composition',
    colors=adata.uns['Group_palette'],
    rasterized=True,
    legend_threshold=0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24, 
    xlabel="")
plt.savefig(image_path / "it_composition.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path / "it_composition.pdf", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

### Define Types

In [None]:
adata.obs['Group'].value_counts()

In [None]:
interneuron_cell_types = [
    "STR TAC3-PLPP4 GABA",
    "STR FS PTHLH-PVALB GABA", 
    "VIP GABA",
    "SN SOX6 Dopa",
    "SN-VTR-HTH GATA3-TCF7L2 GABA",
    "STR SST-CHODL GABA",
    "VTR-HTH Glut",
    "LAMP5-CXCL14 GABA",
    "STH PVALB-PITX2 Glut",
    "AMY-SLEA-BNST GABA",
    "STR SST-RSPO2 GABA",
    "ZI-HTH GABA",
    "STRd Cholinergic GABA",
    "STR LYPD6-RSPO2 GABA",
    "BF SKOR1 Glut",
    "LAMP5-LHX6 GABA",
    "OB FRMD7 GABA",
    "STR-BF TAC3-PLPP4-LHX8 GABA",
    "SN EBF2 GABA",
    "SN SEMA5A GABA",
    "SN-VTR GAD2 Dopa",
    "STR SST-ADARB2 GABA",
    "SN-VTR CALB1 Dopa",
    "SN GATA3-PVALB GABA",
    "STR Cholinergic GABA"
]

In [None]:
adata_G = adata[adata.obs['Group'].isin(interneuron_cell_types)]
adata_G.obs['Group'] = adata_G.obs['Group'].cat.remove_unused_categories()

In [None]:
adata_G = adata[adata.obs['Group'].isin(interneuron_cell_types)]
adata_G.obs['Group'] = adata_G.obs['Group'].cat.remove_unused_categories()

region_neuron_composition = adata_G.obs.groupby(['brain_region', 'Group']).size().to_frame().reset_index()

fig, ax = plt.subplots(1, 1, figsize=(8, 4), dpi=300)
bars = sns.barplot(data=region_neuron_composition, x='Group', y=0, hue='brain_region', ax=ax, palette=adata.uns['brain_region_palette'])
# [ax.bar_label(bars.containers[i], fontsize=6, padding=2, ) for i in range(len(bars.containers))]
ax.set_xlabel("IT Type")
ax.set_ylabel("Cell Count")
ax.set_title("Regional Neuronal Composition")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right', fontsize=6)
ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
plt.show()

In [None]:
IT_types = [
    'LAMP5-CXCL14 GABA',
    'LAMP5-LHX6 GABA',
    'OB FRMD7 GABA',
    'STR FS PTHLH-PVALB GABA',
    'STR LYPD6-RSPO2 GABA',
    'STR SST-CHODL GABA',
    'STR SST-RSPO2 GABA',
    'STR TAC3-PLPP4 GABA',
    'STR-BF TAC3-PLPP4-LHX8 GABA',
    'STRd Cholinergic GABA',
    'VIP GABA',
]

print(sorted(IT_types))

In [None]:
adata_G = adata[adata.obs['Group'].isin(IT_types)]
adata_G.obs['Group'] = adata_G.obs['Group'].cat.remove_unused_categories()

region_neuron_composition = adata_G.obs.groupby(['brain_region', 'Group']).size().to_frame().reset_index()

fig, ax = plt.subplots(1, 1, figsize=(8, 4), dpi=300)
bars = sns.barplot(data=region_neuron_composition, x='Group', y=0, hue='brain_region', ax=ax, palette=adata.uns['brain_region_palette'])
# [ax.bar_label(bars.containers[i], fontsize=6, padding=2, ) for i in range(len(bars.containers))]
ax.set_xlabel("IT Type")
ax.set_ylabel("Cell Count")
ax.set_title("Regional Neuronal Composition")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right', fontsize=6)
ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
plt.show()

In [None]:
adata.obs['IT_Group'] = adata[adata.obs['Group'].isin(IT_types)].obs['Group'].astype('category')
add_colors(adata, 'IT_Group', group_color_palette)

In [None]:
# plot_categorical(
#     adata[adata.obs['dataset_id'] == "PU_UCI5224_salk"], cluster_col = 'IT_Group', coord_base="spatial"
# )

# Palette

In [None]:
wubing_palette_path = "/anvil/projects/x-mcb130189/Wubin/BG/metadata/BG_color_palette.xlsx"
wubing_subclass_palette = pd.read_excel(wubing_palette_path, index_col=0, sheet_name="Subclass").to_dict()['Hex']
wubing_group_palette = pd.read_excel(wubing_palette_path, index_col=0, sheet_name="Group").to_dict()['Hex']

In [None]:
color_palette_path = "/home/x-aklein2/projects/aklein/BICAN/data/color_scheme.xlsx"
subclass_color_palette = pd.read_excel(color_palette_path, index_col=0, sheet_name="Subclass").to_dict()['Hex']
group_color_palette = pd.read_excel(color_palette_path, index_col=0, sheet_name="Group").to_dict()['Hex']

In [None]:
print(set(adata.obs['Subclass'].unique()).difference(set(wubing_subclass_palette.keys())))
print(set(adata.obs['Subclass'].unique()).difference(set(subclass_color_palette.keys())))

In [None]:
subclass_final_palette = {}
for k in adata.obs['Subclass'].unique():
    if k in wubing_subclass_palette:
        subclass_final_palette[k] = wubing_subclass_palette[k]
    elif k in subclass_color_palette:
        subclass_final_palette[k] = subclass_color_palette[k]
    else:
        subclass_final_palette[k] = '#DDDDDD'

In [None]:
print(set(adata.obs['Group'].unique()).difference(set(wubing_group_palette.keys())))
print(set(adata.obs['Group'].unique()).difference(set(group_color_palette.keys())))

In [None]:
group_final_palette = {}
for k in adata.obs['Group'].unique():
    if k in wubing_group_palette:
        group_final_palette[k] = wubing_group_palette[k]
    elif k in group_color_palette:
        group_final_palette[k] = group_color_palette[k]
    else:
        group_final_palette[k] = '#DDDDDD'

In [None]:
adata.uns['Group_palette'] = group_final_palette
adata.uns['Subclass_palette'] = subclass_final_palette

In [None]:
'#FFD8FF'
'#FF85FF'
'#CD73FF'
'#991795'
'#407879'
'#797ED6'

In [None]:
import matplotlib.colors as mcolors

In [None]:
MSN_palette = {
    "STRd D1 Matrix MSN" : "#3B9BFF", 
    "STRd D2 Matrix MSN" : "#2E6BB8", # '#2E6BB8', #1E4D7A
    "STRd D2 Striosome MSN" : "#FF3B3B", 
    "STRd D1 Striosome MSN" : "#C93535", # '#8B2020
    "STRv D1 MSN" : "#3BFF6B", 
    "STRv D2 MSN" : "#2EB854", # '#2EB854', #1F7A3A
    "STR D1D2 Hybrid MSN" : "#8530C2",     
    "STRv D1 NUDAP MSN" : "#FFB03B", 
    "STRd D2 StrioMat Hybrid MSN" : "#FF6B1A",
    "OT D1 ICj" : "#C98830", 
}

In [None]:
print(group_final_palette['STRd D1 Matrix MSN'])
print(group_final_palette['STRd D2 Matrix MSN'])
print(group_final_palette['STRd D1 Striosome MSN'])
print(group_final_palette['STRd D2 Striosome MSN'])
print(group_final_palette['STRv D1 MSN'])
print(group_final_palette['STRv D2 MSN'])
# group_final_palette['STRd D1 Matrix MSN']

In [None]:
color_palette_path = "/home/x-aklein2/projects/aklein/BICAN/data/color_scheme.xlsx"


In [None]:
df_subclass_palette = pd.DataFrame.from_dict(subclass_color_palette, orient='index', columns=['Hex'])
df_group_palette = pd.DataFrame.from_dict(group_color_palette, orient='index', columns=['Hex'])
df_msn_palette = pd.DataFrame.from_dict(msn_palette, orient='index', columns=['Hex'])

In [None]:
with pd.ExcelWriter(color_palette_path, engine='openpyxl', mode='a', if_sheet_exists='replace') as writer:  
    df_group_palette.to_excel(writer, sheet_name="Group", index_label='Group')
    df_subclass_palette.to_excel(writer, sheet_name="Subclass", index_label='Subclass')
    df_msn_palette.to_excel(writer, sheet_name="MSN", index_label='Group')