In [None]:
%load_ext autoreload
%autoreload 2
import os
import torch
import numpy as np
import random

os.chdir("../..")
print(os.getcwd())

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 
seed = 21
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

In [None]:
from modules.data_pipeline import DataPipeline
pipeline = DataPipeline(components_csv='datasets/components.csv')
canonical_data, graph_list = pipeline.run_pipeline(raw_csv='datasets/dataset.csv')

In [None]:
import modules.datasplit_module as dsm
random.shuffle(graph_list)
train, val, test = \
    dsm.system_disjoint_split(graph_list, random_state=seed, stratify_by_components=True)

In [None]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(
    dataset=train[:1000],
    batch_size=1024,
    shuffle=True,
    follow_batch=['component_batch']
)

val_loader = DataLoader(
    dataset=val[:1000],
    batch_size=1024,
    shuffle=False,
    follow_batch=['component_batch']
)

test_loader = DataLoader(
    dataset=test[:1000],
    batch_size=1024,
    shuffle=False,
    follow_batch=['component_batch']
)

In [None]:
# --- 1. Configuration ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
constraint_type = 'hard'
include_gd = False
N_TRIALS = 10
N_EPOCHS = 20

In [None]:
import optuna
from optuna.pruners import MedianPruner
import modules.trainer_module as tm
import modules.dtmpnn as gm
from pathlib import Path
import joblib
import sys


print(f"Device: {device}")
print(f"Constraint Type: {constraint_type}")
print(f"Include GD: {include_gd}")
print(f"Trials: {N_TRIALS}")
print(f"Epochs per trial: {N_EPOCHS}")

def objective(trial: optuna.trial.Trial):
    """
    This function takes an optuna 'trial' object,
    builds a model, trains it, and returns the best validation loss.
    """
    log_dir = Path(f"notebooks/hyperparams_search/HPO_reports/optuna_logs/{constraint_type}_constraint/GD_backprop_{include_gd}")
    log_dir.mkdir(exist_ok=True, parents=True)
    log_file_path = log_dir / f'trial_{trial.number}.log'

    # --- 1. DEFINE THE HYPERPARAMETER SEARCH SPACE ---
    lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
    weight_decay = trial.suggest_float('weight_decay', 1e-7, 1e-4, log=True)
    gd_weight = 1.0
    graph_hidden_dim = trial.suggest_categorical('graph_hidden_dim', [16, 32, 64, 128])
    latent_dim = trial.suggest_categorical('latent_dim', [16, 32, 64])
    context_dim = trial.suggest_categorical('context_dim', [16, 32, 64])
    graph_layers = trial.suggest_int('graph_layers', 2, 5)

    # --- 2. RUN THE TRIAL ---
    try:
        model = gm.DTMPNN(
            node_dim=train[0].x.shape[1],
            edge_dim=train[0].edge_attr.shape[1],
            graph_hidden_dim=graph_hidden_dim,
            latent_dim=latent_dim,
            context_dim=context_dim,
            graph_layers=graph_layers,
            constraint_type=constraint_type
        ).to(device)

        trainer = tm.DTMPNNTrainer(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            include_gd=include_gd,
            device=device,
            lr=lr,
            weight_decay=weight_decay,
            gd_weight=gd_weight
        )

        # Train the model
        trainer.train(
            epochs=N_EPOCHS,
            save_dir=None,
            log_file_path=log_file_path,
            save_best=False,
            save_every=None,
            optuna_trial=trial
        )
        
        torch.cuda.empty_cache()

        # --- 3. RETURN THE METRIC TO MINIMIZE ---
        return trainer.best_val_loss

    except optuna.TrialPruned:
        print(f"Trial {trial.number} was pruned.")
        torch.cuda.empty_cache()
        raise

    except Exception as e:
        print(f"Trial {trial.number} failed with error: {e}", file=sys.stderr)
        torch.cuda.empty_cache()
        return float('inf')

In [None]:
# --- 5. Study Runner ---
def run_study():
    pruner = MedianPruner(
        n_startup_trials=5,
        n_warmup_steps=5,
        interval_steps=1
    )
    
    study_name = f"DTMPNN_hpo_{constraint_type}_constrained_{'gd' if include_gd else 'no_gd'}"
    study_db_path = f"sqlite:///notebooks/hyperparams_search/HPO_reports/dashboard/master_hpo_study.db"
    
    study = optuna.create_study(
        study_name=study_name,
        storage=study_db_path,
        load_if_exists=True,
        direction='minimize',
        pruner=pruner
    )
    
    print(f"--- Starting/Resuming study: {study_name} ---")
    print(f"--- Database at: {study_db_path} ---")

    try:
        study.optimize(
            objective, 
            n_trials=N_TRIALS,
            timeout=None,
            gc_after_trial=True
        ) 
    except KeyboardInterrupt:
        print("--- HPO interrupted by user. Study is saved. ---")
    
    print(f"--- Study complete ---")
    
    # --- 5. PRINT RESULTS ---
    pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
    completed_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
    
    print(f"Total trials: {len(study.trials)}")
    print(f"  Completed: {len(completed_trials)}")
    print(f"  Pruned:    {len(pruned_trials)}")

    if completed_trials:
        print(f"\nBest trial:")
        best_trial = study.best_trial
        print(f"  Value (min val_loss): {best_trial.value:.6f}")
        print(f"  Params: ")
        for key, value in best_trial.params.items():
            print(f"    {key}: {value}")
    else:
        print("No trials completed successfully.")
    
    return study

In [None]:
# --- 6. Run the HPO Study ---
if __name__ == "__main__":
    Path("notebooks/hyperparams_search/HPO_reports").mkdir(exist_ok=True)
    Path("notebooks/hyperparams_search/HPO_reports/dashboard").mkdir(exist_ok=True)
    study = run_study()