# Create cotrain perturabation data

## 1. import the necessary libraries

In [1]:
import xpert as xp
from xpert.data.utils import get_info_txt
import scanpy as sc
import numpy as np
import pandas as pd
from types import MethodType
from tqdm import tqdm
import os
import pickle
import warnings
import logging
warnings.filterwarnings('ignore')
sc.settings.verbosity = 0
logging.getLogger('scanpy').setLevel(logging.ERROR)
logging.getLogger('anndata').setLevel(logging.ERROR)


## 2. load adata and get specific cell linesm

In [2]:
cell_line = 'HT29'

In [3]:
adata_sig = sc.read('../../data/L1000_phase1_cotrain/adata_sig.h5ad')
adata_sig

AnnData object with n_obs × n_vars = 473647 × 978
    obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'

In [4]:
adata_pert = adata_sig[((adata_sig.obs['pert_type']=='trt_sh.cgs')|(adata_sig.obs['pert_type']=='trt_cp'))&(adata_sig.obs['cell_id']==cell_line)].copy()
adata_pert

AnnData object with n_obs × n_vars = 17815 × 978
    obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'

In [5]:
adata_pert.obs['pert_type'].value_counts()

trt_cp        14513
trt_sh.cgs     3302
Name: pert_type, dtype: int64

add smiles

In [6]:
pert_id_unique = pd.Series(np.unique(adata_pert.obs.pert_id))
print(f"# of unique perturbations: {len(pert_id_unique)}")

# of unique perturbations: 14215


In [7]:
import pathlib
reference_df = pd.read_csv('../../data/L1000_phase1_cotrain/GSE92742_Broad_LINCS_pert_info.txt.gz', delimiter = "\t")

In [8]:
reference_df = reference_df.loc[reference_df.pert_id.isin(pert_id_unique), ['pert_id', 'canonical_smiles']]
reference_df.canonical_smiles.value_counts()

-666                                                                                                                                                             3338
restricted                                                                                                                                                         14
CS(=O)(=O)CCNCc1ccc(o1)-c1ccc2ncnc(Nc3ccc(OCc4cccc(F)c4)c(Cl)c3)c2c1                                                                                                2
CCOC(=O)C1=C(NC(=C(C1C2=CC=CC=C2Cl)C(=O)OC)C)COCCN                                                                                                                  2
CO[C@H]1\C=C\O[C@@]2(C)Oc3c(C2=O)c2c(O)c(\C=N\N4CCN(C)CC4)c(NC(=O)\C(C)=C/C=C/[C@H](C)[C@H](O)[C@@H](C)[C@@H](O)[C@@H](C)[C@H](OC(C)=O)[C@@H]1C)c(O)c2c(O)c3C       2
                                                                                                                                                                 ... 
COc1

In [9]:
adata_pert.obs = pd.merge(adata_pert.obs, reference_df, on='pert_id', how='left')
adata_pert.obs_names = adata_pert.obs_names.astype(str)

In [10]:
adata_drug = adata_pert[adata_pert.obs['pert_type']=='trt_cp'].copy()
adata_gene = adata_pert[adata_pert.obs['pert_type']=='trt_sh.cgs'].copy()
adata_drug, adata_gene

