# 2024-04-24-Preprocessing: Generating nested mcfaline23 subsets

PerturbSeq screen of interactions between chemical and genetic perturbations

In [1]:
import scanpy as sc
import pandas as pd
import numpy as np
import random
import os
from sklearn.model_selection import train_test_split

from perturbench.data.datasplitter import PerturbationDataSplitter

%load_ext autoreload
%autoreload 2

## Load data and generate subsets

In [None]:
data_cache_dir = './perturbench_data' ## Change this to your local data directory

splits_dir = f'{data_cache_dir}/mcfaline23_gxe_splits'
if not os.path.exists(splits_dir):
    os.makedirs(splits_dir)

In [None]:
adata = sc.read_h5ad(f'{data_cache_dir}/mcfaline23_gxe_processed.h5ad', backed='r')
adata

## Scale across covariates

In [3]:
adata.obs['cell_type_treat'] = adata.obs['cell_type'].astype(str) + '_' + adata.obs['treatment'].astype(str)

### Generate subsets

In [4]:
unique_covariates = [x for x in adata.obs.cell_type_treat.unique() if x != 'control']
unique_covariates_cell_type = [x.split('_')[0] for x in unique_covariates]
len(unique_covariates)

15

In [5]:
unique_covariates

['a172_none',
 'a172_nintedanib',
 'a172_zstk474',
 'a172_lapatinib',
 'a172_trametinib',
 't98g_nintedanib',
 't98g_lapatinib',
 't98g_none',
 't98g_trametinib',
 't98g_zstk474',
 'u87mg_zstk474',
 'u87mg_lapatinib',
 'u87mg_trametinib',
 'u87mg_nintedanib',
 'u87mg_none']

In [6]:
small_covariates_holdout = [
    'a172_nintedanib',
    't98g_lapatinib',
    'u87mg_none',
]
small_covariates_train = [
    'a172_none',
    't98g_nintedanib',
    'u87mg_lapatinib',
]

small_covariates = small_covariates_holdout + small_covariates_train

### Generate splits

#### Small

In [7]:
adata_small = adata[adata.obs.cell_type_treat.isin(small_covariates)].to_memory()
adata_small

AnnData object with n_obs × n_vars = 407469 × 15009
    obs: 'orig.ident', 'ncounts', 'ngenes', 'cell', 'sample', 'Size_Factor', 'n.umi', 'PCR_plate', 'new_cell', 'dose', 'treatment', 'gRNA_id', 'gene_id', 'guide_number', 'cell_type', 'drug_dose', 'perturbation_type', 'dataset', 'gene_dose', 'perturbation', 'pert_cl_tr', 'condition', 'condition_plus_treatment', 'cov_merged', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'control', 'dose_val', 'cov_drug_dose_name', 'cell_type_treat'
    var: 'ensembl_id', 'n_cells', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'
    uns: 'hvg', 'log1p', 'rank_genes_groups_cov'
    layers: 'counts'

In [8]:
unique_perturbations = [x for x in adata_small.obs.condition.unique() if x!= 'control']
len(unique_perturbations)

528

In [10]:
rng = np.random.default_rng(12345)

control_val_ix = []
control_test_ix = []
heldout_pert_covs = []
for cov in small_covariates_holdout:
    random.seed(int(rng.integers(0, 2**16)))
    sampled_perts = random.sample(unique_perturbations, int(0.7*len(unique_perturbations)))
    heldout_pert_covs += [p + '_' + cov for p in sampled_perts]
    
    cov_controls = adata_small[
        (adata_small.obs.cell_type_treat == cov) & (adata_small.obs.condition == 'control')
    ].obs_names.to_list()
    cov_controls_heldout = random.sample(cov_controls, int(0.5*len(cov_controls)))
    cov_controls_heldout_val, cov_controls_heldout_test = train_test_split(
        cov_controls_heldout, test_size=0.5, random_state=int(rng.integers(0, 2**16))
    )
    control_val_ix += cov_controls_heldout_val
    control_test_ix += cov_controls_heldout_test

random.seed(int(rng.integers(0, 2**16)))
val_pert_covs = random.sample(heldout_pert_covs, int(0.5*len(heldout_pert_covs)))
test_pert_covs = [x for x in heldout_pert_covs if x not in val_pert_covs]

len(val_pert_covs), len(test_pert_covs)


(553, 554)

In [11]:
split = pd.Series('train', index=adata_small.obs.index)

adata_small.obs['pert_cov'] = adata_small.obs.condition.astype(str) + '_' + adata_small.obs.cell_type_treat.astype(str)
val_ix = adata_small[adata_small.obs.pert_cov.isin(val_pert_covs)].obs_names.tolist()
test_ix = adata_small[adata_small.obs.pert_cov.isin(test_pert_covs)].obs_names.tolist()

split.loc[val_ix + control_val_ix] = 'val'
split.loc[test_ix + control_test_ix] = 'test'

split.value_counts()

train    268141
val       70300
test      69028
Name: count, dtype: int64

In [15]:
small_obs = adata_small.obs.copy()
small_obs['split'] = split
small_obs = small_obs.loc[:,['split', 'condition', 'cell_type_treat']].drop_duplicates()

for cov in small_covariates:
    print(cov)
    cov_obs = small_obs[small_obs.cell_type_treat == cov]
    for spl in ['train', 'val', 'test']:
        cov_obs_spl = cov_obs.loc[cov_obs.split == spl]
        if cov_obs_spl.shape[0] > 0:
            assert 'control' in cov_obs_spl.condition.unique()
        else:
            assert spl != 'train'
    
    print(cov_obs.split.value_counts())

