In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc

import matplotlib.pyplot as plt
import seaborn as sns
from PyComplexHeatmap import *

## Functions

In [None]:
def _plot_overlap_heatmap(use_adata, ref_col, qry_col, image_path=None, current_datetime=None):
    vc = use_adata.obs.loc[:, [qry_col, ref_col]].value_counts().reset_index()
    D = vc.groupby(qry_col)['count'].sum()
    vc['N']=vc[qry_col].map(D).astype(int)
    vc['fraction']=vc['count']/vc['N']
    data = vc.pivot(index=qry_col, columns=ref_col, values='fraction')
    data.head()

    df_rows=data.index.to_series().to_frame()
    cols=data.columns.tolist()
    max_idx=np.argmax(data.fillna(0).values,axis=1)
    df_rows["GROUP"]=[cols[i] for i in max_idx]
    use_rows=[]
    for col in data.columns.tolist(): 
        df1=df_rows.loc[df_rows['GROUP']==col]
        if df1.shape[0]==0:
            continue
        use_rows.extend(df1[qry_col].unique().tolist())
    df_rows=df_rows.loc[use_rows]
    ct2code=use_adata.obs.assign(code=use_adata.obs[qry_col].cat.codes).loc[:,[qry_col,'code']].drop_duplicates().set_index(qry_col).code.to_dict()
    # df_rows['Label']=df_rows[cluster_col].apply(lambda x: f"{ct2code[x]}: {x}")
    ret = []
    for x in df_rows[qry_col].tolist():
        ret.extend([f"{ct2code[x]}: {x}"])
    df_rows['Label']=ret
    df_rows.head()

    # Plot
    row_ha=HeatmapAnnotation(
        label=anno_label(df_rows.Label,colors='black',relpos=(0,0.5)),
        axis=0,orientation='right',
    )

    plt.figure(figsize=(24,12))
    ClusterMapPlotter(
        data.loc[df_rows.index.tolist()],row_cluster=False,col_cluster=False,cmap='Reds',
        right_annotation=row_ha,row_split=df_rows['GROUP'],row_split_gap=0.5,
        row_split_order=df_rows['GROUP'].unique().tolist(),
        show_rownames=False,show_colnames=True,yticklabels=True,xticklabels=True,
        xticklabels_kws=dict(labelrotation=-60,labelcolor='blue',labelsize=10),
        yticklabels_kws=dict(labelcolor='red',labelsize=10),
        annot=True,fmt='.2g',linewidth=0.05,linecolor='gold',linestyle='-:',
        label='fraction',legend_kws=dict(extend='both',extendfrac=0.1),
        xlabel=ref_col,ylabel=qry_col,
        xlabel_kws=dict(color='blue',fontsize=14,labelpad=5),xlabel_side='top',
        ylabel_kws=dict(color='red',fontsize=14,labelpad=5), #increace labelpad manually using labelpad (points)
        # xlabel_bbox_kws=dict(facecolor='green'),
        # ylabel_bbox_kws=dict(facecolor='chocolate',edgecolor='red'),|
        # standard_scale=0,
    )
    plt.show()
    plt.close()


