# WaveletDiff Optuna Hyperparameter Optimization

Modern hyperparameter optimization with:
- üéØ Multi-objective optimization (loss, speed, stability)
- üß† TPESampler for intelligent search
- ‚úÇÔ∏è HyperbandPruner for early stopping
- üíæ Persistent SQLite storage (survives Colab restarts)
- üìä Optuna Dashboard for visualization

### Workflow:
1. Configure which hyperparameters to tune (Cell 1)
2. Setup environment and mount Drive (Cell 2)
3. Initialize Fabric and data (Cell 3)
4. Create Optuna study (Cell 4)
5. Launch dashboard (Cell 5) - optional
6. Run optimization (Cell 6)
7. Analyze results (Cell 7)
8. Export best configs (Cell 8)

In [None]:
# @title Cell 1: Hyperparameter Tuning Configuration

# Study Settings
STUDY_NAME = "waveletdiff_multiobjective_v1"  # @param {type:"string"}
N_TRIALS = 50  # @param {type:"integer"}
TIMEOUT_HOURS = None  # @param {type:"number"}

# Trial Settings
STEPS_PER_TRIAL = 2000  # @param {type:"integer"}
EVAL_INTERVAL = 100  # @param {type:"integer"}

# Optimization Mode
USE_MULTI_OBJECTIVE = True  # @param {type:"boolean"}
# ^ If False, uses weighted scalarization (single objective)

# Multi-Objective Weights (only used if USE_MULTI_OBJECTIVE=False)
WEIGHT_LOSS = 1.0  # @param {type:"number"}
WEIGHT_SPEED = 0.001  # @param {type:"number"}
WEIGHT_STABILITY = 0.2  # @param {type:"number"}

# Pruner Settings
ENABLE_PRUNING = True  # @param {type:"boolean"}
PRUNER_TYPE = "hyperband"  # @param ["hyperband", "median", "none"]
PRUNER_MIN_RESOURCE = 500  # @param {type:"integer"}
PRUNER_REDUCTION_FACTOR = 3  # @param {type:"integer"}

# Sampler Settings
SAMPLER_TYPE = "tpe"  # @param ["tpe", "random"]
N_STARTUP_TRIALS = 10  # @param {type:"integer"}

# Dashboard
ENABLE_DASHBOARD = True  # @param {type:"boolean"}
DASHBOARD_PORT = 8080  # @param {type:"integer"}

print("="*60)
print("HYPERPARAMETERS TO TUNE")
print("="*60)
print("Toggle each parameter ON (True) or OFF (False)\n")

# Hyperparameter Tuning Toggles
TUNE_LEARNING_RATE = True  # @param {type:"boolean"}
TUNE_MAX_LR = True  # @param {type:"boolean"}
TUNE_WEIGHT_DECAY = True  # @param {type:"boolean"}
TUNE_EMBED_DIM = True  # @param {type:"boolean"}
TUNE_NUM_HEADS = False  # @param {type:"boolean"}
TUNE_NUM_LAYERS = True  # @param {type:"boolean"}
TUNE_DROPOUT = True  # @param {type:"boolean"}
TUNE_BATCH_SIZE = True  # @param {type:"boolean"}
TUNE_PCT_START = False  # @param {type:"boolean"}
TUNE_GRAD_CLIP_NORM = False  # @param {type:"boolean"}
TUNE_TIME_EMBED_DIM = False  # @param {type:"boolean"}

# Collect tune flags
TUNE_FLAGS = {
    'learning_rate': TUNE_LEARNING_RATE,
    'max_lr': TUNE_MAX_LR,
    'weight_decay': TUNE_WEIGHT_DECAY,
    'embed_dim': TUNE_EMBED_DIM,
    'num_heads': TUNE_NUM_HEADS,
    'num_layers': TUNE_NUM_LAYERS,
    'dropout': TUNE_DROPOUT,
    'batch_size': TUNE_BATCH_SIZE,
    'pct_start': TUNE_PCT_START,
    'grad_clip_norm': TUNE_GRAD_CLIP_NORM,
    'time_embed_dim': TUNE_TIME_EMBED_DIM,
}

# Default Hyperparameters (used when tuning is disabled)
DEFAULT_HYPERPARAMS = {
    'learning_rate': 2e-4,
    'max_lr': 1e-3,
    'weight_decay': 1e-5,
    'embed_dim': 256,
    'num_heads': 8,
    'num_layers': 8,
    'dropout': 0.1,
    'batch_size': 512,
    'pct_start': 0.3,
    'grad_clip_norm': 1.0,
    'time_embed_dim': 128,
}

