In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad

import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go

## functions

In [None]:
from PyComplexHeatmap import * 
def _plot_overlap_heatmap(use_adata, ref_col, qry_col, image_path=None, current_datetime=None):
    if isinstance(use_adata, ad.AnnData):
        use_data = use_adata.obs.copy()
    else: 
        use_data = use_adata.copy()
    vc = use_data.loc[:, [qry_col, ref_col]].value_counts().reset_index()
    D = vc.groupby(qry_col)['count'].sum()
    vc['N']=vc[qry_col].map(D).astype(int)
    vc['fraction']=vc['count']/vc['N']
    data = vc.pivot(index=qry_col, columns=ref_col, values='fraction')
    data.head()

    df_rows=data.index.to_series().to_frame()
    cols=data.columns.tolist()
    max_idx=np.argmax(data.fillna(0).values,axis=1)
    df_rows["GROUP"]=[cols[i] for i in max_idx]
    use_rows=[]
    for col in data.columns.tolist(): 
        df1=df_rows.loc[df_rows['GROUP']==col]
        if df1.shape[0]==0:
            continue
        use_rows.extend(df1[qry_col].unique().tolist())
    df_rows=df_rows.loc[use_rows]
    ct2code=use_data.assign(code=use_data[qry_col].cat.codes).loc[:,[qry_col,'code']].drop_duplicates().set_index(qry_col).code.to_dict()
    # df_rows['Label']=df_rows[cluster_col].apply(lambda x: f"{ct2code[x]}: {x}")
    ret = []
    for x in df_rows[qry_col].tolist():
        ret.extend([f"{ct2code[x]}: {x}"])
    df_rows['Label']=ret
    df_rows.head()

    # Plot
    row_ha=HeatmapAnnotation(
        label=anno_label(df_rows.Label,colors='black',relpos=(0,0.5)),
        axis=0,orientation='right',
    )

    plt.figure(figsize=(24,12))
    ClusterMapPlotter(
        data.loc[df_rows.index.tolist()],row_cluster=False,col_cluster=False,cmap='Reds',
        right_annotation=row_ha,row_split=df_rows['GROUP'],row_split_gap=0.5,
        row_split_order=df_rows['GROUP'].unique().tolist(),
        show_rownames=False,show_colnames=True,yticklabels=True,xticklabels=True,
        xticklabels_kws=dict(labelrotation=-60,labelcolor='blue',labelsize=10),
        yticklabels_kws=dict(labelcolor='red',labelsize=10),
        annot=True,fmt='.2g',linewidth=0.05,linecolor='gold',linestyle='-:',
        label='fraction',legend_kws=dict(extend='both',extendfrac=0.1),
        xlabel=ref_col,ylabel=qry_col,
        xlabel_kws=dict(color='blue',fontsize=14,labelpad=5),xlabel_side='top',
        ylabel_kws=dict(color='red',fontsize=14,labelpad=5), #increace labelpad manually using labelpad (points)
        # xlabel_bbox_kws=dict(facecolor='green'),
        # ylabel_bbox_kws=dict(facecolor='chocolate',edgecolor='red'),|
        # standard_scale=0,
    )
    plt.show()
    plt.close()


## Read

In [None]:
subclass_color_palette = pd.read_excel("/home/x-aklein2/projects/aklein/BICAN/data/color_scheme.xlsx", sheet_name="Subclass", index_col=0).to_dict()['Hex']
group_color_palette = pd.read_excel("/home/x-aklein2/projects/aklein/BICAN/data/color_scheme.xlsx", sheet_name="Group", index_col=0).to_dict()['Hex']

In [None]:
data1 = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_PFV8_annotated_v5.h5ad"
# data2 = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPSAM_annotated_v2.h5ad"
data2 = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_GP_PFV8_annotated.h5ad"
data1_obs = ad.read_h5ad(data1, backed='r').obs.copy()
data2_obs = ad.read_h5ad(data2, backed='r').obs.copy()

In [None]:
data1_obs.columns, data2_obs.columns

