This notebook is used to create perturbation data for X-Pert training

# import the necessary libraries

In [9]:
import xpert as xp
from xpert.data.utils import get_info_txt
import scanpy as sc
import numpy as np
from types import MethodType
from tqdm import tqdm
import os
import pickle
import warnings
import logging
warnings.filterwarnings('ignore')
sc.settings.verbosity = 0
logging.getLogger('scanpy').setLevel(logging.ERROR)
logging.getLogger('anndata').setLevel(logging.ERROR)


# load original adata file of perturbation data

In [4]:
adata = sc.read('../../data/L1000_phase1/adata_L1000_phase1.h5ad')
adata

AnnData object with n_obs × n_vars = 678401 × 978
    obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id', 'canonical_smiles', 'plate_id'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'

add obs.column with 'perturbation_new' and 'cell_type_new', which is used to perform data proprecessing.

In [5]:
def get_perturbation_group(x):
    if x['pert_iname'] == 'DMSO':
        return ' | '.join(['control', x['dose'], x['cell_id']])
    return ' | '.join([x['pert_iname'], x['dose'], x['cell_id']])

def get_perturbation_new(x):
    if x['pert_iname'] == 'DMSO':
        return 'control'
    else:
        return x['pert_iname']
    
def get_cell_type_new(x):
    return x['cell_id']

adata.obs['dose'] = adata.obs['pert_dose']

perturbation_group = adata.obs.apply(get_perturbation_group, axis=1)
adata.obs['perturbation_group'] = perturbation_group

perturbation_new = adata.obs.apply(get_perturbation_new, axis=1)
adata.obs['perturbation_new'] = perturbation_new

celltype_new = adata.obs.apply(get_cell_type_new, axis=1)
adata.obs['celltype_new'] = celltype_new

    

# generate Pert_Data object



In [14]:
# parameters

pert_cell_filter = 0 # this is used to filter perts, cell number less than this will be filtered
seed = 2024 # this is the random seed
split_type = 1 # 1 for unseen perts; 0 for unseen celltypes
split_ratio = [0.8, 0.2, 0] # train:test:val; val is used to choose data, test is for final validation
var_num = 5000 # selecting hvg number
num_de_genes = 20 # number of de genes
bs_train = 32 # batch size of trainloader
bs_test = 32 # batch size of testloader
data_dir = '../../data/L1000_phase1/'

In [15]:
# Create Pert_Data object
pert_data = xp.data.Byte_Pert_Data(
    prefix='L1000_phase1',
    pert_cell_filter=pert_cell_filter,
    seed=seed,
    split_ratio=split_ratio,
    split_type=split_type,
    var_num=var_num,
    num_de_genes=num_de_genes,
    bs_train=bs_train,
    bs_test=bs_test
)

# Complete data processing pipeline
print("Step 1: Reading files...")
pert_data.read_files(adata)

# Get the filter_perturbation_list
pert_data.adata_split = pert_data.adata_ori
tmp_obs = pert_data.adata_split[pert_data.adata_split.obs['perturbation_new']!='control'].obs
pert_data.filter_perturbation_list = list(tmp_obs['perturbation_group'].unique()) # record the perturbation pair



Step 1: Reading files...


# Rewrite cell pair function 