# Dataset Configuration
DATASET_NAME = "stocks"  # @param {type:"string"}
SEQ_LEN = 24  # @param {type:"integer"}
WAVELET_TYPE = "db2"  # @param {type:"string"}
WAVELET_LEVELS = "auto"
DATA_PATH = "src/copied_waveletDiff/data/stocks/stock_data.csv"  # @param {type:"string"}

# Paths
DRIVE_BASE_PATH = "/content/drive/MyDrive/personal_drive/trading"  # @param {type:"string"}
OPTUNA_DB_PATH = f"{DRIVE_BASE_PATH}/optuna_studies/waveletdiff.db"
CHECKPOINT_DIR = f"{DRIVE_BASE_PATH}/optuna_checkpoints/temp"
REPO_URL = "https://github.com/MilesHoffman/waveletDiff_synth_data.git"
REPO_DIR = "/content/waveletDiff_synth_data"

# Print summary
tuned_params = [k for k, v in TUNE_FLAGS.items() if v]
fixed_params = [k for k, v in TUNE_FLAGS.items() if not v]

print(f"Tuning {len(tuned_params)} parameters:")
for param in tuned_params:
    print(f"  ‚úÖ {param}")
print(f"\nFixed {len(fixed_params)} parameters:")
for param in fixed_params:
    print(f"  ‚õî {param}: {DEFAULT_HYPERPARAMS[param]}")
print("="*60)

In [None]:
# @title Cell 2: Environment Setup
import os
import sys
import subprocess

# Mount Drive
try:
    from google.colab import drive
    if os.path.exists('/content/drive'):
        if not os.listdir('/content/drive'):
            print("Force remounting Drive...")
            drive.mount('/content/drive', force_remount=True)
    else:
        drive.mount('/content/drive')
    print("‚úÖ Drive mounted")
except ImportError:
    print("Not running on Colab. Skipping Drive mount.")

# Clone Repository
if os.path.exists(REPO_DIR):
    print(f"Repo exists at {REPO_DIR}, pulling changes...")
    subprocess.run(["git", "-C", REPO_DIR, "pull"], check=True)
else:
    print(f"Cloning {REPO_URL} into {REPO_DIR}...")
    subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)

print("‚úÖ Repository ready")

# Install Dependencies
print("Installing dependencies...")
deps = ["lightning", "pywavelets", "scipy", "pandas", "tqdm", "optuna", "optuna-dashboard", "plotly", "kaleido"]
subprocess.run(["pip", "install", "-q"] + deps, check=True)
print("‚úÖ Dependencies installed")

# Setup Paths
if REPO_DIR not in sys.path:
    sys.path.append(REPO_DIR)
source_path = os.path.join(REPO_DIR, "src", "copied_waveletDiff", "src")
if source_path not in sys.path:
    sys.path.append(source_path)

