# Sample Integration Method Benchmarking

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 20.01.2023
- **Date of Last Modification:** 22.08.2023

## 1. Setup

### 1.1 Import Libraries

In [4]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
import sys
sys.path.append("../../utils")

In [15]:
import argparse
import os
import pickle
import random
import shutil
import warnings
from copy import deepcopy
from datetime import datetime

import anndata as ad
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import mlflow
import numpy as np
import pandas as pd
import plottable
import scanpy as sc
import scipy.sparse as sp
import scvi
import seaborn as sns
import squidpy as sq
import torch
from GraphST import GraphST
from matplotlib import gridspec
from matplotlib.pyplot import rc_context
from plottable import ColumnDefinition, Table
from plottable.cmap import normed_cmap
from plottable.formatters import tickcross
from plottable.plots import bar
from sklearn.decomposition import KernelPCA

from nichecompass.benchmarking import compute_clisis, compute_cas
from nichecompass.models import NicheCompass

from benchmarking_utils import *

### 1.2 Define Parameters

In [8]:
metric_cols_sample_integration = [
    "gcs", "mlami", "cas", "clisis", # spatial conservation
    "cari", "cnmi", "casw", "clisi", # biological conservation
    "nasw", # cluster separability
    "basw", "bgc", "blisi" # batch correction
]
metric_col_weights_sample_integration = [ # separate for each category (later multiplied with category_col_weights)
    (1/6), (1/6), (1/3), (1/3), # spatial conservation
    (1/4), (1/4), (1/4), (1/4), # biological conservation
    1.0, # cluster separability
    (1/3), (1/3), (1,3) # batch correction
]
metric_col_titles_sample_integration = [
    "Graph Connectivity Similarity",
    "Maximum Leiden Adjusted Mutual Info",
    "Cell Type Affinity Similarity",
    "Cell Type Local Inverse Simpson's Index Similarity",
    "Cell Type Adjusted Rand Index",
    "Cell Type Normalized Mutual Info",
    "Cell Type Average Silhouette Width",
    "Cell Type Local Inverse Simpson's Index",
    "Cluster Average Silhouette Width",
    "Batch Average Silhouette Width",
    "Batch Graph Connectivity",
    "Batch Local Inverse Simpson's Index"
]
metric_cols_single_sample = metric_cols_sample_integration[:-3]
metric_col_weights_single_sample = metric_col_weights_sample_integration[:-3]
metric_col_titles_single_sample = metric_col_titles_sample_integration[:-3]

category_cols_sample_integration = [
    "Spatial Conservation Score",
    "Biological Conservation Score",
    "Cluster Separability Score",
    "Batch Correction Score"]
category_col_weights_sample_integration = [
    1,
    1,
    1,
    1]
category_col_titles_sample_integration = [
    "Spatial Conservation Score (25%)",
    "Biological Conservation Score (25%)",
    "Cluster Separability Score (25%)",
    "Batch Correction Score (25%)"]
category_col_weights_single_sample = category_col_weights_sample_integration[:-1]
category_cols_single_sample = category_cols_sample_integration[:-1]
category_col_titles_single_sample = [
    "Spatial Conservation Score (33%)",
    "Biological Conservation Score (33%)",
    "Cluster Separability Score (33%)"]

### 1.3 Run Notebook Setup

In [9]:
sc.set_figure_params(figsize=(6, 6))
sns.set_style("whitegrid", {'axes.grid' : False})

In [10]:
# Ignore future warnings and user warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)

In [11]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

In [12]:
# Set mlflow tracking server (run it on the defined port)
mlflow.set_tracking_uri("http://localhost:8889")

### 1.4 Configure Paths and Create Directories

In [13]:
data_folder_path = "../../datasets/srt_data/gold"
artifact_folder_path = f"../../artifacts"

## 2. Method Benchmarking

- Run all model notebooks in the ```notebooks/sample_integration_method_benchmarking``` directory before continuing.

### 2.1 Retrieve NicheCompass Runs

#### 2.1.1 seqFISH Mouse Organogenesis

