In [1]:
import os
import rpy2
import scib
import scanpy
import scipy
import warnings
import numpy as np
import pandas as pd
from rpy2 import robjects
from IPython.utils import io
from tqdm  import tqdm as tqdm
from sklearn.neighbors import NearestNeighbors, KNeighborsRegressor

In [2]:
%load_ext rpy2.ipython

In [3]:
%R library(lisi)

0,1,2,3,4,5,6
'lisi','tools','stats',...,'datasets','methods','base'


# 1. Load data

In [4]:
dataset="small_atac_peaks" 
data_keys={"human_lung_atlas":{"batch_key":"batch","label_key":"cell_type"},
           "human_pancreas":{"batch_key":"tech","label_key":"celltype"},
           "small_atac_peaks":{"batch_key":"batchname","label_key":"final_cell_label"},
           "small_atac_windows":{"batch_key":"batchname","label_key":"final_cell_label"}}
methods = ['SCITUNA', 'Scanorama', 'fastMNN', 'Seurat', 'SAUCIE']

### 1.2  Original Datasets (unintegrated)

In [5]:
print("Loading dataset")
unintegrated_data=scanpy.read_h5ad("../data/{}.h5ad".format(dataset))

Loading dataset


  d[k] = read_elem(f[k])
  categories = read_elem(categories_dset)
  read_elem(dataset), categories, ordered=ordered
  categories = read_elem(categories_dset)
  read_elem(dataset), categories, ordered=ordered
  categories = read_elem(categories_dset)
  read_elem(dataset), categories, ordered=ordered
  return read_elem(dataset)
  return read_elem(dataset)
  return read_elem(dataset)
  return read_elem(dataset)
  return read_elem(dataset)
  return read_elem(dataset)
  return read_elem(dataset)
  return read_elem(dataset)
  return read_elem(dataset)
  return read_elem(dataset)
  return read_elem(dataset)
  return read_elem(dataset)
  return read_elem(dataset)
  return read_elem(dataset)
  return read_elem(dataset)


### 1.1 Batch Pairs 

In [6]:
# retreive batch pairs as tuples
batch_pairs = [(a,b) for idx, a in enumerate(np.unique(unintegrated_data.obs[data_keys[dataset]["batch_key"]])) for b in np.unique(unintegrated_data.obs[data_keys[dataset]["batch_key"]])[idx + 1:]]
print("There are :",len(batch_pairs)," batch pairs.")
batch_pairs[:4]

There are : 3  batch pairs.


[('10x Genomics', 'Cusanovich et al.'),
 ('10x Genomics', 'Fang et al.'),
 ('Cusanovich et al.', 'Fang et al.')]

# 2. Evaluation

In [7]:
metric_alias = {
    'NMI_cluster/label':           "NMI cluster/label",
    'ARI_cluster/label':           "ARI cluster/label",
    'ASW_label':                   "Cell type ASW",
    'isolated_label_F1':           "Isolated label F1",
    'isolated_label_silhouette':   "Isolated label silhouette", 
    'cell_cycle_conservation':     "CC conservation",
    'hvg_overlap':                 "HVG conservation", 
    'cLISI':                       "cLISI",
    'PCR_batch':                   "PCR batch",
    'ASW_label/batch':             "Batch ASW",
    'iLISI':                       "iLISI",
    'graph_conn':                  "Graph connectivity",
}

In [8]:
def overcorrection_score(emb, celltype, n_neighbors=100, n_pools=100, n_samples_per_pool=100, seed=124):
    """
    source: https://doi.org/10.1038/s41467-022-33758-z
    """
    n_neighbors = min(n_neighbors, len(emb) - 1)
    nne = NearestNeighbors(n_neighbors=1 + n_neighbors, n_jobs=8)
    nne.fit(emb)
    kmatrix = nne.kneighbors_graph(emb) - scipy.sparse.identity(emb.shape[0])

    score = 0
    celltype_ = np.unique(celltype)
    celltype_dict = celltype.value_counts().to_dict()
    
    N_celltype = len(celltype_)

    for t in range(n_pools):
        indices = np.random.choice(np.arange(emb.shape[0]), size=n_samples_per_pool, replace=False)
        score += np.mean([np.mean(celltype[kmatrix[i].nonzero()[1]][:min(celltype_dict[celltype[i]], n_neighbors)] == celltype[i]) for i in indices])

    return 1-score / float(n_pools)

