In [1]:
import sys
import pathlib
import numpy as np
import pandas as pd
import hiplot
from optuna.visualization import plot_param_importances
import optuna
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

from betavae import BetaVAE, train_vae, evaluate_vae
from optimize_utils import get_optimize_args

In [2]:
# Load command line arguments
args = get_optimize_args()

# Load data

output_dir = pathlib.Path("data")

sys.path.insert(0, "../0.data-download/scripts/")
from data_loader import load_train_test_data

data_directory = pathlib.Path("../0.data-download/data")
dfs = load_train_test_data(data_directory, train_or_test="all", load_gene_stats=True)

train_feat = dfs[0]
test_feat = dfs[1]
load_gene_stats = dfs[2]

# Prepare data for training
train_features_df = train_feat.drop(columns=["ModelID", "age_and_sex"])
test_features_df = test_feat.drop(columns=["ModelID", "age_and_sex"])

# subsetting the genes

# create dataframe containing the genes that passed an initial QC (see Pan et al. 2022) and their corresponding gene label and extract the gene labels
gene_dict_df = pd.read_csv(
    "../0.data-download/data/CRISPR_gene_dictionary.tsv", delimiter="\t"
)
gene_list_passed_qc = gene_dict_df.query("qc_pass").dependency_column.tolist()

# create new training and testing dataframes that contain only the corresponding genes
train_df = train_feat.filter(gene_list_passed_qc, axis=1)
test_df = test_feat.filter(gene_list_passed_qc, axis=1)

In [3]:
# Normalize data
train_data = train_df.values.astype(np.float32)
test_data = test_df.values.astype(np.float32)

# Normalize based on data distribution
train_data = (train_data - np.min(train_data, axis=0)) / (
    np.max(train_data, axis=0) - np.min(train_data, axis=0)
)
test_data = (test_data - np.min(test_data, axis=0)) / (
    np.max(test_data, axis=0) - np.min(test_data, axis=0)
)

# Convert dataframes to tensors
train_tensor = torch.tensor(train_data, dtype=torch.float32)
test_tensor = torch.tensor(test_data, dtype=torch.float32)

In [4]:
def objective(trial):
    """
    Optuna objective function: optimized by study
    """
    # Define hyperparameters
    latent_dim = trial.suggest_int(
        "latent_dim", args.min_latent_dim, args.max_latent_dim
    )
    beta = trial.suggest_float("beta", args.min_beta, args.max_beta)
    learning_rate = trial.suggest_categorical(
        "learning_rate", [5e-3, 1e-3, 1e-4, 1e-5, 1e-6]
    )
    batch_size = trial.suggest_int(
        "batch_size", args.min_batch_size, args.max_batch_size, args.batch_size_step
    )
    epochs = trial.suggest_int(
        "epochs", args.min_epochs, args.max_epochs, args.epoch_step
    )

    # Create DataLoader
    train_loader = DataLoader(
        TensorDataset(train_tensor), batch_size=batch_size, shuffle=True
    )
    test_loader = DataLoader(
        TensorDataset(test_tensor), batch_size=batch_size, shuffle=False
    )

    model = BetaVAE(input_dim=train_df.shape[1], latent_dim=latent_dim, beta=beta)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    loss = train_vae(model, train_loader, optimizer, epochs=epochs)

    # Evaluate VAE
    val_loss = evaluate_vae(model, test_loader)

    return val_loss

In [5]:
# Run Optuna optimization
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=500)


[I 2024-07-18 13:54:01,233] A new study created in memory with name: no-name-0b99e275-038d-4c96-b053-784b78ddf84f
  batch_size = trial.suggest_int(
  epochs = trial.suggest_int(


Epoch 0, Loss: 220.94055655664542
Epoch 1, Loss: 198.20562699151566
Epoch 2, Loss: 181.11974827721897
Epoch 3, Loss: 164.37109135058353
Epoch 4, Loss: 150.37361703633675
Epoch 5, Loss: 136.88999530314229
Epoch 6, Loss: 124.61393224167882
Epoch 7, Loss: 113.45037541869817
Epoch 8, Loss: 104.65969481222167
Epoch 9, Loss: 95.30419052086532
Epoch 10, Loss: 87.75013106812423
Epoch 11, Loss: 80.67724324444295
Epoch 12, Loss: 75.07351677072135
Epoch 13, Loss: 69.95959345187251
Epoch 14, Loss: 66.18322334008369
Epoch 15, Loss: 62.319709121741596
Epoch 16, Loss: 59.590370796822214
Epoch 17, Loss: 56.82255730406365
Epoch 18, Loss: 54.741338031590715
Epoch 19, Loss: 53.083851057422834
Epoch 20, Loss: 51.864597779819945
Epoch 21, Loss: 50.35813045267391
Epoch 22, Loss: 49.61246882258235
Epoch 23, Loss: 48.775303346226195
Epoch 24, Loss: 48.21038818359375
Epoch 25, Loss: 47.73388919314823
Epoch 26, Loss: 46.993266412608456
Epoch 27, Loss: 46.949398256344054
Epoch 28, Loss: 46.4976043326263
Epoch 29

[I 2024-07-18 13:54:50,482] Trial 0 finished with value: 104.18366410997179 and parameters: {'latent_dim': 20, 'beta': 2.593511281419829, 'learning_rate': 0.001, 'batch_size': 16, 'epochs': 605}. Best is trial 0 with value: 104.18366410997179.


Epoch 603, Loss: 45.19778164950284
Epoch 604, Loss: 45.19711566147113
Epoch 0, Loss: 263.254773639051
Epoch 1, Loss: 241.24265748569948


  batch_size = trial.suggest_int(
  epochs = trial.suggest_int(


Epoch 2, Loss: 228.51446008330774
Epoch 3, Loss: 219.0900474004722
Epoch 4, Loss: 211.76500025193872
Epoch 5, Loss: 204.9959914748733
Epoch 6, Loss: 199.90538890941724
Epoch 7, Loss: 194.30519887563344
Epoch 8, Loss: 189.8046257150261
Epoch 9, Loss: 186.22251138523112
Epoch 10, Loss: 181.62340288841753
Epoch 11, Loss: 177.65415860862637
Epoch 12, Loss: 173.54070010173527
Epoch 13, Loss: 170.23655127073096
Epoch 14, Loss: 166.28147374558506
Epoch 15, Loss: 162.08485116419686
Epoch 16, Loss: 159.17537634847204
Epoch 17, Loss: 156.18361414503994
Epoch 18, Loss: 152.78321332251997
Epoch 19, Loss: 148.306043170301
Epoch 20, Loss: 146.1050755456273
Epoch 21, Loss: 143.71521415991631
Epoch 22, Loss: 139.40177898735027
Epoch 23, Loss: 136.56145925310963
Epoch 24, Loss: 131.93412807005336
Epoch 25, Loss: 129.5753329790195
Epoch 26, Loss: 125.50607142342982
Epoch 27, Loss: 126.17427036744664
Epoch 28, Loss: 125.50081580159706
Epoch 29, Loss: 118.28704099163083
Epoch 30, Loss: 115.5740426928171
E

In [None]:
# Save best hyperparameters
best_trial = study.best_trial
print(best_trial)
print(f"Best trial: {best_trial.values}")
print(f"Best hyperparameters: {best_trial.params}")