In [17]:
# Store NicheCompass GCN encoder results in benchmarking folder
task = "sample_integration_method_benchmarking"
conv_layer_encoder = "gcnconv"
datasets = [
    "seqfish_mouse_organogenesis",
    "seqfish_mouse_organogenesis_subsample_50pct",
    "seqfish_mouse_organogenesis_subsample_25pct",
    "seqfish_mouse_organogenesis_subsample_10pct",
    "seqfish_mouse_organogenesis_subsample_5pct",
    "seqfish_mouse_organogenesis_subsample_1pct",
]
timestamps = [
    "22082023_135318_1",
    "22082023_143133_1",
    "22082023_143133_1",
    "22082023_150526_1",
    "22082023_151248_1",
    "22082023_152633_1",
]

for dataset, timestamp in zip(datasets, timestamps):
    source_path = f"{artifact_folder_path}/{dataset}/results/{conv_layer_encoder}_{task}/{timestamp}/{dataset}_{conv_layer_encoder}_{task}.h5ad"
    destination_path = f"{artifact_folder_path}/{task}/{dataset}_nichecompass_{conv_layer_encoder}.h5ad"
    shutil.copy(source_path, destination_path)

In [None]:
# Store NicheCompass GCN encoder results in benchmarking folder
task = "single_sample_method_benchmarking"
conv_layer_encoder = "gatv2conv"
datasets = [
    "seqfish_mouse_organogenesis",
    "seqfish_mouse_organogenesis_subsample_50pct",
    "seqfish_mouse_organogenesis_subsample_25pct",
    "seqfish_mouse_organogenesis_subsample_10pct",
    "seqfish_mouse_organogenesis_subsample_5pct",
    "seqfish_mouse_organogenesis_subsample_1pct",
]
timestamps = [
    "",
    "",
    "",
    "",
    "",
    "",
]

for dataset, timestamp in zip(datasets, timestamps):
    source_path = f"{artifact_folder_path}/{dataset}/results/{conv_layer_encoder}_{task}/{timestamp}/{dataset}_{conv_layer_encoder}_{task}.h5ad"
    destination_path = f"{artifact_folder_path}/{task}/{dataset}_nichecompass_{conv_layer_encoder}.h5ad"
    shutil.copy(source_path, destination_path)

#### 2.1.2 nanoString CosMx SMI Human Non-Small-Cell Lung Cancer (NSCLC)

In [None]:
# Store NicheCompass GCN encoder results in benchmarking folder
task = "sample_integration_method_benchmarking"
conv_layer_encoder = "gcnconv"
datasets = [
    "nanostring_cosmx_human_nsclc",
    "nanostring_cosmx_human_nsclc_subsample_50pct",
    #"nanostring_cosmx_human_nsclc_subsample_25pct",
    "nanostring_cosmx_human_nsclc_subsample_10pct",
    "nanostring_cosmx_human_nsclc_subsample_5pct",
    "nanostring_cosmx_human_nsclc_subsample_1pct",
]
timestamps = [
    "21082023_162330_1",
    "21082023_205148_1",
    #"22082023_084950_1",
    "22082023_093531_1",
    "22082023_093728_1",
    "22082023_093729_1",
]

for dataset, timestamp in zip(datasets, timestamps):
    source_path = f"{artifact_folder_path}/{dataset}/results/{conv_layer_encoder}_{task}/{timestamp}/{dataset}_{conv_layer_encoder}_{task}.h5ad"
    destination_path = f"{artifact_folder_path}/{task}/{dataset}_nichecompass_{conv_layer_encoder}.h5ad"
    shutil.copy(source_path, destination_path)

In [None]:
# Add missing runs for 'nanostring_cosmx_human_nsclc'
dataset = "nanostring_cosmx_human_nsclc"
conv_layer_encoder = "gcnconv"
timestamp = "21082023_162250_1"

adata1 = sc.read_h5ad(f"{benchmarking_folder_path}/{dataset}_nichecompass_{conv_layer_encoder}.h5ad")
adata2 = sc.read_h5ad(f"{artifact_folder_path}/{dataset}/results/{conv_layer_encoder}_single_sample_method_benchmarking/{timestamp}/{dataset}_{conv_layer_encoder}_single_sample_method_benchmarking.h5ad")

