In [1]:
import logging
import torch
import yaml
from src.dataset import create_dataloaders, ClinicalDataset, ImagingDataset
from src.utils import load_and_preprocess_data, split_and_scale_data, set_random
from src.train import train_and_evaluate_model
from src.models import SimpleNN, ViTBinaryClassifier
from monai.networks.nets import ViTAutoEnc


import optuna

In [2]:
logging.basicConfig(
    filename='training_logs_imaging.log',
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

set_random()

In [3]:
modality = "imaging"  # can be "clinical", "imaging", or "multimodal

assert modality in ["clinical", "imaging", "multimodal"], f"Modality {modality} not supported"

# Common parameters

with open("config.yml", "r") as file:
    config = yaml.safe_load(file)


geo_csv_path = config['data']["geo_csv_path"]
curated_csv_path = config['data']["curated_csv_path"]
img_seq_path = config['data']["img_seq_path"]
pretrained_model_path  = config['data']["pretrained_model_path"]
label_col = config['data']["label_col"]
exclude_columns = config['data']["exclude_columns"]

In [4]:
if modality == "imaging":
    pre_trained_model = ViTAutoEnc(
    img_size=(64, 64, 64),
    patch_size=8,
    in_channels=1,
    out_channels=1,
    num_layers=12,
    num_heads=12,
    hidden_size=384,
    mlp_dim=2048
)
    state_dict = torch.load(pretrained_model_path, map_location="cpu", weights_only=False)
    if any(k.startswith("module.") for k in state_dict.keys()):
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    pre_trained_model.load_state_dict(state_dict, strict=False)
    model = ViTBinaryClassifier

    ds_cls = ImagingDataset
    # model = resnet18

    ds_cls_kwargs = {"data_dir": img_seq_path, "is_gap": False}
    if isinstance(pre_trained_model, ViTAutoEnc):
        ds_cls_kwargs["is_img"] = True

elif modality == "clinical":
    ds_cls = ClinicalDataset
    model = SimpleNN
    ds_cls_kwargs = {"columns_to_drop": exclude_columns}

elif modality == "multimodal":
    pass  # TODO: Future implementation


done init


In [5]:
epochs = 70
embed_dim = 384


def objective(trial):
    # Define the hyperparameters to tune
    batch_size = trial.suggest_categorical("batch_size", [8, 16, 32])
    num_layers = trial.suggest_int("num_layers", 1, 7)
    learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1e-1)
    random_state = trial.suggest_int("random_state", 0, 10)
    # n_heads = trial.suggest_categorical("n_heads", [2, 4, 8])

    geo_df = load_and_preprocess_data(geo_csv_path, curated_csv_path, label_col)
    geo_df_train, geo_df_test = split_and_scale_data(geo_df, label_col, [col for col in geo_df.columns if col not in exclude_columns], random_state=random_state)

    # Create dataloaders
    dataloaders, feature_columns = create_dataloaders(
        geo_df_train,
        label_col,
        exclude_columns,
        batch_size,
        dataset_cls=ds_cls,
        dataset_kwargs=ds_cls_kwargs
    )

    # Model kwargs for model agnostic training
    model_kwargs = {
                    "unfreeze_last_n": num_layers,
                    "pretrained_model": pre_trained_model,
                    # "num_heads_img": n_heads,
                    # "num_layers_img": num_layers,
                    # "hidden_size": embed_dim,
                    }

    # Train and evaluate the model
    metrics = train_and_evaluate_model(
        trial, dataloaders, feature_columns, geo_df_test, exclude_columns,
        num_epochs=epochs, hidden_size=embed_dim, num_layers=num_layers,
        batch_size=batch_size, learning_rate=learning_rate,
        model_cls=model, model_kwargs=model_kwargs,
        dataset_cls=ds_cls, dataset_kwargs=ds_cls_kwargs
    )

    # Return the validation AUC as the objective value
    return metrics['auc']


# Add stream handler of stdout to show the messages
study_name = "pretrained-encoder"  # Unique identifier of the study.
study = optuna.create_study(study_name=study_name, direction="maximize")
study.optimize(objective, n_trials=75)

# Get the trial data as a DataFrame
trial_data = study.trials_dataframe()


# Save the trial data to a CSV file
trial_data.to_csv(f'optuna_results/optuna_results_{modality}_cv.csv', index=False)

[32m[I 2025-05-05 21:16:57,687][0m A new study created in memory with name: pretrained-encoder[0m
[32m[I 2025-05-05 21:24:14,074][0m Trial 0 finished with value: 0.6397037037037038 and parameters: {'batch_size': 32, 'num_layers': 6, 'learning_rate': 0.09109497989555919, 'random_state': 6, 'weight_decay': 0.001190819215839948}. Best is trial 0 with value: 0.6397037037037038.[0m
[32m[I 2025-05-05 21:31:36,983][0m Trial 1 finished with value: 0.5280740740740741 and parameters: {'batch_size': 16, 'num_layers': 6, 'learning_rate': 1.4274703835631646e-05, 'random_state': 8, 'weight_decay': 0.2146027134226971}. Best is trial 0 with value: 0.6397037037037038.[0m
[32m[I 2025-05-05 21:39:05,869][0m Trial 2 finished with value: 0.9099999999999999 and parameters: {'batch_size': 8, 'num_layers': 6, 'learning_rate': 0.00996939991420049, 'random_state': 10, 'weight_decay': 0.03927755247627527}. Best is trial 2 with value: 0.9099999999999999.[0m
[32m[I 2025-05-05 21:45:59,959][0m Trial 3