# Pong Policy Training Baseline with RLlib

This notebook implements a baseline policy training pipeline using Ray RLlib, training directly on visual observations without any encoder. This serves as a comparison baseline for the V-JEPA2 and RSSM approaches.

## Google Colab Setup

**Important**: This notebook is optimized for Google Colab. Make sure to:
1. Enable GPU runtime: `Runtime > Change runtime type > GPU`
2. Run the installation cell first to install dependencies
3. The configuration is optimized for Colab's limited resources (2 CPUs, 1 GPU)


In [None]:
# Install dependencies for Google Colab
# Note: Colab comes with gymnasium, but we need ray[rllib] and ale-py
%pip install -q "ray[rllib]" "gymnasium[atari]" ale-py

print("Dependencies installed successfully!")
print("Note: Make sure you're using a GPU runtime (Runtime > Change runtime type > GPU) for faster training")


In [None]:
import os
import numpy as np
import gymnasium as gym
import ale_py
from gymnasium import spaces
import matplotlib.pyplot as plt
from pathlib import Path

# Register ALE environments
gym.register_envs(ale_py)

# Configuration
GAME_ID = os.environ.get("ATARI_GAME", "PongNoFrameskip-v4")
NUM_TRAIN_ITERATIONS = int(os.environ.get("NUM_TRAIN_ITERATIONS", "100"))
EVAL_EPISODES = int(os.environ.get("NUM_EVAL_EPISODES", "5"))

print(f"Game: {GAME_ID}")
print(f"Training iterations: {NUM_TRAIN_ITERATIONS}")
print(f"Evaluation episodes: {EVAL_EPISODES}")


## Environment Setup

Create the Pong environment with standard preprocessing (resize to 84x84, normalize).


In [None]:
def transform_obs(obs):
    """Transform observation to (C, H, W) format and normalize to [0, 1]."""
    obs_t = np.transpose(obs, (2, 0, 1)).astype(np.float32) / 255.0
    return obs_t

# Create environment with wrappers
env = gym.make(GAME_ID)
env = gym.wrappers.ResizeObservation(env, (84, 84))

new_obs_space = spaces.Box(low=0.0, high=1.0, shape=(3, 84, 84), dtype=np.float32)

env = gym.wrappers.TransformObservation(
    env,
    func=transform_obs,
    observation_space=new_obs_space,
)

# Test environment
obs, info = env.reset()
assert obs.shape == (3, 84, 84), f"Expected (3, 84, 84), got {obs.shape}"
print(f"Environment initialized successfully")
print(f"Observation shape: {obs.shape}")
print(f"Action space: {env.action_space}")
print(f"Number of actions: {env.action_space.n}")


## RLlib Training Setup

Configure Ray RLlib to train a policy using PPO (Proximal Policy Optimization), which works well for Atari games.


In [None]:
import ray
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig

# Check if GPU is available
import torch
has_gpu = torch.cuda.is_available()
print(f"GPU available: {has_gpu}")

# Initialize Ray for Google Colab
# Colab typically has 2 CPUs, so we use fewer workers
if not ray.is_initialized():
    # Colab-friendly initialization: use available CPUs, enable GPU if available
    ray.init(
        ignore_reinit_error=True,
        num_cpus=2,  # Colab typically has 2 CPUs
        num_gpus=1 if has_gpu else 0,
        object_store_memory=2_000_000_000,  # 2GB object store
    )
    print("Ray initialized for Colab")
else:
    print("Ray already initialized")


In [None]:
# Create a function to register and return the environment
def env_creator(env_config):
    """Create and return the Pong environment."""
    import gymnasium as gym
    import ale_py
    from gymnasium import spaces
    
    gym.register_envs(ale_py)
    
    env = gym.make(GAME_ID)
    env = gym.wrappers.ResizeObservation(env, (84, 84))
    
    def transform_obs(obs):
        obs_t = np.transpose(obs, (2, 0, 1)).astype(np.float32) / 255.0
        return obs_t
    
    new_obs_space = spaces.Box(low=0.0, high=1.0, shape=(3, 84, 84), dtype=np.float32)
    env = gym.wrappers.TransformObservation(
        env,
        func=transform_obs,
        observation_space=new_obs_space,
    )
    
    return env

# Register the environment
tune.register_env("pong_env", env_creator)
print("Environment registered with RLlib")


In [None]:
# Configure PPO algorithm for Colab
# Colab-friendly settings: fewer workers, smaller batches, GPU if available
num_env_runners = 1  # Colab has limited CPUs, use 1 env runner
use_gpu = has_gpu  # Use GPU if available

