In [None]:
import os
import sys
from rich import inspect, print as rprint
from pathlib import Path
import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
import scanpy.experimental as sce

from spida.P.setup_adata import _calc_embeddings
from spida.I.allcools import run_allcools_seurat, normalize_adata

import matplotlib.pyplot as plt
import seaborn as sns

from spida.pl import plot_categorical, plot_continuous, categorical_scatter
# Move Plot ALLCools functions outside of the cli
plt.rcParams['axes.facecolor'] = 'white'
from PyComplexHeatmap import *

from datetime import datetime 
current_datetime = datetime.now().strftime("%Y-%m-%d_%H")

In [None]:
#parameters
EXP = None
REF_EXP = None

ref_adata_path="/home/x-aklein2/projects/aklein/BICAN/data/reference/AIT/AIT_{REF_EXP}_filtered.h5ad"
# ref_adata_path="/home/x-aklein2/projects/aklein/BICAN/data/reference/AIT/AIT_{REF_EXP}.h5ad"
spatial_adata_path="/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_{EXP}/{EXP}.h5ad"
save_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_{EXP}/{EXP}_{SUBSET_LEVEL}_{SUBSET_NAME}.h5ad"
integrator_path="/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_{EXP}/{EXP}_{SUBSET_LEVEL}_{SUBSET_NAME}_joint.h5ad"

anndata_store_path = "/home/x-aklein2/projects/aklein/BICAN/data/aggregated"
annotations_store_path = "/home/x-aklein2/projects/aklein/BICAN/data/annotated"
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/annotations"

rna_subset_col="Subclass"
rna_subset_value=None
qry_subset_col="c2c_allcools_label_Subclass"
qry_subset_value=None

added_joint_emb_key =  "subclass_"

spatial_cell_type_column=None

rna_cell_type_column="Group"
qry_cluster_column=None

top_deg_genes = 100
deg_type = "cef"
cef_column = None
max_cells_per_cluster = 10000
min_cells_per_cluster = 20
label_transfer_k = 30
run_joint_embeddings=True
joint_harmony_batch_keys=["Modality"]
run_clust_label_transfer=False
joint_embedding_leiden_res=1.5
confusion_matrix_cluster_min_value=0.25
confusion_matrix_cluster_max_value=0.9
confusion_matrix_cluster_resolution=1.5

save_integrator=False
save_adata_comb=False
save_query=False

run_harmony_joint_embeddings = True
harmony_nclust_joint_embeddings = 4
max_iter_harmony_joint_embeddings = 20

harmony_batch_key=["replicate"]

outdir=None

umap_coord_base="base_round1_umap"
tsne_coord_base="base_round1_tsne"

In [None]:
ref_adata_path = ref_adata_path.format(REF_EXP=REF_EXP)
spatial_adata_path = spatial_adata_path.format(EXP=EXP)
save_path = Path(save_path)
integrator_path = Path(integrator_path)
image_path = Path(image_path)
image_path.mkdir(parents=True, exist_ok=True)
save_path.parent.mkdir(parents=True, exist_ok=True)
integrator_path.parent.mkdir(parents=True, exist_ok=True)
save_label_transfer_path = save_path.parent / f"{save_path.stem}_label_transfer.tsv"

In [None]:
adata_ref = ad.read_h5ad(ref_adata_path)
adata_ref = adata_ref[adata_ref.obs[rna_subset_col] == rna_subset_value].copy()
adata_ref.obs['Group'] = adata_ref.obs["Group"].map({"Oligodendrocyte" : "Oligo OPALIN", "ImAstro" : "Astrocyte"}).fillna(adata_ref.obs["Group"])
adata_ref.X = adata_ref.raw.X.copy()

In [None]:
spatial_adata = ad.read_h5ad(spatial_adata_path)
spatial_adata.obsm['X_spatial'] = spatial_adata.obsm['spatial'].copy()
adata = spatial_adata[spatial_adata.obs[qry_subset_col] == qry_subset_value].copy()
if rna_cell_type_column in adata.obs.columns:
    adata.obs[f'{rna_cell_type_column}_orig'] = adata.obs[rna_cell_type_column].copy()
    adata.obs = adata.obs.drop(columns=[rna_cell_type_column, f'allcools_{rna_cell_type_column}', f'allcools_{rna_cell_type_column}_filt', f'allcools_{rna_cell_type_column}_transfer_score', f'c2c_allcools_label_{rna_cell_type_column}'], errors='ignore')
