In [None]:
%load_ext autoreload
%autoreload 2
import os
import torch
import numpy as np
import random

os.chdir("../..")
print(os.getcwd())

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 
seed = 21
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

In [None]:
from modules.data_pipeline import DataPipeline
pipeline = DataPipeline(components_csv='datasets/components.csv')
canonical_data, graph_list = pipeline.run_pipeline(raw_csv='datasets/dataset.csv')

In [None]:
import modules.datasplit_module as dsm
random.shuffle(graph_list)
sampled_graph_list = graph_list
train, val, test = \
    dsm.system_disjoint_split(sampled_graph_list, random_state=seed, stratify_by_components=True)

In [None]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(
    dataset=train,
    batch_size=1024,
    shuffle=True,
    follow_batch=['component_batch']
)

val_loader = DataLoader(
    dataset=val,
    batch_size=1024,
    shuffle=False,
    follow_batch=['component_batch']
)

test_loader = DataLoader(
    dataset=test,
    batch_size=1024,
    shuffle=False,
    follow_batch=['component_batch']
)

In [10]:
# --- 1. Configuration ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
track_grad = True
include_gd = True
N_TRIALS = 20
N_EPOCHS = 50

In [11]:
import optuna
from optuna.pruners import MedianPruner
import modules.trainer_module as tm
import modules.dtmpnn as gm
from pathlib import Path
import joblib
import sys


print(f"Device: {device}")
print(f"Include GD: {include_gd}")
print(f"Trials: {N_TRIALS}")
print(f"Epochs per trial: {N_EPOCHS}")

def objective(trial: optuna.trial.Trial):
    """
    This function takes an optuna 'trial' object,
    builds a model, trains it, and returns the best validation loss.
    """
    log_dir = Path(f"notebooks/hyperparams_search/HPO_reports/optuna_logs_{include_gd}")
    log_dir.mkdir(exist_ok=True, parents=True)
    log_file_path = log_dir / f'trial_{trial.number}.log'

    # --- 1. DEFINE THE HYPERPARAMETER SEARCH SPACE ---
    lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
    weight_decay = trial.suggest_float('weight_decay', 1e-7, 1e-4, log=True)
    gd_weight = trial.suggest_float('gd_weight', 1e-3, 1, log=True)
    
    graph_hidden_dim = trial.suggest_categorical('graph_hidden_dim', [16, 32, 64, 128])
    latent_dim = trial.suggest_categorical('latent_dim', [16, 32, 64])
    context_dim = trial.suggest_categorical('context_dim', [16, 32, 64])
    graph_layers = trial.suggest_int('graph_layers', 2, 5)

    # --- 2. RUN THE TRIAL ---
    try:
        model = gm.DTMPNN(
                node_dim=train[0].x.shape[1],
                edge_dim=train[0].edge_attr.shape[1],
                graph_hidden_dim=graph_hidden_dim,
                latent_dim=latent_dim,
                context_dim=context_dim,
                graph_layers=graph_layers,
                track_grad=True
                        ).to(device)

        trainer = tm.DTMPNNTrainer(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            include_gd=include_gd,
            device=device,
            lr=lr,
            weight_decay=weight_decay,
            gd_weight=gd_weight
        )

        # Train the model
        trainer.train(
            epochs=N_EPOCHS,
            save_dir=f'notebooks/hyperparams_search/HPO_reports/optuna_checkpoints_{include_gd}/trial_{trial.number}',
            log_file_path=log_file_path,
            save_best=False,
            save_every=None,
            optuna_trial=trial
        )
        
        torch.cuda.empty_cache()

        # --- 3. RETURN THE METRIC TO MINIMIZE ---
        return trainer.best_val_loss

    except optuna.TrialPruned:
        print(f"Trial {trial.number} was pruned.")
        torch.cuda.empty_cache()
        raise

    except Exception as e:
        print(f"Trial {trial.number} failed with error: {e}", file=sys.stderr)
        torch.cuda.empty_cache()
        return float('inf')

Device: cuda
Include GD: True
Trials: 20
Epochs per trial: 50


In [12]:
# --- 5. Study Runner ---
def run_study():
    pruner = MedianPruner(
        n_startup_trials=5,
        n_warmup_steps=5,
        interval_steps=1
    )
    
    study_name = f"DTMPNN_hpo_{'gd' if include_gd else 'no_gd'}"
    study_db_path = f"sqlite:///notebooks/hyperparams_search/HPO_reports/dashboard/master_hpo_study.db"
    
    study = optuna.create_study(
        study_name=study_name,
        storage=study_db_path,
        load_if_exists=True,
        direction='minimize',
        pruner=pruner
    )
    
    print(f"--- Starting/Resuming study: {study_name} ---")
    print(f"--- Database at: {study_db_path} ---")

    try:
        study.optimize(
            objective, 
            n_trials=N_TRIALS,
            timeout=None,
            gc_after_trial=True
        ) 
    except KeyboardInterrupt:
        print("--- HPO interrupted by user. Study is saved. ---")
    
    print(f"--- Study complete ---")
    
    # --- 5. PRINT RESULTS ---
    pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
    completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
    
    print(f"Total trials: {len(study.trials)}")
    print(f"  Completed: {len(completed_trials)}")
    print(f"  Pruned:    {len(pruned_trials)}")

    if completed_trials:
        print(f"\nBest trial:")
        best_trial = study.best_trial
        print(f"  Value (min val_loss): {best_trial.value:.6f}")
        print(f"  Params: ")
        for key, value in best_trial.params.items():
            print(f"    {key}: {value}")
    else:
        print("No trials completed successfully.")
    
    return study