for run_number in [3, 4]:
    adata1.uns[f"nichecompass_latent_run{run_number}_umap"] = adata2.uns[f"nichecompass_latent_run{run_number}_umap"]
    adata1.uns[f"nichecompass_model_training_duration_run{run_number}"] = adata2.uns[f"nichecompass_model_training_duration_run{run_number}"]
    adata1.obsm[f"nichecompass_latent_run{run_number}"] = adata2.obsm[f"nichecompass_latent_run{run_number}"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_connectivities"] = adata2.obsp[f"nichecompass_latent_run{run_number}_connectivities"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_distances"] = adata2.obsp[f"nichecompass_latent_run{run_number}_distances"]
    
del(adata2)
gc.collect()

timestamp = "22082023_010227_1"
adata2 = sc.read_h5ad(f"{artifact_folder_path}/{dataset}/results/{conv_layer_encoder}_single_sample_method_benchmarking/{timestamp}/{dataset}_{conv_layer_encoder}_single_sample_method_benchmarking.h5ad")

for run_number in [5]:
    adata1.uns[f"nichecompass_latent_run{run_number}_umap"] = adata2.uns[f"nichecompass_latent_run{run_number}_umap"]
    adata1.uns[f"nichecompass_model_training_duration_run{run_number}"] = adata2.uns[f"nichecompass_model_training_duration_run{run_number}"]
    adata1.obsm[f"nichecompass_latent_run{run_number}"] = adata2.obsm[f"nichecompass_latent_run{run_number}"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_connectivities"] = adata2.obsp[f"nichecompass_latent_run{run_number}_connectivities"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_distances"] = adata2.obsp[f"nichecompass_latent_run{run_number}_distances"]
    
del(adata2)
gc.collect()

timestamp = "21082023_163844_1"
adata2 = sc.read_h5ad(f"{artifact_folder_path}/{dataset}/results/{conv_layer_encoder}_single_sample_method_benchmarking/{timestamp}/{dataset}_{conv_layer_encoder}_single_sample_method_benchmarking.h5ad")

for run_number in [6]:
    adata1.uns[f"nichecompass_latent_run{run_number}_umap"] = adata2.uns[f"nichecompass_latent_run{run_number}_umap"]
    adata1.uns[f"nichecompass_model_training_duration_run{run_number}"] = adata2.uns[f"nichecompass_model_training_duration_run{run_number}"]
    adata1.obsm[f"nichecompass_latent_run{run_number}"] = adata2.obsm[f"nichecompass_latent_run{run_number}"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_connectivities"] = adata2.obsp[f"nichecompass_latent_run{run_number}_connectivities"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_distances"] = adata2.obsp[f"nichecompass_latent_run{run_number}_distances"]
    
del(adata2)
gc.collect()

timestamp = "21082023_162134_1"
adata2 = sc.read_h5ad(f"{artifact_folder_path}/{dataset}/results/{conv_layer_encoder}_single_sample_method_benchmarking/{timestamp}/{dataset}_{conv_layer_encoder}_single_sample_method_benchmarking.h5ad")

for run_number in [7]:
    adata1.uns[f"nichecompass_latent_run{run_number}_umap"] = adata2.uns[f"nichecompass_latent_run{run_number}_umap"]
    adata1.uns[f"nichecompass_model_training_duration_run{run_number}"] = adata2.uns[f"nichecompass_model_training_duration_run{run_number}"]
    adata1.obsm[f"nichecompass_latent_run{run_number}"] = adata2.obsm[f"nichecompass_latent_run{run_number}"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_connectivities"] = adata2.obsp[f"nichecompass_latent_run{run_number}_connectivities"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_distances"] = adata2.obsp[f"nichecompass_latent_run{run_number}_distances"]
    
del(adata2)
gc.collect()

timestamp = "21082023_155619_1"
adata2 = sc.read_h5ad(f"{artifact_folder_path}/{dataset}/results/{conv_layer_encoder}_single_sample_method_benchmarking/{timestamp}/{dataset}_{conv_layer_encoder}_single_sample_method_benchmarking.h5ad")

