In [None]:
# Example taken from https://github.com/optuna/optuna-examples/blob/main/rl/sb3_simple.py
# Relevant links
# https://optuna.readthedocs.io/en/stable/reference/generated/optuna.pruners.SuccessiveHalvingPruner.html#optuna.pruners.SuccessiveHalvingPruner
# https://pennylane.ai/blog/2022/07/lightning-fast-simulations-with-pennylane-and-the-nvidia-cuquantum-sdk/
# https://stable-baselines.readthedocs.io/en/master/guide/examples.html
# https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/master/monitor_training.ipynb#scrollTo=mPXYbV39DiCj

In [None]:
""" 
Code source https://github.com/optuna/optuna-examples/blob/main/rl/sb3_simple.py

Optuna example that optimizes the hyperparameters of
a reinforcement learning agent using A2C implementation from Stable-Baselines3
on an OpenAI Gym environment.

This is a simplified version of what can be found in https://github.com/DLR-RM/rl-baselines3-zoo.
"""
from typing import Any, Dict

import joblib
import numpy as np
import gym
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
from stable_baselines3 import PPO, DQN
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
import torch
import torch.nn as nn


N_TRIALS = 100
N_STARTUP_TRIALS = 20
N_EVALUATIONS = 10
N_TIMESTEPS = int(1e4)
EVAL_FREQ = int(N_TIMESTEPS / N_EVALUATIONS)
N_EVAL_EPISODES = 100

# policy = "ppo"
# MODEL = PPO
policy = "dqn"
MODEL = DQN


# ENV_ID = "FrozenLake-v1"
ENV_ID = "MountainCar-v0"


DEFAULT_HYPERPARAMS = {
    "policy": "MlpPolicy",
    "env": ENV_ID,
}

study_save_path = f"studies/{policy}_optuna.pkl"

random_seed = 542
np.random.seed(seed=random_seed)    

def sample_a2c_params(trial: optuna.Trial) -> Dict[str, Any]:
    """Sampler for A2C hyperparameters."""
    gamma = 1.0 - trial.suggest_float("gamma", 0.0001, 0.1, log=True)
    max_grad_norm = trial.suggest_float("max_grad_norm", 0.3, 5.0, log=True)
    gae_lambda = 1.0 - trial.suggest_float("gae_lambda", 0.001, 0.2, log=True)
    n_steps = 2 ** trial.suggest_int("exponent_n_steps", 3, 10)
    learning_rate = trial.suggest_float("lr", 1e-5, 1, log=True)
    ent_coef = trial.suggest_float("ent_coef", 0.00000001, 0.1, log=True)
    ortho_init = trial.suggest_categorical("ortho_init", [False, True])
    net_arch = trial.suggest_categorical("net_arch", ["tiny", "small"])
    activation_fn = trial.suggest_categorical("activation_fn", ["tanh", "relu"])

    # Display true values.
    trial.set_user_attr("gamma_", gamma)
    trial.set_user_attr("gae_lambda_", gae_lambda)
    trial.set_user_attr("n_steps", n_steps)

    net_arch = [
        {"pi": [64], "vf": [64]} if net_arch == "tiny" else {"pi": [64, 64], "vf": [64, 64]}
    ]

    activation_fn = {"tanh": nn.Tanh, "relu": nn.ReLU}[activation_fn]

    return {
        "n_steps": n_steps,
        "gamma": gamma,
        "gae_lambda": gae_lambda,
        "learning_rate": learning_rate,
        "ent_coef": ent_coef,
        "max_grad_norm": max_grad_norm,
        "policy_kwargs": {
            "net_arch": net_arch,
            "activation_fn": activation_fn,
            "ortho_init": ortho_init,
        },
    }


def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
    """Sampler for PPO hyperparameters."""
    learning_rate = trial.suggest_float("lr", 1e-5, 1, log=True)
    n_steps = 2 ** trial.suggest_int("exponent_n_steps", 3, 10)
    gamma = 1.0 - trial.suggest_float("gamma", 0.0001, 0.1, log=True)    

    # Display true values.
    trial.set_user_attr("gamma_", gamma)

    return {
        "learning_rate": learning_rate,
        "n_steps": n_steps,
        "gamma": gamma,
    }

