In [1]:
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
from optuna.visualization import plot_optimization_history, plot_param_importances

import torch
import torch.nn as nn

from stable_baselines3 import PPO

import gym

#### Optuna

In [20]:
n_trials = 25
n_jobs = 1
n_startup_trials = 5
n_evaluations = 8
n_timesteps = 1000000
eval_freq = int(n_timesteps / n_evaluations)
n_eval_envs = 5
n_eval_episodes = 10

env_id = "LunarLander-v2"


default_hyperparams = {
    "policy": "MlpPolicy",
    "env": env_id,
    "n_steps": 1024,
    "batch_size": 256,
    "n_epochs": 4,
}

In [21]:
def sample_ppo_params(trial: optuna.Trial) -> dict:
    """
    return: The sampled hyperparameters for the given trial.
    """

    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)
    batch_size = trial.suggest_categorical("batch_size", [4, 8, 16, 32, 64, 128])
    gamma = 1 - trial.suggest_float("gamma", 1e-5, 1e-1, log=True)
    gae_lambda = 1 - trial.suggest_float("gae_lambda", 1e-5, 1e-1, log=True)
    ent_coef = trial.suggest_float("ent_coef", 1e-4, 1e-1, log=True)

    
    return {
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "gamma": gamma,
        "gae_lambda": gae_lambda,
        "ent_coef": ent_coef,
    }


In [22]:
from stable_baselines3.common.callbacks import EvalCallback

class TrialEvalCallback(EvalCallback):
    """
    Callback used for evaluating and reporting a trial.
    
    :param eval_env: Evaluation environement
    :param trial: Optuna trial object
    :param n_eval_episodes: Number of evaluation episodes
    :param eval_freq:   Evaluate the agent every ``eval_freq`` call of the callback.
    :param deterministic: Whether the evaluation should
        use a stochastic or deterministic policy.
    :param verbose:
    """

    def __init__(
        self,
        eval_env: gym.Env,
        trial: optuna.Trial,
        n_eval_episodes: int = 5,
        eval_freq: int = 10000,
        deterministic: bool = True,
        verbose: int = 0,
    ):

        super().__init__(
            eval_env=eval_env,
            n_eval_episodes=n_eval_episodes,
            eval_freq=eval_freq,
            deterministic=deterministic,
            verbose=verbose,
        )
        self.trial = trial
        self.eval_idx = 0
        self.is_pruned = False

    def _on_step(self) -> bool:
        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
            # Evaluate policy (done in the parent class)
            super()._on_step()
            self.eval_idx += 1
            # Send report to Optuna
            self.trial.report(self.last_mean_reward, self.eval_idx)
            # Prune trial if need
            if self.trial.should_prune():
                self.is_pruned = True
                return False
        return True

In [23]:
from stable_baselines3.common.env_util import make_vec_env

def objective(trial: optuna.Trial) -> float:
    """
    return: The mean reward of the evaluated agent.
    """

    hyperparams = default_hyperparams.copy()
    hyperparams.update(sample_ppo_params(trial))
    
    model = PPO(**hyperparams)
    
    eval_envs = make_vec_env(env_id, n_envs=n_eval_envs)
    
    eval_callback = TrialEvalCallback(
        eval_env=eval_envs,
        trial=trial,
        n_eval_episodes=n_eval_episodes,
        eval_freq=eval_freq,
        verbose=1,
    )

    nan_encountered = False
    try:
        # Train the model
        model.learn(n_timesteps, callback=eval_callback)
    except AssertionError as e:
        # Sometimes, random hyperparams can generate NaN
        print(e)
        nan_encountered = True
    finally:
        # Free memory
        model.env.close()
        eval_envs.close()

    # Tell the optimizer that the trial failed
    if nan_encountered:
        return float("nan")

    if eval_callback.is_pruned:
        raise optuna.exceptions.TrialPruned()

    return eval_callback.last_mean_reward    

In [None]:
torch.set_num_threads(4)

sampler = TPESampler(n_startup_trials=n_startup_trials)
pruner = MedianPruner(n_startup_trials=n_startup_trials, n_warmup_steps=n_evaluations // 3)

study = optuna.create_study(
    sampler=sampler,
    pruner=pruner,
    direction="maximize",
)


try:
    study.optimize(
        objective,
        n_trials=n_trials,
        n_jobs=n_jobs,
    )
except KeyboardInterrupt:
    pass


print("Number of finished trials: ", len(study.trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

print("  User attrs:")
for key, value in trial.user_attrs.items():
    print("    {}: {}".format(key, value))

    
# Write report
study.trials_dataframe().to_csv("study_results_a2c_cartpole.csv")

fig1 = plot_optimization_history(study)
fig2 = plot_param_importances(study)

fig1.show()
fig2.show()

#### Training

In [None]:
torch.set_num_threads(4)

env = gym.make("LunarLander-v2")

model = PPO(
    policy="MlpPolicy",
    env=env,
    n_steps=1024,
    batch_size=16,
    n_epochs=4,
    learning_rate=0.0006877588267892911,
    gamma=0.0103131621676421333,
    gae_lambda=1.9666131053362384e-05,
    ent_coef=0.00020273282332977947,
    verbose=1,
)

model.learn(10_000_000)

model_name = "ppo_lunarlander-v2"
model.save(model_name)

In [2]:
model = PPO.load("ppo_lunarlander-v2.zip")

In [3]:
from stable_baselines3.common.evaluation import evaluate_policy

eval_env = gym.make("LunarLander-v2")

mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, deterministic=True)
print(f"mean_reward:{mean_reward:.2f} +/- {std_reward}")



mean_reward:-20.97 +/- 156.19383567103202
