# Sample Integration Method Benchmarking

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>)
- **Date of Creation:** 20.01.2023
- **Date of Last Modification:** 21.12.2024

Run this notebook in the nichecompass-reproducibility environment, installable from ```('../../../envs/environment.yaml')```. Before running this notebook, compute the sample integration method benchmarking metrics by triggering the
respective jobs (under **Sample Integration Method Benchmarking** header) in ```('../../slurm_job_submission.ipynb')```.


## 1. Setup

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../../../utils")

In [None]:
import argparse
import os
import gc
import pickle
import random
import shutil
import warnings
from copy import deepcopy
from datetime import datetime

import anndata as ad
#import cellcharter as cc
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plottable
import scanpy as sc
import scipy.sparse as sp
#import scvi
import seaborn as sns
import squidpy as sq
import torch
#from GraphST import GraphST
from matplotlib import gridspec
from matplotlib.pyplot import rc_context
from plottable import ColumnDefinition, Table
from plottable.cmap import normed_cmap
from plottable.formatters import tickcross
from plottable.plots import bar
from sklearn.decomposition import KernelPCA
from nichecompass.benchmarking import compute_benchmarking_metrics
from nichecompass.models import NicheCompass

from benchmarking_utils import *

### 1.2 Define Parameters

In [None]:
metric_cols_sample_integration = [
    "cas", "mlami", # global spatial consistency
    "clisis", "gcs", # local spatial consistency
    "nasw", "cnmi", # niche coherence
    "blisi", "pcr" # batch correction
]
metric_col_weights_sample_integration = [ # separate for each category (later multiplied with category_col_weights)
    (1/2), (1/2), # global spatial consistency
    (1/2), (1/2), # local spatial consistency
    1, 1, # niche coherence
    1, 1, # batch correction
]
metric_col_titles_sample_integration = [
    "CAS", # "Cell Type Affinity Similarity",
    "MLAMI", # "Maximum Leiden Adjusted Mutual Info",
    "CLISIS", # "Cell Type Local Inverse Simpson's Index Similarity",
    "GCS", # "Graph Connectivity Similarity",
    "NASW", # "Niche Average Silhouette Width",
    "CNMI", # "Cell Type Normalized Mutual Info",
    "BLISI", # "Batch Local Inverse Simpson's Index"
    "PCR" # "Principal Component Regression"
]
metric_cols_single_sample = metric_cols_sample_integration[:-3]
metric_col_weights_single_sample = metric_col_weights_sample_integration[:-3]
metric_col_titles_single_sample = metric_col_titles_sample_integration[:-3]

category_cols_sample_integration = [
    "Global Spatial Consistency Score",
    "Local Spatial Consistency Score",
    "Niche Coherence Score",
    "Batch Correction Score"]
category_col_weights_sample_integration = [
    1,
    1,
    2,
    2]
category_col_titles_sample_integration = [
    "Global Spatial Consistency Score",
    "Local Spatial Consistency Score",
    "Niche Coherence Score",
    "Batch Correction Score"]
category_col_weights_single_sample = category_col_weights_sample_integration[:-1]
category_cols_single_sample = category_cols_sample_integration[:-1]
category_col_titles_single_sample = [
    "Global Spatial Consistency Score",
    "Local Spatial Consistency Score",
    "Niche Coherence Score"]

### 1.3 Run Notebook Setup

In [None]:
sc.set_figure_params(figsize=(6, 6))
sns.set_style("whitegrid", {'axes.grid' : False})

In [None]:
# Ignore future warnings and user warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)

In [None]:
plt.rcParams['font.family'] = 'Helvetica'
plt.rcParams['font.size'] = 5

In [None]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

### 1.4 Configure Paths and Create Directories

In [None]:
data_folder_path = "../../../datasets/st_data/gold"
artifact_folder_path = f"../../../artifacts"
benchmarking_folder_path = f"{artifact_folder_path}/sample_integration_method_benchmarking"

## 2. Method Benchmarking Evaluation

- Run all model notebooks in this directory (```./```) before continuing.

### 2.1 Retrieve NicheCompass Runs & Fix Datasets

#### 2.1.1 seqFISH Mouse Organogenesis

#### 2.1.2 seqFISH Mouse Organogenesis Imputed

#### 2.1.3 nanoString CosMx SMI Human Non-Small-Cell Lung Cancer (NSCLC)

##### 2.1.3.2 Fix Observation Uniqueness

##### 2.1.3.3 Metrics Computation Split Due to Memory Constraints

### 2.2 Create Benchmarking Plots

#### 2.2.1 nanoString CosMx SMI Human Non-Small-Cell Lung Cancer (NSCLC)

In [None]:
sns.color_palette("colorblind").as_hex()

In [None]:
print(sns.color_palette("colorblind").as_hex())

In [None]:
### Fig. 3e: 10 pct subsample comparison NicheCompass with FoV embedding ###
dataset = "nanostring_cosmx_human_nsclc_subsample_10pct"
run_number = 2
model = "nichecompass_gatv2conv"
latent_leiden_resolution = 0.21
timestamp = "26082023_191240_1"
adata = sc.read_h5ad(f"{artifact_folder_path}/{dataset}/models/gatv2conv_sample_integration_method_benchmarking/{timestamp}/run{run_number}/{dataset}_gatv2conv_sample_integration_method_benchmarking.h5ad")

print("\nComputing UMAP embedding...")
sc.tl.umap(adata,
           neighbors_key=f"nichecompass_latent")

print("\nComputing Leiden clustering...")
sc.tl.leiden(adata=adata,
             resolution=latent_leiden_resolution,
             key_added=f"latent_leiden_{latent_leiden_resolution}",
             neighbors_key=f"nichecompass_latent")

niche_annotation_dict = {
    "0": "Endothelial/Stroma",
    "1": "Tumor 1",
    "2": "Lymphoid Structures",
    "3": "Epithelial/Stroma",
    "4": "Neutrophil/Myeloid/Stroma",
    "5": "Plasmablast/Stroma",
    "6": "Tumor-Stroma Boundary",
    "7": "Tumor 2"}

latent_cluster_colors = create_new_color_dict(
    adata=adata,
    cat_key=f"latent_leiden_{latent_leiden_resolution}")

adata.obs["niche"] = adata.obs[f"latent_leiden_{latent_leiden_resolution}"].map(niche_annotation_dict)
niche_colors = {niche: latent_cluster_colors[cluster] for cluster, niche in niche_annotation_dict.items()}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Niches",
        model_label={model},
        cat_key="niche",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=niche_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_niches_{model}.svg")

cell_type_colors = create_new_color_dict(
    adata=adata,
    color_palette="cell_type_20",
    cat_key="cell_type")

#adata_temp = adata[adata.obs["niche"].isin(["Endothelial/Stroma",
#                                            "Plasmablast/Stroma",
#                                            "Tumor-Stroma Boundary"])]
adata_temp = adata

tmp = pd.crosstab(adata_temp.obs["niche"], adata_temp.obs["cell_type"], normalize='index')
tmp.rename(index={"Tumor-Stroma Boundary": "Tumor-Stroma\nBoundary"},
           inplace=True)
ax = tmp.plot.barh(color=cell_type_colors, stacked=True, figsize=(3, 6)).legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.yticks(fontsize=14)
plt.xlabel("Cell Type Proportions", fontsize=16)
plt.savefig(f"{benchmarking_folder_path}/{model}_{dataset}_niche_cell_type_proportions.svg", bbox_inches='tight')

adata.obs["batch"] = adata.obs["batch"].replace({"lung5_rep1": "Replicate 1",
                                                 "lung5_rep2": "Replicate 2",
                                                 "lung5_rep3": "Replicate 3"})

batch_colors = {"Replicate 1": sns.color_palette("colorblind").as_hex()[0],
                "Replicate 2": sns.color_palette("colorblind").as_hex()[3],
                "Replicate 3": sns.color_palette("colorblind").as_hex()[8]}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Batches",
        model_label={model},
        cat_key="batch",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=batch_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_batches_{model}.svg")

In [None]:
### Fig. 3e: 10 pct subsample comparison GraphST with prior alignment through PASTE algorithm ###
model = "graphst_paste"
latent_leiden_resolution = 0.4
adata = sc.read_h5ad(f"{artifact_folder_path}/sample_integration_method_benchmarking/{dataset}_{model}.h5ad")

print("\nComputing UMAP embedding...")
sc.tl.umap(adata,
           neighbors_key=f"graphst_latent_run{run_number}")

print("\nComputing Leiden clustering...")
sc.tl.leiden(adata=adata,
             resolution=latent_leiden_resolution,
             key_added=f"latent_leiden_{latent_leiden_resolution}",
             neighbors_key=f"graphst_latent_run{run_number}")

latent_cluster_colors = {
    "0": "#66C5CC",
    "1": "#9EB9F3",
    "2": "#C9DB74",
    "3": "#DCB0F2",
    "4": "#8B008B",
    "5": "#FE88B1",
    "6": "#F6CF71",
    "7": "#F89C74"}

niche_annotation_dict = {
    "0": "Endothelial/Stroma (Replicate 1 & 3)",
    "1": "Tumor 1",
    "2": "Lymphoid Structures",
    "3": "Neutrophil/Myeloid/Stroma",
    "4": "Endothelial/Stroma (Replicate 2)",
    "5": "Tumor-Stroma Boundary",
    "6": "Epithelial/Stroma",
    "7": "Tumor 2"}

adata.obs["niche"] = adata.obs[f"latent_leiden_{latent_leiden_resolution}"].map(niche_annotation_dict)
niche_colors = {niche: latent_cluster_colors[cluster] for cluster, niche in niche_annotation_dict.items()}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Niches",
        model_label={model},
        cat_key=f"niche",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=niche_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_niches_{model}.svg")

cell_type_colors = create_new_color_dict(
    adata=adata,
    color_palette="cell_type_20",
    cat_key="cell_type")

#adata_temp = adata[adata.obs["niche"].isin(["Endothelial/Stroma (Replicate 2)",
#                                            "Endothelial/Stroma (Replicate 1 & 3)",
#                                            "Tumor-Stroma Boundary"])]
adata_temp = adata

tmp = pd.crosstab(adata_temp.obs["niche"], adata_temp.obs["cell_type"], normalize='index')
tmp.rename(index={"Endothelial/Stroma (Replicate 2)": "Endothelial/Stroma\n(Replicate 2)",
                  "Endothelial/Stroma (Replicate 1 & 3)": "Endothelial/Stroma\n(Replicate 1 & 3)",
                  "Tumor-Stroma Boundary": "Tumor-Stroma\nBoundary"},
           inplace=True)
ax = tmp.plot.barh(color=cell_type_colors, stacked=True, figsize=(3, 6)).legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.yticks(fontsize=14)
plt.xlabel("Cell Type Proportions", fontsize=16)
plt.savefig(f"{benchmarking_folder_path}/{model}_{dataset}_niche_cell_type_proportions.svg", bbox_inches='tight')

adata.obs["batch"] = adata.obs["batch"].replace({"lung5_rep1": "Replicate 1",
                                                 "lung5_rep2": "Replicate 2",
                                                 "lung5_rep3": "Replicate 3"})

batch_colors = {"Replicate 1": sns.color_palette("colorblind").as_hex()[0],
                "Replicate 2": sns.color_palette("colorblind").as_hex()[3],
                "Replicate 3": sns.color_palette("colorblind").as_hex()[8]}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Batches",
        model_label={model},
        cat_key="batch",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=batch_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_batches_{model}.svg")

In [None]:
### Fig. 3e: 10 pct subsample comparison STACI ###
model = "staci"
run_number = 1
latent_leiden_resolution = 0.19 # 0.18
adata = sc.read_h5ad(f"{artifact_folder_path}/sample_integration_method_benchmarking/{dataset}_{model}.h5ad")

sc.pp.neighbors(adata,
                use_rep=f"staci_latent_run{run_number}",
                key_added=f"staci_latent_run{run_number}")

print("\nComputing UMAP embedding...")
sc.tl.umap(adata,
           neighbors_key=f"staci_latent_run{run_number}")

print("\nComputing Leiden clustering...")
sc.tl.leiden(adata=adata,
             resolution=latent_leiden_resolution,
             key_added=f"latent_leiden_{latent_leiden_resolution}",
             neighbors_key=f"staci_latent_run{run_number}")

latent_cluster_colors = {
    "0": "#9EB9F3",
    "1": "#F6CF71",
    "2": "#C9DB74",
    "3": "#66C5CC",
    "4": "#87C55F",
    "5": "#DCB0F2",
    "6": "#8B008B",
    "7": "#F89C74"}

niche_annotation_dict = {
    "0": "Tumor 1",
    "1": "Epithelial/Stroma",
    "2": "Lymphoid Structures",
    "3": "Endothelial/Stroma (Replicate 1 & 3)",
    "4": "Plasmablast/Stroma",
    "5": "Neutrophil/Myeloid/Stroma",
    "6": "Endothelial/Stroma (Replicate 2)",
    "7": "Tumor 2"}

adata.obs["niche"] = adata.obs[f"latent_leiden_{latent_leiden_resolution}"].map(niche_annotation_dict)
niche_colors = {niche: latent_cluster_colors[cluster] for cluster, niche in niche_annotation_dict.items()}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Niches",
        model_label={model},
        cat_key=f"niche",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=niche_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_niches_{model}.svg")

cell_type_colors = create_new_color_dict(
    adata=adata,
    color_palette="cell_type_20",
    cat_key="cell_type")

#adata_temp = adata[adata.obs["niche"].isin(["Endothelial/Stroma (Replicate 2)",
#                                            "Endothelial/Stroma (Replicate 1 & 3)",
#                                            "Plasmablast/Stroma"])]
adata_temp = adata

tmp = pd.crosstab(adata_temp.obs["niche"], adata_temp.obs["cell_type"], normalize='index')
tmp.rename(index={"Endothelial/Stroma (Replicate 1 & 3)": "Endothelial/Stroma\n(Replicate 1 & 3)",
                  "Endothelial/Stroma (Replicate 2)": "Endothelial/Stroma\n(Replicate 2)"}, inplace=True)
ax = tmp.plot.barh(color=cell_type_colors, stacked=True, figsize=(3, 6)).legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.yticks(fontsize=14)
plt.xlabel("Cell Type Proportions", fontsize=16)
plt.savefig(f"{benchmarking_folder_path}/{model}_{dataset}_niche_cell_type_proportions.svg", bbox_inches='tight')

