# 一、粗糙版GRPO伪代码

存在的问题：
- 1. KL散度非标准实现，基于a泰勒展开式的二阶估计
- 2. 未对优势值展开具体讲解，即过程监督和结果监督
- 3. 未对价值模型是如何被省略做具体讲解

In [None]:
import random


for _ in range(num_iterations):
    reference_model = copy.deepcopy(policy_model)
    
    for _ in range(setps_per_iteration):
        # 便利提示数据的每个批次
        # 每个批次的prompt重复num_generations次，生成num_generations次，生成num_generations个响应
        batch_prompt = random.sample(prompts, batch_size)
        batch_prompt = batch_prompt.repeat_interleave(num_generations)
        # 使用当前policy模型根据批次提示生成响应
        batch_response = actor_model.gengerate(batch_prompt)
        # 将批次提示和生成的响应拼接在一起
        batch_data = concat(batch_prompt, batch_response)
        
        # 奖励模型有多个，每个奖励模型对数据进行打分，得到多个奖励，最后相加得到最终奖励
        batch_rewards = reward_model(batch_data)
        
        # 前向传播当前policy模型，得到所有可能动作/词元的完整概率分布，实际生成序列的动作、词元的具体概率
        old_actor_all_probs, old_actor_probs = actor_model.forward(batch_data)
        # 前向传播参考模型，得到所有可能动作、词元的g完整概率分布、实际生成序列的动作、词元的具体概率
        ref_all_probs, ref_probs = ref_model.forward(batch_data)
        
        # 优势计算有两种方式：过程监督和结果监督
        advantages = compute_advantages(batch_rewards)
        
        # 重要性采样
        for _ in range(mu):
            # 前向传播当前策略模型，得到所有可能动作/词元的额完整概率分布、实际生成序列的动作、词元的具体概率和所有值
            # actor_all_probs: actor模型对所有可能动作/词元的完整概率分布（logits或log probabilities）
            # actor_probs: actor模型对实际生成序列的动作/词元的具体概率(即生成轨迹的概率)
            new_actor_all_probs, new_actor_probs = actor_model.forward(batch_data)
            
            # 计算新、旧policy概率比
            ratio = new_actor_probs / old_actor_probs
            
            # 计算优势损失损失
            loss_adv = torch.mean(-advantages * ratio)
            
            # 计算当前策略模型和参考模型之间的KL散度损失（KL惩罚项， penalty）
            # per_torken_kl = torch.exp(ref_all_probs - new_actor_all_probs) - (ref_all_probs - new_actor_all_probs) - 1
            loss_kl = compute_KL(new_actor_all_probs, ref_all_probs)
            
            # 计算总损失，由actor损失和critic损失加权求和得到
            loss = loss_adv + self.beta * loss_kl
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

# 二、详细版GRPO代码

In [None]:
import torch
import torch.nn as nn
import copy
import random

# 超参数
batch_size = 32          # 批量大小
num_generations = 8      # 每条提示生成 n 条序列
max_len = 128            # 最大生成长度
epsilon = 0.2            # PPO 裁剪范围
beta = 0.01              # KL 惩罚系数
gamma = 0.99             # 折扣因子（过程监督）
mu = 10                  # 每个批次的优化步数
num_iterations = 100     # 总迭代次数
PAD_TOKEN_ID = 0         # Padding token ID
SEP_TOKEN_ID = 50256     # 分隔符 token ID（例如，GPT-2）

# 假设模型和数据(需要训练的大模型：deepseek)
policy_model = nn.Module()    # 策略模型（例如，Transformer）
ref_model = copy.deepcopy(policy_model)  # 参考模型（SFT 模型）
ref_model.eval()

for param in ref_model.parameters():
    param.requires_grad = False
    
reward_models = [nn.Module()]  # 奖励模型列表（支持多个），这里简化为１个
weights = [1.0]               # 每个奖励模型的权重，这里简化为１个
optimizer = torch.optim.Adam(policy_model.parameters(), lr=1e-5)
prompts = [...]               # 提示数据集（列表或张量）