def plot_regional_composition_stacked(adata, region_col='brain_region', subclass_col='RNA.Subclass', 
                                     palette=None, figsize=(10, 6), dpi=300, 
                                     title="Regional Subclass Composition", show_percentages=True):
    """
    Create a stacked barplot showing the cumulative distribution of subclasses across brain regions.
    
    Parameters:
    -----------
    adata : AnnData
        Annotated data object containing observations
    region_col : str, default 'brain_region'
        Column name for brain regions
    subclass_col : str, default 'RNA.Subclass'
        Column name for subclass annotations
    palette : dict, optional
        Color palette for subclasses. If None, will try to use adata.uns palette
    figsize : tuple, default (10, 6)
        Figure size (width, height)
    dpi : int, default 300
        Figure resolution
    title : str, default "Regional Subclass Composition"
        Plot title
    show_percentages : bool, default True
        Whether to show percentages instead of raw counts
    
    Returns:
    --------
    fig, ax : matplotlib figure and axes objects
    """
    
    # Create composition data
    composition_data = adata.obs.groupby([region_col, subclass_col]).size().to_frame(name="count").reset_index()
    
    # Convert to percentage if requested
    if show_percentages:
        region_totals = composition_data.groupby(region_col)['count'].sum()
        composition_data['percentage'] = composition_data.apply(
            lambda x: (x['count'] / region_totals[x[region_col]]) * 100, axis=1
        )
        value_col = 'percentage'
        ylabel = 'Percentage (%)'
    else:
        value_col = 'count'
        ylabel = 'Cell Count'
    
    # Pivot for stacked plotting
    pivot_data = composition_data.pivot(index=region_col, columns=subclass_col, values=value_col).fillna(0)
    
    # Set up color palette
    if palette is None:
        # Try to get palette from adata.uns
        if hasattr(adata, 'uns') and 'AIT_subclass_palette' in adata.uns:
            palette = adata.uns['AIT_subclass_palette']
        else:
            # Generate a default palette
            import matplotlib.cm as cm
            n_colors = len(pivot_data.columns)
            palette = {cat: cm.tab20(i/n_colors) for i, cat in enumerate(pivot_data.columns)}
    
    # Create the plot
    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
    
    # Create stacked bar plot
    bottom = np.zeros(len(pivot_data))
    colors = [palette.get(col, '#888888') for col in pivot_data.columns]
    
    bars = []
    for i, (subclass, color) in enumerate(zip(pivot_data.columns, colors)):
        bar = ax.bar(pivot_data.index, pivot_data[subclass], bottom=bottom, 
                    label=subclass, color=color, edgecolor='white', linewidth=0.5)
        bars.append(bar)
        bottom += pivot_data[subclass]
    
    # Formatting
    ax.set_xlabel(region_col.replace('_', ' ').title())
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.legend(bbox_to_anchor=(1, 1), loc='upper left')
    
    # Rotate x-axis labels if needed
    if len(pivot_data.index) > 6:
        ax.tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    
    return fig, ax

# Paths

In [None]:
rna_group_dir= Path("/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_ALL_RNA")
mc_group_dir= Path("/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_ALL_MC")

neu_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_ALL/BG_pfv8_neu.h5ad"
nn_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_ALL/BG_pfv8_nn.h5ad"

### RNA

