# GRPO

In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical

# Hyperparameters
learning_rate = 0.0005
eps_clip = 0.1
K_epoch = 3
num_trajectories = 5  # Number of trajectories per initial state

def layer_init(layer, std=1.41, bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class GRPO(nn.Module):
    def __init__(self):
        super(GRPO, self).__init__()
        self.data = []

        self.fc1 = nn.Linear(4, 256)
        self.fc_pi = nn.Linear(256, 2)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    def pi(self, x, softmax_dim=0):
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=softmax_dim)
        return prob

    def put_data(self, transition):
        self.data.append(transition)


    def make_batch(self):
        s, a, log_prob_a, advantages = zip(*self.data)
        self.data = []
        return (torch.tensor(np.array(s), dtype=torch.float),  # Fix applied here
                torch.tensor(a).unsqueeze(1),
                torch.tensor(log_prob_a).unsqueeze(1),
                torch.tensor(advantages, dtype=torch.float).unsqueeze(1))


    def train_net(self):
        """ Perform the GRPO training update with PPO clipping """
        s, a, log_prob_a, advantages = self.make_batch()
        for _ in range(K_epoch):
            pi = self.pi(s, softmax_dim=1)
            pi_a = pi.gather(1, a)
            ratio = torch.exp(torch.log(pi_a) - log_prob_a)  # a/b == exp(log(a)-log(b))

            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantages

            loss = -torch.min(surr1, surr2)  # Clipped policy loss

            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

def generate_trajectory(env, policy, seed):
    """ Generate a single trajectory from the environment using a fixed seed """
    s, _ = env.reset(seed = seed)
    s = np.array(s, dtype=np.float32)
    terminated, truncated = False, False
    log_probs = []
    rewards = []
    states = []
    actions = []

    while not (terminated or truncated):
        prob = policy.pi(torch.from_numpy(s).float())
        m = Categorical(prob)
        a = m.sample()
        s_prime, r, terminated, truncated, _ = env.step(a.item())

        log_probs.append(torch.log(prob[a]))
        rewards.append(r)
        states.append(s)
        actions.append(a.item())
        s = np.array(s_prime, dtype=np.float32)

    return states, actions, log_probs, rewards

## Main
env = gym.make('CartPole-v1')
policy = GRPO()
score = 0.0
log_each_iteration = 20  # Defines how often we log (based on iterations)
total_episodes = 0  # Tracks the total number of episodes played

for iteration in range(1000):
    seed = iteration
    trajectories = [generate_trajectory(env, policy, seed) for _ in range(num_trajectories)]
    total_episodes += num_trajectories

    returns = np.array([sum(traj[3]) for traj in trajectories])
    mean_r, std_r = returns.mean(), returns.std()
    normalized_advantages = [(r - mean_r) / (std_r + 0.1) for r in returns]

    for i in range(num_trajectories):
        states, actions, log_probs, _ = trajectories[i]
        adv = normalized_advantages[i]
        for j in range(len(states)):
            policy.put_data((states[j], actions[j], log_probs[j].item(), adv))

    policy.train_net()
    score += returns.mean()

    if total_episodes % (num_trajectories * log_each_iteration) == 0:
        print(f"Total episodes: {total_episodes}, avg score: {score / log_each_iteration:.2f}")
        score = 0.0

env.close()