<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Proximal_Policy_Optimization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import gym
from torch.autograd import set_detect_anomaly

# Enable anomaly detection
set_detect_anomaly(True)

# Define the Actor-Critic model
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)
        )
        self.critic = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, state):
        action_probs = self.actor(state)
        state_value = self.critic(state)
        return action_probs, state_value

# Define the PPO algorithm
def ppo_update(policy, optimizer, old_log_probs, states, actions, returns, advantages, clip_eps=0.2):
    for _ in range(10):
        new_log_probs, state_values = policy(states)
        new_log_probs = new_log_probs.gather(1, actions.unsqueeze(-1)).squeeze(-1)

        ratio = torch.exp(new_log_probs - old_log_probs)
        surrogate1 = ratio * advantages
        surrogate2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * advantages
        actor_loss = -torch.min(surrogate1, surrogate2).mean()

        critic_loss = nn.functional.mse_loss(state_values.squeeze(-1), returns)
        loss = actor_loss + critic_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Training loop for PPO
env = gym.make("CartPole-v1", new_step_api=True)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
policy = ActorCritic(state_dim, action_dim)
optimizer = optim.Adam(policy.parameters(), lr=0.01)

num_epochs = 1000
gamma = 0.99

for epoch in range(num_epochs):
    states, actions, rewards, log_probs = [], [], [], []
    state = env.reset()
    for t in range(200):
        state = torch.FloatTensor(state).unsqueeze(0)
        action_probs, _ = policy(state)
        action = torch.multinomial(action_probs, num_samples=1).item()
        next_state, reward, done, truncated, _ = env.step(action)

        log_prob = torch.log(action_probs.squeeze(0).detach()[action])  # Detach action_probs
        states.append(state)
        actions.append(torch.tensor([action]))
        rewards.append(reward)
        log_probs.append(log_prob.unsqueeze(0))

        state = next_state
        if done or truncated:
            break

    returns = []
    G = 0
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    returns = torch.FloatTensor(returns)

    states = torch.cat(states)
    actions = torch.cat(actions)
    old_log_probs = torch.cat(log_probs)
    advantages = returns - returns.mean()

    ppo_update(policy, optimizer, old_log_probs, states, actions, returns, advantages)

    print(f"Epoch {epoch + 1}, Total Reward: {sum(rewards)}")

print("Training complete.")