# import

In [3]:
import scanpy as sc
import pandas as pd
import numpy as np
from glob import glob
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from torch_geometric.data import Data
from torch_geometric.data import DataLoader

import pickle
import sys
import requests

from types import MethodType
import importlib
from scperturb import *
import anndata as ad

In [4]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve
from tqdm import tqdm

# 在L1000数据上跑direct transfer

In [5]:
# - direc_change_dict: K562 direct change
import json
with open('/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/utils_data/direct_change_dict.json', 'r') as f:
    direct_change_dict = json.load(f)
    
# - single_total_perts: the perts used in K562
gene_list = direct_change_dict['gene_list']
single_total_perts = list(direct_change_dict.keys())
single_total_perts.remove('gene_list')

In [6]:
len(direct_change_dict)

7992

In [7]:
# - get cell line name
common_cell_line = \
{   'A549': 'A549',
    'HEPG2': 'HepG2',
    'HT29': 'HT29',
    'MCF7': 'MCF7',
    # 'SKBR3': 'SK-BR-3',
    'SW480': 'SW480',
    'PC3': 'PC3',
    'A375': 'A375',
} # L1000 cell line : single-cell cell line

# - read adata_L1000, this is processed data
adata_L1000 = sc.read('/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/GSE92742/adata_gene_pert.h5ad')
adata_L1000



AnnData object with n_obs × n_vars = 36720 × 978
    obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'

In [22]:
from tqdm import tqdm
from sklearn.metrics import precision_recall_curve, auc
from scipy.spatial.distance import cdist

