# Single Sample Method Benchmarking

- **Creator**: Sebastian Birk (<sebastian.birk@helmholtz-munich.de>).
- **Affiliation:** Helmholtz Munich, Institute of Computational Biology (ICB), Talavera-López Lab
- **Date of Creation:** 06.01.2023
- **Date of Last Modification:** 15.08.2023

## 1. Setup

### 1.1 Import Libraries

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../../../nichecompass-reproducibility/utils")

In [None]:
import os
import warnings
from datetime import datetime

import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import plottable
import scanpy as sc
import scib_metrics
import seaborn as sns
from matplotlib import gridspec
from matplotlib.colors import LinearSegmentedColormap
from plottable import ColumnDefinition, Table
from plottable.cmap import normed_cmap
from plottable.formatters import tickcross
from plottable.plots import bar

#from nichecompass.benchmarking import compute_benchmarking_metrics
from nichecompass.utils import (add_gps_from_gp_dict_to_adata,
                                extract_gp_dict_from_mebocost_es_interactions,
                                extract_gp_dict_from_nichenet_lrt_interactions,
                                extract_gp_dict_from_omnipath_lr_interactions,
                                filter_and_combine_gp_dict_gps)

from benchmarking_utils import (compute_combined_benchmarking_metrics_for_all_models,
                                compute_combined_benchmarking_metrics)

### 1.2 Define Parameters

In [None]:
spatial_key = "spatial"

### 1.3 Run Notebook Setup

In [None]:
sc.set_figure_params(figsize=(6, 6))

In [None]:
# Ignore future warnings and user warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)
warnings.simplefilter(action="ignore", category=DeprecationWarning)

In [None]:
# Get time of notebook execution for timestamping saved artifacts
now = datetime.now()
current_timestamp = now.strftime("%d%m%Y_%H%M%S")

### 1.4 Configure Paths and Directories

In [None]:
data_folder_path = "../../datasets/srt_data/gold"
artifact_folder_path = f"../../artifacts"
gp_data_folder_path = "../../datasets/gp_data" # gene program data
nichenet_ligand_target_mx_file_path = gp_data_folder_path + "/nichenet_ligand_target_matrix.csv"
omnipath_lr_interactions_file_path = gp_data_folder_path + "/omnipath_lr_interactions.csv"

## 2. Method Benchmarking

- Run all model notebooks in the ```notebooks/single_sample_method_benchmarking``` directory before continuing.

### 2.1 Benchmarking Metrics

#### 2.2.1 Define Functions

#### 2.2.2 Compute Benchmarking Metrics

In [None]:
compute_combined_benchmarking_metrics_for_all_models(
    dataset="seqfish_mouse_organogenesis_embryo2")

In [None]:
compute_combined_benchmarking_metrics_for_all_models(
    dataset="vizgen_merfish_mouse_liver")

In [None]:
compute_combined_benchmarking_metrics_for_all_models(
    dataset="vizgen_merfish_mouse_liver_sample")

In [None]:
compute_combined_benchmarking_metrics_for_all_models(
    dataset="starmap_plus_mouse_cns")

In [None]:
compute_combined_benchmarking_metrics_for_all_models(
    dataset="starmap_plus_mouse_cns_sample")

In [None]:
compute_combined_benchmarking_metrics_for_all_models(
    dataset="nanostring_cosmx_human_nsclc")

In [None]:
compute_combined_benchmarking_metrics_for_all_models(
    dataset="nanostring_cosmx_human_nsclc_subsample_20pct",
    )

#### 3.2.3 Visualize Results

In [None]:
df = pd.DataFrame()
datasets = ["seqfish_mouse_organogenesis_embryo2",
            "starmap_plus_mouse_cns_sample",
            "nanostring_cosmx_human_nsclc_sample"]
timestamps = ["07032023_193322",
              "07032023_071219",
              "07032023_080409"]

for dataset, timestamp in zip(datasets, timestamps):
    dataset_artifact_folder_path = f"../../artifacts/{dataset}/method_benchmarking/comparison/{timestamp}"
    
    # Read complete benchmarking data from disk
    with open(f"{dataset_artifact_folder_path}/benchmarking_dict_list.pickle", "rb") as f:
        benchmarking_dict_list = pickle.load(f)
        dataset_df = pd.DataFrame(benchmarking_dict_list)
        dataset_df["dataset"] = dataset
        df = pd.concat([df, dataset_df])
df.head()

columns = ["gcs",
           "mlami",
           "cas",
           "clisis",
           #"gerr2",
           #"cca"
           ]

rows = ["nichecompass",
        "deeplinc",
        "graphst",
        "sagenet",
        "pca",
        "scvi",
        "expimap"]

unrolled_df = pd.melt(df, 
   id_vars = ["model_name", "dataset"],
   value_vars = columns,
   var_name = "score_type", 
   value_name = "score")

# Compute metric means over all runs
mean_df = unrolled_df.groupby(["model_name", "dataset", "score_type"]).mean()
mean_df.reset_index(inplace=True)