# Create Directories
os.makedirs(os.path.dirname(OPTUNA_DB_PATH), exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print("‚úÖ Setup complete")

In [None]:
# @title Cell 3: Initialize Fabric
from src.torch_gpu_waveletDiff.train import trainer

# Setup Fabric
fabric = trainer.setup_fabric(precision="bf16-mixed", matmul_precision="high")

# Base Config
BASE_CONFIG = {
    'dataset': {'name': DATASET_NAME, 'seq_len': SEQ_LEN},
    'training': {'batch_size': DEFAULT_HYPERPARAMS['batch_size'], 'epochs': 1},
    'data': {'data_dir': 'src/copied_waveletDiff/data/stocks', 'normalize_data': False},
    'wavelet': {'type': WAVELET_TYPE, 'levels': WAVELET_LEVELS},
    'model': {'prediction_target': 'noise'},
    'attention': {'use_cross_level_attention': True},
    'noise': {'schedule': 'exponential'},
    'sampling': {'ddim_eta': 0.0, 'ddim_steps': None},
    'energy': {'weight': 0.0},
    'optimizer': {'scheduler_type': 'onecycle'}
}

print("‚úÖ Fabric initialized")
print(f"   Device: {fabric.device}")
print(f"   Precision: bf16-mixed")

In [None]:
# @title Cell 4: Create Optuna Study
import optuna
from optuna.pruners import HyperbandPruner, MedianPruner, NopPruner
from optuna.samplers import TPESampler, RandomSampler

# Storage
storage_url = f"sqlite:///{OPTUNA_DB_PATH}"
print(f"üìÅ Storage: {storage_url}")

# Sampler
if SAMPLER_TYPE == "tpe":
    sampler = TPESampler(
        n_startup_trials=N_STARTUP_TRIALS,
        multivariate=True,
        group=True,
        constant_liar=True
    )
    print(f"üß† Sampler: TPE (startup trials: {N_STARTUP_TRIALS})")
else:
    sampler = RandomSampler()
    print("üé≤ Sampler: Random")

# Pruner
if not ENABLE_PRUNING or PRUNER_TYPE == "none":
    pruner = NopPruner()
    print("‚úÇÔ∏è Pruner: Disabled")
elif PRUNER_TYPE == "hyperband":
    pruner = HyperbandPruner(
        min_resource=PRUNER_MIN_RESOURCE,
        reduction_factor=PRUNER_REDUCTION_FACTOR
    )
    print(f"‚úÇÔ∏è Pruner: Hyperband (min_resource: {PRUNER_MIN_RESOURCE}, reduction: {PRUNER_REDUCTION_FACTOR})")
elif PRUNER_TYPE == "median":
    pruner = MedianPruner(
        n_startup_trials=5,
        n_warmup_steps=500
    )
    print("‚úÇÔ∏è Pruner: Median")

# Create or load study
if USE_MULTI_OBJECTIVE:
    study = optuna.create_study(
        study_name=STUDY_NAME,
        storage=storage_url,
        directions=["minimize", "minimize", "minimize"],
        sampler=sampler,
        pruner=pruner,
        load_if_exists=True
    )
    print("üéØ Mode: Multi-objective (Pareto optimization)")
    print("   Objectives: [loss, step_time_ms, grad_norm_variance]")
else:
    study = optuna.create_study(
        study_name=STUDY_NAME,
        storage=storage_url,
        direction="minimize",
        sampler=sampler,
        pruner=pruner,
        load_if_exists=True
    )
    print("üéØ Mode: Single-objective (weighted scalarization)")
    print(f"   Weights: loss={WEIGHT_LOSS}, speed={WEIGHT_SPEED}, stability={WEIGHT_STABILITY}")

print(f"\nüìä Study: {STUDY_NAME}")
print(f"   Previous trials: {len(study.trials)}")
print(f"   Completed: {len([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE])}")
print(f"   Pruned: {len([t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED])}")

In [None]:
# @title Configure Ngrok Auth Token

# Get your token from: https://dashboard.ngrok.com/get-started/your-authtoken
NGROK_AUTH_TOKEN = "37jaSY6tfwZBkmLywRS4UkVdAfY_aowS32CceyAoxg9VxvfW"  # @param {type:"string"}

if NGROK_AUTH_TOKEN:
    from pyngrok import ngrok
    ngrok.set_auth_token(NGROK_AUTH_TOKEN)
    print("‚úÖ Ngrok authenticated")
else:
    print("‚ö†Ô∏è Please set your ngrok auth token")

In [None]:
# @title Cell 5: Launch Optuna Dashboard (Optional)
if ENABLE_DASHBOARD:
    import subprocess
    import time
    
    # Kill any existing dashboard
    !pkill -f "optuna-dashboard"
    
    # Start dashboard in background
    dashboard_process = subprocess.Popen(
        ["optuna-dashboard", storage_url, "--port", str(DASHBOARD_PORT)],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE
    )
    
    time.sleep(3)
    
    # Create ngrok tunnel
    try:
        from pyngrok import ngrok
        public_url = ngrok.connect(DASHBOARD_PORT)
        
        print("="*60)
        print("üé® OPTUNA DASHBOARD READY")
        print("="*60)
        print(f"üåê Public URL: {public_url}")
        print(f"üìä Study: {STUDY_NAME}")
        print("="*60)
        print("\n‚ö†Ô∏è Keep this cell running! Dashboard will stop if interrupted.")
    except ImportError:
        print("Installing pyngrok...")
        !pip install -q pyngrok
        print("Please re-run this cell after installation.")
else:
    print("üìä Dashboard disabled. Set ENABLE_DASHBOARD=True to enable.")

In [None]:
# @title Cell 6: Run Hyperparameter Optimization
from src.torch_gpu_waveletDiff.train.optuna_trainer import OptunaWaveletDiffTrainer

# Create Optuna trainer
optuna_trainer = OptunaWaveletDiffTrainer(
    fabric=fabric,
    config_base=BASE_CONFIG,
    repo_dir=REPO_DIR,
    data_path=DATA_PATH,
    tune_flags=TUNE_FLAGS,
    default_hyperparams=DEFAULT_HYPERPARAMS,
    checkpoint_dir=CHECKPOINT_DIR,
    trial_steps=STEPS_PER_TRIAL,
    eval_interval=EVAL_INTERVAL
)

# Select objective function
if USE_MULTI_OBJECTIVE:
    objective_fn = optuna_trainer.objective
else:
    objective_fn = optuna_trainer.objective_single

# Run optimization
print("="*60)
print("üöÄ STARTING OPTIMIZATION")
print("="*60)
print(f"Trials: {N_TRIALS}")
print(f"Steps per trial: {STEPS_PER_TRIAL}")
print(f"Timeout: {TIMEOUT_HOURS or 'None'} hours")
print(f"Mode: {'Multi-objective' if USE_MULTI_OBJECTIVE else 'Single-objective'}")
print(f"Tuning {len([v for v in TUNE_FLAGS.values() if v])} hyperparameters")
print("="*60)

try:
    study.optimize(
        objective_fn,
        n_trials=N_TRIALS,
        timeout=TIMEOUT_HOURS * 3600 if TIMEOUT_HOURS else None,
        n_jobs=1,
        show_progress_bar=True
    )
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è Optimization interrupted by user")

print("\n‚úÖ Optimization complete!")

In [None]:
# @title Cell 7: Analyze Results
import optuna
from optuna.visualization import (
    plot_optimization_history,
    plot_param_importances,
    plot_parallel_coordinate,
    plot_pareto_front
)

# Reload study
study = optuna.load_study(
    study_name=STUDY_NAME,
    storage=f"sqlite:///{OPTUNA_DB_PATH}"
)

print("="*60)
print("üìä OPTIMIZATION RESULTS")
print("="*60)
print(f"Total trials: {len(study.trials)}")
print(f"Completed: {len([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE])}")
print(f"Pruned: {len([t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED])}")
print(f"Failed: {len([t for t in study.trials if t.state == optuna.trial.TrialState.FAIL])}")

if USE_MULTI_OBJECTIVE:
    print("\nüéØ Top Pareto-Optimal Trials:")
    pareto_trials = study.best_trials[:5]
    for i, trial in enumerate(pareto_trials):
        print(f"\nTrial {trial.number}:")
        print(f"  Loss: {trial.values[0]:.6f}")
        print(f"  Step Time: {trial.values[1]:.2f}ms")
        print(f"  Grad Variance: {trial.values[2]:.6f}")
        print(f"  Key Params: embed_dim={trial.params.get('embed_dim', 'N/A')}, "
              f"layers={trial.params.get('num_layers', 'N/A')}, "
              f"batch={trial.params.get('batch_size', 'N/A')}")
else:
    print(f"\nüèÜ Best Trial: {study.best_trial.number}")
    print(f"   Best Value: {study.best_value:.6f}")
    print(f"   Best Params:")
    for key, value in study.best_params.items():
        print(f"      {key}: {value}")

# Visualizations
print("\nüìà Generating visualizations...")

try:
    fig1 = plot_optimization_history(study)
    fig1.show()
except:
    print("Could not plot optimization history")

try:
    if len(study.trials) > 10:
        fig2 = plot_param_importances(study)
        fig2.show()
except:
    print("Could not plot parameter importances")

try:
    fig3 = plot_parallel_coordinate(study)
    fig3.show()
except:
    print("Could not plot parallel coordinates")

if USE_MULTI_OBJECTIVE:
    try:
        fig4 = plot_pareto_front(study)
        fig4.show()
    except:
        print("Could not plot Pareto front")

print("\n‚úÖ Analysis complete")

In [None]:
# @title Cell 8: Export Best Configurations
import json

if USE_MULTI_OBJECTIVE:
    print("üéØ Exporting top 3 Pareto-optimal configurations:\n")
    best_trials = study.best_trials[:3]
else:
    print("üèÜ Exporting best configuration:\n")
    best_trials = [study.best_trial]

for i, trial in enumerate(best_trials):
    config_export = {
        "trial_number": trial.number,
        "hyperparameters": trial.params,
        "user_attrs": dict(trial.user_attrs),
        "state": str(trial.state)
    }
    
    if USE_MULTI_OBJECTIVE:
        config_export["objectives"] = {
            "loss": trial.values[0],
            "step_time_ms": trial.values[1],
            "grad_variance": trial.values[2]
        }
    else:
        config_export["objective_value"] = trial.value
    
    # Save to file
    filename = f"{CHECKPOINT_DIR}/best_config_trial_{trial.number}.json"
    with open(filename, 'w') as f:
        json.dump(config_export, f, indent=2)
    
    print(f"Trial {trial.number}:")
    print(json.dumps(config_export, indent=2))
    print(f"\nüíæ Saved to: {filename}\n")
    print("-"*60)

print("\n‚úÖ Configurations exported")
print(f"\nTo use these hyperparameters, update your training notebook with values from above.")