for cell_line_bulk in list(common_cell_line.keys())[:]:
    cell_line_single = common_cell_line[cell_line_bulk]
    print('='*20, f'cell line is {cell_line_single}')
    
    #===================prepare data
    if cell_line_bulk in ['PC3', 'A375']:
        save_dir_adata = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/single_cell_data/SCP542'
        adata_rna = sc.read(os.path.join(save_dir_adata, cell_line_bulk, f'adata_{cell_line_bulk}.h5ad'))
        
        # - read adata_rna_raw
        save_dir = f'/nfs/public/lichen/data/single_cell/cell_line/SCP542/process/{cell_line_bulk}'
        adata_rna_raw = sc.read(os.path.join(save_dir, f'adata.h5ad'))
    
    else:
        save_dir_adata = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/single_cell_data/CNP0003658'
        adata_rna = sc.read(os.path.join(save_dir_adata, cell_line_bulk, f'adata_{cell_line_bulk}.h5ad'))

        # - read adata_rna
        save_dir = f'/nfs/public/lichen/data/single_cell/cell_line/CNP0003658/process/RNA/{cell_line_single}'
        adata_rna_raw = sc.read(os.path.join(save_dir, f'adata_rna_{cell_line_single}.h5ad'))

    # - consctrut corr mtx
    if not isinstance(adata_rna.X, np.ndarray):
        adata_rna.X = adata_rna.X.toarray()
    
    # - get var_names
    var_names = list(adata_rna.var_names)
    
    # - get common pert
    adata_L1000_sub = adata_L1000[adata_L1000.obs['cell_id']==cell_line_bulk]
    L1000_total_perts = np.unique(adata_L1000_sub.obs['pert_iname'])
    common_perts = np.intersect1d(single_total_perts, L1000_total_perts)
    
    # - get common var
    common_var = np.intersect1d(adata_rna.var_names, direct_change_dict['gene_list'])
    common_var_2 = np.intersect1d(common_var, adata_L1000.var_names)
    
    print('common_perts num: ', len(common_perts))
    print('common var of direct change and single-cell data is: ', len(common_var))
    print('common var to L1000 data is: ', len(common_var_2))
    
    # - get the most close genes of common genes
    matrix = adata_rna.X.T  # 示例数据
    # 假设我们限制的 index_list 是以下这些索引
    index_list = np.array([list(adata_rna.var_names).index(i) for i in common_var])  # 示例 index_list
    # 计算每个向量与其他向量的余弦距离
    distance_matrix = cdist(matrix, matrix, metric='cosine')
    # 将对角线（即每个向量到自己的距离）设置为无穷大，以排除自己
    np.fill_diagonal(distance_matrix, np.inf)
    # 创建一个布尔掩码，将不在 index_list 中的向量距离设置为无穷大
    mask = np.ones(distance_matrix.shape, dtype=bool)
    mask[:, index_list] = False  # 只保留 index_list 中的列
    distance_matrix[mask] = np.inf  # 其他列设置为无穷大
    # 找到每个向量最近的其他向量的索引
    nearest_indices = np.argmin(distance_matrix, axis=1)
    # 将结果转换为 list
    nearest_indices_list = nearest_indices.tolist()

    
    common_idx = [direct_change_dict['gene_list'].index(gene) if i in common_var else direct_change_dict['gene_list'].index(adata_rna.var_names[nearest_indices_list[i]]) for i, gene in enumerate(adata_rna.var_names)]
    
    adata_pert_list = []
    pert_gene_rank_dict = {}
    for pert in tqdm(common_perts):
        # - create adata_pert
        exp_change = np.array(direct_change_dict[pert])[common_idx]
        
        # - read adata_ctrl, drop all the obs for merging
        adata_ctrl = adata_rna.copy()
        # adata_ctrl.obs = adata_ctrl.obs.drop(adata_ctrl.obs.columns, axis=1)
        
        adata_pert = adata_ctrl.copy()
        adata_pert.X += exp_change
        adata_pert.X[adata_pert.X < 0] = 0
        adata_pert.obs_names = [i+f'_{pert}' for i in adata_pert.obs_names]
        
        adata_ctrl.obs['batch'] = 'ctrl'
        adata_pert.obs['batch'] = 'pert'
        
        adata_concat = ad.concat([adata_ctrl, adata_pert])
        
        # - cal de genes
        rankby_abs = False

        sc.tl.rank_genes_groups(
            adata_concat,
            groupby='batch',
            reference='ctrl',
            rankby_abs=rankby_abs,
            n_genes=len(adata_concat.var),
            use_raw=False,
            method = 'wilcoxon'
        )
        de_genes = pd.DataFrame(adata_concat.uns['rank_genes_groups']['names'])
        pvals = pd.DataFrame(adata_concat.uns['rank_genes_groups']['pvals'])
        pvals_adj = pd.DataFrame(adata_concat.uns['rank_genes_groups']['pvals_adj'])
        scores = pd.DataFrame(adata_concat.uns['rank_genes_groups']['scores'])
        logfoldchanges = pd.DataFrame(adata_concat.uns['rank_genes_groups']['logfoldchanges'])

        # - get gene_score
        gene_score = pd.DataFrame({'gene':list(de_genes['pert']),
                                   'z-score':list(scores['pert'])})
        
        pert_gene_rank_dict[pert] = (list(de_genes['pert']), list(scores['pert']))
        
        # break
        
    save_dir = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark_202410/zero_shot/result'
    save_prefix = f'direct_transfer_K562/{cell_line_bulk}' # use result of K562 to do the direct transfer
    os.makedirs(os.path.join(save_dir, save_prefix), exist_ok=True)

    import json
    with open(os.path.join(save_dir, save_prefix, 'pert_gene_rank_dict.json'), 'w') as f:
        json.dump(pert_gene_rank_dict, f)
        
        
    # break





common_perts num:  1794
common var of direct change and single-cell data is:  2764
common var to L1000 data is:  712


  0%|          | 6/1794 [00:19<1:35:16,  3.20s/it]


KeyboardInterrupt: 

# debug - 并行化代码

In [8]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
from tqdm import tqdm
from sklearn.metrics import precision_recall_curve, auc
from scipy.spatial.distance import cdist
import concurrent.futures
import json

