<a href="https://colab.research.google.com/github/aswinaus/Reinforcement-Learning/blob/main/GRPO_Loss_in_PyTorch_using_CartPole_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim

env = gym.make("CartPole-v1")
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n
# Used ReLU - Recitified Linear Unit as activation function
class GRPOAgent(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(obs_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
        )
        self.policy_head = nn.Linear(64, act_dim)
        self.value_head = nn.Linear(64, 1)

    def forward(self, x):
        x = self.fc(x)
        logits = self.policy_head(x)
        value = self.value_head(x)
        return logits, value

    def get_action(self, obs):
        logits, value = self.forward(obs)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        return action.item(), dist.log_prob(action), value

agent = GRPOAgent()
optimizer = optim.Adam(agent.parameters(), lr=3e-4)

gamma = 0.99
num_episodes = 500

for episode in range(num_episodes):
    obs_list, act_list, logp_list, reward_list, value_list = [], [], [], [], []

    obs = env.reset()
    done = False
    total_reward = 0

    while not done:
        obs_tensor = torch.tensor(obs, dtype=torch.float32)
        action, logp, value = agent.get_action(obs_tensor)

        next_obs, reward, done, _ = env.step(action)

        obs_list.append(obs_tensor)
        act_list.append(action)
        logp_list.append(logp)
        value_list.append(value.squeeze())
        reward_list.append(reward)

        obs = next_obs
        total_reward += reward

    # Compute returns
    G = 0
    returns = []
    for r in reversed(reward_list):
        G = r + gamma * G
        returns.insert(0, G)

    returns = torch.tensor(returns, dtype=torch.float32)
    values = torch.stack(value_list)
    advantages = returns - values.detach()

    logps = torch.stack(logp_list)
    policy_loss = -(logps * advantages).mean()
    value_loss = (returns - values).pow(2).mean()
    loss = policy_loss + 0.5 * value_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (episode + 1) % 10 == 0:
        print(
            f"Episode {episode+1:3d} | Total Reward: {total_reward:6.1f} | "
            f"Policy Loss: {policy_loss.item():.4f} | Value Loss: {value_loss.item():.4f}"
        )
