In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import pandas as pd
import numpy as np
import quiche as qu
import matplotlib.pyplot as plt
import seaborn as sns
import anndata
import matplotlib.cm as cm
import imageio as io
import datetime
import shutil
from supplementary_plot_helpers import *
%reload_ext autoreload
%load_ext autoreload
%autoreload 2
%matplotlib inline

## Unstructured simulation viz

In [None]:
# da_vec_A = ['A', 'C', 'E']
# da_vec_B = ['B', 'D']
# n_regions = 1
# n_patients_condA = 10
# n_patients_condB = 10
# sample_size_A = {'A': 1000, 'B': 1000, 'C': 1000, 'D':1000, 'E': 1000}
# sample_size_B = {'A': 1000, 'B': 1000, 'C': 1000, 'D':1000, 'E': 2000}
# n_niches_A = np.array(list(sample_size_A.values())).sum()
# n_niches_B = np.array(list(sample_size_B.values())).sum()
# random_state_list_A = [58, 322, 1426, 65, 651, 417, 2788, 576, 213, 1828]
# random_state_list_B = [51, 1939, 2700, 1831, 804, 2633, 2777, 2053, 948, 420]
# A_id_join = ''.join(da_vec_A)
# B_id_join = ''.join(da_vec_B)
# save_directory = os.path.join('data', 'simulated', 'test')

# for id_save in ['low_high', 'high_high', 'low_low', 'high_low']:
#     if id_save == 'low_high':
#         grid_size = 8
#         ratio = 1.0
#     elif id_save == 'high_high':
#         grid_size = 3
#         ratio = 1.0
#     elif id_save == 'low_low':
#         grid_size = 8
#         ratio = 0.4
#     elif id_save == 'high_low':
#         grid_size = 3
#         ratio = 0.4
#     else:
#         print(f'{id_save} not recognized')

#     adata_simulated = qu.tl.simulate_unstructured(n_patients_condA = n_patients_condA, n_patients_condB = n_patients_condB, num_grids_x = grid_size, num_grids_y = grid_size, ratio = ratio,
#                                                   n_niches_A = n_niches_A, n_niches_B = n_niches_B, n_regionsA = n_regions, n_regionsB = n_regions, da_vec_A = da_vec_A, da_vec_B = da_vec_B,
#                                                   random_state_list_A = random_state_list_A, scale = 2048, random_state_list_B = random_state_list_B, sample_size_A = sample_size_A,
#                                                   sample_size_B = sample_size_B, fig_id = 'fig_id', save_directory=save_directory)
    
#     adata_simulated.write_h5ad(os.path.join('data', 'simulated', f'adata_simulated_unstructured_{id_save}.h5ad'))

In [None]:
colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
    "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf",
    "#aec7e8", "#ffbb78", "#98df8a", "#ff9896", "#c5b0d5",
    "#c49c94", "#f7b6d2", "#c7c7c7", "#dbdb8d", "#9edae5"]

save_directory = os.path.join('publications', 'supplementary_figures', 'supplementary_figure01-02')
qu.pp.make_directory(save_directory)

