In [5]:
import anndata as ad
import pickle as pkl

from src.evaluator.decoder_evaluator import get_models_results
from src.evaluator.evaluator_utils import l2_loss

In [6]:
def train_different_featno(adata_path=None, run_name=None, res_savename=None, output_dim=None):
    DRUG_ENCODING_NAME = "fmfp"
    DRUG_ENCODING_SIZE = 1024
    N_TRIALS = 50
    SCHEDULER_MODE = 'min'

    with open("./data/drug_splits/train_drugs_rand.pkl", 'rb') as f:
        drugs_train_rand = pkl.load(f)

    with open("./data/drug_splits/val_drugs_rand.pkl", 'rb') as f:
        drugs_val_rand = pkl.load(f)

    with open("./data/drug_splits/test_drugs_rand.pkl", 'rb') as f:
        drugs_test_rand = pkl.load(f)

    drug_splits = dict()
    drug_splits['train'] = drugs_train_rand
    drug_splits['valid'] = drugs_val_rand
    drug_splits['test'] = drugs_test_rand

    adata = ad.read_h5ad(adata_path)

    get_models_results(drug_splits=drug_splits,
                          loss_function=l2_loss,
                          adata=adata,
                          input_dim=3,
                          output_dim=output_dim,
                          drug_rep_name=DRUG_ENCODING_NAME,
                          drug_emb_size=DRUG_ENCODING_SIZE,
                          n_trials=N_TRIALS,
                          scheduler_mode=SCHEDULER_MODE,
                          run_name=run_name,
                          save_path=res_savename,
                          add_relu=True
                      )

In [None]:
train_different_featno(
        adata_path="./data/feature_number/sciplex_hvg_500.h5ad",
        run_name="decoder_hvg_500",
        res_savename="./results/feature_number/decoder_hvg_500_res.pkl",
        output_dim=500
    )

