Generate different sciplex datasets for training/testing scFiLM.

2000 , 3500, 5000 hvg kept

all cells vs. balanced vs. subsample

pathway activation, genes, embeddings

COATI + RDKit compound emveddings

In [41]:
import random

import anndata as ad
import scanpy as sc
import numpy as np

from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import AllChem

## Create Morgan Fingerprints and Coati Drug Representations

In [2]:
adata_raw = ad.read_h5ad("../../data/sciplex/sciplex3_uce_adata.h5ad")
adata_old = ad.read_h5ad("../../data/sciplex/sciplex_preprocessed.h5ad")
adata_raw = adata_raw[adata_raw.obs['product_name'].isin(list(adata_old.obs['product_name'].unique()))]

In [3]:
drugname_smiles_map = dict(zip(adata_old.obs['product_name'], adata_old.obs['smiles']))
drugname_pubchemid_map = dict(zip(adata_old.obs['product_name'], adata_old.obs['pubchem_ID']))
drugname_coati_map = dict(zip(adata_old.obs['product_name'], adata_old.obs['sm_embedding']))

In [None]:
smiles_rdkit_map = dict()

for smiles in drugname_smiles_map.values():
    if type(smiles) ==  float:
        continue
    else:
        mol = Chem.MolFromSmiles(smiles)
        fcfp4 = AllChem.GetMorganFingerprintAsBitVect(mol, 2, useFeatures=True, nBits=1024).ToBitString()
        fcfp4_list = np.array(list(fcfp4), dtype=np.float32)
    
        smiles_rdkit_map[smiles] = fcfp4_list

In [6]:
smiles = list()
pubchem_id = list()
coati_emb = list()
morgan_emb = list()

for product_name in list(adata_raw.obs['product_name']):
    if product_name == "Vehicle":
        smiles.append(None)
        pubchem_id.append(None)
        coati_emb.append(None)
        morgan_emb.append(None)

    else:
        smiles.append(drugname_smiles_map[product_name])
        pubchem_id.append(drugname_pubchemid_map[product_name])
        coati_emb.append(drugname_coati_map[product_name])
    
        sm = drugname_smiles_map[product_name]
        morgan_emb.append(smiles_rdkit_map[sm])

In [7]:
adata_raw.obs['SMILES'] = smiles
adata_raw.obs['pubchem_id'] = pubchem_id
adata_raw.obs['sm_coati_emb'] = coati_emb
adata_raw.obs['sm_morgan_emb'] = morgan_emb

adata_raw.obs['sm_morgan_emb'] = adata_raw.obs['sm_morgan_emb'].apply(
    lambda arr: ','.join(map(str, arr)) if arr is not None else None
)

  adata_raw.obs['SMILES'] = smiles


## Filter Low Quality cells and lowly expressed genes

In [11]:
sc.pp.filter_cells(adata_raw, min_genes=100)
sc.pp.filter_genes(adata_raw, min_cells=3)

In [12]:
adata_raw

AnnData object with n_obs × n_vars = 571696 × 17376
    obs: 'cell_type', 'dose', 'dose_character', 'dose_pattern', 'g1s_score', 'g2m_score', 'pathway', 'pathway_level_1', 'pathway_level_2', 'product_dose', 'product_name', 'proliferation_index', 'replicate', 'size_factor', 'target', 'vehicle', 'n_genes', 'SMILES', 'pubchem_id', 'sm_coati_emb', 'sm_morgan_emb'
    var: 'id', 'num_cells_expressed-0-0', 'num_cells_expressed-1-0', 'num_cells_expressed-1', 'n_cells'
    obsm: 'X_uce'

## Match Random Controls to each perturbed condition

In [19]:
match_index = list()

control_pool_A549 = adata_raw[(adata_raw.obs['product_name'] == "Vehicle") & (adata_raw.obs['cell_type'] == "A549")]
control_pool_MCF7 = adata_raw[(adata_raw.obs['product_name'] == "Vehicle") & (adata_raw.obs['cell_type'] == "MCF7")]
control_pool_K562 = adata_raw[(adata_raw.obs['product_name'] == "Vehicle") & (adata_raw.obs['cell_type'] == "K562")]

for i, row in adata_raw.obs.iterrows():
    if row['product_name'] == "Vehicle":
        match_index.append("None")
    else:
        if row['cell_type'] == "A549":
            match_index.append(random.choice(list(control_pool_A549.obs_names)))
        elif row['cell_type'] == "MCF7":
            match_index.append(random.choice(list(control_pool_MCF7.obs_names)))
        elif row['cell_type'] == "K562":
            match_index.append(random.choice(list(control_pool_K562.obs_names)))

adata_raw.obs['match_index'] = match_index

## Create different-dimensional HVG representations for cells

In [22]:
def get_preprocessed_expr(adata_raw, nhvg):
    adata = adata_raw.copy()
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, n_top_genes=nhvg, flavor='seurat')
    adata = adata[:, adata.var['highly_variable']]
    return adata.X, list(adata.var_names)

In [23]:
X_500_hvg, gene_names_500 = get_preprocessed_expr(adata_raw, 500)
X_1000_hvg, gene_names_1000 = get_preprocessed_expr(adata_raw, 1000)
X_2000_hvg, gene_names_2000 = get_preprocessed_expr(adata_raw, 2000)
X_3500_hvg, gene_names_3500 = get_preprocessed_expr(adata_raw, 3500)
X_5000_hvg, gene_names_5000 = get_preprocessed_expr(adata_raw, 5000)
X_7500_hvg, gene_names_7500 = get_preprocessed_expr(adata_raw, 7500)

