In [8]:
import torch
import torch.nn.functional as F
import numpy as np
import gymnasium as gym

from typing import List, Tuple
from torch import nn
from torch.optim import Adam
from torch.distributions.categorical import Categorical

In [46]:
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
        nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.softmax(self.network(x), dim=1)

    def get_action(self, state: List[float]) -> Tuple[int, float]:
        """Return action and its probability"""
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        with torch.no_grad():
            probs = self.forward(state_tensor)

        # Sample action
        dist = Categorical(probs)
        action = dist.sample()
        action_prob = probs[0, action].item()

        return action.item(), action_prob


In [44]:
class Trajectory:
    def __init__(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.probs = []
        self.dones = []

    def add(self, state: np.array, action: int, reward: float, prob: float, done: bool):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.probs.append(prob)
        self.dones.append(done)

    def _compute_returns(self, gamma: float = 0.99) -> List[float]:
        returns = []
        G = 0

        for r, d in zip(reversed(self.rewards), reversed(self.dones)):
            G = r + gamma * G * (1 - int(d))
            returns.append(G)

        return list(reversed(returns))

    def get_return(self, gamma: float) -> float:
        """Returns return (not reward) for the entire trajectory"""
        returns = self._compute_returns(gamma)
        # If you do sum, numbers would go wild based on length of the game when _compute_returns is applied
        # Mean is larger if game was larger, but the differences aren't wild like when summing
        return np.mean(returns)

    def to_tensor(self, gamma: float = 0.99):
        states = torch.FloatTensor(np.array(self.states))
        actions = torch.LongTensor(np.array(self.actions))
        probs = torch.stack([torch.tensor(p) for p in self.probs])
        returns = torch.FloatTensor(self._compute_returns(gamma))

        return states, actions, probs, returns


In [45]:
t = Trajectory()
t.add(state=np.array([0, 1]), action=0, reward=1, prob=0.2, done=False)
t.add(state=np.array([0, 1]), action=0, reward=1, prob=0.2, done=False)
t.add(state=np.array([0, 1]), action=0, reward=1, prob=0.2, done=False)
t.add(state=np.array([0, 1]), action=0, reward=1, prob=0.2, done=True)

gamma = 0.99
print(t._compute_returns(gamma), t.get_return(gamma))

states, actions, probs, returns = t.to_tensor()
print(probs.shape)

[3.9403989999999998, 2.9701, 1.99, 1.0] 2.47512475
torch.Size([4])


In [47]:

class GRPO:
    def __init__(self, env: gym.Env, hidden_dim: int = 64, lr_policy: float = 1e-3, gamma: float = 0.99, n_groups: int = 3, clip_param: float = 0.2):
        self.env = env
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        self.gamma = gamma
        self.n_groups = n_groups
        self.clip_param = clip_param

        self.policy_net = PolicyNetwork(input_dim=self.state_dim, hidden_dim=hidden_dim, output_dim=self.action_dim)
        self.policy_optimizer = Adam(self.policy_net.parameters(), lr=lr_policy)

    def collect_trajectories(self, n_trajectories: int):
        trajectories = []

        for _ in range(n_trajectories):
            traj = Trajectory()
            state, _ = self.env.reset()
            done = False

            while not done:
                action, prob = self.policy_net.get_action(state)
                next_state, reward, terminated, truncated, _ = self.env.step(action)
                done = terminated or truncated

                traj.add(state=state, action=action, reward=reward, prob=prob, done=done)
                state = next_state

            trajectories.append(traj)

        return trajectories

    def group_trajectories(self, trajectories: List[Trajectory]):
        """Group trajectories into `self.n_groups` for comparisons"""
        # traj_with_returns = [(traj, traj.get_return(self.gamma)) for traj in trajectories]

        traj_with_returns = []

        for traj in trajectories:
            mean_returns = traj.get_return(self.gamma) # rewards and returns are different
            traj_with_returns.append((traj, mean_returns))

        sorted_trajectories = [t for (t, _) in sorted(traj_with_returns, key=lambda x: x[1])] # ascending order of rewards

        grouped_trajectories = []
        group_size = max(1, len(sorted_trajectories) // self.n_groups)

        for i in range(0, len(sorted_trajectories), group_size):
            group = sorted_trajectories[i:i + group_size]
            if len(group) > 0:
                grouped_trajectories.append(group)

        # Ensure we don't have more than n_groups by merging groups (starting from the end - i.e. highest returns)
        while len(grouped_trajectories) > self.n_groups:
            if len(grouped_trajectories) >= 2:
                grouped_trajectories[-2].extend(grouped_trajectories[-1])
                grouped_trajectories.pop()

        return grouped_trajectories

    def update_policy(self, grouped_trajectories: List[List[Trajectory]]):
        """Updated policy using group relative approach"""
        for group_idx, group in enumerate(grouped_trajectories):
            # Group weight scales with group index: because we sorted them in ascending order by scores
            group_weight = (group_idx + 1) / len(grouped_trajectories)

            for trajectory in group:
                states, actions, old_probs, returns = trajectory.to_tensor(self.gamma)

                if len(states) == 0:
                    continue

                current_probs = self.policy_net(states)
                dist = Categorical(current_probs)

                # Get log prob for the action take
                log_probs = dist.log_prob(actions)
                old_log_probs = torch.log(old_probs + 1e-10) # add small epsilon to avoid log(0)

                ratios = torch.exp(log_probs - old_log_probs)
                surr1 = ratios * returns * group_weight
                surr2 = torch.clamp(ratios, 1.0 - self.clip_param, 1 + self.clip_param) * returns * group_weight
                # print("Surr losses", surr1.mean().item(), "ratio", ratios.mean().item(), "returns", returns.mean().item(), "group weight", group_weight)
                policy_loss = -torch.min(surr1, surr2).mean()

                self.policy_optimizer.zero_grad()
                policy_loss.backward()
                self.policy_optimizer.step()

    def train(self, n_episodes: int, n_trajectories_per_update: int = 10):
        rewards_history = []

        for episode in range(n_episodes):
            trajectories = self.collect_trajectories(n_trajectories_per_update)

            # Just for logging:
            avg_reward = np.mean([sum(traj.rewards) for traj in trajectories])
            rewards_history.append(avg_reward)

            grouped_trajectories = self.group_trajectories(trajectories)

            self.update_policy(grouped_trajectories)

            if (episode + 1) % 10 == 0:
                print("Episode {}, Avg reward: {:.2f}".format(episode+1, avg_reward))

        return rewards_history

    def evaluate(self, env: gym.Env, n_episodes: int = 10, render: bool = False):
        rewards = []

        for _ in range(n_episodes):
            state, _ = env.reset()
            done = False
            total_reward = 0

            while not done:
                if render:
                    env.render()

                action, _ = self.policy_net.get_action(state)
                state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated
                total_reward += reward

            rewards.append(total_reward)

        avg_reward = np.mean(rewards)
        print("Evaluation: Average reward over {} episodes: {:.2f}".format(n_episodes, avg_reward))

        return avg_reward


In [35]:
# Init grpo algorithm and env

ENV_NAME = 'CartPole-v1' # Possible values: CartPole-v1, Acrobot-v1

env = gym.make(ENV_NAME, max_episode_steps=500)
grpo = GRPO(env=env, hidden_dim=64, lr_policy=0.001, gamma=0.99, n_groups=3, clip_param=0.2)

In [36]:
# Training:
# Can be rerun multiple times to improve performance of the same model
rewards = grpo.train(n_episodes=300, n_trajectories_per_update=5)
avg_reward = grpo.evaluate(env, n_episodes=10)

Episode 10, Avg reward: 29.60
Episode 20, Avg reward: 29.40
Episode 30, Avg reward: 40.20
Episode 40, Avg reward: 100.40
Episode 50, Avg reward: 72.20
Episode 60, Avg reward: 72.60
Episode 70, Avg reward: 192.80
Episode 80, Avg reward: 103.00
Episode 90, Avg reward: 313.40
Episode 100, Avg reward: 341.60
Episode 110, Avg reward: 171.80
Episode 120, Avg reward: 311.80
Episode 130, Avg reward: 500.00
Episode 140, Avg reward: 317.80
Episode 150, Avg reward: 192.00
Episode 160, Avg reward: 164.60
Episode 170, Avg reward: 317.80
Episode 180, Avg reward: 486.40
Episode 190, Avg reward: 500.00
Episode 200, Avg reward: 500.00
Episode 210, Avg reward: 500.00
Episode 220, Avg reward: 476.00
Episode 230, Avg reward: 500.00
Episode 240, Avg reward: 479.20
Episode 250, Avg reward: 500.00
Episode 260, Avg reward: 500.00
Episode 270, Avg reward: 485.40
Episode 280, Avg reward: 500.00
Episode 290, Avg reward: 500.00
Episode 300, Avg reward: 500.00
Evaluation: Average reward over 10 episodes: 500.00


In [40]:
## Evaluate:
env = gym.make(ENV_NAME, max_episode_steps=500, render_mode='human')
_ = grpo.evaluate(env, 1, True)
env.close()

Evaluation: Average reward over 1 episodes: 500.00
