In [1]:
import gymnasium
import math
from gymnasium.envs.classic_control.cartpole import CartPoleEnv
from gymnasium import logger, spaces
import numpy as np
from gymnasium.envs.registration import register


2024-07-17 20:40:23.182372: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-17 20:40:23.182405: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-17 20:40:23.183151: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-17 20:40:23.188097: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
class SwingPole(CartPoleEnv):
    """https://gymnasium.farama.org/environments/classic_control/cart_pole/"""
    def __init__(self, render_mode: str | None = None):
        super().__init__(render_mode)
        self.x_threshold = 3
        self.step_num = 0
        high = np.array(
            [
                self.x_threshold * 2,
                np.finfo(np.float32).max,
                np.finfo(np.float32).max,
                # 4*math.pi,
                np.finfo(np.float32).max,
            ],
            dtype=np.float32,
        )

        self.action_space = spaces.Discrete(2)
        self.observation_space = spaces.Box(-high, high, dtype=np.float32)
    
    def step(self, action):
        obs, rew, _, _, info =  super().step(action)
        self.step_num += 1
        term = False
        x, x_dot, theta, theta_dot = self.state
        term = bool(
            x < -self.x_threshold
            or x > self.x_threshold
            or self.step_num > 1000
            or abs(theta_dot) < 0.01
        )
        rew = theta_dot

        return obs, rew, term, False, info

    def reset(self, *, seed: int | None = None, options: dict | None = None):
        self.step_num = 0
        return super().reset(seed=seed, options=options)        

IndentationError: expected an indented block after class definition on line 1 (1112221372.py, line 2)

In [3]:
register(
     id="SwingPole",
     entry_point=SwingPole,
)

env = gymnasium.make("SwingPole", max_episode_steps=1000, render_mode='rgb_array')

In [5]:
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms import PPO
from ray.tune.logger import pretty_print


alg = (
    PPOConfig()
    .rollouts(num_rollout_workers=1)
    .framework('tf2')
    #  # this is a simple way to change the default ANN. Can also pass a custom model object instead. 
    # # By default, Ray will look at your obs and actions and use a reasonable ANN--ie a small dense network for vector inputs, a small CNN for image inputs
    # .training(model={'fcnet': [64, 64]})
    .resources(num_gpus=0)
    .environment(env=SwingPole, render_env=True, env_config={'render_mode': 'rgb_array'})
    .build()
)
iterations = 10
for i in range(iterations):
    result = alg.train()
    # print(pretty_print(result))

    # # This could be used as a very simple way to save checkpoints in lieu of telling Ray to handle it with configs
    # if i % 5 == 0:
    #     checkpoint_dir = algo.save().checkpoint.path
    #     print(f"Checkpoint saved in directory {checkpoint_dir}")

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
  self.pid = _fork_exec(
  self.pid = _fork_exec(
2024-04-19 16:16:14,545	INFO worker.py:1715 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
[36m(pid=62404)[0m 2024-04-19 16:16:15.988863: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable

In [6]:
import moviepy.editor as mpy

episode_reward = 0
terminated = truncated = False
obs, info = env.reset()
img_list = []
while not terminated and not truncated:
    action = alg.compute_single_action(obs)
    obs, reward, terminated, truncated, info = env.step(action)
    episode_reward += reward
    img_list.append(env.render())

print('episode reward was: ', episode_reward)
clip = mpy.ImageSequenceClip(img_list, fps=30)
clip.write_videofile('cartswing.mp4', logger=None)

2024-04-19 16:17:51.639447: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:933] Skipping loop optimization for Merge node with control input: cond/then/_0/cond/cond/branch_executed/_61


episode reward was:  281.3212336766097


PPO is not deterministic here, the same obs is used each action compute but it does not always choose the same direction to push the cart.

In [7]:
obs, info = env.reset()
for i in range(100):
    print("observation used", obs)
    action = alg.compute_single_action(obs)
    print("action chosen", action)

observation used [-0.004715    0.022643    0.0129916  -0.03718836]
action chosen 0
observation used [-0.004715    0.022643    0.0129916  -0.03718836]
action chosen 0
observation used [-0.004715    0.022643    0.0129916  -0.03718836]
action chosen 0
observation used [-0.004715    0.022643    0.0129916  -0.03718836]
action chosen 1
observation used [-0.004715    0.022643    0.0129916  -0.03718836]
action chosen 0
observation used [-0.004715    0.022643    0.0129916  -0.03718836]
action chosen 0
observation used [-0.004715    0.022643    0.0129916  -0.03718836]
action chosen 0
observation used [-0.004715    0.022643    0.0129916  -0.03718836]
action chosen 1
observation used [-0.004715    0.022643    0.0129916  -0.03718836]
action chosen 1
observation used [-0.004715    0.022643    0.0129916  -0.03718836]
action chosen 0
observation used [-0.004715    0.022643    0.0129916  -0.03718836]
action chosen 0
observation used [-0.004715    0.022643    0.0129916  -0.03718836]
action chosen 0
obse