In [None]:
import os
from rich import inspect, print as rprint
from pathlib import Path
import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc

from spida.P.setup_adata import _calc_embeddings
from spida.I.allcools import run_allcools_seurat

import matplotlib.pyplot as plt
import seaborn as sns

from spida.pl import plot_categorical, plot_continuous
# Move Plot ALLCools functions outside of the cli
plt.rcParams['axes.facecolor'] = 'white'

from PyComplexHeatmap import *

from datetime import datetime 
current_datetime = datetime.now().strftime("%Y-%m-%d_%H:%M")

In [None]:
#parameters
EXP = None
REF_EXP = None

# ref_adata_path="/home/x-aklein2/projects/aklein/BICAN/data/reference/AIT/AIT_{REF_EXP}_filtered.h5ad"
ref_adata_path="/home/x-aklein2/projects/aklein/BICAN/data/reference/AIT/AIT_{REF_EXP}.h5ad"
spatial_adata_path="/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_{EXP}/{EXP}.h5ad"
integrator_path="/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_{EXP}/{EXP}_joint.h5ad"

anndata_store_path = "/home/x-aklein2/projects/aklein/BICAN/data/aggregated"
annotations_store_path = "/home/x-aklein2/projects/aklein/BICAN/data/annotated"
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/annotations"

rna_cell_type_column="Subclass"
qry_cluster_column="base_round1_leiden"
coord_base="base_round1_umap"
top_deg_genes = 30
max_cells_per_cluster = 2000
min_cells_per_cluster = 20
label_transfer_k = 50
run_joint_embeddings=True
joint_harmony_batch_keys=["Modality", "donor"]
run_clust_label_transfer=True
joint_embedding_leiden_res=4
confusion_matrix_cluster_min_value=0.25
confusion_matrix_cluster_max_value=0.9
confusion_matrix_cluster_resolution=1.5
save_integrator=False
save_adata_comb=False
save_query=False
outdir=None


umap_coord_base="base_round1_umap"
tsne_coord_base="base_round1_tsne"


In [None]:
ref_adata_path = ref_adata_path.format(REF_EXP=REF_EXP)
spatial_adata_path = spatial_adata_path.format(EXP=EXP)
integrator_path = integrator_path.format(EXP=EXP)
image_path = Path(image_path) / EXP
image_path.mkdir(parents=True, exist_ok=True)

In [None]:
adata_ref = ad.read_h5ad(ref_adata_path)
adata_ref

In [None]:
spatial_adata = ad.read_h5ad(spatial_adata_path)
# clear categories if already exist: 
added_cols = ['all_annot', 'allcools_Subclass_transfer_score', 'c2c_allcools_label_Subclass', 'allcools_Subclass', 'allcools_Subclass_filt']
for col in added_cols: 
    if col in spatial_adata.obs.columns: 
        del spatial_adata.obs[col]
spatial_adata

### SPIDA Annotations

In [None]:
adata, adata_joint = run_allcools_seurat(
    ref_adata=adata_ref, 
    qry_adata=spatial_adata, 
    anndata_store_path=anndata_store_path,
    annotations_store_path=annotations_store_path,
    rna_cell_type_column=rna_cell_type_column,
    qry_cluster_column=qry_cluster_column,
    top_deg_genes=top_deg_genes,
    label_transfer_k=label_transfer_k,
    max_cells_per_cluster=max_cells_per_cluster,
    min_cells_per_cluster=min_cells_per_cluster,
    run_joint_embeddings=run_joint_embeddings,
    joint_embedding_leiden_res=joint_embedding_leiden_res,
    harmony_batch_keys=joint_harmony_batch_keys,
    run_clust_label_transfer=run_clust_label_transfer,
    confusion_matrix_cluster_min_value=confusion_matrix_cluster_min_value,
    confusion_matrix_cluster_max_value=confusion_matrix_cluster_max_value,
    confusion_matrix_cluster_resolution=confusion_matrix_cluster_resolution,
    save_integrator=save_integrator,
    save_adata_comb=save_adata_comb,
    save_query=save_query,
)

In [None]:
# For Color Schemes! 
for key, value in adata.uns.items(): 
    if key.endswith('_palette') or key.endswith('_colors'):
        adata_joint.uns[key] = value

In [None]:
fig, axes = plt.subplots(2, 2,figsize=(12, 12), dpi=200)
axes = axes.flatten()