In [None]:
ref_col = "Group"
qry_col = "subclass_leiden"
cluster_col = "integrated_leiden"
df_res = []
for rel_path in rna_group_dir.glob("*.h5ad"):
    print(rel_path)
    group_adata = ad.read_h5ad(rel_path)

    if "integrated_leiden_Group_map" in group_adata.uns.keys():
        df_map = group_adata.uns['integrated_leiden_Group_map']
        all_cat = df_map[ref_col].unique().tolist()
        # print(df_map.shape[0], df_map[ref_col].nunique(), df_map[f'{qry_col}_cell_count'].sum() / group_adata.shape[0])
        df_map = df_map.query(f"{ref_col}_cell_count > 5 & {qry_col}_cell_count > 5")
        # print(df_map.shape[0], df_map[ref_col].nunique(), set(all_cat) - set(df_map[ref_col].unique().tolist()), df_map[f'{qry_col}_cell_count'].sum() / group_adata.shape[0])
        df_map = df_map.query(f"{ref_col}_proportion > 0.5")
        # print(df_map.shape[0], df_map[ref_col].nunique(), set(all_cat) - set(df_map[ref_col].unique().tolist()), df_map[f'{qry_col}_cell_count'].sum() / group_adata.shape[0])
        group_adata.obs[f'RNA.{ref_col}.ct'] = group_adata.obs[cluster_col].astype(str).map(df_map[ref_col].to_dict()).fillna("unknown").astype('category')
        # display(group_adata.obs[f'RNA.{ref_col}.ct'].value_counts())
        keeper_clusters = df_map.query(f"{ref_col}_proportion>0.8").index.tolist()
    else: 
        group_adata.obs[f'RNA.{ref_col}.ct'] = group_adata.obs[f'c2c_allcools_label_{ref_col}'].copy()

    group_adata.obs[f'RNA.{ref_col}'] = group_adata.obs[f'allcools_{ref_col}'].astype("category").cat.remove_unused_categories().copy()
    group_adata.obs[f'RNA.{ref_col}.Prob'] = group_adata.obs[f'allcools_{ref_col}_transfer_score'].copy()
    group_adata.obs[f'RNA.{ref_col}'] = group_adata.obs[f'RNA.{ref_col}'].cat.add_categories("unknown")
    group_adata.obs.loc[group_adata.obs[f'RNA.{ref_col}.Prob'] < 0.5, f'RNA.{ref_col}'] = "unknown"
    # display(group_adata.obs[f'RNA.{ref_col}'].value_counts())

    # print("Overall agreement b/w cluster 2 cluster and cell 2 cell : %.3f%%" % 
        # ((group_adata.obs[f'RNA.{ref_col}.ct'].astype(str) == group_adata.obs[f'RNA.{ref_col}'].astype(str)).sum() / group_adata.obs.shape[0] * 100))
    
    df_ct = group_adata.obs[[f'RNA.{ref_col}', f'RNA.{ref_col}.Prob', f'RNA.{ref_col}.ct', cluster_col]].copy()
    # display(df_ct.head())
    df_ct[f'final_{ref_col}'] = "U"
    df_ct['eql_col'] = df_ct.apply(lambda x: x[f'RNA.{ref_col}'] if x[f'RNA.{ref_col}'] == x[f'RNA.{ref_col}.ct'] else x[f'final_{ref_col}'], axis=1)
    df_ct['keeper_clust'] = df_ct.apply(lambda x: x[f'RNA.{ref_col}.ct'] if x[cluster_col] in keeper_clusters else "U", axis=1)
    df_ct['ff'] = df_ct.apply(lambda x: x[f'RNA.{ref_col}'] if x[f'RNA.{ref_col}.Prob'] > 0.75 else "U", axis=1)
    vc_ef = df_ct[['eql_col', 'ff', 'keeper_clust']].value_counts().reset_index()
    # display(vc_ef)
    # df_u = pd.concat((df_ct['eql_col'].value_counts().to_frame(name='eql_cols'),
    #         df_ct['keeper_clust'].value_counts().to_frame(name="keeper_clust"),
    #         df_ct['ff'].value_counts().to_frame(name='ff')),
    #         axis=1).fillna(0).astype(int)
    # display(df_u)
    for i, _cell in enumerate(df_ct.index): 
        ff_annot = df_ct.at[_cell, 'ff']
        keeper_annot = df_ct.at[_cell, 'keeper_clust']
        eql_annot = df_ct.at[_cell, 'eql_col']

        if ff_annot == 'U' and keeper_annot == 'U':
            annot = eql_annot
        elif ff_annot == keeper_annot:
            annot = ff_annot
        elif ff_annot == 'U':
            annot = keeper_annot
        elif keeper_annot == 'U':
            annot = ff_annot
        elif ff_annot != keeper_annot: 
            annot = "unknown"
        df_ct.at[_cell, f'final_{ref_col}'] = annot
    df_ct[f'final_{ref_col}'] = df_ct[f'final_{ref_col}'].replace("U", "unknown")
    vc_ef = df_ct[['eql_col', 'ff', 'keeper_clust', 'final_Group']].value_counts().reset_index()
    # display(vc_ef)
    final_annots = df_ct[f'final_{ref_col}'].unique()
    # print("Missing final annotations: ",  set(df_ct[f'RNA.{ref_col}.ct'].unique()) - set(final_annots), set(df_ct[f'RNA.{ref_col}'].unique()) - set(final_annots))
    print("Finall Annotations for %s: " % rel_path.name)
    group_adata.obs[f'RNA.{ref_col}'] = df_ct[f'final_{ref_col}'].astype("category").cat.remove_unused_categories().copy()
    group_adata.obs[f'RNA.{ref_col}.Prob'] = df_ct[f'RNA.{ref_col}.Prob'].copy()
    display(group_adata.obs[f'RNA.{ref_col}'].value_counts())
    df_res.append(group_adata.obs[[f'RNA.{ref_col}', f'RNA.{ref_col}.Prob', f'allcools_{ref_col}', f'allcools_{ref_col}_transfer_score']].copy())    