In [24]:
adata_raw.obsm['X_500_hvg'] = X_500_hvg
adata_raw.obsm['X_1000_hvg'] = X_1000_hvg
adata_raw.obsm['X_2000_hvg'] = X_2000_hvg
adata_raw.obsm['X_3500_hvg'] = X_3500_hvg
adata_raw.obsm['X_5000_hvg'] = X_5000_hvg
adata_raw.obsm['X_7500_hvg'] = X_7500_hvg

adata_raw.uns['gene_names_500'] = gene_names_500
adata_raw.uns['gene_names_1000'] = gene_names_1000
adata_raw.uns['gene_names_2000'] = gene_names_2000
adata_raw.uns['gene_names_3500'] = gene_names_3500
adata_raw.uns['gene_names_5000'] = gene_names_5000
adata_raw.uns['gene_names_7500'] = gene_names_7500

In [25]:
adata_raw

AnnData object with n_obs × n_vars = 571696 × 17376
    obs: 'cell_type', 'dose', 'dose_character', 'dose_pattern', 'g1s_score', 'g2m_score', 'pathway', 'pathway_level_1', 'pathway_level_2', 'product_dose', 'product_name', 'proliferation_index', 'replicate', 'size_factor', 'target', 'vehicle', 'n_genes', 'SMILES', 'pubchem_id', 'sm_coati_emb', 'sm_morgan_emb', 'match_index'
    var: 'id', 'num_cells_expressed-0-0', 'num_cells_expressed-1-0', 'num_cells_expressed-1', 'n_cells'
    uns: 'gene_names_500', 'gene_names_1000', 'gene_names_2000', 'gene_names_3500', 'gene_names_5000', 'gene_names_7500'
    obsm: 'X_uce', 'X_500_hvg', 'X_1000_hvg', 'X_2000_hvg', 'X_3500_hvg', 'X_5000_hvg', 'X_7500_hvg'

## Create a mask for each individual perturbed condition at different sizes

In [42]:
def create_subset_mask(adata, max_n):
    indices_to_keep = list()

    for drug in tqdm(adata.obs['product_name'].unique()):
        for cell_type in adata.obs['cell_type'].unique():
            for dose in adata.obs['dose'].unique():
                if drug == "Vehicle":
                    continue
    
                if dose == 0.0:
                    continue
                
                else:
                    ad_subset = adata[(adata.obs['cell_type'] == cell_type) 
                    & (adata.obs['product_name'] == drug)
                    & (adata.obs['dose'] == dose)].copy()
    

                    if ad_subset.n_obs > max_n:
                        selected = np.random.choice(list(ad_subset.obs_names), size=max_n, replace=False)
                    else:
                        selected = list(ad_subset.obs_names)

                    indices_to_keep.extend(selected)

    control_indices = adata[adata.obs['product_name'] == "Vehicle"].obs_names
    indices_to_keep.extend(control_indices)

    mask = [1 if x in indices_to_keep else 0 for x in list(adata.obs_names)]
    return mask


In [43]:
mask_50 = create_subset_mask(adata_raw, 50)
mask_100 = create_subset_mask(adata_raw, 100)
mask_250 = create_subset_mask(adata_raw, 250)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 186/186 [00:15<00:00, 12.25it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 186/186 [00:14<00:00, 12.96it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 186/186 [00:14<00:00, 12.67it/s]


In [45]:
adata_raw.obs['mask_50'] = mask_50
adata_raw.obs['mask_100'] = mask_100
adata_raw.obs['mask_250'] = mask_250

In [46]:
adata_raw

AnnData object with n_obs × n_vars = 571696 × 17376
    obs: 'cell_type', 'dose', 'dose_character', 'dose_pattern', 'g1s_score', 'g2m_score', 'pathway', 'pathway_level_1', 'pathway_level_2', 'product_dose', 'product_name', 'proliferation_index', 'replicate', 'size_factor', 'target', 'vehicle', 'n_genes', 'SMILES', 'pubchem_id', 'sm_coati_emb', 'sm_morgan_emb', 'match_index', 'mask_50', 'mask_100', 'mask_250'
    var: 'id', 'num_cells_expressed-0-0', 'num_cells_expressed-1-0', 'num_cells_expressed-1', 'n_cells'
    uns: 'gene_names_500', 'gene_names_1000', 'gene_names_2000', 'gene_names_3500', 'gene_names_5000', 'gene_names_7500'
    obsm: 'X_uce', 'X_500_hvg', 'X_1000_hvg', 'X_2000_hvg', 'X_3500_hvg', 'X_5000_hvg', 'X_7500_hvg'

In [47]:
adata_raw.write_h5ad("../../data/sciplex/sciplex_full_v4.h5ad")

In [49]:
adata1 = adata_raw[adata_raw.obs['mask_50'] == 1]

In [51]:
adata1.obs['product_name'].unique()

['Raltitrexed', 'Vehicle', 'Lenalidomide (CC-5013)', 'MLN8054', 'Celecoxib', ..., 'PHA-680632', 'Bisindolylmaleimide IX (Ro 31-8220 Mesylate)', 'Selisistat (EX 527)', 'Quercetin', 'Tucidinostat (Chidamide)']
Length: 186
Categories (186, object): ['2-Methoxyestradiol (2-MeOE2)', '(+)-JQ1', 'A-366', 'ABT-737', ..., 'XAV-939', 'YM155 (Sepantronium Bromide)', 'ZM 447439', 'Zileuton']