In [9]:
def scib_metrics(m_path, unintegrated_data):
    
    #check if file is empty or corrupted
    if os.stat(m_path).st_size == 0:
        print(f'{m_path} is empty, setting all metrics to NA.')
        return

    else:
        #integrated data
        adata_int = scanpy.read(m_path, cache=True)
        #anndata object of the data before integration
        adata_pre=unintegrated_data[adata_int.obs_names]

    #check if the number of genes in the integrated dataset is less than the desired number of HVG
    if (n_hvgs is not None):
        if (adata_int.n_vars < n_hvgs):
            raise ValueError("There are less genes in the corrected adata than specified for HVG selection")



    # check input files
    if adata_pre.n_obs != adata_int.n_obs:
        print("Error detected: Observations")
        message = "The datasets have different numbers of cells before and after integration."
        message += "Please make sure that both datasets match."
        raise ValueError(message)

    # check if the obsnames were changed and rename them in that case
    if len(set(adata_pre.obs_names).difference(set(adata_int.obs_names))) > 0:
        print("Error detected: Observation Mames")
        # rename adata_int.obs[batch_key] labels by overwriting them with the pre-integration labels
        new_obs_names = ['-'.join(idx.split('-')[:-1]) for idx in adata_int.obs_names]

        if len(set(adata_pre.obs_names).difference(set(new_obs_names))) == 0:
            adata_int.obs_names = new_obs_names
        else:
            raise ValueError('obs_names changed after integration!')

    # batch_key might be overwritten, so we match it to the pre-integrated labels
    adata_int.obs[data_keys[dataset]["batch_key"]] = adata_int.obs[data_keys[dataset]["batch_key"]].astype('category')
    batch_u = adata_pre.obs[data_keys[dataset]["batch_key"]].value_counts().index
    batch_i = adata_int.obs[data_keys[dataset]["batch_key"]].value_counts().index
    if not batch_i.equals(batch_u):
        # pandas uses the table index to match the correct labels
        adata_int.obs[data_keys[dataset]["batch_key"]] = adata_pre.obs[data_keys[dataset]["batch_key"]]


    #with io.capture_output() as captured:
    with io.capture_output() as captured:
        scib.preprocessing.reduce_data(
            adata_int,
            n_top_genes=n_hvgs,
            neighbors=True,
            use_rep='X_pca',
            pca=True,
            umap=False
        )

    #print("| Batch & Bio metrics",end="\t")

    # DEFAULT
    silhouette_ = True
    nmi_ = True
    ari_ = True
    pcr_ = True
    cell_cycle_ = True
    isolated_labels_ = True
    hvg_score_ = True
    graph_conn_ = True

    if assay == "simulation":
        cell_cycle_ = False
    elif assay == "atac":
        cell_cycle_ = False
        hvg_score_ = False


    with io.capture_output() as captured:
        metrics = scib.me.metrics(
            adata_pre,
            adata_int,
            verbose=False,
            hvg_score_=hvg_score_,
            cluster_nmi=None,
            batch_key=data_keys[dataset]["batch_key"],
            label_key=data_keys[dataset]["label_key"],
            silhouette_=silhouette_,
            nmi_=nmi_,
            nmi_method='arithmetic',
            nmi_dir=None,
            ari_=ari_,
            pcr_=pcr_,
            cell_cycle_=cell_cycle_,
            organism=organism,
            isolated_labels_=isolated_labels_,
            n_isolated=None,
            graph_conn_=graph_conn_,
            kBET_=False,
            lisi_graph_=False,
            trajectory_=False
        )

    ###### Calculate iLISI, cLISI ######
    #print("| LISI",end="\t")
    integrated_df=adata_int.to_df()
    celltypes_df=pd.DataFrame(adata_int.obs[data_keys[dataset]["label_key"]].loc[integrated_df.index])
    batches_df=pd.DataFrame(adata_int.obs[data_keys[dataset]["batch_key"]].loc[integrated_df.index])
    %R -i integrated_df,celltypes_df,batches_df
    %R cLISI=lisi::compute_lisi(integrated_df, data.frame(celltypes_df), colnames(celltypes_df))
    %R iLISI=lisi::compute_lisi(integrated_df, data.frame(batches_df), colnames(batches_df))
    %R -o cLISI,iLISI


    #scale ilISI score
    nbatches = adata_pre.obs[data_keys[dataset]["batch_key"]].nunique()
    scaled_ilisi = (np.nanmean(iLISI) - 1) / (nbatches - 1)
    metrics[0]["iLISI"]=scaled_ilisi


    #scale clISI score
    nlabs = adata_pre.obs[data_keys[dataset]["label_key"]].nunique()
    scaled_clisi = (nlabs - np.nanmean(cLISI)) / (nlabs - 1)
    metrics[0]["cLISI"]=scaled_clisi
    ####################################


    ##############Over correction
    #print("| OC",end="\t")
    scanpy.pp.neighbors(adata_int)
    scanpy.tl.umap(adata_int, min_dist=0.1)
    metrics.loc["1 - Over correction"] = 1. - overcorrection_score(adata_int.obsm["X_umap"], adata_int.obs[data_keys[dataset]["label_key"]])
    ####################################

    #print("| Save..",end="\n")
    metrics.columns=[pair]
    metrics.rename( index=metric_alias, inplace=True)
    return metrics