def gather_log_probs(logits, actions):
    """从 logits 中提取实际动作的对数概率"""
    log_probs = torch.log_softmax(logits, dim=-1)  # [batch_size * n, seq_len, vocab_size]
    return torch.gather(log_probs, dim=-1, index=actions.unsqueeze(-1)).squeeze(-1)  # [batch_size * n, seq_len]

def compute_KL(new_logits, ref_logits, response_mask):
    """
    计算标准KL散度（方式１：只计算被采样到的动作的KL散度）
    输入：
        new_logits: 新策略 logits [batch_size * n, seq_len]
        ref_logits: 参考策略 logits [batch_size * n, seq_len]
    --------------------------------------------------------------------
    计算标准 KL 散度（方式 2：整个词汇表概率分布）
    输入：
        new_logits: 新策略 logits [batch_size * n, seq_len, vocab_size]
        ref_logits: 参考策略 logits [batch_size * n, seq_len, vocab_size]
        response_mask: 响应掩码 [batch_size * n, seq_len]
    输出：
        kl: 平均 KL 散度（标量）
    """
    new_probs = torch.softmax(new_logits, dim=-1)  # [batch_size * n, seq_len, vocab_size]
    ref_probs = torch.softmax(ref_logits, dim=-1)  # [batch_size * n, seq_len, vocab_size]
    kl = torch.sum(new_probs * (torch.log(new_probs + 1e-10) - torch.log(ref_probs + 1e-10)), dim=-1)  # [batch_size * n, seq_len]
    return (kl * response_mask).sum() / response_mask.sum()  # 忽略 padding

def compute_process_advantages(batch_rewards, batch_size, num_generations, seq_len, gamma=0.99):
    """
    过程监督优势计算
    - 逐 token 奖励，基于后续累积奖励的归一化
    - 公式：A_{i,t} = (R_{i,t} - mean({R_{j,t}})) / std({R_{j,t}})
    - R_{i,t} = r_{i,t} + γ r_{i,t+1} + γ^2 r_{i,t+2} + ...
    输入：
        batch_rewards: 逐 token 奖励 [batch_size * n, seq_len]
        batch_size: 批量大小
        num_generations: 每条提示生成 n 条序列
        seq_len: 序列长度
        gamma: 折扣因子
    输出：
        advantages: 逐 token 优势 [batch_size * n, seq_len]
    """
    rewards = batch_rewards.view(batch_size, num_generations, seq_len)  # [batch_size, n, seq_len]
    advantages = torch.zeros_like(rewards)
    # 动态规划计算累积奖励
    future_rewards = torch.zeros(batch_size, num_generations).to(rewards.device)
    for t in range(seq_len - 1, -1, -1):
        future_rewards = rewards[:, :, t] + gamma * future_rewards  # R_{i,t} = r_{i,t} + γ R_{i,t+1}
        mean_rewards = future_rewards.mean(dim=1, keepdim=True)  # 均值：1/n ∑ R_{j,t}
        std_rewards = future_rewards.std(dim=1, keepdim=True) + 1e-8  # 标准差：std({R_{j,t}})
        advantages[:, :, t] = (future_rewards - mean_rewards) / std_rewards  # 归一化
    return advantages.view(batch_size * num_generations, seq_len)

def compute_outcome_advantages(batch_rewards, batch_size, num_generations):
    """
    结果监督优势计算
    - 序列级奖励，基于群体归一化
    - 公式：A_i = (r_i - mean({r_j})) / std({r_j})
    输入：
        batch_rewards: 序列级奖励 [batch_size * n]
        batch_size: 批量大小
        num_generations: 每条提示生成 n 条序列
    输出：
        advantages: 序列级优势，扩展到逐 token [batch_size * n, seq_len]
        
    每个token使用序列的归一化优势。就是同一个生成序列的所有token是同一个优势值
    """
    rewards = batch_rewards.view(batch_size, num_generations)  # [batch_size, n]
    mean_rewards = rewards.mean(dim=1, keepdim=True)  # 均值：1/n ∑ r_j
    std_rewards = rewards.std(dim=1, keepdim=True) + 1e-8  # 标准差：std({r_j})
    advantages = (rewards - mean_rewards) / std_rewards  # 归一化
    return advantages.view(batch_size * num_generations, 1).expand(-1, max_len)  # 扩展到 [batch_size * n, seq_len]

