env: cell_rank

# import

In [1]:
import scanpy as sc
import pandas as pd
import numpy as np
from glob import glob
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
# import torch
# from torch_geometric.data import Data
# from torch_geometric.data import DataLoader

import pickle
import sys
import requests

from types import MethodType
import importlib
# from scperturb import *
import anndata as ad

import scvelo as scv

Global seed set to 0


In [2]:
import celloracle as co
co.__version__

Matplotlib is building the font cache; this may take a moment.


'0.18.0'

In [3]:
def import_TF_data(TF_info_matrix=None, TF_info_matrix_path=None, TFdict=None):
    """
    Load data about potential-regulatory TFs.
    You can import either TF_info_matrix or TFdict.
    For more information on how to make these files, please see the motif analysis module within the celloracle tutorial.

    Args:
        TF_info_matrix (pandas.DataFrame): TF_info_matrix.

        TF_info_matrix_path (str): File path for TF_info_matrix (pandas.DataFrame).

        TFdict (dictionary): Python dictionary of TF info.
    """

    if not TF_info_matrix is None:
        tmp = TF_info_matrix.copy()
        tmp = tmp.drop(["peak_id"], axis=1)
        tmp = tmp.groupby(by="gene_short_name").sum()
        TFdict = dict(tmp.apply(lambda x: x[x>0].index.values, axis=1))

    if not TF_info_matrix_path is None:
        tmp = pd.read_parquet(TF_info_matrix_path)
        tmp = tmp.drop(["peak_id"], axis=1)
        tmp = tmp.groupby(by="gene_short_name").sum()
        TFdict = dict(tmp.apply(lambda x: x[x>0].index.values, axis=1))

    return TFdict

# initial

In [10]:
dataset_pert_dict = {
    # 'CAR_T': ['PDCD1'], # total_tf_list中似乎没有
    'blood': ['GATA1', 'SPI1'],
    'OSKM': [['SOX2',
         'POU5F1',
         'KLF4',
         'MYC']],
    'ADM': ['PTF1A']
}
dataset_celltype_dict = {
    'CAR_T': 'Tex', 
    'blood': 'LMPP',
    'OSKM': 'Fibroblast-like',
    'ADM': 'Acinar'
}

dataset_dire_dict = {
    'CAR_T': 'down', 
    'blood': 'down',
    'OSKM': 'up',
    'ADM': 'down'
}


datasets = list(dataset_pert_dict.keys())

# CellOracle跑所有datasets

In [11]:
celltype_col = 'celltype_v2'

