In [27]:
import anndata as ad
import pickle as pkl

from src.evaluator.MLP_baseline_evaluator import cross_validation_models
from src.evaluator.evaluator_utils import l2_loss

import optuna
import torch.optim as optim
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from torch.utils.data.dataloader import DataLoader

from src.models.MLP_concat import MLPModel
from src.dataset.dataset_unseen_compounds import SciplexDatasetUnseenPerturbations
from src.utils import get_model_stats
from src.evaluator.MLP_baseline_evaluator import MLPBaselineEvaluator

In [7]:
ADATA_PATH = "./data/sciplex/sciplex_final.h5ad"
DRUG_SPLIT = "./data/sciplex/prnet_drug_splits.pkl"
DRUG_ENCODING_NAME = "sm_coati_emb"
DRUG_ENCODING_SIZE = 256
CELL_INPUT_NAME = "X_2000_hvg"
CELL_INPUT_SIZE = 2000
CELL_OUTPUT_NAME = "X_2000_hvg"
CELL_OUTPUT_SIZE = 2000
LOSS_FUNCTION = l2_loss
N_TRIALS = 50
GENE_NAMES = 'gene_names_2000'
SCHEDULER_MODE = 'min'

In [8]:
with open(DRUG_SPLIT, "rb") as f:
    drug_splits = pkl.load(f)

adata = ad.read_h5ad(ADATA_PATH)

In [9]:
output = dict()
i=4

drugs_train = drug_splits[f'drug_split_{i}']['train']
drugs_validation = drug_splits[f'drug_split_{i}']['valid']
drugs_test = drug_splits[f'drug_split_{i}']['test']

In [13]:
def objective(trial, dataset_train=None, dataset_validation=None,
              input_dim=0, output_dim=0, drug_dim=0, scheduler_mode='min', loss_fn=None):

    lr = trial.suggest_float('lr', 1e-6, 1e-3, log=True)
    weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-2, log=True)
    scheduler_factor = trial.suggest_float('scheduler_factor', 0.1, 0.5, log=False)
    scheduler_patience = trial.suggest_int('scheduler_patience', 1, 20,)
    batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128, 256])
    dropout = trial.suggest_float('dropout', 0.05, 0.3, log=False)
    hidden_dims = trial.suggest_categorical('hidden_dims', [64, 128, 256, 512, 1024, 2048, 4096])


    params = {
        'input_dim': input_dim,
        'output_dim' : output_dim,
        'drug_dim' : drug_dim,
        'dropout' : dropout,
        'scheduler_mode': scheduler_mode,
        'lr': lr,
        'weight_decay': weight_decay,
        'scheduler_factor': scheduler_factor,
        'scheduler_patience': scheduler_patience,
        'batch_size': batch_size,
        'hidden_dims' : (hidden_dims,),
    }
    ev = MLPBaselineEvaluator(dataset_train, dataset_validation, None, params)

    return ev.train_with_validation(loss_fn, trial)

In [50]:
drug_rep_name = DRUG_ENCODING_NAME
drug_emb_size = DRUG_ENCODING_SIZE
input_name = CELL_INPUT_NAME
output_name = CELL_OUTPUT_NAME
n_trials = 1
input_dim = CELL_INPUT_SIZE
output_dim = CELL_OUTPUT_SIZE
loss_function = l2_loss
scheduler_mode = SCHEDULER_MODE
gene_names_key = GENE_NAMES
run_name = "test_debugging"

In [None]:
#Optimize Hyperparamteres

dataset_train = SciplexDatasetUnseenPerturbations(adata, drugs_train, drug_rep_name, drug_emb_size, input_name, output_name)
dataset_validation = SciplexDatasetUnseenPerturbations(adata, drugs_validation, drug_rep_name, drug_emb_size, input_name, output_name)

