<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Proximal_Policy_Optimization_(PPO).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install --upgrade gym

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import gym
from torch.distributions import Categorical

# Define Actor-Critic Model
class ActorCritic(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.actor = nn.Linear(128, output_dim)
        self.critic = nn.Linear(128, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        policy = torch.softmax(self.actor(x), dim=-1)
        value = self.critic(x)
        return policy, value

# Hyperparameters
gamma = 0.99
lr = 3e-4
eps_clip = 0.2
batch_size = 64
update_epochs = 10
timesteps_per_batch = 2048

# Initialize environment and model
env = gym.make('CartPole-v1')
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.n
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ActorCritic(input_dim, output_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

# PPO Update Function
def ppo_update(states, actions, rewards, next_states, dones, old_log_probs):
    # Convert lists to tensors
    states = torch.tensor(states, dtype=torch.float32).to(device)
    actions = torch.tensor(actions, dtype=torch.long).to(device)
    rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
    next_states = torch.tensor(next_states, dtype=torch.float32).to(device)
    dones = torch.tensor(dones, dtype=torch.float32).to(device)
    old_log_probs = torch.tensor(old_log_probs, dtype=torch.float32).to(device)

    # Compute values and advantages
    with torch.no_grad():  # No gradient computation here
        _, values = model(states)
        _, next_values = model(next_states)
        values = values.squeeze(-1)
        next_values = next_values.squeeze(-1)
        target_values = rewards + gamma * next_values * (1 - dones)
        advantages = target_values - values
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)  # Normalize advantages

    for _ in range(update_epochs):
        # Compute new policy and values
        policy, values = model(states)
        dist = Categorical(policy)
        new_log_probs = dist.log_prob(actions)
        values = values.squeeze(-1)

        # Compute surrogate objective
        ratio = torch.exp(new_log_probs - old_log_probs)
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantages
        actor_loss = -torch.min(surr1, surr2).mean()

        # Critic loss
        critic_loss = nn.MSELoss()(values, target_values)

        # Combine actor and critic losses
        loss = actor_loss + 0.5 * critic_loss

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Training loop
num_episodes = 1000
states, actions, rewards, next_states, dones, old_log_probs = [], [], [], [], [], []
total_timesteps = 0

for episode in range(num_episodes):
    state, _ = env.reset() if isinstance(env.reset(), tuple) else (env.reset(), None)
    done = False
    total_reward = 0

    while not done:
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
        policy, _ = model(state_tensor)
        dist = Categorical(policy)
        action = dist.sample()

        # Step in the environment
        step_result = env.step(action.item())
        if len(step_result) == 4:  # Handle Gym API
            next_state, reward, done, _ = step_result
        else:
            next_state, reward, terminated, truncated, _ = step_result
            done = terminated or truncated

        total_reward += reward

        # Store transitions
        states.append(state)
        actions.append(action.item())
        rewards.append(reward)
        next_states.append(next_state)
        dones.append(done)
        old_log_probs.append(dist.log_prob(action).item())

        state = next_state
        total_timesteps += 1

        # Perform PPO update when enough timesteps are collected
        if total_timesteps >= timesteps_per_batch:
            ppo_update(states, actions, rewards, next_states, dones, old_log_probs)
            states, actions, rewards, next_states, dones, old_log_probs = [], [], [], [], [], []
            total_timesteps = 0

    print(f"Episode {episode + 1}, Total Reward: {total_reward}")