plot_categorical(adata_joint, coord_base="integrated_umap", cluster_col=rna_cell_type_column, show=False, coding=True, text_anno=True, ax=axes[0])
axes[0].set_title(f"HMBA {EXP} Cells")
plot_categorical(adata_joint, coord_base="integrated_umap", cluster_col="integrated_leiden", show=False, coding=True, text_anno=True, ax=axes[1])
axes[1].set_title("Integrated Leiden Clusters")
plot_categorical(adata_joint, coord_base="integrated_umap", cluster_col="donor", show=False, coding=True, text_anno=True, ax=axes[2])
axes[2].set_title("Donor")
plot_categorical(adata_joint, coord_base="integrated_umap", cluster_col="replicate", show=False, coding=True, text_anno=True, ax=axes[3])
axes[3].set_title("Replicate")


plt.savefig(image_path / f"{current_datetime}_01_subclass_integrated_space.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

### Cell 2 Cell

In [None]:
fig, ax = plt.subplots(figsize=(12, 6), dpi=200)
ax.hist(adata.obs['allcools_Subclass_transfer_score'], bins=100, color='lightblue', edgecolor='k')
ax.axvline(x=0.75, color='red', linestyle='--')
ax.set_title("Distribution of Subclass Transfer Scores (Cell 2 Cell)")
ax.set_xlabel("Subclass Transfer Score")
ax.set_ylabel("Frequency")

plt.savefig(image_path / f"{current_datetime}_01_subclass_transfer_score_dist.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
fig, ax = plt.subplots(figsize=(10, 6), dpi=200)
for _cell_type in adata.obs['allcools_Subclass'].unique():
    transfer_scores = adata.obs.loc[adata.obs['allcools_Subclass'] == _cell_type, 'allcools_Subclass_transfer_score']
    ax.hist(transfer_scores, bins=50, alpha=0.5, label=_cell_type)
    ax.axvline(x=0.75, color='red', linestyle='--')

    n_cells = (adata.obs['allcools_Subclass'] == _cell_type).sum()
    n_low_quality = ((adata.obs['allcools_Subclass'] == _cell_type) & (adata.obs['allcools_Subclass_transfer_score'] < 0.75)).sum()
    print(f"{_cell_type}: {n_cells} cells, {n_low_quality} low quality annotations ({n_low_quality/n_cells:.2%})")

ax.legend(title="Cell Type", bbox_to_anchor=(1.05, 1), loc='upper left')


plt.savefig(image_path / f"{current_datetime}_01_subclass_transfer_score_dist_by_celltype.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
adata.obs['allcools_Subclass_filt'] = adata.obs['allcools_Subclass'].astype('category')
low_quality_annot = adata.obs[adata.obs['allcools_Subclass_transfer_score'] < 0.5].index
if "unknown" not in adata.obs['allcools_Subclass_filt'].cat.categories:
    adata.obs['allcools_Subclass_filt'] = adata.obs['allcools_Subclass_filt'].cat.add_categories("unknown")
adata.obs.loc[low_quality_annot, 'allcools_Subclass_filt'] = "unknown"

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))
vc = adata.obs['allcools_Subclass_filt'].value_counts()
bars = ax.bar(x = vc.index, height=vc.values, color='lightblue', edgecolor='black', alpha=1)
# texts = ax.bar_label(bars, padding=3, fontsize=8, color='black', weight='bold')
texts=[]
for j,rect in enumerate(bars):
    left = rect.get_x()+0.5
    top = rect.get_y()+rect.get_height()
    texts.append(ax.text(left,top,'%i'%vc.values[j], ha='center', va='bottom', weight='bold', fontsize=8))
# adjustText.adjust_text(
#     texts, add_objects=bars, only_move='y+', ax=ax,
#     arrowprops=dict(arrowstyle="-", color="k", alpha=1))
ax.set_xticklabels(vc.index, rotation=45, ha='right')
ax.set_title(f"Subclass distribution (n = {vc.sum()} / {adata.shape[0]})")
ax.set_ylabel("Number of cells")
ax.set_xlabel("Annotated Subclass")
plt.tight_layout()

plt.savefig(image_path / f"{current_datetime}_01_subclass_transfer_cell_numbers.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
low_qc_list = {}
for _cell_type in adata.obs['allcools_Subclass'].unique():
    transfer_scores = adata.obs.loc[adata.obs['allcools_Subclass'] == _cell_type, 'allcools_Subclass_transfer_score']
    n_cells = (adata.obs['allcools_Subclass'] == _cell_type).sum()
    n_low_quality = ((adata.obs['allcools_Subclass'] == _cell_type) & (adata.obs['allcools_Subclass_transfer_score'] < 0.75)).sum()
    low_quality_pct = n_low_quality / n_cells if n_cells > 0 else 0
    low_qc_list[_cell_type] = low_quality_pct
    # print(f"{_cell_type}: {n_cells} cells, {n_low_quality} low quality annotations ({low_quality_pct:.2%})")

fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.bar(low_qc_list.keys(), low_qc_list.values(), color='lightblue', edgecolor='black', alpha=1)
texts=[]
for j,rect in enumerate(bars):
    left = rect.get_x()+0.5
    top = rect.get_y()+rect.get_height()
    texts.append(ax.text(left,top,'%i'%vc.values[j], ha='center', va='bottom', weight='bold', fontsize=8))
    
ax.set_xticklabels(low_qc_list.keys(), rotation=45, ha='right')
ax.set_title("Proportion of Low Quality Annotations by Cell Type")
ax.set_ylabel("Proportion of Low Quality Annotations")
ax.set_xlabel("Cell Type")
plt.tight_layout()

plt.savefig(image_path / f"{current_datetime}_01_subclass_transfer_low_quality_annotations.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

### cluster 2 cluster

In [None]:
ref_group = adata_joint.uns['c2c_allcools_integration_results']['ref_group']['ref_group']
qry_group = adata_joint.uns['c2c_allcools_integration_results']['qry_group']['qry_group']
confusion_matrix = adata_joint.uns['c2c_allcools_integration_results']['confusion_matrix']

In [None]:
fig, ax = plt.subplots(figsize=(6, 6), dpi=200)
sns.heatmap(confusion_matrix, vmin=0.1, vmax=0.9, cmap='cividis', cbar=None, ax=ax, yticklabels=True)
cumsum_c, cumsum_r = 0, 0
for xx in ref_group.sort_values().unique():
    rlen, clen = (qry_group==xx).sum(), (ref_group==xx).sum()
    ax.plot([cumsum_c, cumsum_c + clen], [cumsum_r, cumsum_r], c='w', linewidth=0.5)
    ax.plot([cumsum_c, cumsum_c + clen], [cumsum_r + rlen, cumsum_r + rlen], c='w', linewidth=0.5)
    ax.plot([cumsum_c, cumsum_c], [cumsum_r, cumsum_r + rlen], c='w', linewidth=0.5)
    ax.plot([cumsum_c + clen, cumsum_c + clen], [cumsum_r, cumsum_r + rlen], c='w', linewidth=0.5)
    cumsum_c += clen
    cumsum_r += rlen

ax.set_yticklabels(ax.get_yticklabels(), fontsize=6)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=6)

plt.savefig(image_path / f"{current_datetime}_01_subclass_cluster2cluster_matrix.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
ref_col = "Subclass"
cluster_col = "integrated_leiden"

use_adata = adata_joint[adata_joint.obs['Modality'] == 'ref'].copy()
vc = use_adata.obs.loc[:, [cluster_col, ref_col]].value_counts().reset_index()
D = vc.groupby(cluster_col)['count'].sum()
vc['N']=vc[cluster_col].map(D).astype(int)
vc['fraction']=vc['count']/vc['N']
data = vc.pivot(index=cluster_col, columns=ref_col, values='fraction')
data.head()

df_rows=data.index.to_series().to_frame()
cols=data.columns.tolist()
max_idx=np.argmax(data.fillna(0).values,axis=1)
df_rows["GROUP"]=[cols[i] for i in max_idx]
use_rows=[]
for col in data.columns.tolist(): 
    df1=df_rows.loc[df_rows['GROUP']==col]
    if df1.shape[0]==0:
        continue
    use_rows.extend(df1[cluster_col].unique().tolist())
df_rows=df_rows.loc[use_rows]
ct2code=use_adata.obs.assign(code=use_adata.obs[cluster_col].cat.codes).loc[:,[cluster_col,'code']].drop_duplicates().set_index(cluster_col).code.to_dict()
# df_rows['Label']=df_rows[cluster_col].apply(lambda x: f"{ct2code[x]}: {x}")
ret = []
for x in df_rows[cluster_col].tolist():
    ret.extend([f"{ct2code[x]}: {x}"])
df_rows['Label']=ret
df_rows.head()

# Plot
row_ha=HeatmapAnnotation(
    label=anno_label(df_rows.Label,colors='black',relpos=(0,0.5)),
    axis=0,orientation='right',
)

plt.figure(figsize=(24,12))
ClusterMapPlotter(
    data.loc[df_rows.index.tolist()],row_cluster=False,col_cluster=False,cmap='Reds',
    right_annotation=row_ha,row_split=df_rows['GROUP'],row_split_gap=0.5,
    row_split_order=df_rows['GROUP'].unique().tolist(),
    show_rownames=False,show_colnames=True,yticklabels=True,xticklabels=True,
    xticklabels_kws=dict(labelrotation=-60,labelcolor='blue',labelsize=10),
    yticklabels_kws=dict(labelcolor='red',labelsize=10),
    annot=True,fmt='.2g',linewidth=0.05,linecolor='gold',linestyle='-:',
    label='fraction',legend_kws=dict(extend='both',extendfrac=0.1),
    xlabel=ref_col,ylabel=cluster_col,
    xlabel_kws=dict(color='blue',fontsize=14,labelpad=5),xlabel_side='top',
    ylabel_kws=dict(color='red',fontsize=14,labelpad=5), #increace labelpad manually using labelpad (points)
    # xlabel_bbox_kws=dict(facecolor='green'),
    # ylabel_bbox_kws=dict(facecolor='chocolate',edgecolor='red'),|
    # standard_scale=0,
)
plt.savefig(image_path / f"{current_datetime}_01_subclass_int_confusion_matrix.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))
vc = adata.obs['c2c_allcools_label_Subclass'].value_counts()
bars = ax.bar(x = vc.index, height=vc.values, color='lightblue', edgecolor='black', alpha=1)
# texts = ax.bar_label(bars, padding=3, fontsize=8, color='black', weight='bold')
texts=[]
for j,rect in enumerate(bars):
    left = rect.get_x()+0.5
    top = rect.get_y()+rect.get_height()
    texts.append(ax.text(left,top,'%i'%vc.values[j], ha='center', va='bottom', weight='bold', fontsize=8))
# adjustText.adjust_text(
#     texts, add_objects=bars, only_move='y+', ax=ax,
#     arrowprops=dict(arrowstyle="-", color="k", alpha=1))
ax.set_xticklabels(vc.index, rotation=45, ha='right')
ax.set_title(f"Subclass distribution (n = {vc.sum()} / {adata.shape[0]})")
ax.set_ylabel("Number of cells")
ax.set_xlabel("Annotated Subclass")
plt.tight_layout()

plt.savefig(image_path / f"{current_datetime}_01_subclass_cluster2cluster_cell_numbers.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
fig, ax = plt.subplots(1, 3,figsize=(20, 6), dpi=200)

plot_categorical(adata_joint, coord_base="integrated_umap", cluster_col="c2c_allcools_label_Subclass", show=False, coding=True, text_anno=True, ax=ax[0])
plot_categorical(adata, coord_base=coord_base, cluster_col="c2c_allcools_label_Subclass", show=False, coding=True, text_anno=True, ax=ax[1])
plot_categorical(adata[adata.obs['replicate'] == 'salk'], coord_base="spatial", cluster_col="c2c_allcools_label_Subclass", show=False, coding=True, ax=ax[2])

plt.savefig(image_path / f"{current_datetime}_01_subclass_cluster2cluster_joint_umap.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

### post

In [None]:
adata.write_h5ad(spatial_adata_path)
adata_joint.write_h5ad(integrator_path)

In [None]:
cluster_col = "c2c_allcools_label_Subclass"
cell_col = "allcools_Subclass"
cell_prob_col = "allcools_Subclass_transfer_score"

In [None]:
df = adata.obs.loc[:, [cluster_col, cell_col, cell_prob_col]]
df=df.loc[df[cell_prob_col]>=0.5]
D=df[cell_col].to_dict()
use_obs = df.index.tolist()

fig, axes = plt.subplots(1, 3, figsize=(12, 4), dpi=200)
plot_categorical(adata[use_obs,:], coord_base=tsne_coord_base, cluster_col=cluster_col, show=False, coding=True, text_anno=True, ax=axes[0])
axes[0].set_title(f"Cluster2Cluster Annotations")
plot_categorical(adata[use_obs,:], coord_base=tsne_coord_base, cluster_col=cell_col, show=False, coding=True, text_anno=True, ax=axes[1])
axes[1].set_title(f"Cell2Cell Annotations")
plot_continuous(adata[use_obs,:], coord_base=tsne_coord_base, color_by=cell_prob_col, show=False, ax=axes[2], cmap='viridis')
axes[2].set_title("Cell2Cell Annotation Confidence")

plt.savefig(image_path / f"{current_datetime}_01_subclass_tsne_joint.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

fig, axes = plt.subplots(1, 3, figsize=(12, 4), dpi=200)
plot_categorical(adata[use_obs,:], coord_base=umap_coord_base, cluster_col=cluster_col, show=False, coding=True, text_anno=True, ax=axes[0])
axes[0].set_title(f"Cluster2Cluster Annotations")
plot_categorical(adata[use_obs,:], coord_base=umap_coord_base, cluster_col=cell_col, show=False, coding=True, text_anno=True, ax=axes[1])
axes[1].set_title(f"Cell2Cell Annotations")
plot_continuous(adata[use_obs,:], coord_base=umap_coord_base, color_by=cell_prob_col, show=False, ax=axes[2], cmap='viridis')
axes[2].set_title("Cell2Cell Annotation Confidence")

plt.savefig(image_path / f"{current_datetime}_01_subclass_umap_joint.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
cluster_col = "c2c_allcools_label_Subclass"
cell_col = "allcools_Subclass"
cell_prob_col = "allcools_Subclass_transfer_score"

vc = adata.obs.loc[:, [cluster_col, cell_col]].value_counts().reset_index()
D = vc.groupby(cluster_col)['count'].sum()
vc['N']=vc[cluster_col].map(D).astype(int)
vc['fraction']=vc['count']/vc['N']
data = vc.pivot(index=cluster_col, columns=cell_col, values='fraction')
data.head()

df_rows=data.index.to_series().to_frame()
cols=data.columns.tolist()
max_idx=np.argmax(data.fillna(0).values,axis=1)
df_rows["GROUP"]=[cols[i] for i in max_idx]
use_rows=[]
for col in data.columns.tolist(): 
    df1=df_rows.loc[df_rows['GROUP']==col]
    if df1.shape[0]==0:
        continue
    use_rows.extend(df1[cluster_col].unique().tolist())
df_rows=df_rows.loc[use_rows]
ct2code=adata.obs.assign(code=adata.obs[cluster_col].cat.codes).loc[:,[cluster_col,'code']].drop_duplicates().set_index(cluster_col).code.to_dict()
# df_rows['Label']=df_rows[cluster_col].apply(lambda x: f"{ct2code[x]}: {x}")
ret = []
for x in df_rows[cluster_col].tolist():
    ret.extend([f"{ct2code[x]}: {x}"])
df_rows['Label']=ret
df_rows.head()

In [None]:
row_ha=HeatmapAnnotation(
    label=anno_label(df_rows.Label,colors='black',relpos=(0,0.5)),
    axis=0,orientation='right',
)

plt.figure(figsize=(24,12))
ClusterMapPlotter(
    data.loc[df_rows.index.tolist()],row_cluster=False,col_cluster=False,cmap='Reds',
    right_annotation=row_ha,row_split=df_rows['GROUP'],row_split_gap=0.5,
    row_split_order=df_rows['GROUP'].unique().tolist(),
    show_rownames=False,show_colnames=True,yticklabels=True,xticklabels=True,
    xticklabels_kws=dict(labelrotation=-60,labelcolor='blue',labelsize=10),
    yticklabels_kws=dict(labelcolor='red',labelsize=10),
    annot=True,fmt='.2g',linewidth=0.05,linecolor='gold',linestyle='-:',
    label='fraction',legend_kws=dict(extend='both',extendfrac=0.1),
    xlabel=cell_col,ylabel=cluster_col,
    xlabel_kws=dict(color='blue',fontsize=14,labelpad=5),xlabel_side='top',
    ylabel_kws=dict(color='red',fontsize=14,labelpad=5), #increace labelpad manually using labelpad (points)
    # xlabel_bbox_kws=dict(facecolor='green'),
    # ylabel_bbox_kws=dict(facecolor='chocolate',edgecolor='red'),|
    # standard_scale=0,
)
plt.savefig(image_path / f"{current_datetime}_01_subclass_confusion_matrix.png", dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
# vc.groupby("integrated_leiden")['N'].sum().sort_values(ascending=True).to_frame()