adata.obs["batch"] = adata.obs["batch"].replace({"lung5_rep1": "Replicate 1",
                                                 "lung5_rep2": "Replicate 2",
                                                 "lung5_rep3": "Replicate 3"})

batch_colors = {"Replicate 1": sns.color_palette("colorblind").as_hex()[0],
                "Replicate 2": sns.color_palette("colorblind").as_hex()[3],
                "Replicate 3": sns.color_palette("colorblind").as_hex()[8]}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Batches",
        model_label={model},
        cat_key="batch",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=batch_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_batches_{model}.svg")

In [None]:
### Fig. 3e: 10 pct subsample comparison CellCharter ###
model = "cellcharter"
run_number = 1
latent_leiden_resolution = 0.21
adata = sc.read_h5ad(f"{artifact_folder_path}/sample_integration_method_benchmarking/{dataset}_{model}.h5ad")

print("\nComputing UMAP embedding...")
sc.tl.umap(adata,
           neighbors_key=f"{model}_latent_run{run_number}")

print("\nComputing Leiden clustering...")
sc.tl.leiden(adata=adata,
             resolution=latent_leiden_resolution,
             key_added=f"latent_leiden_{latent_leiden_resolution}",
             neighbors_key=f"{model}_latent_run{run_number}")

#print("\nComputing CellCharter clustering...")
# Use cellcharter clustering
#gmm = cc.tl.Cluster(
#    n_clusters=8, 
#    random_state=1,
#    # If running on GPU
#    #trainer_params=dict(accelerator='gpu', devices=1)
#)
#gmm.fit(adata, use_rep=f"{model}_latent_run{run_number}")
#adata.obs["spatial_clusters"] = gmm.predict(adata, use_rep=f"{model}_latent_run{run_number}").astype("str")

latent_cluster_colors = {
    "0": "#66C5CC",
    "1": "#87C55F",
    "2": "#9EB9F3",
    "3": "#DCB0F2",
    "4": "#8B008B",
    "5": "#FE88B1",
    "6": "#FE88B1",
    "7": "#2F4F4F",
    "8": "#FF00FF"}

niche_annotation_dict = {
    "0": "Endothelial/Stroma (Replicate 1 & 3)",
    "1": "Plasmablast/Stroma & Lymphoid Structures",
    "2": "Tumor 1",
    "3": "Neutrophil/Myeloid/Stroma",
    "4": "Endothelial/Stroma (Replicate 2)",
    "5": "Tumor 2",
    "6": "Tumor-Stroma Boundary",
    "7": "Tumor 3",
    "8": "Tumor 4"}

adata.obs["niche"] = adata.obs[f"latent_leiden_{latent_leiden_resolution}"].map(niche_annotation_dict)
niche_colors = {niche: latent_cluster_colors[cluster] for cluster, niche in niche_annotation_dict.items()}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Niches",
        model_label={model},
        cat_key=f"niche",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=niche_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_niches_{model}.svg")

cell_type_colors = create_new_color_dict(
    adata=adata,
    color_palette="cell_type_20",
    cat_key="cell_type")

#adata_temp = adata[adata.obs["niche"].isin(["Endothelial/Stroma (Replicate 2)",
#                                            "Endothelial/Stroma (Replicate 1 & 3)",
#                                            "Tumor-Stroma Boundary",
#                                            "Plasmablast/Stroma"])]
adata_temp = adata

tmp = pd.crosstab(adata_temp.obs["niche"], adata_temp.obs["cell_type"], normalize='index')
tmp.rename(index={"Endothelial/Stroma (Replicate 2)": "Endothelial/Stroma\n(Replicate 2)",
                  "Endothelial/Stroma (Replicate 1 & 3)": "Endothelial/Stroma\n(Replicate 1 & 3)",
                  "Tumor-Stroma Boundary": "Tumor-Stroma\nBoundary",
                  "Plasmablast/Stroma & Lymphoid Structures": "Plasmablast/Stroma &\n Lymphoid Structures"},
           inplace=True)
ax = tmp.plot.barh(color=cell_type_colors, stacked=True, figsize=(3, 6)).legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.yticks(fontsize=14)
plt.xlabel("Cell Type Proportions", fontsize=16)
plt.savefig(f"{benchmarking_folder_path}/{model}_{dataset}_niche_cell_type_proportions.svg", bbox_inches='tight')

adata.obs["batch"] = adata.obs["batch"].replace({"lung5_rep1": "Replicate 1",
                                                 "lung5_rep2": "Replicate 2",
                                                 "lung5_rep3": "Replicate 3"})

batch_colors = {"Replicate 1": sns.color_palette("colorblind").as_hex()[0],
                "Replicate 2": sns.color_palette("colorblind").as_hex()[3],
                "Replicate 3": sns.color_palette("colorblind").as_hex()[8]}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Batches",
        model_label={model},
        cat_key="batch",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=batch_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_batches_{model}.svg")

In [None]:
### Fig. 3e: 10 pct subsample comparison BANKSY ###
model = "banksy"
dataset = "nanostring_cosmx_human_nsclc_subsample_10pct"
run_number = 1
latent_leiden_resolution = 0.48
adata = sc.read_h5ad(f"{artifact_folder_path}/sample_integration_method_benchmarking/{dataset}_{model}.h5ad")

sc.pp.neighbors(adata,
                use_rep=f"{model}_latent_run{run_number}",
                key_added=f"{model}_latent_run{run_number}")

print("\nComputing UMAP embedding...")
sc.tl.umap(adata,
           neighbors_key=f"{model}_latent_run{run_number}")

print("\nComputing Leiden clustering...")
sc.tl.leiden(adata=adata,
             resolution=latent_leiden_resolution,
             key_added=f"latent_leiden_{latent_leiden_resolution}",
             neighbors_key=f"{model}_latent_run{run_number}")

latent_cluster_colors = {
    "0": "#F6CF71",
    "1": "#9EB9F3",
    "2": "#66C5CC",
    "3": "#DCB0F2",
    "4": "#8B4513",
    "5": "#FE88B1",
    "6": "#F89C74",
    "7": "#FF6347",
    "8": "#2F4F4F"}

niche_annotation_dict = {
    "0": "Epithelial/Stroma",
    "1": "Tumor 1",
    "2": "Endothelial/Stroma (Replicate 1 & 3)",
    "3": "Neutrophil/Myeloid/Stroma",
    "4": "Artifact 2",
    "5": "Tumor-Stroma Boundary", #"Plasmablast/Stroma",
    "6": "Tumor 2",
    "7": "Artifact 1",
    "8": "Tumor 3"}

adata.obs["niche"] = adata.obs[f"latent_leiden_{latent_leiden_resolution}"].map(niche_annotation_dict)
niche_colors = {niche: latent_cluster_colors[cluster] for cluster, niche in niche_annotation_dict.items()}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Niches",
        model_label={model},
        cat_key=f"niche",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=niche_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_niches_{model}.svg")

cell_type_colors = create_new_color_dict(
    adata=adata,
    color_palette="cell_type_20",
    cat_key="cell_type")

adata_temp = adata

tmp = pd.crosstab(adata_temp.obs["niche"], adata_temp.obs["cell_type"], normalize='index')
tmp.rename(index={"Endothelial/Stroma (Replicate 1 & 3)": "Endothelial/Stroma\n(Replicate 1 & 3)",
                  "Endothelial/Stroma (Replicate 2)": "Endothelial/Stroma\n(Replicate 2)"}, inplace=True)
ax = tmp.plot.barh(color=cell_type_colors, stacked=True, figsize=(3, 6)).legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.yticks(fontsize=14)
plt.xlabel("Cell Type Proportions", fontsize=16)
plt.savefig(f"{benchmarking_folder_path}/{model}_{dataset}_niche_cell_type_proportions.svg", bbox_inches='tight')

adata.obs["batch"] = adata.obs["batch"].replace({"lung5_rep1": "Replicate 1",
                                                 "lung5_rep2": "Replicate 2",
                                                 "lung5_rep3": "Replicate 3"})

batch_colors = {"Replicate 1": sns.color_palette("colorblind").as_hex()[0],
                "Replicate 2": sns.color_palette("colorblind").as_hex()[3],
                "Replicate 3": sns.color_palette("colorblind").as_hex()[8]}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Batches",
        model_label={model},
        cat_key="batch",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=batch_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_batches_{model}.svg")

In [None]:
# Supplementary figure: PCR batch effect
adata_batch1 = sc.read_h5ad(f"{data_folder_path}/nanostring_cosmx_human_nsclc_subsample_10pct_batch1.h5ad")
adata_batch2 = sc.read_h5ad(f"{data_folder_path}/nanostring_cosmx_human_nsclc_subsample_10pct_batch2.h5ad")
adata_batch3 = sc.read_h5ad(f"{data_folder_path}/nanostring_cosmx_human_nsclc_subsample_10pct_batch3.h5ad")

adata_combined = sc.concat([adata_batch1, adata_batch2, adata_batch3], axis=0)

sc.tl.pca(adata_combined)
sc.pp.neighbors(adata_combined, n_pcs=20)
sc.tl.umap(adata_combined)

adata_combined.obs["batch"] = adata_combined.obs["batch"].replace(
    {"lung5_rep1": "Replicate 1",
     "lung5_rep2": "Replicate 2",
     "lung5_rep3": "Replicate 3"})

adata_combined.obs["fov"] = adata_combined.obs["fov"].replace(
    {"lung5_rep1_1": "Replicate 1 FoV 1",
     "lung5_rep1_2": "Replicate 1 FoV 2",
     "lung5_rep1_3": "Replicate 1 FoV 3",
     "lung5_rep1_4": "Replicate 1 FoV 4",
     "lung5_rep2_1": "Replicate 2 FoV 1",
     "lung5_rep2_2": "Replicate 2 FoV 2",
     "lung5_rep2_3": "Replicate 2 FoV 3",
     "lung5_rep2_4": "Replicate 2 FoV 4",
     "lung5_rep3_1": "Replicate 3 FoV 1",
     "lung5_rep3_2": "Replicate 3 FoV 2",
     "lung5_rep3_3": "Replicate 3 FoV 3",
     "lung5_rep3_4": "Replicate 3 FoV 4",})

batch_colors = {"Replicate 1": sns.color_palette("colorblind").as_hex()[0],
                "Replicate 2": sns.color_palette("colorblind").as_hex()[3],
                "Replicate 3": sns.color_palette("colorblind").as_hex()[8]}
fov_colors = {"Replicate 1 FoV 1": sns.color_palette("dark").as_hex()[0],
              "Replicate 1 FoV 2": sns.color_palette("pastel").as_hex()[0],
              "Replicate 1 FoV 3": sns.color_palette("muted").as_hex()[0],
              "Replicate 1 FoV 4": sns.color_palette("bright").as_hex()[0],
              "Replicate 2 FoV 1": sns.color_palette("dark").as_hex()[3],
              "Replicate 2 FoV 2": sns.color_palette("pastel").as_hex()[3],
              "Replicate 2 FoV 3": sns.color_palette("muted").as_hex()[3],
              "Replicate 2 FoV 4": sns.color_palette("bright").as_hex()[3],
              "Replicate 3 FoV 1": sns.color_palette("dark").as_hex()[8],
              "Replicate 3 FoV 2": sns.color_palette("pastel").as_hex()[8],
              "Replicate 3 FoV 3": sns.color_palette("muted").as_hex()[8],
              "Replicate 3 FoV 4": sns.color_palette("bright").as_hex()[8]}

