# 2024-04-18-Analysis: Verify CPA performance using legacy Theis lab fork

In [1]:
import pandas as pd
import scanpy as sc
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

from analysis.benchmarks.evaluation import Evaluation

sns.set_theme(font="Calibri")
sns.set_style("whitegrid", {'axes.grid' : False})

%load_ext autoreload
%autoreload 2

## Load reference data

In [None]:
data_path = "./notebooks/neurips2025/perturbench_data/"

In [None]:
adata = sc.read_h5ad(f'{data_path}/norman19_cpa_hvg_normalized_curated.h5ad')
adata

AnnData object with n_obs × n_vars = 111122 × 5044
    obs: 'guide_id', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'tissue_type', 'cell_line', 'cancer', 'disease', 'perturbation_type', 'celltype', 'organism', 'perturbation', 'nperts', 'ngenes', 'ncounts', 'percent_mito', 'percent_ribo', 'n_counts', 'condition', 'pert_type', 'cell_type', 'source', 'condition_ID', 'control', 'dose_value', 'pathway', 'cov_cond', 'pert', 'split_hardest', 'split_1', 'split_2', 'split_3', 'split_4', 'split_5', 'split_6', 'cond_harm', 'split'
    var: 'ensemble_id', 'ncounts', 'ncells', 'symbol', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'cell_type_colors', 'gene_embedding_path', 'hvg', 'log1p', 'neighbors', 'rank_genes_groups_cov', 'source_colors', 'split_1_colors', 'split_2_colors', 'split_3_colors', 'split_4_colors', 'split_5_colors', 'split_hardest_colors', 'umap'
    obsm: 'X_pca', 'X_umap'
    layers: 'counts'
    obsp: 'connectivi

In [None]:
split = pd.read_csv(
    f'{data_path}/norman19_cpa_hvg_normalized_splits/split_6.csv',
    index_col=0,
    header=None,
).iloc[:, 0]
split.head()

0
CCCATACCATTCTTAC     test
CTCATTAGTAAGAGAG      val
CACACCTCATGAACCT    train
AGATCTGTCACCAGGC    train
ATGTGTGCAAGCCGCT    train
Name: 1, dtype: object

In [4]:
adata.obs['split'] = split.loc[adata.obs_names]
adata.obs['split'].value_counts()

split
train    83375
val      21038
test      6709
Name: count, dtype: int64

In [5]:
adata_ref = adata[adata.obs['split'] == 'test']
adata_ref

View of AnnData object with n_obs × n_vars = 6709 × 5044
    obs: 'guide_id', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'tissue_type', 'cell_line', 'cancer', 'disease', 'perturbation_type', 'celltype', 'organism', 'perturbation', 'nperts', 'ngenes', 'ncounts', 'percent_mito', 'percent_ribo', 'n_counts', 'condition', 'pert_type', 'cell_type', 'source', 'condition_ID', 'control', 'dose_value', 'pathway', 'cov_cond', 'pert', 'split_hardest', 'split_1', 'split_2', 'split_3', 'split_4', 'split_5', 'split_6', 'cond_harm', 'split'
    var: 'ensemble_id', 'ncounts', 'ncells', 'symbol', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'cell_type_colors', 'gene_embedding_path', 'hvg', 'log1p', 'neighbors', 'rank_genes_groups_cov', 'source_colors', 'split_1_colors', 'split_2_colors', 'split_3_colors', 'split_4_colors', 'split_5_colors', 'split_hardest_colors', 'umap'
    obsm: 'X_pca', 'X_umap'
    layers: 'counts'
    obsp: 'conn

## Load best Theis lab predictions

In [None]:
public_cpa_pred = sc.read_h5ad(
    f'{data_path}/theis_cpa_pred_raw.h5ad'
)
public_cpa_pred

AnnData object with n_obs × n_vars = 111122 × 5044
    obs: 'guide_id', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'tissue_type', 'cell_line', 'cancer', 'disease', 'perturbation_type', 'celltype', 'organism', 'perturbation', 'nperts', 'ngenes', 'ncounts', 'percent_mito', 'percent_ribo', 'n_counts', 'condition', 'pert_type', 'cell_type', 'source', 'condition_ID', 'control', 'dose_value', 'pathway', 'cov_cond', 'pert', 'split_hardest', 'split_1', 'split_2', 'split_3', 'split_4', 'split_5', 'split_6', 'cond_harm', 'split', 'CPA_cat', 'CPA_ctrl', '_scvi_cond_harm', '_scvi_cell_type', '_scvi_CPA_cat'
    var: 'ensemble_id', 'ncounts', 'ncells', 'symbol', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'

In [7]:
sc.pp.normalize_total(public_cpa_pred)
sc.pp.log1p(public_cpa_pred)

In [8]:
public_cpa_pred_test = public_cpa_pred[public_cpa_pred.obs['split'] == 'ood']
public_cpa_pred_test

View of AnnData object with n_obs × n_vars = 6709 × 5044
    obs: 'guide_id', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'tissue_type', 'cell_line', 'cancer', 'disease', 'perturbation_type', 'celltype', 'organism', 'perturbation', 'nperts', 'ngenes', 'ncounts', 'percent_mito', 'percent_ribo', 'n_counts', 'condition', 'pert_type', 'cell_type', 'source', 'condition_ID', 'control', 'dose_value', 'pathway', 'cov_cond', 'pert', 'split_hardest', 'split_1', 'split_2', 'split_3', 'split_4', 'split_5', 'split_6', 'cond_harm', 'split', 'CPA_cat', 'CPA_ctrl', '_scvi_cond_harm', '_scvi_cell_type', '_scvi_CPA_cat'
    var: 'ensemble_id', 'ncounts', 'ncells', 'symbol', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'log1p'

