In [2]:
import anndata as ad
import pickle as pkl

from src.evaluator.MLP_baseline_evaluator import get_models_results
from src.evaluator.evaluator_utils import l2_loss
from src.utils import get_model_stats


In [10]:
def train_different_normalization(adata_path=None, run_name=None, res_savename=None, stats_savename=None):
    DRUG_ENCODING_NAME = "fmfp"
    DRUG_ENCODING_SIZE = 1024
    N_TRIALS = 50
    SCHEDULER_MODE = 'min'

    with open("./data/drug_splits/train_drugs_rand.pkl", 'rb') as f:
        drugs_train_rand = pkl.load(f)

    with open("./data/drug_splits/val_drugs_rand.pkl", 'rb') as f:
        drugs_val_rand = pkl.load(f)

    with open("./data/drug_splits/test_drugs_rand.pkl", 'rb') as f:
        drugs_test_rand = pkl.load(f)

    drug_splits = dict()
    drug_splits['train'] = drugs_train_rand
    drug_splits['valid'] = drugs_val_rand
    drug_splits['test'] = drugs_test_rand

    adata = ad.read_h5ad(adata_path)

    get_models_results(drug_splits=drug_splits,
                          loss_function=l2_loss,
                          adata=adata,
                          input_dim=1878,
                          output_dim=1878,
                          drug_rep_name=DRUG_ENCODING_NAME,
                          drug_emb_size=DRUG_ENCODING_SIZE,
                          n_trials=N_TRIALS,
                          scheduler_mode=SCHEDULER_MODE,
                          run_name=run_name,
                          save_path=res_savename
                      )

    with open(res_savename, 'rb') as f:
        res_raw = pkl.load(f)

    adata_control = adata[adata.obs.product_name == 'Vehicle'].copy()
    gene_names = list(adata_control.var_names)
    raw_stats = get_model_stats(res_raw, adata_control, gene_names, run_name)

    with open(stats_savename, 'wb') as f:
        pkl.dump(raw_stats, f)

In [11]:
train_different_normalization(
        adata_path="./data/normalization/sciplex_raw_filt.h5ad",
        run_name="mlp_rawcount_norm",
        res_savename="./results/mlp_rawcount_norm_res.pkl",
        stats_savename="./results/mlp_rawcount_norm_stats.pkl"
    )

Loading Datasets ...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 401917/401917 [03:07<00:00, 2144.97it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 401917/401917 [01:13<00:00, 5438.82it/s]
[I 2025-05-24 11:17:02,145] Using an existing study with name 'mlp_rawcount_norm' instead of creating a new one.


Optimizing Hyperparameters with Optuna ...


[W 2025-05-24 11:17:04,674] Trial 1 failed with parameters: {'lr': 0.0001, 'weight_decay': 0.0001, 'scheduler_factor': 0.1, 'scheduler_patience': 10, 'batch_size': 256, 'dropout': 0.2, 'hidden_dims': 1024} because of the following error: NameError("name 'avg_loss' is not defined").
Traceback (most recent call last):
  File "/apps/miniconda3/envs/dege-fm/lib/python3.12/site-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/home/victor/projects/dege-fm/src/evaluator/MLP_baseline_evaluator.py", line 180, in <lambda>
    study.optimize(lambda trial: objective(trial,
                                 ^^^^^^^^^^^^^^^^
  File "/home/victor/projects/dege-fm/src/evaluator/MLP_baseline_evaluator.py", line 161, in objective
    return ev.train_with_validation(loss_fn, trial)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/victor/projects/dege-fm/src/evaluator/MLP_baseline_evaluator.py", line 6

Epoch:	 0 Val Loss:	 0.7124263897202789


NameError: name 'avg_loss' is not defined

In [None]:
train_different_normalization(
        adata_path="./data/normalization/sciplex_cpm_filt.h5ad",
        run_name="mlp_cpm_norm",
        res_savename="./results/mlp_cpm_norm_res.pkl",
        stats_savename="./results/mlp_cpm_norm_stats.pkl"
    )

In [None]:
train_different_normalization(
        adata_path="./data/normalization/sciplex_shiftedlog_filt.h5ad",
        run_name="mlp_shiftedlog_norm",
        res_savename="./results/mlp_shiftedlog_norm_res.pkl",
        stats_savename="./results/mlp_shiftedlog_norm_stats.pkl"
    )

In [None]:
train_different_normalization(
        adata_path="./data/normalization/sciplex_analyticpearson_filt.h5ad",
        run_name="mlp_analyticpearson_norm",
        res_savename="./results/mlp_analyticpearson_norm_res.pkl",
        stats_savename="./results/mlp_analyticpearson_norm_stats.pkl"
    )