In [1]:
import scib
import anndata as ad
import pandas as pd
import numpy as np
import os
from multiprocessing import Pool 
from scipy.io import mmread
from scipy.sparse import csr_matrix
import muon
import scarches as sca
import scanpy as sc
from scib_metrics.benchmark import Benchmarker
import scib_metrics
from typing import Optional

import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm
  warn(
  doc = func(self, args[0].__doc__, *args[1:], **kwargs)
 captum (see https://github.com/pytorch/captum).
  IPython.display.set_matplotlib_formats(*ipython_format)


In [2]:
import os
os.chdir("../")
os.getcwd()

'/home/CMML_mini2_final'

In [3]:
# Original data
dt_list = ['GSE156478/Control', 'GSE156478/Stim','brain_ISSAAC_seq','brain_SNARE']

In [4]:
def read_RNA_ATAC(RNA_path,ATAC_path):
    # gene expression
    cell_names = pd.read_csv(RNA_path+'barcodes.tsv', sep = '\t', header=None, index_col=None)
    cell_names.columns =  ['cell_ids'] 
    cell_names['cell_ids'] = cell_names['cell_ids'].str.replace('.','-')
    X = csr_matrix(mmread(RNA_path+'matrix.mtx').T)
    gene_names = pd.read_csv(RNA_path+'features.tsv', sep = '\t',  header=None, index_col=None) 
    gene_names.columns =  ['gene_ids'] 
    adata_RNA = ad.AnnData(X, obs=pd.DataFrame(index=cell_names.cell_ids), var=pd.DataFrame(index = gene_names.gene_ids))
    adata_RNA.var_names_make_unique()
    # peak information
    cell_names = pd.read_csv(ATAC_path + 'barcodes.tsv', sep = '\t', header=None, index_col=None)
    cell_names.columns =  ['cell_ids'] 
    cell_names['cell_ids'] = cell_names['cell_ids'].str.replace('.','-')
    X = csr_matrix(mmread(ATAC_path + 'matrix.mtx').T)
    peak_name = pd.read_csv(ATAC_path + 'features.tsv', sep = '\t',header=None,index_col=None)
    peak_name.columns = ['peak_ids']
    adata_ATAC = ad.AnnData(X, obs=pd.DataFrame(index=cell_names.cell_ids), var=pd.DataFrame(index = peak_name.peak_ids))
    return adata_RNA, adata_ATAC

In [12]:
def count_metrics(dataset):
    # original data path
    data_path = 'data/' + dataset
    adata_RNA, adata_ATAC = read_RNA_ATAC(data_path+"/RNA/",data_path+"/ATAC/")

    adata_RNA.layers["counts"] = adata_RNA.X.copy()
    sc.pp.normalize_total(adata_RNA)
    sc.pp.log1p(adata_RNA)
    sc.pp.highly_variable_genes(
        adata_RNA,
        flavor="seurat_v3",
        n_top_genes=4000,
        subset=False
    )
    adata_RNA = adata_RNA[:, adata_RNA.var.highly_variable].copy()

    adata_ATAC.layers['counts'] = adata_ATAC.X.copy()
    sc.pp.normalize_total(adata_ATAC, target_sum=1e4)
    sc.pp.log1p(adata_ATAC)
    adata_ATAC.layers['log-norm'] = adata_ATAC.X.copy()
    sc.pp.highly_variable_genes(adata_ATAC, n_top_genes=30000)
    adata_ATAC = adata_ATAC[:, adata_ATAC.var.highly_variable].copy()
    
    # simply need a map relationship between index and barcodes
    adata = sc.AnnData(obs=pd.DataFrame(index=adata_RNA.obs_names))  # or latent.index
    # metadata
    metadata = pd.read_csv(data_path + "/meta_data.csv")
    metadata['celltype'].index = adata.obs_names
    adata.obs['cell_type'] = metadata['celltype'].astype('category')
    if np.where(adata.obs["cell_type"].isna())[0].shape[0]!=0:
        adata.obs["cell_type"] = adata.obs["cell_type"].cat.add_categories(['NaN'])
        adata.obs["cell_type"][np.where(adata.obs["cell_type"].isna())[0]] = 'NaN'
    adata.obs['batch'] = ['batch1'] * adata.shape[0]
    
    metrics_list = []
    
    res_path = 'run_res/vertical/'+dataset+"/"
    # get MIRA results
    method="MIRA"
    latent = pd.read_csv(res_path+"MIRA.csv", header = None)
    latent.index = adata.obs_names
    adata.obsm[method] = latent
    sc.pp.neighbors(adata, use_rep=method)
    sc.tl.umap(adata)
    sc.tl.leiden(adata, key_added="cluster")
    scib.metrics.cluster_optimal_resolution(adata, cluster_key="cluster", label_key="cell_type")
    # compute
    ari = scib.metrics.ari(adata, cluster_key="cluster", label_key="cell_type")
    iso_asw = scib.metrics.isolated_labels_asw(adata, label_key="cell_type", batch_key='batch', embed=method,  verbose = False)
    nmi = scib.metrics.nmi(adata, cluster_key="cluster", label_key="cell_type")
    # clisi = scib.metrics.clisi_graph(adata, label_key="cell_type",use_rep=method, type_='embed')
    sht = scib.metrics.silhouette(adata, label_key="cell_type", embed=method, metric='euclidean', scale=True)
    metrics_list.append([ari, iso_asw, nmi, sht, method])
        
    
    # benchmark res
    methods = ["Seurat","PCA"]
    for method in methods:
        con = mmread(res_path + method + '_connectivities.mtx')
        dis = mmread(res_path + method + '_distance.mtx')
        adata.uns['neighbors'] = {'connectivities_key': 'connectivities', 'distances_key': 'distances', 
                                  'params': {'n_neighbors': 20, 'method': 'umap', 'random_state': 0, 
                                             'metric': 'euclidean'}}
        adata.uns['neighbors']['distance'] = csr_matrix(dis)
        adata.uns['neighbors']['connectivities'] = csr_matrix(con)
        adata.obsp['distance'] = csr_matrix(dis)
        adata.obsp['connectivities'] = csr_matrix(con)
        # get clusters
        sc.tl.umap(adata, n_components=20)
        scib.metrics.cluster_optimal_resolution(adata, cluster_key="cluster", label_key="cell_type")
        # calculate metrics
        ari = scib.metrics.ari(adata, cluster_key="cluster", label_key="cell_type")
        iso_asw = scib.metrics.isolated_labels_asw(adata, label_key="cell_type", batch_key='batch', embed="X_umap",  verbose = False)
        nmi = scib.metrics.nmi(adata, cluster_key="cluster", label_key="cell_type")
        # clisi = scib.metrics.clisi_graph(adata, label_key="cell_type",use_rep=method, type_='embed')
        sht = scib.metrics.silhouette(adata, label_key="cell_type", embed="X_umap", metric='euclidean', scale=True)
        metrics_list.append([ari, iso_asw, nmi, sht, method])

    df = pd.DataFrame(metrics_list,columns = ['ari', 'iso_asw', 'nmi', 'sht','method'])
    df['Dataset'] = dataset
    
    bench_path="bench_res/"+dataset
    df.to_csv(bench_path + "/metrics_result.csv",index = False)
    print(dataset)

In [7]:
for dataset in dt_list:
    count_metrics(dataset)

resolution: 0.1, nmi: 0.5521641819000227
resolution: 0.2, nmi: 0.5921957184716324
resolution: 0.3, nmi: 0.5599437717752949
resolution: 0.4, nmi: 0.5878816907643294
resolution: 0.5, nmi: 0.564556028965464
resolution: 0.6, nmi: 0.5624852861992561
resolution: 0.7, nmi: 0.5623885102626962
resolution: 0.8, nmi: 0.5315809652451378
resolution: 0.9, nmi: 0.551577738726008
resolution: 1.0, nmi: 0.5317240198445923
resolution: 1.1, nmi: 0.5228314619396398
resolution: 1.2, nmi: 0.5239421587895879
resolution: 1.3, nmi: 0.5112765935429152
resolution: 1.4, nmi: 0.5198080591436551
resolution: 1.5, nmi: 0.4964232216363137
resolution: 1.6, nmi: 0.49412256296263074
resolution: 1.7, nmi: 0.49367816328663416
resolution: 1.8, nmi: 0.48608873365299143
resolution: 1.9, nmi: 0.47562166553743224
resolution: 2.0, nmi: 0.4827102642396751
optimised clustering against cell_type
optimal cluster resolution: 0.2
optimal score: 0.5921957184716324
resolution: 0.1, nmi: 0.5735345541657895
resolution: 0.2, nmi: 0.59029283