# Pong Policy Training Baseline with RLlib

This notebook implements a baseline policy training pipeline using Ray RLlib, training directly on visual observations without any encoder. This serves as a comparison baseline for the V-JEPA2 and RSSM approaches.

## Environment Notes

- **Google Colab**: keep the install cell and default config (2 CPUs, 1 GPU)
- **Local Apple Silicon (M1/M2)**: use the optimized configuration below (multi-core CPU, optional MPS acceleration)
- Adjust the configuration variables if you run on a different machine


In [2]:
# Install dependencies for Google Colab
# Note: Colab comes with gymnasium, but we need ray[rllib] and ale-py
%pip install -q "ray[rllib]" "gymnasium[atari]" ale-py

print("Dependencies installed successfully!")
print("Note: Make sure you're using a GPU runtime (Runtime > Change runtime type > GPU) for faster training")


Note: you may need to restart the kernel to use updated packages.
Dependencies installed successfully!
Note: Make sure you're using a GPU runtime (Runtime > Change runtime type > GPU) for faster training


In [3]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import ale_py
from dataclasses import dataclass
import matplotlib.pyplot as plt  # NEW

# register ALE envs (gymnasium)
gym.register_envs(ale_py)

# ==========================
#  Hyperparameters
# ==========================
ENV_ID = "ALE/Pong-v5"   # note: we use obs_type="ram" below
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

GAMMA = 0.99
GAE_LAMBDA = 0.95
CLIP_EPS = 0.2
LR = 2.5e-4
VALUE_COEF = 0.5
ENTROPY_COEF = 0.01

STEPS_PER_UPDATE = 2048     # how many env steps before each PPO update
PPO_EPOCHS = 4              # how many passes over the collected data
MINI_BATCH_SIZE = 256

MAX_UPDATES = 20000           # you can bump this later if you want
MODEL_PATH = "ppo_pong_ram.pt"  # NEW: file to save model


# ==========================
#  Actor-Critic Network
# ==========================
class ActorCritic(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_size=128):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )
        self.policy_head = nn.Linear(hidden_size, act_dim)
        self.value_head = nn.Linear(hidden_size, 1)

    def forward(self, x):
        x = self.shared(x)
        logits = self.policy_head(x)
        value = self.value_head(x).squeeze(-1)
        dist = torch.distributions.Categorical(logits=logits)
        return dist, value


@dataclass
class RolloutBatch:
    obs: torch.Tensor       # [N, obs_dim]
    actions: torch.Tensor   # [N]
    log_probs: torch.Tensor # [N]
    returns: torch.Tensor   # [N]
    advantages: torch.Tensor# [N]


# ==========================
#  GAE / advantage computation
# ==========================
def compute_gae(rewards, values, dones, gamma=GAMMA, lam=GAE_LAMBDA):
    """
    rewards, values, dones are 1D numpy arrays of length T
    values has an extra value for V(s_{T}) at the end (so len(values) = T+1).
    """
    T = len(rewards)
    advantages = np.zeros(T, dtype=np.float32)
    gae = 0.0
    for t in reversed(range(T)):
        mask = 1.0 - float(dones[t])
        delta = rewards[t] + gamma * values[t + 1] * mask - values[t]
        gae = delta + gamma * lam * mask * gae
        advantages[t] = gae
    returns = advantages + values[:-1]
    return advantages, returns