In [None]:
add_obs = pd.concat(df_res, axis=0)
add_obs['RNA.Group'] = add_obs['RNA.Group'].astype("category").cat.remove_unused_categories()
add_obs['allcools_Group'] = add_obs['allcools_Group'].astype("category").cat.remove_unused_categories()
add_obs

In [None]:
### print Entropy: 
from scipy.stats import entropy

subclass_entropies = {}
for _class in add_obs[f'allcools_{ref_col}'].cat.categories:
    probs = add_obs.loc[add_obs[f'allcools_{ref_col}'] == _class, f'allcools_{ref_col}_transfer_score']
    subclass_entropies[_class] = entropy(probs.round(3).value_counts().sort_index())
    print(f"Class: {_class}, Entropy: {subclass_entropies[_class]}")

df_rna_entropy = pd.DataFrame(index=subclass_entropies.keys(), data=subclass_entropies.values(), columns=['Entropy']).sort_values(by='Entropy', ascending=False)
fig, ax = plt.subplots(figsize=(6,4), dpi=200)
bars = sns.barplot(data=df_rna_entropy, x=df_rna_entropy.index, y='Entropy', ax=ax, edgecolor='black', linewidth=0.5, color='coral')
ax.set_xticklabels(bars.get_xticklabels(), rotation=45, horizontalalignment='right', fontsize=4)
ax.set_title("Entropy of RNA Subclass assignment probabilities")
ax.set_xlabel("RNA Subclass")
plt.tight_layout()
plt.show()

In [None]:
### print Entropy: 
from scipy.stats import entropy

subclass_entropies = {}
for _class in add_obs[f'RNA.{ref_col}'].cat.categories:
    probs = add_obs.loc[add_obs[f'RNA.{ref_col}'] == _class, f'RNA.{ref_col}.Prob']
    subclass_entropies[_class] = entropy(probs.round(3).value_counts().sort_index())
    print(f"Class: {_class}, Entropy: {subclass_entropies[_class]}")

df_rna_entropy = pd.DataFrame(index=subclass_entropies.keys(), data=subclass_entropies.values(), columns=['Entropy']).sort_values(by='Entropy', ascending=False)
fig, ax = plt.subplots(figsize=(6,4), dpi=200)
bars = sns.barplot(data=df_rna_entropy, x=df_rna_entropy.index, y='Entropy', ax=ax, edgecolor='black', linewidth=0.5, color='coral')
ax.set_xticklabels(bars.get_xticklabels(), rotation=45, horizontalalignment='right', fontsize=4)
ax.set_title("Entropy of RNA Subclass assignment probabilities")
ax.set_xlabel("RNA Subclass")
plt.tight_layout()
plt.show()

In [None]:
add_obs_rna = add_obs.copy()

In [None]:
add_obs_rna['RNA.Group'].value_counts()

### Methylation 

