## ppo

Models: policy (actor), value head (critic), reward model (frozen), reference policy (frozen)


* Use the current policy to **generate** responses (collect data).
* **Reward** = reward-model score − β × KL(current policy || reference policy).
* Use the **value head** to estimate values and compute **advantages** (GAE recommended).
* **Freeze/cache** the batch’s old log-probs / advantages / return targets.
* Do **multiple small PPO updates** on this batch:

  * Recompute new log-probs → **ratio** $r = \exp(\text{logp}_\text{new} - \text{logp}_\text{old})$.
  * **Clip** $r$ to $[1-\varepsilon,\, 1+\varepsilon]$ → compute the policy loss.
  * Do **value regression** + **entropy bonus**.
  * **Backprop** and **update parameters**.
  * **Monitor KL**; if it gets too large, **early-stop** the epoch/iteration.



  policy loss: $\arg \max _\theta \mathbb{E}_{s \sim \nu^\beta, a \sim \pi_{\theta_k}(\cdot \mid s)}\left[\min \left(\frac{\pi_\theta(a \mid s)}{\pi_{\theta_k}(a \mid s)} A^{\pi_{\theta_k}}(s, a), \operatorname{clip}\left(\frac{\pi_\theta(a \mid s)}{\pi_{\theta_k}(a \mid s)}, 1-\epsilon, 1+\epsilon\right) A^{\pi_{\theta_k}}(s, a)\right)\right]$


In [4]:
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

@torch.no_grad()
def compute_reward(
    r: torch.Tensor,
    kl: torch.Tensor,
    kl_coef: float,
    action_mask: torch.Tensor,
) -> torch.Tensor:
    # total_token_reward = terminal_reward_at_EOS + (-kl_coef * KL_token)

    B, T = kl.shape
    device, dtype = kl.device, kl.dtype

    # (-beta * KL) shaping on every valid token
    kl_reward = -kl_coef * kl  # negative penalty

    # Find EOS index = rightmost 1 in each row of mask
    flipped = torch.flip(action_mask.long(), dims=[1])                  # [B, T]
    rightmost_from_right = torch.argmax(flipped, dim=1, keepdim=True)   # [B, 1]
    eos_idx = (T - 1) - rightmost_from_right                            # [B, 1], long

    # Place terminal reward r at EOS position
    last_reward = torch.zeros_like(kl, dtype=dtype)
    last_reward.scatter_(dim=1, index=eos_idx, src=r.to(dtype).unsqueeze(1)) # (B, T)

    # Combine and zero out padding positions
    total = (last_reward + kl_reward) * action_mask.to(dtype) # (B, T)
    return total


@torch.no_grad()
def compute_td_delta(
    rewards: torch.Tensor,      # [B, T] per-token rewards (e.g., EOS reward + -beta*KL)
    values: torch.Tensor,       # [B, T] V(s_t) from the value head (current parameters)
    gamma: float,
    action_mask: torch.Tensor,  # [B, T] 1 for valid tokens (incl. EOS), 0 for padding
) -> torch.Tensor:
    B, T = rewards.shape
    device, dtype = rewards.device, rewards.dtype

    # Find EOS indices (rightmost 1 in the mask for each row)
    flipped = torch.flip(action_mask.long(), dims=[1])                 # [B, T]
    rightmost_from_right = torch.argmax(flipped, dim=1, keepdim=True)  # [B, 1]
    eos_idx = (T - 1) - rightmost_from_right                           # [B, 1]

    #   |   t     |  0  |  1  |  2  |  3  |
    #   | :------ | :-: | :-: | :-: | :-: |
    #   |  V_t    | 10  |  9  |  7  |  0  |
    #   | V_{t+1} |  9  |  7  |  0  |  0  |
    V_tp1 = torch.zeros_like(values)
    V_tp1[:, :-1] = values[:, 1:]

    # done_t is 1 exactly at EOS, else 0 on valid tokens; padding stays 0 but masked later
    done = torch.zeros_like(action_mask, dtype=values.dtype)           # [B, T], float 0/1
    done.scatter_(dim=1, index=eos_idx, src=torch.ones(B, 1, device=device, dtype=values.dtype))

    # δ_t = r_t + γ * (1-done_t) * V_{t+1} - V_t
    td = rewards + gamma * (1.0 - done) * V_tp1 - values               # [B, T]

    # Zero out padding positions (mask=0) to keep shapes consistent
    td = td * action_mask.to(dtype)
    return td


@torch.no_grad()
def compute_advantage(gamma: float, lam: float, td_delta: torch.Tensor) -> torch.Tensor:
    # td_delta: δ_t = r_t + γ V_{t+1} - V_t
    B, T = td_delta.shape
    adv = torch.zeros_like(td_delta) 
    acc = torch.zeros(B, device=td_delta.device, dtype=td_delta.dtype)
    for t in range(T - 1, -1, -1):
        acc = td_delta[:, t] + gamma * lam * acc # A_t = δ_t + γλ A_{t+1}
        adv[:, t] = acc

    return adv  