# Remove pca for plot
mean_df = mean_df[mean_df["model_name"] != "pca"]

# Reformat for plot
mean_df.replace({"nichecompass": "NicheCompass",
                 "deeplinc": "DeepLinc",
                 "expimap": "expiMap",
                 "graphst": "GraphST",
                 "sagenet": "SageNet",
                 "scvi": "scVI",
                 "nanostring_cosmx_human_nsclc_sample": "nanostring_cosmx_sample",
                 "starmap_plus_mouse_cns_sample": "starmap_plus_sample",
                 "seqfish_mouse_organogenesis_embryo2": "seqfish"}, inplace=True)

# Create spatial indicator column
def is_spatially_aware_model(row):
    if row["model_name"] in ["NicheCompass", "DeepLinc", "GraphST", "SageNet"]:
        return True
    return False
mean_df["spatially_aware"] = mean_df.apply(lambda row: is_spatially_aware_model(row), axis=1)

mean_df_seqfish = mean_df[mean_df["dataset"] == "seqfish"]

In [None]:
plot_benchmarking_results(mean_df_seqfish,
                          save_dir=figure_folder_path,
                          save_name="benchmarking_results_seqfish.svg")

In [None]:
plot_benchmarking_results(mean_df,
                          save_dir=figure_folder_path,
                          save_name="benchmarking_results.svg")

In [None]:
# Compute metric means over all runs
mean_df = unrolled_df.groupby(["model_name", "dataset", "score_type"]).mean()

print(mean_df)
mean_df = mean_df.reindex(rows)

metric_ranks = []
for metric in mean_df.columns:
    mean_df[f"{metric}_rank"] = mean_df[metric].rank(ascending=False).astype(int)

mean_df["model_name"] = mean_df.index

unrolled_mean_df = pd.melt(mean_df, 
   id_vars = ["model_name", "dataset"],
   value_vars = columns,
   var_name = "score_type", 
   value_name = "score")

In [None]:
plt.figure(figsize=(30, 15))
ax = sns.barplot(data=unrolled_mean_df,
                 x="score_type",
                 y="score",
                 hue="model_name",
                 palette="Accent")

plt.xticks(fontsize=30)
plt.yticks(fontsize=30)

# Iterate over the bars and annotate each one with the rank
for i, p in enumerate(ax.patches):
    model_idx, metric_idx = divmod(i, len(columns))
    rank = mean_df.loc[rows[model_idx], columns[metric_idx] + "_rank"]
    ax.annotate("%.0f" % rank,
                (p.get_x() + p.get_width() / 2., p.get_height()),
                ha="center",
                va="center",
                fontsize=20,
                color="gray",
                weight="bold",
                xytext=(0, 8),
                textcoords="offset points")

plt.xlabel("Metric",
           fontsize=30)
plt.ylabel("Metric Value (higher better)",
           fontsize=30)
plt.legend(loc=2,
           bbox_to_anchor=(1, 1),
           prop={"size": 28})
plt.title("Model Benchmarking Evaluation Metrics",
          fontsize=30,
          pad=20)
plt.savefig(f"{figure_folder_path}/eval_metrics.svg",
            dpi=300,
            bbox_inches="tight",
            format="svg")
plt.show()

In [None]:
mean_df["dataset_score_type"] = mean_df["score_type"] + mean_df["dataset"]

In [None]:
mean_df[mean_df["model_name"] == "nichecompass"]

In [None]:
fig, axs = plt.subplots(ncols=2, figsize=(30, 15))

plt.xlabel("Metric",
           fontsize=30)
plt.ylabel("Metric Value (higher better)",
           fontsize=30)
plt.legend(loc=2,
           bbox_to_anchor=(1, 1),
           prop={"size": 28})
plt.title("Model Benchmarking Evaluation Metrics",
          fontsize=30,
          pad=20)

sns.scatterplot(data=mean_df[mean_df["dataset"] == "starmap_plus_mouse_cns_sample"],
                ax=axs[0],
                 x="score_type",
                 y="score",
                 hue="model_name",
                 palette="Accent",
                 s=1000,
                 style="dataset")

sns.scatterplot(data=mean_df[mean_df["dataset"] == "nanostring_cosmx_human_nsclc_sample"],
                ax=axs[1],
                 x="score_type",
                 y="score",
                 hue="model_name",
                 palette="Accent",
                 s=1000,
                 style="dataset")

plt.xticks(fontsize=30)
plt.yticks(fontsize=30)


plt.savefig(f"{figure_folder_path}/eval_metrics.svg",
            dpi=300,
            bbox_inches="tight",
            format="svg")
plt.show()

In [None]:
plt.figure(figsize=(15, 10))
ax = sns.scatterplot(data=mean_df,
                 x="dataset_score_type",
                 y="score",
                 hue="model_name",
                 palette="Accent",
                 s=500,
                 style="dataset")

plt.xticks(fontsize=30)
plt.yticks(fontsize=30)