In [None]:
ref_col = "Group"
qry_col = "subclass_leiden"
cluster_col = "integrated_leiden"
df_res = []
for rel_path in mc_group_dir.glob("*.h5ad"):
    print(rel_path)
    group_adata = ad.read_h5ad(rel_path)
    # group_adata = group_adata[group_adata.obs['Modality'] == "MERSCOPE"].copy()
    ct = "_".join(rel_path.name.split("_")[:-1])
    ct_path = rel_path.parent / f"{ct}_cell_mapping.tsv"
    if ct_path.exists(): 
        group_adata = group_adata[group_adata.obs['Modality'] == "MERSCOPE"].copy()
        df_map = pd.read_csv(ct_path, sep='\t', header=0, index_col=0)

        all_cat = df_map[ref_col].unique().tolist()
        # print(df_map.shape[0], df_map[ref_col].nunique(), df_map[f'{qry_col}_cell_count'].sum() / group_adata.shape[0])
        df_map = df_map.query(f"{ref_col}_cell_count > 5 & {qry_col}_cell_count > 5")
        # print(df_map.shape[0], df_map[ref_col].nunique(), set(all_cat) - set(df_map[ref_col].unique().tolist()), df_map[f'{qry_col}_cell_count'].sum() / group_adata.shape[0])
        df_map = df_map.query(f"{ref_col}_proportion > 0.5")
        # print(df_map.shape[0], df_map[ref_col].nunique(), set(all_cat) - set(df_map[ref_col].unique().tolist()), df_map[f'{qry_col}_cell_count'].sum() / group_adata.shape[0])
        group_adata.obs[f'MC.{ref_col}.ct'] = group_adata.obs[cluster_col].astype(str).map(df_map[ref_col].to_dict()).fillna("unknown").astype('category')
        # display(group_adata.obs[f'RNA.{ref_col}.ct'].value_counts())
        keeper_clusters = df_map.query(f"{ref_col}_proportion>0.8").index.tolist()
    else: 
        group_adata.obs[f'MC.{ref_col}.ct'] = group_adata.obs[f'infer_{ref_col}_c2c'].copy().astype("category").cat.remove_unused_categories().copy()
        group_adata.obs[f'infer_{ref_col}_prob'] = 1.0
    
    # break

    group_adata.obs[f'MC.{ref_col}'] = group_adata.obs[f'infer_{ref_col}'].astype("category").cat.remove_unused_categories().copy()
    group_adata.obs[f'MC.{ref_col}.Prob'] = group_adata.obs[f'infer_{ref_col}_prob'].copy()
    group_adata.obs[f'MC.{ref_col}'] = group_adata.obs[f'MC.{ref_col}'].cat.add_categories("unknown")
    group_adata.obs.loc[group_adata.obs[f'MC.{ref_col}.Prob'] < 0.6, f'MC.{ref_col}'] = "unknown"
    # display(group_adata.obs[f'RNA.{ref_col}'].value_counts())

    # print("Overall agreement b/w cluster 2 cluster and cell 2 cell : %.3f%%" % 
    #     ((group_adata.obs[f'RNA.{ref_col}.ct'].astype(str) == group_adata.obs[f'RNA.{ref_col}'].astype(str)).sum() / group_adata.obs.shape[0] * 100))

    df_ct = group_adata.obs[[f'MC.{ref_col}', f'MC.{ref_col}.Prob', f'MC.{ref_col}.ct', cluster_col]].copy()
    # display(df_ct.head())
    df_ct[f'final_{ref_col}'] = "U"
    df_ct['eql_col'] = df_ct.apply(lambda x: x[f'MC.{ref_col}'] if x[f'MC.{ref_col}'] == x[f'MC.{ref_col}.ct'] else x[f'final_{ref_col}'], axis=1)
    df_ct['keeper_clust'] = df_ct.apply(lambda x: x[f'MC.{ref_col}.ct'] if x[cluster_col] in keeper_clusters else "U", axis=1)
    df_ct['ff'] = df_ct.apply(lambda x: x[f'MC.{ref_col}'] if x[f'MC.{ref_col}.Prob'] > 0.75 else "U", axis=1)
    # vc_ef = df_ct[['eql_col', 'ff', 'keeper_clust']].value_counts().reset_index()
    # display(vc_ef)
    # df_u = pd.concat((df_ct['eql_col'].value_counts().to_frame(name='eql_cols'),
    #         df_ct['keeper_clust'].value_counts().to_frame(name="keeper_clust"),
    #         df_ct['ff'].value_counts().to_frame(name='ff')),
    #         axis=1).fillna(0).astype(int)
    # display(df_u)
    for i, _cell in enumerate(df_ct.index): 
        ff_annot = df_ct.at[_cell, 'ff']
        keeper_annot = df_ct.at[_cell, 'keeper_clust']
        eql_annot = df_ct.at[_cell, 'eql_col']

        if ff_annot == 'U' and keeper_annot == 'U':
            annot = eql_annot
        elif ff_annot == keeper_annot:
            annot = ff_annot
        elif ff_annot == 'U':
            annot = keeper_annot
        elif keeper_annot == 'U':
            annot = ff_annot
        elif ff_annot != keeper_annot: 
            annot = "unknown"
        df_ct.at[_cell, f'final_{ref_col}'] = annot
    df_ct[f'final_{ref_col}'] = df_ct[f'final_{ref_col}'].replace("U", "unknown")
    vc_ef = df_ct[['eql_col', 'ff', 'keeper_clust', 'final_Group']].value_counts().reset_index()
    # display(vc_ef)
    final_annots = df_ct[f'final_{ref_col}'].unique()
    # print("Missing final annotations: ",  set(df_ct[f'MC.{ref_col}.ct'].unique()) - set(final_annots), set(df_ct[f'MC.{ref_col}'].unique()) - set(final_annots))
    # print("Finall Annotations for %s: " % rel_path.name)
    group_adata.obs[f'MC.{ref_col}'] = df_ct[f'final_{ref_col}'].astype("category").cat.remove_unused_categories().copy()
    group_adata.obs[f'MC.{ref_col}.Prob'] = df_ct[f'MC.{ref_col}.Prob'].copy()
    # display(group_adata.obs[f'MC.{ref_col}'].value_counts())
    df_res.append(group_adata.obs[[f'MC.{ref_col}', f'MC.{ref_col}.Prob', f'infer_{ref_col}', f'infer_{ref_col}_prob']].copy())    