plot_category_in_latent_and_physical_space(
        adata=adata_combined,
        plot_label="Batches",
        model_label=None,
        cat_key="batch",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=batch_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_pcr_batches.svg")

plot_category_in_latent_and_physical_space(
        adata=adata_combined,
        plot_label="FoV",
        model_label=None,
        cat_key="fov",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=fov_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_pcr_fovs.svg")

In [None]:
### Supplementary figure: 10 pct subsample comparison GraphST without prior alignment through PASTE algorithm ###
run_number = 2

model = "graphst"
latent_leiden_resolution = 0.5
adata = sc.read_h5ad(f"{artifact_folder_path}/sample_integration_method_benchmarking/{dataset}_{model}.h5ad")

print("\nComputing UMAP embedding...")
sc.tl.umap(adata,
           neighbors_key=f"graphst_latent_run{run_number}")

print("\nComputing Leiden clustering...")
sc.tl.leiden(adata=adata,
             resolution=latent_leiden_resolution,
             key_added=f"latent_leiden_{latent_leiden_resolution}",
             neighbors_key=f"graphst_latent_run{run_number}")

latent_cluster_colors = {
    "0": "#9EB9F3",
    "1": "#C9DB74",
    "2": "#66C5CC",
    "3": "#DCB0F2",
    "4": "#87C55F",
    "5": "#8B008B",
    "6": "#FE88B1",
    "7": "#F6CF71",
    "8": "#F89C74",
    "9": "#2F4F4F"}

niche_annotation_dict = {
    "0": "Tumor 1",
    "1": "Lymphoid Structures",
    "2": "Endothelial/Stroma (Replicate 1 & 3)",
    "3": "Neutrophil/Myeloid/Stroma",
    "4": "Plasmablast/Stroma",
    "5": "Endothelial/Stroma (Replicate 2)",
    "6": "Tumor-Stroma Boundary",
    "7": "Epithelial/Stroma",
    "8": "Tumor 2",
    "9": "Tumor 3"}

adata.obs["niche"] = adata.obs[f"latent_leiden_{latent_leiden_resolution}"].map(niche_annotation_dict)
niche_colors = {niche: latent_cluster_colors[cluster] for cluster, niche in niche_annotation_dict.items()}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Niches",
        model_label={model},
        cat_key=f"niche",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=niche_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_niches_{model}.svg")

cell_type_colors = create_new_color_dict(
    adata=adata,
    color_palette="cell_type_20",
    cat_key="cell_type")

adata_temp = adata

tmp = pd.crosstab(adata_temp.obs["niche"], adata_temp.obs["cell_type"], normalize='index')
tmp.rename(index={"Endothelial/Stroma (Replicate 1 & 3)": "Endothelial/Stroma\n(Replicate 1 & 3)",
                  "Endothelial/Stroma (Replicate 2)": "Endothelial/Stroma\n(Replicate 2)"}, inplace=True)
ax = tmp.plot.barh(color=cell_type_colors, stacked=True, figsize=(3, 6)).legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.yticks(fontsize=14)
plt.xlabel("Cell Type Proportions", fontsize=16)
plt.savefig(f"{benchmarking_folder_path}/{model}_{dataset}_niche_cell_type_proportions.svg", bbox_inches='tight')

adata.obs["batch"] = adata.obs["batch"].replace({"lung5_rep1": "Replicate 1",
                                                 "lung5_rep2": "Replicate 2",
                                                 "lung5_rep3": "Replicate 3"})

batch_colors = {"Replicate 1": sns.color_palette("colorblind").as_hex()[0],
                "Replicate 2": sns.color_palette("colorblind").as_hex()[3],
                "Replicate 3": sns.color_palette("colorblind").as_hex()[8]}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Batches",
        model_label={model},
        cat_key="batch",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=batch_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_batches_{model}.svg")

In [None]:
### Supplementary figure: 10 pct subsample comparison NicheCompass (no fov embedding) ###
model = "nichecompass_gatv2conv"
timestamp = "24082023_153432_1"
latent_leiden_resolution = 0.3 # 0.21
adata = sc.read_h5ad(f"{artifact_folder_path}/{dataset}/models/gatv2conv_sample_integration_method_benchmarking/{timestamp}/run{run_number}/{dataset}_gatv2conv_sample_integration_method_benchmarking.h5ad")

print("\nComputing UMAP embedding...")
sc.tl.umap(adata,
           neighbors_key=f"nichecompass_latent")

print("\nComputing Leiden clustering...")
sc.tl.leiden(adata=adata,
             resolution=latent_leiden_resolution,
             key_added=f"latent_leiden_{latent_leiden_resolution}",
             neighbors_key=f"nichecompass_latent")

latent_cluster_colors = {
    "0": "#9EB9F3",
    "1": "#FE88B1",
    "2": "#66C5CC",
    "3": "#C9DB74", 
    "4": "#DCB0F2",
    "5": "#87C55F",
    "6": "#8B008B",
    "7": "#F6CF71",
    "8": "#F89C74",
    "9": "#2F4F4F"}

niche_annotation_dict = {
    "0": "Tumor 1",
    "1": "Tumor-Stroma Boundary",
    "2": "Endothelial/Stroma (Replicate 1 & 3)",
    "3": "Lymphoid Structures",
    "4": "Neutrophil/Myeloid/Stroma",
    "5": "Plasmablast/Stroma",
    "6": "Endothelial/Stroma (Replicate 2)",
    "7": "Epithelial/Stroma",
    "8": "Tumor 2",
    "9": "Tumor 3"}

adata.obs["niche"] = adata.obs[f"latent_leiden_{latent_leiden_resolution}"].map(niche_annotation_dict)
niche_colors = {niche: latent_cluster_colors[cluster] for cluster, niche in niche_annotation_dict.items()}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Niches",
        model_label={model},
        cat_key=f"niche",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=niche_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_niches_{model}.svg")

cell_type_colors = create_new_color_dict(
    adata=adata,
    color_palette="cell_type_20",
    cat_key="cell_type")

adata_temp = adata

tmp = pd.crosstab(adata_temp.obs["niche"], adata_temp.obs["cell_type"], normalize='index')
tmp.rename(index={"Endothelial/Stroma (Replicate 1 & 3)": "Endothelial/Stroma\n(Replicate 1 & 3)",
                  "Endothelial/Stroma (Replicate 2)": "Endothelial/Stroma\n(Replicate 2)"}, inplace=True)
ax = tmp.plot.barh(color=cell_type_colors, stacked=True, figsize=(3, 6)).legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.yticks(fontsize=14)
plt.xlabel("Cell Type Proportions", fontsize=16)
plt.savefig(f"{benchmarking_folder_path}/{model}_{dataset}_niche_cell_type_proportions.svg", bbox_inches='tight')

adata.obs["batch"] = adata.obs["batch"].replace({"lung5_rep1": "Replicate 1",
                                                 "lung5_rep2": "Replicate 2",
                                                 "lung5_rep3": "Replicate 3"})

batch_colors = {"Replicate 1": sns.color_palette("colorblind").as_hex()[0],
                "Replicate 2": sns.color_palette("colorblind").as_hex()[3],
                "Replicate 3": sns.color_palette("colorblind").as_hex()[8]}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Batches",
        model_label={model},
        cat_key="batch",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=batch_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_batches_{model}.svg")

In [None]:
### Supplementary figure: 10 pct subsample comparison NicheCompass Light ###
model = "nichecompass_gcnconv_fov"
timestamp = "23082023_135406_1"
latent_leiden_resolution = 0.25
adata = sc.read_h5ad(f"{artifact_folder_path}/{dataset}/models/gcnconv_sample_integration_method_benchmarking/{timestamp}/run{run_number}/{dataset}_gcnconv_sample_integration_method_benchmarking.h5ad")

print("\nComputing UMAP embedding...")
sc.tl.umap(adata,
           neighbors_key=f"nichecompass_latent")

print("\nComputing Leiden clustering...")
sc.tl.leiden(adata=adata,
             resolution=latent_leiden_resolution,
             key_added=f"latent_leiden_{latent_leiden_resolution}",
             neighbors_key=f"nichecompass_latent")

latent_cluster_colors = {
    "0": "#66C5CC",
    "1": "#9EB9F3",
    "2": "#C9DB74",
    "3": "#87C55F",
    "4": "#DCB0F2",
    "5": "#F6CF71",
    "6": "#FE88B1",
    "7": "#F89C74"}

niche_annotation_dict = {
    "0": "Endothelial/Stroma",
    "1": "Tumor 1",
    "2": "Lymphoid Structures",
    "3": "Plasmablast/Stroma",
    "4": "Neutrophil/Myeloid/Stroma",
    "5": "Epithelial/Stroma",
    "6": "Tumor-Stroma Boundary",
    "7": "Tumor 2"}

adata.obs["niche"] = adata.obs[f"latent_leiden_{latent_leiden_resolution}"].map(niche_annotation_dict)
niche_colors = {niche: latent_cluster_colors[cluster] for cluster, niche in niche_annotation_dict.items()}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Niches",
        model_label={model},
        cat_key=f"niche",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=niche_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_niches_{model}.svg")

cell_type_colors = create_new_color_dict(
    adata=adata,
    color_palette="cell_type_20",
    cat_key="cell_type")

adata_temp = adata

tmp = pd.crosstab(adata_temp.obs["niche"], adata_temp.obs["cell_type"], normalize='index')
tmp.rename(index={"Endothelial/Stroma (Replicate 1 & 3)": "Endothelial/Stroma\n(Replicate 1 & 3)",
                  "Endothelial/Stroma (Replicate 2)": "Endothelial/Stroma\n(Replicate 2)"}, inplace=True)
ax = tmp.plot.barh(color=cell_type_colors, stacked=True, figsize=(3, 6)).legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.yticks(fontsize=14)
plt.xlabel("Cell Type Proportions", fontsize=16)
plt.savefig(f"{benchmarking_folder_path}/{model}_{dataset}_niche_cell_type_proportions.svg", bbox_inches='tight')

adata.obs["batch"] = adata.obs["batch"].replace({"lung5_rep1": "Replicate 1",
                                                 "lung5_rep2": "Replicate 2",
                                                 "lung5_rep3": "Replicate 3"})

batch_colors = {"Replicate 1": sns.color_palette("colorblind").as_hex()[0],
                "Replicate 2": sns.color_palette("colorblind").as_hex()[3],
                "Replicate 3": sns.color_palette("colorblind").as_hex()[8]}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Batches",
        model_label={model},
        cat_key="batch",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=batch_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_batches_{model}.svg")

In [None]:
### Supplementary figure: 10 pct subsample comparison NicheCompass Light (no fov embedding) ###
model = "nichecompass_gcnconv"
timestamp = "22082023_093531_1"
latent_leiden_resolution = 0.21
adata = sc.read_h5ad(f"{artifact_folder_path}/{dataset}/models/gcnconv_sample_integration_method_benchmarking/{timestamp}/run{run_number}/{dataset}_gcnconv_sample_integration_method_benchmarking.h5ad")

print("\nComputing UMAP embedding...")
sc.tl.umap(adata,
           neighbors_key=f"nichecompass_latent")

print("\nComputing Leiden clustering...")
sc.tl.leiden(adata=adata,
             resolution=latent_leiden_resolution,
             key_added=f"latent_leiden_{latent_leiden_resolution}",
             neighbors_key=f"nichecompass_latent")

latent_cluster_colors = {
    "0": "#9EB9F3",
    "1": "#C9DB74",
    "2": "#FE88B1",
    "3": "#66C5CC", 
    "4": "#DCB0F2",
    "5": "#87C55F",
    "6": "#8B008B",
    "7": "#F89C74"}

niche_annotation_dict = {
    "0": "Tumor 1",
    "1": "Lymphoid Structures",
    "2": "Tumor-Stroma Boundary", 
    "3": "Endothelial/Stroma (Replicate 1 & 3)",
    "4": "Neutrophil/Myeloid/Stroma",
    "5": "Plasmablast/Stroma",
    "6": "Endothelial/Stroma (Replicate 2)",
    "7": "Tumor 2",}

adata.obs["niche"] = adata.obs[f"latent_leiden_{latent_leiden_resolution}"].map(niche_annotation_dict)
niche_colors = {niche: latent_cluster_colors[cluster] for cluster, niche in niche_annotation_dict.items()}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Niches",
        model_label={model},
        cat_key=f"niche",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=niche_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_niches_{model}.svg")

cell_type_colors = create_new_color_dict(
    adata=adata,
    color_palette="cell_type_20",
    cat_key="cell_type")

adata_temp = adata

tmp = pd.crosstab(adata_temp.obs["niche"], adata_temp.obs["cell_type"], normalize='index')
tmp.rename(index={"Endothelial/Stroma (Replicate 1 & 3)": "Endothelial/Stroma\n(Replicate 1 & 3)",
                  "Endothelial/Stroma (Replicate 2)": "Endothelial/Stroma\n(Replicate 2)"}, inplace=True)
ax = tmp.plot.barh(color=cell_type_colors, stacked=True, figsize=(3, 6)).legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
plt.yticks(fontsize=14)
plt.xlabel("Cell Type Proportions", fontsize=16)
plt.savefig(f"{benchmarking_folder_path}/{model}_{dataset}_niche_cell_type_proportions.svg", bbox_inches='tight')

adata.obs["batch"] = adata.obs["batch"].replace({"lung5_rep1": "Replicate 1",
                                                 "lung5_rep2": "Replicate 2",
                                                 "lung5_rep3": "Replicate 3"})

batch_colors = {"Replicate 1": sns.color_palette("colorblind").as_hex()[0],
                "Replicate 2": sns.color_palette("colorblind").as_hex()[3],
                "Replicate 3": sns.color_palette("colorblind").as_hex()[8]}

plot_category_in_latent_and_physical_space(
        adata=adata,
        plot_label="Batches",
        model_label={model},
        cat_key="batch",
        groups=None,
        sample_key="batch",
        samples=adata.obs["batch"].unique().tolist(),
        cat_colors=batch_colors,
        size=(720000 / len(adata)),
        spot_size=200,
        save_fig=True,
        file_path=f"{benchmarking_folder_path}/{dataset}_batches_{model}.svg")

In [None]:
### Fig. 3e: 10 pct subsample comparison metrics (not included) ###
datasets = ["nanostring_cosmx_human_nsclc_subsample_10pct"]
models = ["nichecompass_gatv2conv_fov",
          "staci",
          "graphst_paste",
          "cellcharter",
          "banksy"]

summary_df = pd.DataFrame()
for dataset in datasets:
    dataset_df = pd.DataFrame()
    for model in models:
        try:
            benchmark_df = pd.read_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv")
            #adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
            #training_durations = []
            #for run_number in [1, 2, 3, 4, 5, 6, 7, 8]:
            #    training_durations.append(adata.uns[f"{model.split('_')[0]}_model_training_duration_run{run_number}"])
            #benchmark_df["run_time"] = training_durations
            #benchmark_df = benchmark_df[["dataset", "run_number", "run_time", "gcs", "mlami", "cas", "clisis", "nasw", "cnmi", "cari", "casw", "clisi"]]
            #benchmark_df.to_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv", index=False)
            benchmark_df["model"] = model
            dataset_df = pd.concat([dataset_df, benchmark_df], ignore_index=True)
        except FileNotFoundError:
            print(f"Did not find file {benchmarking_folder_path}/{dataset}_{model}_metrics.csv. Continuing...")
            missing_run_data = {
                "dataset": [dataset] * 8,
                "model": [model] * 8,
                "run_number": [1, 2, 3, 4, 5, 6, 7, 8],
                "run_time": [np.nan] * 8
            }
            missing_run_df = pd.DataFrame(missing_run_data)
            dataset_df = pd.concat([dataset_df, missing_run_df], ignore_index=True)
            
    # Apply min-max scaling to metric columns
    for i in range(len(metric_cols_sample_integration)):
        min_val = dataset_df[metric_cols_sample_integration[i]].min()
        max_val = dataset_df[metric_cols_sample_integration[i]].max()
        dataset_df[metric_cols_sample_integration[i] + "_scaled"] = ((
            dataset_df[metric_cols_sample_integration[i]] - min_val) / (max_val - min_val))

    summary_df = pd.concat([summary_df, dataset_df], ignore_index=True)
    continue
    
summary_df['pcr'] = summary_df['pcr'].fillna(0)
summary_df['pcr_scaled'] = summary_df['pcr_scaled'].fillna(0)
            
cat_0_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[0:2]]
cat_1_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[2:4]]
cat_2_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[4:6]]
cat_3_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[6:8]]
    
