In [1]:
!pip install stable-baselines3
!pip install gymnasium
!pip install flappy-bird-gymnasium
!pip install imageio
!pip install pygame

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.13->stable-baselines3)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.13->stable-baselines3)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.13->stable-baselines3)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.13->stable-baselines3)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.13->stable-baselines3)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.13->stable-baselines3)
  Downloading nvidia_cufft_cu12-11.2.

In [5]:
import os
import argparse
from pathlib import Path
import warnings
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback, StopTrainingOnRewardThreshold
from gym.wrappers import RecordVideo

os.environ["XDG_RUNTIME_DIR"] = "/tmp"
warnings.filterwarnings("ignore")

def ensure_dir(path):
    Path(path).mkdir(parents=True, exist_ok=True)

def default_output_dir(save_dir):
    out_dir = Path(save_dir) / "flappy_ppo"
    ensure_dir(out_dir)
    return out_dir

def make_flappy_env(render_mode=None, seed=0):
    def _init():
        env = gym.make("FlappyBird-v0", render_mode=render_mode)
        env.reset(seed=seed)
        return env
    return _init

def make_vec_envs(n_envs=1, seed=0):
    envs = DummyVecEnv([make_flappy_env(seed=i+seed) for i in range(n_envs)])
    envs = VecMonitor(envs)
    return envs

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--timesteps", type=int, default=1_000_000)
    p.add_argument("--save_dir", type=str, default="outputs")
    p.add_argument("--reward_threshold", type=float, default=100.0)
    p.add_argument("--render_video", action="store_true")
    return p.parse_args(args=["--render_video"])

def main():
    args = parse_args()
    out_dir = default_output_dir(args.save_dir)
    log_dir = out_dir / "logs"
    ensure_dir(log_dir)

    train_env = make_vec_envs(1)
    eval_env = make_vec_envs(1, seed=42)

    model = PPO(
        "MlpPolicy",
        train_env,
        verbose=1,
        tensorboard_log=str(log_dir),
        device="cuda",
        n_steps=10000,
        batch_size=512
    )

    checkpoint_callback = CheckpointCallback(
        save_freq=100_000,
        save_path=str(out_dir / "checkpoints"),
        name_prefix="ppo_flappy"
    )

    stop_callback = StopTrainingOnRewardThreshold(
        reward_threshold=args.reward_threshold,
        verbose=1
    )

    eval_callback = EvalCallback(
        eval_env,
        callback_on_new_best=checkpoint_callback,
        best_model_save_path=str(out_dir / "best_model"),
        log_path=str(out_dir / "eval_logs"),
        eval_freq=50_000,
        n_eval_episodes=5,
        deterministic=True,
        render=False,
        callback_after_eval=stop_callback
    )

    model.learn(total_timesteps=args.timesteps, callback=[eval_callback])
    model.save(str(out_dir / "final_model.zip"))

    if args.render_video:
        video_path = out_dir / "flappy_video"
        ensure_dir(video_path)

        env = gym.make("FlappyBird-v0", render_mode="rgb_array")
        env = RecordVideo(env, str(video_path), episode_trigger=lambda x: True)

        best_model_path = out_dir / "best_model" / "best_model.zip"
        model = PPO.load(str(best_model_path))

        obs, _ = env.reset(seed=0)
        done = False
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            action = int(action)

            step_result = env.step(action)
            if len(step_result) == 5:
                obs, reward, terminated, truncated, info = step_result
                done = terminated or truncated
            else:
                obs, reward, done, info = step_result

        env.close()

if __name__ == "__main__":
    main()

Using cpu device
Logging to outputs/flappy_ppo/logs/PPO_4
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 50       |
|    ep_rew_mean     | -7.47    |
| time/              |          |
|    fps             | 436      |
|    iterations      | 1        |
|    time_elapsed    | 22       |
|    total_timesteps | 10000    |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 50          |
|    ep_rew_mean          | -6.276001   |
| time/                   |             |
|    fps                  | 422         |
|    iterations           | 2           |
|    time_elapsed         | 47          |
|    total_timesteps      | 20000       |
| train/                  |             |
|    approx_kl            | 0.015574029 |
|    clip_fraction        | 0.0905      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.679      |
|    explained