# Note: minibatch_size is automatically calculated by RLlib from train_batch_size
# Updated to use new RLlib API (env_runners instead of rollouts)
config = (
    PPOConfig()
    .environment("pong_env")
    .framework("torch")
    .training(
        lr=3e-4,
        train_batch_size=1000,  # Reduced further to see episode completion sooner
        # minibatch_size will be auto-calculated (typically train_batch_size / 4)
        num_epochs=30,  # Updated from num_sgd_iter (deprecated)
        gamma=0.99,
        lambda_=0.95,
        clip_param=0.2,
        entropy_coeff=0.01,
        vf_loss_coeff=0.5,
    )
    .resources(
        num_gpus=1 if use_gpu else 0,
    )
    .env_runners(
        num_env_runners=num_env_runners,  # Updated from num_rollout_workers
        num_envs_per_env_runner=1,  # Updated from num_envs_per_worker
        num_cpus_per_env_runner=1,  # Updated from num_cpus_per_worker in resources
    )
    .evaluation(
        evaluation_interval=10,
        evaluation_duration=EVAL_EPISODES,
        evaluation_num_env_runners=0,  # Set to 0 to disable separate eval workers (Colab resource constraint)
    )
)

print("PPO configuration (Colab-optimized):")
print(f"  Learning rate: {config.lr}")
print(f"  Training batch size: {config.train_batch_size}")
print(f"  Number of epochs: {config.num_epochs}")
print(f"  Number of env runners: {config.num_env_runners}")
print(f"  Evaluation env runners: {config.evaluation_num_env_runners} (0 = use training workers)")
print(f"  GPU enabled: {use_gpu}")
print(f"  Gamma (discount): {config.gamma}")
print("  Note: Minibatch size is auto-calculated by RLlib")
print("  Note: Evaluation uses training workers to avoid CPU resource conflicts")


## Diagnostics

Check Ray cluster status and algorithm readiness before training.


In [None]:
# Check Ray cluster status
import ray
print("Ray cluster status:")
cluster_resources = {}
available_resources = {}
try:
    cluster_resources = ray.cluster_resources()
    available_resources = ray.available_resources()
    print(f"  Total CPUs: {cluster_resources.get('CPU', 0)}")
    print(f"  Available CPUs: {available_resources.get('CPU', 0)}")
    print(f"  Total GPUs: {cluster_resources.get('GPU', 0)}")
    print(f"  Available GPUs: {available_resources.get('GPU', 0)}")
    used_cpus = cluster_resources.get('CPU', 0) - available_resources.get('CPU', 0)
    print(f"  Used CPUs: {used_cpus:.1f}")
except Exception as e:
    print(f"  Could not get cluster resources: {e}")

print("\nAlgorithm configuration check:")
try:
    num_train_cpus = config.num_env_runners * getattr(config, 'num_cpus_per_env_runner', 1)
    num_eval_cpus = config.evaluation_num_env_runners * getattr(config, 'num_cpus_per_env_runner', 1)
    total_required = num_train_cpus + num_eval_cpus
    print(f"  Required CPUs (training): {num_train_cpus}")
    print(f"  Required CPUs (evaluation): {num_eval_cpus}")
    print(f"  Total required: {total_required}")
    
    if cluster_resources:
        available_cpus = cluster_resources.get('CPU', 0)
        if total_required > available_cpus:
            print(f"  ⚠️  WARNING: Required CPUs ({total_required}) > Available CPUs ({available_cpus})")
            print(f"     This may cause training to hang. Consider reducing num_env_runners or setting evaluation_num_env_runners=0")
        else:
            print(f"  ✓ Resource requirements are within available resources")
    else:
        print(f"  (Could not verify against available resources)")
except Exception as e:
    print(f"  Could not check configuration: {e}")
    print(f"  Config attributes: num_env_runners={config.num_env_runners}, evaluation_num_env_runners={config.evaluation_num_env_runners}")


## Build Algorithm

**Expected time: 1-3 minutes**

If this takes longer than 5 minutes, it may be stuck. Check:
1. Ray cluster status (diagnostics cell above)
2. No CPU resource warnings
3. If stuck, restart Ray and rebuild


## Training

Train the policy using RLlib's PPO algorithm.


In [None]:
from ray.rllib.algorithms.ppo import PPO
import time