In [9]:
public_cpa_pred_test.obs.condition.value_counts()

condition
control          1500
UBASH3B+OSR2      796
SET+CEBPE         658
MAPK1+PRTG        500
MAPK1+TGFBR2      497
KLF1+BAK1         395
PTPN12+PTPN9      364
SGK1+TBX2         349
ETS2+MAP7D1       324
PTPN12+ZBTB25     303
TGFBR2+PRTG       265
ZBTB10+PTPN12     265
FOXA3+FOXF1       175
IGDCC3+PRTG       140
PRDM1+CBFA2T3     106
SAMD1+TGFBR2       72
Name: count, dtype: int64

In [10]:
ev = Evaluation(
    model_adatas=[public_cpa_pred],
    model_names=['CPA_theis'],
    ref_adata=adata_ref,
    pert_col='condition',
    cov_cols=['cell_type'],
    ctrl='control',
)
ev

<analysis.benchmarks.evaluation.Evaluation at 0x7fe5d333f680>

In [11]:
ev.aggregate(aggr_method='average')
ev.aggregate(aggr_method='logfc')
ev.aggregate(aggr_method='pca')
ev.aggregate(aggr_method='pca_average')
ev.aggregate(aggr_method='none')
ev.aggregate(aggr_method='scores')

In [12]:
ev.aggregate(aggr_method='scores')

In [13]:
evaluation_pipelines = [
    {
        'aggregation': 'average',
        'metric': 'rmse',
        'rank': True
    },
    {
        'aggregation': 'pca_average',
        'metric': 'cosine',
        'rank': True
    },
    {
        'aggregation': 'logfc',
        'metric': 'cosine',
        'rank': True
    },
    {
        'aggregation': 'scores',
        'metric': 'r2_score',
        'rank': False
    },
    {
        'aggregation': 'scores', 
        'metric': 'top_k_recall',
        'rank': False
    },
    {
        'aggregation': 'pca',
        'metric': 'mmd',
        'rank': True
    },
    {
        'aggregation': 'none',
        'metric': 'mmd',
        'rank': True
    }
]

In [14]:
summary_metrics_dict = {}
for eval_dict in evaluation_pipelines:
    aggr = eval_dict['aggregation']
    metric = eval_dict['metric']
    ev.evaluate(aggr_method=aggr, metric=metric)
    
    df = ev.evals[aggr][metric].copy()
    avg = df.groupby('model').mean('metric')
    summary_metrics_dict[metric + '_' + aggr] = avg['metric']
    
    if eval_dict.get('rank'):
        ev.evaluate_pairwise(aggr_method=aggr, metric=metric)
        ev.evaluate_rank(aggr_method=aggr, metric=metric)
        
        rank_df = ev.rank_evals[aggr][metric].copy()
        avg_rank = rank_df.groupby('model').mean('rank')
        summary_metrics_dict[metric + '_rank_' + aggr] = avg_rank['rank']

In [None]:
summary_metrics = pd.DataFrame(summary_metrics_dict).T.applymap(
    lambda x: float(np.format_float_positional(x, precision=4, unique=False, fractional=False, trim='k')),
)
summary_metrics

In [None]:
cpa_perturbench_metrics = pd.Series({
    'rmse_average': 0.027290,
    'rmse_rank_average': 0.008889,
    'cosine_pca_average': 0.880700,
    'cosine_rank_pca_average': 0.000000,
    'cosine_logfc': 0.749800,
    'cosine_rank_logfc': 0.013330,
    'r2_score_scores': 0.224700,
    'top_k_recall_scores': 0.384000,
    'mmd_pca': 2.308000,
    'mmd_rank_pca': 0.008889,
    'mmd_none': 3.769000,
    'mmd_rank_none': 0.008889
})

In [None]:
summary_metrics = pd.concat(
    [summary_metrics, pd.DataFrame(cpa_perturbench_metrics, columns=['CPA_perturbench'])],
    axis=1
)
summary_metrics

Unnamed: 0,CPA_theis,CPA_anubis
rmse_average,0.1124,0.02729
rmse_rank_average,0.4267,0.008889
cosine_pca_average,0.08203,0.8807
cosine_rank_pca_average,0.2356,0.0
cosine_logfc,0.2058,0.7498
cosine_rank_logfc,0.2622,0.01333
r2_score_scores,-1.004,0.2247
top_k_recall_scores,0.09867,0.384
mmd_pca,3.283,2.308
mmd_rank_pca,0.4222,0.008889


In [None]:
# Convert summary metrics to markdown table
markdown_table = summary_metrics.to_markdown()
print("\nMetrics Summary Table:\n")
print(markdown_table)


Metrics Summary Table:

|                         |   CPA_theis |   CPA_anubis |
|:------------------------|------------:|-------------:|
| rmse_average            |     0.1124  |     0.02729  |
| rmse_rank_average       |     0.4267  |     0.008889 |
| cosine_pca_average      |     0.08203 |     0.8807   |
| cosine_rank_pca_average |     0.2356  |     0        |
| cosine_logfc            |     0.2058  |     0.7498   |
| cosine_rank_logfc       |     0.2622  |     0.01333  |
| r2_score_scores         |    -1.004   |     0.2247   |
| top_k_recall_scores     |     0.09867 |     0.384    |
| mmd_pca                 |     3.283   |     2.308    |
| mmd_rank_pca            |     0.4222  |     0.008889 |
| mmd_none                |     4.237   |     3.769    |
| mmd_rank_none           |     0.4133  |     0.008889 |