for dataset in datasets[:]:
    # dataset = 'OSKM'
    # dataset = 'blood'
    print('='*20, dataset, '='*20)
    save_dir = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/real_case/data'
    adata_rna = sc.read(os.path.join(save_dir, dataset, 'adata_ctrl_v2.h5ad'))
    if not isinstance(adata_rna.X, np.ndarray):
        adata_rna.X = adata_rna.X.toarray()
    adata = adata_rna.copy()

    n_cells_downsample = 10000
    threshold_number = 10000

    ##########################################################

    # - get control adata
    print(f'adata.shape is: ',adata.shape)

    # -- get the baseGRN
    # Load TF info which was made from mouse cell atlas dataset.
    base_GRN = co.data.load_human_promoter_base_GRN()
    print('base_GRN.shape: ', base_GRN.shape)

    tmp_dir = f'/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark_202410/real_case/result/{dataset}'
    save_prefix = f'CellOracle' # use result of K562 to do the direct transfer
    os.makedirs(os.path.join(tmp_dir, save_prefix), exist_ok=True)

    save_dir = os.path.join(tmp_dir, save_prefix)
    if os.path.exists(os.path.join(save_dir, "ctrl.celloracle.oracle")):
        print('file exists')
        oracle = co.load_hdf5(os.path.join(save_dir, "ctrl.celloracle.oracle"))
        links = co.load_hdf5(file_path=os.path.join(save_dir, "ctrl.celloracle.links"))
        
    else:

        # - start CellOracle process for the whole ctrl

        # -- keep raw cont data before log transformation
        adata.raw = adata
        if not isinstance(adata.raw.X, np.ndarray):
            adata.layers["raw_count"] = (np.exp(adata.raw.X.toarray())-1).copy()
        else:
            adata.layers["raw_count"] = (np.exp(adata.raw.X)-1).copy()
            
        # -- get umap 
        sc.pp.scale(adata)
        # PCA
        sc.tl.pca(adata, svd_solver='arpack', random_state=2022)
        # UMAP
        sc.pp.neighbors(adata, n_neighbors=4, n_pcs=20, random_state=2022)
        sc.tl.umap(adata,random_state=2022)

        # -- Random downsampling into 30K cells if the anndata object include more than 30 K cells.
        if adata.shape[0] > n_cells_downsample:
            # Let's dowmsample into 30K cells
            sc.pp.subsample(adata, n_obs=n_cells_downsample, random_state=123)
        print(f"Cell number is :{adata.shape[0]}")

        # -- Instantiate Oracle object
        oracle = co.Oracle()

        # -- Check data in anndata
        print("Metadata columns :", list(adata.obs.columns))
        print("Dimensional reduction: ", list(adata.obsm.keys()))

        # -- In this notebook, we use the unscaled mRNA count for the nput of Oracle object.
        adata.X = adata.layers["raw_count"].copy()

        # -- Instantiate Oracle object.
        oracle.import_anndata_as_raw_count(adata=adata,
                                        cluster_column_name=celltype_col,
                                        embedding_name="X_umap")

        # -- You can load TF info dataframe with the following code.
        oracle.import_TF_data(TF_info_matrix=base_GRN)

        # -- knn imputation, this step is needed for the whole ctrl
        # Perform PCA
        oracle.perform_PCA()

        # Select important PCs
        plt.plot(np.cumsum(oracle.pca.explained_variance_ratio_)[:100])
        n_comps = np.where(np.diff(np.diff(np.cumsum(oracle.pca.explained_variance_ratio_))>0.002))[0][0]
        plt.axvline(n_comps, c="k")
        plt.show()
        print(n_comps)
        n_comps = min(n_comps, 50)

        n_cell = oracle.adata.shape[0]
        print(f"cell number is :{n_cell}")

        k = int(0.025*n_cell)
        print(f"Auto-selected k is :{k}")

        oracle.knn_imputation(n_pca_dims=n_comps, k=k, balanced=True, b_sight=k*8,
                            b_maxl=k*4, n_jobs=4)

        # model_prefix = ''
        os.makedirs(save_dir, exist_ok=True)
        # -- save the oracle
        oracle.to_hdf5(os.path.join(save_dir, "ctrl.celloracle.oracle"))

        # -- get the links
        # Calculate GRN for each population in "louvain_annot" clustering unit.
        # This step may take some time.(~30 minutes)
        links = oracle.get_links(cluster_name_for_GRN_unit=celltype_col, alpha=10,
                                verbose_level=10)

        # -- Save Links object.
        links.to_hdf5(file_path=os.path.join(save_dir, "ctrl.celloracle.links"))


    # -- filter and get the coef_mtx
    links.filter_links(threshold_number=threshold_number,
                        p=0.001,
                        weight='coef_abs')
    oracle.get_cluster_specific_TFdict_from_Links(links_object=links)
    oracle.fit_GRN_for_simulation(alpha=10,
                                use_cluster_specific_TFdict=True)

    ###################################################
    # - get all the TFs in the base_GRN
    TFdict = import_TF_data(TF_info_matrix=base_GRN)
    tf_target_dict = {}
    for target, gene_set in TFdict.items():
        for tf in gene_set:
            if tf not in tf_target_dict:
                tf_target_dict[tf] = []
                tf_target_dict[tf].append(target)
            else:
                tf_target_dict[tf].append(target)
    total_tf_list = list(tf_target_dict.keys())


    ###########################################
    celltype = dataset_celltype_dict[dataset]

    # - get the tf_GRN_dict, to check whether pert have regulatory relations
    gene_GRN_mtx = oracle.coef_matrix_per_cluster[celltype].copy()
    tf_GRN_mtx = gene_GRN_mtx[~(gene_GRN_mtx == 0).all(axis=1)]
    # - get TF-target pair and the regulatory values
    tf_GRN_dict = {} # the tf to targets
    for i in range(len(tf_GRN_mtx)):
        tmp = tf_GRN_mtx.iloc[i,:]
        tmp = tmp[tmp!=0]

        tf_GRN_dict[tf_GRN_mtx.index[i]] = {}
        for j in range(len(tmp)):
            tf_GRN_dict[tf_GRN_mtx.index[i]][tmp.index[j]] = tmp.values[j]

    ###########################################        
    # - get oracle_ctrl

    # adata_rna.obs['celltype'] = cell_line_bulk
    adata_ctrl = adata_rna.copy()
    # keep raw cont data before log transformation
    adata_ctrl.raw = adata_ctrl

    # the result will be recovered in normalized_count
    if not isinstance(adata_ctrl.raw.X, np.ndarray):
        adata_ctrl.layers["raw_count"] = (np.exp(adata_ctrl.raw.X.toarray())-1).copy()
    else:
        adata_ctrl.layers["raw_count"] = (np.exp(adata_ctrl.raw.X)-1).copy()
        
    sc.pp.scale(adata_ctrl)
    # PCA
    sc.tl.pca(adata_ctrl, svd_solver='arpack', random_state=2022)

    # Diffusion map
    sc.pp.neighbors(adata_ctrl, n_neighbors=4, n_pcs=20, random_state=2022)
    sc.tl.umap(adata_ctrl,random_state=2022)

    # Instantiate Oracle object
    oracle_ctrl = co.Oracle()

    # In this notebook, we use the unscaled mRNA count for the nput of Oracle object.
    adata_ctrl.X = adata_ctrl.layers["raw_count"].copy()

    # Instantiate Oracle object.
    oracle_ctrl.import_anndata_as_raw_count(adata=adata_ctrl,
                                    cluster_column_name=celltype_col,
                                    embedding_name="X_umap")

    # You can load TF info dataframe with the following code.
    oracle_ctrl.import_TF_data(TF_info_matrix=base_GRN)

    # get the imputed_count, here we dont do the impute to get the prediction
    oracle_ctrl.adata.layers["imputed_count"] = oracle_ctrl.adata.layers["normalized_count"].copy()

    # get the coef from the whole ctrl
    oracle_ctrl.coef_matrix_per_cluster = oracle.coef_matrix_per_cluster

    pert_gene_rank_dict = {} 
    for pert in tqdm(dataset_pert_dict[dataset]):
        # pert = ['SOX2','POU5F1','KLF4','MYC']
        print('*'*20, pert)
        if isinstance(pert, str):
            pert_combo = [pert]
        else:
            pert_combo = pert
        # - this is for crispra
        gois = pert_combo
        goi_dict = {}

        pert_prefix = '_'.join(pert_combo)

        if dataset_dire_dict[dataset] == 'up':
            # - if pert is up
            for goi in gois:
                # -- if original value is zero
                if np.mean(adata_rna[:,goi].X.toarray())==0:
                    print(f'{goi} ctrl expression is 0')
                    continue
                # -- if the TF has no targets
                if goi not in list(tf_GRN_dict.keys()):
                    print(f'{goi} is not in the tf_GRN_dict, no targets')
                    continue
                goi_dict[goi] = np.mean(adata_rna[adata.obs[celltype_col]==celltype][:,goi].X.toarray())+1
        else:
            # - if pert is down
            for goi in gois:
                # -- if original value is zero
                if np.mean(adata_rna[:,goi].X.toarray())==0:
                    print(f'{goi} ctrl expression is 0')
                    continue
                # -- if the TF has no targets
                if goi not in list(tf_GRN_dict.keys()):
                    print(f'{goi} is not in the tf_GRN_dict, no targets')
                    continue
                goi_dict[goi] = 0


        if len(goi_dict) == 0:
            print(f'{pert_prefix} is filtered')
            continue


        # Enter perturbation conditions to simulate signal propagation after the perturbation.
        oracle_ctrl.simulate_shift(perturb_condition=goi_dict,
                            n_propagation=3)
        # - get the prediction; delta_X = simulated_count - imputed_count
        delta_X, simulated_count = oracle_ctrl.adata.layers["delta_X"], oracle_ctrl.adata.layers["simulated_count"]


        # - create adata_pert
        adata_pert = adata_rna.copy()
        adata_pert.X = simulated_count
        adata_pert.X[adata_pert.X < 0] = 0
        adata_pert.obs_names = [i+f'_{pert_prefix}' for i in adata_pert.obs_names]

        # - adata_ctrl
        adata_ctrl = adata_rna.copy()

        adata_ctrl.obs['batch'] = 'ctrl'
        adata_pert.obs['batch'] = 'pert'

        adata_concat = ad.concat([adata_ctrl, adata_pert])

        tmp_dir = f'/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark_202410/real_case/result/{dataset}'
        save_prefix = f'CellOracle/{pert_prefix}' # use result of K562 to do the direct transfer
        os.makedirs(os.path.join(tmp_dir, save_prefix), exist_ok=True)
        adata_pert.write(os.path.join(tmp_dir, save_prefix, 'adata_pert.h5ad'))

    #     break
    # break