for run_number in [8]:
    adata1.uns[f"nichecompass_latent_run{run_number}_umap"] = adata2.uns[f"nichecompass_latent_run{run_number}_umap"]
    adata1.uns[f"nichecompass_model_training_duration_run{run_number}"] = adata2.uns[f"nichecompass_model_training_duration_run{run_number}"]
    adata1.obsm[f"nichecompass_latent_run{run_number}"] = adata2.obsm[f"nichecompass_latent_run{run_number}"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_connectivities"] = adata2.obsp[f"nichecompass_latent_run{run_number}_connectivities"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_distances"] = adata2.obsp[f"nichecompass_latent_run{run_number}_distances"]
    
adata1.write(f"{benchmarking_folder_path}/{dataset}_nichecompass_{conv_layer_encoder}.h5ad")

In [None]:
# Add missing runs for 'nanostring_cosmx_human_nsclc_subsample_50pct'
dataset = "nanostring_cosmx_human_nsclc_subsample_50pct"
conv_layer_encoder = "gcnconv"
timestamp = "21082023_204746_1"

adata1 = sc.read_h5ad(f"{benchmarking_folder_path}/{dataset}_nichecompass_{conv_layer_encoder}.h5ad")
adata2 = sc.read_h5ad(f"{artifact_folder_path}/{dataset}/results/{conv_layer_encoder}_single_sample_method_benchmarking/{timestamp}/{dataset}_{conv_layer_encoder}_single_sample_method_benchmarking.h5ad")

for run_number in [5, 6]:
    adata1.uns[f"nichecompass_latent_run{run_number}_umap"] = adata2.uns[f"nichecompass_latent_run{run_number}_umap"]
    adata1.uns[f"nichecompass_model_training_duration_run{run_number}"] = adata2.uns[f"nichecompass_model_training_duration_run{run_number}"]
    adata1.obsm[f"nichecompass_latent_run{run_number}"] = adata2.obsm[f"nichecompass_latent_run{run_number}"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_connectivities"] = adata2.obsp[f"nichecompass_latent_run{run_number}_connectivities"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_distances"] = adata2.obsp[f"nichecompass_latent_run{run_number}_distances"]
    
del(adata2)
gc.collect()

timestamp = "22082023_084437_1"
adata2 = sc.read_h5ad(f"{artifact_folder_path}/{dataset}/results/{conv_layer_encoder}_single_sample_method_benchmarking/{timestamp}/{dataset}_{conv_layer_encoder}_single_sample_method_benchmarking.h5ad")

for run_number in [7]:
    adata1.uns[f"nichecompass_latent_run{run_number}_umap"] = adata2.uns[f"nichecompass_latent_run{run_number}_umap"]
    adata1.uns[f"nichecompass_model_training_duration_run{run_number}"] = adata2.uns[f"nichecompass_model_training_duration_run{run_number}"]
    adata1.obsm[f"nichecompass_latent_run{run_number}"] = adata2.obsm[f"nichecompass_latent_run{run_number}"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_connectivities"] = adata2.obsp[f"nichecompass_latent_run{run_number}_connectivities"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_distances"] = adata2.obsp[f"nichecompass_latent_run{run_number}_distances"]
    
del(adata2)
gc.collect()

timestamp = "22082023_084415_1"
adata2 = sc.read_h5ad(f"{artifact_folder_path}/{dataset}/results/{conv_layer_encoder}_single_sample_method_benchmarking/{timestamp}/{dataset}_{conv_layer_encoder}_single_sample_method_benchmarking.h5ad")

