In [1]:
import scib
import anndata as ad
import pandas as pd
import numpy as np
import os
from multiprocessing import Pool 
from scipy.io import mmread
from scipy.sparse import csr_matrix
import muon
import scarches as sca
import scanpy as sc
from scib_metrics.benchmark import Benchmarker
import scib_metrics
from typing import Optional

import warnings
warnings.filterwarnings("ignore")

  warn(
  doc = func(self, args[0].__doc__, *args[1:], **kwargs)
 captum (see https://github.com/pytorch/captum).
  IPython.display.set_matplotlib_formats(*ipython_format)


In [2]:
dt_list = ['10_GSE201402_down']

In [3]:
def read_RNA_ATAC(RNA_path,ATAC_path):
    # gene expression
    cell_names = pd.read_csv(RNA_path+'barcodes.tsv', sep = '\t', header=None, index_col=None)
    cell_names.columns =  ['cell_ids'] 
    cell_names['cell_ids'] = cell_names['cell_ids'].str.replace('.','-')
    X = csr_matrix(mmread(RNA_path+'matrix.mtx').T)
    gene_names = pd.read_csv(RNA_path+'features.tsv', sep = '\t',  header=None, index_col=None) 
    gene_names.columns =  ['gene_ids'] 
    adata_RNA = ad.AnnData(X, obs=pd.DataFrame(index=cell_names.cell_ids), var=pd.DataFrame(index = gene_names.gene_ids))
    adata_RNA.var_names_make_unique()
    # peak information
    cell_names = pd.read_csv(ATAC_path + 'barcodes.tsv', sep = '\t', header=None, index_col=None)
    cell_names.columns =  ['cell_ids'] 
    cell_names['cell_ids'] = cell_names['cell_ids'].str.replace('.','-')
    X = csr_matrix(mmread(ATAC_path + 'matrix.mtx').T)
    peak_name = pd.read_csv(ATAC_path + 'features.tsv', sep = '\t',header=None,index_col=None)
    peak_name.columns = ['peak_ids']
    adata_ATAC  = ad.AnnData(X, obs=pd.DataFrame(index=cell_names.cell_ids), var=pd.DataFrame(index = peak_name.peak_ids))
    return adata_RNA, adata_ATAC

In [8]:
def count_metrics(dataset):
    data_path = '/mnt/' + dataset
    adata_RNA, adata_ATAC = read_RNA_ATAC(data_path+"/RNA/",data_path+"/ATAC/")

    adata_RNA.layers["counts"] = adata_RNA.X.copy()
    sc.pp.normalize_total(adata_RNA)
    sc.pp.log1p(adata_RNA)
    sc.pp.highly_variable_genes(
        adata_RNA,
        flavor="seurat_v3",
        n_top_genes=4000,
        subset=False
    )
    adata_RNA = adata_RNA[:, adata_RNA.var.highly_variable].copy()

    adata_ATAC.layers['counts'] = adata_ATAC.X.copy()
    sc.pp.normalize_total(adata_ATAC, target_sum=1e4)
    sc.pp.log1p(adata_ATAC)
    adata_ATAC.layers['log-norm'] = adata_ATAC.X.copy()
    sc.pp.highly_variable_genes(adata_ATAC, n_top_genes=30000)
    adata_ATAC = adata_ATAC[:, adata_ATAC.var.highly_variable].copy()
    
    # adata = sca.models.organize_multiome_anndatas(
    #     adatas = [[adata_RNA], [adata_ATAC]],    # a list of anndata objects per modality, RNA-seq always goes first
    #     layers = [['counts'], ['log-norm']], # if need to use data from .layers, if None use .X
    # )
    adata = sc.AnnData(obs=pd.DataFrame(index=adata_RNA.obs_names))
    metadata = pd.read_csv("/mnt/" + dataset + "/metadata.csv")
    metadata['celltype'].index = adata.obs_names
    adata.obs['cell_type'] = metadata['celltype'].astype('category')
    if np.where(adata.obs["cell_type"].isna())[0].shape[0]!=0:
        adata.obs["cell_type"] = adata.obs["cell_type"].cat.add_categories(['NaN'])
        adata.obs["cell_type"][np.where(adata.obs["cell_type"].isna())[0]] = 'NaN'
    adata.obs['batch'] = ['batch1'] * adata.shape[0]
    
    result = pd.DataFrame()
    metrics_list = []
    method_list = ['scMVP','Schema']
    for method in method_list:
        if os.path.exists('/mnt/appealing/10_GSE201402_down/' + method + '.csv'):
            latent = pd.read_csv('/mnt/appealing/10_GSE201402_down/'+ method + '.csv', header = None)
            latent.index = adata.obs_names
            adata.obsm[method] = latent
            sc.pp.neighbors(adata, use_rep=method)
            sc.tl.umap(adata)
            sc.tl.leiden(adata, key_added="cluster")
            scib.metrics.cluster_optimal_resolution(adata, cluster_key="cluster", label_key="cell_type")
            ari = scib.metrics.ari(adata, cluster_key="cluster", label_key="cell_type")
            iso_asw = scib.metrics.isolated_labels_asw(adata, label_key="cell_type", batch_key='batch', embed=method,  verbose = False)
            nmi = scib.metrics.nmi(adata, cluster_key="cluster", label_key="cell_type")
            # clisi = scib.metrics.clisi_graph(adata, label_key="cell_type",use_rep=method, type_='embed')
            sht = scib.metrics.silhouette(adata, label_key="cell_type", embed=method, metric='euclidean', scale=True)
            metrics_list.append([ari, iso_asw, nmi, sht, method])
            
    con = mmread('/mnt/appealing/10_GSE201402_down/' + 'Seurat' + '_connectivities.mtx')
    dis = mmread('/mnt/appealing/10_GSE201402_down/' + 'Seurat' + '_distance.mtx')
    adata.uns['neighbors'] = {'connectivities_key': 'connectivities', 'distances_key': 'distances', 
                              'params': {'n_neighbors': 20, 'method': 'umap', 'random_state': 0, 
                                         'metric': 'euclidean'}}
    adata.uns['neighbors']['distance'] = csr_matrix(dis)
    adata.uns['neighbors']['connectivities'] = csr_matrix(con)
    adata.obsp['distance'] = csr_matrix(dis)
    adata.obsp['connectivities'] = csr_matrix(con)
    sc.tl.umap(adata, n_components=20)
    scib.metrics.cluster_optimal_resolution(adata, cluster_key="cluster", label_key="cell_type")
    ari = scib.metrics.ari(adata, cluster_key="cluster", label_key="cell_type")
    iso_asw = scib.metrics.isolated_labels_asw(adata, label_key="cell_type", batch_key='batch', embed=method,  verbose = False)
    nmi = scib.metrics.nmi(adata, cluster_key="cluster", label_key="cell_type")
    # clisi = scib.metrics.clisi_graph(adata, label_key="cell_type",use_rep=method, type_='embed')
    sht = scib.metrics.silhouette(adata, label_key="cell_type", embed=method, metric='euclidean', scale=True)
    metrics_list.append([ari, iso_asw, nmi, sht, 'Seurat'])

    df = pd.DataFrame(metrics_list, columns=['ARI', 'ISO_ASW', 'NMI', 'Silhouette', 'method'])
    result = df
    result['Dataset'] = dataset
    result.to_csv(data_path + "/metrics_result1.csv",index = False)
    print(dataset)

In [9]:
for dataset in dt_list:
    count_metrics(dataset)

resolution: 0.1, nmi: 0.8650851990842399
resolution: 0.2, nmi: 0.886368427310364
resolution: 0.3, nmi: 0.7616788267552257
resolution: 0.4, nmi: 0.7433073581995274
resolution: 0.5, nmi: 0.7401207335007871
resolution: 0.6, nmi: 0.671553022305028
resolution: 0.7, nmi: 0.6457632772713251
resolution: 0.8, nmi: 0.6428685439184371
resolution: 0.9, nmi: 0.6109411856141576
resolution: 1.0, nmi: 0.5835742867058128
resolution: 1.1, nmi: 0.56721866735733
resolution: 1.2, nmi: 0.5579618390800895
resolution: 1.3, nmi: 0.5472011904365633
resolution: 1.4, nmi: 0.5597468616185625
resolution: 1.5, nmi: 0.5403220387831961
resolution: 1.6, nmi: 0.5344644824059293
resolution: 1.7, nmi: 0.5244313255146976
resolution: 1.8, nmi: 0.5263683387789307
resolution: 1.9, nmi: 0.5179633982544485
resolution: 2.0, nmi: 0.5229305114499204
optimised clustering against cell_type
optimal cluster resolution: 0.2
optimal score: 0.886368427310364
resolution: 0.1, nmi: 0.6983722434067615
resolution: 0.2, nmi: 0.686852294832766