# GRPO 训练循环
for iteration in range(num_iterations):
    # 1. 数据准备
    batch_prompt = random.sample(prompts, batch_size)  # 随机采样提示
    batch_prompt = torch.tensor(batch_prompt).repeat_interleave(num_generations, dim=0).to(device)  # [batch_size * n, len_prompt]
    
    # 说明：n就是num_generations
        
    # 2. 生成序列
    with torch.no_grad():
        batch_response, old_log_probs = policy_model.generate(
            batch_prompt, max_len=max_len, return_log_probs=True
        )  # batch_response: [batch_size * n, seq_len], old_log_probs: [batch_size * n, seq_len]
        
    response_mask = (batch_response != PAD_TOKEN_ID).float()  # 掩码：忽略 padding [batch_size * n, seq_len]
    sep_token = torch.tensor([[SEP_TOKEN_ID]] * (batch_size * num_generations)).to(batch_prompt.device)
    batch_data = torch.cat([batch_prompt, sep_token, batch_response], dim=1)  # [batch_size * n, len_prompt + 1 + seq_len]
    
    # 3. 奖励计算
    # 过程监督：逐 token 奖励
    batch_rewards_process = torch.zeros(batch_size * num_generations, max_len).to(device)
    for w, reward_model in zip(weights, reward_models):
        batch_rewards_process += w * reward_model(batch_data)  # [batch_size * n, seq_len]
    batch_rewards_process = batch_rewards_process * response_mask  # 忽略 padding
    
    # 结果监督：序列级奖励
    batch_rewards_outcome = torch.zeros(batch_size * num_generations).to(device)
    for w, reward_model in zip(weights, reward_models):
        batch_rewards_outcome += w * reward_model(batch_data, mode='outcome')  # [batch_size * n]
        
    # 4. 优势计算
    # 过程监督---复杂度非常高
    advantages_process = compute_process_advantages(
        batch_rewards_process, batch_size, num_generations, max_len, gamma
    )  # [batch_size * n, seq_len]
    
    # 结果监督---复杂度相对来说低很多
    advantages_outcome = compute_outcome_advantages(
        batch_rewards_outcome, batch_size, num_generations
    )  # [batch_size * n, seq_len]
    
    # 选择优势函数（示例：使用过程监督）
    advantages = advantages_process  # 或 advantages_outcome
    
    # 5. 参考模型前向传播
    with torch.no_grad():
        ref_logits = ref_model.forward(batch_data)  # [batch_size * n, seq_len, vocab_size]
        ref_log_probs = gather_log_probs(ref_logits, batch_response)  # [batch_size * n, seq_len]
        
    # 6. 优化
    for _ in range(mu):
        # 新策略前向传播
        new_logits = policy_model.forward(batch_data)  # [batch_size * n, seq_len, vocab_size]
        new_log_probs = gather_log_probs(new_logits, batch_response)  # [batch_size * n, seq_len]
        
        # PPO 裁剪损失
        ratio = torch.clamp(torch.exp(new_log_probs - old_log_probs.detach()), 1e-10, 1e10)
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages
        loss_adv = -torch.mean(torch.min(surr1, surr2) * response_mask)  # 忽略 padding
        
        # KL 散度（标准方式）
        loss_kl = compute_KL(new_logits, ref_logits, response_mask)
        
        # 总损失
        loss = loss_adv + beta * loss_kl
        
        # 优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # 7. 可选：打印损失
    print(f"Iteration {iteration}, Loss: {loss.item():.4f}, Adv Loss: {loss_adv.item():.4f}, KL Loss: {loss_kl.item():.4f}")
        