# scIB Benchmark

In [None]:
# Load libraries

# Python packages
import numpy as np
import scanpy as sc
import scvi
import bbknn
import scib
import harmonypy
# import scgen

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib

# R interface
from rpy2.robjects import pandas2ri
from rpy2.robjects import r
import rpy2.rinterface_lib.callbacks
import anndata2ri

pandas2ri.activate()
anndata2ri.activate()

%load_ext rpy2.ipython

#supress warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import sys
import os
_stderr = sys.stderr
null = open(os.devnull,'wb')

In [None]:
print("environment loaded correctly", file=sys.stdout)

In [None]:
# set up working directory
work_dir = "/scratch_isilon/groups/singlecell/gdeuner/SERPENTINE/"

In [None]:
# set up figures directory
sc.settings.figdir = os.path.join(work_dir, "figures", "combined/")
sc.set_figure_params(dpi = 600, dpi_save=600)

In [None]:
# define integration vars
label_key = "Annotation_2.0"
batch_key = "sample"

In [None]:
## read unintegrated data
#adata = sc.read_h5ad(os.path.join(work_dir, "data", "outputdata", "Combined_SCR_CO2_annotated_2.0_TCR_14-02-24.h5ad"))
#adata_hvg = sc.read_h5ad(os.path.join(work_dir, "data", "outputdata", "combined", "Combined_SCR_CO2_annotated_2.0_TCR_HVG_22-02-24.h5ad"))

In [None]:
#print("adata object read correctly", file=sys.stdout)

In [None]:
## read anndata objects
#adata_scvi = sc.read_h5ad(os.path.join(work_dir, "data", "outputdata", "combined", "Combined_SCR_CO2_annotated_2.0_TCR_scVI_integrated_22-02-24.h5ad"))
#adata_scanvi = sc.read_h5ad(os.path.join(work_dir, "data", "outputdata", "combined", "Combined_SCR_CO2_annotated_2.0_TCR_scANVI_anno2.0_integrated_22-02-24.h5ad"))
#adata_scanvi_2 = sc.read_h5ad(os.path.join(work_dir, "data", "outputdata", "combined", "Combined_SCR_CO2_annotated_2.0_TCR_scANVI_anno1.0_integrated_22-02-24.h5ad"))
#adata_bbknn = sc.read_h5ad(os.path.join(work_dir, "data", "outputdata", "combined", "Combined_SCR_CO2_annotated_2.0_TCR_BBKNN_integrated_22-02-24.h5ad"))
#adata_harmony = sc.read_h5ad(os.path.join(work_dir, "data", "outputdata", "combined", "Combined_SCR_CO2_annotated_2.0_TCR_harmony_integrated_22-02-24.h5ad"))
#adata_seurat = sc.read_h5ad(os.path.join(work_dir, "data", "outputdata", "combined", "Combined_SCR_CO2_annotated_2.0_TCR_MNN_integrated_22-02-24.h5ad"))
#adata_scanorama = sc.read_h5ad(os.path.join(work_dir, "data", "outputdata", "combined", "Combined_SCR_CO2_annotated_2.0_TCR_scanorama_integrated_22-02-24.h5ad"))
#adata_scgen = sc.read_h5ad(os.path.join(work_dir, "data", "outputdata", "combined", "Combined_SCR_CO2_annotated_2.0_TCR_scGen_anno2.0_integrated_22-02-24.h5ad"))
#adata_scgen_2 = sc.read_h5ad(os.path.join(work_dir, "data", "outputdata", "combined", "Combined_SCR_CO2_annotated_2.0_TCR_scGen_anno1.0_integrated_22-02-24.h5ad"))

In [None]:
#print("integration adata objects loaded correctly", file=sys.stdout)

In [None]:
## remove NOISE clusters from previous patient-specific cell type annotation
#adata = adata[adata.obs["Annotation_2.0"] != "NOISE"]

