In [1]:
import optuna
from typing import Dict, Any

def sample_dqn_params(trial: optuna.Trial) -> Dict[str, Any]:
    """
    Sampler for DQN hyperparams.

    :param trial:
    :return:
    """
    gamma = trial.suggest_categorical("gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999])
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1, log=True)
    batch_size = trial.suggest_categorical("batch_size", [16, 32, 64, 100, 128, 256, 512])
    buffer_size = trial.suggest_categorical("buffer_size", [int(1e4), int(5e4), int(1e5), int(1e6)])
    exploration_final_eps = trial.suggest_float("exploration_final_eps", 0, 0.2)
    exploration_fraction = trial.suggest_float("exploration_fraction", 0, 0.5)
    target_update_interval = trial.suggest_categorical("target_update_interval", [1, 1000, 5000, 10000, 15000, 20000])
    learning_starts = trial.suggest_categorical("learning_starts", [0, 1000, 5000, 10000, 20000])

    train_freq = trial.suggest_categorical("train_freq", [1, 4, 8, 16, 128, 256, 1000])
    subsample_steps = trial.suggest_categorical("subsample_steps", [1, 2, 4, 8])
    gradient_steps = max(train_freq // subsample_steps, 1)

    net_arch_type = trial.suggest_categorical("net_arch", ["tiny", "small", "medium"])
    net_arch = {"tiny": [64], "small": [64, 64], "medium": [256, 256]}[net_arch_type]

    hyperparams = {
        "gamma": gamma,
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "buffer_size": buffer_size,
        "train_freq": train_freq,
        "gradient_steps": gradient_steps,
        "exploration_fraction": exploration_fraction,
        "exploration_final_eps": exploration_final_eps,
        "target_update_interval": target_update_interval,
        "learning_starts": learning_starts,
        "policy_kwargs": dict(net_arch=net_arch),
    }

    return hyperparams

In [6]:
from stable_baselines3.common.callbacks import EvalCallback
import gymnasium

class TrialEvalCallback(EvalCallback):
    """Callback used for evaluating and reporting a trial."""

    def __init__(
        self,
        eval_env: gymnasium.Env,
        trial: optuna.Trial,
        n_eval_episodes: int = 5,
        eval_freq: int = 10000,
        deterministic: bool = True,
        verbose: int = 0,
    ):
        super().__init__(
            eval_env=eval_env,
            n_eval_episodes=n_eval_episodes,
            eval_freq=eval_freq,
            deterministic=deterministic,
            verbose=verbose,
        )
        self.trial = trial
        self.eval_idx = 0
        self.is_pruned = False

    def _on_step(self) -> bool:
        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
            super()._on_step()
            self.eval_idx += 1
            self.trial.report(self.last_mean_reward, self.eval_idx)
            # Prune trial if need.
            if self.trial.should_prune():
                self.is_pruned = True
                return False
        return True

In [7]:
from stable_baselines3 import DQN
from stable_baselines3.common.monitor import Monitor

ENV_ID = "CartPole-v1"
DEFAULT_HYPERPARAMS = {
    "policy": "MlpPolicy",
    "env": ENV_ID,
}
N_EVAL_EPISODES = int(1e3)
EVAL_FREQ = int(1e2)
N_TIMESTEPS = int(1e3)

def objective(trial: optuna.Trial) -> float:
    # Sample hyperparameters.
    kwargs = sample_dqn_params(trial)
    # Create the RL model.
    model = DQN(policy='MlpPolicy', env=ENV_ID, **kwargs)
    # Create env used for evaluation.
    eval_env = Monitor(gymnasium.make(ENV_ID))
    # Create the callback that will periodically evaluate and report the performance.
    eval_callback = TrialEvalCallback(
        eval_env, trial, n_eval_episodes=N_EVAL_EPISODES, eval_freq=EVAL_FREQ, deterministic=True
    )

    nan_encountered = False
    try:
        model.learn(N_TIMESTEPS, callback=eval_callback)
    except AssertionError as e:
        # Sometimes, random hyperparams can generate NaN.
        print(e)
        nan_encountered = True
    finally:
        # Free memory.
        model.env.close()
        eval_env.close()

    # Tell the optimizer that the trial failed.
    if nan_encountered:
        return float("nan")

    if eval_callback.is_pruned:
        raise optuna.exceptions.TrialPruned()

    return eval_callback.last_mean_reward

In [8]:
import torch
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner

N_STARTUP_TRIALS = int(1e1)
N_EVALUATIONS = int(3e1)
N_TRIALS = int(1e1)

# Set pytorch num threads to 1 for faster training.
torch.set_num_threads(1)

sampler = TPESampler(n_startup_trials=N_STARTUP_TRIALS)
# Do not prune before 1/3 of the max budget is used.
pruner = MedianPruner(n_startup_trials=N_STARTUP_TRIALS, n_warmup_steps=N_EVALUATIONS // 3)

study = optuna.create_study(sampler=sampler, pruner=pruner, direction="maximize")
study.optimize(objective, n_trials=N_TRIALS)

[I 2024-03-08 19:11:29,221] A new study created in memory with name: no-name-53575a89-9597-49ec-bb68-a5f92a0c5f70
[I 2024-03-08 19:11:59,272] Trial 0 finished with value: 9.357 and parameters: {'gamma': 0.9999, 'learning_rate': 0.00677781642597295, 'batch_size': 256, 'buffer_size': 100000, 'exploration_final_eps': 0.18133605293970886, 'exploration_fraction': 0.45938155778571194, 'target_update_interval': 1000, 'learning_starts': 10000, 'train_freq': 4, 'subsample_steps': 2, 'net_arch': 'small'}. Best is trial 0 with value: 9.357.
[I 2024-03-08 19:12:35,461] Trial 1 finished with value: 12.836 and parameters: {'gamma': 0.9, 'learning_rate': 0.00015247067585542835, 'batch_size': 100, 'buffer_size': 50000, 'exploration_final_eps': 0.1932131373000979, 'exploration_fraction': 0.07141960799847397, 'target_update_interval': 20000, 'learning_starts': 20000, 'train_freq': 1000, 'subsample_steps': 1, 'net_arch': 'tiny'}. Best is trial 1 with value: 12.836.
[I 2024-03-08 19:13:15,633] Trial 2 fin

Number of finished trials:  8
Best trial:
  Value:  122.157
  Params: 
    gamma: 0.9999
    learning_rate: 0.00010095877974258623
    batch_size: 16
    buffer_size: 1000000
    exploration_final_eps: 0.1842663942898469
    exploration_fraction: 0.28841046597674574
    target_update_interval: 1000
    learning_starts: 20000
    train_freq: 16
    subsample_steps: 8
    net_arch: tiny
  User attrs:


In [9]:
from optuna.importance import get_param_importances

get_param_importances(study)    

{'batch_size': 0.21263297674406287,
 'exploration_final_eps': 0.20437989752519453,
 'train_freq': 0.12007779742006812,
 'gamma': 0.10538012421820936,
 'learning_rate': 0.09017272483090387,
 'learning_starts': 0.08332788446280265,
 'subsample_steps': 0.06423318695428099,
 'exploration_fraction': 0.058822222974243527,
 'net_arch': 0.05475177690790495,
 'target_update_interval': 0.006209247066744871,
 'buffer_size': 1.2160895584220742e-05}