summary_df[category_cols_sample_integration[0]] = np.average(summary_df[cat_0_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[0:2],
                                                        axis=1)
summary_df[category_cols_sample_integration[1]] = np.average(summary_df[cat_1_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[2:4],
                                                        axis=1)
summary_df[category_cols_sample_integration[2]] = np.average(summary_df[cat_2_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[4:6],
                                                        axis=1)
summary_df[category_cols_sample_integration[3]] = np.average(summary_df[cat_3_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[6:8],
                                                        axis=1)
summary_df["Overall Score"] = np.average(summary_df[category_cols_sample_integration[:4]],
                                         weights=category_col_weights_sample_integration[:4],
                                         axis=1)
 
# Reformat for plot
summary_df.replace({"nichecompass_gatv2conv_fov": "NicheCompass",
                    "staci": "STACI",
                    "graphst_paste": "GraphST",
                    "cellcharter": "CellCharter",
                    "banksy": "BANKSY"}, inplace=True)

summary_df = summary_df[(summary_df["run_number"] == 2) & (summary_df["model"] != "STACI") |
                        (summary_df["run_number"] == 1) & (summary_df["model"] == "STACI")]

# Plot over all loss weights combinations
# Prepare metrics table plot
group_cols = ["dataset", "model"]
aggregate_df = summary_df.groupby(group_cols).mean("Overall Score").sort_values("Overall Score", ascending=False)[
    metric_cols_sample_integration + ["Overall Score"]].reset_index()

unrolled_df = pd.melt(aggregate_df, 
   id_vars=group_cols,
   value_vars=metric_cols_sample_integration + ["Overall Score"], # metric_cols_sample_integration, category_cols_sample_integration
   var_name="score_type", 
   value_name="score")

# Create spatial indicator column
def is_spatially_aware_model(row):
    if row["model"] in ["NicheCompass", "NicheCompass Light", "GraphST", "GraphST (No Prior Alignment)", "CellCharter"]:
        return True
    return False
unrolled_df["spatially_aware"] = unrolled_df.apply(lambda row: is_spatially_aware_model(row), axis=1)
unrolled_df = unrolled_df[["dataset", "spatially_aware", "model", "score_type", "score"]]

# Order datasets
unrolled_df["dataset"] = pd.Categorical(unrolled_df["dataset"], categories=datasets, ordered=True)
unrolled_df = unrolled_df.sort_values(by="dataset")

#print(summary_df["model"].value_counts())
#print(summary_df[summary_df.isna().any(axis=1)])

summary_df["model"] = summary_df["model"].replace("NicheCompass GCN", "NicheCompass Light")
summary_df["model"] = summary_df["model"].replace("NicheCompass GATv2", "NicheCompass")

summary_df["dataset"] = summary_df["dataset"].replace(
    {"nanostring_cosmx_human_nsclc_subsample_10pct": "nanoString CosMx Human NSCLC"})

# Plot table
plot_simple_metrics_table(
    df=unrolled_df,
    model_col="model",
    model_col_width=2.1,
    group_col="dataset",
    metric_cols=metric_cols_sample_integration, # metric_cols_sample_integration, category_cols_sample_integration
    metric_col_weights=metric_col_weights_sample_integration, # metric_col_weights_sample_integration, category_col_weights_sample_integration
    metric_col_titles=[col.replace(" ", "\n") for col in metric_col_titles_sample_integration], # category_col_titles_sample_integration
    metric_col_width=1.1, # 0.8,
    plot_width=11, # 32,
    plot_height=5,
    show=True,
    save_dir=benchmarking_folder_path,
    save_name=f"benchmarking_metrics_nanostring_cosmx_human_nsclc_subsample_10pct_run2.svg")

In [None]:
### Supplementary fig. 14: 10 pct subsample comparison NC vs STACI vs GraphST variations vs CellCharter vs BANKSY ###
datasets = ["nanostring_cosmx_human_nsclc_subsample_10pct"]
models = ["nichecompass_gatv2conv_fov",
          "nichecompass_gatv2conv",
          "nichecompass_gcnconv_fov",
          "nichecompass_gcnconv",
          "staci",
          "graphst_paste",
          "graphst",
          "cellcharter",
          "banksy"]

summary_df = pd.DataFrame()
for dataset in datasets:
    dataset_df = pd.DataFrame()
    for model in models:
        try:
            benchmark_df = pd.read_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv")
            #adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
            #training_durations = []
            #for run_number in [1, 2, 3, 4, 5, 6, 7, 8]:
            #    training_durations.append(adata.uns[f"{model.split('_')[0]}_model_training_duration_run{run_number}"])
            #benchmark_df["run_time"] = training_durations
            #benchmark_df = benchmark_df[["dataset", "run_number", "run_time", "gcs", "mlami", "cas", "clisis", "nasw", "cnmi", "cari", "casw", "clisi"]]
            #benchmark_df.to_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv", index=False)
            benchmark_df["model"] = model
            dataset_df = pd.concat([dataset_df, benchmark_df], ignore_index=True)
        except FileNotFoundError:
            print(f"Did not find file {benchmarking_folder_path}/{dataset}_{model}_metrics.csv. Continuing...")
            missing_run_data = {
                "dataset": [dataset] * 8,
                "model": [model] * 8,
                "run_number": [1, 2, 3, 4, 5, 6, 7, 8],
                "run_time": [np.nan] * 8
            }
            missing_run_df = pd.DataFrame(missing_run_data)
            dataset_df = pd.concat([dataset_df, missing_run_df], ignore_index=True)
            
    # Apply min-max scaling to metric columns
    for i in range(len(metric_cols_sample_integration)):
        min_val = dataset_df[metric_cols_sample_integration[i]].min()
        max_val = dataset_df[metric_cols_sample_integration[i]].max()
        dataset_df[metric_cols_sample_integration[i] + "_scaled"] = ((
            dataset_df[metric_cols_sample_integration[i]] - min_val) / (max_val - min_val))

    summary_df = pd.concat([summary_df, dataset_df], ignore_index=True)
    continue
    
summary_df['pcr'] = summary_df['pcr'].fillna(0)
summary_df['pcr_scaled'] = summary_df['pcr_scaled'].fillna(0)
            
cat_0_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[0:2]]
cat_1_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[2:4]]
cat_2_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[4:6]]
cat_3_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[6:8]]
    
summary_df[category_cols_sample_integration[0]] = np.average(summary_df[cat_0_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[0:2],
                                                        axis=1)
summary_df[category_cols_sample_integration[1]] = np.average(summary_df[cat_1_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[2:4],
                                                        axis=1)
summary_df[category_cols_sample_integration[2]] = np.average(summary_df[cat_2_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[4:6],
                                                        axis=1)
summary_df[category_cols_sample_integration[3]] = np.average(summary_df[cat_3_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[6:8],
                                                        axis=1)
summary_df["Overall Score"] = np.average(summary_df[category_cols_sample_integration[:4]],
                                         weights=category_col_weights_sample_integration[:4],
                                         axis=1)
 
# Reformat for plot
summary_df.replace({"nichecompass_gatv2conv_fov": "NicheCompass",
                    "nichecompass_gatv2conv": "NicheCompass (No FoV Embedding)",
                    "nichecompass_gcnconv_fov": "NicheCompass Light",
                    "nichecompass_gcnconv": "NicheCompass Light (No FoV Embedding)",
                    "staci": "STACI",
                    "graphst_paste": "GraphST",
                    "graphst": "GraphST (No Prior Alignment)",
                    "cellcharter": "CellCharter",
                    "banksy": "BANKSY"}, inplace=True)

summary_df = summary_df[(summary_df["run_number"] == 2) & (summary_df["model"] != "STACI") |
                        (summary_df["run_number"] == 1) & (summary_df["model"] == "STACI")]

summary_df['pcr'] = summary_df['pcr'].fillna(0)

# Plot over all loss weights combinations
# Prepare metrics table plot
group_cols = ["dataset", "model"]
aggregate_df = summary_df.groupby(group_cols).mean("Overall Score").sort_values("Overall Score", ascending=False)[
    metric_cols_sample_integration + ["Overall Score"]].reset_index()

unrolled_df = pd.melt(aggregate_df, 
   id_vars=group_cols,
   value_vars=metric_cols_sample_integration + ["Overall Score"], # metric_cols_sample_integration, category_cols_sample_integration
   var_name="score_type", 
   value_name="score")

# Create spatial indicator column
def is_spatially_aware_model(row):
    if row["model"] in ["NicheCompass",
                        "NicheCompass (No FoV Embedding)",
                        "NicheCompass Light",
                        "NicheCompass Light (No FoV Embedding)",
                        "GraphST",
                        "GraphST (No Prior Alignment)",
                        "CellCharter"]:
        return True
    return False
unrolled_df["spatially_aware"] = unrolled_df.apply(lambda row: is_spatially_aware_model(row), axis=1)
unrolled_df = unrolled_df[["dataset", "spatially_aware", "model", "score_type", "score"]]

# Order datasets
unrolled_df["dataset"] = pd.Categorical(unrolled_df["dataset"], categories=datasets, ordered=True)
unrolled_df = unrolled_df.sort_values(by="dataset")

#print(summary_df["model"].value_counts())
#print(summary_df[summary_df.isna().any(axis=1)])

summary_df["model"] = summary_df["model"].replace("NicheCompass GCN", "NicheCompass Light")
summary_df["model"] = summary_df["model"].replace("NicheCompass GATv2", "NicheCompass")

summary_df["dataset"] = summary_df["dataset"].replace(
    {"nanostring_cosmx_human_nsclc_subsample_10pct": "nanoString CosMx Human NSCLC"})

# Plot table
plot_simple_metrics_table(
    df=unrolled_df,
    model_col="model",
    model_col_width=5.5,
    group_col="dataset",
    metric_cols=metric_cols_sample_integration, # metric_cols_sample_integration, category_cols_sample_integration
    metric_col_weights=metric_col_weights_sample_integration, # metric_col_weights_sample_integration, category_col_weights_sample_integration
    metric_col_titles=[col.replace(" ", "\n") for col in metric_col_titles_sample_integration], # category_col_titles_sample_integration
    metric_col_width=1.1, # 0.8,
    plot_width=14, # 32,
    plot_height=8,
    show=True,
    save_dir=benchmarking_folder_path,
    save_name=f"benchmarking_metrics_nanostring_cosmx_human_nsclc_subsample_10pct_run2_supplement.svg")

In [None]:
# Load metrics
datasets = ["nanostring_cosmx_human_nsclc",
            "nanostring_cosmx_human_nsclc_subsample_50pct",
            "nanostring_cosmx_human_nsclc_subsample_25pct",
            "nanostring_cosmx_human_nsclc_subsample_10pct",
            "nanostring_cosmx_human_nsclc_subsample_5pct",
            "nanostring_cosmx_human_nsclc_subsample_1pct"]
models = ["nichecompass_gatv2conv",
          "nichecompass_gcnconv",
          "staci",
          "graphst_paste",
          "graphst",
          "cellcharter",
          "banksy"]

summary_df = pd.DataFrame()
for dataset in datasets:
    dataset_df = pd.DataFrame()
    for model in models:
        try:
            benchmark_df = pd.read_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv")
            #adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
            #training_durations = []
            #for run_number in [1, 2, 3, 4, 5, 6, 7, 8]:
            #    training_durations.append(adata.uns[f"{model.split('_')[0]}_model_training_duration_run{run_number}"])
            #benchmark_df["run_time"] = training_durations
            #benchmark_df = benchmark_df[["dataset", "run_number", "run_time", "gcs", "mlami", "cas", "clisis", "nasw", "cnmi", "cari", "casw", "clisi"]]
            #benchmark_df.to_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv", index=False)
            benchmark_df["model"] = model
            dataset_df = pd.concat([dataset_df, benchmark_df], ignore_index=True)
        except FileNotFoundError:
            print(f"Did not find file {benchmarking_folder_path}/{dataset}_{model}_metrics.csv. Continuing...")
            missing_run_data = {
                "dataset": [dataset] * 8,
                "model": [model] * 8,
                "run_number": [1, 2, 3, 4, 5, 6, 7, 8],
                "run_time": [np.nan] * 8
            }
            missing_run_df = pd.DataFrame(missing_run_data)
            dataset_df = dataset_df = pd.concat([dataset_df, missing_run_df], ignore_index=True)
            
    # Apply min-max scaling to metric columns
    for i in range(len(metric_cols_sample_integration)):
        min_val = dataset_df[metric_cols_sample_integration[i]].min()
        max_val = dataset_df[metric_cols_sample_integration[i]].max()
        dataset_df[metric_cols_sample_integration[i] + "_scaled"] = ((
            dataset_df[metric_cols_sample_integration[i]] - min_val) / (max_val - min_val))

    summary_df = pd.concat([summary_df, dataset_df], ignore_index=True)
    continue

summary_df['pcr'] = summary_df['pcr'].fillna(0)
summary_df['pcr_scaled'] = summary_df['pcr_scaled'].fillna(0)
    
cat_0_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[0:2]]
cat_1_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[2:4]]
cat_2_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[4:6]]
cat_3_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[6:8]]
    