In [None]:
add_obs = pd.concat(df_res, axis=0)
add_obs['infer_Group'] = add_obs['infer_Group'].astype("category").cat.remove_unused_categories()
add_obs['MC.Group'] = add_obs['MC.Group'].astype("category").cat.remove_unused_categories()
add_obs['MC.Group'].value_counts()

In [None]:
### print Entropy: 
from scipy.stats import entropy

subclass_entropies = {}
for _class in add_obs[f'infer_{ref_col}'].cat.categories:
    probs = add_obs.loc[add_obs[f'infer_{ref_col}'] == _class, f'infer_{ref_col}_prob']
    subclass_entropies[_class] = entropy(probs.round(3).value_counts().sort_index())
    print(f"Class: {_class}, Entropy: {subclass_entropies[_class]}")

df_rna_entropy = pd.DataFrame(index=subclass_entropies.keys(), data=subclass_entropies.values(), columns=['Entropy']).sort_values(by='Entropy', ascending=False)
fig, ax = plt.subplots(figsize=(6,4), dpi=200)
bars = sns.barplot(data=df_rna_entropy, x=df_rna_entropy.index, y='Entropy', ax=ax, edgecolor='black', linewidth=0.5, color='coral')
ax.set_xticklabels(bars.get_xticklabels(), rotation=45, horizontalalignment='right', fontsize=4)
ax.set_title("Entropy of RNA Subclass assignment probabilities")
ax.set_xlabel("RNA Subclass")
plt.tight_layout()
plt.show()

In [None]:
### print Entropy: 
from scipy.stats import entropy

subclass_entropies = {}
for _class in add_obs[f'MC.{ref_col}'].cat.categories:
    probs = add_obs.loc[add_obs[f'MC.{ref_col}'] == _class, f'MC.{ref_col}.Prob']
    subclass_entropies[_class] = entropy(probs.round(3).value_counts().sort_index())
    print(f"Class: {_class}, Entropy: {subclass_entropies[_class]}")

