# 2024-02-14-Analysis: PerturbationDataSplitter Demo

In [3]:
import scanpy as sc
import pandas as pd
from perturbench.data.datasplitter import PerturbationDataSplitter

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
data_cache_dir = '../neurips2024/perturbench_data' ## Change this to your local data directory

## Generating a cell type transfer split

This split models the task where we have perturbations measured across a set of cell types but not every perturbation is measured in every cell type. We want to predict the effects of the perturbations in cell types where they have not been measured. We simulate this task by iterating over perturbations and for all cell types where this perturbation is measured, randomly choosing some cell types to hold out. 

Note: every cell type has at least some measured perturbations in the current version of this data splitter

In [9]:
balanced_transfer_adata = sc.read_h5ad(f'{data_cache_dir}/srivatsan20_processed.h5ad', backed='r')
balanced_transfer_adata

AnnData object with n_obs × n_vars = 183856 × 9198 backed at '../neurips2024/perturbench_data/srivatsan20_processed.h5ad'
    obs: 'ncounts', 'well', 'plate', 'cell_line', 'replicate', 'time', 'dose_value', 'pathway_level_1', 'pathway_level_2', 'perturbation', 'target', 'pathway', 'dose_unit', 'celltype', 'disease', 'cancer', 'tissue_type', 'organism', 'perturbation_type', 'ngenes', 'percent_mito', 'percent_ribo', 'nperts', 'chembl-ID', 'dataset', 'cell_type', 'treatment', 'condition', 'dose', 'cov_merged', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes'
    var: 'ensembl_id', 'ncounts', 'ncells', 'gene_symbol', 'n_cells', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highl

Initialize splitter

In [10]:
balanced_transfer_splitter = PerturbationDataSplitter(
    balanced_transfer_adata.obs.copy(),
    perturbation_key='condition',
    covariate_keys=['cell_type'],
    perturbation_control_value='control',
)
balanced_transfer_splitter

<perturbench.data.datasplitter.PerturbationDataSplitter at 0x7fd5ac054490>

Generate a split. Setting a seed will ensure you get the same split every time

In [None]:
balanced_transfer_split = balanced_transfer_splitter.split_covariates(
    seed=0, 
    print_split=True, ## Print a summary of the split if True
    max_heldout_covariates=2, ## Maximum number of held out covariates (in this case cell types)
    max_heldout_fraction_per_covariate=0.3, ## Maximum fraction of perturbations held out per covariate
)

The split is returned as a pandas Series but is also stored in the splitter object in the `obs_dataframe` slot, indexed by a column that is the combination of the type of split (transfer_split) and random seed (seed0)

In [12]:
balanced_transfer_splitter.obs_dataframe.loc[:,'transfer_split_seed0'].value_counts()

transfer_split_seed0
train    124972
test      29635
val       29249
Name: count, dtype: int64

We can access the parameters used to generate this split in the `split_params` slot

In [13]:
balanced_transfer_splitter.split_params['transfer_split_seed0']

{'min_train_covariates': 1,
 'max_heldout_covariates': 2,
 'max_heldout_fraction_per_cov': 0.3,
 'train_control_fraction': 0.5}

And we can access the split summary dataframe in the `summary_dataframes` slot

In [14]:
balanced_transfer_splitter.summary_dataframes['transfer_split_seed0']

Unnamed: 0,train,val,test
"('mcf7',)",132,30,29
"('k562',)",132,29,30
"('a549',)",132,29,30


Since we set `max_heldout_covariates=2` we should have some perturbations that trained in one cell type and some trained in two (since there are 3 total cell types in this dataset). Let's look at the number of training cell types per perturbation

In [15]:
def get_num_train_cell_types(
    splitter,
    split_key,
):
    """Returns the number of training cell types per perturbation"""
    num_train_cell_types = []
    for pert in splitter.obs_dataframe.condition.unique():
        pert_df = splitter.obs_dataframe[splitter.obs_dataframe.condition == pert]
        pert_df = pert_df.loc[:,['cell_type', split_key]].drop_duplicates()
        num_train_cell_types.append(pert_df.loc[pert_df[split_key] == 'train', 'cell_type'].nunique())

    num_train_cell_types = pd.Series(num_train_cell_types, index=splitter.obs_dataframe.condition.unique())
    return num_train_cell_types

