In [1]:
import gymnasium
import math
from gymnasium.envs.classic_control.cartpole import CartPoleEnv
from gymnasium import logger, spaces
import numpy as np
from gymnasium.envs.registration import register


In [16]:
class SwingPole(CartPoleEnv):
    """https://gymnasium.farama.org/environments/classic_control/cart_pole/"""
    def __init__(self, render_mode: str | None = None):
        super().__init__(render_mode)
        self.x_threshold = 3
        self.step_num = 0
        high = np.array(
            [
                self.x_threshold * 2,
                np.finfo(np.float32).max,
                np.finfo(np.float32).max,
                # 4*math.pi,
                np.finfo(np.float32).max,
            ],
            dtype=np.float32,
        )

        self.action_space = spaces.Discrete(2)
        self.observation_space = spaces.Box(-high, high, dtype=np.float32)
    
    def step(self, action):
        obs, rew, _, _, info =  super().step(action)
        self.step_num += 1
        term = False
        x, x_dot, theta, theta_dot = self.state
        term = bool(
            x < -self.x_threshold
            or x > self.x_threshold
            or self.step_num > 1000
            or abs(theta_dot) < 0.01
        )
        rew = theta_dot

        return obs, rew, term, False, info

    def reset(self, *, seed: int | None = None, options: dict | None = None):
        self.step_num = 0
        return super().reset(seed=seed, options=options)        

In [17]:
register(
     id="SwingPole",
     entry_point=SwingPole,
)

env = gymnasium.make("SwingPole", render_mode='rgb_array')

  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


In [4]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

vec_env = make_vec_env(SwingPole, n_envs=10)
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="./ppo_swingpole_tensorboard/")
model.learn(total_timesteps=10000, tb_log_name="PPO_SwingPole")
model.save("ppo_swingpole")

2025-07-10 17:38:52.089600: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1752169132.102227    2388 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1752169132.106061    2388 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1752169132.115576    2388 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752169132.115589    2388 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1752169132.115591    2388 computation_placer.cc:177] computation placer alr

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to ./ppo_swingpole_tensorboard/PPO_SwingPole_2


  logger.warn(


---------------------------------
| rollout/           |          |
|    ep_len_mean     | 133      |
|    ep_rew_mean     | 686      |
| time/              |          |
|    fps             | 976      |
|    iterations      | 1        |
|    time_elapsed    | 2        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 150         |
|    ep_rew_mean          | 401         |
| time/                   |             |
|    fps                  | 799         |
|    iterations           | 2           |
|    time_elapsed         | 5           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008893371 |
|    clip_fraction        | 0.0394      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.685      |
|    explained_variance   | -0.00775    |
|    learning_rate        | 0.

In [19]:
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import VecVideoRecorder
vec_env = make_vec_env(SwingPole, n_envs=1, env_kwargs={"render_mode": "rgb_array"})
vec_env.render_mode = "rgb_array"
vec_env = VecVideoRecorder(
    venv=vec_env,
    video_folder="videos",
    record_video_trigger=lambda x: True,
    video_length=200,
    name_prefix="ppo_swingpole",
)
mean_reward, std_reward = evaluate_policy(model, vec_env, n_eval_episodes=2)


AssertionError: 

In [None]:
from ray.rllib.algorithms.ppo import PPOConfig

# Build an RLlib PPO trainer for SwingPole
rllib_trainer = (
    PPOConfig()
    .environment(env=SwingPole, env_config={'render_mode': 'rgb_array'})
    .rollouts(num_rollout_workers=1)
    .resources(num_gpus=0)
    .build()
)

# Training loop
for i in range(10):
    result = rllib_trainer.train()
    print(f"Iteration {i}: mean reward = {result['episode_reward_mean']}")

# Save a checkpoint
ckpt_path = rllib_trainer.save()
print("Checkpoint saved at:", ckpt_path)

In [None]:
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms import PPO
from ray.tune.logger import pretty_print


alg = (
    PPOConfig()
    .rollouts(num_rollout_workers=1)
    .framework('tf2')
    #  # this is a simple way to change the default ANN. Can also pass a custom model object instead. 
    # # By default, Ray will look at your obs and actions and use a reasonable ANN--ie a small dense network for vector inputs, a small CNN for image inputs
    # .training(model={'fcnet': [64, 64]})
    .resources(num_gpus=0)
    .environment(env=SwingPole, render_env=True, env_config={'render_mode': 'rgb_array'})
    .build()
)
iterations = 10
for i in range(iterations):
    result = alg.train()
    # print(pretty_print(result))

    # # This could be used as a very simple way to save checkpoints in lieu of telling Ray to handle it with configs
    # if i % 5 == 0:
    #     checkpoint_dir = algo.save().checkpoint.path
    #     print(f"Checkpoint saved in directory {checkpoint_dir}")

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2024-07-25 17:27:03,619	INFO worker.py:1762 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
[36m(RolloutWorker pid=146542)[0m   logger.warn(


In [None]:
import moviepy as mpy

episode_reward = 0
terminated = truncated = False
obs, info = env.reset()
img_list = []
while not terminated and not truncated:
    action = alg.compute_single_action(obs)
    obs, reward, terminated, truncated, info = env.step(action)
    episode_reward += reward
    img_list.append(env.render())

print('episode reward was: ', episode_reward)
clip = mpy.ImageSequenceClip(img_list, fps=30)
clip.write_videofile('cartswing.mp4', logger=None)

episode reward was:  2574.0651628366572


PPO is not deterministic here, the same obs is used each action compute but it does not always choose the same direction to push the cart.

In [7]:
obs, info = env.reset()
for i in range(100):
    print("observation used", obs)
    action = alg.compute_single_action(obs)
    print("action chosen", action)

observation used [-0.01838353 -0.03525767  0.00583611 -0.03221606]
action chosen 0
observation used [-0.01838353 -0.03525767  0.00583611 -0.03221606]
action chosen 0
observation used [-0.01838353 -0.03525767  0.00583611 -0.03221606]
action chosen 0
observation used [-0.01838353 -0.03525767  0.00583611 -0.03221606]
action chosen 0
observation used [-0.01838353 -0.03525767  0.00583611 -0.03221606]
action chosen 0
observation used [-0.01838353 -0.03525767  0.00583611 -0.03221606]
action chosen 0
observation used [-0.01838353 -0.03525767  0.00583611 -0.03221606]
action chosen 0
observation used [-0.01838353 -0.03525767  0.00583611 -0.03221606]
action chosen 0
observation used [-0.01838353 -0.03525767  0.00583611 -0.03221606]
action chosen 0
observation used [-0.01838353 -0.03525767  0.00583611 -0.03221606]
action chosen 0
observation used [-0.01838353 -0.03525767  0.00583611 -0.03221606]
action chosen 1
observation used [-0.01838353 -0.03525767  0.00583611 -0.03221606]
action chosen 0
obse