# Evaluator API Demo

In [1]:
import scanpy as sc
import anndata as ad
import random
from sklearn.model_selection import train_test_split

import perturbench.analysis.benchmarks as benchmarks
from perturbench.data.datasplitter import PerturbationDataSplitter
from perturbench.analysis.benchmarks.evaluator import Evaluator

%reload_ext autoreload
%autoreload 2

For this demo, we'll be demonstrating the usage of the Evaluator API using the srivatsan20-transfer task

In [3]:
data_cache_dir = '../neurips2024/perturbench_data'

In [4]:
adata = sc.read_h5ad(f'{data_cache_dir}/srivatsan20_processed.h5ad')
adata

AnnData object with n_obs × n_vars = 183856 × 9198
    obs: 'ncounts', 'well', 'plate', 'cell_line', 'replicate', 'time', 'dose_value', 'pathway_level_1', 'pathway_level_2', 'perturbation', 'target', 'pathway', 'dose_unit', 'celltype', 'disease', 'cancer', 'tissue_type', 'organism', 'perturbation_type', 'ngenes', 'percent_mito', 'percent_ribo', 'nperts', 'chembl-ID', 'dataset', 'cell_type', 'treatment', 'condition', 'dose', 'cov_merged', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes'
    var: 'ensembl_id', 'ncounts', 'ncells', 'gene_symbol', 'n_cells', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'
    uns: 'hvg', 'log1p', 'rank_genes_groups_cov'
 

Let's create a datasplit using our data splitting class

In [5]:
transfer_splitter = PerturbationDataSplitter(
    adata.obs.copy(),
    perturbation_key='condition',
    covariate_keys=['cell_type'],
    perturbation_control_value='control',
)

Generate a split. Setting a seed will ensure you get the same split every time

In [None]:
split = transfer_splitter.split_covariates(
    seed=0, 
    print_split=True, ## Print a summary of the split if True
    max_heldout_covariates=2, ## Maximum number of held out covariates (in this case cell types)
    max_heldout_fraction_per_covariate=0.3, ## Maximum fraction of perturbations held out per covariate
)

In [7]:
adata_test = adata[split == 'test']
adata_test.shape

(29635, 9198)

We'll simulate predictions by randomly subsampling and shuffling the data and treating those samples/shuffles as different "model predictions"

In [8]:
adata_test.obs['condition_cell_type'] = adata.obs['condition'].astype(str) + '_' + adata.obs['cell_type'].astype(str)

  adata_test.obs['condition_cell_type'] = adata.obs['condition'].astype(str) + '_' + adata.obs['cell_type'].astype(str)


In [9]:
sampled_cells, _ = train_test_split(
    adata_test.obs_names,
    test_size=0.25, 
    stratify=adata_test.obs['condition_cell_type'],
    random_state=54
)

In [10]:
sampled_adata = adata[sampled_cells, :]
sampled_adata.shape

(22226, 9198)

Now we'll create the shuffled predictions to serve as a negative control

In [11]:
random.seed(54)

random_adata_list = []
for cell_type in sampled_adata.obs.cell_type.unique():
    random_adata_cl = sampled_adata[sampled_adata.obs.cell_type == cell_type, :].copy()
    random_adata_cl.obs['condition'] = random.sample(
        list(random_adata_cl.obs['condition'].astype(str)), 
        k=random_adata_cl.n_obs,
    )
    random_adata_list.append(random_adata_cl)

random_adata = ad.concat(random_adata_list)
random_adata.shape

(22226, 9198)

List all tasks in the Evaluator class

In [13]:
Evaluator.list_tasks()

['srivatsan20-transfer', 'norman19-combo', 'mcfaline23-transfer']

Create an evaluator object with the srivatsan20-transfer task

In [14]:
srivatsan20_eval = benchmarks.evaluator.Evaluator(
    task='srivatsan20-transfer',
    local_data_cache=data_cache_dir,
)

The input to our evaluator class is a dictionary of model predictions

In [15]:
simulated_predictions = {
    'sampled': sampled_adata,
    'random': random_adata,
}

We then evaluate our simulated model predictions

In [None]:
metrics_df = srivatsan20_eval.evaluate(
    model_predictions=simulated_predictions,
    return_metrics_dataframe=True,
)

We can then look at the summary metrics returned by the evaluation. This is an average of the metric computed on a per-perturbation basis. As we can see, the sampled data is very close to the full observed data and the random data has no information at all

In [17]:
metrics_df

model,random,sampled
rmse_average,0.02821,0.005487
rmse_rank_average,0.4893,0.0
cosine_logfc,0.000577,0.8902
cosine_rank_logfc,0.5098,0.0