a172_nintedanib
split
test     190
val      173
train    157
Name: count, dtype: int64
t98g_lapatinib
split
val      189
test     178
train    155
Name: count, dtype: int64
u87mg_none
split
val      183
test     178
train    155
Name: count, dtype: int64
a172_none
split
train    524
Name: count, dtype: int64
t98g_nintedanib
split
train    513
Name: count, dtype: int64
u87mg_lapatinib
split
train    521
Name: count, dtype: int64


In [16]:
split_padded = pd.Series(None, index=adata.obs.index)
split_padded.loc[split.index] = split
split_padded.value_counts()

  split_padded.loc[split.index] = split


train    268141
val       70300
test      69028
Name: count, dtype: int64

In [17]:
len(split_padded)

878229

In [18]:
split_padded.to_csv(f'{data_cache_dir}/mcfaline23_gxe_splits/small_covariate_split.csv', header=False)

#### Medium

In [19]:
medium_covariates = [
    'a172_lapatinib',
    't98g_none',
    'u87mg_nintedanib',
] + small_covariates
medium_covariates

In [None]:
adata_medium = adata[adata.obs.cell_type_treat.isin(medium_covariates)]
adata_medium

In [None]:
medium_split = pd.Series('train', index=adata_medium.obs.index)
for split_val in split.unique():
    split_idx = split[split == split_val].index
    medium_split.loc[split_idx] = split_val
medium_split.value_counts()

train    514733
val       57752
test      55544
Name: count, dtype: int64

In [None]:
medium_split_padded = pd.Series(None, index=adata.obs.index)
medium_split_padded.loc[medium_split.index] = medium_split
len(medium_split_padded)

  medium_split_padded.loc[medium_split.index] = medium_split


878229

In [None]:
medium_split_padded.to_csv(f'{data_cache_dir}/mcfaline23_gxe_splits/medium_covariate_split.csv', header=False)

#### Full

In [None]:
full_split = pd.Series('train', index=adata.obs.index)
for split_val in split.unique():
    split_idx = split[split == split_val].index
    full_split.loc[split_idx] = split_val
full_split.value_counts()

train    764933
val       57752
test      55544
Name: count, dtype: int64

In [None]:
full_split.to_csv(f'{data_cache_dir}/mcfaline23_gxe_splits/full_covariate_split.csv', header=False)

## Scale across perturbations

### Generate subsets

In [None]:
unique_covariates = [x for x in adata.obs.condition.unique() if x != 'control']
len(unique_covariates)

528

In [None]:
random.seed(84)
medium_perturbations = random.sample(unique_covariates, int(len(unique_covariates)/2))
small_perturbations = random.sample(medium_perturbations, int(len(medium_perturbations)/2))
len(small_perturbations), len(medium_perturbations) 

(132, 264)

### Generate splits

#### Small

In [None]:
adata_small = adata[adata.obs.condition.isin(small_perturbations + ['control'])]
adata_small

In [6]:
splitter = PerturbationDataSplitter(
    adata_small.obs.copy(),
    perturbation_key='condition',
    covariate_keys=['cell_type', 'treatment'],
    perturbation_control_value='control',
)

In [None]:
split = splitter.split_covariates(
    seed=57,
    print_split=True, ## Print a summary of the split if True
    max_heldout_covariates=7, ## Maximum number of held out covariates (in this case cell types)
    max_heldout_fraction_per_covariate=0.3, ## Maximum fraction of perturbations held out per covariate
)

In [8]:
split.value_counts()

transfer_split_seed57
train    222968
test      70828
val       70551
Name: count, dtype: int64

In [9]:
split_padded = pd.Series(None, index=adata.obs.index)
split_padded.loc[split.index] = split
split_padded.value_counts()

  split_padded.loc[split.index] = split


train    222968
test      70828
val       70551
Name: count, dtype: int64

In [10]:
len(split_padded)

878229

In [11]:
split_padded.to_csv(f'{data_cache_dir}/mcfaline23_gxe_splits/small_split.csv', header=False)

#### Medium

In [None]:
adata_medium = adata[adata.obs.condition.isin(medium_perturbations + ['control'])]
adata_medium

In [13]:
medium_split = pd.Series('train', index=adata_medium.obs.index)
for split_val in split.unique():
    split_idx = split[split == split_val].index
    medium_split.loc[split_idx] = split_val
medium_split.value_counts()

train    387896
test      70828
val       70551
Name: count, dtype: int64

In [14]:
medium_split_padded = pd.Series(None, index=adata.obs.index)
medium_split_padded.loc[medium_split.index] = medium_split
medium_split_padded.value_counts()

  medium_split_padded.loc[medium_split.index] = medium_split


train    387896
test      70828
val       70551
Name: count, dtype: int64

In [15]:
medium_split_padded.to_csv(f'{data_cache_dir}/mcfaline23_gxe_splits/medium_split.csv', header=False)

#### Full

In [16]:
full_split = pd.Series('train', index=adata.obs.index)
for split_val in split.unique():
    split_idx = split[split == split_val].index
    full_split.loc[split_idx] = split_val
full_split.value_counts()

train    736850
test      70828
val       70551
Name: count, dtype: int64

In [17]:
full_split.to_csv(f'{data_cache_dir}/mcfaline23_gxe_splits/full_split.csv', header=False)