for run_number in [8]:
    adata1.uns[f"nichecompass_latent_run{run_number}_umap"] = adata2.uns[f"nichecompass_latent_run{run_number}_umap"]
    adata1.uns[f"nichecompass_model_training_duration_run{run_number}"] = adata2.uns[f"nichecompass_model_training_duration_run{run_number}"]
    adata1.obsm[f"nichecompass_latent_run{run_number}"] = adata2.obsm[f"nichecompass_latent_run{run_number}"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_connectivities"] = adata2.obsp[f"nichecompass_latent_run{run_number}_connectivities"]
    adata1.obsp[f"nichecompass_latent_run{run_number}_distances"] = adata2.obsp[f"nichecompass_latent_run{run_number}_distances"]
    
del(adata2)
gc.collect()
    
adata1.write(f"{benchmarking_folder_path}/{dataset}_nichecompass_{conv_layer_encoder}.h5ad")

## 2. Sample Integration Evaluation

### 2.1 One-Shot Batch Integration

#### 2.1.1 Latent Space Comparison Visualization

In [None]:
run_number = 5

compute_latent_space_comparison(dataset="seqfish_mouse_organogenesis",
                                run_number=run_number,
                                srt_data_results_folder_path=f"{srt_data_results_folder_path}/sample_integration_method_benchmarking",
                                cell_type_colors=seqfish_mouse_organogenesis_cell_type_colors,
                                dataset_title_string="seqFISH Mouse Organogenesis",
                                cell_type_key="cell_type",
                                condition_key="batch",
                                figure_folder_path=figure_folder_path,
                                cell_type_groups=None,
                                spot_size=0.03,
                                included_models=[# "NicheCompass",
                                                 "GraphST",
                                                 "scVI"],
                                save_fig=True)

In [None]:
run_number = 5

compute_latent_space_comparison(dataset="starmap_plus_mouse_cns",
                                run_number=run_number,
                                srt_data_results_folder_path=f"{srt_data_results_folder_path}/sample_integration_method_benchmarking",
                                cell_type_colors=seqfish_mouse_organogenesis_cell_type_colors,
                                dataset_title_string="seqFISH Mouse Organogenesis",
                                cell_type_key="cell_type",
                                condition_key="batch",
                                figure_folder_path=figure_folder_path,
                                cell_type_groups=None,
                                spot_size=0.03,
                                included_models=[# "NicheCompass",
                                                 "GraphST",
                                                 "scVI"],
                                save_fig=True)

##### 4.1.3.3 Compute Batch Integration Metrics

In [None]:
adata = sc.read_h5ad("../../datasets/srt_data/results/sample_integration_method_benchmarking/seqfish_mouse_organogenesis_graphst_sample_integration_method_benchmarking.h5ad")

In [None]:
adata

In [None]:
compute_batch_integration_metrics(dataset="seqfish_mouse_organogenesis",
                                  condition_key="batch",
                                  cell_type_key="cell_type",
                                  srt_data_results_folder_path=srt_data_results_folder_path,
                                  metric_artifacts_folder_path=metric_artifacts_folder_path,
                                  included_models=[#"NicheCompass",
                                                   "GraphST",
                                                   "scVI"])

##### 4.1.3.4 Visualize Batch Integration Results

In [None]:
df = pd.DataFrame()
datasets = ["starmap_plus_mouse_cns",]
timestamps = ["27032023_184359"]

for dataset, timestamp in zip(datasets, timestamps):
    dataset_metric_artifacts_folder_path = f"../artifacts/{dataset}/metrics/{timestamp}"
    
    metrics_dict_list = []
    for model in ["NicheCompass", "GraphST", "scVI"]:
        # Read complete benchmarking data from disk
        with open(f"{dataset_metric_artifacts_folder_path}/metrics_{model.lower()}_oneshot_integrated.pickle", "rb") as f:
            metrics_dict = pickle.load(f)
            metrics_dict["model_name"] = model.lower()
            metrics_dict_list.append(metrics_dict)
    
    dataset_df = pd.DataFrame(metrics_dict_list)
    dataset_df["dataset"] = dataset

df = pd.concat([df, dataset_df])
df.head()

columns = ["cas",
           "clisis",
           "asw",
           "ilisi",
           ]

rows = ["nichecompass",
        "graphst",
        "scvi"]

unrolled_df = pd.melt(df, 
   id_vars = ["model_name", "dataset"],
   value_vars = columns,
   var_name = "score_type", 
   value_name = "score")

# Compute metric means over all runs
mean_df = unrolled_df.groupby(["model_name", "dataset", "score_type"]).mean()
mean_df.reset_index(inplace=True)

# Reformat for plot
mean_df.replace({"nichecompass": "NicheCompass",
                 "graphst": "GraphST",
                 "scvi": "scVI"}, inplace=True)

# Sort for right order of columns in plottable
mean_df["score_type"] = pd.Categorical(mean_df["score_type"], ["cas", "clisis", "asw", "ilisi"])
mean_df.sort_values(["model_name", "score_type"], inplace=True)

# Create spatial indicator column
def is_spatially_aware_model(row):
    if row["model_name"] in ["NicheCompass", "GraphST"]:
        return True
    return False
mean_df["spatially_aware"] = mean_df.apply(lambda row: is_spatially_aware_model(row), axis=1)

In [None]:
plot_batch_integration_results(mean_df,
                               show=True,
                               save_dir=None,
                               save_name="batch_integration_results.svg")

##### 4.1.3.4 Scalability

In [None]:
datasets = ["seqfish_mouse_organogenesis", "starmap_plus_mouse_cns"]
models = ["NicheCompass (Attention Aggregator)", "NicheCompass (Norm Aggregator)", "GraphST", "scVI"]

run_time_dict = {"Dataset": [],
                 "Model": [],
                 "Dataset Size (%)": [],
                 "Mean Runtime (s)": []}

for dataset in datasets:
    for model in models:
        for subsample_pct in [1, 5, 10, 25, 50, 100]:
            if model == "NicheCompass (Attention Aggregator)":
                model_subtype = "nichecompass_one-hop-attention"
                model_str = "nichecompass"
            elif model == "NicheCompass (Norm Aggregator)":
                model_subtype = "nichecompass_one-hop-norm"
                model_str = "nichecompass"
            else:
                model_subtype = model
                model_str = model
                
            if subsample_pct == 100:
                subsample_str = ""
            else:
                subsample_str = f"_subsample_{subsample_pct}pct"

            try:
                adata = sc.read_h5ad(f"{srt_data_results_folder_path}/{dataset}{subsample_str}_{model_subtype.lower()}_sample_integration_method_benchmarking.h5ad")
            except:
                print(f"Did not load {dataset}{subsample_str}_{model_subtype.lower()}_sample_integration_method_benchmarking.h5ad")
                continue
                
            print(f"Loaded {dataset}{subsample_str}_{model_subtype.lower()}_sample_integration_method_benchmarking.h5ad")

            run_times = []
            for run_number in range(1, 11):
                run_times.append(adata.uns[f"{model_str.lower()}_model_training_duration_run{run_number}"])
            run_time_dict["Dataset"].append(dataset)
            run_time_dict["Model"].append(model)
            run_time_dict["Dataset Size (%)"].append(subsample_pct)
            run_time_dict["Mean Runtime (s)"].append(np.mean(run_times))

run_time_df = pd.DataFrame(run_time_dict)
display(run_time_df)

In [None]:
# Create a figure with two subplots
fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharey=True)

