# Single Sample Method Benchmarking

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 06.01.2023
- **Date of Last Modification:** 08.12.2023

## 1. Setup

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../../../nichecompass-reproducibility/utils")

In [None]:
import gc
import os
import shutil
import warnings
from datetime import datetime

import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import seaborn as sns
import mlflow
import numpy as np
import pandas as pd
import pickle
import plottable
import scanpy as sc
import scib_metrics

from nichecompass.utils import (add_gps_from_gp_dict_to_adata,
                                create_new_color_dict,
                                extract_gp_dict_from_mebocost_es_interactions,
                                extract_gp_dict_from_nichenet_lrt_interactions,
                                extract_gp_dict_from_omnipath_lr_interactions,
                                filter_and_combine_gp_dict_gps)

from benchmarking_utils import *

### 1.2 Define Parameters

In [None]:
metric_cols_single_sample = [
    "cas", "mlami", # global spatial conservation
    "clisis", "gcs", # local spatial conservation
    "nasw", "cnmi", # niche coherence
]
metric_col_weights_single_sample = [ # separate for each category (later multiplied with category_col_weights)
    (1/8), (1/8), # global spatial conservation
    (1/8), (1/8), # local spatial conservation
    (1/4), (1/4), # niche clustering performance
]
metric_col_titles_single_sample = [
    "CAS", # "Cell Type Affinity Similarity",
    "MLAMI", # "Maximum Leiden Adjusted Mutual Info",
    "CLISIS", # "Cell Type Local Inverse Simpson's Index Similarity",
    "GCS", # "Graph Connectivity Similarity",
    "NASW", # "Niche Average Silhouette Width",
    "CNMI", # "Cell Type Normalized Mutual Info",
]

category_cols_single_sample = [
    "Global Spatial Conservation Score",
    "Local Spatial Conservation Score",
    "Niche Coherence Score"]
category_col_weights_single_sample = [
    0.25,
    0.25,
    0.5]
category_col_titles_single_sample = [
    "Global Spatial Conservation Score",
    "Local Spatial Conservation Score",
    "Niche Coherence Score"]

### 1.3 Run Notebook Setup

In [None]:
sc.set_figure_params(figsize=(6, 6))

In [None]:
# Ignore future warnings and user warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)
warnings.simplefilter(action="ignore", category=DeprecationWarning)

In [None]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

In [None]:
# Set mlflow tracking server (run it on the defined port)
mlflow.set_tracking_uri("http://localhost:8889")

### 1.4 Configure Paths and Directories

In [None]:
data_folder_path = "../../datasets/srt_data/gold"
artifact_folder_path = f"../../artifacts"
benchmarking_folder_path = f"{artifact_folder_path}/single_sample_method_benchmarking"

## 2. Method Benchmarking

- Run all model notebooks in the ```notebooks/single_sample_method_benchmarking``` directory before continuing.

### 2.1 Retrieve NicheCompass Runs

#### 2.1.1 seqFISH Mouse Organogenesis

#### 2.1.2 nanoString CosMx SMI Human Non-Small-Cell Lung Cancer (NSCLC)

#### 2.1.4 Vizgen MERFISH Mouse Liver

#### 2.1.5 Slide-seqV2 Mouse Hippocampus

### 2.1 Create Benchmarking Metrics Plots & Run Time Plots

#### 2.1.1 Slide-seqV2 Mouse Hippocampus

In [None]:
### Supplementary figure: spatial clusters ###
adata = sc.read_h5ad(f"../../datasets/srt_data/gold/slideseqv2_mouse_hippocampus.h5ad")
adata.obs["sample"] = "sample1"

leiden_resolution = 0.1

print("\nComputing neighbor graph...")
# Use latent representation for UMAP generation
sc.pp.neighbors(adata,
                use_rep="spatial",
                key_added="spatial_knn")

print("\nComputing UMAP embedding...")
sc.tl.umap(adata,
           neighbors_key="spatial_knn")

print("\nComputing Leiden clustering...")
sc.tl.leiden(adata=adata,
             resolution=leiden_resolution,
             key_added=f"spatial_leiden_{leiden_resolution}",
             neighbors_key="spatial_knn")

In [None]:
spatial_cluster_colors = create_new_color_dict(
    adata=adata,
    cat_key=f"spatial_leiden_{leiden_resolution}")

plot_category_in_latent_and_physical_space(
        adata=adata,
        #figsize=(10, 20), for latent UMAP
        plot_label="Spatial Clusters",
        model_label=None,
        cat_key=f"spatial_leiden_{leiden_resolution}",
        groups=None,
        sample_key="sample",
        samples=["sample1"], # =None for latent UMAP
        cat_colors=spatial_cluster_colors,
        size=(720000 / len(adata)),
        spot_size=30,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/spatial_clusters.svg")

In [None]:
### Main figure: full dataset metrics ###
datasets = ["slideseqv2_mouse_hippocampus"]
models = [#"nichecompass_gcnconv",
          "nichecompass_gatv2conv",
          #"staci", # did not run
          "graphst",
          "deeplinc",
          "sagenet",
          #"scvi",
          #"expimap"
]

summary_df = pd.DataFrame()
for dataset in datasets:
    dataset_df = pd.DataFrame()
    for model in models:
        try:
            benchmark_df = pd.read_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv")
            #adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
            #training_durations = []
            #for run_number in [1, 2, 3, 4, 5, 6, 7, 8]:
            #    training_durations.append(adata.uns[f"{model.split('_')[0]}_model_training_duration_run{run_number}"])
            #benchmark_df["run_time"] = training_durations
            #benchmark_df = benchmark_df[["dataset", "run_number", "run_time", "gcs", "mlami", "cas", "clisis", "nasw", "cnmi", "cari", "casw", "clisi"]]
            #benchmark_df.to_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv", index=False)
            benchmark_df["model"] = model
            dataset_df = pd.concat([dataset_df, benchmark_df], ignore_index=True)
        except FileNotFoundError:
            print(f"Did not find file {benchmarking_folder_path}/{dataset}_{model}_metrics.csv. Continuing...")
            missing_run_data = {
                "dataset": [dataset] * 8,
                "model": [model] * 8,
                "run_number": [1, 2, 3, 4, 5, 6, 7, 8],
                "run_time": [np.nan] * 8
            }
            missing_run_df = pd.DataFrame(missing_run_data)
            dataset_df = dataset_df.append(missing_run_df, ignore_index=True)
            
    # Apply min-max scaling to metric columns
    for i in range(len(metric_cols_single_sample)):
        min_val = dataset_df[metric_cols_single_sample[i]].min()
        max_val = dataset_df[metric_cols_single_sample[i]].max()
        dataset_df[metric_cols_single_sample[i] + "_scaled"] = ((
            dataset_df[metric_cols_single_sample[i]] - min_val) / (max_val - min_val))

    summary_df = pd.concat([summary_df, dataset_df], ignore_index=True)
    continue
    
cat_0_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[0:2]]
cat_1_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[2:4]]
cat_2_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[4:6]]
    
summary_df[category_cols_single_sample[0]] = np.average(summary_df[cat_0_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[0:2],
                                                        axis=1)
summary_df[category_cols_single_sample[1]] = np.average(summary_df[cat_1_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[2:4],
                                                        axis=1)
summary_df[category_cols_single_sample[2]] = np.average(summary_df[cat_2_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[4:6],
                                                        axis=1)
summary_df["Overall Score"] = np.average(summary_df[category_cols_single_sample[:3]],
                                         weights=category_col_weights_single_sample[:3],
                                         axis=1)
 
# Reformat for plot
summary_df.replace({"nichecompass_gatv2conv": "NicheCompass",
                    "nichecompass_gcnconv": "NicheCompass Light",
                    "staci": "STACI",
                    "deeplinc": "DeepLinc",
                    "expimap": "expiMap",
                    "graphst": "GraphST",
                    "sagenet": "SageNet",
                    "scvi": "scVI"}, inplace=True)

# Filter for just second run
summary_df = summary_df[summary_df["run_number"] == 2]

# Plot over all loss weights combinations
# Prepare metrics table plot
group_cols = ["dataset", "model"]
aggregate_df = summary_df.groupby(group_cols).mean("Overall Score").sort_values("Overall Score", ascending=False)[
    metric_cols_single_sample + ["Overall Score"]].reset_index()

unrolled_df = pd.melt(aggregate_df, 
   id_vars=group_cols,
   value_vars=metric_cols_single_sample + ["Overall Score"],
   var_name="score_type", 
   value_name="score")

# Create spatial indicator column
def is_spatially_aware_model(row):
    if row["model"] in ["NicheCompass", "STACI", "DeepLinc", "GraphST", "SageNet"]:
        return True
    return False
unrolled_df["spatially_aware"] = unrolled_df.apply(lambda row: is_spatially_aware_model(row), axis=1)
unrolled_df = unrolled_df[["dataset", "spatially_aware", "model", "score_type", "score"]]

# Order datasets
unrolled_df["dataset"] = pd.Categorical(unrolled_df["dataset"], categories=datasets, ordered=True)
unrolled_df = unrolled_df.sort_values(by="dataset")

unrolled_df["model"] = unrolled_df["model"].replace("NicheCompass GATv2", "NicheCompass")
#unrolled_df["model"] = unrolled_df["model"].replace("NicheCompass GCN", "Mini NicheCompass")

In [None]:
# Plot table
plot_simple_metrics_table(
    df=unrolled_df,
    model_col="model",
    model_col_width=1.6,
    group_col="dataset",
    metric_cols=metric_cols_single_sample, # metric_cols_single_sample, category_cols_single_sample
    metric_col_weights=metric_col_weights_single_sample, # metric_col_weights_single_sample, category_col_weights_single_sample
    metric_col_titles=[col.replace(" ", "\n") for col in metric_col_titles_single_sample], # category_col_titles_single_sample
    metric_col_width=0.8, # 0.8,
    aggregate_col_width=1.2,
    plot_width=8.5, # 32,
    plot_height=8,
    show=True,
    save_dir=benchmarking_folder_path,
    save_name=f"benchmarking_metrics_slideseqv2_mouse_hippocampus_run2.svg")

In [None]:
### Supplementary figure: 25% subsample metrics ###
datasets = ["slideseqv2_mouse_hippocampus_subsample_25pct"]
models = [#"nichecompass_gcnconv",
          "nichecompass_gatv2conv",
          "staci",
          "graphst",
          "deeplinc",
          #"sagenet", # did not run
          #"scvi",
          #"expimap"
]

summary_df = pd.DataFrame()
for dataset in datasets:
    dataset_df = pd.DataFrame()
    for model in models:
        try:
            benchmark_df = pd.read_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv")
            #adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
            #training_durations = []
            #for run_number in [1, 2, 3, 4, 5, 6, 7, 8]:
            #    training_durations.append(adata.uns[f"{model.split('_')[0]}_model_training_duration_run{run_number}"])
            #benchmark_df["run_time"] = training_durations
            #benchmark_df = benchmark_df[["dataset", "run_number", "run_time", "gcs", "mlami", "cas", "clisis", "nasw", "cnmi", "cari", "casw", "clisi"]]
            #benchmark_df.to_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv", index=False)
            benchmark_df["model"] = model
            dataset_df = pd.concat([dataset_df, benchmark_df], ignore_index=True)
        except FileNotFoundError:
            print(f"Did not find file {benchmarking_folder_path}/{dataset}_{model}_metrics.csv. Continuing...")
            missing_run_data = {
                "dataset": [dataset] * 8,
                "model": [model] * 8,
                "run_number": [1, 2, 3, 4, 5, 6, 7, 8],
                "run_time": [np.nan] * 8
            }
            missing_run_df = pd.DataFrame(missing_run_data)
            dataset_df = dataset_df.append(missing_run_df, ignore_index=True)
            
    # Apply min-max scaling to metric columns
    for i in range(len(metric_cols_single_sample)):
        min_val = dataset_df[metric_cols_single_sample[i]].min()
        max_val = dataset_df[metric_cols_single_sample[i]].max()
        dataset_df[metric_cols_single_sample[i] + "_scaled"] = ((
            dataset_df[metric_cols_single_sample[i]] - min_val) / (max_val - min_val))

    summary_df = pd.concat([summary_df, dataset_df], ignore_index=True)
    continue
    
cat_0_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[0:2]]
cat_1_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[2:4]]
cat_2_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[4:6]]
    