df_rna_entropy = pd.DataFrame(index=subclass_entropies.keys(), data=subclass_entropies.values(), columns=['Entropy']).sort_values(by='Entropy', ascending=False)
fig, ax = plt.subplots(figsize=(6,4), dpi=200)
bars = sns.barplot(data=df_rna_entropy, x=df_rna_entropy.index, y='Entropy', ax=ax, edgecolor='black', linewidth=0.5, color='coral')
ax.set_xticklabels(bars.get_xticklabels(), rotation=45, horizontalalignment='right', fontsize=4)
ax.set_title("Entropy of RNA Subclass assignment probabilities")
ax.set_xlabel("RNA Subclass")
plt.tight_layout()
plt.show()

In [None]:
add_obs_mc = add_obs.copy()

## Add Obs

In [None]:
nn_adata = ad.read_h5ad(nn_path)
neu_adata = ad.read_h5ad(neu_path)
# adata = nn_adata.concatenate(neu_adata, batch_key="neuron_type", batch_categories=["nn", "neu"], index_unique=None)

In [None]:
add_obs_rna['Is_Neuron'] = add_obs_rna.apply(lambda x: "Neuron" if x.name in neu_adata.obs_names else "Nonneuron", axis=1)
add_obs_mc['Is_Neuron'] = add_obs_mc.apply(lambda x: "Neuron" if x.name in neu_adata.obs_names else "Nonneuron", axis=1)

In [None]:
print("RNA")
rem_adata=neu_adata[~neu_adata.obs_names.isin(add_obs_rna.loc[add_obs_rna['Is_Neuron'] == "Neuron"].index)]
print("Neuron")
display(rem_adata.obs['RNA.Subclass'].value_counts())
rem_adata=nn_adata[~nn_adata.obs_names.isin(add_obs_rna.loc[add_obs_rna['Is_Neuron'] == "Nonneuron"].index)]
print("Nonneuron")
display(rem_adata.obs['RNA.Subclass'].value_counts())
print("MC")
rem_adata=neu_adata[~neu_adata.obs_names.isin(add_obs_mc.loc[add_obs_mc['Is_Neuron'] == "Neuron"].index)]
print("Neuron")
display(rem_adata.obs['RNA.Subclass'].value_counts())
rem_adata=nn_adata[~nn_adata.obs_names.isin(add_obs_mc.loc[add_obs_mc['Is_Neuron'] == "Nonneuron"].index)]
print("Nonneuron")
display(rem_adata.obs['RNA.Subclass'].value_counts())

In [None]:
df_list = []

rem_adata=neu_adata[~neu_adata.obs_names.isin(add_obs_mc.loc[add_obs_mc['Is_Neuron'] == "Neuron"].index)]
for _class in rem_adata.obs['RNA.Subclass'].cat.categories:
    cell_ids = rem_adata.obs_names[rem_adata.obs['RNA.Subclass'] == _class].tolist()
    print(_class, len(cell_ids), pd.Series(cell_ids).isin(add_obs_rna.index).all())
    if pd.Series(cell_ids).isin(add_obs_rna.index).all():
        df_list.append(pd.DataFrame(index=cell_ids, data={"MC.Group": _class, "MC.Group.Prob": 1.0, "infer_Group": _class, "infer_Group_prob": 1.0, "Is_Neuron": "Neuron"}))

rem_adata=nn_adata[~nn_adata.obs_names.isin(add_obs_mc.loc[add_obs_mc['Is_Neuron'] == "Nonneuron"].index)]
for _class in rem_adata.obs['RNA.Subclass'].cat.categories:
    cell_ids = rem_adata.obs_names[rem_adata.obs['RNA.Subclass'] == _class].tolist()
    print(_class, len(cell_ids), pd.Series(cell_ids).isin(add_obs_rna.index).all())
    if pd.Series(cell_ids).isin(add_obs_rna.index).all(): 
        df_list.append(pd.DataFrame(index=cell_ids, data={"MC.Group": _class, "MC.Group.Prob": 1.0, "infer_Group": _class, "infer_Group_prob": 1.0, "Is_Neuron": "Nonneuron"}))
    