In [13]:
# --- 6. Run the HPO Study ---
if __name__ == "__main__":
    Path("notebooks/hyperparams_search/HPO_reports").mkdir(exist_ok=True)
    Path("notebooks/hyperparams_search/HPO_reports/dashboard").mkdir(exist_ok=True)
    study = run_study()

[I 2025-11-18 14:38:59,530] Using an existing study with name 'DTMPNN_hpo_gd' instead of creating a new one.


--- Starting/Resuming study: DTMPNN_hpo_gd ---
--- Database at: sqlite:///notebooks/hyperparams_search/HPO_reports/dashboard/master_hpo_study.db ---


[I 2025-11-18 14:46:55,748] Trial 51 pruned.               


--- Trial pruned at epoch 15. ---
Trial 51 was pruned.


[I 2025-11-18 14:50:06,381] Trial 52 pruned.               


--- Trial pruned at epoch 6. ---
Trial 52 was pruned.


[I 2025-11-18 15:16:52,837] Trial 53 finished with value: 0.2586904537219268 and parameters: {'lr': 0.0019215732323425677, 'weight_decay': 1.9545838402942607e-06, 'gd_weight': 0.0013655666413148622, 'graph_hidden_dim': 64, 'latent_dim': 32, 'context_dim': 16, 'graph_layers': 3}. Best is trial 43 with value: 0.22746023054306325.



--- Full log successfully finalized and written to: trial_53.log ---


[I 2025-11-18 15:19:33,015] Trial 54 pruned.               


--- Trial pruned at epoch 5. ---
Trial 54 was pruned.


[I 2025-11-18 15:22:14,193] Trial 55 pruned.               


--- Trial pruned at epoch 5. ---
Trial 55 was pruned.


[I 2025-11-18 15:36:32,416] Trial 56 pruned.               


--- Trial pruned at epoch 30. ---
Trial 56 was pruned.


[I 2025-11-18 15:50:56,762] Trial 57 pruned.               


--- Trial pruned at epoch 27. ---
Trial 57 was pruned.


[I 2025-11-18 15:54:03,319] Trial 58 pruned.               


--- Trial pruned at epoch 5. ---
Trial 58 was pruned.


[I 2025-11-18 15:57:17,767] Trial 59 pruned.               


--- Trial pruned at epoch 6. ---
Trial 59 was pruned.


[I 2025-11-18 15:59:28,211] Trial 60 pruned.               


--- Trial pruned at epoch 5. ---
Trial 60 was pruned.


[I 2025-11-18 16:34:42,400] Trial 61 finished with value: 0.30038498387886925 and parameters: {'lr': 0.0012225379349194712, 'weight_decay': 6.5366654949903755e-06, 'gd_weight': 0.0020397623811694815, 'graph_hidden_dim': 64, 'latent_dim': 32, 'context_dim': 16, 'graph_layers': 5}. Best is trial 43 with value: 0.22746023054306325.



--- Full log successfully finalized and written to: trial_61.log ---


[I 2025-11-18 17:10:26,074] Trial 62 finished with value: 0.2578068522306589 and parameters: {'lr': 0.0015599654300943553, 'weight_decay': 6.8610544705611225e-06, 'gd_weight': 0.0020885603570798363, 'graph_hidden_dim': 64, 'latent_dim': 32, 'context_dim': 16, 'graph_layers': 5}. Best is trial 43 with value: 0.22746023054306325.



--- Full log successfully finalized and written to: trial_62.log ---


[I 2025-11-18 17:46:11,843] Trial 63 finished with value: 0.2839798535291965 and parameters: {'lr': 0.0014715943600803396, 'weight_decay': 5.9140780871704e-06, 'gd_weight': 0.0020075789738053865, 'graph_hidden_dim': 64, 'latent_dim': 32, 'context_dim': 16, 'graph_layers': 5}. Best is trial 43 with value: 0.22746023054306325.



--- Full log successfully finalized and written to: trial_63.log ---


[I 2025-11-18 17:49:46,463] Trial 64 pruned.               


--- Trial pruned at epoch 5. ---
Trial 64 was pruned.


[I 2025-11-18 17:56:11,797] Trial 65 pruned.               


--- Trial pruned at epoch 9. ---
Trial 65 was pruned.


[I 2025-11-18 17:59:46,323] Trial 66 pruned.               


--- Trial pruned at epoch 5. ---
Trial 66 was pruned.


[I 2025-11-18 18:03:20,955] Trial 67 pruned.               


--- Trial pruned at epoch 5. ---
Trial 67 was pruned.


[I 2025-11-18 18:18:22,019] Trial 68 pruned.               


--- Trial pruned at epoch 21. ---
Trial 68 was pruned.


[I 2025-11-18 18:54:25,087] Trial 69 finished with value: 0.2204171061515808 and parameters: {'lr': 0.001593760129459641, 'weight_decay': 1.0546290780917184e-05, 'gd_weight': 0.001013983616183676, 'graph_hidden_dim': 64, 'latent_dim': 32, 'context_dim': 64, 'graph_layers': 5}. Best is trial 69 with value: 0.2204171061515808.



--- Full log successfully finalized and written to: trial_69.log ---


[I 2025-11-18 19:23:13,998] Trial 70 pruned.               


--- Trial pruned at epoch 40. ---
Trial 70 was pruned.
--- Study complete ---
Total trials: 71
  Completed: 21
  Pruned:    49

Best trial:
  Value (min val_loss): 0.220417
  Params: 
    lr: 0.001593760129459641
    weight_decay: 1.0546290780917184e-05
    gd_weight: 0.001013983616183676
    graph_hidden_dim: 64
    latent_dim: 32
    context_dim: 64
    graph_layers: 5