# Build the algorithm (using new API method)
print("Building algorithm (this may take 1-3 minutes on first run)...")
print("Components being initialized:")
print("  - Neural network models")
print("  - Environment runners")
print("  - Learner group")
print("  - ALE environment (Pong ROM)")
print()

build_start = time.time()
try:
    algo = config.build_algo()
    build_time = time.time() - build_start
    print(f"\n✓ Algorithm built successfully in {build_time:.1f} seconds ({build_time/60:.1f} minutes)")
except Exception as e:
    build_time = time.time() - build_start
    print(f"\n✗ Algorithm build failed after {build_time:.1f} seconds")
    print(f"Error: {e}")
    raise

# Access the model - new API stack uses modules differently
try:
    # Try new API stack access
    if hasattr(algo, 'learner_group'):
        module = algo.learner_group.get_module()
        print(f"Policy module type: {type(module)}")
    elif hasattr(algo, 'get_policy'):
        # Old API stack
        policy = algo.get_policy()
        print(f"Policy network: {policy.model}")
    else:
        print("Algorithm built, but model access method not found")
        print(f"Algorithm type: {type(algo)}")
        print(f"Available attributes: {[attr for attr in dir(algo) if not attr.startswith('_')]}")
except Exception as e:
    print(f"Note: Could not access model directly: {e}")
    print("Algorithm is ready for training")


In [None]:
# Training loop
training_history = {
    "episode_reward_mean": [],
    "episode_reward_min": [],
    "episode_reward_max": [],
    "episode_len_mean": [],
}

print(f"Starting training for {NUM_TRAIN_ITERATIONS} iterations...")
print("=" * 60)
print("Note: First iteration may take longer to initialize environments")
print("=" * 60)

import time
start_time = time.time()

