# Pong Policy Training Baseline with Stable-Baselines3

This notebook implements a baseline policy training pipeline using Stable-Baselines3 (SB3) with vectorized environments, training directly on RAM observations. This serves as a comparison baseline for the V-JEPA2 and RSSM approaches.

## Key Improvements

- **Vectorized Environments**: Uses parallel environments for faster data collection
- **Optimized PPO**: Uses SB3's highly optimized PPO implementation
- **Better GPU Utilization**: Efficient batching and tensor operations

## Environment Notes

- **Google Colab**: Use GPU runtime for faster training
- **Local Apple Silicon (M1/M2)**: Will use CPU (MPS support in SB3 is limited)
- Adjust `n_envs` based on your CPU cores (4-8 is typical)


In [6]:
# ==========================
#  Install dependencies (run once per runtime)
# ==========================
%pip install -q "ray[rllib]" "gymnasium[atari]" ale-py tensorboard "transformers>=4.44.0"

# For Colab: Install tensorboard extension for inline viewing
try:
    from google.colab import drive
    %load_ext tensorboard
    print("Running on Google Colab - TensorBoard extension loaded")
except Exception:
    print("Running locally - TensorBoard can be viewed with: tensorboard --logdir <log_dir>")

print("Dependencies installed successfully!")
print("Note: Make sure you're using a GPU runtime (Runtime > Change runtime type > GPU) for faster training")

# ==========================
#  Imports and setup
# ==========================
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import ale_py
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from transformers import AutoVideoProcessor, AutoModel
import ray
import matplotlib.pyplot as plt
import os
from gymnasium import spaces
from gymnasium.wrappers import ResizeObservation, TransformObservation

from IPython.display import clear_output

# Colab-specific: Mount Google Drive (optional, for saving models)
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)
    USE_DRIVE = True
    DRIVE_PATH = '/content/drive/MyDrive/rl-pong'
    os.makedirs(DRIVE_PATH, exist_ok=True)
    print(f"Google Drive mounted at {DRIVE_PATH}")
except Exception:
    USE_DRIVE = False
    print("Google Drive not available (running locally)")

# GPU optimization settings
if torch.cuda.is_available():
    print(f"CUDA available: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard
Running on Google Colab - TensorBoard extension loaded
Dependencies installed successfully!
Note: Make sure you're using a GPU runtime (Runtime > Change runtime type > GPU) for faster training
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted at /content/drive/MyDrive/rl-pong
CUDA available: NVIDIA A100-SXM4-40GB
CUDA memory: 42.5 GB


In [None]:
# ==========================
#  Install dependencies (run once per runtime)
# ==========================
%pip install -q "ray[rllib]" "gymnasium[atari]" ale-py tensorboard "transformers>=4.44.0"

# For Colab: Install tensorboard extension for inline viewing
try:
    from google.colab import drive
    %load_ext tensorboard
    print("Running on Google Colab - TensorBoard extension loaded")
except Exception:
    print("Running locally - TensorBoard can be viewed with: tensorboard --logdir <log_dir>")

print("Dependencies installed successfully!")
print("Note: Make sure you're using a GPU runtime (Runtime > Change runtime type > GPU) for faster training")

# ==========================
#  Imports and setup
# ==========================
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import ale_py
import ray
import matplotlib.pyplot as plt
import os

from transformers import AutoVideoProcessor, AutoModel
from gymnasium import spaces
from gymnasium.wrappers import ResizeObservation

from ray.rllib.algorithms.ppo import PPOConfig

# Colab-specific: Mount Google Drive (optional, for saving models)
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)
    USE_DRIVE = True
    DRIVE_PATH = '/content/drive/MyDrive/rl-pong'
    os.makedirs(DRIVE_PATH, exist_ok=True)
    print(f"Google Drive mounted at {DRIVE_PATH}")
except Exception:
    USE_DRIVE = False
    print("Google Drive not available (running locally)")

