# 💎 Diamond PPO Demo

This notebook demonstrates the core features of Diamond PPO, a lightweight PyTorch implementation of Proximal Policy Optimization.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/auxeno/diamond-ppo/blob/main/notebooks/demo.ipynb)

## Installation

First, let's install Diamond PPO and its dependencies:

In [None]:
# Install Diamond PPO from GitHub
!pip install -q git+https://github.com/auxeno/diamond-ppo

# Install additional dependencies for visualization
!apt-get install -qq xvfb
!pip install -q pyvirtualdisplay pygame moviepy imageio

import gymnasium as gym
import numpy as np
import torch
from IPython.display import HTML
from base64 import b64encode
import io

print(f"PyTorch version: {torch.__version__}")
print(f"Gymnasium version: {gym.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Basic Usage - Discrete Actions

Let's start with a simple example using CartPole, a classic control task with discrete actions:

In [None]:
from diamond import PPO, PPOConfig

# Create a simple PPO agent with custom configuration
config = PPOConfig(
    total_steps=50_000,  # Total training steps
    rollout_steps=128,   # Steps per rollout
    num_envs=4,          # Parallel environments
    lr=3e-4,             # Learning rate
    gamma=0.99,          # Discount factor
    gae_lambda=0.95,     # GAE lambda
    ppo_clip=0.2,        # PPO clipping parameter
    verbose=True         # Print training progress
)

# Create and train the agent
agent = PPO(
    env_fn=lambda: gym.make("CartPole-v1"),
    cfg=config
)

print("Training PPO on CartPole...")
agent.train()
print("\nTraining complete!")

## 2. Continuous Control

For continuous action spaces, use `ContinuousPPO`:

In [None]:
from diamond import ContinuousPPO, ContinuousPPOConfig

# Configure continuous PPO
config = ContinuousPPOConfig(
    total_steps=100_000,
    rollout_steps=256,
    num_envs=4,
    lr=3e-4,
    verbose=True
)

# Train on a continuous control task
agent = ContinuousPPO(
    env_fn=lambda: gym.make("Pendulum-v1"),
    cfg=config
)

print("Training Continuous PPO on Pendulum...")
agent.train()
print("\nTraining complete!")

## 3. Custom Neural Networks

Diamond PPO supports custom network architectures. Here's an example with a larger network:

In [None]:
import torch.nn as nn

class CustomNetwork(nn.Module):
    def __init__(self, obs_dim, action_dim):
        super().__init__()
        # Shared feature extractor
        self.base = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()
        )
        
        # Separate heads for actor and critic
        self.actor_head = nn.Linear(128, action_dim)
        self.critic_head = nn.Linear(128, 1)
    
    def actor(self, x):
        """Returns action logits"""
        features = self.base(x)
        return self.actor_head(features)
    
    def critic(self, x):
        """Returns value estimates"""
        features = self.base(x)
        return self.critic_head(features).squeeze(-1)
    
    def forward(self, x):
        """Returns both actor and critic outputs"""
        features = self.base(x)
        return self.actor_head(features), self.critic_head(features).squeeze(-1)

# Use the custom network
env = gym.make("CartPole-v1")
obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
env.close()

custom_net = CustomNetwork(obs_dim, action_dim)

config = PPOConfig(
    total_steps=30_000,
    rollout_steps=128,
    num_envs=4,
    verbose=True
)

agent = PPO(
    env_fn=lambda: gym.make("CartPole-v1"),
    cfg=config,
    custom_network=custom_net
)

print("Training with custom network architecture...")
agent.train()
print("\nTraining complete!")

## 4. Training Utilities

Diamond PPO includes helpful utilities for monitoring training:

In [None]:
from diamond.utils import Logger, Timer

# Example of using the Logger
logger = Logger()

# Simulate some training metrics
np.random.seed(42)
for step in range(100):
    # Simulated metrics
    reward = 100 + step * 2 + np.random.randn() * 10
    loss = 1.0 / (1 + step * 0.1) + np.random.randn() * 0.01
    
    logger.log("episode_reward", step, reward)
    logger.log("policy_loss", step, loss)

# Plot the logged metrics
print("Episode Rewards:")
logger.plot("episode_reward")
print("\nPolicy Loss:")
logger.plot("policy_loss")

In [None]:
# Example of using the Timer for profiling
from diamond.utils import Timer
import time

timer = Timer()

# Simulate different parts of a training loop
for i in range(5):
    with timer.time("environment_step"):
        time.sleep(0.01)  # Simulate env.step()
    
    with timer.time("network_forward"):
        time.sleep(0.005)  # Simulate network forward pass
    
    with timer.time("optimization"):
        time.sleep(0.008)  # Simulate optimization step

# Display timing statistics
timer.plot_timings()

## 5. Evaluation and Visualization

Let's evaluate a trained agent and visualize its performance:

In [None]:
def evaluate_agent(agent, env_name, num_episodes=5, render=False):
    """Evaluate a trained agent."""
    env = gym.make(env_name, render_mode="rgb_array" if render else None)
    
    episode_rewards = []
    frames = []
    
    for episode in range(num_episodes):
        obs, _ = env.reset()
        episode_reward = 0
        done = False
        
        while not done:
            # Get action from trained network
            with torch.no_grad():
                obs_tensor = torch.FloatTensor(obs).unsqueeze(0)
                if hasattr(agent, 'network'):
                    logits, _ = agent.network(obs_tensor)
                    action = torch.argmax(logits, dim=-1).item()
                else:
                    # For demonstration, use random actions
                    action = env.action_space.sample()
            
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            episode_reward += reward
            
            if render and episode == 0:  # Only record first episode
                frames.append(env.render())
        
        episode_rewards.append(episode_reward)
        print(f"Episode {episode + 1}: Reward = {episode_reward:.2f}")
    
    env.close()
    
    print(f"\nAverage Reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}")
    
    return episode_rewards, frames

# Quick evaluation (using random policy for demo)
print("Evaluating agent on CartPole...")
rewards, _ = evaluate_agent(None, "CartPole-v1", num_episodes=5, render=False)