In [None]:
# data1_obs['cp_cell_id'] = data1_obs['original_cell_id'].astype(str) + "." + data1_obs['dataset_id'].astype(str)
# data1_obs.reset_index(names="pf_cell_id", inplace=True)
# data1_obs.set_index('cp_cell_id', inplace=True)

In [None]:
data1_obs = data1_obs.loc[data2_obs.index].copy()
d1 = data1_obs[['Subclass', 'Group', 'brain_region']].rename(columns={'Subclass': 'pf.Subclass', 'Group': 'pf.Group'})
d2 = data2_obs[['Subclass', 'Group', 'brain_region']].rename(columns={'Subclass': 'cp.Subclass', 'Group': 'cp.Group'})

In [None]:
data1_obs_neu = data1_obs.loc[data1_obs['neuron_type'] == 'Neuron']
data2_obs_neu = data2_obs.loc[data2_obs['neuron_type'] == 'Neuron']
d1 = data1_obs_neu[['Subclass', 'Group', 'brain_region']].rename(columns={'Subclass': 'pf.Subclass', 'Group': 'pf.Group'})
d2 = data2_obs_neu[['Subclass', 'Group', 'brain_region']].rename(columns={'Subclass': 'cp.Subclass', 'Group': 'cp.Group'})

In [None]:
d1.shape, d2.shape

In [None]:
d2_cols = d2.columns.difference(d1.columns)
d_tog = d1.merge(d2[d2_cols], left_index=True, right_index=True, how="inner")
for col in d_tog.columns:
    d_tog[col] = d_tog[col].cat.add_categories("Missing")
    d_tog[col] = d_tog[col].fillna("Missing")
d_tog.shape

In [None]:
subclass_colors_cat = np.arange(len(subclass_color_palette))
subclass_colors_vals = list(subclass_color_palette.values())
group_colors_cat = np.arange(len(group_color_palette))
group_colors_vals = list(group_color_palette.values())

In [None]:
d_tog['pf.subclass_colors'] = d_tog['pf.Subclass'].map(dict(zip(subclass_color_palette.keys(), subclass_colors_cat)))
d_tog['cp.subclass_colors'] = d_tog['cp.Subclass'].map(dict(zip(subclass_color_palette.keys(), subclass_colors_cat)))
d_tog['pf.group_colors'] = d_tog['pf.Group'].map(dict(zip(group_color_palette.keys(), group_colors_cat)))
d_tog['cp.group_colors'] = d_tog['cp.Group'].map(dict(zip(group_color_palette.keys(), group_colors_cat)))

In [None]:
# d_tog_pu = d_tog.loc[d_tog['brain_region'] == 'GP']

In [None]:
group_color_palette['Missing'] = '#D3D3D3'
subclass_color_palette['Missing'] = '#D3D3D3'

In [None]:
plot_cat = "Subclass"
# colors = d_tog['v1.group_colors']
# colorscale = list([[a, b] for a, b in zip(np.arange(len(group_color_palette)), group_color_palette.values())])

v1_dim = go.parcats.Dimension(
    values=d_tog[f'pf.{plot_cat}'],
    label=f"CELL {plot_cat}"
)

v2_dim = go.parcats.Dimension(
    values=d_tog[f'cp.{plot_cat}'],
    label=f"NUCLEAR {plot_cat}"
)
line_colors = [subclass_color_palette[x] for x in d_tog[f'pf.{plot_cat}']]

fig = go.Figure(data=[go.Parcats(
    dimensions=[v1_dim, v2_dim],
    line={'shape' : "hspline", 'color': line_colors}
)])
fig.update_layout(
    width=2000,
    height=1000,
    margin=dict(l=200)
)
cats = d_tog[f'pf.{plot_cat}'].unique()
fig.update_traces(dimensions=[{"categoryorder": "category descending"} for _ in cats])
fig.show()

In [None]:
_plot_overlap_heatmap(d_tog, ref_col='pf.Subclass', qry_col='cp.Subclass')
_plot_overlap_heatmap(d_tog, ref_col='pf.Group', qry_col='cp.Group')