In [6]:
import logging
from src.dataset import create_dataloaders, ClinicalDataset, ImagingDataset
from src.utils import load_and_preprocess_data, split_and_scale_data
from src.train import train_and_evaluate_model
from src.models import SimpleNN, SimpleNNWithBatchNorm

import optuna

In [7]:
logging.basicConfig(
    filename='training_logs.log',
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

In [8]:
modality = "imaging"  # can be "clinical", "imaging", or "multimodal

assert modality in ["clinical", "imaging", "multimodal"], f"Modality {modality} not supported"

# Common parameters
geo_csv_path = "dataframes/threshold_df_new.csv"
curated_csv_path = "dataframes/molab_df_curated.csv"
img_seq_path = "representations/molab-hardy-leaf-97_embeddings.npy"
label_col = 'label-1RN-0Normal'
exclude_columns = ['label-1RN-0Normal', 'Patient ID', 'id', 'BASELINE_TIME_POINT', "CROSSING_TIME_POINT", "BASELINE_VOLUME", "scan_date"]


geo_df, exclude_columns = load_and_preprocess_data(geo_csv_path, curated_csv_path, label_col, exclude_columns)
geo_df_train, geo_df_test = split_and_scale_data(geo_df, label_col, [col for col in geo_df.columns if col not in exclude_columns])

In [9]:
if modality == "imaging":
    ds_cls = ImagingDataset
    model = SimpleNNWithBatchNorm
    ds_cls_kwargs = {"data_dir": img_seq_path, "is_gap": True}

elif modality == "clinical":
    ds_cls = ClinicalDataset
    model = SimpleNN
    ds_cls_kwargs = {"columns_to_drop": exclude_columns}

elif modality == "multimodal":
    pass  # TODO: Future implementation


In [10]:
epochs = 50

def objective(trial):
    # Define the hyperparameters to tune
    hidden_size = trial.suggest_categorical("hidden_size", [64, 128, 256, 512])
    batch_size = trial.suggest_categorical("batch_size", [8, 16, 32, 64])
    num_layers = trial.suggest_int("num_layers", 1, 5)
    learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1e-1)

    # Create dataloaders
    dataloaders, feature_columns = create_dataloaders(
        geo_df_train,
        label_col,
        exclude_columns,
        batch_size,
        dataset_cls=ds_cls,
        dataset_kwargs=ds_cls_kwargs
    )

    input_size = len(feature_columns) if modality == "clinical" else 384 # TODO: Remove hardcoded value

    # Model kwargs for model agnostic training
    model_kwargs = {"input_size": input_size, "hidden_size": hidden_size, "num_layer": num_layers}

    # Train and evaluate the model
    metrics = train_and_evaluate_model(
        trial, dataloaders, feature_columns, geo_df_test, exclude_columns,
        num_epochs=epochs, hidden_size=hidden_size, num_layers=num_layers,
        batch_size=batch_size, learning_rate=learning_rate,
        model_cls=model, model_kwargs=model_kwargs,
        dataset_cls=ds_cls, dataset_kwargs=ds_cls_kwargs
    )

    # Return the validation AUC as the objective value
    return metrics['auc']


# Add stream handler of stdout to show the messages
study_name = "pretrained-encoder"  # Unique identifier of the study.
study = optuna.create_study(study_name=study_name, direction="maximize")
study.optimize(objective, n_trials=20)

# Get the trial data as a DataFrame
trial_data = study.trials_dataframe()

# Save the trial data to a CSV file
trial_data.to_csv(f'optuna_results/optuna_results_{modality}_cv.csv', index=False)

[32m[I 2025-04-29 15:05:24,067][0m A new study created in memory with name: pretrained-encoder[0m
[32m[I 2025-04-29 15:05:29,966][0m Trial 0 finished with value: 0.6958518518518518 and parameters: {'hidden_size': 256, 'batch_size': 32, 'num_layers': 4, 'learning_rate': 3.760604424190815e-05, 'weight_decay': 0.005703178376401477}. Best is trial 0 with value: 0.6958518518518518.[0m
[32m[I 2025-04-29 15:05:39,018][0m Trial 1 finished with value: 0.6128888888888889 and parameters: {'hidden_size': 256, 'batch_size': 8, 'num_layers': 3, 'learning_rate': 0.005916188244347682, 'weight_decay': 0.0002380077167959243}. Best is trial 0 with value: 0.6958518518518518.[0m
[32m[I 2025-04-29 15:05:48,715][0m Trial 2 finished with value: 0.7143703703703703 and parameters: {'hidden_size': 256, 'batch_size': 8, 'num_layers': 4, 'learning_rate': 0.0268818559293548, 'weight_decay': 0.0017879960977177773}. Best is trial 2 with value: 0.7143703703703703.[0m
[32m[I 2025-04-29 15:05:53,403][0m Tr