<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Proximal_Policy_Optimization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import gym
import numpy as np

torch.autograd.set_detect_anomaly(True)  # Enable anomaly detection

# Define the policy network
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return torch.softmax(self.fc2(x), dim=-1)

# Define the value network
class ValueNetwork(nn.Module):
    def __init__(self, state_dim):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# Initialize environment, policy, value network, and optimizer
env = gym.make('CartPole-v1', new_step_api=True)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

policy_net = PolicyNetwork(state_dim, action_dim)
value_net = ValueNetwork(state_dim)
optimizer_policy = optim.Adam(policy_net.parameters(), lr=1e-3)
optimizer_value = optim.Adam(value_net.parameters(), lr=1e-3)

# PPO hyperparameters
gamma = 0.99
epsilon = 0.2
epochs = 3
batch_size = 32

def compute_returns(rewards, gamma):
    returns = []
    G = 0
    for reward in reversed(rewards):
        G = reward + gamma * G
        returns.insert(0, G)
    return returns

# Training loop
for episode in range(1000):
    state = env.reset()
    log_probs = []
    values = []
    rewards = []
    states = []
    dones = []
    done = False
    while not done:
        state_tensor = torch.tensor(state, dtype=torch.float32)
        dist = policy_net(state_tensor)
        value = value_net(state_tensor)
        action = Categorical(dist).sample()
        log_prob = torch.log(dist[action])

        next_state, reward, done, truncated, _ = env.step(action.item())

        log_probs.append(log_prob)
        values.append(value)
        rewards.append(reward)
        states.append(state)
        dones.append(done or truncated)

        state = next_state

    # Compute returns and advantages
    returns = compute_returns(rewards, gamma)
    returns = torch.tensor(returns, dtype=torch.float32)
    values = torch.stack(values)
    log_probs = torch.stack(log_probs)
    advantages = returns - values.detach()

    # Ensure sizes are consistent
    assert log_probs.size(0) == advantages.size(0), "Mismatch in sizes between log_probs and advantages"
    assert values.size(0) == advantages.size(0), "Mismatch in sizes between values and advantages"

    # Update policy
    for _ in range(epochs):
        for i in range(0, len(rewards), batch_size):
            batch_indices = np.arange(i, min(i + batch_size, len(rewards)))
            batch_states = torch.tensor(np.array(states)[batch_indices], dtype=torch.float32)
            batch_log_probs = log_probs[batch_indices]
            batch_advantages = advantages[batch_indices]

            # Ensure batch_advantages matches batch size and is 1-dimensional