env: cell_rank

# import

In [1]:
import scanpy as sc
import pandas as pd
import numpy as np
from glob import glob
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
# import torch
# from torch_geometric.data import Data
# from torch_geometric.data import DataLoader

import pickle
import sys
import requests

from types import MethodType
import importlib
# from scperturb import *
import anndata as ad

import scvelo as scv

Global seed set to 0


# initial

In [6]:
dataset_pert_dict = {
    'CAR_T': ['PDCD1'],
    'blood': ['GATA1', 'SPI1'],
    'OSKM': [['SOX2',
         'POU5F1',
         'KLF4',
         'MYC']],
    'ADM': ['PTF1A']
}
dataset_celltype_dict = {
    'CAR_T': 'Tex', 
    'blood': 'LMPP',
    'OSKM': 'Fibroblast-like',
    'ADM': 'Acinar'
}

dataset_dire_dict = {
    'CAR_T': 'down', 
    'blood': 'down',
    'OSKM': 'up',
    'ADM': 'down'
}


datasets = list(dataset_pert_dict.keys())

# correlation_method跑所有datasets

In [7]:
celltype_col = 'celltype_v2'

for dataset in datasets:
    # dataset = 'OSKM'
    # dataset = 'CAR_T'
    print('='*20, dataset, '='*20)
    save_dir = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/real_case/data'
    adata_rna = sc.read(os.path.join(save_dir, dataset, 'adata_ctrl_v2.h5ad'))
    if not isinstance(adata_rna.X, np.ndarray):
        adata_rna.X = adata_rna.X.toarray()
    adata = adata_rna.copy()

    # - consctrut corr mtx
    if not isinstance(adata_rna.X, np.ndarray):
        adata_rna.X = adata_rna.X.toarray()
    corr_mtx = np.corrcoef(adata_rna[adata_rna.obs[celltype_col].isin([dataset_celltype_dict[dataset]])].X.T)
    corr_mtx[np.isnan(corr_mtx)] = 0

    var_names = list(adata_rna.var_names)

    for pert in tqdm(dataset_pert_dict[dataset]):
        # pert = ['SOX2','POU5F1','KLF4','MYC']
        # pert = 'PDCD1'
        print('*'*20, pert)
        # - use for multiple perts
        if isinstance(pert, str):
            pert_combo = [pert]
        else:
            pert_combo = pert

        pert_prefix = '_'.join(pert_combo)

        if dataset_dire_dict[dataset] == 'down':
            exp_change = np.zeros(adata_rna.shape)
            for tmp_pert in pert_combo:
                # - get the corr change
                pert_value = adata_rna.X[:, var_names.index(tmp_pert)] # value of the pert gene, 500 cells
                pert_corr = corr_mtx[var_names.index(tmp_pert), :] # corr of the pert gene to total genes
                exp_change = np.dot(pert_value.reshape(-1, 1), pert_corr.reshape(1, -1))
            
            # - create adata_pert
            adata_pert = adata_rna.copy()
            adata_pert.X -= exp_change
            adata_pert.X[adata_pert.X < 0] = 0
            adata_pert.obs_names = [i+f'_{pert_prefix}' for i in adata_pert.obs_names]
        else:
            exp_change = np.zeros(adata_rna.shape)
            for tmp_pert in pert_combo:
                # - get the corr change
                pert_value = np.ones(adata_rna.shape[0])
                pert_corr = corr_mtx[var_names.index(tmp_pert), :] # corr of the pert gene to total genes
                exp_change = np.dot(pert_value.reshape(-1, 1), pert_corr.reshape(1, -1))
            
            # - create adata_pert
            adata_pert = adata_rna.copy()
            adata_pert.X += exp_change # the sign is very important!
            adata_pert.X[adata_pert.X < 0] = 0
            adata_pert.obs_names = [i+f'_{pert_prefix}' for i in adata_pert.obs_names]
        
        # - adata_ctrl
        adata_ctrl = adata_rna.copy()
        
        adata_ctrl.obs['batch'] = 'ctrl'
        adata_pert.obs['batch'] = 'pert'
        
        adata_concat = ad.concat([adata_ctrl, adata_pert])
        
        tmp_dir = f'/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark_202410/real_case/result/{dataset}'
        save_prefix = f'correlation_method/{pert_prefix}' # use result of K562 to do the direct transfer
        os.makedirs(os.path.join(tmp_dir, save_prefix), exist_ok=True)
        adata_pert.write(os.path.join(tmp_dir, save_prefix, 'adata_pert.h5ad'))

    #     break

    # break




  0%|          | 0/1 [00:00<?, ?it/s]

******************** PDCD1


100%|██████████| 1/1 [00:01<00:00,  1.15s/it]




  0%|          | 0/2 [00:00<?, ?it/s]

******************** GATA1


 50%|█████     | 1/2 [00:11<00:11, 11.33s/it]

******************** SPI1


100%|██████████| 2/2 [00:36<00:00, 18.09s/it]




  0%|          | 0/1 [00:00<?, ?it/s]

******************** ['SOX2', 'POU5F1', 'KLF4', 'MYC']


100%|██████████| 1/1 [00:45<00:00, 45.04s/it]




  0%|          | 0/1 [00:00<?, ?it/s]

******************** PTF1A


100%|██████████| 1/1 [00:02<00:00,  2.08s/it]


In [5]:
adata_pert.X - adata_ctrl.X

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)