# ==========================
#  Collect rollout
# ==========================
def collect_rollout(env, model, steps_per_update):
    obs_buf, act_buf, logp_buf = [], [], []
    rew_buf, done_buf, val_buf = [], [], []

    obs, _ = env.reset()
    obs = np.array(obs, dtype=np.float32)

    for _ in range(steps_per_update):
        obs_tensor = torch.from_numpy(obs).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            dist, value = model(obs_tensor)
        action = dist.sample()
        log_prob = dist.log_prob(action)

        next_obs, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated
        next_obs = np.array(next_obs, dtype=np.float32)

        obs_buf.append(obs)
        act_buf.append(action.item())
        logp_buf.append(log_prob.item())
        rew_buf.append(reward)
        done_buf.append(done)
        val_buf.append(value.item())

        obs = next_obs
        if done:
            obs, _ = env.reset()
            obs = np.array(obs, dtype=np.float32)

    # final value for GAE bootstrap
    obs_tensor = torch.from_numpy(obs).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        _, last_value = model(obs_tensor)
    val_buf.append(last_value.item())

    # convert to numpy
    rewards = np.array(rew_buf, dtype=np.float32)
    dones = np.array(done_buf, dtype=np.bool_)
    values = np.array(val_buf, dtype=np.float32)

    advantages, returns = compute_gae(rewards, values, dones)

    # normalize advantages
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # to tensors
    obs_tensor = torch.tensor(np.array(obs_buf), dtype=torch.float32, device=DEVICE)
    actions_tensor = torch.tensor(np.array(act_buf), dtype=torch.long, device=DEVICE)
    logp_tensor = torch.tensor(np.array(logp_buf), dtype=torch.float32, device=DEVICE)
    adv_tensor = torch.tensor(advantages, dtype=torch.float32, device=DEVICE)
    ret_tensor = torch.tensor(returns, dtype=torch.float32, device=DEVICE)

    return RolloutBatch(
        obs=obs_tensor,
        actions=actions_tensor,
        log_probs=logp_tensor,
        returns=ret_tensor,
        advantages=adv_tensor,
    )


# ==========================
#  PPO Update
# ==========================
def ppo_update(model, optimizer, batch: RolloutBatch):
    N = batch.obs.size(0)
    idxs = np.arange(N)

    for _ in range(PPO_EPOCHS):
        np.random.shuffle(idxs)
        for start in range(0, N, MINI_BATCH_SIZE):
            end = start + MINI_BATCH_SIZE
            mb_idx = idxs[start:end]

            obs = batch.obs[mb_idx]
            actions = batch.actions[mb_idx]
            old_log_probs = batch.log_probs[mb_idx]
            returns = batch.returns[mb_idx]
            advantages = batch.advantages[mb_idx]

            dist, values = model(obs)
            new_log_probs = dist.log_prob(actions)
            entropy = dist.entropy().mean()

            # ratio = π_new / π_old
            ratio = torch.exp(new_log_probs - old_log_probs)

            # clipped surrogate objective
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1.0 - CLIP_EPS, 1.0 + CLIP_EPS) * advantages
            policy_loss = -torch.min(surr1, surr2).mean()

            value_loss = (returns - values).pow(2).mean()

            loss = policy_loss + VALUE_COEF * value_loss - ENTROPY_COEF * entropy

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


# ==========================
#  Training (with reward logging)
# ==========================
def train():
    # note render_mode=None for training
    env = gym.make(ENV_ID, obs_type="ram")
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n

    model = ActorCritic(obs_dim, act_dim, hidden_size=128).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LR)

    running_reward = None
    avg_returns_history = []   # NEW: per-update average "return" proxy
    running_history = []       # NEW: smoothed version

    for update in range(1, MAX_UPDATES + 1):
        batch = collect_rollout(env, model, STEPS_PER_UPDATE)

        # logging reward info from rollout
        approx_return = batch.returns.mean().item()
        if running_reward is None:
            running_reward = approx_return
        else:
            running_reward = 0.99 * running_reward + 0.01 * approx_return

        avg_returns_history.append(approx_return)
        running_history.append(running_reward)

        print(
            f"Update {update:4d} | "
            f"Approx return: {approx_return:6.2f} | "
            f"Running avg: {running_reward:6.2f}"
        )

        ppo_update(model, optimizer, batch)

    env.close()

    # save model
    torch.save(model.state_dict(), MODEL_PATH)
    print("Training finished. Model saved to", MODEL_PATH)

    return model, avg_returns_history, running_history