summary_df[category_cols_single_sample[0]] = np.average(summary_df[cat_0_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[0:2],
                                                        axis=1)
summary_df[category_cols_single_sample[1]] = np.average(summary_df[cat_1_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[2:4],
                                                        axis=1)
summary_df[category_cols_single_sample[2]] = np.average(summary_df[cat_2_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[4:6],
                                                        axis=1)
summary_df["Overall Score"] = np.average(summary_df[category_cols_single_sample[:3]],
                                         weights=category_col_weights_single_sample[:3],
                                         axis=1)
 
# Reformat for plot
summary_df.replace({"nichecompass_gatv2conv": "NicheCompass",
                    "nichecompass_gcnconv": "NicheCompass Light",
                    "staci": "STACI",
                    "deeplinc": "DeepLinc",
                    "expimap": "expiMap",
                    "graphst": "GraphST",
                    "sagenet": "SageNet",
                    "scvi": "scVI"}, inplace=True)

# Filter for just second run
summary_df = summary_df[summary_df["run_number"] == 2]

# Plot over all loss weights combinations
# Prepare metrics table plot
group_cols = ["dataset", "model"]
aggregate_df = summary_df.groupby(group_cols).mean("Overall Score").sort_values("Overall Score", ascending=False)[
    metric_cols_single_sample + ["Overall Score"]].reset_index()

unrolled_df = pd.melt(aggregate_df, 
   id_vars=group_cols,
   value_vars=metric_cols_single_sample + ["Overall Score"],
   var_name="score_type", 
   value_name="score")

# Create spatial indicator column
def is_spatially_aware_model(row):
    if row["model"] in ["NicheCompass", "STACI", "DeepLinc", "GraphST", "SageNet"]:
        return True
    return False
unrolled_df["spatially_aware"] = unrolled_df.apply(lambda row: is_spatially_aware_model(row), axis=1)
unrolled_df = unrolled_df[["dataset", "spatially_aware", "model", "score_type", "score"]]

# Order datasets
unrolled_df["dataset"] = pd.Categorical(unrolled_df["dataset"], categories=datasets, ordered=True)
unrolled_df = unrolled_df.sort_values(by="dataset")

unrolled_df["model"] = unrolled_df["model"].replace("NicheCompass GATv2", "NicheCompass")
#unrolled_df["model"] = unrolled_df["model"].replace("NicheCompass GCN", "Mini NicheCompass")

In [None]:
# Plot table
plot_simple_metrics_table(
    df=unrolled_df,
    model_col="model",
    model_col_width=1.6,
    group_col="dataset",
    metric_cols=metric_cols_single_sample, # metric_cols_single_sample, category_cols_single_sample
    metric_col_weights=metric_col_weights_single_sample, # metric_col_weights_single_sample, category_col_weights_single_sample
    metric_col_titles=[col.replace(" ", "\n") for col in metric_col_titles_single_sample], # category_col_titles_single_sample
    metric_col_width=0.8, # 0.8,
    aggregate_col_width=1.2,
    plot_width=8.5, # 32,
    plot_height=8,
    show=True,
    save_dir=benchmarking_folder_path,
    save_name=f"benchmarking_metrics_slideseqv2_mouse_hippocampus_subsample_25pct_run2.svg")

In [None]:
### Main figure: full dataset niches ###
dataset = "slideseqv2_mouse_hippocampus"
models = ["nichecompass_gatv2conv"]
run_number = 2
leiden_resolutions = [0.35]

for i, model in enumerate(models):
    adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
    model = model.replace("_gcnconv", "").replace("_gatv2conv", "")
    adata.obs["sample"] = "sample1"
    
    print("\nComputing neighbor graph...")
    # Use latent representation for UMAP generation
    sc.pp.neighbors(adata,
                    use_rep=f"{model}_latent_run{run_number}",
                    key_added=f"{model}_latent_run{run_number}")

    print("\nComputing UMAP embedding...")
    sc.tl.umap(adata,
               neighbors_key=f"{model}_latent_run{run_number}")
        
    print("\nComputing Leiden clustering...")
    sc.tl.leiden(adata=adata,
                 resolution=leiden_resolutions[i],
                 key_added=f"run{run_number}_leiden_{leiden_resolutions[i]}",
                 neighbors_key=f"{model}_latent_run{run_number}")
    
    niche_annotations = {
        "0": "Stratum",
        "1": "Thalamus LD",
        "2": "Cortical layer 6a",
        "3": "Cortical layer 6b",
        "4": "Thalamus LP",
        "5": "Corpus callosum",
        "6": "CA1",
        "7": "Cortical layer 5",
        "8": "CA2 & CA3",
        "9": "Dentate gyrus",
        "10": "Medial habenula (MH)",
        "11": "Lateral habenula (LH)",
        "12": "Cortical layer 2/3",
        "13": "Third Ventricle (V3)"}

    adata.obs["niche"] = adata.obs["run2_leiden_0.35"].map(niche_annotations)

    latent_cluster_colors = create_new_color_dict(
        adata=adata,
        cat_key=f"run{run_number}_leiden_{leiden_resolutions[i]}")
    
    niche_colors = {niche: latent_cluster_colors[cluster] for cluster, niche in niche_annotations.items()}
    
    plot_category_in_latent_and_physical_space(
            adata=adata,
            #figsize=(10, 20), for latent UMAP
            plot_label="Latent Clusters",
            model_label={model},
            cat_key=f"niche",
            groups=None,
            sample_key="sample",
            samples=["sample1"], # =None for latent UMAP
            cat_colors=niche_colors,
            size=(720000 / len(adata)),
            spot_size=30,
            save_fig=True,
            file_path=f"{benchmarking_folder_path}/latent_clusters_{model}.svg")

In [None]:
sc.tl.dendrogram(adata=adata,
                 use_rep="nichecompass_latent_run2",
                 linkage_method="ward",
                 groupby="niche")

fig, (ax) = plt.subplots(1, 1, figsize=(2, 5))
sc.pl.dendrogram(
    adata=adata,
    groupby="niche",
    orientation="left",
    show=False,
    save=f"_{dataset}_nichecompass_run2",
    ax=ax)
plt.show()

In [None]:
cell_type_colors = create_new_color_dict(
    adata=adata,
    color_palette="cell_type_30",
    cat_key="cell_type")

tmp = pd.crosstab(adata.obs["niche"], adata.obs["cell_type"], normalize='index')
tmp = tmp.reindex(adata.uns["dendrogram_niche"]["categories_ordered"][::])
ax = tmp.plot.barh(color=cell_type_colors, stacked=True, figsize=(6, 10)).legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.yticks(fontsize=16)
plt.xlabel("Cell Type Proportions", fontsize=16)
plt.savefig(f"{benchmarking_folder_path}/niche_cell_type_proportions.svg", bbox_inches='tight')

In [None]:
dataset = "slideseqv2_mouse_hippocampus"
models = ["graphst"]
run_number = 2
leiden_resolutions = [0.93]

for i, model in enumerate(models):
    adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
    model = model.replace("_gcnconv", "").replace("_gatv2conv", "")
    adata.obs["sample"] = "sample1"
    
    print("\nComputing neighbor graph...")
    # Use latent representation for UMAP generation
    sc.pp.neighbors(adata,
                    use_rep=f"{model}_latent_run{run_number}",
                    key_added=f"{model}_latent_run{run_number}")

    print("\nComputing UMAP embedding...")
    sc.tl.umap(adata,
               neighbors_key=f"{model}_latent_run{run_number}")
        
    print("\nComputing Leiden clustering...")
    sc.tl.leiden(adata=adata,
                 resolution=leiden_resolutions[i],
                 key_added=f"run{run_number}_leiden_{leiden_resolutions[i]}",
                 neighbors_key=f"{model}_latent_run{run_number}")

    niche_annotations = {
        "0": "Cortical Layer 6a",
        "1": "Stratum",
        "2": "Thalamus LD",
        "3": "Corpus callosum",
        "4": "Thalamus LP",
        "5": "Cortical Layer 6b",
        "6": "CA1",
        "7": "CA2 & CA3",
        "8": "Dentate gyrus",
        "9": "Artifact 1",
        "10": "Third Ventricle (V3)",
        "11": "Medial habenula (MH)",
        "12": "Lateral habenula (LH)",
        "13": "Artifact 2"}

    adata.obs["niche"] = adata.obs["run2_leiden_0.93"].map(niche_annotations)

    niche_colors = {
     'Stratum': '#8BE0A4',
     'Thalamus LD': '#F6CF71',
     'Cortical Layer 6b': '#B497E7',
     'Thalamus LP': '#87C55F',
     'Corpus callosum': '#DAB6C4',
     'CA1': '#FE88B1',
     'CA2 & CA3': '#DCB0F2',
     'Dentate gyrus': '#D3B484',
     'Medial habenula (MH)': '#F89C74',
     'Lateral habenula (LH)': '#C9DB74',
     'Third Ventricle (V3)': '#B3B3B3',
     'Cortical Layer 6a': "#66C5CC",
     'Artifact 1': "#FF4D4D",
     'Artifact 2': "#D2691E"}
    
    plot_category_in_latent_and_physical_space(
            adata=adata,
            figsize=(10, 20), #for latent UMAP
            plot_label="Latent Clusters",
            model_label={model},
            cat_key=f"niche",
            groups=None,
            sample_key="sample",
            samples=None, #samples=None #for latent UMAP #samples=["sample1"] for spatial plot
            cat_colors=niche_colors,
            size=(720000 / len(adata)),
            spot_size=30,
            save_fig=True,
            file_path=f"{benchmarking_folder_path}/latent_clusters_{model}.svg")

In [None]:
sc.tl.dendrogram(adata=adata,
                 use_rep="graphst_latent_run2",
                 linkage_method="ward",
                 groupby="niche")

fig, (ax) = plt.subplots(1, 1, figsize=(2, 5))
sc.pl.dendrogram(
    adata=adata,
    groupby="niche",
    orientation="left",
    show=False,
    save=f"_{dataset}_graphst_run2",
    ax=ax)
plt.show()

In [None]:
cell_type_colors = create_new_color_dict(
    adata=adata,
    color_palette="cell_type_30",
    cat_key="cell_type")

tmp = pd.crosstab(adata.obs["niche"], adata.obs["cell_type"], normalize='index')
tmp = tmp.reindex(adata.uns["dendrogram_niche"]["categories_ordered"][::])
ax = tmp.plot.barh(color=cell_type_colors, stacked=True, figsize=(6, 10)).legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.yticks(fontsize=16)
plt.xlabel("Cell Type Proportions", fontsize=16)
plt.savefig(f"{benchmarking_folder_path}/niche_cell_type_proportions_graphst.svg", bbox_inches='tight')

In [None]:
dataset = "slideseqv2_mouse_hippocampus"
models = ["sagenet"]
run_number = 2
leiden_resolutions = [0.0775]

for i, model in enumerate(models):
    adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
    model = model.replace("_gcnconv", "").replace("_gatv2conv", "")
    adata.obs["sample"] = "sample1"
    
    print("\nComputing neighbor graph...")
    # Use latent representation for UMAP generation
    sc.pp.neighbors(adata,
                    use_rep=f"{model}_latent_run{run_number}",
                    key_added=f"{model}_latent_run{run_number}")

    print("\nComputing UMAP embedding...")
    sc.tl.umap(adata,
               neighbors_key=f"{model}_latent_run{run_number}")
        
    print("\nComputing Leiden clustering...")
    sc.tl.leiden(adata=adata,
                 resolution=leiden_resolutions[i],
                 key_added=f"run{run_number}_leiden_{leiden_resolutions[i]}",
                 neighbors_key=f"{model}_latent_run{run_number}")
    
    latent_cluster_colors = {
        "0": "#9EB9F3",
        "1": "#DAB6C4",
        "2": "#87C55F",
        "3": "#F6CF71",
        "4": "#9B4DCA",
        "5": "#276A8C",
        "6": "#8BE0A4",
        "7": "#DCB0F2",
        "8": "#FF9CDA",
        "9": "#F89C74",
        "10": "#FE88B1",
        "11": "#66C5CC",
        "12": "#D3B484",
        "13": "#B3B3B3"}
    
    niche_annotations = {
        "0": "Cortical layer 5",
        "1": "Corpus callosum",
        "2": "Thalamus LP",
        "3": "Thalamus LD",
        "4": "Artifact 1",
        "5": "Cortical layer 2/3",
        "6": "Stratum",
        "7": "CA2 & CA3",
        "8": "Artifact 2",
        "9": "Medial habenula (MH)",
        "10": "CA1",
        "11": "Cortical Layer 6a",
        "12": "Dentate gyrus",
        "13": "Third Ventricle (V3)"}

    adata.obs["niche"] = adata.obs["run2_leiden_0.0775"].map(niche_annotations)
    
    niche_colors = {
         'Stratum': '#8BE0A4',
         'Thalamus LD': '#F6CF71',
         'Cortical Layer 6b': '#B497E7',
         'Cortical layer 5': '#9EB9F3',
         'Thalamus LP': '#87C55F',
         'Corpus callosum': '#DAB6C4',
         'CA1': '#FE88B1',
         'CA2 & CA3': '#DCB0F2',
         'Dentate gyrus': '#D3B484',
         'Medial habenula (MH)': '#F89C74',
         'Lateral habenula (LH)': '#C9DB74',
         'Third Ventricle (V3)': '#B3B3B3',
         'Cortical layer 2/3': '#276A8C',
         'Cortical Layer 6a': "#66C5CC",
         'Artifact 1': "#FF4D4D",
         'Artifact 2': "#D2691E"}
    
    plot_category_in_latent_and_physical_space(
            adata=adata,
            figsize=(10, 20), #for latent UMAP
            plot_label="Latent Clusters",
            model_label={model},
            cat_key="niche",
            groups=None,
            sample_key="sample",
            samples=None, #samples=None #for latent UMAP #samples=["sample1"] for spatial plot
            cat_colors=niche_colors,
            size=(720000 / len(adata)),
            spot_size=30,
            save_fig=True,
            file_path=f"{benchmarking_folder_path}/latent_clusters_{model}.svg")

In [None]:
sc.tl.dendrogram(adata=adata,
                 use_rep="sagenet_latent_run2",
                 linkage_method="ward",
                 groupby="niche")

fig, (ax) = plt.subplots(1, 1, figsize=(2, 5))
sc.pl.dendrogram(
    adata=adata,
    groupby="niche",
    orientation="left",
    show=False,
    save=f"_{dataset}_sagenet_run2",
    ax=ax)
plt.show()

In [None]:
cell_type_colors = create_new_color_dict(
    adata=adata,
    color_palette="cell_type_30",
    cat_key="cell_type")

tmp = pd.crosstab(adata.obs["niche"], adata.obs["cell_type"], normalize='index')
tmp = tmp.reindex(adata.uns["dendrogram_niche"]["categories_ordered"][::])
ax = tmp.plot.barh(color=cell_type_colors, stacked=True, figsize=(6, 10)).legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.yticks(fontsize=16)
plt.xlabel("Cell Type Proportions", fontsize=16)
plt.savefig(f"{benchmarking_folder_path}/niche_cell_type_proportions_sagenet.svg", bbox_inches='tight')

In [None]:
dataset = "slideseqv2_mouse_hippocampus"
models = ["deeplinc"]
run_number = 2
leiden_resolutions = [1.1]

for i, model in enumerate(models):
    adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
    model = model.replace("_gcnconv", "").replace("_gatv2conv", "")
    adata.obs["sample"] = "sample1"
    
    print("\nComputing neighbor graph...")
    # Use latent representation for UMAP generation
    sc.pp.neighbors(adata,
                    use_rep=f"{model}_latent_run{run_number}",
                    key_added=f"{model}_latent_run{run_number}")

    print("\nComputing UMAP embedding...")
    sc.tl.umap(adata,
               neighbors_key=f"{model}_latent_run{run_number}")
        
    print("\nComputing Leiden clustering...")
    sc.tl.leiden(adata=adata,
                 resolution=leiden_resolutions[i],
                 key_added=f"run{run_number}_leiden_{leiden_resolutions[i]}",
                 neighbors_key=f"{model}_latent_run{run_number}")
    
    latent_cluster_colors = {
        "0": "#66C5CC",
        "1": "#F6CF71",
        "2": "#B497E7",
        "3": "#87C55F",
        "4": "#DAB6C4",
        "5": "#FE88B1",
        "6": "#8BE0A4",
        "7": "#D3B484",
        "8": "#F89C74",
        "9": "#DCB0F2",
        "10": "#276A8C",
        "11": "#FF4D4D",
        "12": "#B3B3B3",
        "13": "#FF00FF"}
    
    niche_annotations = {
        "0": "Cortical Layer 6a",
        "1": "Thalamus LD",
        "2": "Cortical Layer 6b",
        "3": "Thalamus LP",
        "4": "Corpus callosum",
        "5": "CA1",
        "6": "Stratum",
        "7": "Dentate gyrus",
        "8": "Medial habenula (MH)",
        "9": "CA2 & CA3",
        "10": "Cortical layer 2/3",
        "11": "Artifact 1",
        "12": "Third Ventricle (V3)",
        "13": "Fasciola cinerea"}

    adata.obs["niche"] = adata.obs["run2_leiden_1.1"].map(niche_annotations)
    
    niche_colors = {
         'Stratum': '#8BE0A4',
         'Thalamus LD': '#F6CF71',
         'Cortical Layer 6b': '#B497E7',
         'Cortical layer 5': '#9EB9F3',
         'Thalamus LP': '#87C55F',
         'Corpus callosum': '#DAB6C4',
         'CA1': '#FE88B1',
         'CA2 & CA3': '#DCB0F2',
         'Dentate gyrus': '#D3B484',
         'Medial habenula (MH)': '#F89C74',
         'Lateral habenula (LH)': '#C9DB74',
         'Third Ventricle (V3)': '#B3B3B3',
         'Cortical layer 2/3': '#276A8C',
         'Cortical Layer 6a': "#66C5CC",
         'Artifact 1': "#FF4D4D",
         'Fasciola cinerea': "#FF00FF"}
    
    plot_category_in_latent_and_physical_space(
            adata=adata,
            figsize=(10, 20), #for latent UMAP
            plot_label="Latent Clusters",
            model_label={model},
            cat_key="niche",
            groups=None,
            sample_key="sample",
            samples=None, #samples=None #for latent UMAP #samples=["sample1"] for spatial plot
            cat_colors=niche_colors,
            size=(720000 / len(adata)),
            spot_size=30,
            save_fig=True,
            file_path=f"{benchmarking_folder_path}/latent_clusters_{model}.svg")

In [None]:
sc.tl.dendrogram(adata=adata,
                 use_rep="deeplinc_latent_run2",
                 linkage_method="ward",
                 groupby="niche")

fig, (ax) = plt.subplots(1, 1, figsize=(2, 5))
sc.pl.dendrogram(
    adata=adata,
    groupby="niche",
    orientation="left",
    show=False,
    save=f"_{dataset}_deeplinc_run2",
    ax=ax)
plt.show()

In [None]:
cell_type_colors = create_new_color_dict(
    adata=adata,
    color_palette="cell_type_30",
    cat_key="cell_type")

tmp = pd.crosstab(adata.obs["niche"], adata.obs["cell_type"], normalize='index')
tmp = tmp.reindex(adata.uns["dendrogram_niche"]["categories_ordered"][::])
ax = tmp.plot.barh(color=cell_type_colors, stacked=True, figsize=(6, 10)).legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.yticks(fontsize=16)
plt.xlabel("Cell Type Proportions", fontsize=16)
plt.savefig(f"{benchmarking_folder_path}/niche_cell_type_proportions_deeplinc.svg", bbox_inches='tight')

In [None]:
plot_category_in_latent_and_physical_space(
        adata=adata,
        #figsize=(10, 20), for latent UMAP
        plot_label="Cell Types",
        model_label=None,
        cat_key=f"cell_type",
        groups=None,
        sample_key="sample",
        samples=["sample1"], # =None for latent UMAP
        cat_colors=cell_type_colors,
        size=(720000 / len(adata)),
        spot_size=30,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/cell_types.svg")

In [None]:
dataset = "slideseqv2_mouse_hippocampus"
models = ["scvi"]
run_number = 2
leiden_resolutions = [0.875]

for i, model in enumerate(models):
    adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
    model = model.replace("_gcnconv", "").replace("_gatv2conv", "")
    adata.obs["sample"] = "sample1"
    
    print("\nComputing neighbor graph...")
    # Use latent representation for UMAP generation
    sc.pp.neighbors(adata,
                    use_rep=f"{model}_latent_run{run_number}",
                    key_added=f"{model}_latent_run{run_number}")

    print("\nComputing UMAP embedding...")
    sc.tl.umap(adata,
               neighbors_key=f"{model}_latent_run{run_number}")
        
    print("\nComputing Leiden clustering...")
    sc.tl.leiden(adata=adata,
                 resolution=leiden_resolutions[i],
                 key_added=f"run{run_number}_leiden_{leiden_resolutions[i]}",
                 neighbors_key=f"{model}_latent_run{run_number}")
    
    latent_cluster_colors = {
        "0": "#66C5CC",
        "1": "#8BE0A4",
        "2": "#DAB6C4",
        "3": "#F6CF71",
        "4": "#87C55F",
        "5": "#9D88A2",
        "6": "#F89C74",
        "7": "#BA55D3",
        "8": "#B3B3B3",
        "9": "#D3B484",
        "10": "#8A2BE2",
        "11": "#FE88B1",
        "12": "#DCB0F2",
        "13": "#276A8C"}
    
    plot_category_in_latent_and_physical_space(
            adata=adata,
            plot_label="Latent Clusters",
            model_label={model},
            cat_key=f"run{run_number}_leiden_{leiden_resolutions[i]}",
            groups=None,
            sample_key="sample",
            samples=["sample1"],
            cat_colors=latent_cluster_colors,
            size=(720000 / len(adata)),
            spot_size=30,
            save_fig=True,
            file_path=f"{benchmarking_folder_path}/latent_clusters_{model}.svg")

In [None]:
dataset = "slideseqv2_mouse_hippocampus"
models = ["expimap"]
run_number = 2
leiden_resolutions = [1.5]

for i, model in enumerate(models):
    adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
    model = model.replace("_gcnconv", "").replace("_gatv2conv", "")
    adata.obs["sample"] = "sample1"
    
    print("\nComputing neighbor graph...")
    # Use latent representation for UMAP generation
    sc.pp.neighbors(adata,
                    use_rep=f"{model}_latent_run{run_number}",
                    key_added=f"{model}_latent_run{run_number}")

    print("\nComputing UMAP embedding...")
    sc.tl.umap(adata,
               neighbors_key=f"{model}_latent_run{run_number}")
        
    print("\nComputing Leiden clustering...")
    sc.tl.leiden(adata=adata,
                 resolution=leiden_resolutions[i],
                 key_added=f"run{run_number}_leiden_{leiden_resolutions[i]}",
                 neighbors_key=f"{model}_latent_run{run_number}")
    
    latent_cluster_colors = {
        "0": "#66C5CC",
        "1": "#F6CF71",
        "2": "#DAB6C4",
        "3": "#8BE0A4",
        "4": "#9D88A2",
        "5": "#FE88B1",
        "6": "#F89C74",
        "7": "#D3B484",
        "8": "#8A2BE2",
        "9": "#FF00FF",
        "10": "#48D1CC",
        "11": "#B3B3B3",
        "12": "#B497E7",
        "13": "#9B4DCA"}
    
    plot_category_in_latent_and_physical_space(
            adata=adata,
            plot_label="Latent Clusters",
            model_label={model},
            cat_key=f"run{run_number}_leiden_{leiden_resolutions[i]}",
            groups=None,
            sample_key="sample",
            samples=["sample1"],
            cat_colors=latent_cluster_colors,
            size=(720000 / len(adata)),
            spot_size=30,
            save_fig=True,
            file_path=f"{benchmarking_folder_path}/latent_clusters_{model}.svg")

In [None]:
### Supplementary figure: 25% subsample niches ###
dataset = "slideseqv2_mouse_hippocampus_subsample_25pct"
models = ["nichecompass_gatv2conv"]

run_number = 2
leiden_resolutions = [0.1]

for i, model in enumerate(models):
    adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
    model = model.replace("_gcnconv", "").replace("_gatv2conv", "")
    adata.obs["sample"] = "sample1"
    
    print("\nComputing neighbor graph...")
    # Use latent representation for UMAP generation
    sc.pp.neighbors(adata,
                    use_rep=f"{model}_latent_run{run_number}",
                    key_added=f"{model}_latent_run{run_number}")

    print("\nComputing UMAP embedding...")
    sc.tl.umap(adata,
               neighbors_key=f"{model}_latent_run{run_number}")
        
    print("\nComputing Leiden clustering...")
    sc.tl.leiden(adata=adata,
                 resolution=leiden_resolutions[i],
                 key_added=f"run{run_number}_leiden_{leiden_resolutions[i]}",
                 neighbors_key=f"{model}_latent_run{run_number}")
    
    latent_cluster_colors = {
        "0": "#8BE0A4",
        "1": "#DCB0F2",
        "2": "#66C5CC",
        "3": "#F6CF71",
        "4": "#DAB6C4",
        "5": "#D3B484",
        "6": "#FE88B1"}
    
    plot_category_in_latent_and_physical_space(
            adata=adata,
            figsize=(10,20),
            plot_label="Latent Clusters",
            model_label={model},
            cat_key=f"run{run_number}_leiden_{leiden_resolutions[i]}",
            groups=None,
            sample_key="sample",
            samples=["sample1"], #samples=None #for latent UMAP #samples=["sample1"] for spatial plot
            cat_colors=latent_cluster_colors,
            size=(720000 / len(adata)),
            spot_size=30,
            save_fig=True,
            file_path=f"{benchmarking_folder_path}/latent_clusters_{model}.svg")

In [None]:
dataset = "slideseqv2_mouse_hippocampus_subsample_25pct"
models = ["staci"]

run_number = 2
leiden_resolutions = [0.1]

for i, model in enumerate(models):
    adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
    adata.obs["sample"] = "sample1"
    
    print("\nComputing neighbor graph...")
    # Use latent representation for UMAP generation
    sc.pp.neighbors(adata,
                    use_rep=f"{model}_latent_run{run_number}",
                    key_added=f"{model}_latent_run{run_number}")

    print("\nComputing UMAP embedding...")
    sc.tl.umap(adata,
               neighbors_key=f"{model}_latent_run{run_number}")
        
    print("\nComputing Leiden clustering...")
    sc.tl.leiden(adata=adata,
                 resolution=leiden_resolutions[i],
                 key_added=f"run{run_number}_leiden_{leiden_resolutions[i]}",
                 neighbors_key=f"{model}_latent_run{run_number}")
    
    latent_cluster_colors = {
        "0": "#8BE0A4",
        "1": "#66C5CC",
        "2": "#D3B484",
        "3": "#F6CF71",
        "4": "#DCB0F2",
        "5": "#FE88B1",
        "6": "#DAB6C4"}
    
    plot_category_in_latent_and_physical_space(
            adata=adata,
            figsize=(10,20),
            plot_label="Latent Clusters",
            model_label={model},
            cat_key=f"run{run_number}_leiden_{leiden_resolutions[i]}",
            groups=None,
            sample_key="sample",
            samples=["sample1"], #samples=None #for latent UMAP #samples=["sample1"] for spatial plot
            cat_colors=latent_cluster_colors,
            size=(720000 / len(adata)),
            spot_size=30,
            save_fig=True,
            file_path=f"{benchmarking_folder_path}/latent_clusters_{model}.svg")

In [None]:
dataset = "slideseqv2_mouse_hippocampus_subsample_25pct"
models = ["graphst"]

run_number = 2
leiden_resolutions = [0.15]

for i, model in enumerate(models):
    adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
    adata.obs["sample"] = "sample1"
    
    print("\nComputing neighbor graph...")
    # Use latent representation for UMAP generation
    sc.pp.neighbors(adata,
                    use_rep=f"{model}_latent_run{run_number}",
                    key_added=f"{model}_latent_run{run_number}")

    print("\nComputing UMAP embedding...")
    sc.tl.umap(adata,
               neighbors_key=f"{model}_latent_run{run_number}")
        
    print("\nComputing Leiden clustering...")
    sc.tl.leiden(adata=adata,
                 resolution=leiden_resolutions[i],
                 key_added=f"run{run_number}_leiden_{leiden_resolutions[i]}",
                 neighbors_key=f"{model}_latent_run{run_number}")
    
    latent_cluster_colors = {
        "0": "#8BE0A4",
        "1": "#DAB6C4",
        "2": "#F6CF71",
        "3": "#DCB0F2",
        "4": "#66C5CC",
        "5": "#D3B484",
        "6": "#FE88B1"}
    
    plot_category_in_latent_and_physical_space(
            adata=adata,
            figsize=(10,20),
            plot_label="Latent Clusters",
            model_label={model},
            cat_key=f"run{run_number}_leiden_{leiden_resolutions[i]}",
            groups=None,
            sample_key="sample",
            samples=["sample1"], #samples=None #for latent UMAP #samples=["sample1"] for spatial plot
            cat_colors=latent_cluster_colors,
            size=(720000 / len(adata)),
            spot_size=30,
            save_fig=True,
            file_path=f"{benchmarking_folder_path}/latent_clusters_{model}.svg")

In [None]:
dataset = "slideseqv2_mouse_hippocampus_subsample_25pct"
models = ["deeplinc"]

run_number = 2
leiden_resolutions = [0.2]

for i, model in enumerate(models):
    adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
    adata.obs["sample"] = "sample1"
    
    print("\nComputing neighbor graph...")
    # Use latent representation for UMAP generation
    sc.pp.neighbors(adata,
                    use_rep=f"{model}_latent_run{run_number}",
                    key_added=f"{model}_latent_run{run_number}")

    print("\nComputing UMAP embedding...")
    sc.tl.umap(adata,
               neighbors_key=f"{model}_latent_run{run_number}")
        
    print("\nComputing Leiden clustering...")
    sc.tl.leiden(adata=adata,
                 resolution=leiden_resolutions[i],
                 key_added=f"run{run_number}_leiden_{leiden_resolutions[i]}",
                 neighbors_key=f"{model}_latent_run{run_number}")
    
    latent_cluster_colors = {
        "0": "#8BE0A4",
        "1": "#D3B484",
        "2": "#DCB0F2",
        "3": "#66C5CC",
        "4": "#F6CF71",
        "5": "#DAB6C4",
        "6": "#FE88B1"}
    
    plot_category_in_latent_and_physical_space(
            adata=adata,
            figsize=(10,20),
            plot_label="Latent Clusters",
            model_label={model},
            cat_key=f"run{run_number}_leiden_{leiden_resolutions[i]}",
            groups=None,
            sample_key="sample",
            samples=["sample1"], #samples=None #for latent UMAP #samples=["sample1"] for spatial plot
            cat_colors=latent_cluster_colors,
            size=(720000 / len(adata)),
            spot_size=30,
            save_fig=True,
            file_path=f"{benchmarking_folder_path}/latent_clusters_{model}.svg")

In [None]:
datasets = ["slideseqv2_mouse_hippocampus",
            "slideseqv2_mouse_hippocampus_subsample_50pct",
            "slideseqv2_mouse_hippocampus_subsample_25pct",
            "slideseqv2_mouse_hippocampus_subsample_10pct",
            "slideseqv2_mouse_hippocampus_subsample_5pct",
            "slideseqv2_mouse_hippocampus_subsample_1pct"]
models = ["nichecompass_gcnconv",
          "nichecompass_gatv2conv",
          "staci",
          "deeplinc",
          "graphst",
          "sagenet",
          #"scvi",
          #"expimap"
         ]

summary_df = pd.DataFrame()
for dataset in datasets:
    dataset_df = pd.DataFrame()
    for model in models:
        try:
            benchmark_df = pd.read_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv")
            #adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
            #training_durations = []
            #for run_number in [1, 2, 3, 4, 5, 6, 7, 8]:
            #    training_durations.append(adata.uns[f"{model.split('_')[0]}_model_training_duration_run{run_number}"])
            #benchmark_df["run_time"] = training_durations
            #benchmark_df = benchmark_df[["dataset", "run_number", "run_time", "gcs", "mlami", "cas", "clisis", "nasw", "cnmi", "cari", "casw", "clisi"]]
            #benchmark_df.to_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv", index=False)
            benchmark_df["model"] = model
            dataset_df = pd.concat([dataset_df, benchmark_df], ignore_index=True)
        except FileNotFoundError:
            print(f"Did not find file {benchmarking_folder_path}/{dataset}_{model}_metrics.csv. Continuing...")
            missing_run_data = {
                "dataset": [dataset] * 8,
                "model": [model] * 8,
                "run_number": [1, 2, 3, 4, 5, 6, 7, 8],
                "run_time": [np.nan] * 8
            }
            missing_run_df = pd.DataFrame(missing_run_data)
            dataset_df = dataset_df.append(missing_run_df, ignore_index=True)
            
    # Apply min-max scaling to metric columns
    for i in range(len(metric_cols_single_sample)):
        min_val = dataset_df[metric_cols_single_sample[i]].min()
        max_val = dataset_df[metric_cols_single_sample[i]].max()
        dataset_df[metric_cols_single_sample[i] + "_scaled"] = ((
            dataset_df[metric_cols_single_sample[i]] - min_val) / (max_val - min_val))

    summary_df = pd.concat([summary_df, dataset_df], ignore_index=True)
    continue
    
cat_0_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[0:2]]
cat_1_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[2:4]]
cat_2_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[4:6]]
    
summary_df[category_cols_single_sample[0]] = np.average(summary_df[cat_0_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[0:2],
                                                        axis=1)
summary_df[category_cols_single_sample[1]] = np.average(summary_df[cat_1_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[2:4],
                                                        axis=1)
summary_df[category_cols_single_sample[2]] = np.average(summary_df[cat_2_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[4:6],
                                                        axis=1)
summary_df["Overall Score"] = np.average(summary_df[category_cols_single_sample[:3]],
                                         weights=category_col_weights_single_sample[:3],
                                         axis=1)
 
# Reformat for plot
summary_df.replace({"nichecompass_gatv2conv": "NicheCompass",
                    "nichecompass_gcnconv": "NicheCompass Light",
                    "staci": "STACI",
                    "deeplinc": "DeepLinc",
                    "expimap": "expiMap",
                    "graphst": "GraphST",
                    "sagenet": "SageNet",
                    "scvi": "scVI"}, inplace=True)

# Plot over all loss weights combinations
# Prepare metrics table plot
group_cols = ["dataset", "model"]
aggregate_df = summary_df.groupby(group_cols).mean("Overall Score").sort_values("Overall Score", ascending=False)[
    metric_cols_single_sample + ["Overall Score"]].reset_index()

unrolled_df = pd.melt(
    aggregate_df, 
    id_vars=group_cols,
    value_vars=metric_cols_single_sample + ["Overall Score"],
    var_name="score_type", 
    value_name="score")

# Create spatial indicator column
def is_spatially_aware_model(row):
    if row["model"] in ["NicheCompass GCN", "NicheCompass GATv2", "DeepLinc", "GraphST", "SageNet"]:
        return True
    return False
unrolled_df["spatially_aware"] = unrolled_df.apply(lambda row: is_spatially_aware_model(row), axis=1)
unrolled_df = unrolled_df[["dataset", "spatially_aware", "model", "score_type", "score"]]

# Order datasets
unrolled_df["dataset"] = pd.Categorical(unrolled_df["dataset"], categories=datasets, ordered=True)
unrolled_df = unrolled_df.sort_values(by="dataset")

#print(summary_df["model"].value_counts())
#print(summary_df[summary_df.isna().any(axis=1)])

In [None]:
model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "DeepLinc": "#0E1C4A",
                 "GraphST": "#D78FF8",
                 "SageNet": "#F46AA2",
                 "scVI": "#FE8B3B",
                 "expiMap": "#7E0028",
                 }

run_time_mean_df = summary_df.groupby(["dataset", "model"])[["run_time"]].mean().reset_index()
run_time_mean_df["run_time"] = run_time_mean_df["run_time"] / 60

def create_dataset_share_col(row):
    if row["dataset"] == "slideseqv2_mouse_hippocampus":
        return 100
    elif row["dataset"] == "slideseqv2_mouse_hippocampus_subsample_50pct":    
        return 50
    elif row["dataset"] == "slideseqv2_mouse_hippocampus_subsample_25pct":    
        return 25
    elif row["dataset"] == "slideseqv2_mouse_hippocampus_subsample_10pct":    
        return 10
    elif row["dataset"] == "slideseqv2_mouse_hippocampus_subsample_5pct":    
        return 5
    elif row["dataset"] == "slideseqv2_mouse_hippocampus_subsample_1pct":    
        return 1
    
run_time_mean_df["dataset_share"] = run_time_mean_df.apply(lambda row: create_dataset_share_col(row), axis=1)
    
ax = sns.lineplot(data=run_time_mean_df,
                  x="dataset_share",
                  y="run_time",
                  hue="model",
                  marker='o',
                  palette=model_palette)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.title("SlideSeqV2 Mouse Hippocampus\n(41,786 Cells; 4,000 Genes)")
plt.ylabel("Run Time (Minutes)")
plt.xlabel("Dataset Size (%)")
custom_y_ticks = [1, 10, 60, 180, 360, 720, 1440]  # Adjust the tick positions as needed
plt.yscale("log")
plt.yticks(custom_y_ticks, custom_y_ticks)
legend = plt.gca().get_legend()
for handle in legend.legendHandles:
    handle.set_linewidth(4.0)  # Adjust the size as needed
handles, labels = legend.legendHandles, [text.get_text() for text in legend.get_texts()]
order = [3, 2, 4, 1, 0]
ordered_handles = [handles[i] for i in order]
ordered_labels = [labels[i] for i in order]
plt.legend(ordered_handles, ordered_labels)
ax = plt.gca()
ax.legend().set_visible(False)
plt.savefig(benchmarking_folder_path + "/benchmarking_runtimes_slideseqv2_mouse_hippocampus.svg")
plt.show()

In [None]:
# Plot table
plot_metrics_table(
    df=unrolled_df,
    model_col="model",
    model_col_width=1.9,
    group_col="dataset",
    metric_cols=metric_cols_single_sample, # metric_cols_single_sample, category_cols_single_sample
    metric_col_weights=metric_col_weights_single_sample, # metric_col_weights_single_sample, category_col_weights_single_sample
    metric_col_titles=[col.replace(" ", "\n") for col in metric_col_titles_single_sample], # category_col_titles_single_sample
    metric_col_width=0.7, # 0.8,
    aggregate_col_width=1.,
    plot_width=42, # 32,
    plot_height=8,
    show=True,
    save_dir=benchmarking_folder_path,
    save_name=f"benchmarking_metrics_slideseqv2_mouse_hippocampus.svg")

#### 2.1.2 seqFISH Mouse Organogenesis

In [None]:
datasets = ["seqfish_mouse_organogenesis_embryo2",
            "seqfish_mouse_organogenesis_subsample_50pct_embryo2",
            "seqfish_mouse_organogenesis_subsample_25pct_embryo2",
            "seqfish_mouse_organogenesis_subsample_10pct_embryo2",
            "seqfish_mouse_organogenesis_subsample_5pct_embryo2",
            "seqfish_mouse_organogenesis_subsample_1pct_embryo2"]
models = ["nichecompass_gatv2conv",
          "nichecompass_gcnconv",
          "staci",
          "deeplinc",
          "graphst",
          "sagenet",
          #"scvi",
          #"expimap"
         ]

summary_df = pd.DataFrame()
for dataset in datasets:
    dataset_df = pd.DataFrame()
    for model in models:
        try:
            benchmark_df = pd.read_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv")
            #adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
            #training_durations = []
            #for run_number in [1, 2, 3, 4, 5, 6, 7, 8]:
            #    training_durations.append(adata.uns[f"{model.split('_')[0]}_model_training_duration_run{run_number}"])
            #benchmark_df["run_time"] = training_durations
            #benchmark_df = benchmark_df[["dataset", "run_number", "run_time", "gcs", "mlami", "cas", "clisis", "nasw", "cnmi", "cari", "casw", "clisi"]]
            #benchmark_df.to_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv", index=False)
            benchmark_df["model"] = model
            dataset_df = pd.concat([dataset_df, benchmark_df], ignore_index=True)
        except FileNotFoundError:
            print(f"Did not find file {benchmarking_folder_path}/{dataset}_{model}_metrics.csv. Continuing...")
            missing_run_data = {
                "dataset": [dataset] * 8,
                "model": [model] * 8,
                "run_number": [1, 2, 3, 4, 5, 6, 7, 8],
                "run_time": [np.nan] * 8
            }
            missing_run_df = pd.DataFrame(missing_run_data)
            dataset_df = dataset_df.append(missing_run_df, ignore_index=True)
            
    # Apply min-max scaling to metric columns
    for i in range(len(metric_cols_single_sample)):
        min_val = dataset_df[metric_cols_single_sample[i]].min()
        max_val = dataset_df[metric_cols_single_sample[i]].max()
        dataset_df[metric_cols_single_sample[i] + "_scaled"] = ((
            dataset_df[metric_cols_single_sample[i]] - min_val) / (max_val - min_val))

    summary_df = pd.concat([summary_df, dataset_df], ignore_index=True)
    continue
    
cat_0_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[0:2]]
cat_1_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[2:4]]
cat_2_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[4:6]]
    