add_obs_mc = pd.concat((add_obs_mc, pd.concat(df_list, axis=0)))

In [None]:
add_obs_rna["Is_Neuron"].value_counts(), add_obs_mc["Is_Neuron"].value_counts(), add_obs_mc.shape[0], add_obs_rna.shape[0]

In [None]:
add_obs = pd.concat((add_obs_rna, add_obs_mc), axis=1) # .iloc[34000:34010, :]

In [None]:
print("Total Agreement at group level between RNA and MC annotations: %.3f%%" % 
      ((add_obs['RNA.Group'] == add_obs['MC.Group']).mean() * 100))

In [None]:
adata_all = nn_adata.concatenate(neu_adata, batch_key="neuron_type", batch_categories=["nn", "neu"], index_unique=None)

In [None]:
for key, value in nn_adata.uns.items(): 
    if key.endswith("_colors") or key.endswith("_palette"):
        adata_all.uns[key] = value
# adata_all.obs[]

In [None]:
add_obs = add_obs.drop(columns=["Is_Neuron"])

In [None]:
for _col in add_obs.columns: 
    print(_col, add_obs[_col].dtype)
    adata_all.obs[_col] = add_obs[_col].copy()
    if add_obs[_col].dtype.name == "category":
        adata_all.obs[_col] = adata_all.obs[_col].cat.add_categories("unknown") if "unknown" not in adata_all.obs[_col].cat.categories else adata_all.obs[_col]
        adata_all.obs[_col] = adata_all.obs[_col].fillna("unknown").astype("category").cat.remove_unused_categories()
    elif add_obs[_col].dtype == "object":
        adata_all.obs[_col] = adata_all.obs[_col].fillna("unknown").astype("category").cat.remove_unused_categories()
    elif add_obs[_col].dtype == "float64":
        adata_all.obs[_col] = adata_all.obs[_col].fillna(0).astype("float32")

In [None]:
adata_all.obs

### Overlaps

In [None]:
_plot_overlap_heatmap(adata_all, ref_col='RNA.Group', qry_col='MC.Group')
_plot_overlap_heatmap(adata_all, ref_col='MC.Group', qry_col='RNA.Group')

In [None]:
adata_all.obs['Combined.Group'] = "Unassigned"
adata_all.obs['Combined.Group'] = adata_all.obs.apply(lambda x: x['RNA.Group'] if x['RNA.Group'] != "unknown" else x['MC.Group'], axis=1)
adata_all.obs['Combined.Group'] = adata_all.obs['Combined.Group'].astype("category").cat.remove_unused_categories().copy()

In [None]:
adata_all.obs['Combined.Group'].value_counts().head(20)

In [None]:
final_path = "/anvil/projects/x-mcb130189/aklein/BICAN/BG/data/BICAN_BG_PFV8_annotated_v3.h5ad"
adata_all.write_h5ad(final_path)

In [None]:
### Add the colors: 
def add_colors(adata, cat_col, palette):
    colors = []
    for _cat in adata.obs[cat_col].cat.categories: 
        try:
            if isinstance(palette, dict):
                color = palette[_cat]
            else:
                color = palette.loc[_cat, 'Hex']
        except KeyError:
            print(_cat)
            color = '#808080'
        colors.append(color)

    adata.uns[f'{cat_col}_colors'] = colors   

In [None]:
add_colors(adata_all, "RNA.Subclass", adata_all.uns["AIT_subclass_palette"])
add_colors(adata_all, "MC.Subclass", adata_all.uns["AIT_subclass_palette"])
add_colors(adata_all, "Combined.Subclass", adata_all.uns["AIT_subclass_palette"])
add_colors(adata_all, "RNA.Group", adata_all.uns["AIT_group_palette"])
add_colors(adata_all, "MC.Group", adata_all.uns["AIT_group_palette"])
add_colors(adata_all, "Combined.Group", adata_all.uns["AIT_group_palette"])