# GRPO

$$
\mathcal{J}_{\text{GRPO}}(\theta) =  \mathbb{E}_{q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(O \mid q)} \Bigg[ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \min \Bigg( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,<t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,<t})} \hat{A}_{i,t}, \nonumber  \text{clip} \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,<t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,<t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \Bigg) \nonumber - \beta D_{\text{KL}}[\pi_\theta \| \pi_{\text{ref}}] \Bigg],
$$

$$
\hat{A}_{i,t}=\tilde{r}_i=\frac{r_i-\text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})} \notag
$$

In [None]:
import torch

def grpo_loss(rewards, logp_per_token, ref_logp_per_token, old_logp_per_token, beta=0.01 , clip_epsilon=0.25):
    """
    Args:
        rewards (torch.Tensor): 奖励, shape: [batch_size, num_generation]
        logp_per_token (torch.Tensor): 策略模型logp, shape: [batch_size, num_generation, seq_len]
        ref_logp_per_token (torch.Tensor): 参考模型logp, shape: [batch_size, num_generation, seq_len]
        old_logp_per_token (torch.Tensor): 旧策略模型logp, shape: [batch_size, num_generation, seq_len]
        beta (float): KL正则化参数
        clip_epsilon (float): 裁剪参数
    """
    mean_grouped_rewards = rewards.mean(dim=-1, keepdim=True) # shape: [batch_size, 1]
    std_grouped_rewards = rewards.std(dim=-1, keepdim=True)   # shape: [batch_size, 1]
    
    advantage_per_sequence = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) # [batch_size, num_generation]
    advantage_per_token = advantage_per_sequence.unsqueeze(-1).expand_as(logp_per_token) # [batch_size, num_generation, seq_len]
    
    importance_ratio = torch.exp(logp_per_token - old_logp_per_token)
    clipped_importance_ratio = torch.clamp(importance_ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon)

    adv1 = importance_ratio * advantage_per_token
    adv2 = clipped_importance_ratio * advantage_per_token
    policy_objective_per_token = torch.min(adv1, adv2) # [batch_size, num_generation, seq_len]
    
    mean_policy_objective = policy_objective_per_token.mean() # [1,]

    # [batch_size, num_generation, seq_len]
    kl_per_token = torch.exp(ref_logp_per_token - logp_per_token) - (ref_logp_per_token - logp_per_token) - 1
    mean_kl = kl_per_token.mean()
    
    loss = beta * mean_kl - mean_policy_objective
    
    return loss