# Plot the first lineplot in the first subplot
sns.lineplot(x="Dataset Size (%)", y="Mean Runtime (s)", hue=run_time_df["Model"], data=run_time_df[run_time_df["Dataset"] == "seqfish_mouse_organogenesis"], marker="o", ax=axs[0])
axs[0].set_title("seqFISH Mouse Organogenesis")

# Plot the second lineplot in the second subplot
sns.lineplot(x="Dataset Size (%)", y="Mean Runtime (s)", hue=run_time_df["Model"], data=run_time_df[run_time_df["Dataset"] == "starmap_plus_mouse_cns"], marker="o", ax=axs[1], legend=False)
axs[1].set_title("STARmap PLUS Mouse CNS")

# Remove the legend title
leg = axs[0].legend()
leg.set_title("")

# Add a grid to both subplots
for ax in axs:
    ax.grid(True, linewidth=0.2, color='lightgrey')
    
plt.subplots_adjust(wspace=0.05)

plt.savefig(f"{figure_folder_path}/method_comparison_scalability.svg",
            bbox_inches="tight")
plt.show()

### 4.2 Query-to-Reference Mapping

#### 3.3.1 Building the Reference

#### 3.3.2 Mapping the Query

##### 3.3.2.1 Initialize, Train & Save Model

