In [1]:
import gym
import numpy as np

from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise

In [2]:
env = gym.make("Pendulum-v0")

# The noise objects for DDPG
n_actions = env.action_space.shape[-1]
action_noise = OrnsteinUhlenbeckActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))

In [3]:
model = DDPG("MlpPolicy", env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=10000, log_interval=10)
model.save("ddpg_pendulum")
env = model.get_env()

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/           |           |
|    ep_len_mean     | 200       |
|    ep_rew_mean     | -1.38e+03 |
| time/              |           |
|    episodes        | 10        |
|    fps             | 49        |
|    time_elapsed    | 40        |
|    total_timesteps | 2000      |
| train/             |           |
|    actor_loss      | 55.3      |
|    critic_loss     | 0.256     |
|    learning_rate   | 0.001     |
|    n_updates       | 1800      |
----------------------------------
----------------------------------
| rollout/           |           |
|    ep_len_mean     | 200       |
|    ep_rew_mean     | -1.12e+03 |
| time/              |           |
|    episodes        | 20        |
|    fps             | 39        |
|    time_elapsed    | 100       |
|    total_timesteps | 4000      |
| train/             |           |
|    actor_loss      | 83.9    

In [4]:
scores = []
for i in range(100):
    obs = env.reset()
    score = 0
    done = False
    
    while not done:
        action, _states = model.predict(obs, deterministic=True)
        obs, reward, done, info = env.step(action)
        score += reward
    
    scores.append(score)

scores

[array([-124.33081], dtype=float32),
 array([-126.95367], dtype=float32),
 array([-116.89191], dtype=float32),
 array([-321.63882], dtype=float32),
 array([-239.00049], dtype=float32),
 array([-312.9167], dtype=float32),
 array([-1.656664], dtype=float32),
 array([-1.5864592], dtype=float32),
 array([-1.4678514], dtype=float32),
 array([-121.05871], dtype=float32),
 array([-345.7332], dtype=float32),
 array([-118.3028], dtype=float32),
 array([-234.81154], dtype=float32),
 array([-126.44826], dtype=float32),
 array([-223.98578], dtype=float32),
 array([-120.044075], dtype=float32),
 array([-4.2052064], dtype=float32),
 array([-120.47144], dtype=float32),
 array([-122.239876], dtype=float32),
 array([-123.69418], dtype=float32),
 array([-115.3385], dtype=float32),
 array([-225.2278], dtype=float32),
 array([-4.332299], dtype=float32),
 array([-2.4494426], dtype=float32),
 array([-119.09205], dtype=float32),
 array([-238.41634], dtype=float32),
 array([-120.038574], dtype=float32),
 arra