# 定义处理每个 cell_line_single 的函数
def process_cell_line(cell_line_bulk, cell_line_single, common_cell_line, adata_L1000, single_total_perts, direct_change_dict):
    print('=' * 20, f'cell line is {cell_line_single}')

    # 准备数据
    if cell_line_bulk in ['PC3', 'A375']:
        save_dir_adata = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/single_cell_data/SCP542'
        adata_rna = sc.read(os.path.join(save_dir_adata, cell_line_bulk, f'adata_{cell_line_bulk}.h5ad'))

        save_dir = f'/nfs/public/lichen/data/single_cell/cell_line/SCP542/process/{cell_line_bulk}'
        adata_rna_raw = sc.read(os.path.join(save_dir, f'adata.h5ad'))

    else:
        save_dir_adata = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/benchmark_data/L1000/single_cell_data/CNP0003658'
        adata_rna = sc.read(os.path.join(save_dir_adata, cell_line_bulk, f'adata_{cell_line_bulk}.h5ad'))

        save_dir = f'/nfs/public/lichen/data/single_cell/cell_line/CNP0003658/process/RNA/{cell_line_single}'
        adata_rna_raw = sc.read(os.path.join(save_dir, f'adata_rna_{cell_line_single}.h5ad'))

    # 计算相关矩阵
    if not isinstance(adata_rna.X, np.ndarray):
        adata_rna.X = adata_rna.X.toarray()
    corr_mtx = np.corrcoef(adata_rna.X.T)

    # 获取 common pert 和 common var
    adata_L1000_sub = adata_L1000[adata_L1000.obs['cell_id'] == cell_line_bulk]
    L1000_total_perts = np.unique(adata_L1000_sub.obs['pert_iname'])
    common_perts = np.intersect1d(single_total_perts, L1000_total_perts)
    common_var = np.intersect1d(adata_rna.var_names, direct_change_dict['gene_list'])
    common_var_2 = np.intersect1d(common_var, adata_L1000.var_names)

    print('common_perts num: ', len(common_perts))
    print('common var of direct change and single-cell data is: ', len(common_var))
    print('common var to L1000 data is: ', len(common_var_2))

    # 最近基因计算
    matrix = adata_rna.X.T
    index_list = np.array([list(adata_rna.var_names).index(i) for i in common_var])

    distance_matrix = cdist(matrix, matrix, metric='cosine')
    np.fill_diagonal(distance_matrix, np.inf)
    mask = np.ones(distance_matrix.shape, dtype=bool)
    mask[:, index_list] = False
    distance_matrix[mask] = np.inf
    nearest_indices = np.argmin(distance_matrix, axis=1)
    nearest_indices_list = nearest_indices.tolist()

    common_idx = [direct_change_dict['gene_list'].index(gene) if i in common_var else direct_change_dict['gene_list'].index(adata_rna.var_names[nearest_indices_list[i]]) for i, gene in enumerate(adata_rna.var_names)]

    pert_gene_rank_dict = {}
    for pert in tqdm(common_perts):
        exp_change = np.array(direct_change_dict[pert])[common_idx]

        adata_ctrl = adata_rna.copy()
        adata_pert = adata_ctrl.copy()
        adata_pert.X += exp_change
        adata_pert.X[adata_pert.X < 0] = 0
        adata_pert.obs_names = [i + f'_{pert}' for i in adata_pert.obs_names]

        adata_ctrl.obs['batch'] = 'ctrl'
        adata_pert.obs['batch'] = 'pert'

        adata_concat = sc.concat([adata_ctrl, adata_pert])

        sc.tl.rank_genes_groups(
            adata_concat,
            groupby='batch',
            reference='ctrl',
            rankby_abs=False,
            n_genes=len(adata_concat.var),
            use_raw=False,
            method='wilcoxon'
        )

        de_genes = pd.DataFrame(adata_concat.uns['rank_genes_groups']['names'])
        scores = pd.DataFrame(adata_concat.uns['rank_genes_groups']['scores'])

        pert_gene_rank_dict[pert] = (list(de_genes['pert']), list(scores['pert']))

    save_dir = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark_202410/zero_shot/result'
    save_prefix = f'direct_transfer_K562/{cell_line_bulk}'
    os.makedirs(os.path.join(save_dir, save_prefix), exist_ok=True)

    with open(os.path.join(save_dir, save_prefix, 'pert_gene_rank_dict.json'), 'w') as f:
        json.dump(pert_gene_rank_dict, f)

