In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Policy Gradients and PPO for Language Models

**Vizuara AI**

In this notebook, we will build the optimization engine of RLHF — the algorithm that uses reward signals to improve a language model. We will implement vanilla policy gradients from scratch, understand why they are unstable, and then build PPO (Proximal Policy Optimization) step by step.

By the end, you will have a working PPO implementation that can optimize any policy given a reward signal.


## 1. Why Does This Matter?

In Notebook 1, we built a reward model that can score any text completion. But scoring alone does not improve the model — we need an algorithm that takes these scores and adjusts the model's behavior accordingly.

This is where policy gradient methods come in. The idea is elegant: if an action (token) led to a high reward, increase its probability. If it led to a low reward, decrease its probability. Do this over many iterations, and the model gradually learns to produce better outputs.

But there is a catch — vanilla policy gradients are noisy and can cause catastrophic policy updates. PPO solves this with a brilliantly simple trick: **clip the updates so the policy cannot change too much in one step**. This single idea made RLHF practical for training models like ChatGPT.

In [None]:
# Setup
!pip install torch numpy matplotlib -q

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from collections import deque

torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Building Intuition

Let us build intuition with a simple analogy. Imagine you are learning to play darts.

- **Policy** = your throwing technique (angle, force, release point)
- **Action** = each individual throw
- **Reward** = points scored based on where the dart lands

After each round, you reflect: "When I threw with more arc, I hit closer to the bullseye. When I threw flat, I missed." You then adjust your technique to favor the higher-arc throws.

Policy gradients work exactly the same way. The model generates text (throws darts), receives a reward score, and adjusts its parameters to favor the sequences that scored well.

Let us build a simple environment to see this in action.

In [None]:
# A simple "bandit" environment to build intuition
# The agent must learn to pick the action with highest reward

class SimpleBandit:
    """
    Multi-armed bandit: 5 actions with different mean rewards.
    The agent must learn which action gives the best reward.
    """
    def __init__(self):
        # True mean rewards (agent does not know these)
        self.true_rewards = torch.tensor([0.2, 0.5, 0.8, 0.3, 0.1])

    def step(self, action):
        # Reward = true mean + noise
        reward = self.true_rewards[action] + 0.1 * torch.randn(1).item()
        return reward

# The policy is a simple softmax over action logits
class BanditPolicy(nn.Module):
    def __init__(self, n_actions):
        super().__init__()
        self.logits = nn.Parameter(torch.zeros(n_actions))

    def forward(self):
        return F.softmax(self.logits, dim=0)

    def sample_action(self):
        probs = self.forward()
        action = torch.multinomial(probs, 1).item()
        return action, probs[action]

env = SimpleBandit()
policy = BanditPolicy(5)

print("True reward means:", env.true_rewards.numpy())
print("Initial action probabilities:", policy().detach().numpy())
print("\nThe best action is action 2 (reward = 0.8)")
print("The policy starts uniform — it does not know which action is best.")

## 3. The Mathematics

The policy gradient theorem gives us the gradient of the expected reward:

$$\nabla_\theta J(\theta) = \mathbb{E}\left[\nabla_\theta \log \pi_\theta(a|s) \cdot A(s, a)\right]$$

Where:
- $\pi_\theta(a|s)$ is the policy probability of action $a$ in state $s$
- $A(s, a)$ is the **advantage** — how much better this action is compared to average
- $\nabla_\theta \log \pi_\theta(a|s)$ points in the direction that increases $\pi_\theta(a|s)$

**Numerical example:** Suppose $\pi_\theta(\text{action}_2) = 0.3$ and the advantage $A = 2.0$:

$$\nabla_\theta \log(0.3) \cdot 2.0$$

The gradient $\nabla_\theta \log(0.3)$ points in the direction that increases the probability of action 2. Multiplying by $A = 2.0$ (positive — good action) means we step in that direction, increasing the probability of this action.

If instead $A = -1.0$ (bad action), we step in the opposite direction, decreasing its probability. This is exactly what we want.

**The PPO clipping trick:**

Vanilla policy gradients can cause wild updates. PPO fixes this:

