## ppo

Models: policy (actor), value head (critic), reward model (frozen), reference policy (frozen)


* Use the current policy to **generate** responses (collect data).
* **Reward** = reward-model score − β × KL(current policy || reference policy).
* Use the **value head** to estimate values and compute **advantages** (GAE recommended).
* **Freeze/cache** the batch’s old log-probs / advantages / return targets.
* Do **multiple small PPO updates** on this batch:

  * Recompute new log-probs → **ratio** $r = \exp(\text{logp}_\text{new} - \text{logp}_\text{old})$.
  * **Clip** $r$ to $[1-\varepsilon,\, 1+\varepsilon]$ → compute the policy loss.
  * Do **value regression** + **entropy bonus**.
  * **Backprop** and **update parameters**.
  * **Monitor KL**; if it gets too large, **early-stop** the epoch/iteration.



  policy loss: $\arg \max _\theta \mathbb{E}_{s \sim \nu^\beta, a \sim \pi_{\theta_k}(\cdot \mid s)}\left[\min \left(\frac{\pi_\theta(a \mid s)}{\pi_{\theta_k}(a \mid s)} A^{\pi_{\theta_k}}(s, a), \operatorname{clip}\left(\frac{\pi_\theta(a \mid s)}{\pi_{\theta_k}(a \mid s)}, 1-\epsilon, 1+\epsilon\right) A^{\pi_{\theta_k}}(s, a)\right)\right]$


In [None]:
def compute_advantage(gamma, lam, td_delta):
    # A_t = sum_l ((lambda * gamma)^l * td_delta_{t+l})
    td_delta = td_delta.detach().numpy()
    advantage = []
    acc = 0
    for delta in td_delta[::-1]:
        acc = delta + gamma * lam * acc
        advantage.append(acc)
    advantage.reverse()
    return torch.tensor(advantage, dtype=torch.float)

def compute_reward(r, kl, kl_coef, action_mask):
    # token-level reward shaping
    kl_reward = - kl_coef * kl 
    eos_index = action_mask.size(1) - 1 - action_mask.long().fliplr().argmax(dim=1, keepdim=True)
    last_reward = torch.zeros_like(kl).scatter_(dim=1, index=eos_indices, src=r.unsqueeze(1).to(kl.dtype))
    return last_reward - kl_reward
        
class PolicyLoss(nn.Module):
    def __init__(self, clip_ratio):
        super().__init__()
        self.clip_ratio = clip_ratio
    
    def forward(self, old_log_prob, log_prob, advantage, action_mask):
        ratio = torch.exp(log_prob - old_log_prob)
        surr1 = ratio * advantage
        surr2 = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantage
        loss = -torch.min(surr1, surr2)
        if action_mask is not None:
            loss = (loss * action_mask).sum(axis=-1) / action_mask.sum(axis=-1)
        else:
            loss = loss.mean(axis=-1) # one loss per sample
        return loss.mean()


class ValueLoss(nn.Module):
    def __init__(self, clip_ratio):
        super().__init__()
        self.clip_ratio = clip_ratio
        
    def forward(self, values, old_values, return_targets, action_mask): # targets = A_t + V_t
        if self.clip_ratio is not None:
            v_clipped = old_values + torch.clamp(values - old_values, -self.clip_ratio, self.clip_ratio)
            loss = torch.max((values - return_targets).pow(2), (v_clipped - return_targets).pow(2))
        else:
            loss = (values - return_targets).pow(2)
        if action_mask is not None:
            loss = (loss * action_mask).sum(axis=-1) / action_mask.sum(axis=-1)
        else:
            loss = loss.mean(axis=-1) # one loss per sample
        return 0.5 * loss.mean()

class PairWiseRMLoss(nn.Module):
    def forward(self, chosen_rewards, rejected_rewards, margin):
        if margin is not None:
            loss = -F.logsigmoid(chosen_rewards - rejected_rewards - margin)
        else:
            loss = -F.logsigmoid(chosen_rewards - rejected_rewards)
        return loss.mean() 
    

## DPO

$\mathcal{L}_{\mathrm{DPO}}\left(\pi_\theta, \pi_{\mathrm{ref}}\right)=-\mathbb{E}_{\left(x, y_w, y_l\right) \sim \mathcal{D}}\left[\log \sigma\left(\beta \log \frac{\pi_\theta\left(y_w \mid x\right)}{\pi_{\mathrm{ref}}\left(y_w \mid x\right)}-\beta \log \frac{\pi_\theta\left(y_l \mid x\right)}{\pi_{\mathrm{ref}}\left(y_l \mid x\right)}\right)\right]$

In [None]:
class DPOLoss(nn.Module):
    def __init__(self, beta, label_smoothing=0.0, ipo=False):
        super().__init__()
        self.beta = beta
        self.label_smoothing = label_smoothing # soft label
        self.ipo = ipo
    
    def forward(self, policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps):
        policy_log_ratio = policy_chosen_logps - policy_rejected_logps  
        reference_log_ratio = reference_chosen_logps - reference_rejected_logps
        logits = policy_log_ratio - reference_log_ratio
        losses = -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) -F.logsigmoid(-self.beta * logits) * self.label_smoothing
        loss = losses.mean(dim=-1)
        chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
        if self.ipo:
            loss = loss + self.margin * (chosen_rewards - rejected_rewards)
        return loss, chosen_rewards, rejected_rewards      

## GRPO



$\begin{aligned} & \mathcal{J}_{G R P O}(\theta)=\mathbb{E}\left[q \sim P(Q),\left\{o_i\right\}_{i=1}^G \sim \pi_{\theta_{\text {old }}}(O \mid q)\right] \\ & \quad \frac{1}{G} \sum_{i=1}^G \frac{1}{\left|o_i\right|} \sum_{t=1}^{\left|o_i\right|}\left\{\min \left[\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_{\theta_{\text {old }}}\left(o_{i, t} \mid q, o_{i,<t}\right)} \hat{A}_{i, t}, \operatorname{clip}\left(\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_{\theta_{\text {old }}}\left(o_{i, t} \mid q, o_{i,<t}\right)}, 1-\varepsilon, 1+\varepsilon\right) \hat{A}_{i, t}\right]-\beta \mathbb{D}_{K L}\left[\pi_\theta \| \pi_{r e f}\right]\right\}\end{aligned}$

$\mathbb{D}_{K L}\left[\pi_\theta \| \pi_{r e f}\right]=\frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-\log \frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-1$

In [None]:
def compute_loss(model, inputs):
    prompt_ids, prompt_mask = inputs['prompt_ids'], inputs['prompt_mask']
    completion_ids, completion_mask = inputs['completion_ids'], inputs['completion_mask']
    input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
    attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
    logits_to_keep = completion_ids.size(1)

    per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask)

    ref_per_token_logps = inputs['ref_per_token_logps']

    per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1

    advantages = inputs['advantages']

    old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
    coef_1 = torch.exp(per_token_logps - old_per_token_logps)   # r = new/old
    coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon)  # r clip to [1-ε, 1+ε]

    per_token_loss1 = coef_1 * advantages.unsqueeze(1)
    per_token_loss2 = coef_2 * advantages.unsqueeze(1)

    per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