summary_df[category_cols_sample_integration[0]] = np.average(summary_df[cat_0_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[0:2],
                                                        axis=1)
summary_df[category_cols_sample_integration[1]] = np.average(summary_df[cat_1_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[2:4],
                                                        axis=1)
summary_df[category_cols_sample_integration[2]] = np.average(summary_df[cat_2_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[4:6],
                                                        axis=1)
summary_df[category_cols_sample_integration[3]] = np.average(summary_df[cat_3_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[6:8],
                                                        axis=1)
summary_df["Overall Score"] = np.average(summary_df[category_cols_sample_integration[:4]],
                                         weights=category_col_weights_sample_integration[:4],
                                         axis=1)
 
# Reformat for plot
summary_df.replace({"nichecompass_gatv2conv": "NicheCompass",
                    "nichecompass_gcnconv": "NicheCompass Light",
                    "staci": "STACI",
                    "graphst_paste": "GraphST",
                    "graphst": "GraphST (No Prior Alignment)",
                    "cellcharter": "CellCharter",
                    "banksy": "BANKSY"}, inplace=True)

# Plot over all loss weights combinations
# Prepare metrics table plot
group_cols = ["dataset", "model"]
aggregate_df = summary_df.groupby(group_cols).mean("Overall Score").sort_values("Overall Score", ascending=False)[
    metric_cols_sample_integration + ["Overall Score"]].reset_index()

unrolled_df = pd.melt(aggregate_df, 
   id_vars=group_cols,
   value_vars=metric_cols_sample_integration + ["Overall Score"], # metric_cols_sample_integration, category_cols_sample_integration
   var_name="score_type", 
   value_name="score")

# Create spatial indicator column
def is_spatially_aware_model(row):
    if row["model"] in ["NicheCompass", "NicheCompass Light", "GraphST", "GraphST (No Prior Alignment)", "CellCharter", "BANKSY"]:
        return True
    return False

unrolled_df["spatially_aware"] = unrolled_df.apply(lambda row: is_spatially_aware_model(row), axis=1)
unrolled_df = unrolled_df[["dataset", "spatially_aware", "model", "score_type", "score"]]

# Order datasets
unrolled_df["dataset"] = pd.Categorical(unrolled_df["dataset"], categories=datasets, ordered=True)
unrolled_df = unrolled_df.sort_values(by="dataset")

#print(summary_df["model"].value_counts())
#print(summary_df[summary_df.isna().any(axis=1)])

summary_df["model"] = summary_df["model"].replace("NicheCompass GCN", "NicheCompass Light")
summary_df["model"] = summary_df["model"].replace("NicheCompass GATv2", "NicheCompass")

In [None]:
### Supplementary figure: scalability ###
model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "GraphST": "#D78FF8",
                 "GraphST (No Prior Alignment)": "#b5bd61",
                 "CellCharter": "#F46AA2",
                 "BANKSY": "#556B2F"}

run_time_mean_df = summary_df.groupby(["dataset", "model"])[["run_time"]].mean().reset_index()
run_time_mean_df["run_time"] = run_time_mean_df["run_time"] / 60

def create_dataset_share_col(row):
    if row["dataset"] == "nanostring_cosmx_human_nsclc":
        return 100
    elif row["dataset"] == "nanostring_cosmx_human_nsclc_subsample_50pct":    
        return 50
    elif row["dataset"] == "nanostring_cosmx_human_nsclc_subsample_25pct":    
        return 25
    elif row["dataset"] == "nanostring_cosmx_human_nsclc_subsample_10pct":    
        return 10
    elif row["dataset"] == "nanostring_cosmx_human_nsclc_subsample_5pct":    
        return 5
    elif row["dataset"] == "nanostring_cosmx_human_nsclc_subsample_1pct":    
        return 1
    
run_time_mean_df["dataset_share"] = run_time_mean_df.apply(lambda row: create_dataset_share_col(row), axis=1)

with sns.axes_style("ticks"):
    ax = sns.lineplot(data=run_time_mean_df,
                      x="dataset_share",
                      y="run_time",
                      hue="model",
                      marker='o',
                      palette=model_palette)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.title("nanoString CosMx Human NSCLC\n(232,671 Cells; 960 Genes)")
    plt.ylabel("Run Time (Minutes)")
    plt.xlabel("Dataset Size (%)")
    custom_y_ticks = [1, 10, 60, 180, 360, 720, 1440]  # Adjust the tick positions as needed
    plt.yscale("log")
    plt.yticks(custom_y_ticks, custom_y_ticks)
    legend = plt.gca().get_legend()
    for handle in legend.legendHandles:
        handle.set_linewidth(4.0)  # Adjust the size as needed
    handles, labels = legend.legendHandles, [text.get_text() for text in legend.get_texts()]
    order = [1, 2, 0]
    ordered_handles = [handles[i] for i in order]
    ordered_labels = [labels[i] for i in order]
    plt.legend(ordered_handles, ordered_labels)
    ax = plt.gca()
    ax.legend().set_visible(False)
    plt.grid(True)
    plt.savefig(benchmarking_folder_path + "/benchmarking_runtimes_nanostring_cosmx_human_nsclc.svg")
    plt.show()

In [None]:
### Supplementary fig. 16b: sample integration metric averages ###
# Define the custom sort order
custom_order = ["NicheCompass", "NicheCompass Light", "BANKSY", "CellCharter", "STACI", "GraphST", "GraphST (No Prior Alignment)"]

# Convert the 'category' column to a categorical data type with the custom order
unrolled_df['model'] = pd.Categorical(unrolled_df['model'], categories=custom_order, ordered=True)

# Sort the DataFrame based on the 'category' column
unrolled_df = unrolled_df.sort_values(['dataset', 'model'])

plot_metrics_table(
    df=unrolled_df,
    model_col="model",
    model_col_width=3,
    group_col="dataset",
    metric_cols=metric_cols_sample_integration, # metric_cols_sample_integration, category_cols_sample_integration
    metric_col_weights=metric_col_weights_sample_integration, # metric_col_weights_sample_integration, category_col_weights_sample_integration
    metric_col_titles=[col.replace(" ", "\n") for col in metric_col_titles_sample_integration], # category_col_titles_sample_integration
    metric_col_width=0.5, # 0.8,
    aggregate_col_width=0.8,
    plot_width=43.5, # 32,
    plot_height=6,
    show=True,
    save_dir=benchmarking_folder_path,
    save_name=f"benchmarking_metrics_nanostring_cosmx_human_nsclc.svg")

#### 2.1.2 seqFISH Mouse Organogenesis

In [None]:
# Load metrics
datasets = ["seqfish_mouse_organogenesis",
            "seqfish_mouse_organogenesis_subsample_50pct",
            "seqfish_mouse_organogenesis_subsample_25pct",
            "seqfish_mouse_organogenesis_subsample_10pct",
            "seqfish_mouse_organogenesis_subsample_5pct",
            "seqfish_mouse_organogenesis_subsample_1pct"]
models = ["nichecompass_gatv2conv",
          "nichecompass_gcnconv",
          "staci",
          "graphst_paste",
          "graphst",
          "cellcharter",
          "banksy"]

summary_df = pd.DataFrame()
for dataset in datasets:
    dataset_df = pd.DataFrame()
    for model in models:
        try:
            benchmark_df = pd.read_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv")
            #adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
            #training_durations = []
            #for run_number in [1, 2, 3, 4, 5, 6, 7, 8]:
            #    training_durations.append(adata.uns[f"{model.split('_')[0]}_model_training_duration_run{run_number}"])
            #benchmark_df["run_time"] = training_durations
            #benchmark_df = benchmark_df[["dataset", "run_number", "run_time", "gcs", "mlami", "cas", "clisis", "nasw", "cnmi", "cari", "casw", "clisi"]]
            #benchmark_df.to_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv", index=False)
            benchmark_df["model"] = model
            dataset_df = pd.concat([dataset_df, benchmark_df], ignore_index=True)
        except FileNotFoundError:
            print(f"Did not find file {benchmarking_folder_path}/{dataset}_{model}_metrics.csv. Continuing...")
            missing_run_data = {
                "dataset": [dataset] * 8,
                "model": [model] * 8,
                "run_number": [1, 2, 3, 4, 5, 6, 7, 8],
                "run_time": [np.nan] * 8
            }
            missing_run_df = pd.DataFrame(missing_run_data)
            dataset_df = pd.concat([dataset_df, missing_run_df], ignore_index=True)
            
    # Apply min-max scaling to metric columns
    for i in range(len(metric_cols_sample_integration)):
        min_val = dataset_df[metric_cols_sample_integration[i]].min()
        max_val = dataset_df[metric_cols_sample_integration[i]].max()
        dataset_df[metric_cols_sample_integration[i] + "_scaled"] = ((
            dataset_df[metric_cols_sample_integration[i]] - min_val) / (max_val - min_val))

    summary_df = pd.concat([summary_df, dataset_df], ignore_index=True)
    continue
    
summary_df['pcr'] = summary_df['pcr'].fillna(0)
summary_df['pcr_scaled'] = summary_df['pcr_scaled'].fillna(0)
            
cat_0_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[0:2]]
cat_1_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[2:4]]
cat_2_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[4:6]]
cat_3_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[6:8]]
    
summary_df[category_cols_sample_integration[0]] = np.average(summary_df[cat_0_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[0:2],
                                                        axis=1)
summary_df[category_cols_sample_integration[1]] = np.average(summary_df[cat_1_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[2:4],
                                                        axis=1)
summary_df[category_cols_sample_integration[2]] = np.average(summary_df[cat_2_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[4:6],
                                                        axis=1)
summary_df[category_cols_sample_integration[3]] = np.average(summary_df[cat_3_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[6:8],
                                                        axis=1)
summary_df["Overall Score"] = np.average(summary_df[category_cols_sample_integration[:4]],
                                         weights=category_col_weights_sample_integration[:4],
                                         axis=1)
 
# Reformat for plot
summary_df.replace({"nichecompass_gatv2conv": "NicheCompass",
                    "nichecompass_gcnconv": "NicheCompass Light",
                    "staci": "STACI",
                    "graphst_paste": "GraphST",
                    "graphst": "GraphST (No Prior Alignment)",
                    "cellcharter": "CellCharter",
                    "banksy": "BANKSY"}, inplace=True)

# Plot over all loss weights combinations
# Prepare metrics table plot
group_cols = ["dataset", "model"]
aggregate_df = summary_df.groupby(group_cols).mean("Overall Score").sort_values("Overall Score", ascending=False)[
    metric_cols_sample_integration + ["Overall Score"]].reset_index()

unrolled_df = pd.melt(aggregate_df, 
   id_vars=group_cols,
   value_vars=metric_cols_sample_integration + ["Overall Score"], # metric_cols_sample_integration, category_cols_sample_integration
   var_name="score_type", 
   value_name="score")

# Create spatial indicator column
def is_spatially_aware_model(row):
    if row["model"] in ["NicheCompass", "NicheCompass Light", "GraphST", "GraphST (No Prior Alignment)", "BANKSY"]:
        return True
    return False
unrolled_df["spatially_aware"] = unrolled_df.apply(lambda row: is_spatially_aware_model(row), axis=1)
unrolled_df = unrolled_df[["dataset", "spatially_aware", "model", "score_type", "score"]]

# Order datasets
unrolled_df["dataset"] = pd.Categorical(unrolled_df["dataset"], categories=datasets, ordered=True)
unrolled_df = unrolled_df.sort_values(by="dataset")

#print(summary_df["model"].value_counts())
#print(summary_df[summary_df.isna().any(axis=1)])

summary_df["model"] = summary_df["model"].replace("NicheCompass GCN", "NicheCompass Light")
summary_df["model"] = summary_df["model"].replace("NicheCompass GATv2", "NicheCompass")

In [None]:
### Supplementary fig. 17: sample integration runtimes ###
model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "GraphST": "#D78FF8",
                 "GraphST (No Prior Alignment)": "#b5bd61",
                 "CellCharter": "#F46AA2",
                 "BANKSY": "#556B2F"}

run_time_mean_df = summary_df.groupby(["dataset", "model"])[["run_time"]].mean().reset_index()
run_time_mean_df["run_time"] = run_time_mean_df["run_time"] / 60

def create_dataset_share_col(row):
    if row["dataset"] == "seqfish_mouse_organogenesis":
        return 100
    elif row["dataset"] == "seqfish_mouse_organogenesis_subsample_50pct":    
        return 50
    elif row["dataset"] == "seqfish_mouse_organogenesis_subsample_25pct":    
        return 25
    elif row["dataset"] == "seqfish_mouse_organogenesis_subsample_10pct":    
        return 10
    elif row["dataset"] == "seqfish_mouse_organogenesis_subsample_5pct":    
        return 5
    elif row["dataset"] == "seqfish_mouse_organogenesis_subsample_1pct":    
        return 1
    
run_time_mean_df["dataset_share"] = run_time_mean_df.apply(lambda row: create_dataset_share_col(row), axis=1)

with sns.axes_style("ticks"):
    ax = sns.lineplot(data=run_time_mean_df,
                      x="dataset_share",
                      y="run_time",
                      hue="model",
                      marker='o',
                      palette=model_palette)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.title("seqFISH Mouse Organogenesis\n(57,536 Cells; 351 Genes)")
    plt.ylabel("Run Time (Minutes)")
    plt.xlabel("Dataset Size (%)")
    custom_y_ticks = [1, 10, 60, 180, 360, 720, 1440]  # Adjust the tick positions as needed
    plt.yscale("log")
    plt.yticks(custom_y_ticks, custom_y_ticks)
    legend = plt.gca().get_legend()
    for handle in legend.legendHandles:
        handle.set_linewidth(4.0)  # Adjust the size as needed
    handles, labels = legend.legendHandles, [text.get_text() for text in legend.get_texts()]
    order = [1, 2, 0]
    ordered_handles = [handles[i] for i in order]
    ordered_labels = [labels[i] for i in order]
    plt.legend(ordered_handles, ordered_labels)
    ax = plt.gca()
    ax.legend().set_visible(False)
    plt.grid(True)
    plt.savefig(benchmarking_folder_path + "/benchmarking_runtimes_seqfish_mouse_organogenesis.svg")
    plt.show()

In [None]:
### Supplementary fig. 16b: sample integration metric averages ###
plot_metrics_table(
    df=unrolled_df,
    model_col="model",
    model_col_width=3,
    group_col="dataset",
    metric_cols=metric_cols_sample_integration, # metric_cols_sample_integration, category_cols_sample_integration
    metric_col_weights=metric_col_weights_sample_integration, # metric_col_weights_sample_integration, category_col_weights_sample_integration
    metric_col_titles=[col.replace(" ", "\n") for col in metric_col_titles_sample_integration], # category_col_titles_sample_integration
    metric_col_width=0.5, # 0.8,
    aggregate_col_width=0.8,
    plot_width=43.5, # 32,
    plot_height=6,
    show=True,
    save_dir=benchmarking_folder_path,
    save_name=f"benchmarking_metrics_seqfish_mouse_organogenesis.svg")

#### 2.1.3 seqFISH Mouse Organogenesis Imputed

In [None]:
datasets = ["seqfish_mouse_organogenesis_imputed",
            "seqfish_mouse_organogenesis_imputed_subsample_50pct",
            "seqfish_mouse_organogenesis_imputed_subsample_25pct",
            "seqfish_mouse_organogenesis_imputed_subsample_10pct",
            "seqfish_mouse_organogenesis_imputed_subsample_5pct",
            "seqfish_mouse_organogenesis_imputed_subsample_1pct"]
models = ["nichecompass_gatv2conv",
          "nichecompass_gcnconv",
          "staci",
          "graphst_paste",
          "graphst",
          "cellcharter",
          "banksy"]

summary_df = pd.DataFrame()
for dataset in datasets:
    dataset_df = pd.DataFrame()
    for model in models:
        try:
            benchmark_df = pd.read_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv")
            #adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
            #training_durations = []
            #for run_number in [1, 2, 3, 4, 5, 6, 7, 8]:
            #    training_durations.append(adata.uns[f"{model.split('_')[0]}_model_training_duration_run{run_number}"])
            #benchmark_df["run_time"] = training_durations
            #benchmark_df = benchmark_df[["dataset", "run_number", "run_time", "gcs", "mlami", "cas", "clisis", "nasw", "cnmi", "cari", "casw", "clisi"]]
            #benchmark_df.to_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv", index=False)
            benchmark_df["model"] = model
            dataset_df = pd.concat([dataset_df, benchmark_df], ignore_index=True)
        except FileNotFoundError:
            print(f"Did not find file {benchmarking_folder_path}/{dataset}_{model}_metrics.csv. Continuing...")
            missing_run_data = {
                "dataset": [dataset] * 8,
                "model": [model] * 8,
                "run_number": [1, 2, 3, 4, 5, 6, 7, 8],
                "run_time": [np.nan] * 8
            }
            missing_run_df = pd.DataFrame(missing_run_data)
            dataset_df = pd.concat([dataset_df, missing_run_df], ignore_index=True)
            
    # Apply min-max scaling to metric columns
    for i in range(len(metric_cols_sample_integration)):
        min_val = dataset_df[metric_cols_sample_integration[i]].min()
        max_val = dataset_df[metric_cols_sample_integration[i]].max()
        dataset_df[metric_cols_sample_integration[i] + "_scaled"] = ((
            dataset_df[metric_cols_sample_integration[i]] - min_val) / (max_val - min_val))

    summary_df = pd.concat([summary_df, dataset_df], ignore_index=True)
    continue
    
summary_df['pcr'] = summary_df['pcr'].fillna(0)
summary_df['pcr_scaled'] = summary_df['pcr_scaled'].fillna(0)
            
cat_0_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[0:2]]
cat_1_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[2:4]]
cat_2_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[4:6]]
cat_3_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[6:8]]
    