$$L^{\text{CLIP}} = \mathbb{E}\left[\min\left(r_t A_t, \; \text{clip}(r_t, 1-\epsilon, 1+\epsilon) A_t\right)\right]$$

Where $r_t = \pi_\theta(a_t|s_t) / \pi_{\theta_\text{old}}(a_t|s_t)$ is the ratio between new and old policy.

**Numerical example:** With $\epsilon = 0.2$, $r_t = 1.5$, $A_t = 2.0$:
- Unclipped: $1.5 \times 2.0 = 3.0$
- Clipped: $\text{clip}(1.5, 0.8, 1.2) \times 2.0 = 1.2 \times 2.0 = 2.4$
- PPO objective: $\min(3.0, 2.4) = 2.4$

The clipping reduces the gradient, preventing an overly aggressive update.

In [None]:
# Visualize PPO clipping
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
epsilon = 0.2
ratio = np.linspace(0.0, 2.5, 500)

# Positive advantage
A_pos = 1.0
unclipped_pos = ratio * A_pos
clipped_pos = np.clip(ratio, 1 - epsilon, 1 + epsilon) * A_pos
ppo_pos = np.minimum(unclipped_pos, clipped_pos)

axes[0].plot(ratio, unclipped_pos, 'b--', alpha=0.5, label='Unclipped')
axes[0].plot(ratio, ppo_pos, 'b-', linewidth=2.5, label='PPO (clipped)')
axes[0].axvline(x=1.0, color='gray', linestyle=':', alpha=0.5)
axes[0].fill_between(ratio, ppo_pos, unclipped_pos,
                     where=(unclipped_pos > ppo_pos), alpha=0.15, color='red',
                     label='Clipped away')
axes[0].set_xlabel('Policy Ratio r(θ)', fontsize=12)
axes[0].set_ylabel('Objective', fontsize=12)
axes[0].set_title('Positive Advantage (A > 0)', fontsize=13)
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Negative advantage
A_neg = -1.0
unclipped_neg = ratio * A_neg
clipped_neg = np.clip(ratio, 1 - epsilon, 1 + epsilon) * A_neg
ppo_neg = np.minimum(unclipped_neg, clipped_neg)

axes[1].plot(ratio, unclipped_neg, 'r--', alpha=0.5, label='Unclipped')
axes[1].plot(ratio, ppo_neg, 'r-', linewidth=2.5, label='PPO (clipped)')
axes[1].axvline(x=1.0, color='gray', linestyle=':', alpha=0.5)
axes[1].fill_between(ratio, ppo_neg, unclipped_neg,
                     where=(unclipped_neg > ppo_neg), alpha=0.15, color='blue',
                     label='Clipped away')
axes[1].set_xlabel('Policy Ratio r(θ)', fontsize=12)
axes[1].set_ylabel('Objective', fontsize=12)
axes[1].set_title('Negative Advantage (A < 0)', fontsize=13)
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
print("PPO prevents the policy from changing too much in either direction!")

## 4. Let's Build It — Component by Component

### Step 1: REINFORCE (Vanilla Policy Gradient)

Let us first implement the simplest policy gradient algorithm — REINFORCE — on our bandit problem.

In [None]:
# REINFORCE on the bandit problem
policy = BanditPolicy(5).to(device)
optimizer = torch.optim.Adam(policy.parameters(), lr=0.05)
env = SimpleBandit()

rewards_history = []
probs_history = []

for episode in range(500):
    # Sample action from policy
    action, prob = policy.sample_action()

    # Get reward from environment
    reward = env.step(action)

    # REINFORCE loss: -log(pi(a)) * R
    loss = -torch.log(prob) * reward

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    rewards_history.append(reward)
    probs_history.append(policy().detach().cpu().numpy().copy())

# Plot results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Smoothed rewards
window = 50
smoothed = [np.mean(rewards_history[max(0,i-window):i+1]) for i in range(len(rewards_history))]
axes[0].plot(smoothed, 'b-', linewidth=1.5)
axes[0].axhline(y=0.8, color='green', linestyle='--', label='Optimal reward (0.8)')
axes[0].set_xlabel('Episode', fontsize=12)
axes[0].set_ylabel('Reward (smoothed)', fontsize=12)
axes[0].set_title('REINFORCE Training Rewards', fontsize=13)
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Action probabilities over time
probs_array = np.array(probs_history)
for i in range(5):
    axes[1].plot(probs_array[:, i], label=f'Action {i} (r={env.true_rewards[i]:.1f})')
