# 2024-04-18-Analysis: Verify CPA performance using legacy Theis lab fork

In [None]:
import optuna
import pandas as pd
import scanpy as sc
import seaborn as sns
import numpy as np

import anubis
from analysis.benchmarks.evaluation import Evaluation

sns.set_theme(font="Calibri")
sns.set_style("whitegrid", {'axes.grid' : False})

%load_ext autoreload
%autoreload 2

In [None]:
data_path = "./notebooks/neurips2025/perturbench_data/"

## Load reference data

In [None]:
adata = sc.read_h5ad(f'{data_path}/srivatsan20_highest_dose_preprocessed.h5ad')
adata

AnnData object with n_obs × n_vars = 178213 × 8630
    obs: 'ncounts', 'well', 'plate', 'cell_line', 'replicate', 'time', 'dose_value', 'pathway_level_1', 'pathway_level_2', 'perturbation', 'target', 'pathway', 'dose_unit', 'celltype', 'disease', 'cancer', 'tissue_type', 'organism', 'perturbation_type', 'ngenes', 'percent_mito', 'percent_ribo', 'nperts', 'chembl-ID', 'dataset', 'cell_type', 'treatment', 'condition', 'dose', 'perturbation_raw', 'pert_cell_type', 'ood_split', 'cov_merged', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'control', 'dose_val', 'cov_drug_dose_name', '_scvi_cell_type'
    var: 'ensembl_id', 'ncounts', 'ncells', 'gene_symbol', 'n_cells', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'highly_variable_rank', 'm

In [None]:
split_vals = pd.read_csv(
    f'{data_path}/srivatsan20_highest_dose_splits/ood_split.csv', 
    header=None, 
    index_col=0,
).iloc[:,0]

In [25]:
adata.obs['ood_split'] = split_vals.loc[adata.obs_names]
adata.obs['ood_split'].value_counts()

ood_split
train    120222
test      30947
val       27044
Name: count, dtype: int64

In [26]:
adata_ref = adata[adata.obs['ood_split'] == 'val']
adata_ref

View of AnnData object with n_obs × n_vars = 27044 × 8630
    obs: 'ncounts', 'well', 'plate', 'cell_line', 'replicate', 'time', 'dose_value', 'pathway_level_1', 'pathway_level_2', 'perturbation', 'target', 'pathway', 'dose_unit', 'celltype', 'disease', 'cancer', 'tissue_type', 'organism', 'perturbation_type', 'ngenes', 'percent_mito', 'percent_ribo', 'nperts', 'chembl-ID', 'dataset', 'cell_type', 'treatment', 'condition', 'dose', 'perturbation_raw', 'pert_cell_type', 'ood_split', 'cov_merged', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'control', 'dose_val', 'cov_drug_dose_name', '_scvi_cell_type'
    var: 'ensembl_id', 'ncounts', 'ncells', 'gene_symbol', 'n_cells', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'highly_variable_ra

## Load best Theis lab predictions

In [None]:
val_pred = sc.read_h5ad(
    f'{data_path}/srivatsan20-test/cpa-pred-theis-fork/best-hparams/test-preds/srivatsan20_ood_split_val_0.h5ad'
)
val_pred

In [None]:
ev = Evaluation(
    model_adatas=[val_pred],
    model_names=['CPA_theis_fork'],
    ref_adata=adata_ref,
    pert_col='condition',
    cov_cols=['cell_type'],
    ctrl='control',
)
ev

In [None]:
ev.aggregate(aggr_method='average')
ev.aggregate(aggr_method='logfc')

In [30]:
evaluation_pipelines = [
    {
        'aggregation': 'average',
        'metric': 'rmse',
        'rank': True
    },
    {
        'aggregation': 'logfc',
        'metric': 'cosine',
        'rank': True
    }
]

In [31]:
summary_metrics_dict = {}
for eval_dict in evaluation_pipelines:
    aggr = eval_dict['aggregation']
    metric = eval_dict['metric']
    ev.evaluate(aggr_method=aggr, metric=metric)
    
    df = ev.evals[aggr][metric].copy()
    avg = df.groupby('model').mean('metric')
    summary_metrics_dict[metric + '_' + aggr] = avg['metric']
    
    if eval_dict.get('rank'):
        ev.evaluate_pairwise(aggr_method=aggr, metric=metric)
        ev.evaluate_rank(aggr_method=aggr, metric=metric)
        
        rank_df = ev.rank_evals[aggr][metric].copy()
        avg_rank = rank_df.groupby('model').mean('rank')
        summary_metrics_dict[metric + '_rank_' + aggr] = avg_rank['rank']

In [32]:
summary_metrics = pd.DataFrame(summary_metrics_dict).T.applymap(
    lambda x: float(np.format_float_positional(x, precision=4, unique=False, fractional=False, trim='k')),
)
summary_metrics

  summary_metrics = pd.DataFrame(summary_metrics_dict).T.applymap(


model,CPA_theis_fork
rmse_average,0.02436
rmse_rank_average,0.3935
cosine_logfc,0.1156
cosine_rank_logfc,0.4884


In [None]:
cpa_anubis_summary = pd.read_csv(
    f'{data_path}/srivatsan20-test/cpa-pred-anubis/logs/train/runs/2024-04-19_18-47-50/evaluation/summary.csv'
)
cpa_anubis_summary.index = cpa_anubis_summary['metric']
cpa_anubis_summary.drop(columns='metric', inplace=True)
cpa_anubis_summary

Unnamed: 0_level_0,CPA
metric,Unnamed: 1_level_1
rmse_average,0.02319
rmse_rank_average,0.3946
cosine_logfc,0.3489
cosine_rank_logfc,0.3867


In [34]:
summary_metrics = pd.concat(
    [summary_metrics, cpa_anubis_summary],
    axis=1
)
summary_metrics

Unnamed: 0,CPA_theis_fork,CPA
rmse_average,0.02436,0.02319
rmse_rank_average,0.3935,0.3946
cosine_logfc,0.1156,0.3489
cosine_rank_logfc,0.4884,0.3867