In [None]:
# Plot UMAP with batch annotations
fig = sc.pl.umap(adata_reference,
                 color=[condition_key],
                 legend_fontsize=12,
                 size=240000/len(adata_reference),
                 return_fig=True)
fig.set_figheight(7)
fig.set_figwidth(7)
plt.title(f"One-Shot Integration: {model_name} Latent Batch Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_batches_{model_name.lower()}_oneshot_integrated.svg",
            bbox_inches="tight",
            format="svg")

# Plot UMAP with cell type annotations
fig = sc.pl.umap(adata_reference,
                 color=[cell_type_key],
                 palette=starmap_pluse_mouse_cns_cell_type_colors,
                 legend_fontsize=12,
                 size=240000/len(adata_reference),
                 return_fig=True)
fig.set_figheight(7)
fig.set_figwidth(7)
plt.title(f"One-Shot Integration: {model_name} Latent Cell Type Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_cell_types_{model_name.lower()}_oneshot_integrated.svg",
            bbox_inches="tight",
            format="svg")

##### 4.1.1.2 GraphST

In [None]:
# Plot UMAP with batch annotations
fig = sc.pl.umap(adata_reference,
                 color=[condition_key],
                 legend_fontsize=12,
                 size=240000/len(adata_reference),
                 return_fig=True)
fig.set_figheight(7)
fig.set_figwidth(7)
plt.title(f"One-Shot Integration: {model_name} Latent Batch Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_batches_{model_name.lower()}_oneshot_integrated.svg",
            bbox_inches="tight",
            format="svg")

# Plot UMAP with cell type annotations
fig = sc.pl.umap(adata_reference,
                 color=[cell_type_key],
                 palette=starmap_pluse_mouse_cns_cell_type_colors,
                 legend_fontsize=12,
                 size=240000/len(adata_reference),
                 return_fig=True)
fig.set_figheight(7)
fig.set_figwidth(7)
plt.title(f"One-Shot Integration: {model_name} Latent Cell Type Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_cell_types_{model_name.lower()}_oneshot_integrated.svg",
            bbox_inches="tight",
            format="svg")

#### 4.1.2 NicheCompass

In [None]:
model_name = "NicheCompass"

# Load trained model
model = NicheCompass.load(dir_path=model_artifacts_folder_path + "/reference",
                        adata=None,
                        adata_file_name=f"{dataset}_reference.h5ad",
                        gp_names_key="nichecompass_gp_names")

adata_reference = model.adata

# Use NicheCompass latent representation for UMAP generation
sc.pp.neighbors(adata_reference,
                use_rep=f"{model_name.lower()}_{latent_key}")
sc.tl.umap(adata_reference)

# Save integrated adata to disk
adata_reference.write(f"{srt_data_gold_folder_path}/{dataset}_{model_name.lower()}_oneshot_integrated.h5ad")

In [None]:
# Plot UMAP with batch annotations
fig = sc.pl.umap(adata_reference,
                 color=[condition_key],
                 legend_fontsize=12,
                 size=240000/len(adata_reference),
                 return_fig=True)
fig.set_figheight(10)
fig.set_figwidth(10)
plt.title(f"One-Shot Integration: {model_name} Latent Batch Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_batches_{model_name.lower()}_oneshot_integrated.svg",
            bbox_inches="tight",
            format="svg")

# Plot UMAP with cell type annotations
fig = sc.pl.umap(adata_reference,
                 color=[cell_type_key],
                 palette=starmap_pluse_mouse_cns_cell_type_colors,
                 legend_fontsize=12,
                 size=240000/len(adata_reference),
                 return_fig=True)
fig.set_figheight(10)
fig.set_figwidth(10)
plt.title(f"One-Shot Integration: {model_name} Latent Cell Type Annotations", size=20, pad=15)
fig.savefig(f"{figure_folder_path}/latent_cell_types_{model_name.lower()}_oneshot_integrated.svg",
            bbox_inches="tight",
            format="svg")