summary_df[category_cols_single_sample[0]] = np.average(summary_df[cat_0_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[0:2],
                                                        axis=1)
summary_df[category_cols_single_sample[1]] = np.average(summary_df[cat_1_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[2:4],
                                                        axis=1)
summary_df[category_cols_single_sample[2]] = np.average(summary_df[cat_2_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[4:6],
                                                        axis=1)
summary_df["Overall Score"] = np.average(summary_df[category_cols_single_sample[:3]],
                                         weights=category_col_weights_single_sample[:3],
                                         axis=1)
 
# Reformat for plot
summary_df.replace({"nichecompass_gatv2conv": "NicheCompass",
                    "nichecompass_gcnconv": "NicheCompass Light",
                    "staci": "STACI",
                    "deeplinc": "DeepLinc",
                    "expimap": "expiMap",
                    "graphst": "GraphST",
                    "sagenet": "SageNet",
                    "scvi": "scVI"}, inplace=True)

# Plot over all loss weights combinations
# Prepare metrics table plot
group_cols = ["dataset", "model"]
aggregate_df = summary_df.groupby(group_cols).mean("Overall Score").sort_values("Overall Score", ascending=False)[
    metric_cols_single_sample + ["Overall Score"]].reset_index()

unrolled_df = pd.melt(
    aggregate_df, 
    id_vars=group_cols,
    value_vars=metric_cols_single_sample + ["Overall Score"],
    var_name="score_type", 
    value_name="score")

# Create spatial indicator column
def is_spatially_aware_model(row):
    if row["model"] in ["NicheCompass GCN", "NicheCompass GATv2", "DeepLinc", "GraphST", "SageNet"]:
        return True
    return False
unrolled_df["spatially_aware"] = unrolled_df.apply(lambda row: is_spatially_aware_model(row), axis=1)
unrolled_df = unrolled_df[["dataset", "spatially_aware", "model", "score_type", "score"]]

# Order datasets
unrolled_df["dataset"] = pd.Categorical(unrolled_df["dataset"], categories=datasets, ordered=True)
unrolled_df = unrolled_df.sort_values(by="dataset")

#print(summary_df["model"].value_counts())
#print(summary_df[summary_df.isna().any(axis=1)])

In [None]:
model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "DeepLinc": "#0E1C4A",
                 "GraphST": "#D78FF8",
                 "SageNet": "#F46AA2",
                 "scVI": "#FE8B3B",
                 "expiMap": "#7E0028",
                 }

run_time_mean_df = summary_df.groupby(["dataset", "model"])[["run_time"]].mean().reset_index()
run_time_mean_df["run_time"] = run_time_mean_df["run_time"] / 60

def create_dataset_share_col(row):
    if row["dataset"] == "seqfish_mouse_organogenesis_embryo2":
        return 100
    elif row["dataset"] == "seqfish_mouse_organogenesis_subsample_50pct_embryo2":    
        return 50
    elif row["dataset"] == "seqfish_mouse_organogenesis_subsample_25pct_embryo2":    
        return 25
    elif row["dataset"] == "seqfish_mouse_organogenesis_subsample_10pct_embryo2":    
        return 10
    elif row["dataset"] == "seqfish_mouse_organogenesis_subsample_5pct_embryo2":    
        return 5
    elif row["dataset"] == "seqfish_mouse_organogenesis_subsample_1pct_embryo2":    
        return 1
    
run_time_mean_df["dataset_share"] = run_time_mean_df.apply(lambda row: create_dataset_share_col(row), axis=1)
    
ax = sns.lineplot(data=run_time_mean_df,
                  x="dataset_share",
                  y="run_time",
                  hue="model",
                  marker='o',
                  palette=model_palette)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
plt.title("seqFISH Mouse Organogenesis\n(14,891 Cells; 351 Genes)")
plt.ylabel("Run Time (Minutes)")
plt.xlabel("Dataset Size (%)")
custom_y_ticks = [1, 10, 60, 180, 360, 720, 1440]  # Adjust the tick positions as needed
plt.yscale("log")
plt.yticks(custom_y_ticks, None)
legend = plt.gca().get_legend()
for handle in legend.legendHandles:
    handle.set_linewidth(4.0)  # Adjust the size as needed
handles, labels = legend.legendHandles, [text.get_text() for text in legend.get_texts()]
order = [3, 2, 4, 1, 0]
ordered_handles = [handles[i] for i in order]
ordered_labels = [labels[i] for i in order]
plt.legend(ordered_handles, ordered_labels)
ax = plt.gca()
ax.legend().set_visible(False)
plt.savefig(benchmarking_folder_path + "/benchmarking_runtimes_seqfish_mouse_organogenesis.svg")
plt.show()

In [None]:
# Plot table
plot_metrics_table(
    df=unrolled_df,
    model_col="model",
    model_col_width=1.9,
    group_col="dataset",
    metric_cols=metric_cols_single_sample, # metric_cols_single_sample, category_cols_single_sample
    metric_col_weights=metric_col_weights_single_sample, # metric_col_weights_single_sample, category_col_weights_single_sample
    metric_col_titles=[col.replace(" ", "\n") for col in metric_col_titles_single_sample], # category_col_titles_single_sample
    metric_col_width=0.7, # 0.8,
    aggregate_col_width=1.,
    plot_width=42, # 32,
    plot_height=8,
    show=True,
    save_dir=benchmarking_folder_path,
    save_name=f"benchmarking_metrics_seqfish_mouse_organogenesis.svg")

#### 2.1.3 nanoString CosMx SMI Human Non-Small-Cell Lung Cancer (NSCLC)

In [None]:
datasets = ["nanostring_cosmx_human_nsclc_batch5",
            "nanostring_cosmx_human_nsclc_subsample_50pct_batch5",
            "nanostring_cosmx_human_nsclc_subsample_25pct_batch5",
            "nanostring_cosmx_human_nsclc_subsample_10pct_batch5",
            "nanostring_cosmx_human_nsclc_subsample_5pct_batch5",
            "nanostring_cosmx_human_nsclc_subsample_1pct_batch5"]
models = ["nichecompass_gcnconv",
          "nichecompass_gatv2conv",
          "staci",
          "deeplinc",
          "graphst",
          "sagenet",
          #"scvi",
          #"expimap"
         ]

summary_df = pd.DataFrame()
for dataset in datasets:
    dataset_df = pd.DataFrame()
    for model in models:
        try:
            benchmark_df = pd.read_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv")
            #adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
            #training_durations = []
            #for run_number in [1, 2, 3, 4, 5, 6, 7, 8]:
            #    training_durations.append(adata.uns[f"{model.split('_')[0]}_model_training_duration_run{run_number}"])
            #benchmark_df["run_time"] = training_durations
            #benchmark_df = benchmark_df[["dataset", "run_number", "run_time", "gcs", "mlami", "cas", "clisis", "nasw", "cnmi", "cari", "casw", "clisi"]]
            #benchmark_df.to_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv", index=False)
            benchmark_df["model"] = model
            dataset_df = pd.concat([dataset_df, benchmark_df], ignore_index=True)
        except FileNotFoundError:
            print(f"Did not find file {benchmarking_folder_path}/{dataset}_{model}_metrics.csv. Continuing...")
            missing_run_data = {
                "dataset": [dataset] * 8,
                "model": [model] * 8,
                "run_number": [1, 2, 3, 4, 5, 6, 7, 8],
                "run_time": [np.nan] * 8
            }
            missing_run_df = pd.DataFrame(missing_run_data)
            dataset_df = dataset_df.append(missing_run_df, ignore_index=True)
            
    # Apply min-max scaling to metric columns
    for i in range(len(metric_cols_single_sample)):
        min_val = dataset_df[metric_cols_single_sample[i]].min()
        max_val = dataset_df[metric_cols_single_sample[i]].max()
        dataset_df[metric_cols_single_sample[i] + "_scaled"] = ((
            dataset_df[metric_cols_single_sample[i]] - min_val) / (max_val - min_val))

    summary_df = pd.concat([summary_df, dataset_df], ignore_index=True)
    continue
    
cat_0_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[0:2]]
cat_1_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[2:4]]
cat_2_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[4:6]]
    
