# Proximal Policy Optimization (PPO)

Proximal Policy Optimization (PPO) is an optimization algorithm for training reinforcement learning agents. It aims to find an optimal policy by iteratively updating the policy parameters in small steps while ensuring that the updated policy does not deviate significantly from the previous one. PPO achieves this by constraining the update using a clipped surrogate objective function, which prevents large policy updates that could lead to policy divergence. By balancing exploration and exploitation, PPO effectively learns near-optimal policies in a wide range of environments, making it a popular choice for training RL agents.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

class ActorCritic(nn.Module):
    def __init__(self, obs_size, hidden_size, n_actions):
        super(ActorCritic, self).__init__()
        # Actor network
        self.fc1 = nn.Linear(obs_size, hidden_size)
        self.fc_actor = nn.Linear(hidden_size, n_actions)
        # Critic network
        self.fc_critic = nn.Linear(hidden_size, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        # Actor network outputs logits for each action
        logits = self.fc_actor(x)
        # Critic network outputs a single value for state value estimation
        value = self.fc_critic(x)
        return logits, value

class PPO_Agent:
    def __init__(self,
                 action_space,
                 observation_space,
                 hidden_size,
                 gamma,
                 clip_ratio,
                 ppo_epochs,
                 batch_size,
                 lr_actor,
                 lr_critic,
                 ):
        self.action_space = action_space
        self.observation_space = observation_space
        self.gamma = gamma
        self.clip_ratio = clip_ratio
        self.ppo_epochs = ppo_epochs
        self.batch_size = batch_size
        self.actor_lr = lr_actor
        self.critic_lr = lr_critic
        
        # Initialize the Actor-Critic model
        self.actor_critic = ActorCritic(observation_space.shape[0], hidden_size, action_space.n)
        # Initialize optimizers for the actor and critic
        self.optimizer_actor = optim.Adam(self.actor_critic.fc_actor.parameters(), lr=self.actor_lr)
        self.optimizer_critic = optim.Adam(self.actor_critic.fc_critic.parameters(), lr=self.critic_lr)

    def get_action(self, state):
        # Convert state to tensor
        state = torch.tensor(state).float().unsqueeze(0)
        # Forward pass through the actor network to get action logits
        logits, _ = self.actor_critic(state)
        # Convert logits to action probabilities
        action_probs = F.softmax(logits, dim=1)
        # Sample an action from the action probabilities
        action = torch.multinomial(action_probs, num_samples=1)
        return action.item()

    def update(self, states, actions, rewards, dones, next_states):
        # Convert data to tensors
        states = torch.tensor(states, dtype=torch.float32)
        actions = torch.tensor(actions, dtype=torch.int64).unsqueeze(1)
        rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1)
        dones = torch.tensor(dones, dtype=torch.float32).unsqueeze(1)
        next_states = torch.tensor(next_states, dtype=torch.float32)

        # Get old logits and values
        old_logits, old_values = self.actor_critic(states)
        _, next_values = self.actor_critic(next_states)

        # Compute advantages
        advantages = rewards + self.gamma * (1 - dones) * next_values - old_values.detach()

        # Update the policy for several epochs
        for _ in range(self.ppo_epochs):
            # Forward pass to get new logits and values
            logits, values = self.actor_critic(states)
            # Compute action probabilities and log probabilities
            probs = F.softmax(logits, dim=1)
            log_probs = F.log_softmax(logits, dim=1)
            action_log_probs = log_probs.gather(1, actions)

            # Compute ratios and surrogate losses
            ratio = (probs / torch.exp(action_log_probs)).gather(1, actions)
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages

            # Compute actor loss
            actor_loss = -torch.min(surr1, surr2).mean()

            # Compute critic loss
            critic_loss = F.smooth_l1_loss(values, rewards + self.gamma * (1 - dones) * next_values)

            # Update actor parameters
            self.optimizer_actor.zero_grad()
            actor_loss.backward(retain_graph=True)
            self.optimizer_actor.step()

            # Update critic parameters
            self.optimizer_critic.zero_grad()
            critic_loss.backward()
            self.optimizer_critic.step()

    def get_value(self, state):
        # Convert state to tensor
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        # Forward pass through the critic network to get value
        _, value = self.actor_critic(state)
        return value.item()