axes[1].set_xlabel('Episode', fontsize=12)
axes[1].set_ylabel('Probability', fontsize=12)
axes[1].set_title('Action Probabilities During Training', fontsize=13)
axes[1].legend(fontsize=9)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal action probabilities: {policy().detach().cpu().numpy()}")
print(f"Best action (action 2) probability: {policy().detach().cpu()[2]:.3f}")

### Step 2: PPO Implementation

Now let us implement PPO. The key difference is that we collect a batch of experiences, then do **multiple optimization steps** on the same batch, using the clipped objective to prevent wild updates.

In [None]:
def ppo_loss(log_probs_new, log_probs_old, advantages, epsilon=0.2):
    """
    PPO clipped surrogate loss.

    Args:
        log_probs_new: log pi_theta(a|s) under current policy
        log_probs_old: log pi_theta_old(a|s) under old policy (detached)
        advantages: A(s, a) advantage estimates
        epsilon: clipping parameter (default 0.2)

    Returns:
        Scalar loss (negated for gradient ascent)
    """
    # Compute importance sampling ratio
    ratio = torch.exp(log_probs_new - log_probs_old)

    # Unclipped objective
    unclipped = ratio * advantages

    # Clipped objective
    clipped = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages

    # PPO takes the minimum (pessimistic bound)
    loss = -torch.min(unclipped, clipped).mean()

    return loss

# Test it
log_new = torch.tensor([-0.5, -1.0, -0.3])
log_old = torch.tensor([-0.7, -0.8, -0.5])
adv = torch.tensor([2.0, -1.0, 0.5])

loss = ppo_loss(log_new, log_old, adv)
ratios = torch.exp(log_new - log_old)
print(f"Ratios: {ratios.numpy()}")
print(f"Advantages: {adv.numpy()}")
print(f"PPO Loss: {loss.item():.4f}")

## 5. Your Turn

### TODO 1: Implement GAE (Generalized Advantage Estimation)

In real RLHF, advantages are computed using GAE, which balances bias and variance.

In [None]:
def compute_gae(rewards, values, gamma=0.99, lam=0.95):
    """
    Compute Generalized Advantage Estimation.

    GAE computes advantages as an exponentially weighted average of
    TD residuals: delta_t = r_t + gamma * V(s_{t+1}) - V(s_t)

    A_t = sum_{l=0}^{inf} (gamma * lambda)^l * delta_{t+l}

    Args:
        rewards: list of rewards [r_0, r_1, ..., r_T]
        values: list of value estimates [V(s_0), V(s_1), ..., V(s_T), V(s_{T+1})]
        gamma: discount factor
        lam: GAE lambda parameter

    Returns:
        advantages: list of advantage estimates [A_0, A_1, ..., A_T]

    Hint:
        1. Compute TD residuals: delta_t = r_t + gamma * V(s_{t+1}) - V(s_t)
        2. Work backwards from t=T to t=0
        3. A_t = delta_t + gamma * lambda * A_{t+1}
    """
    # TODO: Implement GAE
    # Step 1: Initialize advantages list and set A_{T+1} = 0
    # Step 2: Loop backwards from T to 0
    # Step 3: Compute delta_t = rewards[t] + gamma * values[t+1] - values[t]
    # Step 4: Compute A_t = delta_t + gamma * lam * A_{t+1}
    pass

# Test (uncomment after implementing):
# rewards = [1.0, 0.5, 2.0, 0.0, 1.0]
# values = [0.5, 0.6, 0.4, 0.8, 0.3, 0.0]  # includes V(s_{T+1}) = 0
# advantages = compute_gae(rewards, values)
# print(f"Advantages: {advantages}")

### TODO 2: Add Value Function Loss to PPO

PPO typically also trains a value function alongside the policy. Implement the combined loss.

