# GRPO + PPO算法demo

## 一、整体结构说明
1. **配置模块（Config）**：用于设置参数，其中包括选择算法（algo = "ppo" 或 "grpo"）。
2. **损失函数（ppo_loss / grpo_loss）**：分别对应两种算法，接口一致，便于切换。
3. **损失函数注册表（LOSS_FN_REGISTRY）**：根据配置信息cfg.algo自动选用对应的损失函数。
4. **训练步骤（train_step）**：每步自动调用当前算法对应的loss函数进行优化。
5. **主训练循环（main）**：训练主流程，只需在配置中更改cfg.algo即可实现算法的无缝切换。

In [2]:
import torch
from torch import nn
from dataclasses import dataclass

# ======================
# 1. 配置
# ======================

@dataclass
class Config:
    algo: str = "ppo"           # <<< 只改这一行: "ppo" 或 "grpo"
    lr: float = 1e-4
    clip_range: float = 0.2
    beta_kl: float = 0.01       # GRPO 用
    vf_coef: float = 0.5
    ent_coef: float = 0.01
    group_num: int = 4          # GRPO: 每个 prompt 有多少个候选 response
    input_len: int = 8          # prompt 长度
    seq_len: int = 32
    vocab_size: int = 50257
    hidden_size: int = 256
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


# ======================
# 2. 一个极简 policy 模型
# ======================