summary_df[category_cols_single_sample[0]] = np.average(summary_df[cat_0_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[0:2],
                                                        axis=1)
summary_df[category_cols_single_sample[1]] = np.average(summary_df[cat_1_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[2:4],
                                                        axis=1)
summary_df[category_cols_single_sample[2]] = np.average(summary_df[cat_2_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[4:6],
                                                        axis=1)
summary_df["Overall Score"] = np.average(summary_df[category_cols_single_sample[:3]],
                                         weights=category_col_weights_single_sample[:3],
                                         axis=1)
 
# Reformat for plot
summary_df.replace({"nichecompass_gatv2conv": "NicheCompass",
                    "nichecompass_gcnconv": "NicheCompass Light",
                    "staci": "STACI",
                    "deeplinc": "DeepLinc",
                    "expimap": "expiMap",
                    "graphst": "GraphST",
                    "sagenet": "SageNet",
                    "scvi": "scVI"}, inplace=True)

# Plot over all loss weights combinations
# Prepare metrics table plot
group_cols = ["dataset", "model"]
aggregate_df = summary_df.groupby(group_cols).mean("Overall Score").sort_values("Overall Score", ascending=False)[
    metric_cols_single_sample + ["Overall Score"]].reset_index()

unrolled_df = pd.melt(
    aggregate_df, 
    id_vars=group_cols,
    value_vars=metric_cols_single_sample + ["Overall Score"],
    var_name="score_type", 
    value_name="score")

# Create spatial indicator column
def is_spatially_aware_model(row):
    if row["model"] in ["NicheCompass GCN", "NicheCompass GATv2", "DeepLinc", "GraphST", "SageNet"]:
        return True
    return False
unrolled_df["spatially_aware"] = unrolled_df.apply(lambda row: is_spatially_aware_model(row), axis=1)
unrolled_df = unrolled_df[["dataset", "spatially_aware", "model", "score_type", "score"]]

# Order datasets
unrolled_df["dataset"] = pd.Categorical(unrolled_df["dataset"], categories=datasets, ordered=True)
unrolled_df = unrolled_df.sort_values(by="dataset")

#print(summary_df["model"].value_counts())
#print(summary_df[summary_df.isna().any(axis=1)])

In [None]:
model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "DeepLinc": "#0E1C4A",
                 "GraphST": "#D78FF8",
                 "SageNet": "#F46AA2",
                 "scVI": "#FE8B3B",
                 "expiMap": "#7E0028",
                 }

run_time_mean_df = summary_df.groupby(["dataset", "model"])[["run_time"]].mean().reset_index()
run_time_mean_df["run_time"] = run_time_mean_df["run_time"] / 60

def create_dataset_share_col(row):
    if row["dataset"] == "nanostring_cosmx_human_nsclc_batch5":
        return 100
    elif row["dataset"] == "nanostring_cosmx_human_nsclc_subsample_50pct_batch5":    
        return 50
    elif row["dataset"] == "nanostring_cosmx_human_nsclc_subsample_25pct_batch5":    
        return 25
    elif row["dataset"] == "nanostring_cosmx_human_nsclc_subsample_10pct_batch5":    
        return 10
    elif row["dataset"] == "nanostring_cosmx_human_nsclc_subsample_5pct_batch5":    
        return 5
    elif row["dataset"] == "nanostring_cosmx_human_nsclc_subsample_1pct_batch5":    
        return 1
    
run_time_mean_df["dataset_share"] = run_time_mean_df.apply(lambda row: create_dataset_share_col(row), axis=1)
    
ax = sns.lineplot(data=run_time_mean_df,
                  x="dataset_share",
                  y="run_time",
                  hue="model",
                  marker='o',
                  palette=model_palette)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
plt.title("nanoString CosMx Human NSCLC\n(77,391 Cells; 883 Genes)")
plt.ylabel("Run Time (Minutes)")
plt.xlabel("Dataset Size (%)")
custom_y_ticks = [1, 10, 60, 180, 360, 720, 1440]  # Adjust the tick positions as needed
plt.yscale("log")
plt.yticks(custom_y_ticks, None)
legend = plt.gca().get_legend()
for handle in legend.legendHandles:
    handle.set_linewidth(4.0)  # Adjust the size as needed
handles, labels = legend.legendHandles, [text.get_text() for text in legend.get_texts()]
order = [3, 2, 4, 1, 0, 5]
ordered_handles = [handles[i] for i in order]
ordered_labels = [labels[i] for i in order]
lgd = plt.legend(ordered_handles, ordered_labels, bbox_to_anchor=(1.05, 1), loc='upper left')
ax = plt.gca()
#ax.legend().set_visible(False)
plt.savefig(benchmarking_folder_path + "/benchmarking_runtimes_nanostring_cosmx_human_nsclc.svg", bbox_inches="tight", bbox_extra_artists=[lgd])
plt.show()

In [None]:
# Plot table
plot_metrics_table(
    df=unrolled_df,
    model_col="model",
    model_col_width=1.9,
    group_col="dataset",
    metric_cols=metric_cols_single_sample, # metric_cols_single_sample, category_cols_single_sample
    metric_col_weights=metric_col_weights_single_sample, # metric_col_weights_single_sample, category_col_weights_single_sample
    metric_col_titles=[col.replace(" ", "\n") for col in metric_col_titles_single_sample],
    metric_col_width=0.7, # 0.8,
    aggregate_col_width=1.,
    plot_width=42, # 32,
    plot_height=8,
    show=True,
    save_dir=benchmarking_folder_path,
    save_name=f"benchmarking_metrics_nanostring_cosmx_human_nsclc.svg")

#### 2.1.4 Vizgen MERFISH Mouse Liver

In [None]:
datasets = ["vizgen_merfish_mouse_liver",
            "vizgen_merfish_mouse_liver_subsample_50pct",
            "vizgen_merfish_mouse_liver_subsample_25pct",
            "vizgen_merfish_mouse_liver_subsample_10pct",
            "vizgen_merfish_mouse_liver_subsample_5pct",
            "vizgen_merfish_mouse_liver_subsample_1pct"]
models = ["nichecompass_gcnconv",
          "nichecompass_gatv2conv",
          "staci",
          "deeplinc",
          "graphst",
          "sagenet",
          #"scvi",
          #"expimap"
         ]

summary_df = pd.DataFrame()
for dataset in datasets:
    dataset_df = pd.DataFrame()
    for model in models:
        try:
            benchmark_df = pd.read_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv")
            #adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
            #training_durations = []
            #for run_number in [1, 2, 3, 4, 5, 6, 7, 8]:
            #    training_durations.append(adata.uns[f"{model.split('_')[0]}_model_training_duration_run{run_number}"])
            #benchmark_df["run_time"] = training_durations
            #benchmark_df = benchmark_df[["dataset", "run_number", "run_time", "gcs", "mlami", "cas", "clisis", "nasw", "cnmi", "cari", "casw", "clisi"]]
            #benchmark_df.to_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv", index=False)
            benchmark_df["model"] = model
            dataset_df = pd.concat([dataset_df, benchmark_df], ignore_index=True)
        except FileNotFoundError:
            print(f"Did not find file {benchmarking_folder_path}/{dataset}_{model}_metrics.csv. Continuing...")
            missing_run_data = {
                "dataset": [dataset] * 8,
                "model": [model] * 8,
                "run_number": [1, 2, 3, 4, 5, 6, 7, 8],
                "run_time": [np.nan] * 8
            }
            missing_run_df = pd.DataFrame(missing_run_data)
            dataset_df = dataset_df.append(missing_run_df, ignore_index=True)
            
    # Apply min-max scaling to metric columns
    for i in range(len(metric_cols_single_sample)):
        min_val = dataset_df[metric_cols_single_sample[i]].min()
        max_val = dataset_df[metric_cols_single_sample[i]].max()
        dataset_df[metric_cols_single_sample[i] + "_scaled"] = ((
            dataset_df[metric_cols_single_sample[i]] - min_val) / (max_val - min_val))

    summary_df = pd.concat([summary_df, dataset_df], ignore_index=True)
    continue
    
cat_0_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[0:2]]
cat_1_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[2:4]]
cat_2_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[4:6]]
    