adata.shape is:  (7073, 1206)
Loading prebuilt promoter base-GRN. Version: hg19_gimmemotifsv5_fpr2
base_GRN.shape:  (37003, 1096)
file exists


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

******************** GATA1


... storing 'batch' as categorical
 50%|█████     | 1/2 [01:13<01:13, 73.83s/it]

******************** SPI1


... storing 'batch' as categorical
100%|██████████| 2/2 [01:42<00:00, 51.09s/it]


adata.shape is:  (7349, 3563)
Loading prebuilt promoter base-GRN. Version: hg19_gimmemotifsv5_fpr2
base_GRN.shape:  (37003, 1096)
Cell number is :7349
Metadata columns : ['nCount_RNA', 'nFeature_RNA', 'sample', 'percent.oskm', 'barcode_sample', 'cell type', 'celltype', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'n_counts_all', 'celltype_v2']
Dimensional reduction:  ['X_pca', 'X_umap']
28
cell number is :7349
Auto-selected k is :183


  0%|          | 0/3 [00:00<?, ?it/s]

Inferring GRN for Fibroblast-like...


  0%|          | 0/3087 [00:00<?, ?it/s]

Inferring GRN for Keratinocyte-like...


  0%|          | 0/3087 [00:00<?, ?it/s]

Inferring GRN for Stem-like...


  0%|          | 0/3087 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]