In [None]:
#adata.obs["Annotation_2.0"]=adata.obs["Annotation_2.0"].astype("category")
#adata.obs["Annotation_1.0"]=adata.obs["Annotation_1.0"].astype("category")
#adata.obs["sample"]=adata.obs["sample"].astype("category")

In [None]:
## run metrics for each integration method
#metrics_scvi = scib.metrics.metrics_all(adata, adata_scvi, batch_key, label_key, embed="X_scVI", organism="human")
#print("metrics computed for scvi correctly", file=sys.stdout)
#metrics_scanvi = scib.metrics.metrics_all(adata, adata_scanvi, batch_key, label_key, embed="X_scANVI", organism="human")
#print("metrics computed for scanvi correctly", file=sys.stdout)
#metrics_scanvi_2 = scib.metrics.metrics_all(adata, adata_scanvi_2, batch_key, label_key, embed="X_scANVI", organism="human")
#print("metrics computed for scanvi2 correctly", file=sys.stdout)
#metrics_bbknn = scib.metrics.metrics_all(adata, adata_bbknn, batch_key, label_key, organism="human")
#print("metrics computed for bbknn correctly", file=sys.stdout)
#metrics_harmony = scib.metrics.metrics_all(adata, adata_harmony, batch_key, label_key, embed="X_pca_harmony", organism="human")
#print("metrics computed for harmony correctly", file=sys.stdout)
#metrics_seurat = scib.metrics.metrics_all(adata, adata_seurat, batch_key, label_key, organism="human")
#print("metrics computed for mnn correctly", file=sys.stdout)
#metrics_scgen = scib.metrics.metrics_all(adata, adata_scgen, batch_key, label_key, embed="corrected_latent", organism="human")
#print("metrics computed for scgen correctly", file=sys.stdout)
#metrics_scgen_2 = scib.metrics.metrics_all(adata, adata_scgen_2, batch_key, label_key, embed="corrected_latent", organism="human")
#print("metrics computed for scgen2 correctly", file=sys.stdout)
#metrics_hvg = scib.metrics.metrics_all(adata, adata_hvg, batch_key, label_key, organism="human")
#print("metrics computed for hvg correctly", file=sys.stdout)

In [None]:
# load metrics dfs
metrics_scvi = pd.read_csv(os.path.join(work_dir, "data", "metrics", "Combined_integration_scvi_metrics.csv"))
metrics_scanvi = pd.read_csv(os.path.join(work_dir, "data", "metrics", "Combined_integration_scanvi_anno2.0_metrics.csv"), usecols=["0"])
metrics_scanvi_2 = pd.read_csv(os.path.join(work_dir, "data", "metrics", "Combined_integration_scanvi_anno1.0_metrics.csv"), usecols=["0"])
metrics_bbknn = pd.read_csv(os.path.join(work_dir, "data", "metrics", "Combined_integration_bbknn_metrics.csv"), usecols=["0"])
metrics_harmony = pd.read_csv(os.path.join(work_dir, "data", "metrics", "Combined_integration_harmony_metrics.csv"), usecols=["0"])
metrics_seurat = pd.read_csv(os.path.join(work_dir, "data", "metrics", "Combined_integration_mnn_metrics.csv"), usecols=["0"])
metrics_scanorama = pd.read_csv(os.path.join(work_dir, "data", "metrics", "Combined_integration_scanorama_metrics.csv"), usecols=["0"])
metrics_scgen = pd.read_csv(os.path.join(work_dir, "data", "metrics", "Combined_integration_scgen_anno2.0_metrics.csv"), usecols=["0"])
metrics_scgen_2 = pd.read_csv(os.path.join(work_dir, "data", "metrics", "Combined_integration_scgen_anno1.0_metrics.csv"), usecols=["0"])
metrics_hvg = pd.read_csv(os.path.join(work_dir, "data", "metrics", "Combined_integration_hvg_metrics.csv"), usecols=["0"])