summary_df[category_cols_single_sample[0]] = np.average(summary_df[cat_0_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[0:2],
                                                        axis=1)
summary_df[category_cols_single_sample[1]] = np.average(summary_df[cat_1_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[2:4],
                                                        axis=1)
summary_df[category_cols_single_sample[2]] = np.average(summary_df[cat_2_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[4:6],
                                                        axis=1)
summary_df["Overall Score"] = np.average(summary_df[category_cols_single_sample[:3]],
                                         weights=category_col_weights_single_sample[:3],
                                         axis=1)
 
# Reformat for plot
summary_df.replace({"nichecompass_gatv2conv": "NicheCompass",
                    "nichecompass_gcnconv": "NicheCompass Light",
                    "staci": "STACI",
                    "deeplinc": "DeepLinc",
                    "expimap": "expiMap",
                    "graphst": "GraphST",
                    "sagenet": "SageNet",
                    "scvi": "scVI"}, inplace=True)

# Plot over all loss weights combinations
# Prepare metrics table plot
group_cols = ["dataset", "model"]
aggregate_df = summary_df.groupby(group_cols).mean("Overall Score").sort_values("Overall Score", ascending=False)[
    metric_cols_single_sample + ["Overall Score"]].reset_index()

unrolled_df = pd.melt(
    aggregate_df, 
    id_vars=group_cols,
    value_vars=metric_cols_single_sample + ["Overall Score"],
    var_name="score_type", 
    value_name="score")

# Create spatial indicator column
def is_spatially_aware_model(row):
    if row["model"] in ["NicheCompass GCN", "NicheCompass GATv2", "DeepLinc", "GraphST", "SageNet"]:
        return True
    return False
unrolled_df["spatially_aware"] = unrolled_df.apply(lambda row: is_spatially_aware_model(row), axis=1)
unrolled_df = unrolled_df[["dataset", "spatially_aware", "model", "score_type", "score"]]

# Order datasets
unrolled_df["dataset"] = pd.Categorical(unrolled_df["dataset"], categories=datasets, ordered=True)
unrolled_df = unrolled_df.sort_values(by="dataset")

#print(summary_df["model"].value_counts())
#print(summary_df[summary_df.isna().any(axis=1)])

In [None]:
model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "DeepLinc": "#0E1C4A",
                 "GraphST": "#D78FF8",
                 "SageNet": "#F46AA2",
                 "scVI": "#FE8B3B",
                 "expiMap": "#7E0028",
                 }

run_time_mean_df = summary_df.groupby(["dataset", "model"])[["run_time"]].mean().reset_index()
run_time_mean_df["run_time"] = run_time_mean_df["run_time"] / 60

def create_dataset_share_col(row):
    if row["dataset"] == "vizgen_merfish_mouse_liver":
        return 100
    elif row["dataset"] == "vizgen_merfish_mouse_liver_subsample_50pct":    
        return 50
    elif row["dataset"] == "vizgen_merfish_mouse_liver_subsample_25pct":    
        return 25
    elif row["dataset"] == "vizgen_merfish_mouse_liver_subsample_10pct":    
        return 10
    elif row["dataset"] == "vizgen_merfish_mouse_liver_subsample_5pct":    
        return 5
    elif row["dataset"] == "vizgen_merfish_mouse_liver_subsample_1pct":    
        return 1
    
run_time_mean_df["dataset_share"] = run_time_mean_df.apply(lambda row: create_dataset_share_col(row), axis=1)
    
ax = sns.lineplot(data=run_time_mean_df,
                  x="dataset_share",
                  y="run_time",
                  hue="model",
                  marker='o',
                  palette=model_palette)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
plt.title("MERFISH Mouse Liver\n(367,335 Cells; 347 Genes)")
plt.ylabel("Run Time (Minutes)")
plt.xlabel("Dataset Size (%)")
custom_y_ticks = [1, 10, 60, 180, 360, 720, 1440]  # Adjust the tick positions as needed
plt.yscale("log")
plt.yticks(custom_y_ticks, None)
legend = plt.gca().get_legend()
for handle in legend.legendHandles:
    handle.set_linewidth(4.0)  # Adjust the size as needed
handles, labels = legend.legendHandles, [text.get_text() for text in legend.get_texts()]
order = [3, 2, 4, 1, 0]
ordered_handles = [handles[i] for i in order]
ordered_labels = [labels[i] for i in order]
plt.legend(ordered_handles, ordered_labels)
ax = plt.gca()
ax.legend().set_visible(False)
plt.savefig(benchmarking_folder_path + "/benchmarking_runtimes_vizgen_merfish_mouse_liver.svg")
plt.show()

In [None]:
# Plot table
plot_metrics_table(
    df=unrolled_df,
    model_col="model",
    model_col_width=1.9,
    group_col="dataset",
    metric_cols=metric_cols_single_sample, # metric_cols_single_sample, category_cols_single_sample
    metric_col_weights=metric_col_weights_single_sample, # metric_col_weights_single_sample, category_col_weights_single_sample
    metric_col_titles=[col.replace(" ", "\n") for col in metric_col_titles_single_sample], # category_col_titles_single_sample
    metric_col_width=0.7, # 0.8,
    aggregate_col_width=1.,
    plot_width=42, # 32,
    plot_height=8,
    show=True,
    save_dir=benchmarking_folder_path,
    save_name=f"benchmarking_metrics_vizgen_merfish_mouse_liver.svg")

#### 2.1.5 All Datasets

In [None]:
# Define params for plot formatting
fig_width_10_ticks = 8.2
fig_width_9_ticks = 7.8
fig_width_8_ticks = 7.4
fig_width_7_ticks = 7.0
fig_width_6_ticks = 6.6
fig_width_5_ticks = 6.2
fig_width_2_ticks = 5.1
fig_width_3_ticks = 5.5
fig_height = 5
fontsize = 14
row_fontsize = 16

In [None]:
# Load metrics
datasets = ["slideseqv2_mouse_hippocampus",
            "seqfish_mouse_organogenesis_embryo2",
            "vizgen_merfish_mouse_liver",
            "nanostring_cosmx_human_nsclc_batch5",]
models = ["nichecompass_gatv2conv",
          "nichecompass_gcnconv",
          "staci",
          "deeplinc",
          "graphst",
          "sagenet",
          #"scvi",
          #"expimap"
         ]

summary_df = pd.DataFrame()
for dataset in datasets:
    dataset_df = pd.DataFrame()
    for model in models:
        try:
            benchmark_df = pd.read_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv")
            #adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
            #training_durations = []
            #for run_number in [1, 2, 3, 4, 5, 6, 7, 8]:
            #    training_durations.append(adata.uns[f"{model.split('_')[0]}_model_training_duration_run{run_number}"])
            #benchmark_df["run_time"] = training_durations
            #benchmark_df = benchmark_df[["dataset", "run_number", "run_time", "gcs", "mlami", "cas", "clisis", "nasw", "cnmi", "cari", "casw", "clisi"]]
            #benchmark_df.to_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv", index=False)
            benchmark_df["model"] = model
            dataset_df = pd.concat([dataset_df, benchmark_df], ignore_index=True)
        except FileNotFoundError:
            print(f"Did not find file {benchmarking_folder_path}/{dataset}_{model}_metrics.csv. Continuing...")
            missing_run_data = {
                "dataset": [dataset] * 8,
                "model": [model] * 8,
                "run_number": [1, 2, 3, 4, 5, 6, 7, 8],
                "run_time": [np.nan] * 8
            }
            missing_run_df = pd.DataFrame(missing_run_data)
            dataset_df = dataset_df.append(missing_run_df, ignore_index=True)
            
    # Apply min-max scaling to metric columns
    for i in range(len(metric_cols_single_sample)):
        min_val = dataset_df[metric_cols_single_sample[i]].min()
        max_val = dataset_df[metric_cols_single_sample[i]].max()
        dataset_df[metric_cols_single_sample[i] + "_scaled"] = ((
            dataset_df[metric_cols_single_sample[i]] - min_val) / (max_val - min_val))

    summary_df = pd.concat([summary_df, dataset_df], ignore_index=True)
    continue

cat_0_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[0:2]]
cat_1_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[2:4]]
cat_2_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_single_sample[4:6]]
    