# GPU info (for PPO / policy net)
if torch.cuda.is_available():
    print(f"CUDA available: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

# Register ALE envs
gym.register_envs(ale_py)

# ==========================
#  Hyperparameters
# ==========================
ENV_ID = "ALE/Pong-v5"

HF_REPO = "facebook/vjepa2-vitl-fpc64-256"
IMG_SIZE = 256
LATENT_DIM = 1024    # V-JEPA 2 ViT-L feature dim

DEVICE_PPO = "cuda" if torch.cuda.is_available() else "cpu"

# PPO hyperparams
GAMMA = 0.99
GAE_LAMBDA = 0.95
CLIP_RANGE = 0.2
LR = 2.5e-4
VF_COEF = 0.5
ENT_COEF = 0.01

# Training / parallelism
STEPS_PER_UPDATE = 2048        # per env per update
NUM_EPOCHS = 2

NUM_ENV_RUNNERS = 0            # number of RLlib env runner processes
NUM_ENVS_PER_RUNNER = 1        # envs per runner (vectorized within a process)

TOTAL_TIMESTEPS = 500_000    # start smaller; bump once working
MODEL_PATH = "ppo_pong_vjepa_latent_rllib"
EVAL_INTERVAL = 0              # disable evaluation to keep it simple

if USE_DRIVE:
    MODEL_PATH = os.path.join(DRIVE_PATH, MODEL_PATH)
    os.makedirs(MODEL_PATH, exist_ok=True)
    print(f"Models will be saved to: {MODEL_PATH}")



# ==========================
#  Global encoder cache (per process)
# ==========================
_GLOBAL_VJEPA_ENCODER = None

def get_vjepa_encoder():
    """
    Returns a (per-process) singleton V-JEPA encoder on CPU.
    This avoids re-loading the huge HF model for every env instance.
    """
    global _GLOBAL_VJEPA_ENCODER
    if _GLOBAL_VJEPA_ENCODER is None:
        _GLOBAL_VJEPA_ENCODER = FrozenVJEPAEncoder(latent_dim=LATENT_DIM, device="cpu")
    return _GLOBAL_VJEPA_ENCODER


# ==========================
#  V-JEPA encoder (CPU)
# ==========================
class FrozenVJEPAEncoder(nn.Module):
    """
    Frozen V-JEPA 2 encoder using Hugging Face.

    Expects:
      x: (B, C, H, W) float32 in [0, 1]
    Returns:
      z: (B, LATENT_DIM) pooled V-JEPA 2 features
    """

    def __init__(self, latent_dim: int, hf_repo: str = HF_REPO, device: str = "cpu"):
        super().__init__()
        self.latent_dim = latent_dim
        # IMPORTANT: For envs we keep this on CPU to avoid Ray GPU issues
        self.device = torch.device("cpu")
        self.hf_repo = hf_repo

        print(f"[VJEPA] Loading V-JEPA 2 encoder on {self.device} from HF repo: {hf_repo}")
        self.processor = AutoVideoProcessor.from_pretrained(hf_repo)
        self.model = AutoModel.from_pretrained(hf_repo).to(self.device)

        self.model.eval()
        for p in self.model.parameters():
            p.requires_grad = False

    @torch.no_grad()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, C, H, W), float32 in [0,1]
        returns: (B, latent_dim)
        """
        # [0,1] -> uint8 on CPU
        x = (x.clamp(0.0, 1.0) * 255.0).to(torch.uint8)  # (B, C, H, W)
        x_np = x.permute(0, 2, 3, 1).cpu().numpy()       # (B, H, W, C)
        images = [img for img in x_np]

        processed = self.processor(images, return_tensors="pt")
        pixel_values = processed["pixel_values_videos"].to(self.device)  # (B, T, C, H, W)

        # Repeat frames along time dimension for JEPA (simple hack)
        pixel_values = pixel_values.repeat(1, 16, 1, 1, 1)

        features = self.model.get_vision_features(pixel_values)  # (B, num_patches, D)
        z = features.mean(dim=1)  # (B, D)
        return z

# ==========================
#  Pong env that outputs V-JEPA latents
# ==========================
class PongVJEPAEnv(gym.Env):
    """
    Pong env with observations = V-JEPA latent vector (1024-D).

    Pipeline per step:
      - step ALE/Pong-v5 (obs_type='rgb')
      - Resize to (IMG_SIZE, IMG_SIZE)
      - Run FrozenVJEPAEncoder on the frame (CPU)
      - Return z ∈ ℝ^{LATENT_DIM}
    """

    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 60}

    def __init__(self, config=None):
        super().__init__()
        gym.register_envs(ale_py)

        # Base Pong env with RGB observations
        base_env = gym.make(ENV_ID, obs_type="rgb")

        # Resize to IMG_SIZE x IMG_SIZE
        self._env = ResizeObservation(base_env, (IMG_SIZE, IMG_SIZE))

        # V-JEPA encoder (CPU)
        # self.encoder = FrozenVJEPAEncoder(latent_dim=LATENT_DIM, device="cpu")
        self.encoder = get_vjepa_encoder()


        # Latent observation space
        self.observation_space = spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(LATENT_DIM,),
            dtype=np.float32,
        )
        self.action_space = self._env.action_space

    def _encode_obs(self, obs: np.ndarray) -> np.ndarray:
        """
        obs: (H, W, C) uint8 [0,255] -> z: (LATENT_DIM,) float32
        """
        x = torch.from_numpy(obs).permute(2, 0, 1).unsqueeze(0).float() / 255.0  # (1, C, H, W)
        with torch.no_grad():
            z = self.encoder(x)  # (1, D)
        return z.squeeze(0).cpu().numpy().astype(np.float32)

    def reset(self, *, seed=None, options=None):
        obs, info = self._env.reset(seed=seed, options=options)
        z = self._encode_obs(obs)
        return z, info

    def step(self, action):
        obs, reward, terminated, truncated, info = self._env.step(action)
        z = self._encode_obs(obs)
        return z, reward, terminated, truncated, info

    def render(self):
        return self._env.render()

    def close(self):
        self._env.close()

# ==========================
#  Initialize Ray
# ==========================
ray.shutdown()
ray.init(ignore_reinit_error=True, num_cpus=NUM_ENV_RUNNERS + 1)
print(f"Environment '{ENV_ID}' (V-JEPA latent) will be used with RLlib")
print("Ray initialized for RLlib")

# ==========================
#  PPO config (new API stack, default MLP on latents)
# ==========================
TRAIN_BATCH_SIZE = STEPS_PER_UPDATE


config = (
    PPOConfig()
    .framework("torch")
    .environment(PongVJEPAEnv)
    .env_runners(
        num_env_runners=NUM_ENV_RUNNERS,
        num_envs_per_env_runner=NUM_ENVS_PER_RUNNER,
        rollout_fragment_length=STEPS_PER_UPDATE,
        sample_timeout_s=600,            # <-- give it 10 minutes if needed
    )
    .training(
        lr=LR,
        gamma=GAMMA,
        lambda_=GAE_LAMBDA,
        clip_param=CLIP_RANGE,
        vf_loss_coeff=VF_COEF,
        entropy_coeff=ENT_COEF,
        train_batch_size=TRAIN_BATCH_SIZE,
        num_epochs=NUM_EPOCHS,
        model={
            "fcnet_hiddens": [256, 256],
            "fcnet_activation": "relu",
            "vf_share_layers": True,
        },
    )
    .resources(
        num_gpus=1 if torch.cuda.is_available() else 0,
    )
)

algo = config.build_algo()


print("RLlib PPO configured (new API stack, MLP on V-JEPA latents)")
print(f"Total timesteps target: {TOTAL_TIMESTEPS:,}")
print(f"Env runners: {NUM_ENV_RUNNERS}, envs/runner: {NUM_ENVS_PER_RUNNER}")
print(f"Steps per update per env: {STEPS_PER_UPDATE}")
print(f"Train batch size per update: {TRAIN_BATCH_SIZE}")

# ==========================
#  Training Loop
# ==========================
print("\nStarting training...")
print("=" * 60)

results = []
episode_rewards = []
episode_lengths = []

NUM_UPDATES = TOTAL_TIMESTEPS // TRAIN_BATCH_SIZE
print(f"Training for {NUM_UPDATES} updates (~{TOTAL_TIMESTEPS:,} timesteps)")

for i in range(NUM_UPDATES):
    result = algo.train()
    results.append(result)

    env_metrics = result.get("env_runners", {}) or {}

    reward = env_metrics.get(
        "episode_return_mean",
        result.get("episode_reward_mean", 0.0),
    )
    length = env_metrics.get(
        "episode_len_mean",
        result.get("episode_len_mean", 0.0),
    )
    timesteps = env_metrics.get(
        "num_env_steps_sampled_lifetime",
        result.get("timesteps_total", 0),
    )

    episode_rewards.append(reward)
    episode_lengths.append(length)

    print(
        f"Update {i+1:4d}/{NUM_UPDATES} | "
        f"Reward: {reward:7.2f} | "
        f"Length: {length:6.1f} | "
        f"Timesteps: {timesteps:,}"
    )

    # Optional: checkpoint every N updates
    if (i + 1) % 50 == 0:
        algo.save(MODEL_PATH)

# Save final model
checkpoint_path = algo.save(MODEL_PATH)
print(f"\nTraining finished. Model saved to {checkpoint_path}")

# ==========================
#  Final Metrics + Simple Plot
# ==========================
print("\nFinal metrics (last 100 updates):")
if episode_rewards:
    import numpy as np
    print(f"  Mean reward: {np.mean(episode_rewards[-100:]):.2f}")
    print(f"  Best reward: {max(episode_rewards):.2f}")

    window = min(20, len(episode_rewards))
    ma = np.convolve(
        episode_rewards, np.ones(window) / window, mode="valid"
    )

    plt.figure(figsize=(8, 4))
    plt.plot(episode_rewards, label="Episode return mean (per update)")
    plt.plot(range(window - 1, window - 1 + len(ma)), ma, label=f"{window}-update MA")
    plt.xlabel("Update")
    plt.ylabel("Return (env score)")
    plt.title("PPO on Pong with V-JEPA latents (RLlib)")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

print("\nTo view training progress with TensorBoard, run:")
print(f"  tensorboard --logdir {MODEL_PATH}")


The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard
Running on Google Colab - TensorBoard extension loaded
Dependencies installed successfully!
Note: Make sure you're using a GPU runtime (Runtime > Change runtime type > GPU) for faster training
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted at /content/drive/MyDrive/rl-pong
CUDA available: NVIDIA A100-SXM4-40GB
CUDA memory: 42.5 GB
Models will be saved to: /content/drive/MyDrive/rl-pong/ppo_pong_vjepa_latent_rllib


2025-11-24 03:50:26,002	INFO worker.py:2023 -- Started a local Ray instance.


Environment 'ALE/Pong-v5' (V-JEPA latent) will be used with RLlib
Ray initialized for RLlib


  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


[VJEPA] Loading V-JEPA 2 encoder on cpu from HF repo: facebook/vjepa2-vitl-fpc64-256




RLlib PPO configured (new API stack, MLP on V-JEPA latents)
Total timesteps target: 500,000
Env runners: 0, envs/runner: 1
Steps per update per env: 2048
Train batch size per update: 2048

Starting training...
Training for 244 updates (~500,000 timesteps)


[36m(pid=gcs_server)[0m [2025-11-24 03:50:53,608 E 30904 30904] (gcs_server) gcs_server.cc:303: Failed to establish connection to the event+metrics exporter agent. Events and metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
[33m(raylet)[0m [2025-11-24 03:50:55,947 E 31017 31017] (raylet) main.cc:979: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
[36m(pid=31069)[0m [2025-11-24 03:50:57,720 E 31069 31177] core_worker_process.cc:837: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14


Update    1/244 | Reward:  -21.00 | Length:  764.0 | Timesteps: 2,048
Update    2/244 | Reward:  -20.75 | Length:  859.8 | Timesteps: 4,096


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def extract_learning_curves(results):
    """Extract timesteps, mean episode return, and mean episode length from RLlib results."""
    timesteps = []
    rewards = []
    lengths = []

    for r in results:
        env_metrics = r.get("env_runners", {}) or {}

        t = env_metrics.get(
            "num_env_steps_sampled_lifetime",
            r.get("timesteps_total", 0),
        )
        rew = env_metrics.get(
            "episode_return_mean",
            r.get("episode_reward_mean", np.nan),
        )
        leng = env_metrics.get(
            "episode_len_mean",
            r.get("episode_len_mean", np.nan),
        )

        timesteps.append(t)
        rewards.append(rew)
        lengths.append(leng)

    return np.array(timesteps), np.array(rewards), np.array(lengths)


def plot_training(results, smooth_window=10):
    """
    Plot reward (and optionally episode length) vs timesteps.
    `smooth_window` applies a moving average to the reward curve.
    """
    ts, rew, leng = extract_learning_curves(results)

    # Simple moving average smoothing for rewards
    if smooth_window > 1 and len(rew) >= smooth_window:
        kernel = np.ones(smooth_window) / smooth_window
        rew_smooth = np.convolve(rew, kernel, mode="valid")
        ts_smooth = ts[smooth_window - 1 :]
    else:
        rew_smooth = rew
        ts_smooth = ts

    fig, ax1 = plt.subplots(figsize=(8, 4))

    # Reward curve
    ax1.plot(ts_smooth, rew_smooth, label="Mean episode reward")
    ax1.set_xlabel("Timesteps")
    ax1.set_ylabel("Reward")
    ax1.grid(True)

    # Optional: overlay episode length on secondary axis
    ax2 = ax1.twinx()
    ax2.plot(ts, leng, alpha=0.3, linestyle="--", label="Episode length")
    ax2.set_ylabel("Episode length")

    # Combine legends
    lines, labels = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines + lines2, labels + labels2, loc="lower right")

    plt.title("PPO Training on Pong (RAM)")
    plt.tight_layout()
    plt.show()


NameError: name 'avg_hist' is not defined

In [None]:
# ==========================
#  Watch Trained Agent (RLlib)
# ==========================
def watch_agent_rllib(checkpoint_path=None, n_episodes=3, render_mode="human"):
    """
    Renders the agent playing Pong using RLlib model.
    - If `checkpoint_path` is None, loads from MODEL_PATH.
    - render_mode: 'human' for window, 'rgb_array' for frames
    """
    if checkpoint_path is None:
        checkpoint_path = MODEL_PATH

    # Load algorithm from checkpoint
    algo = config.build()
    algo.restore(checkpoint_path)
    print(f"Loaded model from {checkpoint_path}")

    # Create environment
    env = gym.make(ENV_ID, obs_type="ram", render_mode=render_mode)

    returns = []
    for ep in range(n_episodes):
        obs, info = env.reset()
        done = False
        ep_return = 0.0
        steps = 0

        while not done:
            # RLlib: compute_action returns action
            action = algo.compute_single_action(obs, explore=False)
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            ep_return += reward
            steps += 1

        returns.append(ep_return)
        print(f"Episode {ep + 1}: return = {ep_return:.1f}, length = {steps}")

    env.close()
    print(f"\nMean return over {n_episodes} episodes: {np.mean(returns):.2f} ± {np.std(returns):.2f}")
    return returns

# Evaluate the trained model
# Uncomment to run:
# watch_agent_rllib(n_episodes=5)

A.L.E: Arcade Learning Environment (version 0.11.0+unknown)
[Powered by Stella]
  model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
Game console created:
  ROM file:  /Users/lavan/miniconda3/envs/rl-pong/lib/python3.11/site-packages/ale_py/roms/pong.bin
  Cart Name: Video Olympics (1978) (Atari)
  Cart MD5:  60e0ea3cbe0913d39803477945e9e5ec
  Display Format:  AUTO-DETECT ==> NTSC
  ROM Size:        2048
  Bankswitch Type: AUTO-DETECT ==> 2K

Running ROM file...
Random seed is 281464273


Episode 1: return = 14.0
Episode 2: return = 3.0
Episode 3: return = 16.0


In [None]:
# ==========================
#  Save Agent Videos (RLlib)
# ==========================
import imageio.v2 as imageio

def watch_agent_and_save_rllib(checkpoint_path=None, n_episodes=3, video_prefix="pong_ep", fps=30):
    """
    Runs the trained agent and saves each episode as an MP4 using RLlib model.
    """
    if checkpoint_path is None:
        checkpoint_path = MODEL_PATH

    # Load algorithm from checkpoint
    algo = config.build()
    algo.restore(checkpoint_path)
    print(f"Loaded model from {checkpoint_path}")

    # Create environment with rgb_array rendering
    env = gym.make(ENV_ID, obs_type="ram", render_mode="rgb_array")

    for ep in range(n_episodes):
        obs, _ = env.reset()
        done = False
        ep_return = 0.0
        frames = []

        while not done:
            # Render current frame
            frame = env.render()
            frames.append(frame)

            # Get action from RLlib algorithm
            action = algo.compute_single_action(obs, explore=False)
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            ep_return += reward

        # Save video
        filename = f"{video_prefix}_{ep}.mp4"
        imageio.mimsave(filename, frames, fps=fps)
        print(f"Episode {ep + 1}: return = {ep_return:.1f}, saved to {filename}")

    env.close()
    print(f"Videos saved with prefix: {video_prefix}")

# Uncomment to save videos:
# watch_agent_and_save_rllib(n_episodes=3)

A.L.E: Arcade Learning Environment (version 0.11.0+unknown)
[Powered by Stella]
Game console created:
  ROM file:  /Users/lavan/miniconda3/envs/rl-pong/lib/python3.11/site-packages/ale_py/roms/pong.bin
  Cart Name: Video Olympics (1978) (Atari)
  Cart MD5:  60e0ea3cbe0913d39803477945e9e5ec
  Display Format:  AUTO-DETECT ==> NTSC
  ROM Size:        2048
  Bankswitch Type: AUTO-DETECT ==> 2K

Running ROM file...
Random seed is -408401050
  model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))


Episode 1: return = -21.0
Episode 2: return = -21.0
Episode 3: return = -21.0
Videos saved to: videos/


## TensorBoard Integration (Colab)

View training progress in real-time using TensorBoard inline in Colab.


In [None]:
# ==========================
#  Start TensorBoard (Colab)
# ==========================
# This will display TensorBoard inline in Colab
# Run this cell after training starts to view progress

log_dir = f"{MODEL_PATH}_logs/"

try:
    from google.colab import output
    # Start TensorBoard in background
    %tensorboard --logdir {log_dir} --port 6006
except:
    print(f"TensorBoard logs available at: {log_dir}")
    print("Run: tensorboard --logdir", log_dir)


## Visualization and Evaluation Tools

Enhanced visualization functions for viewing agent performance and comparing results.


In [None]:
# ==========================
#  Enhanced Evaluation & Visualization
# ==========================
from IPython.display import HTML
from IPython import display
import pandas as pd

def evaluate_agent_comprehensive_rllib(checkpoint_path=None, n_episodes=10, render_video=True):
    """
    Comprehensive evaluation: runs episodes and optionally saves video.
    Returns stats and optionally displays video in Colab.
    """
    if checkpoint_path is None:
        checkpoint_path = MODEL_PATH

    # Load algorithm from checkpoint
    algo = config.build()
    algo.restore(checkpoint_path)
    print(f"Loaded model from {checkpoint_path}")

    env = gym.make(ENV_ID, obs_type="ram", render_mode="rgb_array")

    returns = []
    lengths = []
    videos = [] if render_video else None

    for ep in range(n_episodes):
        obs, _ = env.reset()
        done = False
        ep_return = 0.0
        steps = 0
        frames = [] if render_video else None

        while not done:
            if render_video:
                frame = env.render()
                frames.append(frame)

            action = algo.compute_single_action(obs, explore=False)
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            ep_return += reward
            steps += 1

        returns.append(ep_return)
        lengths.append(steps)
        if render_video and frames:
            videos.append(frames)

        print(f"Episode {ep+1:2d}: return = {ep_return:6.1f}, length = {steps:4d}")

    env.close()

    # Print statistics
    returns = np.array(returns)
    lengths = np.array(lengths)
    print(f"\n{'='*60}")
    print(f"Evaluation Results ({n_episodes} episodes):")
    print(f"{'='*60}")
    print(f"Mean return:  {returns.mean():7.2f} ± {returns.std():6.2f}")
    print(f"Max return:   {returns.max():7.2f}")
    print(f"Min return:   {returns.min():7.2f}")
    print(f"Mean length:  {lengths.mean():7.1f} ± {lengths.std():6.1f}")
    print(f"{'='*60}")

    # Display video in Colab (first episode)
    if render_video and videos and len(videos) > 0:
        print("\nDisplaying first episode video...")
        display_video_colab(videos[0], fps=30)

    return {
        'returns': returns,
        'lengths': lengths,
        'mean_return': returns.mean(),
        'std_return': returns.std(),
        'videos': videos
    }

def display_video_colab(frames, fps=30):
    """
    Display video inline in Colab notebook.
    """
    import imageio
    import tempfile

    # Save to temporary file
    with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp:
        tmp_path = tmp.name

    imageio.mimsave(tmp_path, frames, fps=fps)

    # Read and encode
    with open(tmp_path, 'rb') as f:
        video_data = f.read()

    video_base64 = base64.b64encode(video_data).decode('utf-8')
    video_html = f'''
    <video width="640" height="480" controls>
        <source src="data:video/mp4;base64,{video_base64}" type="video/mp4">
    </video>
    '''
    display(HTML(video_html))

    # Cleanup
    os.unlink(tmp_path)

def compare_models_rllib(checkpoint_paths, model_names=None, n_episodes=10):
    """
    Compare multiple trained models side-by-side using RLlib.
    """
    if model_names is None:
        model_names = [f"Model {i+1}" for i in range(len(checkpoint_paths))]

    results = {}
    for path, name in zip(checkpoint_paths, model_names):
        print(f"\nEvaluating {name}...")
        algo = config.build()
        algo.restore(path)
        env = gym.make(ENV_ID, obs_type="ram")

        returns = []
        for _ in range(n_episodes):
            obs, _ = env.reset()
            done = False
            ep_return = 0.0

            while not done:
                action = algo.compute_single_action(obs, explore=False)
                obs, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated
                ep_return += reward

            returns.append(ep_return)

        env.close()
        results[name] = {
            'mean': np.mean(returns),
            'std': np.std(returns),
            'max': np.max(returns),
            'returns': returns
        }
        print(f"  Mean return: {results[name]['mean']:.2f} ± {results[name]['std']:.2f}")

    # Plot comparison
    fig, ax = plt.subplots(figsize=(10, 6))
    names = list(results.keys())
    means = [results[n]['mean'] for n in names]
    stds = [results[n]['std'] for n in names]

    x = np.arange(len(names))
    ax.bar(x, means, yerr=stds, capsize=5, alpha=0.7)
    ax.set_xticks(x)
    ax.set_xticklabels(names, rotation=45, ha='right')
    ax.set_ylabel('Mean Episode Return')
    ax.set_title('Model Comparison')
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    return results

# Test evaluation (uncomment after training):
# results = evaluate_agent_comprehensive_rllib(n_episodes=5, render_video=True)


## Training Progress Dashboard

Plot comprehensive training metrics from TensorBoard or monitor logs.


In [None]:
# ==========================
#  Training Progress Dashboard (RLlib)
# ==========================
def plot_comprehensive_training_metrics_rllib(results=None):
    """
    Plot comprehensive training metrics from RLlib results.
    """
    if results is None or len(results) == 0:
        print("No training results available. Run training first.")
        return

    # Extract all available metrics
    rewards = [r.get("episode_reward_mean", 0) for r in results]
    lengths = [r.get("episode_len_mean", 0) for r in results]
    timesteps = [r.get("timesteps_total", 0) for r in results]
    policy_loss = [r.get("info", {}).get("learner", {}).get("default_policy", {}).get("policy_loss", 0) for r in results]
    value_loss = [r.get("info", {}).get("learner", {}).get("default_policy", {}).get("vf_loss", 0) for r in results]

    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Episode rewards
    axes[0, 0].plot(rewards)
    axes[0, 0].set_title('Episode Return')
    axes[0, 0].set_xlabel('Training Iteration')
    axes[0, 0].set_ylabel('Mean Return')
    axes[0, 0].grid(True, alpha=0.3)

    # Episode length
    axes[0, 1].plot(lengths)
    axes[0, 1].set_title('Episode Length')
    axes[0, 1].set_xlabel('Training Iteration')
    axes[0, 1].set_ylabel('Mean Length')
    axes[0, 1].grid(True, alpha=0.3)

    # Policy loss (if available)
    if any(policy_loss):
        axes[1, 0].plot([p for p in policy_loss if p != 0])
        axes[1, 0].set_title('Policy Loss')
        axes[1, 0].set_xlabel('Training Iteration')
        axes[1, 0].set_ylabel('Loss')
        axes[1, 0].grid(True, alpha=0.3)

    # Value loss (if available)
    if any(value_loss):
        axes[1, 1].plot([v for v in value_loss if v != 0])
        axes[1, 1].set_title('Value Loss')
        axes[1, 1].set_xlabel('Training Iteration')
        axes[1, 1].set_ylabel('Loss')
        axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    print(f"Plotting {len(results)} training iterations")

# Uncomment after training to see comprehensive plots:
# plot_comprehensive_training_metrics_rllib(results)