for i in range(NUM_TRAIN_ITERATIONS):
    iteration_start = time.time()
    
    # Train for one iteration
    try:
        result = algo.train()
    except Exception as e:
        print(f"\nERROR at iteration {i+1}: {e}")
        print("Training failed. Check resource constraints and Ray status.")
        break
    
    # Diagnostic check on first iteration: print all available metrics
    if i == 0:
        print("\nFirst iteration diagnostics:")
        print(f"  Result type: {type(result)}")
        print(f"  Result keys (first 20): {list(result.keys())[:20]}")
        
        # Check for environment runner metrics (new API - note: plural "env_runners")
        if "env_runners" in result:
            env_metrics = result["env_runners"]
            print(f"  Env runners metrics type: {type(env_metrics)}")
            if isinstance(env_metrics, dict):
                print(f"  Env runners keys: {list(env_metrics.keys())[:20]}")
                # Print some values
                for k in list(env_metrics.keys())[:10]:
                    print(f"    {k}: {env_metrics[k]}")
        elif "env_runner" in result:
            env_metrics = result["env_runner"]
            print(f"  Env runner (singular) metrics type: {type(env_metrics)}")
            if isinstance(env_metrics, dict):
                print(f"  Env runner keys: {list(env_metrics.keys())[:15]}")
        
        # Check for learner metrics
        if "learners" in result:
            learner_metrics = result["learners"]
            print(f"  Learners metrics keys: {list(learner_metrics.keys())[:15] if isinstance(learner_metrics, dict) else 'N/A'}")
        
        # Look for episode metrics in different places
        print("\n  Searching for episode/reward metrics:")
        for key in result.keys():
            if "episode" in key.lower() or "reward" in key.lower():
                val = result[key]
                print(f"    {key}: {val}")
        
        # Check env_runners dict for episode metrics
        if "env_runners" in result and isinstance(result["env_runners"], dict):
            print("\n  env_runners metrics:")
            for key in result["env_runners"].keys():
                if "episode" in key.lower() or "reward" in key.lower() or "step" in key.lower():
                    val = result["env_runners"][key]
                    print(f"    {key}: {val}")
            
            # Also check num_env_steps_sampled to see if steps are being collected
            if "num_env_steps_sampled" in result.get("env_runners", {}):
                print(f"\n  Total steps sampled: {result['env_runners']['num_env_steps_sampled']}")
        print()
    
    # Store metrics - new RLlib API uses different key names
    # In env_runners: episode_return_* (not episode_reward_*)
    # Also check agent_episode_return_mean for agent-specific rewards
    env_runners_metrics = result.get("env_runners", {})
    env_runner_metrics = result.get("env_runner", {})  # Also check singular for compatibility
    
    # Try env_runners first (new API), then env_runner, then top-level
    if isinstance(env_runners_metrics, dict) and len(env_runners_metrics) > 0:
        # New API uses "episode_return" not "episode_reward"
        episode_reward_mean = env_runners_metrics.get("episode_return_mean",
                                                      env_runners_metrics.get("episode_reward_mean",
                                                      result.get("episode_return_mean",
                                                                 result.get("episode_reward_mean", 0))))
        episode_reward_min = env_runners_metrics.get("episode_return_min",
                                                     env_runners_metrics.get("episode_reward_min",
                                                     result.get("episode_return_min",
                                                                result.get("episode_reward_min", 0))))
        episode_reward_max = env_runners_metrics.get("episode_return_max",
                                                     env_runners_metrics.get("episode_reward_max",
                                                     result.get("episode_return_max",
                                                                result.get("episode_reward_max", 0))))
        episode_len_mean = env_runners_metrics.get("episode_len_mean",
                                                   result.get("episode_len_mean", 0))
        
        # If we got rewards from agent_episode_return_mean, use that (more accurate)
        agent_returns = env_runners_metrics.get("agent_episode_return_mean", {})
        if isinstance(agent_returns, dict) and len(agent_returns) > 0:
            # Use the first agent's return (typically "default_agent")
            first_agent_return = list(agent_returns.values())[0]
            if first_agent_return != 0 or episode_reward_mean == 0:
                episode_reward_mean = first_agent_return
    elif isinstance(env_runner_metrics, dict) and len(env_runner_metrics) > 0:
        episode_reward_mean = env_runner_metrics.get("episode_return_mean",
                                                     env_runner_metrics.get("episode_reward_mean",
                                                     result.get("episode_return_mean",
                                                                result.get("episode_reward_mean", 0))))
        episode_reward_min = env_runner_metrics.get("episode_return_min",
                                                    env_runner_metrics.get("episode_reward_min",
                                                    result.get("episode_return_min",
                                                               result.get("episode_reward_min", 0))))
        episode_reward_max = env_runner_metrics.get("episode_return_max",
                                                    env_runner_metrics.get("episode_reward_max",
                                                    result.get("episode_return_max",
                                                               result.get("episode_reward_max", 0))))
        episode_len_mean = env_runner_metrics.get("episode_len_mean",
                                                  result.get("episode_len_mean", 0))
    else:
        # Fallback to top-level keys
        episode_reward_mean = result.get("episode_return_mean", result.get("episode_reward_mean", 0))
        episode_reward_min = result.get("episode_return_min", result.get("episode_reward_min", 0))
        episode_reward_max = result.get("episode_return_max", result.get("episode_reward_max", 0))
        episode_len_mean = result.get("episode_len_mean", 0)
    
    training_history["episode_reward_mean"].append(episode_reward_mean)
    training_history["episode_reward_min"].append(episode_reward_min)
    training_history["episode_reward_max"].append(episode_reward_max)
    training_history["episode_len_mean"].append(episode_len_mean)
    
    iteration_time = time.time() - iteration_start
    elapsed_time = time.time() - start_time
    
    # Print progress (always print first iteration, then every 10)
    if (i + 1) % 10 == 0 or i == 0:
        # Use the same values we stored (already extracted from correct location)
        mean_reward = episode_reward_mean
        min_reward = episode_reward_min
        max_reward = episode_reward_max
        mean_len = episode_len_mean
        
        # Get episode count - new API uses "num_episodes" not "episodes_this_iter"
        if isinstance(env_runners_metrics, dict):
            num_episodes = env_runners_metrics.get("num_episodes",
                                                   env_runners_metrics.get("episodes_this_iter",
                                                   result.get("num_episodes",
                                                              result.get("episodes_this_iter", 0))))
        elif isinstance(env_runner_metrics, dict):
            num_episodes = env_runner_metrics.get("num_episodes",
                                                  env_runner_metrics.get("episodes_this_iter",
                                                  result.get("num_episodes",
                                                             result.get("episodes_this_iter", 0))))
        else:
            num_episodes = result.get("num_episodes", result.get("episodes_this_iter", 0))
        
        # Check if we have other reward metrics
        env_runner_metrics = result.get("env_runner", {})
        if isinstance(env_runner_metrics, dict):
            actual_rewards = env_runner_metrics.get("episode_reward_mean", mean_reward)
        else:
            actual_rewards = mean_reward
        
        print(f"Iteration {i+1}/{NUM_TRAIN_ITERATIONS} | "
              f"Mean Reward: {mean_reward:.2f} | "
              f"Min: {min_reward:.2f} | Max: {max_reward:.2f} | "
              f"Episodes: {num_episodes} | "
              f"Mean Len: {mean_len:.1f} | "
              f"Time: {iteration_time:.1f}s")
        
        # Diagnostic: Print all result keys on first iteration
        if i == 0:
            print(f"  Available metrics keys: {[k for k in result.keys() if 'reward' in k.lower() or 'episode' in k.lower()][:10]}")
        
    elif i < 5:  # Also print first 5 iterations to show it's working
        mean_reward = episode_reward_mean
        mean_len = episode_len_mean
        
        # Get episode count - new API uses "num_episodes" not "episodes_this_iter"
        if isinstance(env_runners_metrics, dict):
            num_episodes = env_runners_metrics.get("num_episodes",
                                                   env_runners_metrics.get("episodes_this_iter",
                                                   result.get("num_episodes",
                                                              result.get("episodes_this_iter", 0))))
        elif isinstance(env_runner_metrics, dict):
            num_episodes = env_runner_metrics.get("num_episodes",
                                                  env_runner_metrics.get("episodes_this_iter",
                                                  result.get("num_episodes",
                                                             result.get("episodes_this_iter", 0))))
        else:
            num_episodes = result.get("num_episodes", result.get("episodes_this_iter", 0))
        print(f"Iteration {i+1}/{NUM_TRAIN_ITERATIONS} | "
              f"Mean Reward: {mean_reward:.2f} | "
              f"Episodes: {num_episodes} | "
              f"Mean Len: {mean_len:.1f} | "
              f"Time: {iteration_time:.1f}s")