summary_df[category_cols_single_sample[0]] = np.average(summary_df[cat_0_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[0:2],
                                                        axis=1)
summary_df[category_cols_single_sample[1]] = np.average(summary_df[cat_1_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[2:4],
                                                        axis=1)
summary_df[category_cols_single_sample[2]] = np.average(summary_df[cat_2_scaled_metric_cols],
                                                        weights=metric_col_weights_single_sample[4:6],
                                                        axis=1)
summary_df["Overall Score"] = np.average(summary_df[category_cols_single_sample[:3]],
                                         weights=category_col_weights_single_sample[:3],
                                         axis=1)

# Reformat for plot
summary_df.replace({"nichecompass_gatv2conv": "NicheCompass",
                    "nichecompass_gcnconv": "NicheCompass Light",
                    "staci": "STACI",
                    "deeplinc": "DeepLinc",
                    "expimap": "expiMap",
                    "graphst": "GraphST",
                    "sagenet": "SageNet",
                    "scvi": "scVI"}, inplace=True)

# Plot over all loss weights combinations
# Prepare metrics table plot
group_cols = ["dataset", "model"]
aggregate_df = summary_df.groupby(group_cols).mean("Overall Score").sort_values("Overall Score", ascending=False)[
    metric_cols_single_sample + ["Overall Score"]].reset_index()

unrolled_df = pd.melt(
    aggregate_df, 
    id_vars=group_cols,
    value_vars=metric_cols_single_sample + ["Overall Score"],
    var_name="score_type", 
    value_name="score")

# Create spatial indicator column
def is_spatially_aware_model(row):
    if row["model"] in ["NicheCompass GCN", "NicheCompass GATv2", "DeepLinc", "GraphST", "SageNet"]:
        return True
    return False
unrolled_df["spatially_aware"] = unrolled_df.apply(lambda row: is_spatially_aware_model(row), axis=1)
unrolled_df = unrolled_df[["dataset", "spatially_aware", "model", "score_type", "score"]]

# Order datasets
unrolled_df["dataset"] = pd.Categorical(unrolled_df["dataset"], categories=datasets, ordered=True)
unrolled_df = unrolled_df.sort_values(by="dataset")

#print(summary_df["model"].value_counts())
#print(summary_df[summary_df.isna().any(axis=1)])

summary_df["model"] = summary_df["model"].replace("NicheCompass GCN", "NicheCompass Light")
summary_df["model"] = summary_df["model"].replace("NicheCompass GATv2", "NicheCompass")

summary_df["dataset"] = summary_df["dataset"].replace(
    {"slideseqv2_mouse_hippocampus": "SlideSeqV2 Mouse Hippocampus",
     "seqfish_mouse_organogenesis_embryo2": "seqFISH Mouse Organogenesis",
     "nanostring_cosmx_human_nsclc_batch5": "nanoString CosMx Human NSCLC",
     "vizgen_merfish_mouse_liver": "MERFISH Mouse Liver"})

In [None]:
sns.set_style("whitegrid")

model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "DeepLinc": "#0E1C4A",
                 "GraphST": "#D78FF8",
                 "SageNet": "#F46AA2",
                 "scVI": "#FE8B3B",
                 "expiMap": "#7E0028",
                 }

plt.figure(figsize=(fig_width_8_ticks, fig_height))
ax = sns.barplot(data=summary_df,
                 x="cas",
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette)
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + 1 * i, color='gray', linestyle='--', linewidth=2)

