In [1]:
import torch
import numpy as np
import gymnasium as gym
from matplotlib import pyplot as plt
from IPython.display import clear_output
from stable_baselines3 import TD3

In [2]:
env_string = "LunarLander-v2"

In [3]:
def play(env, policy, steps=1000, render_every=4):
    state, info = env.reset()
    done = False
    step = 0
    rewards = []
    while not done and step < steps:
        if step % render_every == 0:
            clear_output(wait=True)
            plt.imshow(env.render())
            plt.show()
        action = policy(torch.tensor(state, dtype=torch.float32))
        state, reward, done, _, _ = env.step(action)
        rewards.append(reward)
        step += 1
    env.close()
    return np.sum(rewards)

In [4]:
env = gym.make(env_string, render_mode="rgb_array", continuous=True)

In [5]:
# create ddpg stable baselines agent
agent = TD3(
    "MlpPolicy",
    env,
    verbose=1,
)
# agent.target_policy_noise = 0.1

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [9]:
# play
agent.learn(40000, progress_bar=False)

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 999      |
|    ep_rew_mean     | -1.07    |
| time/              |          |
|    episodes        | 4        |
|    fps             | 94       |
|    time_elapsed    | 42       |
|    total_timesteps | 3996     |
| train/             |          |
|    actor_loss      | 0.256    |
|    critic_loss     | 7.71e-05 |
|    learning_rate   | 0.001    |
|    n_updates       | 43956    |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 999      |
|    ep_rew_mean     | -0.66    |
| time/              |          |
|    episodes        | 8        |
|    fps             | 72       |
|    time_elapsed    | 110      |
|    total_timesteps | 7992     |
| train/             |          |
|    actor_loss      | 0.272    |
|    critic_loss     | 5.67e-05 |
|    learning_rate   | 0.001    |
|    n_updates       | 47952    |
--------------

<stable_baselines3.td3.td3.TD3 at 0x7f9dc1e99df0>

In [8]:
policy = lambda x: agent.predict(x, deterministic=True)[0]
play(env, policy, steps=800, render_every=4)

KeyboardInterrupt: 