# Method Benchmarking

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 06.01.2023
- **Date of Last Modification:** 24.02.2023

## 1. Setup

### 1.1 Import Libraries

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../../../autotalker")

In [3]:
import os
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import scanpy as sc
import scib
import seaborn as sns

from autotalker.benchmarking import compute_benchmarking_metrics
from autotalker.utils import (add_gps_from_gp_dict_to_adata,
                              extract_gp_dict_from_mebocost_es_interactions,
                              extract_gp_dict_from_nichenet_ligand_target_mx,
                              extract_gp_dict_from_omnipath_lr_interactions,
                              filter_and_combine_gp_dict_gps)

  res = Downloader(opt).maybe_download(
  res = Downloader(opt).maybe_download(
  res = Downloader(opt).maybe_download(
  res = Downloader(opt).maybe_download(
  res = Downloader(opt).maybe_download(
  return UNKNOWN_SERVER_VERSION
  warn(


### 1.2 Define Parameters

In [4]:
dataset = "seqfish_mouse_organogenesis_embryo2"
cell_type_key = "celltype_mapped_refined"
spatial_key = "spatial"

In [5]:
cell_type_colors = {"Epiblast" : "#635547",
                    "Primitive Streak" : "#DABE99",
                    "Caudal epiblast" : "#9e6762",
                    "PGC" : "#FACB12",
                    "Anterior Primitive Streak" : "#c19f70",
                    "Notochord" : "#0F4A9C",
                    "Def. endoderm" : "#F397C0",
                    "Definitive endoderm" : "#F397C0",
                    "Gut" : "#EF5A9D",
                    "Gut tube" : "#EF5A9D",
                    "Nascent mesoderm" : "#C594BF",
                    "Mixed mesoderm" : "#DFCDE4",
                    "Intermediate mesoderm" : "#139992",
                    "Caudal Mesoderm" : "#3F84AA",
                    "Paraxial mesoderm" : "#8DB5CE",
                    "Somitic mesoderm" : "#005579",
                    "Pharyngeal mesoderm" : "#C9EBFB",
                    "Splanchnic mesoderm" : "#C9EBFB",
                    "Cardiomyocytes" : "#B51D8D",
                    "Allantois" : "#532C8A",
                    "ExE mesoderm" : "#8870ad",
                    "Lateral plate mesoderm" : "#8870ad",
                    "Mesenchyme" : "#cc7818",
                    "Mixed mesenchymal mesoderm" : "#cc7818",
                    "Haematoendothelial progenitors" : "#FBBE92",
                    "Endothelium" : "#ff891c",
                    "Blood progenitors 1" : "#f9decf",
                    "Blood progenitors 2" : "#c9a997",
                    "Erythroid1" : "#C72228",
                    "Erythroid2" : "#f79083",
                    "Erythroid3" : "#EF4E22",
                    "Erythroid" : "#f79083",
                    "Blood progenitors" : "#f9decf",
                    "NMP" : "#8EC792",
                    "Rostral neurectoderm" : "#65A83E",
                    "Caudal neurectoderm" : "#354E23",
                    "Neural crest" : "#C3C388",
                    "Forebrain/Midbrain/Hindbrain" : "#647a4f",
                    "Spinal cord" : "#CDE088",
                    "Surface ectoderm" : "#f7f79e",
                    "Visceral endoderm" : "#F6BFCB",
                    "ExE endoderm" : "#7F6874",
                    "ExE ectoderm" : "#989898",
                    "Parietal endoderm" : "#1A1A1A",
                    "Low quality" : "#e6e6e6",
                    "Cranial mesoderm" : "#77441B",
                    "Anterior somitic tissues" : "#F90026",
                    "Sclerotome" : "#A10037",
                    "Dermomyotome" : "#DA5921",
                    "Posterior somitic tissues" : "#E1C239",
                    "Presomitic mesoderm" : "#9DD84A",
                    "None" : "#D3D3D3"}

### 1.3 Run Notebook Setup

In [6]:
sc.set_figure_params(figsize=(6, 6))

  IPython.display.set_matplotlib_formats(*ipython_format)


In [7]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

### 1.4 Configure Paths and Directories

In [8]:
data_folder_path = "../../datasets/srt_data/gold/"
figure_folder_path = f"../../figures/{dataset}/method_benchmarking/comparison/{current_timestamp}"
gp_data_folder_path = "../../datasets/gp_data" # gene program data
nichenet_ligand_target_mx_file_path = gp_data_folder_path + "/nichenet_ligand_target_matrix.csv"
omnipath_lr_interactions_file_path = gp_data_folder_path + "/omnipath_lr_interactions.csv"

In [9]:
# Create required directories
os.makedirs(figure_folder_path, exist_ok=True)

## 2. Method Benchmarking

- Run all notebooks in the ```method_benchmarking``` directory before continuing.

### 3.1 Latent Space Comparison

#### 3.1.1 Define Function

In [11]:
def compute_latent_space_comparison(dataset,
                                    cell_type_key,
                                    n_neighbors=12,
                                    run_number=5):
    # Load data
    adata_pca = sc.read_h5ad(data_folder_path + f"{dataset}_pca.h5ad")
    adata_scvi = sc.read_h5ad(data_folder_path + f"{dataset}_scvi.h5ad")
    adata_expimap = sc.read_h5ad(data_folder_path + f"{dataset}_expimap.h5ad")
    adata_sagenet = sc.read_h5ad(data_folder_path + f"{dataset}_sagenet.h5ad")
    adata_deeplinc = sc.read_h5ad(data_folder_path + f"{dataset}_deeplinc.h5ad")
    adata_graphst = sc.read_h5ad(data_folder_path + f"{dataset}_graphst.h5ad")
    adata_autotalker = sc.read_h5ad(data_folder_path + f"{dataset}_autotalker.h5ad")
    
    adata_sagenet.obsm["X_umap"] = adata_sagenet.obsm[f"sagenet_latent_run{run_number}"] # latent representation of SageNet are already UMAP features
    for adata, method in zip([adata_pca, adata_scvi, adata_expimap, adata_deeplinc, adata_graphst, adata_autotalker],
                             ["pca", "scvi", "expimap", "deeplinc", "graphst", "autotalker"]):
        sc.pp.neighbors(adata,
                        use_rep=f"{method}_latent_run{run_number}",
                        n_neighbors=n_neighbors)
        sc.tl.umap(adata)
        
    fig, axs = plt.subplots(nrows=2, ncols=4, figsize=(20, 10))
    plt.suptitle("Latent Space Comparison", fontsize=25, x=0.575)
    plt.subplots_adjust(hspace=0.25, wspace=0.25, top=0.9)
    axs=axs.flatten()

    sc.pl.spatial(adata=adata,
                  color=[cell_type_key],
                  palette=cell_type_colors,
                  spot_size=0.03,
                  ax=axs[0],
                  show=False)
    axs[0].set_title("Physical Space", fontsize=17)
    handles, labels = axs[0].get_legend_handles_labels()
    lgd = fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(1.07, 0.845))
    axs[0].get_legend().remove()

    for i, (adata, title) in enumerate(zip([adata_autotalker, adata_deeplinc, adata_graphst, adata_sagenet, adata_pca, adata_scvi, adata_expimap],
                                           ["Autotalker", "DeepLinc", "GraphST", "SageNet", "Log Normalized Counts PCA", "scVI", "expiMap"])):        
        sc.pl.umap(adata,
                   color=[cell_type_key],
                   palette=cell_type_colors,
                   ax=axs[i + 1],
                   show=False,
                   legend_loc=None)
        axs[i + 1].set_title(title, fontsize=17)

    fig.savefig(f"{figure_folder_path}/latent_comparison.png",
                bbox_inches="tight")
    plt.show()

#### 3.1.2 Run Function

In [None]:
compute_latent_space_comparison(dataset="seqfish_mouse_organogenesis_embryo2",
                                cell_type_key="celltype_mapped_refined",
                                n_neighbors=12,
                                run_number=5)

In [None]:
compute_latent_space_comparison(dataset="vizgen_merfish_mouse_liver",
                                cell_type_key="Cell_Type",
                                n_neighbors=12,
                                run_number=5)

In [None]:
compute_latent_space_comparison(dataset="starmap_plus_mouse_cns",
                                cell_type_key="Main_molecular_cell_type",
                                n_neighbors=12,
                                run_number=5)

In [None]:
compute_latent_space_comparison(dataset="nanostring_cosmx_human_nsclc",
                                cell_type_key="cell_type",
                                n_neighbors=12,
                                run_number=5)

### 3.2 Benchmarking Metrics

#### 3.2.1 Define Functions

In [12]:
def compute_combined_benchmarking_metrics(model_adata,
                                          model_name,
                                          cell_type_key,
                                          run_number_list=list(np.arange(1, 11)),
                                          n_neighbors_list=[4, 4, 8, 8, 12, 12, 16, 16, 20, 20],
                                          ger_genes=None):
    benchmarking_dict_list = []
    for run_number, n_neighbors in zip(run_number_list, n_neighbors_list):
        
        # Compute Autotalker metrics
        benchmarking_dict = compute_benchmarking_metrics(adata=model_adata,
                                                         latent_key=f"{model_name}_latent_run{run_number}",
                                                         active_gp_names_key=f"{model_name}_active_gp_names_run{run_number}",
                                                         cell_type_key=cell_type_key,
                                                         spatial_key=spatial_key,
                                                         spatial_knng_key=f"spatial_{n_neighbors}nng",
                                                         latent_knng_key=f"{model_name}_latent_{n_neighbors}nng_run{run_number}",
                                                         ger_genes=ger_genes)

        # Compute scib metrics
        sc.pp.neighbors(adata=model_adata,
                        use_rep=f"{model_name}_latent_run{run_number}",
                        n_neighbors=n_neighbors)
        scib.me.cluster_optimal_resolution(adata=model_adata,
                                           cluster_key="cluster",
                                           label_key=cell_type_key)
        benchmarking_dict["ari"] = scib.me.ari(model_adata,
                                               cluster_key="cluster",
                                               label_key=cell_type_key)
        benchmarking_dict["clisi"] = scib.me.clisi_graph(adata=model_adata,
                                                         label_key=cell_type_key,
                                                         type_="embed",
                                                         use_rep=f"{model_name}_latent_run{run_number}")
        benchmarking_dict["nmi"] = scib.me.nmi(adata=model_adata,
                                               cluster_key="cluster",
                                               label_key=cell_type_key)
        benchmarking_dict["asw"] = scib.me.silhouette(adata=model_adata,
                                                      label_key=cell_type_key,
                                                      embed=f"{model_name}_latent_run{run_number}")
        
        benchmarking_dict["model_name"] = model_name
        benchmarking_dict["run"] = run_number
        benchmarking_dict_list.append(benchmarking_dict)
    return benchmarking_dict_list

In [17]:
def compute_combined_benchmarking_metrics_for_all_models(dataset,
                                                         cell_type_key,
                                                         run_number_list=list(np.arange(1, 11)),
                                                         n_neighbors_list=[4, 4, 8, 8, 12, 12, 16, 16, 20, 20],
                                                         use_only_gp_mask_target_genes_for_gene_expr_regr=True):
    # Configure dataset artifact folder path
    dataset_artifact_folder_path = f"../../artifacts/{dataset}/method_benchmarking/comparison/{current_timestamp}"
    os.makedirs(dataset_artifact_folder_path, exist_ok=True)
    
    if use_only_gp_mask_target_genes_for_gene_expr_regr:
        # Identify genes that are available in gp mask as target genes
        print("Retrieving gp mask target genes...")
        adata = sc.read_h5ad(data_folder_path + f"{dataset}.h5ad")
        
        nichenet_gp_dict = extract_gp_dict_from_nichenet_ligand_target_mx(
            keep_target_ratio=0.01,
            load_from_disk=False,
            save_to_disk=False,
            file_path=nichenet_ligand_target_mx_file_path)

        omnipath_gp_dict = extract_gp_dict_from_omnipath_lr_interactions(
            min_curation_effort=0,
            load_from_disk=False,
            save_to_disk=False,
            file_path=omnipath_lr_interactions_file_path)

        mebocost_gp_dict = extract_gp_dict_from_mebocost_es_interactions(
            dir_path=f"{gp_data_folder_path}/metabolite_enzyme_sensor_gps/",
            species="mouse",
            genes_uppercase=True)

        # Combine gene programs into one dictionary
        combined_gp_dict = dict(nichenet_gp_dict)
        combined_gp_dict.update(omnipath_gp_dict)
        combined_gp_dict.update(mebocost_gp_dict)

        # Filter and combine gene programs
        combined_new_gp_dict = filter_and_combine_gp_dict_gps(
            gp_dict=combined_gp_dict,
            gp_filter_mode="subset", #None,
            combine_overlap_gps=True, #True,
            overlap_thresh_source_genes=0.9,
            overlap_thresh_target_genes=0.9,
            overlap_thresh_genes=0.9,
            verbose=True)

        # Add the gene program dictionary as binary masks to the adata for model training
        add_gps_from_gp_dict_to_adata(
            gp_dict=combined_new_gp_dict,
            adata=adata_autotalker,
            genes_uppercase=True,
            gp_targets_mask_key="autotalker_gp_targets",
            gp_sources_mask_key="autotalker_gp_sources",
            gp_names_key="autotalker_gp_names",
            min_genes_per_gp=1,
            min_source_genes_per_gp=0,
            min_target_genes_per_gp=0,
            max_genes_per_gp=None,
            max_source_genes_per_gp=None,
            max_target_genes_per_gp=None)

        ger_genes = adata_autotalker.var_names[
            adata_autotalker.uns["autotalker_target_genes_idx"]].tolist()
        del(adata)
    else:
        ger_genes = None
        
    # PCA
    print("Computing metrics for PCA...")
    adata_pca = sc.read_h5ad(data_folder_path + f"{dataset}_pca.h5ad")
    benchmarking_dict_list_pca = compute_combined_benchmarking_metrics(model_adata=adata_pca,
                                                                       model_name="pca",
                                                                       run_number_list=run_number_list,
                                                                       n_neighbors_list=n_neighbors_list,
                                                                       cell_type_key=cell_type_key,
                                                                       ger_genes=ger_genes)   
    
    benchmarking_dict_list = benchmarking_dict_list_pca
    with open(f"{dataset_artifact_folder_path}/benchmarking_dict_list.pickle", "wb") as f:
        pickle.dump(benchmarking_dict_list, f)
    del(adata_pca)
    print("")
    
    # scVI
    print("Computing metrics for scVI...")
    adata_scvi = sc.read_h5ad(data_folder_path + f"{dataset}_scvi.h5ad")
    benchmarking_dict_list_scvi = compute_combined_benchmarking_metrics(model_adata=adata_scvi,
                                                                        model_name="scvi",
                                                                        run_number_list=run_number_list,
                                                                        n_neighbors_list=n_neighbors_list,
                                                                        cell_type_key=cell_type_key,
                                                                        ger_genes=ger_genes)  
    benchmarking_dict_list += benchmarking_dict_list_scvi
    with open(f"{dataset_artifact_folder_path}/benchmarking_dict_list.pickle", "wb") as f:
        pickle.dump(benchmarking_dict_list, f)
    del(adata_scvi)
    print("")
    
    # expiMap
    print("Computing metrics for expiMap...")
    adata_expimap = sc.read_h5ad(data_folder_path + f"{dataset}_expimap.h5ad")
    benchmarking_dict_list_expimap = compute_combined_benchmarking_metrics(model_adata=adata_expimap,
                                                                           model_name="expimap",
                                                                           run_number_list=run_number_list,
                                                                           n_neighbors_list=n_neighbors_list,
                                                                           cell_type_key=cell_type_key,
                                                                           ger_genes=ger_genes)  
    benchmarking_dict_list += benchmarking_dict_list_expimap
    with open(f"{dataset_artifact_folder_path}/benchmarking_dict_list.pickle", "wb") as f:
        pickle.dump(benchmarking_dict_list, f)
    del(adata_expimap)
    print("")
    
    # SageNet
    print("Computing metrics for SageNet...")
    adata_sagenet = sc.read_h5ad(data_folder_path + f"{dataset}_sagenet.h5ad")
    benchmarking_dict_list_sagenet = compute_combined_benchmarking_metrics(model_adata=adata_sagenet,
                                                                           model_name="sagenet",
                                                                           run_number_list=run_number_list,
                                                                           n_neighbors_list=n_neighbors_list,
                                                                           cell_type_key=cell_type_key,
                                                                           ger_genes=ger_genes) 
    benchmarking_dict_list += benchmarking_dict_list_sagenet
    with open(f"{dataset_artifact_folder_path}/benchmarking_dict_list.pickle", "wb") as f:
        pickle.dump(benchmarking_dict_list, f)
    del(adata_sagenet)
    print("")
    
    # DeepLinc
    print("Computing metrics for DeepLinc...")
    adata_deeplinc = sc.read_h5ad(data_folder_path + f"{dataset}_deeplinc.h5ad")
    benchmarking_dict_list_deeplinc = compute_combined_benchmarking_metrics(model_adata=adata_deeplinc,
                                                                            model_name="deeplinc",
                                                                            run_number_list=run_number_list,
                                                                            n_neighbors_list=n_neighbors_list,
                                                                            cell_type_key=cell_type_key,
                                                                            ger_genes=ger_genes)
    benchmarking_dict_list += benchmarking_dict_list_deeplinc
    with open(f"{dataset_artifact_folder_path}/benchmarking_dict_list.pickle", "wb") as f:
        pickle.dump(benchmarking_dict_list, f)
    del(adata_deeplinc)
    print("")
    
    # GraphST
    print("Computing metrics for GraphST...")
    adata_graphst = sc.read_h5ad(data_folder_path + f"{dataset}_graphst.h5ad")
    benchmarking_dict_list_graphst = compute_combined_benchmarking_metrics(model_adata=adata_graphst,
                                                                           model_name="graphst",
                                                                           run_number_list=run_number_list,
                                                                           n_neighbors_list=n_neighbors_list,
                                                                           cell_type_key=cell_type_key,
                                                                           ger_genes=ger_genes)
    benchmarking_dict_list += benchmarking_dict_list_graphst
    with open(f"{dataset_artifact_folder_path}/benchmarking_dict_list.pickle", "wb") as f:
        pickle.dump(benchmarking_dict_list, f)
    del(adata_graphst)
    print("")
    
    # Autotalker
    print("Computing metrics for Autotalker...")
    adata_autotalker = sc.read_h5ad(data_folder_path + f"{dataset}_autotalker.h5ad")
    benchmarking_dict_list_autotalker = compute_combined_benchmarking_metrics(model_adata=adata_autotalker,
                                                                              model_name="autotalker",
                                                                              run_number_list=run_number_list,
                                                                              n_neighbors_list=n_neighbors_list,
                                                                              cell_type_key=cell_type_key,
                                                                              ger_genes=ger_genes)
    benchmarking_dict_list += benchmarking_dict_list_autotalker
    with open(f"{dataset_artifact_folder_path}/benchmarking_dict_list.pickle", "wb") as f:
        pickle.dump(benchmarking_dict_list, f)
    del(adata_autotalker)
    print("")

In [18]:
def create_summary_plot(dataset):
    # Read complete benchmarking data from disk
    with open(f"{artifact_folder_path}/benchmarking_dict_list.pickle", "rb") as f:
        benchmarking_dict_list = pickle.load(f)
        
    df = pd.DataFrame(benchmarking_dict_list)

#### 3.2.2 Run Functions

In [None]:
compute_combined_benchmarking_metrics_for_all_models(dataset="seqfish_mouse_organogenesis_embryo2",
                                                     cell_type_key="celltype_mapped_refined")

Retrieving gp mask target genes...
Downloading NicheNet ligand target potential matrix from the web. This might take a while...


In [None]:
compute_combined_benchmarking_metrics_for_all_models(dataset="vizgen_merfish_mouse_liver",
                                                     cell_type_key="Cell_Type",)

In [None]:
compute_combined_benchmarking_metrics_for_all_models(dataset="starmap_plus_mouse_cns",
                                                     cell_type_key="Main_molecular_cell_type")

In [None]:
compute_combined_benchmarking_metrics_for_all_models(dataset="nanostring_cosmx_human_nsclc",
                                                     cell_type_key="cell_type")

#### 3.2.8 Summary

In [None]:
dataset_artifact_folder_path = f"../../artifacts/{dataset}/method_benchmarking/comparison/23022023_081459"

In [None]:
# Read complete benchmarking data from disk
with open(f"{dataset_artifact_folder_path}/benchmarking_dict_list.pickle", "rb") as f:
    benchmarking_dict_list = pickle.load(f)

In [None]:
df = pd.DataFrame(benchmarking_dict_list)
df.head()

In [None]:
# Compute metric means over all runs
mean_df = df.groupby("model_name").mean()

columns = ["gcd",
           "mlnmi",
           "cad",
           "arclisi",
           "rclisi",
           "germse",
           "cca",
           "ari",
           "clisi",
           "nmi",
           "asw",
           "ilasw"]

rows = ["autotalker",
        "deeplinc",
        "graphst",
        "sagenet",
        "pca",
        "scvi",
        "expimap"]

mean_df = mean_df[columns]
mean_df = mean_df.reindex(rows)

mean_df

##### 3.2.8.1 Metrics Plot

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=int(np.ceil(len(columns)/2)), figsize=(3*len(columns), 8))
axs=axs.flatten()

for i, col in enumerate(columns):
    sns.barplot(data=mean_df, x=mean_df.index, y=col, ax=axs[i])
    axs[i].set_xlabel('')
    xlabels = axs[i].get_xticks()
    axs[i].set_xticklabels(mean_df.index, rotation=45)
plt.suptitle("Method Benchmarking Metrics", fontsize=25)
plt.subplots_adjust(hspace=0.5, wspace=0.5, top=0.9)

if len(columns) % 2 != 0:
    fig.delaxes(axs[-1])

fig.savefig(f"{figure_folder_path}/metrics_{current_timestamp}.png",
            bbox_inches="tight")    
plt.show()

##### 3.2.8.1 Metrics Ranking Plot

In [None]:
mean_df_min_best = mean_df[["gcd", "cad", "arclisi", "germse"]] # lower values are better
mean_df_max_best = mean_df[["mlnmi", "cca", "ari", "clisi", "nmi", "asw", "ilasw", ]] # higher values are better
rank_df_min = mean_df_min_best.rank(method="max", ascending=True)
rank_df_max = mean_df_max_best.rank(method="max", ascending=False)
rank_df = pd.concat([rank_df_min, rank_df_max], axis=1)
rank_df = rank_df[columns]

In [None]:
heatmap = sns.heatmap(rank_df, annot=True, cmap="YlGnBu")
fig = heatmap.get_figure()
plt.title("Method Benchmarking Metrics Ranking", fontsize=20, pad=25)
plt.xticks(rotation=45)
fig.savefig(f"{figure_folder_path}/metrics_ranking_{current_timestamp}.png",
            bbox_inches="tight")
plt.show()