In [None]:
import cellcharter as cc
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap
import pandas as pd
import seaborn as sns
import scanpy as sc
import anndata as ad
import squidpy as sq
import os
import json

In [None]:
extension = 'png'
save_path = f'/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/paper/plots/figures/figure_5/{extension}/'
save_path_supp = f'/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/paper/plots/figures/suppl_figure_5/{extension}/'
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 300
adata = ad.read_h5ad("/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/results/standard/adatas/cells_final.h5ad")
with open('/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/Paper/figure_plots/neighborhood_color_map.json', 'r') as f:
    neighborhood_color_map = json.load(f)
with open('/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/Paper/figure_plots/phenotype_color_map.json', 'r') as f:
    phenotype_color_map = json.load(f)
with open('/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/Paper/figure_plots/disease_color_map.json', 'r') as f:
    disease_color_map = json.load(f)
with open('/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/Paper/figure_plots/disease3_color_map.json', 'r') as f:
    disease3_color_map = json.load(f)
neighborhood_colors = [neighborhood_color_map[cat] for cat in list(adata.obs['cellcharter_CN'].cat.categories)]
neighborhood_colors = ListedColormap(neighborhood_colors)
phenotype_colors = [phenotype_color_map[cat] for cat in list(adata.obs['Phenotype4'].cat.categories)]
phenotype_colors = ListedColormap(phenotype_colors)
disease_colors = [disease_color_map[cat] for cat in list(adata.obs['disease2'].cat.categories)]
disease_colors = ListedColormap(disease_colors)
disease3_colors = [disease3_color_map[cat] for cat in list(adata.obs['disease3'].cat.categories)]
disease3_colors = ListedColormap(disease3_colors)

In [None]:
md = pd.read_csv('/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/metadata/metadata.csv')
md = md[md['Cohort'].isin(['B', 'UB'])]
adata = adata[adata.obs['patient_ID'].isin(md['IMC label'])]
adata.obs['PFS_group'] = adata.obs['patient_ID'].map(md.set_index('IMC label')['Short_long PFS (< or >  2 yrs)'])
adata.obs['PFS_group'] = adata.obs['PFS_group'].astype('category')

# Diff. NBH enrichment between PFS groups

In [None]:
cc.gr.diff_nhood_enrichment(
    adata,
    cluster_key=f'cellcharter_CN',
    condition_key='PFS_group',
    library_key='image_ID',
)

In [None]:
cc.pl.diff_nhood_enrichment(
    adata,
    cluster_key=f'cellcharter_CN',
    condition_key='PFS_group',
    condition_groups=['long_PFS', 'short_PFS'],
    figsize=(4,4),
    fontsize=12,
    palette=neighborhood_color_map,
    save = os.path.join(save_path, 'diff_nhood_enrichment.png')
)

# Lets take specific cellcharter nbh enrichment scores per patient

In [None]:
cc.gr.nhood_enrichment(
    adata,
    cluster_key='cellcharter_CN',
    pvalues=True,
    n_jobs=8,
    n_perms=250
)

In [None]:
cc.pl.nhood_enrichment(
    adata,
    cluster_key='cellcharter_CN',
    annotate=True,
    figsize=(4,4),
    significance=0.05,
    fontsize=12,
    vmax=0.03208789256832183,
    vmin=-0.03208789256832183
)

In [None]:
score_categories = ['adaptive_immune*bone_vasculature', 'adaptive_immune*focal_pc_oxphos', 'focal_pc_oxphos*stroma_adipocyte',
                    'pc_myeloid*focal_pc_oxphos', 'bone_vasculature*focal_pc_oxphos', 'hypoxic_immune*focal_pc_oxphos']
def add_patient(patient_id):
    score_dict[patient_id] = {category: [] for category in score_categories}
score_dict = {}

In [None]:
for patient in md['IMC label'].unique():
    add_patient(patient)
    adata_subset = adata[adata.obs['patient_ID'] == patient].copy()
    cc.gr.nhood_enrichment(
        adata_subset,
        cluster_key='cellcharter_CN',
    )
    for category in score_categories:
        index_nbh1, index_nbh2 = category.split('*')
        score_dict[patient][category] = pd.DataFrame(adata_subset.uns['cellcharter_CN_nhood_enrichment']['enrichment']).loc[index_nbh1, index_nbh2]
    del adata_subset