In [None]:
#put results in table
# Concatenate metrics results
metrics = pd.concat(
    [metrics_scvi, metrics_scanvi, metrics_scanvi_2, metrics_bbknn, metrics_harmony, metrics_seurat, metrics_scanorama, metrics_scgen, metrics_scgen_2, metrics_hvg],
    axis="columns",
)

In [None]:
metrics.set_index('Unnamed: 0', inplace=True)


In [None]:
metrics

In [None]:

# Set methods as column names
metrics = metrics.set_axis(
    ["scVI", "scANVI_2.0", "scANVI_1.0", "BBKNN", "Harmony", "Seurat", "Scanorama", "scGen_2.0", "scGen_1.0", "Unintegrated"], axis="columns"
)
# Select only the fast metrics
metrics = metrics.loc[
    [
        "NMI_cluster/label",
        "ARI_cluster/label",
        "ASW_label",
        "ASW_label/batch",
        "PCR_batch",
        #"cell_cycle_conservation",
        "isolated_label_F1",
        "isolated_label_silhouette",
        "graph_conn",
        "kBET",
        "iLISI",
        "cLISI",
        "hvg_overlap",
        #"trajectory",
        
    ],
    :,
]
# Transpose so that metrics are columns and methods are rows
metrics = metrics.T
# Remove the HVG overlap metric because it's not relevant to embedding outputs
metrics.index.name = None
metrics = metrics.drop(columns=["hvg_overlap"])
metrics

In [None]:
# save metrics df
metrics.to_csv(os.path.join(work_dir, "data", "metrics", "Combined_integration_metrics.csv"))

In [None]:
print("metrics saved!", file=sys.stdout)

In [None]:
# style the table
metrics.style.background_gradient(cmap="bone")

In [None]:
metrics["NMI_cluster/label"]

In [None]:
# scaled by 
metrics_scaled = (metrics - metrics.min()) / (metrics.max() - metrics.min())
metrics_scaled.style.background_gradient(cmap="bone")

In [None]:
# group into 2 groups: removsl of batch effects & conservation of biological variation
metrics_scaled["Batch"] = metrics_scaled[
    ["ASW_label/batch", "PCR_batch", "graph_conn", "iLISI", "kBET"]
].mean(axis=1)
metrics_scaled["Bio"] = metrics_scaled[["ASW_label", "isolated_label_silhouette","ARI_cluster/label", "isolated_label_F1", "NMI_cluster/label", "cLISI"]].mean(
    axis=1
)
metrics_scaled.style.background_gradient(cmap="bone")

# MISSING: ADD OTHER METRICS AS TECH OR BIO!!!

In [None]:
# use benchmarker function

In [None]:
# plot summary scores against each other
fig, ax = plt.subplots()
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
metrics_scaled.plot.scatter(
    x="Batch",
    y="Bio",
    c=range(len(metrics_scaled)),
    ax=ax,
)

for k, v in metrics_scaled[["Batch", "Bio"]].iterrows():
    ax.annotate(
        k,
        v,
        xytext=(6, -3),
        textcoords="offset points",
        family="sans-serif",
        fontsize=12,
    )

In [None]:
# give weights to tech and bio
metrics_scaled["Overall"] = 0.4 * metrics_scaled["Batch"] + 0.6 * metrics_scaled["Bio"]
metrics_scaled.style.background_gradient(cmap="bone")

In [None]:
# overall performance
metrics_scaled.plot.bar(y="Overall")


In [None]:
# save figures
fig = leiden_umap.get_figure()
fig.set_size_inches(5, 5)
fig.savefig(str(sc.settings.figdir) + '/umap_lgd_harmony_leiden',
    dpi=400, bbox_extra_artists=(lgd,), bbox_inches='tight')

In [None]:
# which metrics are bio consv and batch corr? https://scib.readthedocs.io/en/latest/api.html