summary_df[category_cols_sample_integration[0]] = np.average(summary_df[cat_0_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[0:2],
                                                        axis=1)
summary_df[category_cols_sample_integration[1]] = np.average(summary_df[cat_1_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[2:4],
                                                        axis=1)
summary_df[category_cols_sample_integration[2]] = np.average(summary_df[cat_2_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[4:6],
                                                        axis=1)
summary_df[category_cols_sample_integration[3]] = np.average(summary_df[cat_3_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[6:8],
                                                        axis=1)
summary_df["Overall Score"] = np.average(summary_df[category_cols_sample_integration[:4]],
                                         weights=category_col_weights_sample_integration[:4],
                                         axis=1)
 
# Reformat for plot
summary_df.replace({"nichecompass_gatv2conv": "NicheCompass",
                    "nichecompass_gcnconv": "NicheCompass Light",
                    "staci": "STACI",
                    "graphst_paste": "GraphST",
                    "graphst": "GraphST (No Prior Alignment)",
                    "cellcharter": "CellCharter",
                    "banksy": "BANKSY"}, inplace=True)

# Plot over all loss weights combinations
# Prepare metrics table plot
group_cols = ["dataset", "model"]
aggregate_df = summary_df.groupby(group_cols).mean("Overall Score").sort_values("Overall Score", ascending=False)[
    metric_cols_sample_integration + ["Overall Score"]].reset_index()

unrolled_df = pd.melt(aggregate_df, 
   id_vars=group_cols,
   value_vars=metric_cols_sample_integration + ["Overall Score"], # metric_cols_sample_integration, category_cols_sample_integration
   var_name="score_type", 
   value_name="score")

# Create spatial indicator column
def is_spatially_aware_model(row):
    if row["model"] in ["NicheCompass", "NicheCompass Light", "GraphST", "CellCharter"]:
        return True
    return False
unrolled_df["spatially_aware"] = unrolled_df.apply(lambda row: is_spatially_aware_model(row), axis=1)
unrolled_df = unrolled_df[["dataset", "spatially_aware", "model", "score_type", "score"]]

# Order datasets
unrolled_df["dataset"] = pd.Categorical(unrolled_df["dataset"], categories=datasets, ordered=True)
unrolled_df = unrolled_df.sort_values(by="dataset")

#print(summary_df["model"].value_counts())
#print(summary_df[summary_df.isna().any(axis=1)])

summary_df["model"] = summary_df["model"].replace("NicheCompass GCN", "NicheCompass Light")
summary_df["model"] = summary_df["model"].replace("NicheCompass GATv2", "NicheCompass")

In [None]:
### Supplementary fig. 17: sample integration runtimes ###
model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "GraphST": "#D78FF8",
                 "GraphST (No Prior Alignment)": "#b5bd61",
                 "CellCharter": "#F46AA2",
                 "BANKSY": "#556B2F"}

run_time_mean_df = summary_df.groupby(["dataset", "model"])[["run_time"]].mean().reset_index()
run_time_mean_df["run_time"] = run_time_mean_df["run_time"] / 60

def create_dataset_share_col(row):
    if row["dataset"] == "seqfish_mouse_organogenesis_imputed":
        return 100
    elif row["dataset"] == "seqfish_mouse_organogenesis_imputed_subsample_50pct":    
        return 50
    elif row["dataset"] == "seqfish_mouse_organogenesis_imputed_subsample_25pct":    
        return 25
    elif row["dataset"] == "seqfish_mouse_organogenesis_imputed_subsample_10pct":    
        return 10
    elif row["dataset"] == "seqfish_mouse_organogenesis_imputed_subsample_5pct":    
        return 5
    elif row["dataset"] == "seqfish_mouse_organogenesis_imputed_subsample_1pct":    
        return 1
    
run_time_mean_df["dataset_share"] = run_time_mean_df.apply(lambda row: create_dataset_share_col(row), axis=1)

with sns.axes_style("ticks"):
    ax = sns.lineplot(data=run_time_mean_df,
                      x="dataset_share",
                      y="run_time",
                      hue="model",
                      marker='o',
                      palette=model_palette)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.title("seqFISH Mouse Organogenesis (Imputed)\n(57,536 Cells; 3,000 Genes)")
    plt.ylabel("Run Time (Minutes)")
    plt.xlabel("Dataset Size (%)")
    custom_y_ticks = [1, 10, 60, 180, 360, 720, 1440]  # Adjust the tick positions as needed
    plt.yscale("log")
    plt.yticks(custom_y_ticks, custom_y_ticks)
    legend = plt.gca().get_legend()
    for handle in legend.legendHandles:
        handle.set_linewidth(4.0)  # Adjust the size as needed
    handles, labels = legend.legendHandles, [text.get_text() for text in legend.get_texts()]
    order = [4, 5, 1, 0, 3, 2, 6]
    ordered_handles = [handles[i] for i in order]
    ordered_labels = [labels[i] for i in order]
    lgd = plt.legend(ordered_handles, ordered_labels, bbox_to_anchor=(1.05, 1), loc='upper left')
    ax = plt.gca()
    #ax.legend().set_visible(False)
    plt.grid(True)
    plt.savefig(benchmarking_folder_path + "/benchmarking_runtimes_seqfish_mouse_organogenesis_imputed.svg", bbox_inches="tight", bbox_extra_artists=[lgd])
    plt.show()

In [None]:
### Supplementary fig. 16b: sample integration metric averages ###
plot_metrics_table(
    df=unrolled_df,
    model_col="model",
    model_col_width=3,
    group_col="dataset",
    metric_cols=metric_cols_sample_integration, # metric_cols_sample_integration, category_cols_sample_integration
    metric_col_weights=metric_col_weights_sample_integration, # metric_col_weights_sample_integration, category_col_weights_sample_integration
    metric_col_titles=[col.replace(" ", "\n") for col in metric_col_titles_sample_integration], # category_col_titles_sample_integration
    metric_col_width=0.5, # 0.8,
    aggregate_col_width=0.8,
    plot_width=43.5, # 32,
    plot_height=6,
    show=True,
    save_dir=benchmarking_folder_path,
    save_name=f"benchmarking_metrics_seqfish_mouse_organogenesis_imputed.svg")

#### 2.1.4 All Datasets

In [None]:
# Define params for plot formatting
fig_width_10_ticks = 8.2
fig_width_9_ticks = 7.8
fig_width_8_ticks = 7.4
fig_width_7_ticks = 7.0
fig_width_6_ticks = 6.6
fig_width_5_ticks = 6.2
fig_width_2_ticks = 5.1
fig_width_3_ticks = 5.5
fig_width_4_ticks = 5.85
fig_height = 6 # 5
fontsize = 22

In [None]:
# Load metrics
datasets = ["seqfish_mouse_organogenesis",
            "seqfish_mouse_organogenesis_imputed",
            "nanostring_cosmx_human_nsclc"]
models = ["nichecompass_gatv2conv",
          "nichecompass_gcnconv",
          "staci",
          "graphst_paste",
          "cellcharter",
          # "graphst",
          "banksy"
         ]

summary_df = pd.DataFrame()
for dataset in datasets:
    dataset_df = pd.DataFrame()
    for model in models:
        try:
            benchmark_df = pd.read_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv")
            #adata = sc.read_h5ad(f"../../artifacts/single_sample_method_benchmarking/{dataset}_{model}.h5ad")
            #training_durations = []
            #for run_number in [1, 2, 3, 4, 5, 6, 7, 8]:
            #    training_durations.append(adata.uns[f"{model.split('_')[0]}_model_training_duration_run{run_number}"])
            #benchmark_df["run_time"] = training_durations
            #benchmark_df = benchmark_df[["dataset", "run_number", "run_time", "gcs", "mlami", "cas", "clisis", "nasw", "cnmi", "cari", "casw", "clisi"]]
            #benchmark_df.to_csv(f"{benchmarking_folder_path}/{dataset}_{model}_metrics.csv", index=False)
            benchmark_df["model"] = model
            dataset_df = pd.concat([dataset_df, benchmark_df], ignore_index=True)
        except FileNotFoundError:
            print(f"Did not find file {benchmarking_folder_path}/{dataset}_{model}_metrics.csv. Continuing...")
            missing_run_data = {
                "dataset": [dataset] * 8,
                "model": [model] * 8,
                "run_number": [1, 2, 3, 4, 5, 6, 7, 8],
                "run_time": [np.nan] * 8
            }
            missing_run_df = pd.DataFrame(missing_run_data)
            dataset_df = pd.concat([dataset_df, missing_run_df], ignore_index=True)
            
    # Apply min-max scaling to metric columns
    for i in range(len(metric_cols_sample_integration)):
        min_val = dataset_df[metric_cols_sample_integration[i]].min()
        max_val = dataset_df[metric_cols_sample_integration[i]].max()
        dataset_df[metric_cols_sample_integration[i] + "_scaled"] = ((
            dataset_df[metric_cols_sample_integration[i]] - min_val) / (max_val - min_val))

    summary_df = pd.concat([summary_df, dataset_df], ignore_index=True)
    continue
            
cat_0_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[0:2]]
cat_1_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[2:4]]
cat_2_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[4:6]]
cat_3_scaled_metric_cols = [metric_col + "_scaled" for metric_col in metric_cols_sample_integration[6:8]]

summary_df['pcr'] = summary_df['pcr'].fillna(0)
summary_df['pcr_scaled'] = summary_df['pcr_scaled'].fillna(0)
    
summary_df[category_cols_sample_integration[0]] = np.average(summary_df[cat_0_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[0:2],
                                                        axis=1)
summary_df[category_cols_sample_integration[1]] = np.average(summary_df[cat_1_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[2:4],
                                                        axis=1)
summary_df[category_cols_sample_integration[2]] = np.average(summary_df[cat_2_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[4:6],
                                                        axis=1)
summary_df[category_cols_sample_integration[3]] = np.average(summary_df[cat_3_scaled_metric_cols],
                                                        weights=metric_col_weights_sample_integration[6:8],
                                                        axis=1)
summary_df["Overall Score"] = np.average(summary_df[category_cols_sample_integration[:4]],
                                         weights=category_col_weights_sample_integration[:4],
                                         axis=1)
 
# Reformat for plot
summary_df.replace({"nichecompass_gatv2conv": "NicheCompass",
                    "nichecompass_gcnconv": "NicheCompass Light",
                    "staci": "STACI",
                    "graphst_paste": "GraphST",
                    "graphst": "GraphST (No Prior Alignment)",
                    "cellcharter": "CellCharter",
                    "banksy": "BANKSY",
                   }, inplace=True)

# Plot over all loss weights combinations
# Prepare metrics table plot
group_cols = ["dataset", "model"]
aggregate_df = summary_df.groupby(group_cols).mean("Overall Score").sort_values("Overall Score", ascending=False)[
    metric_cols_sample_integration + ["Overall Score"]].reset_index()

unrolled_df = pd.melt(aggregate_df, 
   id_vars=group_cols,
   value_vars=metric_cols_sample_integration + ["Overall Score"], # metric_cols_sample_integration, category_cols_sample_integration
   var_name="score_type", 
   value_name="score")

# Create spatial indicator column
def is_spatially_aware_model(row):
    if row["model"] in ["NicheCompass GCN", "NicheCompass GATv2", "DeepLinc", "GraphST", "CellCharter"]:
        return True
    return False
unrolled_df["spatially_aware"] = unrolled_df.apply(lambda row: is_spatially_aware_model(row), axis=1)
unrolled_df = unrolled_df[["dataset", "spatially_aware", "model", "score_type", "score"]]

# Order datasets
unrolled_df["dataset"] = pd.Categorical(unrolled_df["dataset"], categories=datasets, ordered=True)
unrolled_df = unrolled_df.sort_values(by="dataset")

#print(summary_df["model"].value_counts())
#print(summary_df[summary_df.isna().any(axis=1)])

summary_df["model"] = summary_df["model"].replace("NicheCompass GCN", "NicheCompass Light")
summary_df["model"] = summary_df["model"].replace("NicheCompass GATv2", "NicheCompass")

summary_df["dataset"] = summary_df["dataset"].replace(
    {"seqfish_mouse_organogenesis": "seqFISH Mouse Organogenesis",
     "seqfish_mouse_organogenesis_imputed": "seqFISH Mouse Organogenesis (Imputed)",
     "nanostring_cosmx_human_nsclc": "nanoString CosMx Human NSCLC"})

In [None]:
metric = "cas"  # Metric to plot

sns.set_style("whitegrid")

model_palette = {"NicheCompass": "#8C96C6",
                 "NicheCompass Light": "#42B6C7",
                 "STACI": "#FFD700",
                 "GraphST": "#D78FF8",
                 #"GraphST (No Prior Alignment)": "#b5bd61",
                 "CellCharter": "#F46AA2",
                 "BANKSY": "#556B2F"}

plt.figure(figsize=(fig_width_8_ticks, fig_height))  # Adjust figure size

# Boxplot
ax = sns.boxplot(data=summary_df,
                 x=metric,  # Use the variable metric for x-axis
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette,
                 showcaps=False,  # Hide caps for cleaner look
                 fliersize=0,  # Hide individual outlier markers
                 boxprops={'alpha': 0.6},  # Semi-transparent boxes
                 whiskerprops={'linewidth': 1.5})  # Thicker whiskers

# Add horizontal lines between categories
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + i, color='gray', linestyle='--', linewidth=2, zorder=0)

# Overlay with individual points
sns.stripplot(data=summary_df,
              x=metric,  # Use the variable metric for x-axis
              y="dataset",
              hue="model",
              dodge=True,
              palette=model_palette,
              alpha=0.6,  # Transparent points
              size=5,  # Adjust dot size
              jitter=0.2,  # Slight jitter for better separation
              edgecolor="black",  # Add edge color
              linewidth=0.5)

# Add gridlines
ax.grid(axis='x', linestyle='--', alpha=0.7)

# Adjust legend to avoid duplication
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:len(model_palette)], labels[:len(model_palette)],
          loc='upper left', bbox_to_anchor=(1, 1))

