<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Proximal_Policy_Optimization_(PPO).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import MultivariateNormal

# Actor-Critic network
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, action_dim)
        )
        self.critic = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )
        self.log_std = nn.Parameter(torch.zeros(action_dim))

    def forward(self, state):
        action_mean = self.actor(state)
        action_std = torch.exp(self.log_std)
        value = self.critic(state)
        return action_mean, action_std, value

# PPO agent
class PPO:
    def __init__(self, state_dim, action_dim, env):
        self.actor_critic = ActorCritic(state_dim, action_dim)
        self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=3e-4)
        self.env = env
        self.state_dim = state_dim
        self.action_dim = action_dim

        self.clip_epsilon = 0.2
        self.gamma = 0.99
        self.lmbda = 0.95
        self.epochs = 10
        self.batch_size = 64

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        action_mean, action_std, _ = self.actor_critic(state)
        action_dist = MultivariateNormal(action_mean, torch.diag(action_std ** 2))
        action = action_dist.sample()
        action_log_prob = action_dist.log_prob(action)
        return action.detach().numpy().flatten(), action_log_prob.item()

    def compute_gae(self, rewards, values, dones):
        values = values + [0]
        gae = 0
        returns = []
        for step in reversed(range(len(rewards))):
            delta = rewards[step] + self.gamma * values[step + 1] * (1 - dones[step]) - values[step]
            gae = delta + self.gamma * self.lmbda * (1 - dones[step]) * gae
            returns.insert(0, gae + values[step])
        return returns

    def update(self, memory):
        states, actions, log_probs, rewards, next_states, dones = zip(*memory)
        states = torch.tensor(states, dtype=torch.float32)
        actions = torch.tensor(actions, dtype=torch.float32)
        log_probs = torch.tensor(log_probs, dtype=torch.float32)
        rewards = torch.tensor(rewards, dtype=torch.float32)
        dones = torch.tensor(dones, dtype=torch.float32)

        with torch.no_grad():
            _, _, values = self.actor_critic(states)
            values = values.squeeze().numpy()
            returns = self.compute_gae(rewards, values, dones)

        returns = torch.tensor(returns, dtype=torch.float32)
        advantages = returns - torch.tensor(values, dtype=torch.float32)

        for _ in range(self.epochs):
            for i in range(0, len(memory), self.batch_size):
                batch_states = states[i:i + self.batch_size]
                batch_actions = actions[i:i + self.batch_size]
                batch_log_probs = log_probs[i:i + self.batch_size]
                batch_advantages = advantages[i:i + self.batch_size]
                batch_returns = returns[i:i + self.batch_size]

                action_mean, action_std, values = self.actor_critic(batch_states)
                action_dist = MultivariateNormal(action_mean, torch.diag(action_std ** 2))
                new_log_probs = action_dist.log_prob(batch_actions)

                ratios = torch.exp(new_log_probs - batch_log_probs)
                surrogate1 = ratios * batch_advantages
                surrogate2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_advantages
                policy_loss = -torch.min(surrogate1, surrogate2).mean()

                value_loss = (batch_returns - values.squeeze()).pow(2).mean()

                loss = policy_loss + 0.5 * value_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

    def train(self, max_episodes=1000):
        for episode in range(max_episodes):
            state = self.env.reset()
            memory = []
            episode_reward = 0

            done = False
            while not done:
                action, action_log_prob = self.select_action(state)
                next_state, reward, done, truncated, _ = self.env.step(action)
                memory.append((state, action, action_log_prob, reward, next_state, done or truncated))
                state = next_state
                episode_reward += reward

            self.update(memory)

            print(f"Episode {episode}, Reward: {episode_reward}")

env = gym.make("Pendulum-v1", new_step_api=True)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

agent = PPO(state_dim, action_dim, env)
agent.train()