In [1]:
import anndata as ad
import pickle as pkl

from src.evaluator.MLP_baseline_evaluator import cross_validation_models
from src.evaluator.evaluator_utils import l2_loss


  from .autonotebook import tqdm as notebook_tqdm


Read adata, splits information, cell representations

In [2]:
ADATA_PATH = "./data/sciplex_testing.h5ad"
DRUG_SPLIT = "./data/drug_split_test.pkl"
DRUG_ENCODING_NAME = "sm_coati_emb"
DRUG_ENCODING_SIZE = 256
CELL_INPUT_NAME = "X_2000_hvg"
CELL_INPUT_SIZE = 2000
CELL_OUTPUT_NAME = "X_2000_hvg"
CELL_OUTPUT_SIZE = 2000
LOSS_FUNCTION = l2_loss
N_TRIALS = 2
GENE_NAMES = 'gene_names_2000'
SCHEDULER_MODE = 'min'

In [3]:
with open(DRUG_SPLIT, "rb") as f:
    drug_splits = pkl.load(f)

adata = ad.read_h5ad(ADATA_PATH)

In [7]:
adata

AnnData object with n_obs × n_vars = 22166 × 17376
    obs: 'cell_type', 'dose', 'dose_character', 'dose_pattern', 'g1s_score', 'g2m_score', 'pathway', 'pathway_level_1', 'pathway_level_2', 'product_dose', 'product_name', 'proliferation_index', 'replicate', 'size_factor', 'target', 'vehicle', 'n_genes', 'SMILES', 'pubchem_id', 'sm_coati_emb', 'sm_morgan_emb', 'condition_aggregate', 'match_index', 'split_50', 'split_100', 'split_200', 'split_300', 'split_400'
    var: 'id', 'num_cells_expressed-0-0', 'num_cells_expressed-1-0', 'num_cells_expressed-1', 'n_cells'
    uns: 'gene_names_2000', 'gene_names_3500', 'gene_names_5000'
    obsm: 'X_2000_hvg', 'X_3500_hvg', 'X_5000_hvg', 'X_uce'

In [4]:
performance = cross_validation_models(drug_splits=drug_splits,
                                      loss_function=LOSS_FUNCTION,
                                      adata=adata,
                                      input_name=CELL_INPUT_NAME,
                                      input_dim=CELL_INPUT_SIZE,
                                      output_name=CELL_OUTPUT_NAME,
                                      output_dim=CELL_OUTPUT_SIZE,
                                      drug_rep_name=DRUG_ENCODING_NAME,
                                      drug_emb_size=DRUG_ENCODING_SIZE,
                                      n_trials=N_TRIALS,
                                      gene_names_key=GENE_NAMES,
                                      scheduler_mode=SCHEDULER_MODE,
                                      run_name="test")

100%|██████████| 22166/22166 [00:02<00:00, 10393.54it/s]
100%|██████████| 22166/22166 [00:02<00:00, 9276.86it/s]
[I 2025-04-03 17:32:04,376] A new study created in memory with name: no-name-3d31f000-4e7b-4700-a084-58ca52278a04
[I 2025-04-03 17:32:07,545] Trial 0 finished with value: 0.03590430940190951 and parameters: {'lr': 3.548163038778531e-05, 'weight_decay': 3.820431275383771e-06, 'scheduler_factor': 0.44169448967592684, 'scheduler_patience': 3, 'batch_size': 64, 'dropout': 0.2937699435340559, 'hidden_dims': 128}. Best is trial 0 with value: 0.03590430940190951.
[I 2025-04-03 17:32:18,546] Trial 1 finished with value: 0.03393331895991464 and parameters: {'lr': 2.4901541292684264e-05, 'weight_decay': 4.3295107133509236e-05, 'scheduler_factor': 0.28146238882892716, 'scheduler_patience': 19, 'batch_size': 32, 'dropout': 0.08314597247052781, 'hidden_dims': 512}. Best is trial 1 with value: 0.03393331895991464.
100%|██████████| 22166/22166 [00:03<00:00, 6329.40it/s]
100%|██████████| 22

In [6]:
print(performance)

{0: {'key': 'test', 'mse_A549': np.float32(0.092338674), 'mse_K562': np.float32(0.10662203), 'mse_MCF7': np.float32(0.10835916), 'r2_A549': np.float64(-4.143093109130859), 'r2_K562': np.float64(-8.82184886932373), 'r2_MCF7': np.float64(-20.8350887298584), 'rank_logfc_A549': np.float64(0.5), 'rank_logfc_K562': np.float64(0.5), 'rank_logfc_MCF7': np.float64(0.49999999999999994), 'edistance_A549': np.float64(369.35471036568447), 'edistance_K562': np.float64(426.4881808196507), 'edistance_MCF7': np.float64(433.43673219807357)}, 1: {'key': 'test', 'mse_A549': np.float32(0.1947694), 'mse_K562': np.float32(0.24085543), 'mse_MCF7': np.float32(0.23698658), 'r2_A549': np.float64(-9.848292350769043), 'r2_K562': np.float64(-21.187213897705078), 'r2_MCF7': np.float64(-46.75435256958008), 'rank_logfc_A549': np.float64(0.5), 'rank_logfc_K562': np.float64(0.5), 'rank_logfc_MCF7': np.float64(0.49999999999999994), 'edistance_A549': np.float64(779.0776495949892), 'edistance_K562': np.float64(963.42139708