# 主函数
if __name__ == "__main__":
    # common_cell_line = \
    # {   'A549': 'A549',
    #     'HEPG2': 'HepG2',
    #     'HT29': 'HT29',
    #     'MCF7': 'MCF7',
    #     # 'SKBR3': 'SK-BR-3',
    #     'SW480': 'SW480',
    #     'PC3': 'PC3',
    #     'A375': 'A375',
    # } # L1000 cell line : single-cell cell line


    # adata_L1000 = sc.read("/path/to/adata_L1000.h5ad")  # 示例
    # single_total_perts = []  # 定义 single_total_perts
    # direct_change_dict = {}  # 定义 direct_change_dict

    # 使用并行执行
    with concurrent.futures.ProcessPoolExecutor() as executor:
        futures = [
            executor.submit(process_cell_line, cell_line_bulk, common_cell_line[cell_line_bulk], common_cell_line, adata_L1000, single_total_perts, direct_change_dict)
            for cell_line_bulk in common_cell_line.keys()
        ]
        
        # 等待所有任务完成
        for future in concurrent.futures.as_completed(futures):
            try:
                future.result()  # 获取每个任务的结果，如果有异常，将在此处抛出
            except Exception as e:
                print(f"Error: {e}")






common_perts num:  2184
common var of direct change and single-cell data is:  2764
common var to L1000 data is:  712












common_perts num:  1980
common var of direct change and single-cell data is:  2983
common var to L1000 data is:  712
common_perts num:  2033
common var of direct change and single-cell data is:  2663
common var to L1000 data is:  708








common_perts num:  2215
common var of direct change and single-cell data is:  3184
common var to L1000 data is:  713


  0%|          | 0/2184 [00:00<?, ?it/s]

common_perts num:  161
common var of direct change and single-cell data is:  1547
common var to L1000 data is:  712


  0%|          | 3/2184 [00:03<36:25,  1.00s/it]

common_perts num:  2235
common var of direct change and single-cell data is:  3298
common var to L1000 data is:  707


  0%|          | 5/2184 [00:04<33:57,  1.07it/s]



  0%|          | 4/2235 [00:02<17:14,  2.16it/s]]

common_perts num:  2195
common var of direct change and single-cell data is:  3190
common var to L1000 data is:  711


100%|██████████| 161/161 [01:45<00:00,  1.52it/s]]
100%|██████████| 2235/2235 [10:39<00:00,  3.49it/s] 
100%|██████████| 2215/2215 [26:54<00:00,  1.37it/s]  
 90%|█████████ | 1791/1980 [28:15<02:34,  1.22it/s]
100%|██████████| 1980/1980 [30:47<00:00,  1.07it/s]
100%|██████████| 2195/2195 [33:57<00:00,  1.08it/s]
100%|██████████| 2184/2184 [34:53<00:00,  1.04it/s]


# debug