model, avg_hist, run_hist = train()

A.L.E: Arcade Learning Environment (version 0.11.0+unknown)
[Powered by Stella]
Game console created:
  ROM file:  /Users/lavan/miniconda3/envs/rl-pong/lib/python3.11/site-packages/ale_py/roms/pong.bin
  Cart Name: Video Olympics (1978) (Atari)
  Cart MD5:  60e0ea3cbe0913d39803477945e9e5ec
  Display Format:  AUTO-DETECT ==> NTSC
  ROM Size:        2048
  Bankswitch Type: AUTO-DETECT ==> 2K

Running ROM file...
Random seed is -1654354918


Update    1 | Approx return:  -0.26 | Running avg:  -0.26
Update    2 | Approx return:  -1.20 | Running avg:  -0.27
Update    3 | Approx return:  -1.51 | Running avg:  -0.29
Update    4 | Approx return:  -1.67 | Running avg:  -0.30
Update    5 | Approx return:  -1.89 | Running avg:  -0.32
Update    6 | Approx return:  -2.02 | Running avg:  -0.33
Update    7 | Approx return:  -2.18 | Running avg:  -0.35
Update    8 | Approx return:  -2.25 | Running avg:  -0.37


KeyboardInterrupt: 

In [2]:
# ==========================
#  Plot rewards
# ==========================
# ==========================
#  Plot rewards (skip initial burn-in)
# ==========================
def plot_rewards(avg_returns_history, running_history, skip=300):
    """
    Plot rewards but skip the first `skip` updates (burn-in period).
    """
    import numpy as np
    plt.figure(figsize=(8, 4))

    n = len(avg_returns_history)
    skip = min(skip, n - 1) if n > 1 else 0  # safety

    # x-axis = update index (1-based), starting after burn-in
    xs = np.arange(skip, n)

    plt.plot(xs, avg_returns_history[skip:], label="Per-update avg return")
    plt.plot(xs, running_history[skip:], label="Smoothed return")

    plt.xlabel("Update")
    plt.ylabel("Return (approx)")
    plt.title(f"PPO on Pong (GT)")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()



# ==========================
#  Watch trained agent
# ==========================
def watch_agent(model=None, n_episodes=3):
    """
    Renders the agent playing Pong.
    - If `model` is None, loads weights from MODEL_PATH.
    - Uses render_mode='human' (window) – good for local runs.
      In a notebook, switch to render_mode='rgb_array' and draw frames manually.
    """
    # for local desktop: render_mode="human"
    env = gym.make(ENV_ID, obs_type="ram", render_mode="human")

    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n

    if model is None:
        model = ActorCritic(obs_dim, act_dim, hidden_size=128).to(DEVICE)
        model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.eval()

    for ep in range(n_episodes):
        obs, _ = env.reset()
        obs = np.array(obs, dtype=np.float32)
        done = False
        ep_return = 0.0

        while not done:
            obs_tensor = torch.from_numpy(obs).unsqueeze(0).to(DEVICE)
            with torch.no_grad():
                dist, _ = model(obs_tensor)
                action = dist.probs.argmax(dim=-1)  # greedy for demo
            next_obs, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated
            ep_return += reward
            obs = np.array(next_obs, dtype=np.float32)

        print(f"Episode {ep + 1}: return = {ep_return}")

    env.close()


plot_rewards(avg_hist, run_hist)


NameError: name 'avg_hist' is not defined

