# Method Benchmarking

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 06.01.2023
- **Date of Last Modification:** 22.02.2023

## 1. Setup

### 1.1 Import Libraries

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../../autotalker")

In [3]:
import os
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import scanpy as sc
import scib
import seaborn as sns

from autotalker.benchmarking import compute_benchmarking_metrics

  res = Downloader(opt).maybe_download(
  res = Downloader(opt).maybe_download(
  res = Downloader(opt).maybe_download(
  res = Downloader(opt).maybe_download(
  res = Downloader(opt).maybe_download(
  return UNKNOWN_SERVER_VERSION
  warn(


### 1.2 Define Parameters

In [4]:
dataset = "seqfish_mouse_organogenesis_embryo2"
cell_type_key = "celltype_mapped_refined"
spatial_key = "spatial"

### 1.3 Run Notebook Setup

In [5]:
sc.set_figure_params(figsize=(6, 6))

  IPython.display.set_matplotlib_formats(*ipython_format)


In [6]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

### 1.4 Configure Paths and Directories

In [7]:
data_folder_path = "../datasets/srt_data/gold/"
figure_folder_path = f"../figures/{dataset}/method_benchmarking/"
artifact_folder_path = f"../artifacts/{dataset}/method_benchmarking/"

In [8]:
# Create required directories
os.makedirs(artifact_folder_path, exist_ok=True)
os.makedirs(figure_folder_path, exist_ok=True)

### 1.5 Define Functions

In [9]:
def compute_combined_benchmarking_metrics(model_adata,
                                          model_name,
                                          cell_type_key,
                                          run_number_list=list(np.arange(1, 11)),
                                          n_neighbors_list=[4, 4, 8, 8, 12, 12, 16, 16, 20, 20]):
    benchmarking_dict_list = []
    for run_number, n_neighbors in zip(run_number_list, n_neighbors_list):
        
        # Compute Autotalker metrics
        benchmarking_dict = compute_benchmarking_metrics(adata=model_adata,
                                                         latent_key=f"{model_name}_latent_run{run_number}",
                                                         active_gp_names_key=f"{model_name}_active_gp_names_run{run_number}",
                                                         cell_type_key=cell_type_key,
                                                         spatial_key=spatial_key,
                                                         spatial_knng_key="spatial_knng",
                                                         latent_knng_key = f"{model_name}_latent_knng_run{run_number}")

        # Compute scib metrics
        sc.pp.neighbors(adata=model_adata,
                        use_rep=f"{model_name}_latent_run{run_number}")
        scib.me.cluster_optimal_resolution(adata=model_adata,
                                           cluster_key="cluster",
                                           label_key=cell_type_key)
        benchmarking_dict["ari"] = scib.me.ari(model_adata,
                                               cluster_key="cluster",
                                               label_key=cell_type_key)
        benchmarking_dict["clisi"] = scib.me.clisi_graph(adata=model_adata,
                                                         label_key=cell_type_key,
                                                         type_="embed",
                                                         use_rep=f"{model_name}_latent_run{run_number}")
        benchmarking_dict["nmi"] = scib.me.nmi(adata=model_adata,
                                               cluster_key="cluster",
                                               label_key=cell_type_key)
        benchmarking_dict["asw"] = scib.me.silhouette(adata=model_adata,
                                                      label_key=cell_type_key,
                                                      embed=f"{model_name}_latent_run{run_number}")
        benchmarking_dict["ilasw"] = scib.me.isolated_labels_asw(adata=model_adata,
                                                                 batch_key="sample",
                                                                 label_key=cell_type_key,
                                                                 embed=f"{model_name}_latent_run{run_number}")
        
        benchmarking_dict["model_name"] = model_name
        benchmarking_dict["run"] = run_number
        benchmarking_dict_list.append(benchmarking_dict)
    return benchmarking_dict_list

## 2. Data

In [10]:
# Load data after running all notebooks in the 'method_benchmarking' folder
adata_pca = sc.read_h5ad(data_folder_path + f"{dataset}_pca.h5ad")
adata_scvi = sc.read_h5ad(data_folder_path + f"{dataset}_scvi.h5ad")
adata_expimap = sc.read_h5ad(data_folder_path + f"{dataset}_expimap.h5ad")
adata_sagenet = sc.read_h5ad(data_folder_path + f"{dataset}_sagenet.h5ad")
adata_deeplinc = sc.read_h5ad(data_folder_path + f"{dataset}_deeplinc.h5ad")
adata_graphst = sc.read_h5ad(data_folder_path + f"{dataset}_graphst.h5ad")
adata_autotalker = sc.read_h5ad(data_folder_path + f"{dataset}_autotalker.h5ad")

FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = '../datasets/srt_data/gold/seqfish_mouse_organogenesis_embryo2_pca.h5ad', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

## 3. Method Benchmarking

- Run all notebooks in the ```method_benchmarking``` directory before continuing.

### 3.1 Latent Space Comparison

In [None]:
# Baseline
sc.pp.neighbors(adata_pca, use_rep=f"pca_latent_run1")
sc.tl.umap(adata_pca)

# Methods
run_number = 5
adata_sagenet.obsm["X_umap"] = adata_sagenet.obsm[f"sagenet_latent_run{run_number}"] # latent representation of SageNet are already UMAP features
for adata, method in zip([adata_scvi, adata_expimap, adata_deeplinc, adata_graphst, adata_autotalker],
                         ["scvi", "expimap", "deeplinc", "graphst", "autotalker"]):
    sc.pp.neighbors(adata, use_rep=f"{method}_latent_run{run_number}")
    sc.tl.umap(adata)

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=4, figsize=(20, 10))
plt.suptitle("Latent Space Comparison", fontsize=25, x=0.575)
plt.subplots_adjust(hspace=0.25, wspace=0.25, top=0.9)
axs=axs.flatten()

sc.pl.spatial(adata=adata,
              color=[cell_type_key],
              spot_size=0.03,
              ax=axs[0],
              show=False)
axs[0].set_title("Physical Space", fontsize=17)
handles, labels = axs[0].get_legend_handles_labels()
lgd = fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(1.07, 0.845))
axs[0].get_legend().remove()
                         
for i, (adata, title) in enumerate(zip([adata_autotalker, adata_deeplinc, adata_graphst, adata_sagenet, adata_pca, adata_scvi, adata_expimap],
                                       ["Autotalker", "DeepLinc", "GraphST", "SageNet", "Log Normalized Counts PCA", "scVI", "expiMap"])):        
    sc.pl.umap(adata,
               color=[cell_type_key],
               ax=axs[i + 1],
               show=False,
               legend_loc=None)
    axs[i + 1].set_title(title, fontsize=17)

fig.savefig(f"{figure_folder_path}/latent_comparison_{current_timestamp}.png",
            bbox_inches="tight")
plt.show()

### 3.2 Benchmarking Metrics

#### 3.2.1 PCA

- Evaluate PCA of log normalized gene expression.

In [None]:
benchmarking_dict_list_pca = compute_combined_benchmarking_metrics(model_adata=adata_pca,
                                                                   model_name="pca",
                                                                   spatial_model=False,
                                                                   run_number_list=[1],
                                                                   n_neighbors_list=[12],
                                                                   cell_type_key="celltype_mapped_refined")

In [None]:
benchmarking_dict_list = benchmarking_dict_list_pca

# Store to disk
with open(f"{artifact_folder_path}/benchmarking_dict_list.pickle", "wb") as f:
    pickle.dump(benchmarking_dict_list, f)

# Clean from memory
del(adata_pca)

#### 3.2.2 scVI

- Evaluate scVI.

In [None]:
benchmarking_dict_list_scvi = compute_combined_benchmarking_metrics(model_adata=adata_scvi,
                                                                    model_name="scvi",
                                                                    spatial_model=False,
                                                                    run_number_list=list(np.arange(1, 11)),
                                                                    n_neighbors_list=[4, 4, 8, 8, 12, 12, 16, 16, 20, 20],
                                                                    cell_type_key="celltype_mapped_refined")

In [None]:
benchmarking_dict_list += benchmarking_dict_list_scvi

# Store to disk
with open(f"{artifact_folder_path}/benchmarking_dict_list.pickle", "wb") as f:
    pickle.dump(benchmarking_dict_list, f)

# Clean from memory
del(adata_scvi)

#### 3.2.3 expiMap

- Evaluate expiMap.

In [None]:
benchmarking_dict_list_expimap = compute_combined_benchmarking_metrics(model_adata=adata_expimap,
                                                                       model_name="expimap",
                                                                       spatial_model=False,
                                                                       run_number_list=list(np.arange(1, 11)),
                                                                       n_neighbors_list=[4, 4, 8, 8, 12, 12, 16, 16, 20, 20],
                                                                       cell_type_key="celltype_mapped_refined")

In [None]:
benchmarking_dict_list += benchmarking_dict_list_expimap

# Store to disk
with open(f"{artifact_folder_path}/benchmarking_dict_list.pickle", "wb") as f:
    pickle.dump(benchmarking_dict_list, f)

del(adata_expimap)

#### 3.2.4 SageNet

- Evaluate SageNet.

In [None]:
benchmarking_dict_list_sagenet = compute_combined_benchmarking_metrics(model_adata=adata_sagenet,
                                                                       model_name="sagenet",
                                                                       spatial_model=True,
                                                                       run_number_list=list(np.arange(1, 11)),
                                                                       n_neighbors_list=[4, 4, 8, 8, 12, 12, 16, 16, 20, 20],
                                                                       cell_type_key="celltype_mapped_refined")

In [None]:
benchmarking_dict_list += benchmarking_dict_list_sagenet

# Store to disk
with open(f"{artifact_folder_path}/benchmarking_dict_list.pickle", "wb") as f:
    pickle.dump(benchmarking_dict_list, f)
    
del(adata_sagenet)

#### 3.2.5 DeepLinc

- Evaluate DeepLinc.

In [None]:
benchmarking_dict_list_deeplinc = compute_combined_benchmarking_metrics(model_adata=adata_deeplinc,
                                                                        model_name="deeplinc",
                                                                        spatial_model=True,
                                                                        run_number_list=list(np.arange(1, 11)),
                                                                        n_neighbors_list=[4, 4, 8, 8, 12, 12, 16, 16, 20, 20],
                                                                        cell_type_key="celltype_mapped_refined")

In [None]:
benchmarking_dict_list += benchmarking_dict_list_deeplinc

# Store to disk
with open(f"{artifact_folder_path}/benchmarking_dict_list.pickle", "wb") as f:
    pickle.dump(benchmarking_dict_list, f)

del(adata_deeplinc)

#### 3.2.6 GraphST

- Evaluate GraphST.

In [None]:
benchmarking_dict_list_graphst = compute_combined_benchmarking_metrics(model_adata=adata_graphst,
                                                                       model_name="graphst",
                                                                       spatial_model=True,
                                                                       run_number_list=list(np.arange(1, 11)),
                                                                       n_neighbors_list=[4, 4, 8, 8, 12, 12, 16, 16, 20, 20],
                                                                       cell_type_key="celltype_mapped_refined")

In [None]:
benchmarking_dict_list += benchmarking_dict_list_graphst

# Store to disk
with open(f"{artifact_folder_path}/benchmarking_dict_list.pickle", "wb") as f:
    pickle.dump(benchmarking_dict_list, f)

del(adata_graphst)

#### 3.2.7 Autotalker

- Evaluate Autotalker.

In [None]:
benchmarking_dict_list_autotalker = compute_combined_benchmarking_metrics(model_adata=adata_autotalker,
                                                                          model_name="autotalker",
                                                                          spatial_model=True,
                                                                          run_number_list=list(np.arange(1, 11)),
                                                                          n_neighbors_list=[4, 4, 8, 8, 12, 12, 16, 16, 20, 20],
                                                                          cell_type_key="celltype_mapped_refined")

In [None]:
benchmarking_dict_list += benchmarking_dict_list_autotalker

# Store to disk
with open(f"{artifact_folder_path}/benchmarking_dict_list.pickle", "wb") as f:
    pickle.dump(benchmarking_dict_list, f)

del(adata_autotalker)

#### 3.2.8 Summary

In [None]:
# Read complete benchmarking data from disk
with open(f"{artifact_folder_path}/benchmarking_dict_list.pickle", "rb") as f:
    benchmarking_dict_list = pickle.load(f)

In [None]:
df = pd.DataFrame(benchmarking_dict_list)
df.head()

In [None]:
# Compute metric means over all runs
mean_df = df.groupby("model_name").mean()

columns = ["gcd",
           "mlnmi",
           "cad",
           "arclisi",
           "germse",
           "cca",
           "ari",
           "clisi",
           "nmi",
           "asw",
           "ilasw"]

rows = ["autotalker",
        "deeplinc",
        "graphst",
        "sagenet",
        "pca",
        "scvi",
        "expimap"]

mean_df = mean_df[columns]
mean_df = mean_df.reindex(rows)

mean_df

##### 3.2.8.1 Metrics Plot

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=int(np.ceil(len(columns)/2)), figsize=(3*len(columns), 8))
axs=axs.flatten()

for i, col in enumerate(columns):
    sns.barplot(data=mean_df, x=mean_df.index, y=col, ax=axs[i])
    axs[i].set_xlabel('')
    xlabels = axs[i].get_xticks()
    axs[i].set_xticklabels(mean_df.index, rotation=45)
plt.suptitle("Method Benchmarking Metrics", fontsize=25)
plt.subplots_adjust(hspace=0.5, wspace=0.5, top=0.9)

if len(columns) % 2 != 0:
    fig.delaxes(axs[-1])

fig.savefig(f"{figure_folder_path}/metrics_{current_timestamp}.png",
            bbox_inches="tight")    
plt.show()

##### 3.2.8.1 Metrics Ranking Plot

In [None]:
mean_df_min_best = mean_df[["gcd", "cad", "arclisi", "germse"]] # lower values are better
mean_df_max_best = mean_df[["mlnmi", "cca", "ari", "clisi", "nmi", "asw", "ilasw", ]] # higher values are better
rank_df_min = mean_df_min_best.rank(method="max", ascending=True)
rank_df_max = mean_df_max_best.rank(method="max", ascending=False)
rank_df = pd.concat([rank_df_min, rank_df_max], axis=1)
rank_df = rank_df[columns]

In [None]:
heatmap = sns.heatmap(rank_df, annot=True, cmap="YlGnBu")
fig = heatmap.get_figure()
plt.title("Method Benchmarking Metrics Ranking", fontsize=20, pad=25)
plt.xticks(rotation=45)
fig.savefig(f"{figure_folder_path}/metrics_ranking_{current_timestamp}.png",
            bbox_inches="tight")
plt.show()