# Iterate over the bars and annotate each one with the rank
for i, p in enumerate(ax.patches):
    model_idx, metric_idx = divmod(i, len(columns))
    rank = mean_df.loc[rows[model_idx], columns[metric_idx] + "_rank"]
    ax.annotate("%.0f" % rank,
                (p.get_x() + p.get_width() / 2., p.get_height()),
                ha="center",
                va="center",
                fontsize=20,
                color="gray",
                weight="bold",
                xytext=(0, 8),
                textcoords="offset points")

plt.xlabel("Metric",
           fontsize=30)
plt.ylabel("Metric Value (higher better)",
           fontsize=30)
plt.legend(loc=2,
           bbox_to_anchor=(1, 1),
           prop={"size": 28})
plt.title("Model Benchmarking Evaluation Metrics",
          fontsize=30,
          pad=20)
plt.savefig(f"{figure_folder_path}/eval_metrics.svg",
            dpi=300,
            bbox_inches="tight",
            format="svg")
plt.show()

##### 3.2.8.1 Metrics Ranking Plot

In [None]:
mean_df_min_best = mean_df[["gcd", "cad", "arclisi", "germse"]] # lower values are better
mean_df_max_best = mean_df[["mlnmi", "cca", "ari", "clisi", "nmi", "asw", "ilasw", ]] # higher values are better
rank_df_min = mean_df_min_best.rank(method="max", ascending=True)
rank_df_max = mean_df_max_best.rank(method="max", ascending=False)
rank_df = pd.concat([rank_df_min, rank_df_max], axis=1)
rank_df = rank_df[columns]

In [None]:
heatmap = sns.heatmap(rank_df, annot=True, cmap="YlGnBu")
fig = heatmap.get_figure()
plt.title("Method Benchmarking Metrics Ranking", fontsize=20, pad=25)
plt.xticks(rotation=45)
fig.savefig(f"{figure_folder_path}/metrics_ranking_{current_timestamp}.png",
            bbox_inches="tight")
plt.show()

### 2.1 Latent Space Comparison

#### 2.1.2 Visualize Results

##### 2.1.2.1 seqFISH Mouse Organogenesis

In [None]:
compute_latent_space_comparison(
    dataset="seqfish_mouse_organogenesis_embryo2",
    cell_type_colors=color_utils.seqfish_mouse_organogenesis_cell_type_colors,
    cell_type_groups=None,
    n_neighbors=12,
    run_number=5,
    spot_size=0.03)

In [None]:
compute_latent_space_comparison(dataset="seqfish_mouse_organogenesis_embryo2",
                                cell_type_colors=seqfish_mouse_organogenesis_embryo2_cell_type_colors,
                                cell_type_groups="Mixed mesenchymal mesoderm", # "Gut tube",
                                n_neighbors=12,
                                run_number=5,
                                spot_size=0.03)

##### 2.1.2.2 MERFISH Mouse Liver

In [None]:
# DeepLinc did not work on this dataset
compute_latent_space_comparison(dataset="vizgen_merfish_mouse_liver",
                                cell_type_colors=vizgen_merfish_mouse_liver_cell_type_colors,
                                n_neighbors=12,
                                run_number=5,
                                included_models=["NicheCompass",
                                                 "expiMap",
                                                 "scVI"],
                                spot_size=20.0)

In [None]:
# DeepLinc did not work on this dataset
compute_latent_space_comparison(dataset="vizgen_merfish_mouse_liver_sample",
                                cell_type_colors=vizgen_merfish_mouse_liver_cell_type_colors,
                                n_neighbors=12,
                                run_number=5,
                                included_models=["NicheCompass",
                                                 "GraphST",
                                                 "SageNet",
                                                 "expiMap",
                                                 "scVI"],
                                spot_size=20.0)

##### 2.1.2.3 STARmap PLUS Mouse CNS

In [None]:
compute_latent_space_comparison(dataset="starmap_plus_mouse_cns",
                                cell_type_colors=starmap_pluse_mouse_cns_cell_type_colors,
                                n_neighbors=12,
                                run_number=5,
                                included_models=["NicheCompass",
                                                 "expiMap",
                                                 "scVI"],
                                spot_size=0.2)

In [None]:
compute_latent_space_comparison(dataset="starmap_plus_mouse_cns_sample",
                                cell_type_colors=starmap_pluse_mouse_cns_cell_type_colors,
                                n_neighbors=12,
                                run_number=5,
                                spot_size=0.2)

##### 2.1.2.4 NanoString CosMx Human NSCLC

In [None]:
compute_latent_space_comparison(dataset="nanostring_cosmx_human_nsclc",
                                cell_type_colors=nanostring_cosmx_human_nsclc_cell_type_colors,
                                n_neighbors=12,
                                run_number=5,
                                included_models=["NicheCompass",
                                                 "expiMap",
                                                 "scVI"],
                                spot_size=100.0)

In [None]:
compute_latent_space_comparison(dataset="nanostring_cosmx_human_nsclc_sample",
                                cell_type_colors=nanostring_cosmx_human_nsclc_cell_type_colors,
                                n_neighbors=12,
                                run_number=5,
                                spot_size=100.0)