# Policy Gradient: Direct Policy Optimization

> **"Policy gradient methods learn policies directly without value functions."**

## Learning Objectives
- Understand the policy gradient theorem and its mathematical foundation
- Implement REINFORCE algorithm from scratch
- Learn about advantage functions and variance reduction
- Master policy gradient variants (A2C, PPO)
- Apply policy gradient methods to continuous control problems


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import random

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
np.random.seed(42)

print("Libraries imported successfully!")


## 1. Policy Gradient Methods

### What are Policy Gradient Methods?
Policy gradient methods are a class of reinforcement learning algorithms that directly optimize the policy function without needing to learn a value function. They work by adjusting the policy parameters in the direction that increases the expected return.

### Key Concepts

#### 1. Policy Function
A policy π(a|s) defines the probability of taking action a in state s.

#### 2. Policy Gradient Theorem
The gradient of the expected return with respect to the policy parameters is:
$\nabla_\theta J(\theta) = \mathbb{E}[\nabla_\theta \log \pi(a|s) Q(s,a)]$

#### 3. REINFORCE Algorithm
A simple policy gradient algorithm that uses the total return as an estimate of Q(s,a).


In [None]:
class REINFORCEAgent:
    """REINFORCE policy gradient agent implementation."""
    
    def __init__(self, n_states, n_actions, learning_rate=0.01, gamma=0.99):
        self.n_states = n_states
        self.n_actions = n_actions
        self.learning_rate = learning_rate
        self.gamma = gamma
        
        # Initialize policy parameters (weights for each state-action pair)
        self.theta = np.random.randn(n_states, n_actions) * 0.1
        
    def softmax_policy(self, state):
        """Compute policy using softmax function."""
        # Get logits for current state
        logits = self.theta[state]
        
        # Apply softmax to get probabilities
        exp_logits = np.exp(logits - np.max(logits))  # Subtract max for numerical stability
        probabilities = exp_logits / np.sum(exp_logits)
        
        return probabilities
    
    def select_action(self, state):
        """Select action according to current policy."""
        probabilities = self.softmax_policy(state)
        action = np.random.choice(self.n_actions, p=probabilities)
        return action
    
    def compute_returns(self, rewards, gamma):
        """Compute discounted returns for each timestep."""
        returns = []
        G = 0
        
        # Compute returns backwards
        for reward in reversed(rewards):
            G = reward + gamma * G
            returns.insert(0, G)
        
        return returns
    
    def update_policy(self, states, actions, returns):
        """Update policy parameters using REINFORCE algorithm."""
        for state, action, G in zip(states, actions, returns):
            # Get current policy probabilities
            probabilities = self.softmax_policy(state)
            
            # Compute policy gradient
            # ∇log π(a|s) = 1 - π(a|s) for the selected action
            # ∇log π(a'|s) = -π(a'|s) for other actions
            policy_gradient = np.zeros(self.n_actions)
            policy_gradient[action] = 1 - probabilities[action]
            
            # Update parameters
            self.theta[state] += self.learning_rate * G * policy_gradient

print("REINFORCE agent class defined successfully!")


In [None]:
# Simple Grid World Environment for Policy Gradient
class GridWorldPG:
    """Grid world environment for policy gradient methods."""
    
    def __init__(self, size=4):
        self.size = size
        self.n_states = size * size
        self.n_actions = 4  # Up, Down, Left, Right
        
        # Define grid
        self.grid = np.zeros((size, size))
        self.start_pos = (0, 0)
        self.goal_pos = (size-1, size-1)
        self.current_pos = self.start_pos
        
        # Actions: 0=Up, 1=Down, 2=Left, 3=Right
        self.action_effects = [(-1, 0), (1, 0), (0, -1), (0, 1)]
        
    def reset(self):
        """Reset environment to initial state."""
        self.current_pos = self.start_pos
        return self._pos_to_state(self.current_pos)
    
    def _pos_to_state(self, pos):
        """Convert position to state index."""
        return pos[0] * self.size + pos[1]
    
    def _state_to_pos(self, state):
        """Convert state index to position."""
        return (state // self.size, state % self.size)
    
    def step(self, action):
        """Take a step in the environment."""
        # Calculate new position
        new_row = self.current_pos[0] + self.action_effects[action][0]
        new_col = self.current_pos[1] + self.action_effects[action][1]
        
        # Check boundaries
        if 0 <= new_row < self.size and 0 <= new_col < self.size:
            self.current_pos = (new_row, new_col)
        
        # Calculate reward
        if self.current_pos == self.goal_pos:
            reward = 100  # Goal reached
            done = True
        else:
            reward = -1   # Step penalty
            done = False
        
        return self._pos_to_state(self.current_pos), reward, done

# Train REINFORCE agent
def train_reinforce(agent, env, episodes=1000):
    """Train REINFORCE agent on the environment."""
    episode_rewards = []
    episode_lengths = []
    
    for episode in range(episodes):
        # Generate episode
        states = []
        actions = []
        rewards = []
        
        state = env.reset()
        done = False
        steps = 0
        max_steps = 100
        
        while not done and steps < max_steps:
            # Select action
            action = agent.select_action(state)
            
            # Take step
            next_state, reward, done = env.step(action)
            
            # Store experience
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            
            # Update state
            state = next_state
            steps += 1
        
        # Compute returns
        returns = agent.compute_returns(rewards, agent.gamma)
        
        # Update policy
        agent.update_policy(states, actions, returns)
        
        # Store episode statistics
        episode_rewards.append(sum(rewards))
        episode_lengths.append(steps)
        
        # Print progress
        if episode % 100 == 0:
            avg_reward = np.mean(episode_rewards[-100:])
            print(f"Episode {episode}, Average Reward: {avg_reward:.2f}")
    
    return episode_rewards, episode_lengths

# Create environment and agent
env = GridWorldPG(size=4)
agent = REINFORCEAgent(n_states=env.n_states, n_actions=env.n_actions, 
                       learning_rate=0.01, gamma=0.99)

print("Training REINFORCE agent...")
print("=" * 50)

# Train the agent
episode_rewards, episode_lengths = train_reinforce(agent, env, episodes=1000)

# Plot training progress
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Plot episode rewards
axes[0].plot(episode_rewards, alpha=0.6)
axes[0].set_xlabel('Episode')
axes[0].set_ylabel('Total Reward')
axes[0].set_title('Episode Rewards During Training')
axes[0].grid(True, alpha=0.3)

# Plot moving average
window = 50
if len(episode_rewards) >= window:
    moving_avg = np.convolve(episode_rewards, np.ones(window)/window, mode='valid')
    axes[0].plot(range(window-1, len(episode_rewards)), moving_avg, 'r-', linewidth=2, label=f'Moving Average ({window})')
    axes[0].legend()

# Plot episode lengths
axes[1].plot(episode_lengths, alpha=0.6)
axes[1].set_xlabel('Episode')
axes[1].set_ylabel('Episode Length')
axes[1].set_title('Episode Lengths During Training')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nTraining completed!")
print(f"Final average reward (last 100 episodes): {np.mean(episode_rewards[-100:]):.2f}")
print(f"Final average episode length: {np.mean(episode_lengths[-100:]):.2f}")