In [28]:
study = optuna.create_study(direction='minimize', study_name=f"debugging_fold{i}", storage="sqlite:///optuna_study.db", load_if_exists=True)
study.optimize(lambda trial: objective(trial,
                                       dataset_train=dataset_train, dataset_validation=dataset_validation,
                                       input_dim=input_dim, output_dim=output_dim,
                                       drug_dim=drug_emb_size, loss_fn=loss_function), n_trials=n_trials)

[I 2025-04-12 17:24:42,599] Using an existing study with name 'debugging_fold4' instead of creating a new one.
[I 2025-04-12 17:26:40,325] Trial 3 finished with value: 0.02395641846299627 and parameters: {'lr': 8.91084980365547e-05, 'weight_decay': 2.46448350240984e-06, 'scheduler_factor': 0.22688200558305743, 'scheduler_patience': 1, 'batch_size': 32, 'dropout': 0.08592128016574645, 'hidden_dims': 256}. Best is trial 3 with value: 0.02395641846299627.


In [29]:
best_trial = study.best_trial
optimal_params = best_trial.params
best_epoch = best_trial.user_attrs["best_epoch"]

In [30]:
#Retrain the model on validation + train set with the best parameters
drugs_train_final = list(drugs_train) + list(drugs_validation)

dataset_train_final = SciplexDatasetUnseenPerturbations(adata, drugs_train_final, drug_rep_name, drug_emb_size, input_name, output_name)
dataset_test = SciplexDatasetUnseenPerturbations(adata, drugs_test, drug_rep_name, drug_emb_size, input_name, output_name)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 571696/571696 [01:19<00:00, 7164.99it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 571696/571696 [00:44<00:00, 12920.17it/s]


In [37]:
optimal_params['input_dim'] = input_dim
optimal_params['output_dim'] = output_dim
optimal_params['drug_dim'] = drug_emb_size
optimal_params['scheduler_mode'] = scheduler_mode
optimal_params['hidden_dims'] = (optimal_params['hidden_dims'],)

In [44]:
optimal_params['batch_size'] = 16

In [45]:
optimal_params['hidden_dims'] = (256,)

In [46]:
final_ev = MLPBaselineEvaluator(dataset_train_final, None, dataset_test, optimal_params)
final_ev.train(loss_function, num_epochs=best_epoch)

NameError: name 'gene_names_key' is not defined

In [51]:
#Get model performance metrics
adata_control = adata[adata.obs['product_name'] == "Vehicle"]
gene_names = adata_control.uns[gene_names_key]
predictions = final_ev.test()

performance = get_model_stats(predictions, adata_control, output_name, gene_names, run_name)
output[i] = performance

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4003/4003 [00:01<00:00, 2477.22it/s]
  utils.warn_names_duplicates("obs")


In [52]:
performance

{'key': 'test_debugging',
 'mse_A549': np.float32(0.000900806),
 'mse_K562': np.float32(0.00075966056),
 'mse_MCF7': np.float32(0.00041178125),
 'r2_A549': np.float64(0.9496960107769284),
 'r2_K562': np.float64(0.9298760344584783),
 'r2_MCF7': np.float64(0.9324416461445036),
 'rank_logfc_A549': np.float64(0.4730349971313827),
 'rank_logfc_K562': np.float64(0.5024383247274813),
 'rank_logfc_MCF7': np.float64(0.5035857716580608),
 'edistance_A549': np.float64(3.603224000393934),
 'edistance_K562': np.float64(3.038642212446841),
 'edistance_MCF7': np.float64(1.6471251163199943),
 'logfc_corr_A549': np.float64(0.591986006908933),
 'logfc_corr_K562': np.float64(0.46390072334513616),
 'logfc_corr_MCF7': np.float64(0.35091833792997623),
 'top_logfc_corr_A549': np.float64(0.5285331505343687),
 'top_logfc_corr_K562': np.float64(0.5551241497598857),
 'top_logfc_corr_MCF7': np.float64(0.23199707959154908),
 'predicted_bio_rep_A549': np.float64(0.961904761904762),
 'predicted_bio_rep_K562': np.flo