print("=" * 60)
total_time = time.time() - start_time
print(f"Training complete! Total time: {total_time/60:.1f} minutes")
print(f"Final mean reward: {training_history['episode_reward_mean'][-1] if training_history['episode_reward_mean'] else 'N/A'}")


## Results Visualization


In [None]:
%matplotlib inline

plt.figure(figsize=(12, 5))

# Plot 1: Episode rewards
plt.subplot(1, 2, 1)
plt.plot(training_history["episode_reward_mean"], label="Mean Reward", linewidth=2)
plt.fill_between(
    range(len(training_history["episode_reward_mean"])),
    training_history["episode_reward_min"],
    training_history["episode_reward_max"],
    alpha=0.3,
    label="Min-Max Range"
)
plt.xlabel("Training Iteration")
plt.ylabel("Episode Reward")
plt.title("Training Progress - Episode Rewards")
plt.legend()
plt.grid(True, alpha=0.3)

# Plot 2: Episode length
plt.subplot(1, 2, 2)
plt.plot(training_history["episode_len_mean"], label="Mean Episode Length", color="green", linewidth=2)
plt.xlabel("Training Iteration")
plt.ylabel("Episode Length")
plt.title("Training Progress - Episode Length")
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final statistics
final_mean = training_history["episode_reward_mean"][-1] if training_history["episode_reward_mean"] else 0
best_mean = max(training_history["episode_reward_mean"]) if training_history["episode_reward_mean"] else 0
print(f"\nFinal mean reward: {final_mean:.2f}")
print(f"Best mean reward: {best_mean:.2f}")
if "Pong" in GAME_ID:
    print(f"Pong theoretical max: 21 (winning 21-0)")


## Evaluation

Run evaluation episodes to see how well the trained policy performs.


In [None]:
# Run evaluation
eval_results = algo.evaluate()

print("Evaluation Results:")
print(f"  Mean episode reward: {eval_results.get('evaluation/episode_reward_mean', 'N/A'):.2f}")
print(f"  Min episode reward: {eval_results.get('evaluation/episode_reward_min', 'N/A'):.2f}")
print(f"  Max episode reward: {eval_results.get('evaluation/episode_reward_max', 'N/A'):.2f}")
print(f"  Mean episode length: {eval_results.get('evaluation/episode_len_mean', 'N/A'):.2f}")


In [None]:
# Manual evaluation: run a few episodes and record video
import imageio
from PIL import Image, ImageDraw

