In [None]:
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print

import gymnasium as gym

import wandb
from loguru import logger

In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="ray_test",
    sync_tensorboard=True
)

algo = (
    PPOConfig()
    .rollouts(num_rollout_workers=8)
    .resources(num_gpus=0)
    .environment(env="CartPole-v1")
    .build()
)

In [None]:
logger.info("starting training")

for i in range(20):
    logger.info(f"training step {i}")
    result = algo.train()

    if i % 5 == 0:
        checkpoint_dir = algo.save().checkpoint.path
        logger.debug(f"Checkpoint saved in directory {checkpoint_dir}")

logger.success(f"succesfully trained network")

In [None]:
import numpy as np
import cv2

size = (600, 400)

result = cv2.VideoWriter('test.webm',  
                         cv2.VideoWriter_fourcc(*'VP90'), 
                         25, size) 

env = gym.make("CartPole-v1", render_mode="rgb_array")

obs, info = env.reset()

logger.debug(f"started rendering")

frames = []
image = None

for i in range(500):
    action = algo.compute_single_action(
        observation=obs,
    )
    obs, reward, done, truncated, _ = env.step(action)
    image = env.render()
    
    image_array = np.asanyarray(image, dtype=np.uint8).reshape(400 ,600 ,3)
    result.write(image_array)
    frames.append(image_array)

    if done:
        logger.success(f"done!")
        break

result.release()

In [None]:
wandb.log({"video": wandb.Video("test.webm", format="mp4")})
wandb.finish()