# ppo.py

Auto-generated implementation from the Agentic RL PhD codebase.

### Original Implementations & References
The following links point to the official or high-quality reference implementations for the papers covered in this notebook:

- https://github.com/nikhilbarhate99/PPO-PyTorch

*Note: The code below is a simplified pedagogical implementation.*

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

# Paper: "Proximal Policy Optimization Algorithms" (Schulman et al., 2017)
# Category: Foundational Model-Free RL

class PPOAgent(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(PPOAgent, self).__init__()
        # Actor: Decides what to do (Policy)
        self.actor = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
        
        # Critic: Evaluates how good the state is (Value Function)
        self.critic = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self):
        raise NotImplementedError

    def get_action_and_value(self, state, action=None):
        probs = self.actor(state)
        dist = Categorical(probs)
        if action is None:
            action = dist.sample()
        return action, dist.log_prob(action), dist.entropy(), self.critic(state)

def compute_gae(rewards, values, next_value, gamma, lam):
    """
    Generalized Advantage Estimation (GAE)
    Essential for stable PPO training.
    """
    values = values + [next_value]
    gae = 0
    returns = []
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * values[step + 1] - values[step]
        gae = delta + gamma * lam * gae
        returns.insert(0, gae + values[step])
    return returns

def ppo_loss(old_log_probs, log_probs, advantages, returns, values, clip_param=0.2):
    """
    The Core Equation (Eq. 7 in the Paper)
    L_CLIP(theta) = E [ min( r(theta) * A, clip(r(theta), 1-e, 1+e) * A ) ]
    """
    ratio = (log_probs - old_log_probs).exp()
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantages
    
    actor_loss = -torch.min(surr1, surr2).mean()
    critic_loss = (returns - values).pow(2).mean()
    
    return actor_loss + 0.5 * critic_loss