In [23]:
for pert in tqdm(common_perts):
    # - create adata_pert
    exp_change = np.array(direct_change_dict[pert])[common_idx]
    
    # - read adata_ctrl, drop all the obs for merging
    adata_ctrl = adata_rna.copy()
    # adata_ctrl.obs = adata_ctrl.obs.drop(adata_ctrl.obs.columns, axis=1)
    
    adata_pert = adata_ctrl.copy()
    adata_pert.X += exp_change
    adata_pert.X[adata_pert.X < 0] = 0
    adata_pert.obs_names = [i+f'_{pert}' for i in adata_pert.obs_names]
    
    adata_ctrl.obs['batch'] = 'ctrl'
    adata_pert.obs['batch'] = 'pert'
    
    adata_concat = ad.concat([adata_ctrl, adata_pert])
    
    # - cal de genes
    rankby_abs = False

    sc.tl.rank_genes_groups(
        adata_concat,
        groupby='batch',
        reference='ctrl',
        rankby_abs=rankby_abs,
        n_genes=len(adata_concat.var),
        use_raw=False,
        method = 'wilcoxon'
    )
    de_genes = pd.DataFrame(adata_concat.uns['rank_genes_groups']['names'])
    pvals = pd.DataFrame(adata_concat.uns['rank_genes_groups']['pvals'])
    pvals_adj = pd.DataFrame(adata_concat.uns['rank_genes_groups']['pvals_adj'])
    scores = pd.DataFrame(adata_concat.uns['rank_genes_groups']['scores'])
    logfoldchanges = pd.DataFrame(adata_concat.uns['rank_genes_groups']['logfoldchanges'])

    # - get gene_score
    gene_score = pd.DataFrame({'gene':list(de_genes['pert']),
                                'z-score':list(scores['pert'])})
    
    pert_gene_rank_dict[pert] = (list(de_genes['pert']), list(scores['pert']))

  0%|          | 8/1794 [00:15<56:09,  1.89s/it]  


KeyboardInterrupt: 

In [17]:
common_idx = [direct_change_dict['gene_list'].index(gene) if gene in common_var else direct_change_dict['gene_list'].index(adata_rna.var_names[nearest_indices_list[i]]) for i, gene in enumerate(adata_rna.var_names)]

ValueError: 'CALCB' is not in list

In [18]:
direct_change_dict['gene_list'].index('CALCB')

ValueError: 'CALCB' is not in list

In [20]:
'CALCB' in common_var
len(common_var)

2764

In [21]:
import numpy as np
from scipy.spatial.distance import cdist

# 假设我们有一个 500*100 的矩阵
matrix = adata_rna.X.T  # 示例数据

# 假设我们限制的 index_list 是以下这些索引
index_list = np.array([list(adata_rna.var_names).index(i) for i in common_var])  # 示例 index_list

# 计算每个向量与其他向量的余弦距离
distance_matrix = cdist(matrix, matrix, metric='cosine')

# 将对角线（即每个向量到自己的距离）设置为无穷大，以排除自己
np.fill_diagonal(distance_matrix, np.inf)

# 创建一个布尔掩码，将不在 index_list 中的向量距离设置为无穷大
mask = np.ones(distance_matrix.shape, dtype=bool)
mask[:, index_list] = False  # 只保留 index_list 中的列
distance_matrix[mask] = np.inf  # 其他列设置为无穷大

# 找到每个向量最近的其他向量的索引
nearest_indices = np.argmin(distance_matrix, axis=1)

# 将结果转换为 list
nearest_indices_list = nearest_indices.tolist()

print(nearest_indices_list)