# Adjust axes and labels
new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel(metric.upper(), fontsize=fontsize)  # Convert metric name to uppercase
plt.yticks(fontsize=fontsize)
plt.xticks(fontsize=fontsize)  # Adjust ticks
plt.xlim(0., 0.8)  # Set range to [0., 0.8]

plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_boxplot_{metric}_score.svg")
plt.show()

In [None]:
metric = "mlami"  # Metric to plot

plt.figure(figsize=(fig_width_7_ticks, fig_height))  # Adjust figure size

# Boxplot
ax = sns.boxplot(data=summary_df,
                 x=metric,  # Use the variable metric for x-axis
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette,
                 showcaps=False,  # Hide caps for cleaner look
                 fliersize=0,  # Hide individual outlier markers
                 boxprops={'alpha': 0.6},  # Semi-transparent boxes
                 whiskerprops={'linewidth': 1.5})  # Thicker whiskers

# Add horizontal lines between categories
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + i, color='gray', linestyle='--', linewidth=2, zorder=0)

# Overlay with individual points
sns.stripplot(data=summary_df,
              x=metric,  # Use the variable metric for x-axis
              y="dataset",
              hue="model",
              dodge=True,
              palette=model_palette,
              alpha=0.6,  # Transparent points
              size=5,  # Adjust dot size
              jitter=0.2,  # Slight jitter for better separation
              edgecolor="black",  # Add edge color
              linewidth=0.5)

# Add gridlines
ax.grid(axis='x', linestyle='--', alpha=0.7)

# Adjust legend to avoid duplication
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:len(model_palette)], labels[:len(model_palette)],
          loc='upper left', bbox_to_anchor=(1, 1))

# Adjust axes and labels
new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel(metric.upper(), fontsize=fontsize)  # Convert metric name to uppercase
plt.yticks(fontsize=fontsize)
plt.xticks(fontsize=fontsize)  # Adjust ticks for the range
plt.xlim(0.2, 0.9)  # Set range to [0.2, 0.9]

plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_boxplot_{metric}_score.svg")
plt.show()

In [None]:
summary_df["Spatial Consistency Score"] = (summary_df["Global Spatial Consistency Score"] + summary_df["Local Spatial Consistency Score"])/2

In [None]:
plt.figure(figsize=(fig_width_10_ticks * 0.85, fig_height))  # Adjust figure size

# Boxplot
ax = sns.boxplot(data=summary_df,
                 x="Spatial Consistency Score",  # Use "Spatial Consistency Score" as the x-axis metric
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette,
                 showcaps=False,  # Hide caps for cleaner look
                 fliersize=0,  # Hide individual outlier markers
                 boxprops={'alpha': 0.6},  # Semi-transparent boxes
                 whiskerprops={'linewidth': 1.5})  # Thicker whiskers

# Add horizontal lines between categories
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + i, color='gray', linestyle='--', linewidth=2, zorder=0)

# Overlay with individual points
sns.stripplot(data=summary_df,
              x="Spatial Consistency Score",  # Use "Spatial Consistency Score" for the x-axis
              y="dataset",
              hue="model",
              dodge=True,
              palette=model_palette,
              alpha=0.6,  # Transparent points
              size=5,  # Adjust dot size
              jitter=0.2,  # Slight jitter for better separation
              edgecolor="black",  # Add edge color
              linewidth=0.5)

# Add gridlines
ax.grid(axis='x', linestyle='--', alpha=0.7)

# Adjust legend to avoid duplication
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:len(model_palette)], labels[:len(model_palette)],
          loc='upper left', bbox_to_anchor=(1, 1))

# Adjust axes and labels
new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel("Spatial Consistency Score", fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.xticks(fontsize=fontsize)  # Adjust ticks for the range
plt.xlim(0., 1.0)  # Set range to [0., 0.9]

plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_boxplot_spatial_consistency_score.svg")
plt.show()

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis"]
metrics_temp_df = temp_df.groupby("model")[["Spatial Consistency Score"]].mean()
metrics_temp_df.loc["NicheCompass"][0] - np.max(metrics_temp_df[~metrics_temp_df.index.isin(["NicheCompass", "NicheCompass Light"])])

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis (Imputed)"]
metrics_temp_df = temp_df.groupby("model")[["Spatial Consistency Score"]].mean()
metrics_temp_df.loc["NicheCompass"][0] - np.max(metrics_temp_df[~metrics_temp_df.index.isin(["NicheCompass", "NicheCompass Light"])])

In [None]:
temp_df = summary_df[summary_df["dataset"] == "nanoString CosMx Human NSCLC"]
metrics_temp_df = temp_df.groupby("model")[["Spatial Consistency Score"]].mean()
metrics_temp_df.loc["NicheCompass"][0] - np.max(metrics_temp_df[~metrics_temp_df.index.isin(["NicheCompass", "NicheCompass Light"])])

In [None]:
metric = "clisis"  # Metric to plot

plt.figure(figsize=(fig_width_5_ticks, fig_height))  # Adjust figure size

# Boxplot
ax = sns.boxplot(data=summary_df,
                 x=metric,  # Use the variable metric for x-axis
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette,
                 showcaps=False,  # Hide caps for cleaner look
                 fliersize=0,  # Hide individual outlier markers
                 boxprops={'alpha': 0.6},  # Semi-transparent boxes
                 whiskerprops={'linewidth': 1.5})  # Thicker whiskers

# Add horizontal lines between categories
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + i, color='gray', linestyle='--', linewidth=2, zorder=0)

# Overlay with individual points
sns.stripplot(data=summary_df,
              x=metric,  # Use the variable metric for x-axis
              y="dataset",
              hue="model",
              dodge=True,
              palette=model_palette,
              alpha=0.6,  # Transparent points
              size=5,  # Adjust dot size
              jitter=0.2,  # Slight jitter for better separation
              edgecolor="black",  # Add edge color
              linewidth=0.5)

# Add gridlines
ax.grid(axis='x', linestyle='--', alpha=0.7)

# Adjust legend to avoid duplication
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:len(model_palette)], labels[:len(model_palette)],
          loc='upper left', bbox_to_anchor=(1, 1))

# Adjust axes and labels
new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel(metric.upper(), fontsize=fontsize)  # Convert metric name to uppercase
plt.yticks(fontsize=fontsize)
plt.xticks(fontsize=fontsize)  # Adjust ticks for the range
plt.xlim(0.6, 1.0)  # Set range to [0.6, 1.0]

plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_boxplot_{metric}_score.svg")
plt.show()

In [None]:
metric = "gcs"  # Metric to plot

plt.figure(figsize=(fig_width_4_ticks, fig_height))  # Adjust figure size

# Boxplot
ax = sns.boxplot(data=summary_df,
                 x=metric,  # Use the variable metric for x-axis
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette,
                 showcaps=False,  # Hide caps for cleaner look
                 fliersize=0,  # Hide individual outlier markers
                 boxprops={'alpha': 0.6},  # Semi-transparent boxes
                 whiskerprops={'linewidth': 1.5})  # Thicker whiskers

# Add horizontal lines between categories
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + i, color='gray', linestyle='--', linewidth=2, zorder=0)

# Overlay with individual points
sns.stripplot(data=summary_df,
              x=metric,  # Use the variable metric for x-axis
              y="dataset",
              hue="model",
              dodge=True,
              palette=model_palette,
              alpha=0.6,  # Transparent points
              size=5,  # Adjust dot size
              jitter=0.2,  # Slight jitter for better separation
              edgecolor="black",  # Add edge color
              linewidth=0.5)

# Add gridlines
ax.grid(axis='x', linestyle='--', alpha=0.7)

# Adjust legend to avoid duplication
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:len(model_palette)], labels[:len(model_palette)],
          loc='upper left', bbox_to_anchor=(1, 1))

# Adjust axes and labels
new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel(metric.upper(), fontsize=fontsize)  # Convert metric name to uppercase
plt.yticks(fontsize=fontsize)
plt.xticks(fontsize=fontsize)  # Adjust ticks for the range
plt.xlim(0.7, 1.0)  # Set range to [0.7, 1.0]

plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_boxplot_{metric}_score.svg")
plt.show()

In [None]:
plt.figure(figsize=(fig_width_9_ticks * 0.85, fig_height))  # Adjust figure size

# Boxplot
ax = sns.boxplot(data=summary_df,
                 x="Niche Coherence Score",  # Use "Niche Coherence Score" as the x-axis metric
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette,
                 showcaps=False,  # Hide caps for cleaner look
                 fliersize=0,  # Hide individual outlier markers
                 boxprops={'alpha': 0.6},  # Semi-transparent boxes
                 whiskerprops={'linewidth': 1.5})  # Thicker whiskers

# Add horizontal lines between categories
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + i, color='gray', linestyle='--', linewidth=2, zorder=0)

# Overlay with individual points
sns.stripplot(data=summary_df,
              x="Niche Coherence Score",  # Use "Niche Coherence Score" for the x-axis
              y="dataset",
              hue="model",
              dodge=True,
              palette=model_palette,
              alpha=0.6,  # Transparent points
              size=5,  # Adjust dot size
              jitter=0.2,  # Slight jitter for better separation
              edgecolor="black",  # Add edge color
              linewidth=0.5)

# Add gridlines
ax.grid(axis='x', linestyle='--', alpha=0.7)

# Adjust legend to avoid duplication
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:len(model_palette)], labels[:len(model_palette)],
          loc='upper left', bbox_to_anchor=(1, 1))

# Adjust axes and labels
new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel("Niche Coherence Score", fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.xticks(fontsize=fontsize)  # Adjust ticks for the range
plt.xlim(0., 1.0)  # Set range to [0., 1.0]

plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_boxplot_niche_coherence_score.svg")
plt.show()

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis"]
metrics_temp_df = temp_df.groupby("model")[["Niche Coherence Score"]].mean()
metrics_temp_df.loc["NicheCompass"][0] - np.max(metrics_temp_df[~metrics_temp_df.index.isin(["NicheCompass", "NicheCompass Light"])])

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis (Imputed)"]
metrics_temp_df = temp_df.groupby("model")[["Niche Coherence Score"]].mean()
metrics_temp_df.loc["NicheCompass"][0] - np.max(metrics_temp_df[~metrics_temp_df.index.isin(["NicheCompass", "NicheCompass Light"])])

In [None]:
temp_df = summary_df[summary_df["dataset"] == "nanoString CosMx Human NSCLC"]
metrics_temp_df = temp_df.groupby("model")[["Niche Coherence Score"]].mean()
metrics_temp_df.loc["NicheCompass"][0] - np.max(metrics_temp_df[~metrics_temp_df.index.isin(["NicheCompass", "NicheCompass Light"])])

In [None]:
metric = "cnmi"  # Metric to plot

plt.figure(figsize=(fig_width_7_ticks, fig_height))  # Adjust figure size

# Boxplot
ax = sns.boxplot(data=summary_df,
                 x=metric,  # Use the variable metric for x-axis
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette,
                 showcaps=False,  # Hide caps for cleaner look
                 fliersize=0,  # Hide individual outlier markers
                 boxprops={'alpha': 0.6},  # Semi-transparent boxes
                 whiskerprops={'linewidth': 1.5})  # Thicker whiskers

# Add horizontal lines between categories
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + i, color='gray', linestyle='--', linewidth=2, zorder=0)

# Overlay with individual points
sns.stripplot(data=summary_df,
              x=metric,  # Use the variable metric for x-axis
              y="dataset",
              hue="model",
              dodge=True,
              palette=model_palette,
              alpha=0.6,  # Transparent points
              size=5,  # Adjust dot size
              jitter=0.2,  # Slight jitter for better separation
              edgecolor="black",  # Add edge color
              linewidth=0.5)

# Add gridlines
ax.grid(axis='x', linestyle='--', alpha=0.7)

# Adjust legend to avoid duplication
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:len(model_palette)], labels[:len(model_palette)],
          loc='upper left', bbox_to_anchor=(1, 1))

# Adjust axes and labels
new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel(metric.upper(), fontsize=fontsize)  # Convert metric name to uppercase
plt.yticks(fontsize=fontsize)
plt.xticks(fontsize=fontsize)  # Adjust ticks for the range
plt.xlim(0., 0.7)  # Set range to [0., 0.7]

plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_boxplot_{metric}_score.svg")
plt.show()

In [None]:
metric = "nasw"  # Metric to plot

plt.figure(figsize=(fig_width_3_ticks, fig_height))  # Adjust figure size

# Boxplot
ax = sns.boxplot(data=summary_df,
                 x=metric,  # Use the variable metric for x-axis
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette,
                 showcaps=False,  # Hide caps for cleaner look
                 fliersize=0,  # Hide individual outlier markers
                 boxprops={'alpha': 0.6},  # Semi-transparent boxes
                 whiskerprops={'linewidth': 1.5})  # Thicker whiskers

# Add horizontal lines between categories
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + i, color='gray', linestyle='--', linewidth=2, zorder=0)

# Overlay with individual points
sns.stripplot(data=summary_df,
              x=metric,  # Use the variable metric for x-axis
              y="dataset",
              hue="model",
              dodge=True,
              palette=model_palette,
              alpha=0.6,  # Transparent points
              size=5,  # Adjust dot size
              jitter=0.2,  # Slight jitter for better separation
              edgecolor="black",  # Add edge color
              linewidth=0.5)

# Add gridlines
ax.grid(axis='x', linestyle='--', alpha=0.7)

# Adjust legend to avoid duplication
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:len(model_palette)], labels[:len(model_palette)],
          loc='upper left', bbox_to_anchor=(1, 1))

# Adjust axes and labels
new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel(metric.upper(), fontsize=fontsize)  # Convert metric name to uppercase
plt.yticks(fontsize=fontsize)
plt.xticks(fontsize=fontsize)  # Adjust ticks for the range
plt.xlim(0.4, 0.7)  # Set range to [0.4, 0.7]

plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_boxplot_{metric}_score.svg")
plt.show()


In [None]:
plt.figure(figsize=(fig_width_9_ticks * 0.85, fig_height))  # Adjust figure size

# Boxplot
ax = sns.boxplot(data=summary_df,
                 x="Batch Correction Score",  # Use "Batch Correction Score" as the x-axis metric
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette,
                 showcaps=False,  # Hide caps for cleaner look
                 fliersize=0,  # Hide individual outlier markers
                 boxprops={'alpha': 0.6},  # Semi-transparent boxes
                 whiskerprops={'linewidth': 1.5})  # Thicker whiskers

