In [2]:
import logging
from src.dataset import create_dataloaders, ClinicalDataset, ImagingDataset
from src.utils import load_and_preprocess_data, split_and_scale_data
from src.train import train_and_evaluate_model
from src.models import SimpleNN, SimpleNNWithBatchNorm, VisionModel

import optuna

In [3]:
logging.basicConfig(
    filename='training_logs.log',
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

In [4]:
modality = "imaging"  # can be "clinical", "imaging", or "multimodal

assert modality in ["clinical", "imaging", "multimodal"], f"Modality {modality} not supported"

# Common parameters
geo_csv_path = "dataframes/threshold_df_new.csv"
curated_csv_path = "dataframes/molab_df_curated.csv"
img_seq_path = "representations/molab-hardy-leaf-97_embeddings.npy"
label_col = 'label-1RN-0Normal'
exclude_columns = ['label-1RN-0Normal', 'Patient ID', 'id', 'BASELINE_TIME_POINT', "CROSSING_TIME_POINT", "BASELINE_VOLUME", "scan_date"]


geo_df, exclude_columns = load_and_preprocess_data(geo_csv_path, curated_csv_path, label_col, exclude_columns)
geo_df_train, geo_df_test = split_and_scale_data(geo_df, label_col, [col for col in geo_df.columns if col not in exclude_columns])

In [5]:
if modality == "imaging":
    ds_cls = ImagingDataset
    model = VisionModel
    ds_cls_kwargs = {"data_dir": img_seq_path, "is_gap": False}

elif modality == "clinical":
    ds_cls = ClinicalDataset
    model = SimpleNN
    ds_cls_kwargs = {"columns_to_drop": exclude_columns}

elif modality == "multimodal":
    pass  # TODO: Future implementation


In [6]:
epochs = 60
embed_dim = 384


def objective(trial):
    # Define the hyperparameters to tune
    batch_size = trial.suggest_categorical("batch_size", [8, 16, 32, 64])
    num_layers = trial.suggest_int("num_layers", 1, 7)
    learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1e-1)
    n_heads = trial.suggest_categorical("n_heads", [2, 4, 8])

    # Create dataloaders
    dataloaders, feature_columns = create_dataloaders(
        geo_df_train,
        label_col,
        exclude_columns,
        batch_size,
        dataset_cls=ds_cls,
        dataset_kwargs=ds_cls_kwargs
    )

    # Model kwargs for model agnostic training
    model_kwargs = {"hidden_size": embed_dim, "num_heads_img": n_heads, "num_layers_img": num_layers}

    # Train and evaluate the model
    metrics = train_and_evaluate_model(
        trial, dataloaders, feature_columns, geo_df_test, exclude_columns,
        num_epochs=epochs, hidden_size=embed_dim, num_layers=num_layers,
        batch_size=batch_size, learning_rate=learning_rate,
        model_cls=model, model_kwargs=model_kwargs,
        dataset_cls=ds_cls, dataset_kwargs=ds_cls_kwargs
    )

    # Return the validation AUC as the objective value
    return metrics['auc']


# Add stream handler of stdout to show the messages
study_name = "pretrained-encoder"  # Unique identifier of the study.
study = optuna.create_study(study_name=study_name, direction="maximize")
study.optimize(objective, n_trials=75)

# Get the trial data as a DataFrame
trial_data = study.trials_dataframe()

# Save the trial data to a CSV file
trial_data.to_csv(f'optuna_results/optuna_results_{modality}_cv.csv', index=False)

[32m[I 2025-04-30 12:41:00,778][0m A new study created in memory with name: pretrained-encoder[0m
[32m[I 2025-04-30 12:45:12,289][0m Trial 0 finished with value: 0.6211111111111112 and parameters: {'batch_size': 16, 'num_layers': 7, 'learning_rate': 0.04006517719476767, 'n_heads': 8, 'weight_decay': 0.005108256994135997}. Best is trial 0 with value: 0.6211111111111112.[0m
[32m[I 2025-04-30 12:47:28,788][0m Trial 1 finished with value: 0.4834074074074074 and parameters: {'batch_size': 32, 'num_layers': 4, 'learning_rate': 5.0053701838715566e-05, 'n_heads': 4, 'weight_decay': 0.04254583156635979}. Best is trial 0 with value: 0.6211111111111112.[0m
[32m[I 2025-04-30 12:49:34,163][0m Trial 2 finished with value: 0.6636296296296297 and parameters: {'batch_size': 64, 'num_layers': 4, 'learning_rate': 0.018083186139654466, 'n_heads': 2, 'weight_decay': 0.033382270329062325}. Best is trial 2 with value: 0.6636296296296297.[0m
[32m[I 2025-04-30 12:51:06,758][0m Trial 3 finished wi