@torch.no_grad()
def normalize_advantage(adv: torch.Tensor, mask: Optional[torch.Tensor] = None, eps: float = 1e-8) -> torch.Tensor:
    if mask is None:
        mean = adv.mean()
        std = adv.std(unbiased=False)
    else:
        denom = mask.sum().clamp_min(1.0)
        mean = (adv * mask).sum() / denom
        var = (mask * (adv - mean) ** 2).sum() / denom
        std = var.sqrt()

    return (adv - mean) / (std + eps)

class PolicyLoss(nn.Module):
    # L_t = -min( ratio * A_t, clip(ratio, 1-ε, 1+ε) * A_t )
    #   ratio = π_θ(a_t | s_t) / π_{θ_old}(a_t | s_t) = exp(log_prob_new - log_prob_old)
    def __init__(self, clip_ratio: float, eps: float = 1e-8):
        super().__init__()
        self.clip_ratio = float(clip_ratio)
        self.eps = eps

    def forward(
        self,
        old_log_prob: torch.Tensor,  # [B, T]
        log_prob: torch.Tensor,      # [B, T]
        advantage: torch.Tensor,     # [B, T] (should be detached/treated as constant)
        action_mask: Optional[torch.Tensor] = None,  # [B, T] 0/1
    ) -> torch.Tensor:
        advantage = advantage.detach()
        ratio = torch.exp(log_prob - old_log_prob)                     # [B, T]
        surr1 = ratio * advantage
        surr2 = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantage
        loss_t = -torch.min(surr1, surr2)                              # [B, T]

        if action_mask is not None:
            denom = action_mask.sum(dim=-1).clamp_min(self.eps)        # [B]
            loss = (loss_t * action_mask).sum(dim=-1) / denom          # [B]
        else:
            loss = loss_t.mean(dim=-1)                                  # [B]

        return loss.mean()  # scalar


class ValueLoss(nn.Module):
    # MSE between values and (A + V_old)
    def __init__(self, clip_ratio: Optional[float], eps: float = 1e-8):
        super().__init__()
        self.clip_ratio = clip_ratio
        self.eps = eps

    def forward(
        self,
        values: torch.Tensor,         # [B, T]
        old_values: torch.Tensor,     # [B, T]
        return_targets: torch.Tensor, # [B, T]  (e.g., A + V or discounted returns)
        action_mask: Optional[torch.Tensor] = None,  # [B, T] 0/1
    ) -> torch.Tensor:
        if self.clip_ratio is not None:
            v_clipped = old_values + torch.clamp(values - old_values, -self.clip_ratio, self.clip_ratio)
            mse_unclipped = (values - return_targets).pow(2)
            mse_clipped   = (v_clipped - return_targets).pow(2)
            loss_t = torch.max(mse_unclipped, mse_clipped)
        else:
            loss_t = (values - return_targets).pow(2)

        if action_mask is not None:
            denom = action_mask.sum(dim=-1).clamp_min(self.eps)
            loss = (loss_t * action_mask).sum(dim=-1) / denom
        else:
            loss = loss_t.mean(dim=-1)

        return 0.5 * loss.mean()  # scalar


class PairWiseRMLoss(nn.Module):
    # Bradley–Terry style: loss = -log σ( (r_chosen - r_rejected) - margin )
    def __init__(self):
        super().__init__()

    def forward(
        self,
        chosen_rewards: torch.Tensor,    # [B]
        rejected_rewards: torch.Tensor,  # [B]
        margin: Optional[float] = None,
    ) -> torch.Tensor:
        diff = chosen_rewards - rejected_rewards
        if margin is not None:
            diff = diff - float(margin)
        loss = -F.logsigmoid(diff)
        return loss.mean()


## DPO

$\mathcal{L}_{\mathrm{DPO}}\left(\pi_\theta, \pi_{\mathrm{ref}}\right)=-\mathbb{E}_{\left(x, y_w, y_l\right) \sim \mathcal{D}}\left[\log \sigma\left(\beta \log \frac{\pi_\theta\left(y_w \mid x\right)}{\pi_{\mathrm{ref}}\left(y_w \mid x\right)}-\beta \log \frac{\pi_\theta\left(y_l \mid x\right)}{\pi_{\mathrm{ref}}\left(y_l \mid x\right)}\right)\right]$

In [5]:
class DPOLoss(nn.Module):
    def __init__(self, beta, label_smoothing=0.0, ipo=False):
        super().__init__()
        self.beta = beta
        self.label_smoothing = label_smoothing # soft label
        self.ipo = ipo
    
    def forward(self, policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps):
        policy_log_ratio = policy_chosen_logps - policy_rejected_logps  # [B, T]
        reference_log_ratio = reference_chosen_logps - reference_rejected_logps
        logits = policy_log_ratio - reference_log_ratio
        losses = -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) -F.logsigmoid(-self.beta * logits) * self.label_smoothing
        loss = losses.mean(dim=-1) # mean over sequence
        chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() # no grad
        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() # no grad
        if self.ipo: # Identity Preference Optimization
            loss = loss + self.margin * (chosen_rewards - rejected_rewards)
        return loss, chosen_rewards, rejected_rewards      

## GRPO