In [16]:
def set_control_barcode(self):
    """
    this function is used to set control_barcode for each pert
    """

    # - get pert_index_dict
    obs_df = pert_data.adata_split.obs
    obs_df['control_barcode'] = 'None'

    pert_index_dict = {}

    for i in tqdm(range(len(obs_df))):
        palte_id, cell_type, pert_time, perturbation = obs_df.loc[str(i), 'plate_id'], obs_df.loc[str(i), 'celltype_new'], obs_df.loc[str(i), 'pert_time'], obs_df.loc[str(i), 'perturbation_new']
        if perturbation == 'control':
            _key = ' | '.join([palte_id, cell_type, str(pert_time)])
            if _key in pert_index_dict:
                pert_index_dict[_key].append(str(i))
            else:
                pert_index_dict[_key] = []
                pert_index_dict[_key].append(str(i))
        i += 1
    
    # - set control_barcode
    np.random.seed(pert_data.seed)
    for i in tqdm(range(len(obs_df))):
        palte_id, cell_type, pert_time, perturbation = obs_df.loc[str(i), 'plate_id'], obs_df.loc[str(i), 'celltype_new'], obs_df.loc[str(i), 'pert_time'], obs_df.loc[str(i), 'perturbation_new']
        if perturbation != 'control':
            _key = ' | '.join([palte_id, cell_type, str(pert_time)])
            pair_control_obs = np.random.choice(pert_index_dict[_key], 1, replace=True)
            obs_df.loc[str(i),'control_barcode'] = pair_control_obs[0]

    pert_data.adata_split.obs = obs_df

def get_de_genes(self,
                rankby_abs = True,
                key_added = 'rank_genes_groups'):
    gene_dict = {}
    pvals_dict, pvals_adj_dict, scores_dict, logfoldchanges_dict = {}, {}, {}, {}

    for pert in tqdm(self.filter_perturbation_list):
        

        gene_dict[pert] = list(self.adata_split.var_names)
        pvals_dict[pert] = [0.1]*len(self.adata_split.var_names)
        pvals_adj_dict[pert] = [0.1]*len(self.adata_split.var_names)
        scores_dict[pert] = [0.1]*len(self.adata_split.var_names)
        logfoldchanges_dict[pert] = [0.1]*len(self.adata_split.var_names)
        
    self.adata_split.uns[key_added] = gene_dict
    self.adata_split.uns['pvals'] = pvals_dict
    self.adata_split.uns['pvals_adj'] = pvals_adj_dict
    self.adata_split.uns['scores'] = scores_dict
    self.adata_split.uns['logfoldchanges'] = logfoldchanges_dict
    print('='*10,f'get de genes finished!')

pert_data.set_control_barcode = MethodType(set_control_barcode, pert_data)

# Run process

In [17]:


# print("Step 2: Filtering perturbations...")
# pert_data.filter_perturbation()

# print("Step 3: Preprocessing adata and selecting HVGs...")
# pert_data.get_and_process_adata(var_num=pert_data.var_num)

print("Step 2: Setting control barcodes...")
pert_data.set_control_barcode()

# print("Step 5: Calculating E-distances...")
# pert_data.get_edis_2()

# print("Step 6: Filtering sgRNAs...")
# pert_data.adata_split.obs['sgRNA_new'] = 'control'
# pert_data.filter_sgRNA()

print("Step 3: Data splitting...")
pert_data.data_split_2(split_type = 0,
            test_perts = None)

pert_data.data_split_2(split_type = 1,
            test_perts = None)

pert_data.data_split_2(split_type = 2,
            test_perts = None)

print("Step 8: Getting differential genes...")
get_de_genes(pert_data)

print("Step 9: Saving processed data...")
# Save the processed data
pickle.dump(pert_data, open(os.path.join(data_dir, 'pert_data.pkl'), 'wb'))

print("Pert_Data object generation completed!")
print(f"Final dataset shape: {pert_data.adata_split.shape}")
print(f"Number of perturbation groups: {len(pert_data.filter_perturbation_list)}")

Step 2: Setting control barcodes...


  0%|          | 0/678401 [00:00<?, ?it/s]

100%|██████████| 678401/678401 [00:16<00:00, 40795.08it/s]
100%|██████████| 678401/678401 [01:18<00:00, 8594.04it/s] 


Step 3: Data splitting...
Step 8: Getting differential genes...


100%|██████████| 128725/128725 [00:18<00:00, 6873.53it/s] 


Step 9: Saving processed data...
Pert_Data object generation completed!
Final dataset shape: (678401, 978)
Number of perturbation groups: 128725
