env: scGen

In [1]:
import sys
import importlib
import scanpy as sc
import pandas as pd
import numpy as np
from glob import glob
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle
import json
from anndata import AnnData
import scgen
from scvi.data import setup_anndata 

import torch

sys.path.append("/data1/lichen/code/single_cell_perturbation/scPerturb/Byte_Pert_Data/")

import v1
from v1.utils import *
from v1.dataloader import *

In [2]:
importlib.reload(v1)
importlib.reload(v1.utils)
importlib.reload(v1.dataloader)

<module 'v1.dataloader' from '/data1/lichen/code/single_cell_perturbation/scPerturb/Byte_Pert_Data/v1/dataloader.py'>

# direct transfer跑unseen celltype

In [3]:
# =================================== initial

# - load common_pert
with open('/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/utils_data/prefix_gene_dict.json', 'r') as f:
    prefix_gene_dict = json.load(f)
with open('/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/utils_data/prefix_pert_dict.json', 'r') as f:
    prefix_pert_dict = json.load(f)
common_genes = np.intersect1d(prefix_gene_dict['ReplogleWeissman2022_K562_essential'],prefix_gene_dict['ReplogleWeissman2022_rpe1'])
common_perts = np.intersect1d(prefix_pert_dict['ReplogleWeissman2022_K562_essential'],prefix_pert_dict['ReplogleWeissman2022_rpe1'])
np.random.seed(2024)
np.random.shuffle(common_perts)

# - split the perts
split_point = int(len(common_perts) * 0.9)
test_perts = common_perts[:split_point]
val_perts = common_perts[split_point:]

# - init para
data_dir = '/nfs/public/lichen/data/single_cell/perturb_data/scPerturb/raw/scPerturb_rna/statistic_20240520'
pert_cell_filter = 100 # this is used to filter perts, cell number less than this will be filtered
seed = 2024 # this is the random seed
split_type = 1 # 1 for unseen perts; 0 for unseen celltypes
split_ratio = [0.7, 0.2, 0.1] # train:test:val; val is used to choose data, test is for final validation
var_num = 5000 # selecting hvg number
num_de_genes = 20 # number of de genes
bs_train = 32 # batch size of trainloader
bs_test = 32 # batch size of testloader
lr = 1e-4

def get_common_pert(pert_data, common_genes):
    # - get the common_genes **************
    pert_data.var_genes = list(common_genes)
    pert_data.adata_split = pert_data.adata_split[:, pert_data.var_genes].copy()
    pert_idx_dict = {}
    for pert, tmp_list in pert_data.adata_split.uns['rank_genes_groups'].items():
        idx_list = []
        for i, gene in enumerate(tmp_list):
            if gene in pert_data.adata_split.var_names:
                idx_list.append(i)
        pert_idx_dict[pert] = idx_list
    for key in pert_data.adata_split.uns.keys():
        print(key)
        ele = pert_data.adata_split.uns[key]
        for pert in ele.keys():
            ele[pert] = list(np.array(ele[pert])[pert_idx_dict[pert]])

tmp_dir = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/scPerturb'
# - load k562 pert_data
prefix = 'ReplogleWeissman2022_K562_essential'
save_prefix = f'GEARS_v2-prefix_{prefix}-pert_cell_filter_{pert_cell_filter}-\
seed_{seed}-split_type_{split_type}-var_num_{var_num}-num_de_genes_{num_de_genes}-bs_train_{bs_train}-\
bs_test_{bs_test}'
save_dir = os.path.join(tmp_dir, prefix, save_prefix)
pert_data_1 = pickle.load(open(os.path.join(save_dir,'pert_data.pkl'), 'rb'))
get_common_pert(pert_data_1, common_genes)
print('='*20, 'k562 pert_data loaded')

# - load rpe1 pert_data
prefix = 'ReplogleWeissman2022_rpe1'
save_prefix = f'GEARS_v2-prefix_{prefix}-pert_cell_filter_{pert_cell_filter}-\
seed_{seed}-split_type_{split_type}-var_num_{var_num}-num_de_genes_{num_de_genes}-bs_train_{bs_train}-\
bs_test_{bs_test}'
save_dir = os.path.join(tmp_dir, prefix, save_prefix)
pert_data_2 = pickle.load(open(os.path.join(save_dir,'pert_data.pkl'), 'rb'))
get_common_pert(pert_data_2, common_genes)
print('='*20, 'rpe1 pert_data loaded')