# Add horizontal lines between categories
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + i, color='gray', linestyle='--', linewidth=2, zorder=0)

# Overlay with individual points
sns.stripplot(data=summary_df,
              x="Batch Correction Score",  # Use "Batch Correction Score" for the x-axis
              y="dataset",
              hue="model",
              dodge=True,
              palette=model_palette,
              alpha=0.6,  # Transparent points
              size=5,  # Adjust dot size
              jitter=0.2,  # Slight jitter for better separation
              edgecolor="black",  # Add edge color
              linewidth=0.5)

# Add gridlines
ax.grid(axis='x', linestyle='--', alpha=0.7)

# Adjust legend to avoid duplication
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:len(model_palette)], labels[:len(model_palette)],
          loc='upper left', bbox_to_anchor=(1, 1))

# Adjust axes and labels
new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel("Batch Correction Score", fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.xticks(fontsize=fontsize)  # Adjust ticks for the range
plt.xlim(0., 1.0)  # Set range to [0., 1.0]

plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_boxplot_batch_correction_score.svg")
plt.show()

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis"]
metrics_temp_df = temp_df.groupby("model")[["Batch Correction Score"]].mean()
metrics_temp_df.loc["NicheCompass"][0] - np.max(metrics_temp_df[~metrics_temp_df.index.isin(["NicheCompass", "NicheCompass Light"])])

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis (Imputed)"]
metrics_temp_df = temp_df.groupby("model")[["Batch Correction Score"]].mean()
metrics_temp_df.loc["NicheCompass"][0] - np.max(metrics_temp_df[~metrics_temp_df.index.isin(["NicheCompass", "NicheCompass Light"])])

In [None]:
temp_df = summary_df[summary_df["dataset"] == "nanoString CosMx Human NSCLC"]
metrics_temp_df = temp_df.groupby("model")[["Batch Correction Score"]].mean()
metrics_temp_df.loc["NicheCompass"][0] - np.max(metrics_temp_df[~metrics_temp_df.index.isin(["NicheCompass", "NicheCompass Light"])])

In [None]:
metric = "blisi"  # Metric to plot

plt.figure(figsize=(fig_width_7_ticks, fig_height))  # Adjust figure size

# Boxplot
ax = sns.boxplot(data=summary_df,
                 x=metric,  # Use the variable metric for x-axis
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette,
                 showcaps=False,  # Hide caps for cleaner look
                 fliersize=0,  # Hide individual outlier markers
                 boxprops={'alpha': 0.6},  # Semi-transparent boxes
                 whiskerprops={'linewidth': 1.5})  # Thicker whiskers

# Add horizontal lines between categories
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + i, color='gray', linestyle='--', linewidth=2, zorder=0)

# Overlay with individual points
sns.stripplot(data=summary_df,
              x=metric,  # Use the variable metric for x-axis
              y="dataset",
              hue="model",
              dodge=True,
              palette=model_palette,
              alpha=0.6,  # Transparent points
              size=5,  # Adjust dot size
              jitter=0.2,  # Slight jitter for better separation
              edgecolor="black",  # Add edge color
              linewidth=0.5)

# Add gridlines
ax.grid(axis='x', linestyle='--', alpha=0.7)

# Adjust legend to avoid duplication
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:len(model_palette)], labels[:len(model_palette)],
          loc='upper left', bbox_to_anchor=(1, 1))

# Adjust axes and labels
new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel(metric.upper(), fontsize=fontsize)  # Convert metric name to uppercase
plt.yticks(fontsize=fontsize)
plt.xticks(fontsize=fontsize)  # Adjust ticks for the range
plt.xlim(0., 0.6)  # Set range to [0., 0.6]

plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_boxplot_{metric}_score.svg")
plt.show()

In [None]:
metric = "pcr"  # Metric to plot

plt.figure(figsize=(fig_width_10_ticks, fig_height))  # Adjust figure size

# Boxplot
ax = sns.boxplot(data=summary_df,
                 x=metric,  # Use the variable metric for x-axis
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette,
                 showcaps=False,  # Hide caps for cleaner look
                 fliersize=0,  # Hide individual outlier markers
                 boxprops={'alpha': 0.6},  # Semi-transparent boxes
                 whiskerprops={'linewidth': 1.5})  # Thicker whiskers

# Add horizontal lines between categories
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + i, color='gray', linestyle='--', linewidth=2, zorder=0)

# Overlay with individual points
sns.stripplot(data=summary_df,
              x=metric,  # Use the variable metric for x-axis
              y="dataset",
              hue="model",
              dodge=True,
              palette=model_palette,
              alpha=0.6,  # Transparent points
              size=5,  # Adjust dot size
              jitter=0.2,  # Slight jitter for better separation
              edgecolor="black",  # Add edge color
              linewidth=0.5)

# Add gridlines
ax.grid(axis='x', linestyle='--', alpha=0.7)

# Adjust legend to avoid duplication
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:len(model_palette)], labels[:len(model_palette)],
          loc='upper left', bbox_to_anchor=(1, 1))

# Adjust axes and labels
new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel(metric.upper(), fontsize=fontsize)  # Convert metric name to uppercase
plt.yticks(fontsize=fontsize)
plt.xticks(fontsize=fontsize)  # Adjust ticks for the range
plt.xlim(0., 1.0)  # Set range to [0., 1.0]

plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_boxplot_{metric}_score.svg")
plt.show()


In [None]:
plt.figure(figsize=(fig_width_7_ticks * 0.85, fig_height))  # Adjust figure size

# Boxplot
ax = sns.boxplot(data=summary_df,
                 x="Overall Score",  # Use "Overall Score" as the x-axis metric
                 y="dataset",
                 hue="model",
                 orient="h",
                 palette=model_palette,
                 showcaps=False,  # Hide caps for cleaner look
                 fliersize=0,  # Hide individual outlier markers
                 boxprops={'alpha': 0.6},  # Semi-transparent boxes
                 whiskerprops={'linewidth': 1.5})  # Thicker whiskers

# Add horizontal lines between categories
for i in range(summary_df["dataset"].nunique()):
    ax.axhline(0.5 + i, color='gray', linestyle='--', linewidth=2, zorder=0)

# Overlay with individual points
sns.stripplot(data=summary_df,
              x="Overall Score",  # Use "Overall Score" for the x-axis
              y="dataset",
              hue="model",
              dodge=True,
              palette=model_palette,
              alpha=0.6,  # Transparent points
              size=5,  # Adjust dot size
              jitter=0.2,  # Slight jitter for better separation
              edgecolor="black",  # Add edge color
              linewidth=0.5)

# Add gridlines
ax.grid(axis='x', linestyle='--', alpha=0.7)

# Adjust legend to avoid duplication
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:len(model_palette)], labels[:len(model_palette)],
          loc='upper left', bbox_to_anchor=(1, 1))

# Adjust axes and labels
new_labels = [label.get_text().replace(' ', '\n') if i < 3 else label.get_text().replace(' H', '\nH', -1).replace(' N', '\nN', -1).replace(' CosMx', '\nCosMx', -1) for i, label in enumerate(ax.get_yticklabels())]
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_yticklabels(new_labels)
plt.ylabel(None)
plt.xlabel("Overall Score", fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.xticks(fontsize=fontsize)  # Adjust ticks for the range
plt.xlim(0.1, 0.8)  # Set range to [0.1, 0.7]

plt.tight_layout()
plt.savefig(f"{benchmarking_folder_path}/benchmarking_boxplot_overall_score.svg")
plt.show()

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis"]
metrics_temp_df = temp_df.groupby("model")[["Overall Score"]].mean()
metrics_temp_df.loc["NicheCompass"][0] - np.max(metrics_temp_df[~metrics_temp_df.index.isin(["NicheCompass", "NicheCompass Light"])])

In [None]:
temp_df = summary_df[summary_df["dataset"] == "seqFISH Mouse Organogenesis (Imputed)"]
metrics_temp_df = temp_df.groupby("model")[["Overall Score"]].mean()
metrics_temp_df.loc["NicheCompass"][0] - np.max(metrics_temp_df[~metrics_temp_df.index.isin(["NicheCompass", "NicheCompass Light"])])

In [None]:
temp_df = summary_df[summary_df["dataset"] == "nanoString CosMx Human NSCLC"]
metrics_temp_df = temp_df.groupby("model")[["Overall Score"]].mean()
metrics_temp_df.loc["NicheCompass"][0] - np.max(metrics_temp_df[~metrics_temp_df.index.isin(["NicheCompass", "NicheCompass Light"])])

In [None]:
### Supplementary fig. 17: gene scalability analysis ###
size_dict = {}
size_dict["dataset"] = [] 
size_dict["n_genes"] = []
size_dict["n_gps"] = []
size_dict["n_params"] = []

model_label = "gatv2conv_sample_integration_method_benchmarking"
run = "run1"

for dataset, timestamp in zip(
    ["seqfish_mouse_organogenesis", # 256 edge batch size
     "seqfish_mouse_organogenesis_imputed", # 2048 edge batch size
     "nanostring_cosmx_human_nsclc"], # 512 edge batch size
    ["22082023_184821_1",
     "28082023_181323_1",
     "21082023_190305_1"]):

    model_folder_path = f"{artifact_folder_path}/{dataset}/models/{model_label}/{timestamp}/{run}"
    model = NicheCompass.load(dir_path=model_folder_path,
                              adata=None,
                              adata_file_name=f"{dataset}_{model_label}.h5ad",
                              gp_names_key="nichecompass_gp_names")

    size_dict["dataset"].append(dataset)
    size_dict["n_genes"].append(len(model.adata.var))
    size_dict["n_gps"].append(model.adata.obsm["nichecompass_latent"].shape[1])
    size_dict["n_params"].append(sum(p.numel() for p in model.model.parameters()))
    
size_df = pd.DataFrame(size_dict)
size_df.to_csv(f"{benchmarking_folder_path}/model_sizes.csv", index=False)
size_df.head()

In [None]:
size_df = pd.read_csv(f"{benchmarking_folder_path}/model_sizes.csv")

size_df.columns = ["Dataset", "Number of Genes", "Number of GPs", "Number of Params"]

size_df["Dataset"].replace({"seqfish_mouse_organogenesis": "seqFISH Mouse Organogenesis",
                            "seqfish_mouse_organogenesis_imputed": "seqFISH Mouse Organogenesis (Imputed)",
                            "nanostring_cosmx_human_nsclc": "Nanostring CosMx Human NSCLC"}, inplace=True)

fig, ax = plt.subplots(figsize=(10, 2))
ax.axis('off')
ax.axis('tight')
table = ax.table(cellText=size_df.values,
                 colLabels=size_df.columns,
                 cellLoc='center',
                 loc='center',
                 colColours=["lightblue"] + ["darkgrey"]*(len(size_df.columns) -1))
table.auto_set_font_size(False)
table.set_fontsize(14)
table.scale(2.2, 2.)
plt.savefig(f"{benchmarking_folder_path}/model_sizes.svg", bbox_inches="tight", dpi=300)
plt.show()

## 3. Compute Metrics on Analysis Datasets

### 3.1 Xenium Human Breast Cancer

In [None]:
dataset = "xenium_human_breast_cancer"
timestamp = "26102023_153021"
model_label = "reference"
model_folder_path = f"{artifact_folder_path}/{dataset}/models/{model_label}/{timestamp}"
metrics = "gcs mlami cas clisis nasw cnmi cari casw clisi blisi pcr"
cell_type_key = "cell_type"
batch_key = "batch"
spatial_key = "spatial"
latent_key = "nichecompass_latent"

In [None]:
model = NicheCompass.load(dir_path=model_folder_path,
                          adata=None,
                          adata_file_name=f"{dataset}_{model_label}.h5ad",
                          gp_names_key="nichecompass_gp_names")

if dataset == "xenium_human_breast_cancer":
    model.adata.obs["batch"] = model.adata.obs["batch"].replace({"sample1": "Replicate 1", "sample2": "Replicate 2"})
    
    trans_from=[['Epi_ABCC11+', 'Epi_FOXA1+', 'Epi_AGR3+', 'Epi_CENPF+', 'mgEpi_KRT14+', 'Epi_KRT14+'],['EC_CLEC14A+', 'EC_CAVIN2+'],['adipo_FB', 'GJB2+iKC-FB'],['EMT-Epi1_CEACAM6+', 'EMT-Epi2_CEACAM6+', 'EMT-Epi_SERPINA3+', 'EMT-Epi_KRT23+'],['DERL3+B', 'BANK1+B', 'B'],['eff_CD8+T1', 'eff_CD8+T2',],['tcm_CD4+T', 'CD161+FOXP3+T'],['NK/T'],['ADIPOQ+Mast'],['M2MØ', 'MMP12+miMØ'], ['DC1']]
    trans_to = ['Epithelial', 'Endothelial', 'Fibroblast', 'EMT', 'B_cells', 'CD8+T', 'CD4+T', 'NK/T', 'Mast', 'MØ', 'DC']

    model.adata.obs['cell_type'] = [str(i) for i in model.adata.obs['cell_states']]
    for leiden,celltype in zip(trans_from, trans_to):
        for leiden_from in leiden:
            model.adata.obs['cell_type'][model.adata.obs['cell_type'] == leiden_from] = celltype
            
    model.adata.obs[cell_type_key] = model.adata.obs[cell_type_key].replace("MØ", "Mɸ")
    model.adata.obs["cell_states"] = model.adata.obs["cell_states"].replace("M2MØ", "M2Mɸ").replace("MMP12+miMØ", "MMP12+miMɸ")
    
model.adata.obs[cell_type_key] = model.adata.obs[cell_type_key].astype('category')

In [None]:
if "pcr" in metrics:
    sc.tl.pca(model.adata, use_highly_variable=False)
    pcr_X_pre = model.adata.obsm["X_pca"]
else:
    pcr_X_pre = None

benchmark_dict = compute_benchmarking_metrics(
        adata=model.adata,
        metrics=metrics,
        cell_type_key=cell_type_key,
        batch_key=batch_key,
        spatial_key=spatial_key,
        latent_key=latent_key,
        pcr_X_pre=pcr_X_pre,
        n_jobs=1,
        seed=0,
        mlflow_experiment_id=None)

benchmark_df = pd.DataFrame(benchmark_dict.values(), index=benchmark_dict.keys()).T
benchmark_df.to_csv(f"{benchmarking_folder_path}/{dataset}_metrics.csv", index=False)