new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel("CAS", fontsize=fontsize)
plt.yticks(fontsize=row_fontsize)
plt.xticks([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], fontsize=fontsize)
plt.xlim(0.1, 0.9)
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_barplot_cas.svg")
plt.show()

In [None]:
temp_df = summary_df[summary_df["dataset"] == "SlideSeqV2 Mouse Hippocampus"]
metrics_temp_df = temp_df.groupby("model")[["cas"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[1] # NC vs GraphST

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis"]
metrics_temp_df = temp_df.groupby("model")[["cas"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[4] # NC vs STACI

In [None]:
sns.set_style("whitegrid")

model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "DeepLinc": "#0E1C4A",
                 "GraphST": "#D78FF8",
                 "SageNet": "#F46AA2",
                 "scVI": "#FE8B3B",
                 "expiMap": "#7E0028",
                 }

plt.figure(figsize=(fig_width_8_ticks, fig_height))
ax = sns.barplot(data=summary_df,
                 x="mlami",
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette)
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + 1 * i, color='gray', linestyle='--', linewidth=2)

new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel("MLAMI", fontsize=fontsize)
plt.yticks(fontsize=row_fontsize)
plt.xticks([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], fontsize=fontsize)
plt.xlim(0.1, 0.9)
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_barplot_mlami.svg")
plt.show()

In [None]:
temp_df = summary_df[summary_df["dataset"] == "SlideSeqV2 Mouse Hippocampus"]
metrics_temp_df = temp_df.groupby("model")[["mlami"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[1] # NC vs GraphST

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis"]
metrics_temp_df = temp_df.groupby("model")[["mlami"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[4] # NC vs STACI

In [None]:
sns.set_style("whitegrid")

model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "DeepLinc": "#0E1C4A",
                 "GraphST": "#D78FF8",
                 "SageNet": "#F46AA2",
                 "scVI": "#FE8B3B",
                 "expiMap": "#7E0028",
                 }

plt.figure(figsize=(fig_width_2_ticks, fig_height))
ax = sns.barplot(data=summary_df,
                 x="clisis",
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette)
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + 1 * i, color='gray', linestyle='--', linewidth=2)

new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel("CLISIS", fontsize=fontsize)
plt.yticks(fontsize=row_fontsize)
plt.xticks([0.8, 0.9], fontsize=fontsize)
plt.xlim(0.8, 1.0)
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_barplot_clisis.svg")
plt.show()

In [None]:
temp_df = summary_df[summary_df["dataset"] == "SlideSeqV2 Mouse Hippocampus"]
metrics_temp_df = temp_df.groupby("model")[["clisis"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[1] # NC vs GraphST

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis"]
metrics_temp_df = temp_df.groupby("model")[["clisis"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[4] # NC vs STACI

In [None]:
sns.set_style("whitegrid")

model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "DeepLinc": "#0E1C4A",
                 "GraphST": "#D78FF8",
                 "SageNet": "#F46AA2",
                 "scVI": "#FE8B3B",
                 "expiMap": "#7E0028",
                 }

plt.figure(figsize=(fig_width_3_ticks, fig_height))
ax = sns.barplot(data=summary_df,
                 x="gcs",
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette)
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + 1 * i, color='gray', linestyle='--', linewidth=2)

new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel("GCS", fontsize=fontsize)
plt.yticks(fontsize=row_fontsize)
plt.xticks([0.7, 0.8, 0.9], fontsize=fontsize)
plt.xlim(0.7, 1.0)
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_barplot_gcs.svg")
plt.show()

In [None]:
temp_df = summary_df[summary_df["dataset"] == "SlideSeqV2 Mouse Hippocampus"]
metrics_temp_df = temp_df.groupby("model")[["gcs"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[1] # NC vs GraphST

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis"]
metrics_temp_df = temp_df.groupby("model")[["gcs"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[4] # NC vs STACI

sns.set_style("whitegrid")

model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "DeepLinc": "#0E1C4A",
                 "GraphST": "#D78FF8",
                 "SageNet": "#F46AA2",
                 "scVI": "#FE8B3B",
                 "expiMap": "#7E0028",
                 }

plt.figure(figsize=(fig_width_9_ticks, fig_height))
ax = sns.barplot(data=summary_df,
                 x="Local Spatial Conservation Score",
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette)
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + 1 * i, color='gray', linestyle='--', linewidth=2)

new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel("Local Spatial Conservation Score", fontsize=fontsize)
plt.yticks(fontsize=row_fontsize)
plt.xticks([0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], fontsize=fontsize)
plt.xlim(0., 1.0)
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_barplot_local_spatial_conservation_score.svg")
plt.show()

In [None]:
summary_df["Spatial Conservation Score"] = (summary_df["Global Spatial Conservation Score"] + summary_df["Local Spatial Conservation Score"])/2

In [None]:
sns.set_style("whitegrid")

model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "DeepLinc": "#0E1C4A",
                 "GraphST": "#D78FF8",
                 "SageNet": "#F46AA2",
                 "scVI": "#FE8B3B",
                 "expiMap": "#7E0028",
                 }

plt.figure(figsize=(fig_width_10_ticks, fig_height))
ax = sns.barplot(data=summary_df,
                 x="Spatial Conservation Score",
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette)
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + 1 * i, color='gray', linestyle='--', linewidth=2)

new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel("Spatial Conservation Score", fontsize=fontsize)
plt.yticks(fontsize=row_fontsize)
plt.xticks([0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], fontsize=fontsize)
plt.xlim(0., 1.0)
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_barplot_spatial_conservation_score.svg")
plt.show()

In [None]:
temp_df = summary_df[summary_df["dataset"] == "SlideSeqV2 Mouse Hippocampus"]
metrics_temp_df = temp_df.groupby("model")[["Spatial Conservation Score"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[1] # NC vs GraphST

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis"]
metrics_temp_df = temp_df.groupby("model")[["Spatial Conservation Score"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[4] # NC vs STACI

In [None]:
sns.set_style("whitegrid")

model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "DeepLinc": "#0E1C4A",
                 "GraphST": "#D78FF8",
                 "SageNet": "#F46AA2",
                 "scVI": "#FE8B3B",
                 "expiMap": "#7E0028",
                 }

plt.figure(figsize=(fig_width_6_ticks, fig_height))
ax = sns.barplot(data=summary_df,
                 x="cnmi",
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette)
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + 1 * i, color='gray', linestyle='--', linewidth=2)

new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel("CNMI", fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.xticks([0., 0.1, 0.2, 0.3, 0.4, 0.5], fontsize=fontsize)
plt.xlim(0., 0.6)
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_barplot_cnmi.svg")
plt.show()

In [None]:
sns.set_style("whitegrid")

model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "DeepLinc": "#0E1C4A",
                 "GraphST": "#D78FF8",
                 "SageNet": "#F46AA2",
                 "scVI": "#FE8B3B",
                 "expiMap": "#7E0028",
                 }

plt.figure(figsize=(fig_width_2_ticks, fig_height))
ax = sns.barplot(data=summary_df,
                 x="nasw",
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette)
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + 1 * i, color='gray', linestyle='--', linewidth=2)

new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel("NASW", fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.xticks([0.5, 0.6], fontsize=fontsize)
plt.xlim(0.5, 0.7)
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_barplot_nasw.svg")
plt.show()

In [None]:
temp_df = summary_df[summary_df["dataset"] == "SlideSeqV2 Mouse Hippocampus"]
metrics_temp_df = temp_df.groupby("model")[["nasw"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[1] # NC vs GraphST

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis"]
metrics_temp_df = temp_df.groupby("model")[["nasw"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[4] # NC vs STACI

In [None]:
sns.set_style("whitegrid")

model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "DeepLinc": "#0E1C4A",
                 "GraphST": "#D78FF8",
                 "SageNet": "#F46AA2",
                 "scVI": "#FE8B3B",
                 "expiMap": "#7E0028",
                 }

plt.figure(figsize=(fig_width_7_ticks, 5))
ax = sns.barplot(data=summary_df,
                 x="Niche Coherence Score",
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette)
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + 1 * i, color='gray', linestyle='--', linewidth=2)

new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel("Niche Coherence Score", fontsize=fontsize)
plt.yticks(fontsize=row_fontsize)
plt.xticks([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], fontsize=fontsize)
plt.xlim(0.1, 0.8)
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_barplot_niche_coherence_score.svg")
plt.show()

In [None]:
temp_df = summary_df[summary_df["dataset"] == "SlideSeqV2 Mouse Hippocampus"]
metrics_temp_df = temp_df.groupby("model")[["Niche Coherence Score"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[1] # NC vs GraphST

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis"]
metrics_temp_df = temp_df.groupby("model")[["Niche Coherence Score"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[4] # NC vs STACI

In [None]:
model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "DeepLinc": "#0E1C4A",
                 "GraphST": "#D78FF8",
                 "SageNet": "#F46AA2",
                 "scVI": "#FE8B3B",
                 "expiMap": "#7E0028",
                 }

plt.figure(figsize=(fig_width_6_ticks, fig_height))
ax = sns.barplot(data=summary_df,
                 x="Overall Score",
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette)
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + 1 * i, color='gray', linestyle='--', linewidth=2)

new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel("Overall Score", fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.xticks([0.2, 0.3, 0.4, 0.5, 0.6, 0.7], fontsize=fontsize)
plt.xlim(0.2, 0.8)
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_barplot_overall_score.svg")
plt.show()

In [None]:
temp_df = summary_df[summary_df["dataset"] == "SlideSeqV2 Mouse Hippocampus"]
metrics_temp_df = temp_df.groupby("model")[["Overall Score"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[1] # NC vs GraphST

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis"]
metrics_temp_df = temp_df.groupby("model")[["Overall Score"]].mean()
metrics_temp_df.iloc[2] / metrics_temp_df.iloc[4] # NC vs STACI