******************** ['SOX2', 'POU5F1', 'KLF4', 'MYC']


... storing 'batch' as categorical
100%|██████████| 1/1 [00:55<00:00, 55.62s/it]


adata.shape is:  (968, 4880)
Loading prebuilt promoter base-GRN. Version: hg19_gimmemotifsv5_fpr2
base_GRN.shape:  (37003, 1096)
Cell number is :968
Metadata columns : ['CELL', 'Patient', 'Type', 'Cell_type', 'celltype0', 'celltype1', 'celltype2', 'celltype3', 'Patient2', 'nCount_RNA', 'nFeature_RNA', 'Project', 'orig.ident', 'Stage', 'Grade', 'Gender', 'Age', 'Percent_mito', 'Percent_ribo', 'Percent_hemo', 'percent.mt', 'predicted.id', 'prediction.score.Fibroblast.cell', 'prediction.score.Stellate.cell', 'prediction.score.Macrophage.cell', 'prediction.score.Endothelial.cell', 'prediction.score.T.cell', 'prediction.score.B.cell', 'prediction.score.Ductal.cell.type.2', 'prediction.score.Endocrine.cell', 'prediction.score.Ductal.cell.type.1', 'prediction.score.Acinar.cell', 'prediction.score.max', 'classical_score1', 'basal_score1', 'classical_score21', 'basal_score21', 'endocrine_score_1', 'immune_score_1', 'exocrine_score_1', 'activated_stroma_score_1', 'histone_score_1', 'normal_strom

  0%|          | 0/3 [00:00<?, ?it/s]

Inferring GRN for Acinar...


  0%|          | 0/4380 [00:00<?, ?it/s]

Inferring GRN for Ductal...


  0%|          | 0/4380 [00:00<?, ?it/s]

Inferring GRN for Stellate...


  0%|          | 0/4380 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

4880 genes were found in the adata. Note that Celloracle is intended to use around 1000-3000 genes, so the behavior with this number of genes may differ from what is expected.


  0%|          | 0/1 [00:00<?, ?it/s]

******************** PTF1A


... storing 'batch' as categorical
100%|██████████| 1/1 [00:08<00:00,  8.43s/it]


# debug

In [6]:
oracle_ctrl.adata.layers["normalized_count"]

array([[0.69125253, 0.        , 0.        , ..., 0.69125253, 0.69125253,
        0.        ],
       [1.26249027, 0.        , 0.        , ..., 0.8185041 , 1.26249027,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 1.23931146,
        1.04377818],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 3.74183989,
        0.        ],
       [0.59633952, 0.        , 0.        , ..., 0.        , 2.80298233,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 1.19598246,
        0.        ]])

In [9]:
adata_rna.X

array([[0.69125253, 0.        , 0.        , ..., 0.69125253, 0.69125253,
        0.        ],
       [1.26249027, 0.        , 0.        , ..., 0.8185041 , 1.26249027,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 1.23931146,
        1.04377818],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 3.74183989,
        0.        ],
       [0.59633952, 0.        , 0.        , ..., 0.        , 2.80298233,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 1.19598246,
        0.        ]])

In [8]:
(np.exp(adata_rna.X)-1)

array([[ 0.99621429,  0.        ,  0.        , ...,  0.99621429,
         0.99621429,  0.        ],
       [ 2.53421169,  0.        ,  0.        , ...,  1.26710592,
         2.53421169,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         2.45323494,  1.83992653],
       ...,
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
        41.17551695,  0.        ],
       [ 0.81546117,  0.        ,  0.        , ...,  0.        ,
        15.49376334,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         2.30680497,  0.        ]])