In [None]:
def ppo_combined_loss(log_probs_new, log_probs_old, advantages,
                      values, returns, epsilon=0.2, vf_coeff=0.5):
    """
    Combined PPO loss = Policy loss + Value function loss.

    Args:
        log_probs_new: current policy log probs
        log_probs_old: old policy log probs
        advantages: GAE advantages
        values: value function predictions V(s)
        returns: discounted returns (targets for value function)
        epsilon: clipping parameter
        vf_coeff: coefficient for value function loss

    Returns:
        total_loss, policy_loss, value_loss

    TODO: Implement this function
    Hint:
        1. Compute PPO clipped policy loss (reuse ppo_loss function)
        2. Compute value loss as MSE: (values - returns)^2
        3. Combine: total = policy_loss + vf_coeff * value_loss
    """
    # TODO: Implement
    pass

# Test (uncomment after implementing):
# total, pol, val = ppo_combined_loss(log_new, log_old, adv,
#                                      torch.tensor([1.0, 2.0, 0.5]),
#                                      torch.tensor([1.5, 1.8, 0.7]))
# print(f"Total: {total:.4f}, Policy: {pol:.4f}, Value: {val:.4f}")

## 6. Putting It All Together

Let us now train PPO on a slightly more complex problem — a contextual bandit where the optimal action depends on the state.

In [None]:
class ContextualBandit:
    """
    Contextual bandit: the best action depends on the state (context).
    State is a 4-dim vector. Optimal action = argmax(state * weights).
    """
    def __init__(self, n_actions=5, state_dim=4):
        self.n_actions = n_actions
        self.state_dim = state_dim
        self.weights = torch.randn(state_dim, n_actions) * 0.5

    def get_state(self):
        return torch.randn(self.state_dim)

    def step(self, state, action):
        true_values = state @ self.weights
        reward = true_values[action].item() + 0.1 * np.random.randn()
        return reward

class PPOPolicy(nn.Module):
    def __init__(self, state_dim, n_actions, hidden_dim=32):
        super().__init__()
        # Actor (policy)
        self.actor = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions),
        )
        # Critic (value function)
        self.critic = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, state):
        logits = self.actor(state)
        value = self.critic(state).squeeze(-1)
        return logits, value

    def get_action(self, state):
        logits, value = self.forward(state)
        probs = F.softmax(logits, dim=-1)
        action = torch.multinomial(probs, 1).item()
        log_prob = F.log_softmax(logits, dim=-1)[action]
        return action, log_prob, value

# Initialize
env = ContextualBandit(n_actions=5, state_dim=4)
policy = PPOPolicy(state_dim=4, n_actions=5).to(device)
optimizer = torch.optim.Adam(policy.parameters(), lr=3e-4)

print(f"Policy parameters: {sum(p.numel() for p in policy.parameters()):,}")
print("Ready for PPO training!")

## 7. Training and Results

In [None]:
# PPO Training Loop
GAMMA = 0.99
EPSILON = 0.2
PPO_EPOCHS = 4
BATCH_SIZE = 64
NUM_ITERATIONS = 200

all_rewards = []

for iteration in range(NUM_ITERATIONS):
    # --- Collect batch of experiences ---
    states, actions, rewards, log_probs_old, values = [], [], [], [], []

    for _ in range(BATCH_SIZE):
        state = env.get_state().to(device)
        action, log_prob, value = policy.get_action(state)
        reward = env.step(state.cpu(), action)

        states.append(state)
        actions.append(action)
        rewards.append(reward)
        log_probs_old.append(log_prob.detach())
        values.append(value.detach())

    # Convert to tensors
    states_t = torch.stack(states)
    actions_t = torch.tensor(actions, device=device)
    rewards_t = torch.tensor(rewards, dtype=torch.float32, device=device)
    log_probs_old_t = torch.stack(log_probs_old)
    values_t = torch.stack(values)

    # Compute advantages (simple: reward - value baseline)
    advantages = rewards_t - values_t
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # --- PPO update (multiple epochs on same batch) ---
    for ppo_epoch in range(PPO_EPOCHS):
        logits, new_values = policy(states_t)
        log_probs_all = F.log_softmax(logits, dim=-1)
        log_probs_new = log_probs_all.gather(1, actions_t.unsqueeze(1)).squeeze(1)

        # PPO clipped policy loss
        ratio = torch.exp(log_probs_new - log_probs_old_t)
        unclipped = ratio * advantages
        clipped = torch.clamp(ratio, 1 - EPSILON, 1 + EPSILON) * advantages
        policy_loss = -torch.min(unclipped, clipped).mean()

        # Value function loss
        value_loss = F.mse_loss(new_values, rewards_t)

        # Combined loss
        loss = policy_loss + 0.5 * value_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    all_rewards.append(np.mean(rewards))

    if (iteration + 1) % 50 == 0:
        recent = np.mean(all_rewards[-50:])
        print(f"Iteration {iteration+1}/{NUM_ITERATIONS} — "
              f"Mean Reward: {recent:.3f}")

