# Lab D.4: Policy Gradients and PPO

**Module:** D - Reinforcement Learning (Optional)
**Time:** 2.5-3 hours
**Difficulty:** ‚≠ê‚≠ê‚≠ê‚≠ê‚≠ê

---

## üéØ Learning Objectives

By the end of this notebook, you will:
- [ ] Understand why policy gradients work differently from value-based methods
- [ ] Implement the REINFORCE algorithm from scratch
- [ ] Build an Actor-Critic architecture
- [ ] Implement Proximal Policy Optimization (PPO)
- [ ] Understand why PPO is used for RLHF in LLMs

---

## üìö Prerequisites

- Completed: Lab D.3 (Deep Q-Networks)
- Knowledge of: PyTorch, neural networks, gradients
- Understanding of: Probability distributions, log probabilities

---

## üåç Real-World Context

**Why Policy Gradients for LLMs?**

DQN outputs Q-values for each action. But for LLMs:
- Action space = entire vocabulary (~50,000+ tokens!)
- Computing Q(s, a) for every token is expensive
- We already have a policy: the LLM outputs token probabilities!

**Solution: Directly optimize the policy (the LLM itself)**

This is exactly what RLHF does:
- The LLM is the policy œÄ(a|s)
- PPO adjusts the LLM to generate better responses
- ChatGPT, Claude, and other assistants use this approach!

---

## üßí ELI5: Policy Gradients vs Value-Based Methods

> **Value-based (DQN)**: "Let me figure out how good each action is, then pick the best one."
> - Like a chess player calculating the value of each possible move
> - Works well when you can evaluate all actions
>
> **Policy-based**: "Let me directly learn which actions to take in each situation."
> - Like a tennis player developing muscle memory
> - Don't evaluate every option, just learn good habits
>
> **The key insight of policy gradients:**
> - If an action led to high reward, make it MORE likely
> - If an action led to low reward, make it LESS likely
> - Adjust probabilities proportionally to how good the outcome was
>
> **In AI terms:** Policy gradient methods directly optimize the policy parameters by computing the gradient of expected reward with respect to those parameters.

---

## Part 1: Setup

In [None]:
# Setup - run this first!
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, List, Dict, Optional
from collections import deque
import warnings
warnings.filterwarnings('ignore')

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical

# Gymnasium
try:
    import gymnasium as gym
except ImportError:
    !pip install gymnasium -q
    import gymnasium as gym

# Set seeds
np.random.seed(42)
torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("üöÄ Module D.4: Policy Gradients and PPO")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# We'll use CartPole again, plus LunarLander for a harder challenge
env_cartpole = gym.make("CartPole-v1")
env_lunar = gym.make("LunarLander-v2")

print("üéÆ Environments:")
print(f"\nCartPole-v1:")
print(f"   State: {env_cartpole.observation_space.shape}")
print(f"   Actions: {env_cartpole.action_space.n}")

print(f"\nLunarLander-v2:")
print(f"   State: {env_lunar.observation_space.shape}")
print(f"   Actions: {env_lunar.action_space.n}")
print("   (0=nothing, 1=left engine, 2=main engine, 3=right engine)")

---

## Part 2: The Policy Gradient Theorem

### The Math Behind Policy Gradients

Our goal: Maximize expected return $J(\theta) = \mathbb{E}_{\pi_\theta}[R]$

The **Policy Gradient Theorem** tells us:

$$\nabla_\theta J(\theta) = \mathbb{E}_{\pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a|s) \cdot Q^{\pi}(s, a) \right]$$

In plain English:
1. Sample actions from your policy
2. Compute how much each action contributed to the outcome
3. Increase probability of good actions, decrease for bad ones

### The Log Probability Trick

Why $\log \pi$ instead of $\pi$?

> üßí **ELI5**: If you got an A on a test, you should study MORE like you did. If you got an F, study LESS like that. The log trick lets us turn "multiply probability" into "add to log probability", which works better with gradient descent.

---

## Part 3: REINFORCE Algorithm

REINFORCE is the simplest policy gradient algorithm:

1. Run a complete episode, collecting (state, action, reward)
2. Compute returns (total future reward) for each timestep
3. Update: $\theta \leftarrow \theta + \alpha \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot G_t$