adata

In [None]:
adata.obs['dataset_id'].value_counts()

In [None]:
common_genes = adata_ref.var_names.intersection(adata.var_names)
print(len(common_genes), "common genes found between reference and query")
adata_ref.X = adata_ref.raw.X.copy()
adata_ref = adata_ref[:, common_genes].copy()
adata_ref.layers['counts'] = adata_ref.X.copy()
normalize_adata(adata_ref)

adata = adata[:, common_genes].copy()
adata.X = adata.layers['counts'].copy()
normalize_adata(adata)

# Embedding

In [None]:
vc = adata_ref.obs[rna_cell_type_column].value_counts()
print(vc)

if len(vc[vc > 20].index) ==  1: 
    print(f"Only one group with > 20 cells found in {rna_subset_value}, skipping")
    annot = vc.index[0]
    
    adata.obs[f'allcools_{rna_cell_type_column}_filt'] = annot
    adata.obs[f'allcools_{rna_cell_type_column}'] = annot
    adata.obs[f'c2c_allcools_label_{rna_cell_type_column}'] = annot
    adata.obs[f'allcools_{rna_cell_type_column}_transfer_score'] = 1.0


    Path(save_path).parent.mkdir(parents=True, exist_ok=True)
    adata.write_h5ad(save_path)
    sys.exit(0)

In [None]:
adata_cp = adata.copy()
sce.pp.highly_variable_genes(adata_cp, n_top_genes=200, subset=True, layer='counts')
 
_calc_embeddings(
    adata_cp,
    layer="volume_norm",
    key_added=added_joint_emb_key,
    leiden_res=2.5,
    min_dist=0.25,
    knn=50,
    run_harmony=True, 
    batch_key=harmony_batch_key,
    harmony_nclust=3,
    max_iter_harmony=20,
)
# Copy Over results
for col in adata_cp.obs.columns: 
    if col not in adata.obs: 
        adata.obs[col] = adata_cp.obs[col].copy()
for key, value in adata_cp.obsm.items(): 
    if key not in adata.obsm: 
        adata.obsm[key] = value.copy()
# _calc_embeddings(adata, knn=35, layer="volume_norm", key_added=added_joint_emb_key)

In [None]:
fig, ax = plt.subplots(1,4,figsize=(20,5), dpi=200)
plot_categorical(adata, coord_base=f"{added_joint_emb_key}umap", cluster_col = f"{added_joint_emb_key}leiden", show=False, coding=True, text_anno=True, ax=ax[0])
plot_categorical(adata, coord_base=f"{added_joint_emb_key}umap", cluster_col = "donor", show=False, ax=ax[1])
plot_categorical(adata, coord_base=f"{added_joint_emb_key}umap", cluster_col = "replicate", show=False, ax=ax[2])
plot_categorical(adata, coord_base=f"{added_joint_emb_key}umap", cluster_col = "brain_region", show=False, ax=ax[3])

