In [5]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
import optuna
from train_multiple_jax_models import train_multiple_cls_models_vmap_lion_dynamic
from pca_datasets import *
from models_jax import *
from functools import partial

def objective(trial, model, batch_size=None):
    # Suggest hyperparameters
    base_wd    = trial.suggest_float("weight_decay",           1e-3, 1e+0, log=True)
    init_std   = trial.suggest_float("init_std",               1e-1, 1e+1, log=True)
    freq_std   = trial.suggest_float("frequency_min_init",     1e-1, 1e+1, log=True)
    freq_lr    = trial.suggest_float("learning_rate_frequency",1e-3, 1e-1, log=True)

    if batch_size is None:
        batch_size = trial.suggest_int("batch_size", 10,1000)

    # Build the init_std_dict exactly like we do for lr_dict
    init_std_dict = {
        "default": init_std,
        "frequency_min": freq_std,
        "quantum": trial.suggest_float("init_std_Q", 1e-1, 1e+1, log=True)
    }

    # lr_dict must include a "default" entry
    lr_dict = {
        "default": trial.suggest_float("learning_rate", 1e-3, 1e-1, log=True),
        "frequency_min": freq_lr,
        "quantum": trial.suggest_float("learning_rate_quanum_weights", 1e-3, 1e-1, log=True)
    }

    final_state, history, test_metrics = train_multiple_cls_models_vmap_lion_dynamic(
        N=1,
        model=model,
        X_train=X_train, y_train=y_train,
        X_val=  X_val,   y_val=y_val,
        X_test= X_test,  y_test=y_test,
        batch_size=batch_size,
        target_epochs=10000,
        init_std_dict=init_std_dict,
        frequency_min_init=freq_std,
        initial_wd=base_wd,
        smoothing=0.0,
        seed=np.random.randint(0,1e16),
        print_output=True,
        b1=0.9, b2=0.99,
        patience=100,
        ema_decay=1.0,
        lr_dict=lr_dict
    )

    # Use the minimum smoothed validation loss as the objective
    min_smoothed = [min(h["smoothed_val_loss"]) for h in history]
    return float(np.mean(min_smoothed))

In [2]:
num_features = 4
X_train, y_train, X_val, y_val, X_test, y_test = generate_dataset_semeion_pca(
    n_components=num_features,
    test_size=0.15,
    val_size=0.15,
    random_state=0,
)

In [3]:
num_frequencies = 1
model = QNN(
    num_features = num_features,
    num_frequencies = num_frequencies,
    layer_depth = 2,
    num_output = 10)

In [6]:
# === Create or Load the Optuna Study ===
study_name = f"fnn_num_features={num_features}_num_frequencies={num_frequencies}"
storage = "sqlite:///fnn_hpo_study.db"

study = optuna.create_study(
    study_name=study_name,
    storage=storage,
    load_if_exists=True,
    direction="minimize",  # Minimizing validation loss.
    sampler=optuna.samplers.TPESampler(),
)

# === Run the Optimization ===
study.optimize( partial(objective, model = model, batch_size = 200), n_trials=10 )

[I 2025-04-23 19:38:12,894] Using an existing study with name 'fnn_num_features=4_num_frequencies=1' instead of creating a new one.


[Epoch    0] train_loss=2.2944, val_loss=2.2857
[Epoch   50] train_loss=1.8538, val_loss=1.8424


[W 2025-04-23 19:40:47,048] Trial 4 failed with parameters: {'weight_decay': 0.3527922613435429, 'init_std': 0.2224651494774224, 'frequency_min_init': 5.227807067240266, 'learning_rate_frequency': 0.049140105964527, 'init_std_Q': 0.2572063591961185, 'learning_rate': 0.0019317711356687595, 'learning_rate_quanum_weights': 0.003853614911043268} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
  File "/tmp/ipykernel_791311/3008990426.py", line 32, in objective
    final_state, history, test_metrics = train_multiple_cls_models_vmap_lion_dynamic(
  File "/root/jupyter/tutorial/VK_Folder/HPO_with_jax/train_multiple_jax_models.py", line 231, in train_multiple_cls_models_vmap_lion_dynamic
    losses, accs = eval_step_batch(params_batched, batch, smoothing, model)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/trace

KeyboardInterrupt: 