In [12]:
# ==========================
#  Watch trained agent
# ==========================
def watch_agent(model=None, n_episodes=3):
    """
    Renders the agent playing Pong.
    - If `model` is None, loads weights from MODEL_PATH.
    - Uses render_mode='human' (window) – good for local runs.
      In a notebook, switch to render_mode='rgb_array' and draw frames manually.
    """
    # for local desktop: render_mode="human"
    env = gym.make(ENV_ID, obs_type="ram", render_mode="human")

    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n

    if model is None:
        model = ActorCritic(obs_dim, act_dim, hidden_size=128).to(DEVICE)
        model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.eval()

    for ep in range(n_episodes):
        obs, _ = env.reset()
        obs = np.array(obs, dtype=np.float32)
        done = False
        ep_return = 0.0

        while not done:
            obs_tensor = torch.from_numpy(obs).unsqueeze(0).to(DEVICE)
            with torch.no_grad():
                dist, _ = model(obs_tensor)
                action = dist.probs.argmax(dim=-1)  # greedy for demo
            next_obs, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated
            ep_return += reward
            obs = np.array(next_obs, dtype=np.float32)

        print(f"Episode {ep + 1}: return = {ep_return}")

    env.close()

watch_agent(None, n_episodes=3)

A.L.E: Arcade Learning Environment (version 0.11.0+unknown)
[Powered by Stella]
Game console created:
  ROM file:  /Users/lavan/miniconda3/envs/rl-pong/lib/python3.11/site-packages/ale_py/roms/pong.bin
  Cart Name: Video Olympics (1978) (Atari)
  Cart MD5:  60e0ea3cbe0913d39803477945e9e5ec
  Display Format:  AUTO-DETECT ==> NTSC
  ROM Size:        2048
  Bankswitch Type: AUTO-DETECT ==> 2K

Running ROM file...
Random seed is -1569516431
  model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))


Episode 1: return = 8.0
Episode 2: return = 12.0
Episode 3: return = 16.0


In [10]:
import imageio.v2 as imageio

def watch_agent_and_save(model=None, n_episodes=3, video_prefix="pong_ep", fps=30):
    """
    Runs the trained agent, saves each episode as an MP4.
    Assumes obs preprocessing matches training (adjust /255.0 if needed).
    """
    env = gym.make(ENV_ID, obs_type="ram", render_mode="rgb_array")

    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n

    if model is None:
        model = ActorCritic(obs_dim, act_dim, hidden_size=128).to(DEVICE)
        model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.eval()

    for ep in range(n_episodes):
        obs, _ = env.reset()
        # IMPORTANT: match training preprocessing here
        obs = np.array(obs, dtype=np.float32)           # or /255.0 if you trained that way
        done = False
        ep_return = 0.0
        frames = []

        while not done:
            # render current frame
            frame = env.render()        # H x W x 3 RGB
            frames.append(frame)

            obs_tensor = torch.from_numpy(obs).unsqueeze(0).to(DEVICE)
            with torch.no_grad():
                dist, _ = model(obs_tensor)
                action = dist.probs.argmax(dim=-1)  # greedy
            next_obs, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated
            ep_return += reward
            obs = np.array(next_obs, dtype=np.float32)  # or /255.0

        print(f"Episode {ep + 1}: return = {ep_return}")

        # save video
        filename = f"{video_prefix}{ep}.mp4"
        imageio.mimsave(filename, frames, fps=fps)
        print(f"Saved {filename}")

    env.close()

watch_agent(None, n_episodes=3)

A.L.E: Arcade Learning Environment (version 0.11.0+unknown)
[Powered by Stella]
Game console created:
  ROM file:  /Users/lavan/miniconda3/envs/rl-pong/lib/python3.11/site-packages/ale_py/roms/pong.bin
  Cart Name: Video Olympics (1978) (Atari)
  Cart MD5:  60e0ea3cbe0913d39803477945e9e5ec
  Display Format:  AUTO-DETECT ==> NTSC
  ROM Size:        2048
  Bankswitch Type: AUTO-DETECT ==> 2K

Running ROM file...
Random seed is -408401050
  model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))


Episode 1: return = -21.0
Episode 2: return = -21.0
Episode 3: return = -21.0
Videos saved to: videos/