[3054, 1001, 3625, 2137, 399, 3826, 2704, 1595, 1255, 4283, 4603, 691, 1975, 3298, 662, 3298, 2437, 2323, 3453, 2137, 2137, 2138, 2147, 4603, 1207, 917, 2772, 908, 4182, 212, 1255, 3584, 2839, 3632, 1204, 2097, 1195, 2315, 2050, 622, 883, 4407, 3518, 1439, 2410, 3001, 3140, 2872, 1330, 4324, 2315, 3981, 2392, 2007, 1439, 2915, 5138, 4482, 4430, 2315, 956, 653, 3270, 2296, 4221, 2315, 4430, 2206, 3680, 1039, 4108, 572, 2839, 4879, 2206, 3020, 524, 2406, 3918, 1001, 2130, 0, 3474, 402, 1303, 2406, 1277, 533, 1749, 4917, 480, 4118, 2161, 2107, 4694, 2997, 2713, 2091, 867, 2776, 2315, 2870, 4304, 2392, 1640, 4836, 4503, 970, 5047, 5058, 3835, 3373, 2717, 3033, 3478, 3521, 2371, 4894, 4430, 1322, 1743, 2050, 571, 3585, 2206, 1502, 524, 3986, 3124, 555, 2563, 4626, 4026, 667, 3236, 659, 2169, 4335, 3091, 597, 558, 3579, 4029, 2206, 3663, 1509, 2262, 880, 3804, 4869, 2713, 3734, 1662, 723, 4231, 4756, 529, 703, 832, 4836, 2598, 4836, 4546, 3752, 850, 4038, 2741, 1137, 2392, 1115, 2713, 1572, 

In [11]:
import numpy as np
from scipy.spatial.distance import cdist

# 假设我们有一个 500*100 的矩阵
matrix = adata_rna.X.T  # 示例数据

# 计算每个向量与其他向量的余弦距离
distance_matrix = cdist(matrix, matrix, metric='cosine')

# 将对角线（即每个向量到自己的距离）设置为无穷大，以排除自己
np.fill_diagonal(distance_matrix, np.inf)

# 找到每个向量最近的其他向量的索引
nearest_indices = np.argmin(distance_matrix, axis=1)

# 将结果转换为 list
nearest_indices_list = nearest_indices.tolist()

print(nearest_indices_list)


[81, 786, 3625, 2137, 2530, 3826, 180, 1595, 1396, 4283, 4085, 691, 1975, 2835, 662, 3298, 2437, 2323, 3453, 2137, 2137, 2138, 2147, 1747, 301, 917, 2772, 908, 4182, 212, 1255, 3584, 2241, 3632, 4886, 2097, 4873, 1702, 465, 622, 4666, 1190, 3518, 786, 331, 5048, 781, 2245, 2090, 3360, 100, 3981, 108, 242, 58, 2240, 185, 4482, 3250, 1702, 1180, 114, 1006, 935, 4221, 83, 3938, 3360, 2416, 52, 4592, 572, 4499, 853, 117, 114, 83, 740, 155, 1515, 5092, 0, 122, 2428, 363, 341, 1277, 533, 347, 3622, 190, 4118, 1515, 2107, 347, 1912, 3796, 2091, 867, 2192, 114, 461, 837, 4564, 4004, 4836, 124, 4280, 52, 4561, 3938, 156, 117, 3033, 100, 3521, 329, 1676, 369, 319, 364, 176, 82, 562, 3965, 234, 2363, 1711, 3124, 318, 3228, 511, 4026, 4238, 3236, 312, 4565, 4751, 3091, 3421, 364, 4947, 365, 347, 359, 615, 2545, 3210, 1070, 4869, 3223, 3128, 1097, 3016, 2222, 2228, 111, 703, 4571, 169, 2598, 105, 2228, 4889, 2635, 1945, 769, 2105, 4560, 2068, 150, 2256, 2966, 2394, 175, 2207, 121, 4275, 1999, 4889,

In [14]:
len(nearest_indices_list)

5155

In [13]:
cdist(matrix, matrix, metric='cosine')

array([[0.00000000e+00, 8.70379673e-01, 6.72934804e-01, ...,
        9.58046734e-01, 8.41414129e-01, 7.68164595e-01],
       [8.70379673e-01, 0.00000000e+00, 8.64498517e-01, ...,
        1.00000000e+00, 9.14451413e-01, 8.87232965e-01],
       [6.72934804e-01, 8.64498517e-01, 0.00000000e+00, ...,
        9.12601185e-01, 6.50920434e-01, 5.15319435e-01],
       ...,
       [9.58046734e-01, 1.00000000e+00, 9.12601185e-01, ...,
        0.00000000e+00, 1.00000000e+00, 8.98259869e-01],
       [8.41414129e-01, 9.14451413e-01, 6.50920434e-01, ...,
        1.00000000e+00, 0.00000000e+00, 7.85272243e-01],
       [7.68164595e-01, 8.87232965e-01, 5.15319435e-01, ...,
        8.98259869e-01, 7.85272243e-01, 1.11022302e-16]])