# Plot training curve
plt.figure(figsize=(10, 5))
window = 20
smoothed = [np.mean(all_rewards[max(0,i-window):i+1]) for i in range(len(all_rewards))]
plt.plot(smoothed, 'b-', linewidth=2)
plt.xlabel('Iteration', fontsize=12)
plt.ylabel('Mean Reward (smoothed)', fontsize=12)
plt.title('PPO Training on Contextual Bandit', fontsize=13)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nFinal mean reward: {np.mean(all_rewards[-20:]):.3f}")
print("PPO successfully learns to pick the best action for each context!")

## 8. Final Output

Let us compare the behavior of the trained policy with a random baseline to see how much PPO improved.

In [None]:
# Compare trained policy vs random policy
policy.eval()
n_test = 200

trained_rewards = []
random_rewards = []

with torch.no_grad():
    for _ in range(n_test):
        state = env.get_state().to(device)

        # Trained policy
        logits, _ = policy(state)
        action = logits.argmax().item()
        trained_reward = env.step(state.cpu(), action)
        trained_rewards.append(trained_reward)

        # Random policy
        random_action = np.random.randint(0, 5)
        random_reward = env.step(state.cpu(), random_action)
        random_rewards.append(random_reward)

# Optimal (cheating — picking best action with known weights)
optimal_rewards = []
for _ in range(n_test):
    state = env.get_state()
    true_values = state @ env.weights
    optimal_action = true_values.argmax().item()
    optimal_reward = env.step(state, optimal_action)
    optimal_rewards.append(optimal_reward)

plt.figure(figsize=(8, 5))
data = [random_rewards, trained_rewards, optimal_rewards]
bp = plt.boxplot(data, labels=['Random', 'PPO-Trained', 'Optimal'], patch_artist=True)
colors = ['#ff9999', '#66b3ff', '#99ff99']
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
plt.ylabel('Reward', fontsize=12)
plt.title('Policy Comparison: Random vs PPO vs Optimal', fontsize=13)
plt.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()

print(f"Random mean reward:  {np.mean(random_rewards):.3f}")
print(f"PPO mean reward:     {np.mean(trained_rewards):.3f}")
print(f"Optimal mean reward: {np.mean(optimal_rewards):.3f}")
improvement = (np.mean(trained_rewards) - np.mean(random_rewards)) / (np.mean(optimal_rewards) - np.mean(random_rewards)) * 100
print(f"\nPPO closes {improvement:.0f}% of the gap between random and optimal!")

## 9. Reflection and Next Steps

**What we built:**
- Vanilla REINFORCE algorithm from scratch
- PPO with clipped surrogate objective
- Actor-Critic architecture with separate policy and value heads
- Training on a contextual bandit environment

**Key takeaways:**
1. Policy gradients increase the probability of actions that led to high rewards
2. Vanilla REINFORCE is simple but high-variance
3. PPO clips the policy ratio to prevent catastrophic updates
4. Multiple epochs on the same batch make PPO sample-efficient

**Think about:**
- Why does PPO use multiple epochs on the same batch instead of collecting fresh data each time?
- What would happen if epsilon were set to 0 (no clipping at all)?
- How does PPO scale to language models where the "action space" is the entire vocabulary?

**Next notebook:** We will combine the reward model from Notebook 1 with the PPO algorithm from this notebook to build a complete RLHF pipeline that aligns a language model.