In [3]:
import anndata as ad
import pickle as pkl

from src.evaluator.MLP_baseline_evaluator import get_models_results
from src.evaluator.evaluator_utils import l2_loss

In [4]:
def train_different_featno(adata_path=None, run_name=None, res_savename=None, input_dim=None, output_dim=None):
    DRUG_ENCODING_NAME = "fmfp"
    DRUG_ENCODING_SIZE = 1024
    N_TRIALS = 50
    SCHEDULER_MODE = 'min'

    with open("./data/drug_splits/train_drugs_rand.pkl", 'rb') as f:
        drugs_train_rand = pkl.load(f)

    with open("./data/drug_splits/val_drugs_rand.pkl", 'rb') as f:
        drugs_val_rand = pkl.load(f)

    with open("./data/drug_splits/test_drugs_rand.pkl", 'rb') as f:
        drugs_test_rand = pkl.load(f)

    drug_splits = dict()
    drug_splits['train'] = drugs_train_rand
    drug_splits['valid'] = drugs_val_rand
    drug_splits['test'] = drugs_test_rand

    adata = ad.read_h5ad(adata_path)

    get_models_results(drug_splits=drug_splits,
                          loss_function=l2_loss,
                          adata=adata,
                          input_dim=input_dim,
                          output_dim=output_dim,
                          drug_rep_name=DRUG_ENCODING_NAME,
                          drug_emb_size=DRUG_ENCODING_SIZE,
                          n_trials=N_TRIALS,
                          scheduler_mode=SCHEDULER_MODE,
                          run_name=run_name,
                          save_path=res_savename,
                          add_relu=True
                      )

In [None]:
train_different_featno(
        adata_path="./data/feature_number/sciplex_hvg_500.h5ad",
        run_name="mlp_hvg_500",
        res_savename="./results/feature_number/mlp_hvg_500_res.pkl",
        input_dim=500,
        output_dim=500,
    )

Loading Datasets ...


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 401917/401917 [03:05<00:00, 2162.11it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 401917/401917 [01:13<00:00, 5441.96it/s]
[I 2025-06-18 12:16:22,461] A new study created in RDB with name: mlp_hvg_500


Optimizing Hyperparameters with Optuna ...
Epoch:	 0 Val Loss:	 0.03604320379219638
Epoch:	 1 Val Loss:	 0.03647944459937202
Epoch:	 2 Val Loss:	 0.03639627547509417
Epoch:	 3 Val Loss:	 0.03625103992999942
Epoch:	 4 Val Loss:	 0.036146133072027437
Epoch:	 5 Val Loss:	 0.03616646544295109
Epoch:	 6 Val Loss:	 0.036119568470110844
Epoch:	 7 Val Loss:	 0.036087531889893616
Epoch:	 8 Val Loss:	 0.03601778980787736
Epoch:	 9 Val Loss:	 0.03627077741039336
Epoch:	 10 Val Loss:	 0.036098425063960424
Epoch:	 11 Val Loss:	 0.03624313283364298
Epoch:	 12 Val Loss:	 0.03614858539709133
Epoch:	 13 Val Loss:	 0.03606382787970293
Epoch:	 14 Val Loss:	 0.036161963531442014
Epoch:	 15 Val Loss:	 0.03600762937804894
Epoch:	 16 Val Loss:	 0.03607207453447331
Epoch:	 17 Val Loss:	 0.03608266689434312
Epoch:	 18 Val Loss:	 0.03614292048229282
Epoch:	 19 Val Loss:	 0.03615481744840789
Epoch:	 20 Val Loss:	 0.036189854743971316
Epoch:	 21 Val Loss:	 0.03608942449308477
Epoch:	 22 Val Loss:	 0.0359626189533

[I 2025-06-18 12:17:16,203] Trial 0 finished with value: 0.03596261895330581 and parameters: {'lr': 0.001, 'weight_decay': 0.001, 'scheduler_factor': 0.1, 'scheduler_patience': 10, 'batch_size': 256, 'dropout': 0.15, 'hidden_dims': 1024}. Best is trial 0 with value: 0.03596261895330581.


Epoch:	 32 Val Loss:	 0.036150687586791644
Epoch:	 0 Val Loss:	 0.035262776605532606
Epoch:	 1 Val Loss:	 0.03517082932887909
Epoch:	 2 Val Loss:	 0.03516092370789512
Epoch:	 3 Val Loss:	 0.03519663018510537
Epoch:	 4 Val Loss:	 0.035213203563132296
Epoch:	 5 Val Loss:	 0.0351970288387504
Epoch:	 6 Val Loss:	 0.03514785990432469
Epoch:	 7 Val Loss:	 0.03521427960047154
Epoch:	 8 Val Loss:	 0.03521955643312722
Epoch:	 9 Val Loss:	 0.035254226473520785
Epoch:	 10 Val Loss:	 0.035228751850579595
Epoch:	 11 Val Loss:	 0.03516889295865994
Epoch:	 12 Val Loss:	 0.03519169475224173
Epoch:	 13 Val Loss:	 0.03519844716689858


In [None]:
train_different_featno(
        adata_path="./data/feature_number/sciplex_hvg_1000.h5ad",
        run_name="mlp_hvg_1000",
        res_savename="./results/feature_number/mlp_hvg_1000_res.pkl",
        input_dim=1000,
        output_dim=1000,
    )

In [None]:
train_different_featno(
        adata_path="./data/feature_number/sciplex_hvg_2000.h5ad",
        run_name="mlp_hvg_2000",
        res_savename="./results/feature_number/mlp_hvg_2000_res.pkl",
        input_dim=2000,
        output_dim=2000,
    )

In [None]:
train_different_featno(
        adata_path="./data/feature_number/sciplex_hvg_3500.h5ad",
        run_name="mlp_hvg_3500",
        res_savename="./results/feature_number/mlp_hvg_3500_res.pkl",
        input_dim=3500,
        output_dim=3500,
    )

In [None]:
train_different_featno(
        adata_path="./data/feature_number/sciplex_hvg_5000.h5ad",
        run_name="mlp_hvg_5000",
        res_savename="./results/feature_number/mlp_hvg_5000_res.pkl",
        input_dim=5000,
        output_dim=5000,
    )

In [None]:
train_different_featno(
        adata_path="./data/feature_number/sciplex_hvg_7500.h5ad",
        run_name="mlp_hvg_7500",
        res_savename="./results/feature_number/mlp_hvg_7500_res.pkl",
        input_dim=7500,
        output_dim=7500,
    )