In [11]:
n_hvgs = 2000
organism = "human" # human | mouse 
assay = "atac" # expression | atac | simulation
version = "pairwise" # pairwise | MBI 

In [12]:
warnings.filterwarnings('ignore')
o_folder=f"../output/{dataset}/{version}/"
try:
    os.mkdir(f"{o_folder}/metrics/")
except:
    pass

if version == "pairwise":
    combined_metrics = {}
    num = 0
    for pair in tqdm(batch_pairs):
        for method in methods[:1]:
            if not  os.path.isfile(f"{o_folder}/{method}_[{pair[0]}]_[{pair[1]}].h5ad"):
                continue

            if os.path.isfile(f"{o_folder}/metrics/{method}_[{pair[0]}]_[{pair[1]}].csv"):
                continue

            if method not in combined_metrics:
                combined_metrics[method] = None

            m_path=f"{o_folder}/{method}_[{pair[0]}]_[{pair[1]}].h5ad"
            metrics = scib_metrics(m_path, unintegrated_data)
            
            if combined_metrics[method] is None:
                combined_metrics[method] = metrics
            else:
                combined_metrics[method] = pd.concat([combined_metrics[method], metrics], axis = 1)

            metrics.to_csv(f"{o_folder}/metrics/{method}_[{pair[0]}]_[{pair[1]}].csv")
            
    for method in combined_metrics:            
        print(method, combined_metrics[method].shape)
        combined_metrics[method].to_csv(f"{o_folder}/{method}_metrics.csv")  
    
    
elif version == "MBI":
     for method in tqdm(methods[:1]):
        if not  os.path.isfile(f"{o_folder}/{method}.h5ad"):
            continue

        if os.path.isfile(f"{o_folder}/metrics/{method}.csv"):
            continue
            
        m_path=f"{o_folder}/{method}.h5ad"
        metrics = scib_metrics(m_path, unintegrated_data)
        metrics.to_csv(f"{o_folder}/metrics/{method}.csv")

  0%|                                                                                                                                                                                  | 0/3 [00:00<?, ?it/s]2025-02-13 22:03:22.641680: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2025-02-13 22:03:22.641761: I tensorflow/compiler/xla/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2025-02-13 22:03:23.929010: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2025-02-13 22:03:23.929145: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnv

Only considering the two last: ['.]', '.h5ad'].
Only considering the two last: ['.]', '.h5ad'].


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [25:37<00:00, 512.47s/it]

SCITUNA (15, 3)



