In [1]:
import optuna
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy

In [4]:
def optimize_ppo(trial):
    # Define the hyperparameter search space
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-2,log=True)
    n_steps = trial.suggest_int('n_steps', 2048, 4096)
    gamma = trial.suggest_float('gamma', 0.9, 0.9999)
    clip_range = trial.suggest_float('clip_range', 0.1, 0.4)
    
    # Create the environment
    env = gym.make('CartPole-v1')
    
    # Create the PPO model with the hyperparameters
    model = PPO('MlpPolicy', env, learning_rate=learning_rate, n_steps=n_steps, gamma=gamma, clip_range=clip_range, verbose=0)
    
    # Train the model
    model.learn(total_timesteps=10000)
    
    # Evaluate the model
    mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10)
    
    return mean_reward

In [5]:
study = optuna.create_study(direction='maximize')
study.optimize(optimize_ppo, n_trials=20)

[I 2024-05-20 23:32:01,591] A new study created in memory with name: no-name-f17f64a7-9333-44c0-8ebd-50668f2152ef
We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=4021 and n_envs=1)
  if not isinstance(terminated, (bool, np.bool8)):
[I 2024-05-20 23:34:44,892] Trial 0 finished with value: 424.9 and parameters: {'learning_rate': 0.0011544728932953933, 'n_steps': 4021, 'gamma': 0.9655739078237566, 'clip_range': 0.23017312846065438}. Best is trial 0 with value: 424.9.
We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=2679 and n_envs=1)
[I 2024-05-20 23:37:03,576] Trial 1 finished with value: 441.5 and parameters: {'learning_rate': 0.006648366083205107, 'n_steps': 2679, 'gamma': 0.9356482941434853, 'clip_range': 0.20748303001261154}. Best is trial 1 with value: 441.5.
We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.
Info: (n_steps=2578 and n_envs=1)
[I 2024-05-20 23:39:14,382] Trial 2 fi

KeyboardInterrupt: 

In [None]:
print('Best hyperparameters: ', study.best_params)