# - modify 2 pert_data
pert_data_1.modify_gears(without_subgroup=True)
pert_data_2.modify_gears(without_subgroup=True)

# - give celltypes
cell_type_1, cell_type_2 = 'K562', 'retinal pigment epithelial cells'

# - get pert_gorups
pert_groups_1 = [i+' | '+cell_type_1 for i in common_perts]
pert_groups_2 = [i+' | '+cell_type_2 for i in common_perts]

adata_1 = pert_data_1.adata_split.copy()
adata_2 = pert_data_2.adata_split.copy()


rank_genes_groups
pvals
pvals_adj
scores
logfoldchanges
rank_genes_groups
pvals
pvals_adj
scores
logfoldchanges
add adata finished
add condition finished
add set2conditions finished
add adata finished
add condition finished
add set2conditions finished


In [4]:
from tqdm import tqdm

pert_cat_list, pred_list, truth_list, pred_de_list, truth_de_list = [], [], [], [], []
for i in tqdm(range(len(pert_groups_1))):
    fix_seed(2024)
    #  - get adatas
    pert_group_1, pert_group_2 = pert_groups_1[i], pert_groups_2[i]
    
    adata_pert_1 = adata_1[adata_1.obs['perturbation_group']==pert_group_1]
    adata_pert_2 = adata_2[adata_2.obs['perturbation_group']==pert_group_2]
    
    adata_ctrl_1 = adata_1[adata_pert_1.obs['control_barcode']]
    adata_ctrl_2 = adata_2[adata_pert_2.obs['control_barcode']]
    
    delta = np.mean((adata_pert_1.X - adata_ctrl_1.X), axis=0)
    pred_X = adata_ctrl_2.X.toarray() + delta
    pred_X[pred_X < 0 ] = 0
    pred_X = np.array(pred_X)

    
    # ====================================== generate out.pkl
    # - get test_res
    pert_gears = transform_name(pert_group_2)
    geneid2idx = dict(zip(pert_data_2.adata.var.index.values, range(len(pert_data_2.adata.var.index.values))))
    pert2pert_full_id = dict(pert_data_2.adata.obs[['condition', 'condition_name']].values)
    de_idx = [geneid2idx[j] for j in pert_data_2.adata.uns['rank_genes_groups_cov_all'][pert2pert_full_id[pert_gears]][:20]]

    # - add to list
    pert_cat_list.append(np.array([pert_gears]*adata_pert_2.shape[0]))
    pred_list.append(pred_X)
    truth_list.append(adata_pert_2.X.toarray())
    pred_de_list.append(pred_X[:,de_idx])
    truth_de_list.append(adata_pert_2.X.toarray()[:, de_idx])
    

100%|██████████| 490/490 [01:59<00:00,  4.10it/s]


In [5]:
test_res = {}
test_res['pert_cat'] = np.hstack(pert_cat_list)
test_res['pred'] = np.vstack(pred_list)
test_res['truth'] = np.vstack(truth_list)
test_res['pred_de'] = np.vstack(pred_de_list)
test_res['truth_de'] = np.vstack(truth_de_list)
# - get the out
out = get_metric(pert_data_2.adata, test_res)
# out_non_dropout = non_dropout_analysis(pert_data_2.adata, test_res)

tmp_dir = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark_202410/unseen_celltype/K562_RPE1/result'
prefix = 'direct_transfer_v1'
save_dir = os.path.join(tmp_dir, prefix)
os.makedirs(save_dir, exist_ok=True)
pickle.dump(out, open(os.path.join(save_dir,f'out.pkl'), 'wb'))
pickle.dump(test_res, open(os.path.join(save_dir,f'test_res.pkl'), 'wb'))
# pickle.dump(out_non_dropout, open(os.path.join(save_dir,f'out_non_dropout.pkl'), 'wb'))

get metrics... ...


  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth)/truth
  change_ratio = np.abs(pred-truth