In [None]:
df = pd.DataFrame(score_dict).T
df['patient_ID'] = df.index
pfs_mappping = adata.obs[['PFS_group', 'patient_ID']].drop_duplicates()
pfs_mappping.set_index('patient_ID', inplace=True)
df['pfs'] = df['patient_ID'].map(pfs_mappping['PFS_group'])
df.dropna(inplace=True)
df.sort_values('pfs', inplace=True)
df.set_index('patient_ID', inplace=True)
df

In [None]:
df[df['pfs'] == 'short_PFS']

In [None]:
df.to_csv('/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/clinical_correlation/nhood_enrichment_patients.csv')

In [None]:
adata_subset.uns['cellcharter_CN_nhood_enrichment']['enrichment']

In [None]:
index_nbh1, index_nbh2

In [None]:
score_categories = ['adaptive_immune*bone_vasculature', 'adaptive_immune*focal_pc_oxphos', 'focal_pc_oxphos*stroma_adipocyte',
                    'pc_myeloid*focal_pc_oxphos', 'bone_vasculature*focal_pc_oxphos', 'hypoxic_immune*focal_pc_oxphos']
def add_image(image_id):
    score_dict[image_id] = {category: [] for category in score_categories}
score_dict = {}
for image_id in adata.obs['image_ID'].unique():
    add_image(image_id)
    adata_subset = adata[adata.obs['image_ID'] == image_id].copy()
    cc.gr.nhood_enrichment(
        adata_subset,
        cluster_key='cellcharter_CN',
    )
    for category in score_categories:
        index_nbh1, index_nbh2 = category.split('*')

        if (index_nbh1 not in adata_subset.uns['cellcharter_CN_nhood_enrichment']['enrichment'].index) or (index_nbh2 not in adata_subset.uns['cellcharter_CN_nhood_enrichment']['enrichment'].index):
            score_dict[image_id][category] = np.nan

        else: 
            score_dict[image_id][category] = pd.DataFrame(adata_subset.uns['cellcharter_CN_nhood_enrichment']['enrichment']).loc[index_nbh1, index_nbh2]
    del adata_subset


In [None]:
df = pd.DataFrame(score_dict).T
df['image_ID'] = df.index
pfs_mappping = adata.obs[['PFS_group', 'image_ID']].drop_duplicates()
pfs_mappping.set_index('image_ID', inplace=True)
df['pfs'] = df['image_ID'].map(pfs_mappping['PFS_group'])
df.dropna(inplace=True)
df.sort_values('pfs', inplace=True)
df.set_index('image_ID', inplace=True)
df

In [None]:
df.to_csv('/Users/lukashat/Documents/PhD_Schapiro/Projects/Myeloma_Standal/github/myeloma_standal/src/downstream/clinical_correlation/nhood_enrichment_images.csv')

In [None]:
pfs_color_map = {'short_PFS': sns.color_palette('Set3')[0], 'long_PFS': sns.color_palette('Set3')[1]}
pfs_color_map

In [None]:
plt.style.use('default')
fig, ax = plt.subplots(figsize=(20, 2))
ax_plot = sns.heatmap(df.drop(columns=['pfs']).T, cmap='coolwarm', center=0, cbar_kws={'label': 'enrichment \n focal_pc_oxphos \n stroma_adipocyte', 'shrink': 1.2, 'aspect': 8, 'pad': 0.13}, ax=ax, vmin=-0.15, vmax=0.15)
disease_colors = [pfs_color_map[d] for d in df['pfs']]
ax_disease = fig.add_axes([ax.get_position().x0, 
                           ax.get_position().y1 + 0.01, 
                           ax.get_position().width, 
                           0.14])
ax_disease.imshow([disease_colors], 
                aspect='auto',
                extent=[0, len(disease_colors), 0, 1])
ax_disease.set_xticks([])
ax_disease.set_yticks([])
ax_disease.set_xticklabels('')
ax_plot.set_xlabel('')
ax_plot.set_xticklabels('')
ax_plot.set_yticklabels(ax_plot.get_yticklabels(), size=16)

legend_patches = [mpatches.Patch(color=color, label=disease_) 
                for disease_, color in pfs_color_map.items()]
cbar = ax_plot.collections[0].colorbar
cbar.ax.tick_params(labelsize=16)          # Tick labels
cbar.ax.yaxis.label.set_size(14) 
ax.legend(handles=legend_patches, 
        bbox_to_anchor=(1.0, 1.6),
        framealpha=0.0, 
        loc='upper left',
        title='Disease Cohort',
        title_fontsize=16,
        fontsize=14)
#plt.tight_layout()