for id_save in ['low_high', 'high_high', 'low_low', 'high_low']:
    adata_simulated = anndata.read_h5ad(os.path.join('data', 'simulated', f'adata_simulated_unstructured_{id_save}.h5ad'))
    save_directory_ = os.path.join(save_directory, 'predicted')
    qu.pp.make_directory(save_directory_)
    spatial_method = qu.tl.run_quiche
    spatial_method_params = {'radius': 200,
                            'labels_key':'cell_cluster',
                            'spatial_key':'spatial',
                            'fov_key':'Patient_ID',
                            'patient_key':'Patient_ID',
                            'khop':3,
                            'n_neighbors': 10,
                            'delaunay': False,
                            'min_cells':5,
                            'k_sim':100,
                            'design':'~condition',
                            'model_contrasts':'conditionA-conditionB',
                            'sketch_size':None,
                            'nlargest': 5,
                            'annotation_key':'quiche_niche',
                            'n_jobs':-1,
                            'label_scheme':'neighborhood_norm',
                            'sig_key':'PValue'}

    benchmarker = qu.tl.benchmark(adata = adata_simulated, spatial_method = spatial_method, spatial_method_params = spatial_method_params)
    mdata, _ = benchmarker.perform_enrichment()
    scores_df = pd.DataFrame(mdata['quiche'].var.groupby('quiche_niche')['SpatialFDR'].mean())
    scores_df.columns = ['pval']
    scores_df.fillna(0, inplace=True)
    scores_df['logFC'] = mdata['quiche'].var.groupby('quiche_niche')['logFC'].mean()
    scores_df = scores_df[scores_df['pval'] < 0.05]
    ids = list(set(scores_df.index).intersection(set(list(mdata['quiche'].var['quiche_niche'].value_counts()[mdata['quiche'].var['quiche_niche'].value_counts() >= 5].index))))
    scores_df = scores_df.loc[ids]
    niches = list(scores_df.index)

    df_A0 = pd.DataFrame(mdata['expression'][mdata['expression'].obs['Patient_ID'] == 'A0'].obsm['spatial'], columns=['X0', 'Y0'])
    df_A0['DA_group'] = mdata['spatial_nhood'].obs[mdata['spatial_nhood'].obs['Patient_ID'] == 'A0'].DA_group.values
    df_A0['cell_cluster'] = mdata['spatial_nhood'].obs[mdata['spatial_nhood'].obs['Patient_ID'] == 'A0'].cell_cluster.values
    pval = mdata['quiche'].var[np.isin(mdata['quiche'].var.index_cell, mdata['spatial_nhood'][mdata['spatial_nhood'].obs['Patient_ID'] == 'A0'].obs_names)].SpatialFDR.values
    niche = mdata['quiche'].var[np.isin(mdata['quiche'].var.index_cell, mdata['spatial_nhood'][mdata['spatial_nhood'].obs['Patient_ID'] == 'A0'].obs_names)].quiche_niche.values
    df_A0['pval'] = -1 * np.log10(pval)
    df_A0['quiche_niche'] = niche

    df_B0 = pd.DataFrame(mdata['expression'][mdata['expression'].obs['Patient_ID'] == 'B0'].obsm['spatial'], columns=['X0', 'Y0'])
    df_B0['DA_group'] = mdata['spatial_nhood'].obs[mdata['spatial_nhood'].obs['Patient_ID'] == 'B0'].DA_group.values
    df_B0['cell_cluster'] = mdata['spatial_nhood'].obs[mdata['spatial_nhood'].obs['Patient_ID'] == 'B0'].cell_cluster.values
    pval = mdata['quiche'].var[np.isin(mdata['quiche'].var.index_cell, mdata['spatial_nhood'][mdata['spatial_nhood'].obs['Patient_ID'] == 'B0'].obs_names)].SpatialFDR.values
    df_B0['pval'] = -1 * np.log10(pval)
    niche = mdata['quiche'].var[np.isin(mdata['quiche'].var.index_cell, mdata['spatial_nhood'][mdata['spatial_nhood'].obs['Patient_ID'] == 'B0'].obs_names)].quiche_niche.values
    df_B0['quiche_niche'] = niche

    n_niches = len(list(mdata['quiche'].var['quiche_niche'].unique()))
    colors_dict = dict(zip(mdata['quiche'].var['quiche_niche'].unique(), colors[:n_niches]))

    ## ground truth
    plot_unstructured_niche(df_A0,  {'A': '#B46CDA','B': '#78CE8B', 'C': '#FF8595', 'D': '#1885F2', 'E': '#D78F09'}, (4,4), 'cell_cluster', 'DA_group', ['A_C_E'], 'Ground Truth', save_directory_, f'ground_truth_ACE_{id_save}', axes_label = False, ext = '.tiff', legend=False)
    plot_unstructured_niche(df_B0,  {'A': '#B46CDA','B': '#78CE8B', 'C': '#FF8595', 'D': '#1885F2', 'E': '#D78F09'}, (4,4), 'cell_cluster', 'DA_group', ['B_D'], 'Ground Truth', save_directory_, f'ground_truth_BD_{id_save}', axes_label = False, ext = '.tiff', legend=False)

    ## quiche
    plot_unstructured_niche_cat(df_A0,  (4,4), 'quiche_niche', 'QUICHE Niche', save_directory_, f'quiche_ACE_{id_save}', axes_label = False, ext = '.tiff', legend=False)
    plot_unstructured_niche_score(df_A0,  'Reds', (4,4), 'pval', 'quiche_niche', niches, '-log10(p-value)', save_directory_, f'quiche_ACE_pvalue_{id_save}', 0, 2.5, 5, axes_label=False, cbar = False, ext = '.tiff')

    plot_unstructured_niche_cat(df_B0,  (4,4), 'quiche_niche', 'QUICHE Niche', save_directory_, f'quiche_BD_{id_save}', axes_label = False, ext = '.tiff', legend=False)
    plot_unstructured_niche_score(df_B0,  'Reds', (4,4), 'pval', 'quiche_niche', niches, '-log10(p-value)', save_directory_, f'quiche_BD_pvalue_{id_save}', 0, 2.5, 5, axes_label=False, cbar = False, ext = '.tiff')

    for cluster in [3,5,7]:
        spatial_method = qu.tl.evaluate_kmeans
        spatial_method_params = {'n_clusters': cluster, 'random_state': 42, 'fov_key': 'Patient_ID', 'condition_key': 'condition', 'labels_key': 'cell_cluster', 'radius': 200, 'delaunay': False, 'save_directory': None, 'condition_list': ['A', 'B'], 'filename_save': 'simulated', 'sig_threshold': 0.05, 'nlargest':5}

        benchmarker = qu.tl.benchmark(adata = adata_simulated, spatial_method = spatial_method, spatial_method_params = spatial_method_params)
        mdata, sig_niches = benchmarker.perform_enrichment()

        scores_df = pd.DataFrame(mdata['spatial_nhood'].obs.groupby('kmeans_cluster_labeled')['pval'].mean())
        scores_df.columns = ['pval']
        scores_df.fillna(0, inplace=True)
        scores_df = scores_df[scores_df['pval'] < 0.05]
        ids = list(set(scores_df.index).intersection(set(list(mdata['spatial_nhood'].obs['kmeans_cluster_labeled'].value_counts()[mdata['spatial_nhood'].obs['kmeans_cluster_labeled'].value_counts() >= 5].index))))
        scores_df = scores_df.loc[ids]
        niches = list(scores_df.index)

        df_A0 = pd.DataFrame(mdata['expression'][mdata['expression'].obs['Patient_ID'] == 'A0'].obsm['spatial'], columns=['X0', 'Y0'])
        df_A0['DA_group'] = mdata['spatial_nhood'].obs[mdata['spatial_nhood'].obs['Patient_ID'] == 'A0'].DA_group.values
        df_A0['cell_cluster'] = mdata['spatial_nhood'].obs[mdata['spatial_nhood'].obs['Patient_ID'] == 'A0'].cell_cluster.values
        df_A0['pval'] =  -1*np.log10(mdata['spatial_nhood'].obs[mdata['spatial_nhood'].obs['Patient_ID'] == 'A0'].pval.values)
        df_A0['kmeans_cluster_labeled'] = mdata['spatial_nhood'].obs[mdata['spatial_nhood'].obs['Patient_ID'] == 'A0'].kmeans_cluster_labeled.values

        df_B0 = pd.DataFrame(mdata['expression'][mdata['expression'].obs['Patient_ID'] == 'B0'].obsm['spatial'], columns=['X0', 'Y0'])
        df_B0['DA_group'] = mdata['spatial_nhood'].obs[mdata['spatial_nhood'].obs['Patient_ID'] == 'B0'].DA_group.values
        df_B0['cell_cluster'] = mdata['spatial_nhood'].obs[mdata['spatial_nhood'].obs['Patient_ID'] == 'B0'].cell_cluster.values
        df_B0['pval'] =  -1*np.log10(mdata['spatial_nhood'].obs[mdata['spatial_nhood'].obs['Patient_ID'] == 'B0'].pval.values)
        df_B0['kmeans_cluster_labeled'] = mdata['spatial_nhood'].obs[mdata['spatial_nhood'].obs['Patient_ID'] == 'B0'].kmeans_cluster_labeled.values

        n_niches = len(list(mdata['spatial_nhood'].obs['kmeans_cluster_labeled'].unique()))
        colors_dict = dict(zip(mdata['spatial_nhood'].obs['kmeans_cluster_labeled'].unique(), colors[:n_niches]))

        plot_unstructured_niche_cat(df_A0,  (4,4), 'kmeans_cluster_labeled', 'kmeans', save_directory_, f'kmeans_ACE_{cluster}_{id_save}', axes_label = False, ext = '.tiff', legend=False)
        plot_unstructured_niche_score(df_A0,  'Reds', (4,4), 'pval', 'kmeans_cluster_labeled', niches, '-log10(p-value)', save_directory_, f'kmeans_ACE_pvalue_{cluster}_{id_save}', 0, 2.5, 5.0, axes_label=False, cbar = False, ext = '.tiff')

        plot_unstructured_niche_cat(df_B0,  (4,4), 'kmeans_cluster_labeled', 'kmeans', save_directory_, f'kmeans_BD_{cluster}_{id_save}', axes_label = False, ext = '.tiff', legend=False)
        plot_unstructured_niche_score(df_B0,  'Reds', (4,4), 'pval', 'kmeans_cluster_labeled', niches, '-log10(p-value)', save_directory_, f'kmeans_BD_pvalue_{cluster}_{id_save}', 0, 2.5, 5.0, axes_label=False, cbar = False, ext = '.tiff')

    for cluster in [3,5,None]:
        spatial_method = qu.tl.evaluate_cell_charter
        
        spatial_method_params = {'n_clusters':cluster,
                        'fov_key':'Patient_ID',
                        'condition_key':'condition',
                        'max_runs':2,
                        'n_jobs':1,
                        'condition_list':['A', 'B']}
        
        benchmarker = qu.tl.benchmark(adata = adata_simulated, spatial_method = spatial_method, spatial_method_params = spatial_method_params)
        mdata, sig_niches = benchmarker.perform_enrichment()

        scores_df = pd.DataFrame(mdata['expression'].obs.groupby('spatial_cluster')['pval'].mean())
        scores_df.fillna(0, inplace=True)
        scores_df.columns = ['pval']
        scores_df = scores_df[scores_df['pval'] < 0.05]
        ids = list(set(scores_df.index).intersection(set(list(mdata['expression'].obs['spatial_cluster'].value_counts()[mdata['expression'].obs['spatial_cluster'].value_counts() >= 5].index))))
        scores_df = scores_df.loc[ids]
        niches = list(scores_df.index)

        df_A0 = pd.DataFrame(mdata['expression'][mdata['expression'].obs['Patient_ID'] == 'A0'].obsm['spatial'], columns=['X0', 'Y0'])
        df_A0['DA_group'] = mdata['expression'].obs[mdata['expression'].obs['Patient_ID'] == 'A0'].DA_group.values
        df_A0['cell_cluster'] = mdata['expression'].obs[mdata['expression'].obs['Patient_ID'] == 'A0'].cell_cluster.values
        df_A0['pval'] =  -1*np.log10(mdata['expression'].obs[mdata['expression'].obs['Patient_ID'] == 'A0'].pval.values)
        df_A0['spatial_cluster'] = mdata['expression'].obs[mdata['expression'].obs['Patient_ID'] == 'A0'].spatial_cluster.values

        df_B0 = pd.DataFrame(mdata['expression'][mdata['expression'].obs['Patient_ID'] == 'B0'].obsm['spatial'], columns=['X0', 'Y0'])
        df_B0['DA_group'] = mdata['expression'].obs[mdata['expression'].obs['Patient_ID'] == 'B0'].DA_group.values
        df_B0['cell_cluster'] = mdata['expression'].obs[mdata['expression'].obs['Patient_ID'] == 'B0'].cell_cluster.values
        df_B0['pval'] =  -1*np.log10(mdata['expression'].obs[mdata['expression'].obs['Patient_ID'] == 'B0'].pval.values)
        df_B0['spatial_cluster'] = mdata['expression'].obs[mdata['expression'].obs['Patient_ID'] == 'B0'].spatial_cluster.values
        if cluster == None:
            cluster = 'auto'
        
        n_niches = len(list(mdata['expression'].obs['spatial_cluster'].unique()))
        colors_dict = dict(zip(mdata['expression'].obs['spatial_cluster'].unique(), colors[:n_niches]))

        plot_unstructured_niche_cat(df_A0,  (4,4), 'spatial_cluster', 'cellcharter', save_directory_, f'cellcharter_ACE_{cluster}_{id_save}', axes_label = False, ext = '.tiff', legend=False)
        plot_unstructured_niche_score(df_A0,  'Reds', (4,4), 'pval', 'spatial_cluster', niches, '-log10(p-value)', save_directory_, f'cellcharter_ACE_pvalue_{cluster}_{id_save}', 0, 2.5, 5.0, axes_label=False, cbar = False, ext = '.tiff')

        plot_unstructured_niche_cat(df_B0,  (4,4), 'spatial_cluster', 'cellcharter', save_directory_, f'cellcharter_BD_{cluster}_{id_save}', axes_label = False, ext = '.tiff', legend=False)
        plot_unstructured_niche_score(df_B0,  'Reds', (4,4), 'pval', 'spatial_cluster', niches, '-log10(p-value)', save_directory_, f'cellcharter_BD_pvalue_{cluster}_{id_save}', 0, 2.5, 5.0, axes_label=False, cbar = False, ext = '.tiff')

    conditions = ["ACE", "BD"]
    title_list_1 = ['Ground Truth', 'QUICHE', 'QUICHE pval', 'KMeans 3', 'KMeans pval 3', 'KMeans 5', 'KMeans pval 5', 'KMeans 7', 'KMeans pval 7']
    title_list_2 = ['Ground Truth', 'QUICHE', 'QUICHE pval', 'CellCharter 3', 'CellCharter pval 3', 'CellCharter 5', 'CellCharter pval 5', 'CellCharter auto', 'CellCharter pval auto']

    for condition in conditions:
        save_directory_run =os.path.join('publications', 'supplementary_figures', 'supplementary_figure01-02')
        first_row_files = [os.path.join('predicted', f'ground_truth_{condition}_{id_save}.tiff'),
                            os.path.join('predicted', f'quiche_{condition}_{id_save}.tiff'),
                            os.path.join('predicted', f'quiche_{condition}_pvalue_{id_save}.tiff'),
                            os.path.join('predicted',f'kmeans_{condition}_3_{id_save}.tiff'),
                            os.path.join('predicted',f'kmeans_{condition}_pvalue_3_{id_save}.tiff'),
                            os.path.join('predicted',f'kmeans_{condition}_5_{id_save}.tiff'),
                            os.path.join('predicted',f'kmeans_{condition}_pvalue_5_{id_save}.tiff'),
                            os.path.join('predicted', f'kmeans_{condition}_7_{id_save}.tiff'),
                            os.path.join('predicted',f'kmeans_{condition}_pvalue_7_{id_save}.tiff')] 
        
        second_row_files = [os.path.join('predicted', f'ground_truth_{condition}_{id_save}.tiff'),
                            os.path.join('predicted', f'quiche_{condition}_{id_save}.tiff'),
                            os.path.join('predicted', f'quiche_{condition}_pvalue_{id_save}.tiff'),
                            os.path.join('predicted',f'cellcharter_{condition}_3_{id_save}.tiff'),
                            os.path.join('predicted',f'cellcharter_{condition}_pvalue_3_{id_save}.tiff'),
                            os.path.join('predicted',f'cellcharter_{condition}_5_{id_save}.tiff'),
                            os.path.join('predicted',f'cellcharter_{condition}_pvalue_5_{id_save}.tiff'),
                            os.path.join('predicted', f'cellcharter_{condition}_auto_{id_save}.tiff'),
                            os.path.join('predicted',f'cellcharter_{condition}_pvalue_auto_{id_save}.tiff')] 
        # Initialize a 2x9 grid
        fig, axes = plt.subplots(2, 9, figsize=(18, 8))

        # Populate the first row
        for idx, filename in enumerate(first_row_files):
            if idx < 9:  # Prevent index errors
                load_and_plot_image(axes[0, idx], os.path.join(save_directory_run, filename),title_list_1[idx])

        # Populate the second row
        for idx, filename in enumerate(second_row_files):
            if idx < 9:  # Prevent index errors
                load_and_plot_image(axes[1, idx], os.path.join(save_directory_run, filename), title_list_2[idx])

        # Adjust layout and show the grid
        plt.tight_layout(rect=[0, 0, 1, 0.95])  # Leave space for the title
        plt.savefig(os.path.join(save_directory_run, f'unstructured_{id_save}_{condition}.pdf'), bbox_inches = 'tight', dpi = 400)

    # Move each folder to the temporary folder
    folders_to_move = [os.path.join(save_directory_run, 'predicted')]

    for folder in folders_to_move:
        if os.path.exists(folder):
            qu.pp.make_directory(f'tmp_{id_save}')
            destination = os.path.join(f'tmp_{id_save}', os.path.basename(folder))
            shutil.move(folder, destination)
            print(f"Moved: {folder} -> {destination}")
        else:
            print(f"Folder does not exist and was skipped: {folder}")