<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Proximal_Policy_Optimization_(PPO).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

class PolicyNetwork(nn.Module):
    def __init__(self):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return torch.softmax(self.fc2(x), dim=-1)

def compute_returns(rewards, gamma=0.99):
    R = 0
    returns = []
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)
    return returns

env = gym.make('CartPole-v1', new_step_api=True)
policy = PolicyNetwork()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)

for episode in range(1000):
    state = env.reset() if isinstance(env.reset(), tuple) else env.reset()
    log_probs = []
    rewards = []
    done = False
    while not done:
        state = torch.tensor(state, dtype=torch.float32)
        probs = policy(state)
        m = Categorical(probs)
        action = m.sample()
        log_prob = m.log_prob(action)
        log_probs.append(log_prob)
        next_state, reward, done, truncated, _ = env.step(action.item())
        done = done or truncated
        rewards.append(reward)
        state = next_state

    returns = compute_returns(rewards)
    loss = 0
    for log_prob, R in zip(log_probs, returns):
        loss += -log_prob * R

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if episode % 100 == 0:
        print(f"Episode {episode}, loss: {loss.item()}")