plt.savefig(image_path / f"00_group_umap_joint.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

# Integration

In [None]:
qry_cluster_column=f"{added_joint_emb_key}leiden"
coord_base=f"{added_joint_emb_key}umap"

In [None]:
adata, adata_joint = run_allcools_seurat(
    ref_adata=adata_ref, 
    qry_adata=adata, 
    anndata_store_path=anndata_store_path,
    annotations_store_path=annotations_store_path,
    rna_cell_type_column=rna_cell_type_column,
    qry_cluster_column=qry_cluster_column,
    top_deg_genes=top_deg_genes,
    label_transfer_k=label_transfer_k,
    max_cells_per_cluster=max_cells_per_cluster,
    min_cells_per_cluster=min_cells_per_cluster,
    run_joint_embeddings=run_joint_embeddings,
    joint_embedding_leiden_res=joint_embedding_leiden_res,
    run_harmony_joint_embeddings=run_harmony_joint_embeddings,
    harmony_nclust_joint_embeddings=len(adata_ref.obs[rna_cell_type_column].unique()),
    harmony_batch_keys=joint_harmony_batch_keys,
    run_clust_label_transfer=run_clust_label_transfer,
    confusion_matrix_cluster_min_value=confusion_matrix_cluster_min_value,
    confusion_matrix_cluster_max_value=confusion_matrix_cluster_max_value,
    confusion_matrix_cluster_resolution=confusion_matrix_cluster_resolution,
    save_integrator=save_integrator,
    save_adata_comb=save_adata_comb,
    save_query=save_query,
    save_label_transfer_path=save_label_transfer_path,
    normalize_ref = True,
    normalize_qry = True,
    deg_type = deg_type,
    cef_column = cef_column,
)

In [None]:
ref_batch_key = "ref"
qry_batch_key = "query"
integrated_col = "integrated_leiden"
ref_col = "Group"
qry_col = qry_cluster_column

ref_vc=adata_joint.obs.query(f"Modality=='{ref_batch_key}'").groupby(integrated_col)[ref_col].value_counts(normalize=True).sort_values(ascending=False).reset_index()
spatial_vc=adata_joint.obs.query(f"Modality=='{qry_batch_key}'").groupby(integrated_col)[qry_col].value_counts(normalize=True).sort_values(ascending=False).reset_index()
ref_vc.drop_duplicates(integrated_col,keep='first',inplace=True)
spatial_vc.drop_duplicates(integrated_col,keep='first',inplace=True)
ref_vc.rename(columns={'proportion':f'{ref_col}_proportion'},inplace=True)
spatial_vc.rename(columns={'proportion':f'{qry_col}_proportion'},inplace=True)
df_map=pd.concat([ref_vc.set_index(integrated_col),spatial_vc.set_index(integrated_col)],axis=1)
df_map[f'{ref_col}_cell_count']=df_map.index.to_series().map(adata_joint.obs.query(f"Modality=='{ref_batch_key}'").groupby(integrated_col)[ref_col].count()).astype(int)
df_map[f'{qry_col}_cell_count']=df_map.index.to_series().map(adata_joint.obs.query(f"Modality=='{qry_batch_key}'").groupby(integrated_col)[qry_col].count()).astype(int)
adata.uns[f'{integrated_col}_{ref_col}_map']=df_map.copy()
all_cat = set(df_map[ref_col].unique().tolist())

print(df_map.shape[0], df_map[ref_col].nunique(), df_map[f'{qry_col}_cell_count'].sum() / adata.shape[0])
df_map = df_map.query(f"{ref_col}_cell_count > 5 & {qry_col}_cell_count > 1")
print(df_map.shape[0], df_map[ref_col].nunique(), set(all_cat) - set(df_map[ref_col].unique().tolist()), df_map[f'{qry_col}_cell_count'].sum() / adata.shape[0])
df_map = df_map.query(f"{ref_col}_proportion > 0.5")
print(df_map.shape[0], df_map[ref_col].nunique(), set(all_cat) - set(df_map[ref_col].unique().tolist()), df_map[f'{qry_col}_cell_count'].sum() / adata.shape[0])

try: 
    adata_joint.obs[f'c2c_allcools_label_{ref_col}'] = adata_joint.obs['integrated_leiden'].map(df_map[ref_col].to_dict()).fillna("unknown").astype('category')
except TypeError: 
    adata_joint.obs[f'c2c_allcools_label_{ref_col}'] = adata_joint.obs['integrated_leiden'].map(df_map[ref_col].to_dict()).astype('category')
adata.obs[f'c2c_allcools_label_{ref_col}'] = adata_joint.obs.loc[adata_joint.obs['Modality'] == qry_batch_key, f'c2c_allcools_label_{ref_col}'].values

In [None]:
# For Color Schemes! 
for key, value in adata.uns.items(): 
    if key.endswith('_palette') or key.endswith('_colors'):
        adata_joint.uns[key] = value
        
qry_cell_type_key1 = f"allcools_{rna_cell_type_column}"
qry_cell_type_key2 = f"c2c_allcools_label_{rna_cell_type_column}"
adata.obsm['X_integrated_umap'] = adata_joint.obsm['X_integrated_umap'][(adata_joint.obs["Modality"] == "query").values]

fig, axes = plt.subplots(2, 2,figsize=(12, 12), dpi=200)
axes = axes.flatten()

categorical_scatter(data=adata_joint, coord_base="integrated_umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=axes[0])
plot_categorical(adata_joint, coord_base="integrated_umap", cluster_col=rna_cell_type_column, show=False, coding=True, text_anno=True, ax=axes[0])
axes[0].set_title(f"HMBA {EXP} Cells")
categorical_scatter(data=adata_joint, coord_base="integrated_umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=axes[1])
plot_categorical(adata_joint, coord_base="integrated_umap", cluster_col="integrated_leiden", show=False, coding=True, text_anno=True, ax=axes[1])
axes[1].set_title("Integrated Leiden Clusters")
categorical_scatter(data=adata_joint, coord_base="integrated_umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=axes[2])
plot_categorical(adata_joint, coord_base="integrated_umap", cluster_col=qry_cell_type_key1, show=False, coding=True, text_anno=True, ax=axes[2])
axes[2].set_title(f"MERSCOPE {EXP} Cells")
categorical_scatter(data=adata_joint, coord_base="integrated_umap", max_points=None, hue=None, scatter_kws=dict(color='lightgrey'), ax=axes[3])
plot_categorical(adata, coord_base="integrated_umap", cluster_col=qry_cell_type_key2, show=False, coding=True, text_anno=True, ax=axes[3])
axes[3].set_title(f"MERSCOPE {EXP} Cells")

plt.savefig(image_path / f"01_group_integrated_space.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
adata.write_h5ad(save_path)
# adata_joint.write_h5ad(integrator_path)

## plots

### cell 2 cell

In [None]:
fig, ax = plt.subplots(figsize=(12, 6), dpi=200)
ax.hist(adata.obs['allcools_Group_transfer_score'], bins=100, color='lightblue', edgecolor='k')
ax.axvline(x=0.75, color='red', linestyle='--')
ax.set_title("Distribution of Group Transfer Scores (Cell 2 Cell)")
ax.set_xlabel("Group Transfer Score")
ax.set_ylabel("Frequency")

plt.savefig(image_path / f"01_group_transfer_score_dist.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()


fig, ax = plt.subplots(figsize=(10, 6), dpi=200)
for _cell_type in adata.obs['allcools_Group'].unique():
    transfer_scores = adata.obs.loc[adata.obs['allcools_Group'] == _cell_type, 'allcools_Group_transfer_score']
    ax.hist(transfer_scores, bins=50, alpha=0.5, label=_cell_type)
    ax.axvline(x=0.75, color='red', linestyle='--')

    n_cells = (adata.obs['allcools_Group'] == _cell_type).sum()
    n_low_quality = ((adata.obs['allcools_Group'] == _cell_type) & (adata.obs['allcools_Group_transfer_score'] < 0.75)).sum()
    print(f"{_cell_type}: {n_cells} cells, {n_low_quality} low quality annotations ({n_low_quality/n_cells:.2%})")

ax.legend(title="Cell Type", bbox_to_anchor=(1.05, 1), loc='upper left')

plt.savefig(image_path / f"01_group_transfer_score_dist_by_celltype.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()


adata.obs['allcools_Group_filt'] = adata.obs['allcools_Group'].astype('category')
low_quality_annot = adata.obs[adata.obs['allcools_Group_transfer_score'] < 0.5].index
if "unknown" not in adata.obs['allcools_Group_filt'].cat.categories:
    adata.obs['allcools_Group_filt'] = adata.obs['allcools_Group_filt'].cat.add_categories("unknown")
adata.obs.loc[low_quality_annot, 'allcools_Group_filt'] = "unknown"


fig, ax = plt.subplots(figsize=(10, 6))
vc = adata.obs['allcools_Group_filt'].value_counts()
bars = ax.bar(x = vc.index, height=vc.values, color='lightblue', edgecolor='black', alpha=1)
# texts = ax.bar_label(bars, padding=3, fontsize=8, color='black', weight='bold')
texts=[]
for j,rect in enumerate(bars):
    left = rect.get_x()+0.5
    top = rect.get_y()+rect.get_height()
    texts.append(ax.text(left,top,'%i'%vc.values[j], ha='center', va='bottom', weight='bold', fontsize=8))
# adjustText.adjust_text(
#     texts, add_objects=bars, only_move='y+', ax=ax,
#     arrowprops=dict(arrowstyle="-", color="k", alpha=1))
ax.set_xticklabels(vc.index, rotation=45, ha='right')
ax.set_title(f"Subclass distribution (n = {vc.sum()} / {adata.shape[0]})")
ax.set_ylabel("Number of cells")
ax.set_xlabel("Annotated Subclass")
plt.tight_layout()

plt.savefig(image_path / f"01_group_transfer_cell_numbers.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()


low_qc_list = {}
for _cell_type in adata.obs['allcools_Group'].unique():
    transfer_scores = adata.obs.loc[adata.obs['allcools_Group'] == _cell_type, 'allcools_Group_transfer_score']
    n_cells = (adata.obs['allcools_Group'] == _cell_type).sum()
    n_low_quality = ((adata.obs['allcools_Group'] == _cell_type) & (adata.obs['allcools_Group_transfer_score'] < 0.75)).sum()
    low_quality_pct = n_low_quality / n_cells if n_cells > 0 else 0
    low_qc_list[_cell_type] = low_quality_pct
    # print(f"{_cell_type}: {n_cells} cells, {n_low_quality} low quality annotations ({low_quality_pct:.2%})")

fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.bar(low_qc_list.keys(), low_qc_list.values(), color='lightblue', edgecolor='black', alpha=1)
texts=[]
for j,rect in enumerate(bars):
    left = rect.get_x()+0.5
    top = rect.get_y()+rect.get_height()
    texts.append(ax.text(left,top,'%i'%vc.values[j], ha='center', va='bottom', weight='bold', fontsize=8))
    
ax.set_xticklabels(low_qc_list.keys(), rotation=45, ha='right')
ax.set_title("Proportion of Low Quality Annotations by Cell Type")
ax.set_ylabel("Proportion of Low Quality Annotations")
ax.set_xlabel("Cell Type")
plt.tight_layout()

plt.savefig(image_path / f"01_group_transfer_low_quality_annotations.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

fig, ax = plt.subplots(1, 3,figsize=(20, 6), dpi=200)
plot_categorical(adata, coord_base=coord_base, cluster_col="allcools_Group_filt", show=False, coding=True, text_anno=True, ax=ax[0])
plot_categorical(adata[adata.obs['replicate'] == 'ucsd'], coord_base="spatial", cluster_col="allcools_Group_filt", show=False, coding=True, ax=ax[1])
plot_categorical(adata[adata.obs['replicate'] == 'salk'], coord_base="spatial", cluster_col="allcools_Group_filt", show=False, coding=True, ax=ax[2])

plt.savefig(image_path / f"01_group_cell2cell_joint_umap.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

### custer 2 cluster

In [None]:
ref_col = "Group"
cluster_col = "integrated_leiden"

use_adata = adata_joint[adata_joint.obs['Modality'] == 'ref'].copy()
vc = use_adata.obs.loc[:, [cluster_col, ref_col]].value_counts().reset_index()
D = vc.groupby(cluster_col)['count'].sum()
vc['N']=vc[cluster_col].map(D).astype(int)
vc['fraction']=vc['count']/vc['N']
data = vc.pivot(index=cluster_col, columns=ref_col, values='fraction')
data.head()

df_rows=data.index.to_series().to_frame()
cols=data.columns.tolist()
max_idx=np.argmax(data.fillna(0).values,axis=1)
df_rows["GROUP"]=[cols[i] for i in max_idx]
use_rows=[]
for col in data.columns.tolist(): 
    df1=df_rows.loc[df_rows['GROUP']==col]
    if df1.shape[0]==0:
        continue
    use_rows.extend(df1[cluster_col].unique().tolist())
df_rows=df_rows.loc[use_rows]
ct2code=use_adata.obs.assign(code=use_adata.obs[cluster_col].cat.codes).loc[:,[cluster_col,'code']].drop_duplicates().set_index(cluster_col).code.to_dict()
# df_rows['Label']=df_rows[cluster_col].apply(lambda x: f"{ct2code[x]}: {x}")
ret = []
for x in df_rows[cluster_col].tolist():
    ret.extend([f"{ct2code[x]}: {x}"])
df_rows['Label']=ret
df_rows.head()

# Plot
row_ha=HeatmapAnnotation(
    label=anno_label(df_rows.Label,colors='black',relpos=(0,0.5)),
    axis=0,orientation='right',
)

plt.figure(figsize=(24,12))
ClusterMapPlotter(
    data.loc[df_rows.index.tolist()],row_cluster=False,col_cluster=False,cmap='Reds',
    right_annotation=row_ha,row_split=df_rows['GROUP'],row_split_gap=0.5,
    row_split_order=df_rows['GROUP'].unique().tolist(),
    show_rownames=False,show_colnames=True,yticklabels=True,xticklabels=True,
    xticklabels_kws=dict(labelrotation=-60,labelcolor='blue',labelsize=10),
    yticklabels_kws=dict(labelcolor='red',labelsize=10),
    annot=True,fmt='.2g',linewidth=0.05,linecolor='gold',linestyle='-:',
    label='fraction',legend_kws=dict(extend='both',extendfrac=0.1),
    xlabel=ref_col,ylabel=cluster_col,
    xlabel_kws=dict(color='blue',fontsize=14,labelpad=5),xlabel_side='top',
    ylabel_kws=dict(color='red',fontsize=14,labelpad=5), #increace labelpad manually using labelpad (points)
    # xlabel_bbox_kws=dict(facecolor='green'),
    # ylabel_bbox_kws=dict(facecolor='chocolate',edgecolor='red'),|
    # standard_scale=0,
)
plt.savefig(image_path / f"01_group_int_confusion_matrix.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()


fig, ax = plt.subplots(figsize=(10, 6))
vc = adata.obs['c2c_allcools_label_Group'].value_counts()
bars = ax.bar(x = vc.index, height=vc.values, color='lightblue', edgecolor='black', alpha=1)
# texts = ax.bar_label(bars, padding=3, fontsize=8, color='black', weight='bold')
texts=[]
for j,rect in enumerate(bars):
    left = rect.get_x()+0.5
    top = rect.get_y()+rect.get_height()
    texts.append(ax.text(left,top,'%i'%vc.values[j], ha='center', va='bottom', weight='bold', fontsize=8))
# adjustText.adjust_text(
#     texts, add_objects=bars, only_move='y+', ax=ax,
#     arrowprops=dict(arrowstyle="-", color="k", alpha=1))
ax.set_xticklabels(vc.index, rotation=45, ha='right')
ax.set_title(f"Group distribution (n = {vc.sum()} / {adata.shape[0]})")
ax.set_ylabel("Number of cells")
ax.set_xlabel("Annotated Group")
plt.tight_layout()

plt.savefig(image_path / f"01_group_cluster2cluster_cell_numbers.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()


fig, ax = plt.subplots(1, 3,figsize=(20, 6), dpi=200)

if "c2c_allcools_label_Group_colors" in adata_joint.uns: 
    del adata_joint.uns['c2c_allcools_label_Group_colors']
if "c2c_allcools_label_Group_palette" in adata_joint.uns:
    del adata_joint.uns['c2c_allcools_label_Group_palette']
if "c2c_allcools_label_Group_colors" in adata.uns: 
    del adata.uns['c2c_allcools_label_Group_colors']
if "c2c_allcools_label_Group_palette" in adata.uns:
    del adata.uns['c2c_allcools_label_Group_palette']
plot_categorical(adata_joint[adata_joint.obs['Modality'] == qry_batch_key], coord_base="integrated_umap", cluster_col="c2c_allcools_label_Group", show=False, coding=True, text_anno=True, ax=ax[0])
plot_categorical(adata, coord_base=coord_base, cluster_col="c2c_allcools_label_Group", show=False, coding=True, text_anno=True, ax=ax[1])
plot_categorical(adata, coord_base="spatial", cluster_col="c2c_allcools_label_Group", show=False, coding=True, ax=ax[2])

plt.savefig(image_path / f"01_group_cluster2cluster_joint_umap.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

### joint

In [None]:
cluster_col = "c2c_allcools_label_Group"
cell_col = "allcools_Group"
cell_prob_col = "allcools_Group_transfer_score"

df = adata.obs.loc[:, [cluster_col, cell_col, cell_prob_col]]
df=df.loc[df[cell_prob_col]>=0.5]
D=df[cell_col].to_dict()
use_obs = df.index.tolist()

fig, axes = plt.subplots(1, 3, figsize=(12, 4), dpi=200)
plot_categorical(adata[use_obs,:], coord_base=tsne_coord_base, cluster_col=cluster_col, show=False, coding=True, text_anno=True, ax=axes[0])
axes[0].set_title(f"Cluster2Cluster Annotations")
plot_categorical(adata[use_obs,:], coord_base=tsne_coord_base, cluster_col=cell_col, show=False, coding=True, text_anno=True, ax=axes[1])
axes[1].set_title(f"Cell2Cell Annotations")
plot_continuous(adata[use_obs,:], coord_base=tsne_coord_base, color_by=cell_prob_col, show=False, ax=axes[2], cmap='viridis')
axes[2].set_title("Cell2Cell Annotation Confidence")

plt.savefig(image_path / f"01_group_tsne_joint.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

fig, axes = plt.subplots(1, 3, figsize=(12, 4), dpi=200)
plot_categorical(adata[use_obs,:], coord_base=umap_coord_base, cluster_col=cluster_col, show=False, coding=True, text_anno=True, ax=axes[0])
axes[0].set_title(f"Cluster2Cluster Annotations")
plot_categorical(adata[use_obs,:], coord_base=umap_coord_base, cluster_col=cell_col, show=False, coding=True, text_anno=True, ax=axes[1])
axes[1].set_title(f"Cell2Cell Annotations")
plot_continuous(adata[use_obs,:], coord_base=umap_coord_base, color_by=cell_prob_col, show=False, ax=axes[2], cmap='viridis')
axes[2].set_title("Cell2Cell Annotation Confidence")

plt.savefig(image_path / f"01_group_umap_joint.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()


cluster_col = "c2c_allcools_label_Group"
cell_col = "allcools_Group"
cell_prob_col = "allcools_Group_transfer_score"

vc = adata.obs.loc[:, [cluster_col, cell_col]].value_counts().reset_index()
D = vc.groupby(cluster_col)['count'].sum()
vc['N']=vc[cluster_col].map(D).astype(int)
vc['fraction']=vc['count']/vc['N']
data = vc.pivot(index=cluster_col, columns=cell_col, values='fraction')
data.head()

df_rows=data.index.to_series().to_frame()
cols=data.columns.tolist()
max_idx=np.argmax(data.fillna(0).values,axis=1)
df_rows["GROUP"]=[cols[i] for i in max_idx]
use_rows=[]
for col in data.columns.tolist(): 
    df1=df_rows.loc[df_rows['GROUP']==col]
    if df1.shape[0]==0:
        continue
    use_rows.extend(df1[cluster_col].unique().tolist())
df_rows=df_rows.loc[use_rows]
ct2code=adata.obs.assign(code=adata.obs[cluster_col].cat.codes).loc[:,[cluster_col,'code']].drop_duplicates().set_index(cluster_col).code.to_dict()
# df_rows['Label']=df_rows[cluster_col].apply(lambda x: f"{ct2code[x]}: {x}")
ret = []
for x in df_rows[cluster_col].tolist():
    ret.extend([f"{ct2code[x]}: {x}"])
df_rows['Label']=ret
df_rows.head()
row_ha=HeatmapAnnotation(
    label=anno_label(df_rows.Label,colors='black',relpos=(0,0.5)),
    axis=0,orientation='right',
)

plt.figure(figsize=(24,12))
ClusterMapPlotter(
    data.loc[df_rows.index.tolist()],row_cluster=False,col_cluster=False,cmap='Reds',
    right_annotation=row_ha,row_split=df_rows['GROUP'],row_split_gap=0.5,
    row_split_order=df_rows['GROUP'].unique().tolist(),
    show_rownames=False,show_colnames=True,yticklabels=True,xticklabels=True,
    xticklabels_kws=dict(labelrotation=-60,labelcolor='blue',labelsize=10),
    yticklabels_kws=dict(labelcolor='red',labelsize=10),
    annot=True,fmt='.2g',linewidth=0.05,linecolor='gold',linestyle='-:',
    label='fraction',legend_kws=dict(extend='both',extendfrac=0.1),
    xlabel=cell_col,ylabel=cluster_col,
    xlabel_kws=dict(color='blue',fontsize=14,labelpad=5),xlabel_side='top',
    ylabel_kws=dict(color='red',fontsize=14,labelpad=5), #increace labelpad manually using labelpad (points)
    # xlabel_bbox_kws=dict(facecolor='green'),
    # ylabel_bbox_kws=dict(facecolor='chocolate',edgecolor='red'),|
    # standard_scale=0,
)
plt.savefig(image_path / f"01_group_confusion_matrix.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()