# Proximal Policy Optimization

Here we have actors and critics and I tried to create a Ping-Pong Agent with this model. To visulize the training metrics, I used `wandb` and used the PPO model from `stable_baseline3`. 

In [None]:
import wandb
import gymnasium as gym

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack, VecVideoRecorder
from stable_baselines3.common.callbacks import CheckpointCallback

from wandb.integration.sb3 import WandbCallback

import ale_py

In [None]:
gym.register_envs(ale_py)

config = {
    "env_name": "PongNoFrameskip-v4",
    "num_envs": 8,
    "total_timesteps": int(10e6),
    "seed": 42,    
}

In [None]:
run = wandb.init(
    project="PPO_Pong",
    config = config,
    sync_tensorboard = True,
    monitor_gym = True,
    save_code = True,
    )

This part is important where I tried to stack 4 frames and passed it as input to the model. This allows for the Agent to analyze the motion of the opponent and also the ball to prepare a counter.

In [None]:
# Creating the Environment & Loading Frame Stacking
env = make_atari_env(config["env_name"], n_envs=config["num_envs"], seed=config["seed"]) #PongNoFrameskip-v4

print("Environment Action Space: ", env.action_space.n)

# Frame-stacking with 4 frames
env = VecFrameStack(env, n_stack=4)

In [None]:
model = PPO(policy = "CnnPolicy",
            env = env,
            batch_size = 256,
            clip_range = 0.1,
            ent_coef = 0.01,
            gae_lambda = 0.9,
            gamma = 0.99,
            learning_rate = 2.5e-4,
            max_grad_norm = 0.5,
            n_epochs = 4,
            n_steps = 128,
            vf_coef = 0.5,
            tensorboard_log = f"runs",
            verbose=1,
            )

## The Training Script

In [None]:
# Video Recorder for WandB integration/validation recording
env = VecVideoRecorder(env, "videos", record_video_trigger=lambda x: x % 100000 == 0, video_length=2000)

# Main Training Script - I reccomend (and personally did) run this on a GPU. The free t4 GPU on Colab is a great option!
model.learn(
    total_timesteps = config["total_timesteps"],
    callback = [
        WandbCallback(
        gradient_save_freq = 1000,
        model_save_path = f"models/{run.id}",
        ), 
        CheckpointCallback(save_freq=10000, save_path='./pong',
                                         name_prefix=config["env_name"]),
        ]
)

model.save("ppo-PongNoFrameskip-v4.zip")

## View Model Performance

In [None]:
# Load the saved model
model = PPO.load("pong\PongNoFrameskip-v4_5680000_steps.zip")

# Create and wrap the environment
env = make_atari_env("PongNoFrameskip-v4", n_envs=1, seed=53)
env = VecFrameStack(env, n_stack=4)

# Run the model
obs = env.reset()
for _ in range(5000):
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render("human")

    if dones:
        obs = env.reset()

env.close()

## Save as Video
This cell saves the model performance in a `videos` folder.

In [None]:
import os
import numpy as np

# Ensure the videos directory exists
os.makedirs("videos", exist_ok=True)

# Load the trained model
model = PPO.load("pong/PongNoFrameskip-v4_5680000_steps.zip")

# Create and wrap the environment
env = make_atari_env("PongNoFrameskip-v4", n_envs=1, seed=153)
env = VecFrameStack(env, n_stack=4)

# Wrap with video recorder
env = VecVideoRecorder(
    env,
    video_folder="videos",                    # save videos here
    record_video_trigger=lambda step: step == 0,  # record first rollout
    video_length=3000,                        # number of steps to record
    name_prefix="ppo_pong_eval"               # filename prefix
)

# Run the model
obs = env.reset()
for _ in range(3000):
    action, _ = model.predict(obs, deterministic=True)
    obs, rewards, dones, infos = env.step(action)

    # 'dones' is an array since this is a VecEnv
    if np.any(dones):
        obs = env.reset()

# Important: this finalizes and writes the MP4
env.close()

print("Video saved in ./videos/")