<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Reinforcement_Learning_with_Proximal_Policy_Optimization_(PPO).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import gym

torch.autograd.set_detect_anomaly(True)  # Enable anomaly detection for debugging

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Policy Network
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        return self.fc(x)

# Value Network
class ValueNetwork(nn.Module):
    def __init__(self, input_dim):
        super(ValueNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.fc(x)

# PPO Agent
class PPOAgent:
    def __init__(self, env):
        self.env = env
        self.policy = PolicyNetwork(env.observation_space.shape[0], env.action_space.n).to(device)
        self.value = ValueNetwork(env.observation_space.shape[0]).to(device)
        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=1e-3)
        self.value_optimizer = optim.Adam(self.value.parameters(), lr=1e-3)
        self.gamma = 0.99
        self.eps_clip = 0.2

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        probs = self.policy(state)
        value = self.value(state)
        dist = Categorical(probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action), value

    def compute_returns(self, rewards, next_value, dones):
        returns = []
        G = next_value
        for r, done in zip(reversed(rewards), reversed(dones)):
            G = r + self.gamma * G * (1 - done)
            returns.insert(0, G)
        return torch.FloatTensor(returns).to(device)

    def update(self, log_probs, values, rewards, dones):
        # Compute returns and advantages
        next_value = torch.zeros(1).to(device)
        returns = self.compute_returns(rewards, next_value, dones)
        values = torch.cat(values).squeeze(-1)
        advantages = returns - values

        # Detach tensors to ensure no graph issues
        returns = returns.detach()
        advantages = advantages.detach()

        for _ in range(4):  # Multiple epochs for updating
            for old_log_prob, value, advantage, ret in zip(log_probs, values, advantages, returns):
                ratio = (old_log_prob - old_log_prob.detach()).exp()
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * advantage

                # Policy loss
                policy_loss = -torch.min(surr1, surr2).mean()

                # Value loss
                value_loss = nn.MSELoss()(value, ret)

                # Backpropagation
                self.policy_optimizer.zero_grad()
                self.value_optimizer.zero_grad()
                policy_loss.backward()
                value_loss.backward()
                self.policy_optimizer.step()
                self.value_optimizer.step()

# Main Training Loop
env = gym.make('CartPole-v1', new_step_api=True)
agent = PPOAgent(env)

for episode in range(1000):
    state = env.reset()
    log_probs = []
    rewards = []
    values = []
    dones = []
    done = False

    while not done:
        action, log_prob, value = agent.select_action(state)
        next_state, reward, terminated, truncated, _ = env.step(action)

        log_probs.append(log_prob)
        rewards.append(reward)
        values.append(value)
        dones.append(terminated or truncated)

        state = next_state

    agent.update(log_probs, values, rewards, dones)

    if episode % 10 == 0:
        print(f"Episode {episode}: Total Reward: {sum(rewards)}")