<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Reinforcement_Learning_with_Proximal_Policy_Optimization_(PPO).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import gym

# Enable anomaly detection
torch.autograd.set_detect_anomaly(True)

# Define the policy network
class Policy(nn.Module):
    def __init__(self, input_dim, action_dim):
        super(Policy, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        return self.fc(x)

# Define the value network
class Value(nn.Module):
    def __init__(self, input_dim):
        super(Value, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.fc(x)

# PPO training
def train_ppo(env, policy_net, value_net, optimizer, epochs, gamma=0.99, eps_clip=0.2, lr=0.001):
    for epoch in range(epochs):
        state = env.reset()
        log_probs = []
        values = []
        rewards = []
        dones = []
        states = []
        actions = []

        for _ in range(1000):  # Collect data for an episode
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            dist = policy_net(state)
            value = value_net(state)
            action = Categorical(dist).sample()
            next_state, reward, done, truncated, _ = env.step(action.item())

            log_probs.append(Categorical(dist).log_prob(action))
            values.append(value)
            rewards.append(torch.tensor([reward], dtype=torch.float32))
            dones.append(torch.tensor([done or truncated], dtype=torch.float32))
            states.append(state)
            actions.append(action)

            if done or truncated:
                break
            state = next_state

        # Convert lists to tensors
        log_probs = torch.cat(log_probs).detach()
        values = torch.cat(values).squeeze().detach()
        rewards = torch.cat(rewards).detach()
        dones = torch.cat(dones).detach()
        actions = torch.cat(actions).detach()
        states = torch.cat(states).detach()

        # Debugging statements to check tensor shapes and values
        print(f"Epoch {epoch + 1} - Rewards: {rewards}, Dones: {dones}")

        # Compute returns and advantages
        returns = []
        Gt = torch.tensor([0.0])
        for r, d in zip(rewards.flip(0), dones.flip(0)):
            Gt = r + gamma * Gt * (1 - d)
            returns.insert(0, Gt)
        returns = torch.cat(returns).detach()

        advs = returns - values

        # PPO update
        old_log_probs = log_probs
        for _ in range(5):
            dists = policy_net(states)
            new_log_probs = Categorical(dists).log_prob(actions)
            ratio = torch.exp(new_log_probs - old_log_probs)
            surr1 = ratio * advs
            surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advs
            policy_loss = -torch.min(surr1, surr2).mean()

            value_loss = 0.5 * (returns - values).pow(2).mean()

            loss = policy_loss + value_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

# Use the new step API
env = gym.make('CartPole-v1', new_step_api=True)
policy_net = Policy(input_dim=4, action_dim=2)
value_net = Value(input_dim=4)
optimizer = optim.Adam(list(policy_net.parameters()) + list(value_net.parameters()), lr=0.001)

train_ppo(env, policy_net, value_net, optimizer, epochs=100)