When we look at the number of training cell types we see the distribution we expect

In [16]:
num_train_cell_types = get_num_train_cell_types(balanced_transfer_splitter, 'transfer_split_seed0')
num_train_cell_types.value_counts()


3    77
1    59
2    53
Name: count, dtype: int64

If we set `max_heldout_covariates=1` there should be at least 2 training cell types per perturbation. This makes things a bit easier for the model

In [None]:
balanced_transfer_split = balanced_transfer_splitter.split_covariates(
    seed=1, 
    print_split=True, ## Print a summary of the split if True
    max_heldout_covariates=1, ## Maximum number of held out covariates (in this case cell types)
    max_heldout_fraction_per_covariate=0.3, ## Maximum fraction of perturbations held out per covariate
)

In [18]:
num_train_cell_types = get_num_train_cell_types(balanced_transfer_splitter, 'transfer_split_seed1')
num_train_cell_types.value_counts()


2    171
3     18
Name: count, dtype: int64

We can adjust the relative balance of perturbations used for training and held out perturbations using the `max_heldout_fraction_per_covariate` parameter

In [None]:
balanced_transfer_split = balanced_transfer_splitter.split_covariates(
    seed=2, 
    print_split=True, ## Print a summary of the split if True
    max_heldout_covariates=1, ## Maximum number of held out covariates (in this case cell types)
    max_heldout_fraction_per_covariate=0.2, ## Maximum fraction of perturbations held out per covariate
)

## Generating a combo split

This split models the task where we have single perturbations and a handful of combinations and want to predict the effects of the remaining combinations. We simulate this task by keeping all single perturbations for training, and holding some tunable fraction of the combinations randomly in each cell type

In [20]:
combo_adata = sc.read_h5ad(f'{data_cache_dir}/norman19_processed.h5ad', backed='r')
combo_adata

AnnData object with n_obs × n_vars = 111445 × 5850 backed at '../neurips2024/perturbench_data/norman19_processed.h5ad'
    obs: 'guide_id', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'tissue_type', 'cell_type', 'cancer', 'disease', 'perturbation_type', 'celltype', 'organism', 'perturbation', 'nperts', 'ngenes', 'ncounts', 'percent_mito', 'percent_ribo', 'condition', 'cov_merged', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes'
    var: 'ensemble_id', 'ncounts', 'ncells', 'n_cells', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'
    uns: 'hvg', 'log1p', 'rank_genes_groups_cov'
    layers: 'counts'

Initialize splitter

In [21]:
combo_splitter = PerturbationDataSplitter(
    combo_adata.obs.copy(),
    perturbation_key='condition',
    covariate_keys=None,
    perturbation_control_value='control',
)
combo_splitter

<perturbench.data.datasplitter.PerturbationDataSplitter at 0x7fd574f39510>

Generate a split. Setting a seed will ensure you get the same split every time

In [22]:
combo_split = combo_splitter.split_combinations(
    seed=0, 
    print_split=True, ## Print a summary of the split if True
    max_heldout_fraction_per_covariate=0.3, ## Maximum fraction of perturbations held out per covariate
)

        train  val  test
('1',)    198   20    21


We can modify the `max_heldout_fraction_per_covariate` parameter to increase or decrease the number of held out combos

In [23]:
combo_split = combo_splitter.split_combinations(
    seed=0, 
    print_split=True, ## Print a summary of the split if True
    max_heldout_fraction_per_covariate=0.7, ## Maximum fraction of perturbations held out per covariate
)

        train  val  test
('1',)    146   46    47


## Generating a inverse combo split

This split models the task where we've observed all combinations of perturbations and some of the single perturbations, and are trying to predict the effects of the remaining single perturbations

In [24]:
inverse_combo_split = combo_splitter.split_combinations_inverse(
    seed=0, 
    print_split=True, ## Print a summary of the split if True
    max_heldout_fraction_per_covariate=0.3, ## Maximum fraction of perturbations held out per covariate
)

        train  val  test
('1',)    206   16    17