class TinyPolicy(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.embed = nn.Embedding(cfg.vocab_size, cfg.hidden_size)
        self.ln = nn.LayerNorm(cfg.hidden_size)
        self.head = nn.Linear(cfg.hidden_size, cfg.vocab_size)

    def forward(self, input_ids):
        """
        input_ids: [bs, seq_len]
        返回 logits: [bs, seq_len, vocab_size]
        """
        x = self.embed(input_ids)              # [bs, seq_len, hid]
        x = self.ln(x)
        logits = self.head(x)                 # [bs, seq_len, vocab]
        return logits

    def log_prob(self, input_ids, actions):
        """
        给定 input_ids，返回这些 action 的 log_prob，形状 [bs, seq_len]
        actions: [bs, seq_len]
        """
        logits = self.forward(input_ids)      # [bs, seq_len, vocab]
        logprobs = logits.log_softmax(dim=-1)
        return logprobs.gather(-1, actions.unsqueeze(-1)).squeeze(-1)


# ======================
# 3. PPO / GRPO Loss 实现（统一接口）
# ======================

def compute_mask(bs, seq_len, input_len, device, dtype):
    mask = torch.zeros(bs, seq_len, device=device, dtype=dtype)
    mask[:, input_len:] = 1.0
    return mask


def ppo_policy_loss(pi_logprob, pi_old_logprob, advantage, mask, clip_range):
    """
    纯 policy 的 PPO loss（不含 value / entropy）
    """
    bs, seq_len = pi_logprob.shape

    ratio = torch.exp(pi_logprob - pi_old_logprob)            # [bs, seq_len]
    ratio_clip = torch.clamp(ratio, 1.0 - clip_range, 1.0 + clip_range)

    advantage = advantage.view(bs, 1)                         # [bs, 1]
    pg_unclipped = ratio * advantage
    pg_clipped   = ratio_clip * advantage

    pg_loss_token = -torch.minimum(pg_unclipped, pg_clipped)  # [bs, seq_len]
    pg_loss_token = pg_loss_token * mask

    valid_tokens = mask.sum().clamp_min(1.0)
    pg_loss = pg_loss_token.sum() / valid_tokens
    return pg_loss


def grpo_kl(pi_logprob, pi_ref_logprob, mask):
    """
    简单 KL 估计：E[ log π - log π_ref ]
    返回样本级 [bs, 1]
    """
    log_ratio = (pi_logprob - pi_ref_logprob) * mask          # [bs, seq_len]
    valid_token = mask.sum(dim=1, keepdim=True).clamp_min(1.0)
    kl_per_sample = log_ratio.sum(dim=1, keepdim=True) / valid_token
    return kl_per_sample


def grpo_loss(
    pi_logprob,
    pi_old_logprob,
    pi_ref_logprob,
    advantage,
    mask,
    len_oi,        # [bs]
    group_num,
    clip_range,
    beta_kl,
):
    """
    GRPO-style loss:
    - PPO clipped policy gradient
    - minus beta * KL(π || π_ref)
    - 只在 response token 上算（mask）
    - 按 group / 长度归一化
    """
    device = pi_logprob.device
    dtype = pi_logprob.dtype
    bs, seq_len = pi_logprob.shape

    ratio = torch.exp(pi_logprob - pi_old_logprob)
    ratio_clip = torch.clamp(ratio, 1.0 - clip_range, 1.0 + clip_range)

    advantage = advantage.view(bs, 1)
    pg_unclipped = ratio * advantage
    pg_clipped   = ratio_clip * advantage
    policy_gradient = torch.minimum(pg_unclipped, pg_clipped) * mask  # [bs, seq_len]

    # KL 部分（先按样本，再摊回 token）
    kl_per_sample = grpo_kl(pi_logprob, pi_ref_logprob, mask)         # [bs, 1]
    valid_token = mask.sum(dim=1, keepdim=True).clamp_min(1.0)
    kl_token = (kl_per_sample / valid_token) * mask                   # [bs, seq_len]

    # token 级目标
    objective_token = policy_gradient - beta_kl * kl_token            # [bs, seq_len]

    # 长度归一化
    len_oi = len_oi.view(bs, 1).to(device=device, dtype=dtype).clamp_min(1.0)
    loss_token = (-1.0 / group_num) * (1.0 / len_oi) * objective_token

    loss = loss_token.sum()
    return loss


# 再包一层，把接口统一成一个函数
def policy_loss_wrapper(
    algo: str,
    pi_logprob,
    pi_old_logprob,
    pi_ref_logprob,
    advantage,
    mask,
    len_oi,
    cfg: Config,
):
    """
    统一接口: 只改 cfg.algo = "ppo" / "grpo" 即可切换
    """
    if algo == "ppo":
        # PPO 不用 ref / len_oi / group_num
        return ppo_policy_loss(pi_logprob, pi_old_logprob, advantage, mask, cfg.clip_range)

    elif algo == "grpo":
        return grpo_loss(
            pi_logprob=pi_logprob,
            pi_old_logprob=pi_old_logprob,
            pi_ref_logprob=pi_ref_logprob,
            advantage=advantage,
            mask=mask,
            len_oi=len_oi,
            group_num=cfg.group_num,
            clip_range=cfg.clip_range,
            beta_kl=cfg.beta_kl,
        )
    else:
        raise ValueError(f"Unknown algo: {algo}")


# ======================
# 4. 假数据构造 & Advantage 计算示例
# ======================

def make_fake_batch(cfg: Config, bs: int):
    """
    模拟一个 batch：
    - input_ids: [bs, seq_len]
    - actions:   [bs, seq_len]
    - rewards:   [bs]  (这里只给一个 per-sample scalar reward)
    """
    device = cfg.device
    seq_len = cfg.seq_len
    vocab = cfg.vocab_size

    input_ids = torch.randint(0, vocab, (bs, seq_len), device=device)
    actions   = torch.randint(0, vocab, (bs, seq_len), device=device)

    # 简单点：随机 reward
    rewards = torch.randn(bs, device=device)
    return input_ids, actions, rewards


def compute_advantage_from_reward(rewards: torch.Tensor):
    """
    极简版 advantage: 直接用 reward - reward.mean()
    实际使用时可以换成 GAE, baseline, critic 等。
    """
    adv = rewards - rewards.mean()
    return adv


# ======================
# 5. 单次 train_step：只改 cfg.algo 即可切换
# ======================

def train_step(
    cfg: Config,
    policy: TinyPolicy,
    ref_policy: TinyPolicy,   # PPO 可以不用 ref，GRPO 会用
    optimizer: torch.optim.Optimizer,
    input_ids,
    actions,
    rewards,
):
    """
    一次优化步骤：
    - 用 ref_policy 估当前 batch 的 log_prob（相当于 behavior / ref）
    - 复制一份作为 old_logprob（实际可缓存）
    - 再用 policy 计算当前 log_prob
    - 根据 cfg.algo 调用对应 loss
    """

    device = cfg.device

    # 不参与梯度的旧策略 & reference
    with torch.no_grad():
        pi_old_logprob = policy.log_prob(input_ids, actions)      # 这里当作 old（简单）
        pi_ref_logprob = ref_policy.log_prob(input_ids, actions)  # GRPO 用

    # 当前策略 log_prob
    pi_logprob = policy.log_prob(input_ids, actions)

    # advantage
    advantage = compute_advantage_from_reward(rewards)            # [bs]

    bs, seq_len = pi_logprob.shape
    mask = compute_mask(bs, seq_len, cfg.input_len, device, pi_logprob.dtype)

    # 这里假设每条样本有效 response 长度都一样
    len_oi = torch.full((bs,), seq_len - cfg.input_len, device=device, dtype=torch.long)

    # ✅ 只改这一行 cfg.algo，即可 PPO <-> GRPO
    loss = policy_loss_wrapper(
        algo=cfg.algo,
        pi_logprob=pi_logprob,
        pi_old_logprob=pi_old_logprob,
        pi_ref_logprob=pi_ref_logprob,
        advantage=advantage,
        mask=mask,
        len_oi=len_oi,
        cfg=cfg,
    )

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()


# ======================
# 6. 主训练循环
# ======================

def main():
    # 切换算法
    cfg = Config()
    cfg.algo = "grpo"   # 或 "ppo"
    print("Using algo:", cfg.algo)

    policy = TinyPolicy(cfg).to(cfg.device)
    ref_policy = TinyPolicy(cfg).to(cfg.device)
    # 通常 ref_policy 冻结参数，这里不需要 optimizer
    ref_policy.load_state_dict(policy.state_dict())  # 初始化一样

    optimizer = torch.optim.Adam(policy.parameters(), lr=cfg.lr)

    # 假设总 batch size = group_num * num_prompts_per_batch
    num_prompts_per_batch = 2
    bs = cfg.group_num * num_prompts_per_batch

    for step in range(10):
        input_ids, actions, rewards = make_fake_batch(cfg, bs)
        loss = train_step(
            cfg=cfg,
            policy=policy,
            ref_policy=ref_policy,
            optimizer=optimizer,
            input_ids=input_ids,
            actions=actions,
            rewards=rewards,
        )
        print(f"step {step}: loss = {loss:.4f}")


if __name__ == "__main__":
    main()


Using algo: grpo
step 0: loss = -0.0000
step 1: loss = -0.0000
step 2: loss = -0.0000
step 3: loss = -0.0000
step 4: loss = -0.0000
step 5: loss = -0.0000
step 6: loss = -0.0000
step 7: loss = 0.0000
step 8: loss = -0.0000
step 9: loss = 0.0000