def sample_dqn_params(trial: optuna.Trial) -> Dict[str, Any]:
    """Sampler for DQN hyperparameters."""
    learning_rate = trial.suggest_float("lr", 1e-5, 1, log=True)
    learning_starts = trial.suggest_categorical("ls", [5_000])  # default 50_000        
    gamma = 1.0 - trial.suggest_float("gamma", 0.0001, 0.1, log=True)    
    max_grad_norm = trial.suggest_float("max_grad_norm", 0.3, 10.0, log=True)

    # Display true values.
    trial.set_user_attr("gamma_", gamma)

    return {
        "learning_rate": learning_rate,
        "learning_starts": learning_starts,        
        "gamma": gamma,
        "max_grad_norm": max_grad_norm,
    }


class TrialEvalCallback(EvalCallback):
    """Callback used for evaluating and reporting a trial."""

    def __init__(
        self,
        eval_env: gym.Env,
        trial: optuna.Trial,
        n_eval_episodes: int = 5,
        eval_freq: int = 1_000,
        deterministic: bool = True,
        verbose: int = 0,
    ):
        super().__init__(
            eval_env=eval_env,
            n_eval_episodes=n_eval_episodes,
            eval_freq=eval_freq,
            deterministic=deterministic,
            verbose=verbose,
        )
        self.trial = trial
        self.eval_idx = 0
        self.is_pruned = False

    def _on_step(self) -> bool:
        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
            super()._on_step()
            self.eval_idx += 1
            self.trial.report(self.last_mean_reward, self.eval_idx)
            # Prune trial if need.
            if self.trial.should_prune():
                self.is_pruned = True
                return False
        return True


def objective(trial: optuna.Trial) -> float:
    kwargs = DEFAULT_HYPERPARAMS.copy()
    # Sample hyperparameters.
    if policy == "ppo":
        kwargs.update(sample_ppo_params(trial))
    elif policy == "dqn"
        kwargs.update(sample_dqn_params(trial))
    else:
        raise ValueError(f"Invalid policy {policy}")
    # Create the RL model.
    model = MODEL(**kwargs)
    # Create env used for evaluation.
    eval_env = Monitor(gym.make(ENV_ID))
    # Create the callback that will periodically evaluate and report the performance.
    eval_callback = TrialEvalCallback(
        eval_env, trial, n_eval_episodes=N_EVAL_EPISODES, eval_freq=EVAL_FREQ, deterministic=True
    )

    nan_encountered = False
    try:
        model.learn(N_TIMESTEPS, callback=eval_callback)
    except AssertionError as e:
        # Sometimes, random hyperparams can generate NaN.
        print(e)
        nan_encountered = True
    finally:
        # Free memory.
        model.env.close()
        eval_env.close()

    # Tell the optimizer that the trial failed.
    if nan_encountered:
        return float("nan")

    if eval_callback.is_pruned:
        raise optuna.exceptions.TrialPruned()

    return eval_callback.last_mean_reward


if __name__ == "__main__":
    # Set pytorch num threads to 1 for faster training.
    torch.set_num_threads(1)

    sampler = TPESampler(n_startup_trials=N_STARTUP_TRIALS)
    # Do not prune before 1/3 of the max budget is used.
    pruner = MedianPruner(n_startup_trials=N_STARTUP_TRIALS, n_warmup_steps=N_EVALUATIONS // 3)

    study = optuna.create_study(sampler=sampler, pruner=pruner, direction="maximize")
    try:
        study.optimize(objective, n_trials=N_TRIALS, timeout=600)
    except KeyboardInterrupt:
        pass

    
    if study_save_path is not None:
        joblib.dump(study, study_save_path)
    
    print("Number of finished trials: ", len(study.trials))
    
    

    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

    print("  User attrs:")
    for key, value in trial.user_attrs.items():
        print("    {}: {}".format(key, value))

In [None]:
study.best_trial.params

In [None]:
fig = optuna.visualization.plot_optimization_history(study)
fig.write_image(f"../images/{policy}_optuna_opt_frozen_lake_default.png")
fig.show()

In [None]:
optuna.visualization.plot_intermediate_values(study)