Loading Datasets ...


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 401917/401917 [03:07<00:00, 2142.94it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 401917/401917 [01:14<00:00, 5414.96it/s]
[I 2025-06-18 12:14:28,947] A new study created in RDB with name: decoder_hvg_500


Optimizing Hyperparameters with Optuna ...
Epoch:	 0 Val Loss:	 0.036040958764560725
Epoch:	 1 Val Loss:	 0.0359940535447188
Epoch:	 2 Val Loss:	 0.03582629210122803
Epoch:	 3 Val Loss:	 0.03582519927864855
Epoch:	 4 Val Loss:	 0.036094724330722236
Epoch:	 5 Val Loss:	 0.03576840181352143
Epoch:	 6 Val Loss:	 0.03585878534789644
Epoch:	 7 Val Loss:	 0.03574844200110263
Epoch:	 8 Val Loss:	 0.035793500481529184
Epoch:	 9 Val Loss:	 0.035686842017892274
Epoch:	 10 Val Loss:	 0.035923281869885436
Epoch:	 11 Val Loss:	 0.03587923205216949
Epoch:	 12 Val Loss:	 0.035885136326353394
Epoch:	 13 Val Loss:	 0.03603785663221277
Epoch:	 14 Val Loss:	 0.03586215011321044
Epoch:	 15 Val Loss:	 0.03581007752705061
Epoch:	 16 Val Loss:	 0.035877912994444275
Epoch:	 17 Val Loss:	 0.0357632360407667
Epoch:	 18 Val Loss:	 0.03576978373776278


[I 2025-06-18 12:15:12,035] Trial 0 finished with value: 0.035686842017892274 and parameters: {'lr': 0.001, 'weight_decay': 0.001, 'scheduler_factor': 0.3, 'scheduler_patience': 20, 'batch_size': 128, 'dropout': 0.2, 'hidden_dims': 128}. Best is trial 0 with value: 0.035686842017892274.


Epoch:	 19 Val Loss:	 0.035808348478106586
Epoch:	 0 Val Loss:	 0.038022296488258195
Epoch:	 1 Val Loss:	 0.03644553445778941
Epoch:	 2 Val Loss:	 0.03604772642398139
Epoch:	 3 Val Loss:	 0.03582695532314729
Epoch:	 4 Val Loss:	 0.03566017445080554
Epoch:	 5 Val Loss:	 0.03580968765878911
Epoch:	 6 Val Loss:	 0.03574560487508905
Epoch:	 7 Val Loss:	 0.03575120633122821
Epoch:	 8 Val Loss:	 0.035745682103394794
Epoch:	 9 Val Loss:	 0.03573907113109143
Epoch:	 10 Val Loss:	 0.03569552591871049
Epoch:	 11 Val Loss:	 0.03577713483369433
Epoch:	 12 Val Loss:	 0.03574741372782263
Epoch:	 13 Val Loss:	 0.03578403928550194


[I 2025-06-18 12:18:07,000] Trial 1 finished with value: 0.03566017445080554 and parameters: {'lr': 1e-05, 'weight_decay': 0.001, 'scheduler_factor': 0.3, 'scheduler_patience': 1, 'batch_size': 16, 'dropout': 0.1, 'hidden_dims': 256}. Best is trial 1 with value: 0.03566017445080554.


Epoch:	 14 Val Loss:	 0.035710386979479314
Epoch:	 0 Val Loss:	 0.06744520060634346
Epoch:	 1 Val Loss:	 0.054380036465955205
Epoch:	 2 Val Loss:	 0.05049358516405329
Epoch:	 3 Val Loss:	 0.048896340052350565
Epoch:	 4 Val Loss:	 0.04797175038683663
Epoch:	 5 Val Loss:	 0.047165043215546903
Epoch:	 6 Val Loss:	 0.04657300886956685
Epoch:	 7 Val Loss:	 0.04588967263339802
Epoch:	 8 Val Loss:	 0.04522650343265809
Epoch:	 9 Val Loss:	 0.044609905329695866
Epoch:	 10 Val Loss:	 0.044051297122603625
Epoch:	 11 Val Loss:	 0.0434296712410871
Epoch:	 12 Val Loss:	 0.0428082597760385
Epoch:	 13 Val Loss:	 0.04227996485406476
Epoch:	 14 Val Loss:	 0.0417388075511203
Epoch:	 15 Val Loss:	 0.041248632837116814
Epoch:	 16 Val Loss:	 0.040865078412988975
Epoch:	 17 Val Loss:	 0.04048172684078423
Epoch:	 18 Val Loss:	 0.040164403384177684
Epoch:	 19 Val Loss:	 0.0398281691341779
Epoch:	 20 Val Loss:	 0.03953457482003668
Epoch:	 21 Val Loss:	 0.03934901378104048
Epoch:	 22 Val Loss:	 0.039177880468471

[I 2025-06-18 12:21:05,873] Trial 2 finished with value: 0.03663911332300061 and parameters: {'lr': 1e-06, 'weight_decay': 0.001, 'scheduler_factor': 0.1, 'scheduler_patience': 5, 'batch_size': 128, 'dropout': 0.15, 'hidden_dims': 512}. Best is trial 1 with value: 0.03566017445080554.


Epoch:	 77 Val Loss:	 0.03673161994695759
Epoch:	 0 Val Loss:	 0.03589680688351058
Epoch:	 1 Val Loss:	 0.03581555143752125
Epoch:	 2 Val Loss:	 0.03592527249639624
Epoch:	 3 Val Loss:	 0.03596980319048964
Epoch:	 4 Val Loss:	 0.03581739229002121
Epoch:	 5 Val Loss:	 0.036124419803651914
Epoch:	 6 Val Loss:	 0.0360926061892586
Epoch:	 7 Val Loss:	 0.035857010411508204
Epoch:	 8 Val Loss:	 0.03592477311315927
Epoch:	 9 Val Loss:	 0.03580556617658077
Epoch:	 10 Val Loss:	 0.03583789116928704
Epoch:	 11 Val Loss:	 0.035950139747935546
Epoch:	 12 Val Loss:	 0.03596626109418192
Epoch:	 13 Val Loss:	 0.035727262688295416
Epoch:	 14 Val Loss:	 0.036039589149324314
Epoch:	 15 Val Loss:	 0.036013912884922124
Epoch:	 16 Val Loss:	 0.03574597854675489
Epoch:	 17 Val Loss:	 0.035921253897787865
Epoch:	 18 Val Loss:	 0.035837215199516444
Epoch:	 19 Val Loss:	 0.0361100450265752
Epoch:	 20 Val Loss:	 0.03607459604309612
Epoch:	 21 Val Loss:	 0.03586650731260283
Epoch:	 22 Val Loss:	 0.03581407007553

[I 2025-06-18 12:21:59,725] Trial 3 finished with value: 0.035727262688295416 and parameters: {'lr': 0.001, 'weight_decay': 0.001, 'scheduler_factor': 0.3, 'scheduler_patience': 20, 'batch_size': 128, 'dropout': 0.2, 'hidden_dims': 64}. Best is trial 1 with value: 0.03566017445080554.


Epoch:	 23 Val Loss:	 0.03587372512534189
Epoch:	 0 Val Loss:	 0.11475929431018339
Epoch:	 1 Val Loss:	 0.10340974122956635
Epoch:	 2 Val Loss:	 0.09600670254786298
Epoch:	 3 Val Loss:	 0.08848802822004177
Epoch:	 4 Val Loss:	 0.08338948058928708
Epoch:	 5 Val Loss:	 0.07860563199523944
Epoch:	 6 Val Loss:	 0.07466411928367768
Epoch:	 7 Val Loss:	 0.07124727530782246
Epoch:	 8 Val Loss:	 0.0686368321898091
Epoch:	 9 Val Loss:	 0.06607697804689024
Epoch:	 10 Val Loss:	 0.06436591809825115
Epoch:	 11 Val Loss:	 0.0622215700781997
Epoch:	 12 Val Loss:	 0.06145662667044106
Epoch:	 13 Val Loss:	 0.05993783255961164
Epoch:	 14 Val Loss:	 0.05909316336370741
Epoch:	 15 Val Loss:	 0.058370859976079305


In [None]:
train_different_featno(
        adata_path="./data/feature_number/sciplex_hvg_1000.h5ad",
        run_name="decoder_hvg_1000",
        res_savename="./results/feature_number/decoder_hvg_1000_res.pkl",
        output_dim=1000
    )

In [None]:
train_different_featno(
        adata_path="./data/feature_number/sciplex_hvg_2000.h5ad",
        run_name="decoder_hvg_2000",
        res_savename="./results/feature_number/decoder_hvg_2000_res.pkl",
        output_dim=2000
    )

In [None]:
train_different_featno(
        adata_path="./data/feature_number/sciplex_hvg_3500.h5ad",
        run_name="decoder_hvg_3500",
        res_savename="./results/feature_number/decoder_hvg_3500_res.pkl",
        output_dim=3500
    )

In [None]:
train_different_featno(
        adata_path="./data/feature_number/sciplex_hvg_5000.h5ad",
        run_name="decoder_hvg_5000",
        res_savename="./results/feature_number/decoder_hvg_5000_res.pkl",
        output_dim=5000
    )

In [None]:
train_different_featno(
        adata_path="./data/feature_number/sciplex_hvg_7500.h5ad",
        run_name="decoder_hvg_7500",
        res_savename="./results/feature_number/decoder_hvg_7500_res.pkl",
        output_dim=7500
    )