$\begin{aligned} & \mathcal{J}_{G R P O}(\theta)=\mathbb{E}\left[q \sim P(Q),\left\{o_i\right\}_{i=1}^G \sim \pi_{\theta_{\text {old }}}(O \mid q)\right] \\ & \quad \frac{1}{G} \sum_{i=1}^G \frac{1}{\left|o_i\right|} \sum_{t=1}^{\left|o_i\right|}\left\{\min \left[\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_{\theta_{\text {old }}}\left(o_{i, t} \mid q, o_{i,<t}\right)} \hat{A}_{i, t}, \operatorname{clip}\left(\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_{\theta_{\text {old }}}\left(o_{i, t} \mid q, o_{i,<t}\right)}, 1-\varepsilon, 1+\varepsilon\right) \hat{A}_{i, t}\right]-\beta \mathbb{D}_{K L}\left[\pi_\theta \| \pi_{r e f}\right]\right\}\end{aligned}$

$\mathbb{D}_{K L}\left[\pi_\theta \| \pi_{r e f}\right]=\frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-\log \frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-1$

In [None]:
@torch.no_grad()
def compute_grpo_advantages(
    rewards: torch.Tensor,   # [B], one scalar reward per completion
    group_size: int,
    eps: float = 1e-4,
) -> torch.Tensor:

    B = rewards.shape[0]
    if B % group_size != 0:
        raise ValueError(
            f"Batch size {B} is not divisible by group_size {group_size}. "
        )

    device, dtype = rewards.device, rewards.dtype
    num_groups = B // group_size

    # Reshape to [num_groups, group_size] so each row is a group of completions
    grouped_rewards = rewards.view(num_groups, group_size)  # [num_groups, G]

    # Compute mean reward per group: μ
    mean = grouped_rewards.mean(dim=1, keepdim=True)        # [num_groups, 1]

    advantages_grouped = grouped_rewards - mean             # [num_groups, G]

    # Flatten back to [B]
    advantages = advantages_grouped.reshape(B)

    return advantages



class GRPOLoss(nn.Module):
    def __init__(self, epsilon: float = 0.2, beta: float = 0.1, reduction: str = "mean"):
        super().__init__()
        self.epsilon = epsilon
        self.beta = beta

    def forward(
        self,
        per_token_logps: torch.Tensor,
        old_per_token_logps: torch.Tensor,
        ref_per_token_logps: torch.Tensor,
        advantages: torch.Tensor,
        mask: torch.Tensor | None = None,
    ):

        B, T = per_token_logps.shape

        # If advantages is [B], broadcast it to [B, T] so every token of a sequence
        # shares the same sequence-level advantage.
        if advantages.dim() == 1:
            advantages = advantages.unsqueeze(1).expand(B, T)
        elif advantages.shape != (B, T):
            raise ValueError("advantages must be shape [B] or [B, T]")

        # If no mask is given, treat all tokens as valid (mask = 1)
        if mask is None:
            mask = torch.ones_like(per_token_logps, dtype=torch.float32)
        else:
            mask = mask.to(per_token_logps.dtype)


        # 1) PPO-style clipped surrogate:
        #    ratio = pi_theta(a_t|s_t) / pi_old(a_t|s_t) = exp(logp_new - logp_old)
        log_ratio = per_token_logps - old_per_token_logps       
        ratio = torch.exp(log_ratio)             

        # Clip the ratio to [1 - epsilon, 1 + epsilon]
        clipped_ratio = torch.clamp(ratio, 1.0 - self.epsilon, 1.0 + self.epsilon)

        # Unclipped and clipped objectives per token: r_t * A_t, clip(r_t) * A_t
        # Note: advantage can be positive or negative.
        surrogate_unclipped = ratio * advantages
        surrogate_clipped = clipped_ratio * advantages

        # PPO uses min(r*A, clip(r)*A) to avoid overly large policy updates.
        # We *maximize* this surrogate, so in a loss (to minimize) we take the negative.
        policy_loss_per_token = -torch.min(surrogate_unclipped, surrogate_clipped)

        # 2) KL penalty per token between current policy and reference policy:
        #    D_KL[pi_theta || pi_ref] ~= pi_ref/pi_theta - log(pi_ref/pi_theta) - 1
        #    Using the log-space form: Δ = log pi_ref - log pi_theta
        log_ratio_ref_policy = ref_per_token_logps - per_token_logps  # Δ = log(pi_ref/pi_theta)
        ratio_ref_policy = torch.exp(log_ratio_ref_policy)            # pi_ref / pi_theta

        # Closed-form per-token KL approximation:
        # D_KL = ratio_ref_policy - log_ratio_ref_policy - 1
        kl_per_token = ratio_ref_policy - log_ratio_ref_policy - 1.0

        # Total per-token loss: PPO surrogate loss + beta * KL penalty
        total_per_token_loss = policy_loss_per_token + self.beta * kl_per_token

        # 3) Mask out invalid tokens (e.g., padding / prompt) and reduce
        total_per_token_loss = total_per_token_loss * mask
        kl_per_token = kl_per_token * mask

        denom = mask.sum().clamp_min(1.0)
        loss = total_per_token_loss.sum() / denom

        return loss