Where $G_t = \sum_{k=t}^T \gamma^{k-t} r_k$ is the return from timestep $t$.

In [None]:
class PolicyNetwork(nn.Module):
    """
    Simple policy network that outputs action probabilities.
    
    Unlike Q-network (outputs values), this outputs a probability distribution!
    """
    
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 128):
        super().__init__()
        
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)  # Output probabilities!
        )
    
    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """Returns action probabilities."""
        return self.network(state)
    
    def get_action(self, state: torch.Tensor) -> Tuple[int, torch.Tensor]:
        """
        Sample an action from the policy.
        
        Returns:
            action: The sampled action
            log_prob: Log probability of that action (for gradient computation)
        """
        probs = self.forward(state)
        dist = Categorical(probs)  # Create categorical distribution
        action = dist.sample()     # Sample an action
        log_prob = dist.log_prob(action)  # Get log probability
        return action.item(), log_prob

# Test it
policy = PolicyNetwork(4, 2).to(device)
state = torch.randn(1, 4).to(device)
probs = policy(state)
print(f"Action probabilities: {probs.detach().cpu().numpy()[0]}")

action, log_prob = policy.get_action(state)
print(f"Sampled action: {action}, log_prob: {log_prob.item():.4f}")

In [None]:
class REINFORCEAgent:
    """
    REINFORCE: Monte Carlo Policy Gradient.
    
    The simplest policy gradient algorithm.
    Updates after each complete episode using total returns.
    """
    
    def __init__(self, state_dim: int, action_dim: int, 
                 lr: float = 1e-3, gamma: float = 0.99):
        self.gamma = gamma
        
        self.policy = PolicyNetwork(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        
        # Episode storage
        self.log_probs = []
        self.rewards = []
        
        print(f"ü§ñ REINFORCE Agent Initialized")
        print(f"   Parameters: {sum(p.numel() for p in self.policy.parameters()):,}")
    
    def select_action(self, state: np.ndarray) -> int:
        """Select action and store log probability."""
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        action, log_prob = self.policy.get_action(state_tensor)
        self.log_probs.append(log_prob)
        return action
    
    def store_reward(self, reward: float):
        """Store reward for current timestep."""
        self.rewards.append(reward)
    
    def compute_returns(self) -> torch.Tensor:
        """
        Compute discounted returns for each timestep.
        
        G_t = r_t + Œ≥*r_{t+1} + Œ≥¬≤*r_{t+2} + ...
        """
        returns = []
        G = 0
        
        # Work backwards from end of episode
        for reward in reversed(self.rewards):
            G = reward + self.gamma * G
            returns.insert(0, G)
        
        returns = torch.tensor(returns, dtype=torch.float32).to(device)
        
        # Normalize returns (reduces variance!)
        if len(returns) > 1:
            returns = (returns - returns.mean()) / (returns.std() + 1e-8)
        
        return returns
    
    def update(self) -> float:
        """
        Update policy using REINFORCE.
        
        Loss = -sum(log_prob * return)  (negative because we maximize)
        """
        returns = self.compute_returns()
        log_probs = torch.stack(self.log_probs)
        
        # Policy gradient loss
        # We want to MAXIMIZE log_prob * return, so we MINIMIZE negative
        loss = -(log_probs * returns).sum()
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Clear episode storage
        self.log_probs = []
        self.rewards = []
        
        return loss.item()

# Create agent
reinforce_agent = REINFORCEAgent(
    state_dim=env_cartpole.observation_space.shape[0],
    action_dim=env_cartpole.action_space.n,
    lr=1e-3,
    gamma=0.99
)

In [None]:
def train_reinforce(env, agent, n_episodes: int = 1000, 
                    print_freq: int = 100) -> List[float]:
    """
    Train REINFORCE agent.
    """
    episode_rewards = []
    
    for episode in range(n_episodes):
        state, _ = env.reset()
        total_reward = 0
        
        # Run episode
        while True:
            action = agent.select_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            agent.store_reward(reward)
            total_reward += reward
            state = next_state
            
            if done:
                break
        
        # Update at end of episode
        agent.update()
        episode_rewards.append(total_reward)
        
        # Print progress
        if (episode + 1) % print_freq == 0:
            avg_reward = np.mean(episode_rewards[-100:])
            print(f"Episode {episode + 1:4d} | Avg Reward (last 100): {avg_reward:.1f}")
    
    return episode_rewards

# Train!
print("\nüèãÔ∏è Training REINFORCE on CartPole...\n")
reinforce_rewards = train_reinforce(env_cartpole, reinforce_agent, n_episodes=1000)

print(f"\n‚úÖ Final average (last 100): {np.mean(reinforce_rewards[-100:]):.1f}")

In [None]:
# Visualize training
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(reinforce_rewards, alpha=0.3)
window = 50
smoothed = np.convolve(reinforce_rewards, np.ones(window)/window, mode='valid')
plt.plot(range(window-1, len(reinforce_rewards)), smoothed, label=f'{window}-ep avg')
plt.axhline(y=475, color='g', linestyle='--', label='Goal')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('REINFORCE Training')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
# Show variance in rewards
chunk_size = 50
n_chunks = len(reinforce_rewards) // chunk_size
variances = [np.var(reinforce_rewards[i*chunk_size:(i+1)*chunk_size]) 
             for i in range(n_chunks)]
plt.plot([i*chunk_size for i in range(n_chunks)], variances)
plt.xlabel('Episode')
plt.ylabel('Reward Variance')
plt.title('REINFORCE has HIGH VARIANCE')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### The Problem with REINFORCE: High Variance

Notice the jagged learning curve! REINFORCE has **high variance** because:
1. Uses Monte Carlo returns (waits until end of episode)
2. Each episode has different outcomes due to randomness
3. Credit assignment is noisy‚Äîwhich action was actually responsible?

**Solution: Actor-Critic methods!**

---

## Part 4: Actor-Critic

The key insight: Use a **learned value function** as a baseline to reduce variance.

- **Actor**: The policy œÄ(a|s) - decides what to do
- **Critic**: The value function V(s) - evaluates how good states are

Instead of using raw returns $G_t$, we use the **advantage**:

$$A(s, a) = Q(s, a) - V(s) \approx r + \gamma V(s') - V(s)$$

> üßí **ELI5**: Instead of asking "Was that action good?", we ask "Was that action better than average?" This is less noisy because we compare to a baseline.

In [None]:
class ActorCriticNetwork(nn.Module):
    """
    Combined Actor-Critic Network.
    
    Shared feature extraction with separate heads for:
    - Actor (policy): outputs action probabilities
    - Critic (value): outputs state value V(s)
    """
    
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 128):
        super().__init__()
        
        # Shared feature extractor
        self.shared = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
        )
        
        # Actor head (policy)
        self.actor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
        
        # Critic head (value function)
        self.critic = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns:
            action_probs: Probability distribution over actions
            value: State value V(s)
        """
        features = self.shared(state)
        action_probs = self.actor(features)
        value = self.critic(features)
        return action_probs, value
    
    def get_action_and_value(self, state: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]:
        """Sample action and get log prob + value."""
        probs, value = self.forward(state)
        dist = Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.item(), log_prob, value

# Test it
ac_net = ActorCriticNetwork(4, 2).to(device)
state = torch.randn(1, 4).to(device)
probs, value = ac_net(state)
print(f"Action probs: {probs.detach().cpu().numpy()[0]}")
print(f"State value: {value.item():.4f}")

In [None]:
class ActorCriticAgent:
    """
    Advantage Actor-Critic (A2C).
    
    Uses the advantage function to reduce variance in policy gradients.
    """
    
    def __init__(self, state_dim: int, action_dim: int,
                 lr: float = 3e-4, gamma: float = 0.99,
                 value_coef: float = 0.5, entropy_coef: float = 0.01):
        self.gamma = gamma
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        
        self.network = ActorCriticNetwork(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.network.parameters(), lr=lr)
        
        # Episode storage
        self.log_probs = []
        self.values = []
        self.rewards = []
        self.entropies = []
        
        print(f"ü§ñ Actor-Critic Agent Initialized")
        print(f"   Parameters: {sum(p.numel() for p in self.network.parameters()):,}")
    
    def select_action(self, state: np.ndarray) -> int:
        """Select action, store log prob, value, and entropy."""
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        
        probs, value = self.network(state_tensor)
        dist = Categorical(probs)
        action = dist.sample()
        
        self.log_probs.append(dist.log_prob(action))
        self.values.append(value.squeeze())
        self.entropies.append(dist.entropy())
        
        return action.item()
    
    def store_reward(self, reward: float):
        self.rewards.append(reward)
    
    def update(self) -> Dict[str, float]:
        """
        Update using advantage.
        
        Advantage = actual return - predicted value
        """
        # Compute returns
        returns = []
        G = 0
        for reward in reversed(self.rewards):
            G = reward + self.gamma * G
            returns.insert(0, G)
        returns = torch.tensor(returns, dtype=torch.float32).to(device)
        
        # Stack tensors
        log_probs = torch.stack(self.log_probs)
        values = torch.stack(self.values)
        entropies = torch.stack(self.entropies)
        
        # Compute advantages
        advantages = returns - values.detach()
        
        # Normalize advantages (reduces variance)
        if len(advantages) > 1:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # Policy loss: -log_prob * advantage
        policy_loss = -(log_probs * advantages).mean()
        
        # Value loss: MSE between predicted and actual returns
        value_loss = F.mse_loss(values, returns)
        
        # Entropy bonus (encourages exploration)
        entropy_loss = -entropies.mean()
        
        # Total loss
        loss = policy_loss + self.value_coef * value_loss + self.entropy_coef * entropy_loss
        
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)
        self.optimizer.step()
        
        # Clear storage
        self.log_probs = []
        self.values = []
        self.rewards = []
        self.entropies = []
        
        return {
            'policy_loss': policy_loss.item(),
            'value_loss': value_loss.item(),
            'entropy': -entropy_loss.item()
        }

# Create and train
ac_agent = ActorCriticAgent(
    state_dim=env_cartpole.observation_space.shape[0],
    action_dim=env_cartpole.action_space.n
)

In [None]:
def train_actor_critic(env, agent, n_episodes: int = 1000,
                       print_freq: int = 100) -> List[float]:
    """Train Actor-Critic agent."""
    episode_rewards = []
    
    for episode in range(n_episodes):
        state, _ = env.reset()
        total_reward = 0
        
        while True:
            action = agent.select_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            agent.store_reward(reward)
            total_reward += reward
            state = next_state
            
            if done:
                break
        
        agent.update()
        episode_rewards.append(total_reward)
        
        if (episode + 1) % print_freq == 0:
            avg_reward = np.mean(episode_rewards[-100:])
            print(f"Episode {episode + 1:4d} | Avg Reward: {avg_reward:.1f}")
    
    return episode_rewards

print("\nüèãÔ∏è Training Actor-Critic on CartPole...\n")
ac_rewards = train_actor_critic(env_cartpole, ac_agent, n_episodes=1000)

print(f"\n‚úÖ Final average: {np.mean(ac_rewards[-100:]):.1f}")

In [None]:
# Compare REINFORCE vs Actor-Critic
plt.figure(figsize=(12, 5))

window = 50

# Smooth both
reinforce_smooth = np.convolve(reinforce_rewards, np.ones(window)/window, mode='valid')
ac_smooth = np.convolve(ac_rewards, np.ones(window)/window, mode='valid')

plt.plot(range(window-1, len(reinforce_rewards)), reinforce_smooth, 
         label='REINFORCE', alpha=0.8)
plt.plot(range(window-1, len(ac_rewards)), ac_smooth, 
         label='Actor-Critic', alpha=0.8)

plt.axhline(y=475, color='g', linestyle='--', label='Goal')
plt.xlabel('Episode')
plt.ylabel('Reward (smoothed)')
plt.title('REINFORCE vs Actor-Critic on CartPole')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print("\nüìä Comparison:")
print(f"   REINFORCE final avg:     {np.mean(reinforce_rewards[-100:]):.1f}")
print(f"   Actor-Critic final avg:  {np.mean(ac_rewards[-100:]):.1f}")

---

## Part 5: Proximal Policy Optimization (PPO)

### The Stability Problem

Policy gradient updates can be unstable:
- Large updates might destroy the policy
- Small updates might be too slow

**PPO's Solution: Clipped Objective**

Instead of directly maximizing $r(\theta) \cdot A$ where $r(\theta) = \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)}$,

PPO maximizes:

$$L^{CLIP}(\theta) = \min\left( r(\theta) A, \text{clip}(r(\theta), 1-\epsilon, 1+\epsilon) A \right)$$

> üßí **ELI5**: "Don't change too much at once!" PPO puts guardrails on how much the policy can change in one update. If you try to change too much, it clips the update back.

### Why PPO for LLMs?

1. **Stable**: Doesn't destroy the model with one bad update
2. **Sample efficient**: Can reuse data multiple times
3. **Works with large models**: Scales to billions of parameters

In [None]:
class PPOAgent:
    """
    Proximal Policy Optimization (PPO).
    
    The algorithm behind RLHF for ChatGPT, Claude, etc.
    
    Key features:
    - Clipped objective for stable updates
    - Multiple epochs per rollout (sample efficient)
    - Generalized Advantage Estimation (GAE)
    """
    
    def __init__(self, state_dim: int, action_dim: int,
                 lr: float = 3e-4,
                 gamma: float = 0.99,
                 gae_lambda: float = 0.95,
                 clip_epsilon: float = 0.2,
                 value_coef: float = 0.5,
                 entropy_coef: float = 0.01,
                 n_epochs: int = 10,
                 batch_size: int = 64):
        
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_epsilon = clip_epsilon
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        
        self.network = ActorCriticNetwork(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.network.parameters(), lr=lr)
        
        # Rollout storage
        self.states = []
        self.actions = []
        self.log_probs = []
        self.values = []
        self.rewards = []
        self.dones = []
        
        print(f"ü§ñ PPO Agent Initialized")
        print(f"   Clip epsilon: {clip_epsilon}")
        print(f"   Epochs per update: {n_epochs}")
        print(f"   GAE lambda: {gae_lambda}")
    
    def select_action(self, state: np.ndarray) -> int:
        """Select action and store transition data."""
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        
        with torch.no_grad():
            probs, value = self.network(state_tensor)
        
        dist = Categorical(probs)
        action = dist.sample()
        
        self.states.append(state)
        self.actions.append(action.item())
        self.log_probs.append(dist.log_prob(action).item())
        self.values.append(value.item())
        
        return action.item()
    
    def store_transition(self, reward: float, done: bool):
        """Store reward and done flag."""
        self.rewards.append(reward)
        self.dones.append(done)
    
    def compute_gae(self, next_value: float) -> Tuple[np.ndarray, np.ndarray]:
        """
        Compute Generalized Advantage Estimation.
        
        GAE smoothly interpolates between:
        - TD(0): low variance, high bias (Œª=0)
        - Monte Carlo: high variance, low bias (Œª=1)
        """
        values = self.values + [next_value]
        advantages = []
        gae = 0
        
        # Work backwards
        for t in reversed(range(len(self.rewards))):
            delta = self.rewards[t] + self.gamma * values[t+1] * (1 - self.dones[t]) - values[t]
            gae = delta + self.gamma * self.gae_lambda * (1 - self.dones[t]) * gae
            advantages.insert(0, gae)
        
        advantages = np.array(advantages)
        returns = advantages + np.array(self.values)
        
        return advantages, returns
    
    def update(self, next_state: np.ndarray) -> Dict[str, float]:
        """
        PPO update with clipped objective.
        
        This is the heart of PPO!
        """
        # Get final value for GAE computation
        with torch.no_grad():
            next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(device)
            _, next_value = self.network(next_state_tensor)
            next_value = next_value.item()
        
        # Compute advantages
        advantages, returns = self.compute_gae(next_value)
        
        # Convert to tensors
        states = torch.FloatTensor(np.array(self.states)).to(device)
        actions = torch.LongTensor(self.actions).to(device)
        old_log_probs = torch.FloatTensor(self.log_probs).to(device)
        advantages = torch.FloatTensor(advantages).to(device)
        returns = torch.FloatTensor(returns).to(device)
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # PPO update: multiple epochs over same data
        total_policy_loss = 0
        total_value_loss = 0
        total_entropy = 0
        n_updates = 0
        
        for epoch in range(self.n_epochs):
            # Mini-batch updates
            indices = np.random.permutation(len(states))
            
            for start in range(0, len(states), self.batch_size):
                end = start + self.batch_size
                batch_idx = indices[start:end]
                
                batch_states = states[batch_idx]
                batch_actions = actions[batch_idx]
                batch_old_log_probs = old_log_probs[batch_idx]
                batch_advantages = advantages[batch_idx]
                batch_returns = returns[batch_idx]
                
                # Get current policy outputs
                probs, values = self.network(batch_states)
                dist = Categorical(probs)
                new_log_probs = dist.log_prob(batch_actions)
                entropy = dist.entropy().mean()
                
                # Probability ratio
                ratio = (new_log_probs - batch_old_log_probs).exp()
                
                # Clipped surrogate objective
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(ratio, 
                                    1 - self.clip_epsilon, 
                                    1 + self.clip_epsilon) * batch_advantages
                policy_loss = -torch.min(surr1, surr2).mean()
                
                # Value loss
                value_loss = F.mse_loss(values.squeeze(), batch_returns)
                
                # Total loss
                loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy
                
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)
                self.optimizer.step()
                
                total_policy_loss += policy_loss.item()
                total_value_loss += value_loss.item()
                total_entropy += entropy.item()
                n_updates += 1
        
        # Clear rollout storage
        self.states = []
        self.actions = []
        self.log_probs = []
        self.values = []
        self.rewards = []
        self.dones = []
        
        return {
            'policy_loss': total_policy_loss / n_updates,
            'value_loss': total_value_loss / n_updates,
            'entropy': total_entropy / n_updates
        }

# Create PPO agent
ppo_agent = PPOAgent(
    state_dim=env_cartpole.observation_space.shape[0],
    action_dim=env_cartpole.action_space.n,
    lr=3e-4,
    clip_epsilon=0.2,
    n_epochs=10
)

In [None]:
def train_ppo(env, agent, n_episodes: int = 500,
              rollout_length: int = 2048,
              print_freq: int = 50) -> List[float]:
    """
    Train PPO agent.
    
    PPO collects a "rollout" of experiences, then updates multiple times on it.
    """
    episode_rewards = []
    current_reward = 0
    state, _ = env.reset()
    steps = 0
    episodes_done = 0
    
    while episodes_done < n_episodes:
        # Collect rollout
        for _ in range(rollout_length):
            action = agent.select_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            agent.store_transition(reward, done)
            current_reward += reward
            steps += 1
            
            if done:
                episode_rewards.append(current_reward)
                episodes_done += 1
                current_reward = 0
                state, _ = env.reset()
                
                if episodes_done % print_freq == 0:
                    avg_reward = np.mean(episode_rewards[-100:])
                    print(f"Episode {episodes_done:4d} | Avg Reward: {avg_reward:.1f}")
                
                if episodes_done >= n_episodes:
                    break
            else:
                state = next_state
        
        # Update policy
        if len(agent.states) > 0:
            agent.update(state)
    
    return episode_rewards

print("\nüèãÔ∏è Training PPO on CartPole...\n")
ppo_rewards = train_ppo(env_cartpole, ppo_agent, n_episodes=500)

print(f"\n‚úÖ Final average: {np.mean(ppo_rewards[-100:]):.1f}")

In [None]:
# Compare all three methods
plt.figure(figsize=(14, 5))

window = 50

# Limit to same number of episodes for fair comparison
n_eps = min(len(reinforce_rewards), len(ac_rewards), len(ppo_rewards))

reinforce_smooth = np.convolve(reinforce_rewards[:n_eps], np.ones(window)/window, mode='valid')
ac_smooth = np.convolve(ac_rewards[:n_eps], np.ones(window)/window, mode='valid')
ppo_smooth = np.convolve(ppo_rewards[:n_eps], np.ones(window)/window, mode='valid')

plt.plot(range(window-1, n_eps), reinforce_smooth, label='REINFORCE', alpha=0.8)
plt.plot(range(window-1, n_eps), ac_smooth, label='Actor-Critic', alpha=0.8)
plt.plot(range(window-1, n_eps), ppo_smooth, label='PPO', alpha=0.8, linewidth=2)

plt.axhline(y=475, color='g', linestyle='--', label='Goal (475)')
plt.xlabel('Episode')
plt.ylabel('Reward (smoothed)')
plt.title('Policy Gradient Methods Comparison on CartPole')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print("\nüìä Final Performance (last 100 episodes):")
print(f"   REINFORCE:     {np.mean(reinforce_rewards[-100:]):6.1f}")
print(f"   Actor-Critic:  {np.mean(ac_rewards[-100:]):6.1f}")
print(f"   PPO:           {np.mean(ppo_rewards[-100:]):6.1f}")

---

## Part 6: PPO on LunarLander (Harder Environment)

Let's test PPO on a more challenging environment!

In [None]:
# Train PPO on LunarLander
ppo_lunar = PPOAgent(
    state_dim=env_lunar.observation_space.shape[0],
    action_dim=env_lunar.action_space.n,
    lr=3e-4,
    gamma=0.99,
    clip_epsilon=0.2,
    n_epochs=10
)

print("\nüöÄ Training PPO on LunarLander (this takes longer)...\n")
lunar_rewards = train_ppo(env_lunar, ppo_lunar, n_episodes=1000, print_freq=100)

print(f"\n‚úÖ Final average: {np.mean(lunar_rewards[-100:]):.1f}")
print(f"   (Goal: >200 for 'solved')")

In [None]:
# Visualize LunarLander training
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(lunar_rewards, alpha=0.3)
window = 50
smoothed = np.convolve(lunar_rewards, np.ones(window)/window, mode='valid')
plt.plot(range(window-1, len(lunar_rewards)), smoothed, label=f'{window}-ep avg')
plt.axhline(y=200, color='g', linestyle='--', label='Solved (200)')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('PPO on LunarLander')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
# Histogram of final rewards
plt.hist(lunar_rewards[-100:], bins=20, edgecolor='black', alpha=0.7)
plt.axvline(x=200, color='g', linestyle='--', label='Solved')
plt.xlabel('Reward')
plt.ylabel('Count')
plt.title('Reward Distribution (Last 100 Episodes)')
plt.legend()

plt.tight_layout()
plt.show()

---

## ‚ö†Ô∏è Common Mistakes

### Mistake 1: Not Normalizing Advantages

```python
# ‚ùå Raw advantages can have huge variance
loss = -(log_probs * advantages).mean()

# ‚úÖ Normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
loss = -(log_probs * advantages).mean()
```

### Mistake 2: Wrong Probability Ratio

```python
# ‚ùå Subtracting probabilities (wrong!)
ratio = new_probs - old_probs

# ‚úÖ Ratio of probabilities (in log space for stability)
ratio = (new_log_probs - old_log_probs).exp()
```

### Mistake 3: Forgetting to Detach Old Log Probs

```python
# ‚ùå Computing gradients through old policy
old_log_probs = dist.log_prob(action)  # Still attached to graph!

# ‚úÖ Detach when storing
old_log_probs = dist.log_prob(action).detach()  # or .item()
```

---

## üéâ Checkpoint

You've learned:
- ‚úÖ **Policy Gradient Theorem**: Directly optimize the policy
- ‚úÖ **REINFORCE**: Simple but high variance
- ‚úÖ **Actor-Critic**: Use value function as baseline
- ‚úÖ **Advantage Function**: "How much better than average?"
- ‚úÖ **PPO**: Stable training with clipped objective
- ‚úÖ **GAE**: Balance bias-variance in advantage estimation

---

## üîó Connection to RLHF

PPO is the algorithm used in RLHF for LLMs. Here's how it maps:

| PPO Concept | RLHF Application |
|-------------|------------------|
| Policy œÄ(a\|s) | The LLM generating tokens |
| State s | Prompt + tokens generated so far |
| Action a | Next token to generate |
| Reward | Reward model score for complete response |
| Clipped objective | Prevents LLM from changing too much |
| KL penalty | Extra constraint to stay close to original model |

In the next lab, we'll see exactly how this works with the TRL library!

---

## üìñ Further Reading

- [PPO Paper](https://arxiv.org/abs/1707.06347) - Original PPO algorithm
- [GAE Paper](https://arxiv.org/abs/1506.02438) - Generalized Advantage Estimation
- [Spinning Up PPO](https://spinningup.openai.com/en/latest/algorithms/ppo.html) - Great explanation
- [The 37 Implementation Details of PPO](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/) - Deep dive

---

## üßπ Cleanup

In [None]:
# Close environments
env_cartpole.close()
env_lunar.close()

# Clear GPU memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()

import gc
gc.collect()

print("‚úÖ Notebook complete! Ready for Lab D.5: RLHF for Language Models")