(AnnData object with n_obs × n_vars = 14513 × 978
     obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'
     var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing',
 AnnData object with n_obs × n_vars = 3302 × 978
     obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'
     var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing')

In [11]:
# remove invalid smiles
adata_drug.obs.loc[:, 'canonical_smiles'] = adata_drug.obs.canonical_smiles.astype('str')
invalid_smiles = adata_drug.obs.canonical_smiles.isin(['-666', 
                                                  'restricted', 
                                                  'nan'
                                                  ])
# cond = adata_drug.obs['pert_type']=='trt_sh.cgs'
print(f'Among {len(adata_drug)} observations, {100*invalid_smiles.sum()/len(adata_drug):.2f}% ({invalid_smiles.sum()}) have an invalid SMILES string')
adata_drug = adata_drug[(~invalid_smiles)].copy()
adata_drug

Among 14513 observations, 0.76% (110) have an invalid SMILES string


AnnData object with n_obs × n_vars = 14403 × 978
    obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'

In [12]:
# - contruct pert_embed to delete same embeddings
pert_embed_dict = {}
pert_embed = pd.read_csv('../../data/L1000_phase1_cotrain/embed_ecfp.csv', sep = ",", index_col=0)

for pert in pert_embed.columns:
    if pert in pert_embed.columns:
        pert_embed_dict[pert] = pert_embed.loc[:, pert].values
    else:
        print(f'{pert} not in pert_embed')
        pert_embed_dict[pert] = pert_embed.loc[:, np.random.choice(pert_embed.columns, 1)[0]].values


# - create drug to embeddings
embed_drugs = np.unique(adata_drug.obs['pert_iname'])
drug_embedding_dict = {}
for i, drug in enumerate(embed_drugs):
    drug_embedding_dict[drug] = pert_embed_dict[drug]


from collections import defaultdict
# 将相同的向量聚在一起
embedding_to_drugs = defaultdict(list)

for drug, emb in drug_embedding_dict.items():
    emb_key = tuple(emb.tolist())  # 把 ndarray 转成 hashable 的 tuple
    embedding_to_drugs[emb_key].append(drug)

len(embedding_to_drugs)

unique_drugs = [value[0] for key, value in embedding_to_drugs.items()]
len(unique_drugs)
adata_drug = adata_drug[adata_drug.obs['pert_iname'].isin(unique_drugs)].copy()
adata_drug
# remove dulplicated smiles
dup_mask = adata_drug.obs['canonical_smiles'].duplicated(keep='first')   # True == 要删
adata_drug = adata_drug[~dup_mask, :].copy()
adata_drug

AnnData object with n_obs × n_vars = 6353 × 978
    obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'

## 3. Get control cells

In [13]:
adata_level3 = sc.read('../../data/L1000_phase1_cotrain/adata_inst.h5ad')
adata_level3

AnnData object with n_obs × n_vars = 1319138 × 978
    obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'

In [14]:
cond_1 = adata_level3.obs['cell_id'] == cell_line
cond_2 = adata_level3.obs['pert_type'] == 'ctl_untrt'
adata_level3_part = adata_level3[cond_1&cond_2]
adata_level3_part

View of AnnData object with n_obs × n_vars = 1917 × 978
    obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'

In [15]:
import anndata as ad
adata_level3_ctrl = ad.AnnData(X=np.mean(adata_level3_part.X, axis=0).reshape(1, -1),
                               obs = pd.DataFrame(adata_level3_part.obs.iloc[0,:]).T,
                               var = pd.DataFrame(index=adata_level3_part.var_names))
adata_level3_ctrl

AnnData object with n_obs × n_vars = 1 × 978
    obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id'

In [16]:
adata_concat = ad.concat([adata_drug, adata_gene, adata_level3_ctrl], join = 'outer')
adata_concat

AnnData object with n_obs × n_vars = 9656 × 978
    obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles', 'inst_id', 'rna_plate', 'rna_well'

## 4. Add basic attributes

In [17]:
adata = adata_concat.copy()

In [18]:
def get_perturbation_group(x):
    if x['pert_iname'] == 'UnTrt':
        return ' | '.join(['control', x['dose'], x['cell_id']])
    return ' | '.join([x['pert_iname'], x['dose'], x['cell_id']])

def get_perturbation_new(x):
    if x['pert_iname'] == 'UnTrt':
        return 'control'
    else:
        return x['pert_iname']
    
def get_cell_type_new(x):
    return x['cell_id']

adata.obs['dose'] = adata.obs['pert_dose']

perturbation_group = adata.obs.apply(get_perturbation_group, axis=1)
adata.obs['perturbation_group'] = perturbation_group

perturbation_new = adata.obs.apply(get_perturbation_new, axis=1)
adata.obs['perturbation_new'] = perturbation_new

celltype_new = adata.obs.apply(get_cell_type_new, axis=1)
adata.obs['celltype_new'] = celltype_new

    

## 5. generate Pert_Data object



In [19]:
# parameters

pert_cell_filter = 0 # this is used to filter perts, cell number less than this will be filtered
seed = 2024 # this is the random seed
split_type = 1 # 1 for unseen perts; 0 for unseen celltypes
split_ratio = [0.8, 0.2, 0] # train:test:val; val is used to choose data, test is for final validation
var_num = 5000 # selecting hvg number
num_de_genes = 20 # number of de genes
bs_train = 32 # batch size of trainloader
bs_test = 32 # batch size of testloader
data_dir = '../../data/L1000_phase1_cotrain/'

In [20]:
# Create Pert_Data object
pert_data = xp.data.Byte_Pert_Data(
    prefix='L1000_phase1',
    pert_cell_filter=pert_cell_filter,
    seed=seed,
    split_ratio=split_ratio,
    split_type=split_type,
    var_num=var_num,
    num_de_genes=num_de_genes,
    bs_train=bs_train,
    bs_test=bs_test
)

# Complete data processing pipeline
print("Step 1: Reading files...")
pert_data.read_files(adata)

# Get the filter_perturbation_list
pert_data.adata_split = pert_data.adata_ori
tmp_obs = pert_data.adata_split[pert_data.adata_split.obs['perturbation_new']!='control'].obs
pert_data.filter_perturbation_list = list(tmp_obs['perturbation_group'].unique()) # record the perturbation pair



Step 1: Reading files...


## 6. Rewrite cell pair function 

In [21]:
def set_control_barcode(self):
    """
    this function is used to set control_barcode for each pert
    """

    self.obs_df_split = self.adata_split.obs.copy()
    
    # - set all control_barcode to None
    self.obs_df_split['control_barcode'] = 'None'
    
    # # - get all the control barcodes
    # control_obs = np.array(self.obs_df_split[(self.obs_df_split['perturbation_new']=='control')].index)
    
    np.random.seed(self.seed)
    for pert in tqdm(self.filter_perturbation_list):
        # - get the pert control
        # - get all the control barcodes
        control_obs = np.array(self.obs_df_split[(self.obs_df_split['perturbation_group']==' | '.join(['control', '-666.0', pert.split(' | ')[-1]]))].index)
        
        obs_df_sub_idx = np.array(self.obs_df_split[self.obs_df_split['perturbation_group']==pert].index)
        # - get the paired control
        pair_control_obs = np.random.choice(control_obs, len(obs_df_sub_idx), replace=True)
        # - set the control barcode
        self.obs_df_split.loc[obs_df_sub_idx,'control_barcode'] = pair_control_obs
        
    self.adata_split.obs = self.obs_df_split
    print('='*10,f'set control barcodes finished!')

def get_de_genes(self,
                rankby_abs = True,
                key_added = 'rank_genes_groups'):
    gene_dict = {}
    pvals_dict, pvals_adj_dict, scores_dict, logfoldchanges_dict = {}, {}, {}, {}

    for pert in tqdm(self.filter_perturbation_list):
        

        gene_dict[pert] = list(self.adata_split.var_names)
        pvals_dict[pert] = [0.1]*len(self.adata_split.var_names)
        pvals_adj_dict[pert] = [0.1]*len(self.adata_split.var_names)
        scores_dict[pert] = [0.1]*len(self.adata_split.var_names)
        logfoldchanges_dict[pert] = [0.1]*len(self.adata_split.var_names)
        
    self.adata_split.uns[key_added] = gene_dict
    self.adata_split.uns['pvals'] = pvals_dict
    self.adata_split.uns['pvals_adj'] = pvals_adj_dict
    self.adata_split.uns['scores'] = scores_dict
    self.adata_split.uns['logfoldchanges'] = logfoldchanges_dict
    print('='*10,f'get de genes finished!')

pert_data.set_control_barcode = MethodType(set_control_barcode, pert_data)

## 7. Run process

In [22]:


# print("Step 2: Filtering perturbations...")
# pert_data.filter_perturbation()

# print("Step 3: Preprocessing adata and selecting HVGs...")
# pert_data.get_and_process_adata(var_num=pert_data.var_num)

print("Step 2: Setting control barcodes...")
pert_data.set_control_barcode()

# print("Step 5: Calculating E-distances...")
# pert_data.get_edis_2()

# print("Step 6: Filtering sgRNAs...")
# pert_data.adata_split.obs['sgRNA_new'] = 'control'
# pert_data.filter_sgRNA()

print("Step 3: Data splitting...")
# pert_data.data_split_2(split_type = 0,
#             test_perts = None)

pert_data.split_ratio = [0.9, 0.1, 0]
pert_data.data_split_2(split_type = 1,
            test_perts = None)

# pert_data.data_split_2(split_type = 2,
#             test_perts = None)

print("Step 4: Getting differential genes...")
get_de_genes(pert_data)

print("Step 5: Saving processed data...")
# Save the processed data
pickle.dump(pert_data, open(os.path.join(data_dir, 'pert_data.pkl'), 'wb'))

print("Pert_Data object generation completed!")
print(f"Final dataset shape: {pert_data.adata_split.shape}")
print(f"Number of perturbation groups: {len(pert_data.filter_perturbation_list)}")

Step 2: Setting control barcodes...


100%|██████████| 9405/9405 [00:15<00:00, 600.49it/s]


Step 3: Data splitting...
Step 4: Getting differential genes...


100%|██████████| 9405/9405 [00:00<00:00, 11605.10it/s]


Step 5: Saving processed data...
Pert_Data object generation completed!
Final dataset shape: (9656, 978)
Number of perturbation groups: 9405