def evaluate_and_record(algo, env, num_episodes=3, save_video=True):
    """Run evaluation episodes and optionally save videos."""
    # Handle both old and new API stacks
    def get_action(obs):
        try:
            # Try new API stack first (algo.compute_single_action)
            if hasattr(algo, 'compute_single_action'):
                result = algo.compute_single_action(obs, explore=False)
                # Result might be action directly or tuple
                return result[0] if isinstance(result, (list, tuple)) else result
            # Fallback to old API stack
            elif hasattr(algo, 'get_policy'):
                policy = algo.get_policy()
                return policy.compute_single_action(obs, explore=False)[0]
            else:
                raise AttributeError("No method found to compute actions")
        except Exception as e:
            print(f"Error computing action: {e}")
            # Return a random action as fallback
            return env.action_space.sample()
    
    all_rewards = []
    
    for ep in range(num_episodes):
        obs, info = env.reset(seed=1000 + ep)
        done = False
        total_reward = 0.0
        frames = []
        
        while not done:
            # Get action from policy (works with both API stacks)
            action = get_action(obs)
            
            # Render frame
            frame = (np.clip(obs, 0, 1) * 255).astype(np.uint8).transpose(1, 2, 0)
            
            # Add action overlay
            img = Image.fromarray(frame)
            draw = ImageDraw.Draw(img)
            txt = f"action={action}, reward={total_reward:.1f}"
            draw.rectangle([0, 0, 200, 20], fill=(0, 0, 0, 200))
            draw.text((5, 3), txt, fill=(255, 255, 255))
            frame = np.array(img)
            frames.append(frame)
            
            # Step environment
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total_reward += reward
        
        all_rewards.append(total_reward)
        
        # Save video
        if save_video:
            video_path = f"rllib_eval_episode_{ep+1}.mp4"
            with imageio.get_writer(video_path, fps=30, macro_block_size=1) as writer:
                for f in frames:
                    writer.append_data(f)
            print(f"Episode {ep+1}: Reward = {total_reward:.1f}, Saved to {video_path}")
        else:
            print(f"Episode {ep+1}: Reward = {total_reward:.1f}")
    
    print(f"\nEvaluation Summary:")
    print(f"  Mean reward: {np.mean(all_rewards):.2f}")
    print(f"  Std reward: {np.std(all_rewards):.2f}")
    print(f"  Min reward: {np.min(all_rewards):.2f}")
    print(f"  Max reward: {np.max(all_rewards):.2f}")
    
    return all_rewards

# Run evaluation
eval_rewards = evaluate_and_record(algo, env, num_episodes=EVAL_EPISODES, save_video=True)


## Save and Load Model

Save the trained model for later use.


In [None]:
# Save the trained model
checkpoint_path = algo.save("./rllib_pong_checkpoint")
print(f"Model saved to: {checkpoint_path}")

# Example: Load the model later
# from ray.rllib.algorithms.ppo import PPO
# algo_loaded = PPO.from_checkpoint(checkpoint_path)
# print("Model loaded successfully")


## Notes

### Google Colab Considerations

- **GPU Runtime**: Enable GPU for faster training (Runtime > Change runtime type > GPU)
- **Resource Limits**: Configuration uses 1 worker and reduced batch sizes to fit Colab's constraints
- **Memory**: If you encounter OOM errors, reduce `train_batch_size` further
- **Ray Shutdown**: Colab may require restarting the runtime if Ray doesn't shut down cleanly

### Baseline Comparison

This notebook provides a baseline for comparing against:
- **V-JEPA2 encoder approach**: Uses pre-trained frozen encoder
- **RSSM approach**: Uses learned dynamics model

### Key Differences

- **No encoder**: Policy learns directly from raw visual observations
- **RLlib framework**: Uses well-tested RL algorithms (PPO) with built-in optimizations
- **Standard preprocessing**: 84x84 RGB frames, normalized to [0, 1]
- **Colab-optimized**: Reduced workers and batch sizes for Colab's resource constraints

### Configuration

You can adjust training parameters via environment variables:
- `NUM_TRAIN_ITERATIONS`: Number of training iterations (default: 100)
- `NUM_EVAL_EPISODES`: Number of evaluation episodes (default: 5)
- `ATARI_GAME`: Game environment (default: PongNoFrameskip-v4)

### Next Steps

- Compare training curves with V-JEPA2 and RSSM approaches
- Experiment with different RLlib algorithms (IMPALA, A3C, etc.)
- Tune hyperparameters for better performance
- Add frame stacking for temporal information


In [None]:
# Cleanup
try:
    env.close()
except:
    pass

# Shutdown Ray (recommended for Colab to free resources)
# Note: In Colab, you may need to restart the runtime if Ray doesn't shut down cleanly
try:
    ray.shutdown()
    print("Ray shut down successfully")
except:
    print("Ray shutdown encountered an issue. You may need to restart the Colab runtime.")
