# **Group Relative Policy Optimization (GRPO)**  

**Group Relative Policy Optimization (GRPO)** enhances traditional Proximal Policy Optimization (PPO) by reducing noise and variance during training, especially when dealing with tasks like text generation. Instead of updating the policy using a single response per query, GRPO samples multiple responses and averages their rewards to capture a more representative signal. This approach helps smooth out noise that can arise from individual, potentially suboptimal responses and ensures that policy updates are guided by a more stable and generalizable signal. By reducing the reliance on any one response, GRPO stabilizes the optimization process and prevents the model from overfitting to specific samples.

The core reason GRPO reduces variance is through reward aggregation across multiple sampled responses. When relying on a single path (as in vanilla PPO), randomness in sampling or noisy feedback can cause unstable updates. By averaging rewards across multiple responses, GRPO approximates the expected reward, mitigating the effect of outliers and noisy samples. This variance reduction allows the policy to learn robustly, generalizing its behavior across a broader range of possible outputs without being overly sensitive to specific responses.

Probability ratio clipping further stabilizes GRPO by preventing excessive policy shifts during updates. The ratio, which compares how much the new policy diverges from the old, is clipped within a defined range to avoid over-updating based on extreme samples. Combined with reward averaging, this mechanism ensures that the model takes gradual, meaningful steps in optimization without erratic behavior. In text generation, this approach prevents the model from overreacting to rare or overly favorable completions, fostering controlled exploration of diverse outputs.

GRPO's effectiveness also comes from incorporating techniques like KL divergence regularization, which keeps the policy anchored to its original language generation capabilities. The grouping mechanism—sampling multiple responses per query—is particularly valuable when rewards are derived from human feedback, which can be inconsistent or noisy. By averaging across diverse responses, GRPO reduces the risk of overfitting to specific preferences, resulting in a more generalizable and scalable policy for tasks involving complex and sequential outputs.

This group-level aggregation is critical when working with long, complex outputs (like text generation) because the reward signal for a single output might be noisy. By considering multiple completions, we smooth out the noise.

## Full GRPO Loss Function

$$
\mathcal{J}_{\text{GRPO}}(\theta) = \mathbb{E} \left[ q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(o | q) \right] \Bigg[ \frac{1}{G} \sum_{i=1}^G \min \Bigg( \frac{\pi_{\theta}(o_i | q)}{\pi_{\theta_{\text{old}}}(o_i | q)} A_i, \, \text{clip} \Bigg( \frac{\pi_{\theta}(o_i | q)}{\pi_{\theta_{\text{old}}}(o_i | q)}, 1 - \epsilon, 1 + \epsilon \Bigg) A_i \Bigg) - \beta \, \mathcal{D}_{\text{KL}}(\pi_{\theta} \| \pi_{\text{ref}}) \Bigg]
$$

$$
\mathcal{D}_{\text{KL}}(\pi_{\theta} \| \pi_{\text{ref}}) = \frac{\pi_{\text{ref}}(o_i \mid q)}{\pi_{\theta}(o_i \mid q)} - \log \frac{\pi_{\text{ref}}(o_i \mid q)}{\pi_{\theta}(o_i \mid q)} - 1
$$

$$
A_i = \frac{r_i - \text{mean}(\{r_1, r_2, \dots, r_G\})}{\text{std}(\{r_1, r_2, \dots, r_G\})}
$$

where:
- $G$ is the number of responses sampled per query.
- $\pi_{\theta}(o_i \mid q)$ is the probability of response $o_i$ given query $q$ under the new policy.
- $\pi_{\theta_{\text{old}}}(o_i \mid q)$ is the probability of response $o_i$ given query $q$ under the old policy.
- $A_i$ is the advantage of response $o_i$.
- $\epsilon$ is the clipping parameter.
- $\beta$ is the KL divergence regularization weight.

## **Step 1: Sampling Queries $q \sim P(Q)$**
We randomly sample queries $q$ from the distribution of possible prompts $P(Q)$. This distribution represents **real-world tasks** or **user inputs** that the model will be trained to handle.

**Examples of prompts:**  
- "The best food in the world is"  
- "Corgis are"  

This step forms the basis of RLHF tasks where the model needs to generate responses aligned with user preferences or reward feedback.

## **Step 2: Sampling $G$ Full Responses Per Query**
For each query $q$, we sample **G possible responses (sequences of tokens)** from the **old policy** $\pi_\theta^{\text{old}}$. Instead of just predicting a single token, the model generates **entire responses of length $T$** by sampling the next token sequentially until the response is complete.

**Example:** If the query is "The capital of France is," the G possible responses could be:  
- $o_1$ = "Paris is a beautiful city with historic landmarks."  
- $o_2$ = "Paris, home to the Eiffel Tower and the Louvre."  
- $o_3$ = "Paris, the cultural and political hub of France."  

This process ensures that the model explores **different potential completions** for each prompt, enabling the policy to generalize across a variety of outputs.

## **Step 3: Why Are We Sampling $G$ Full Responses?**
Sampling **multiple full responses per query** helps reduce variance and ensures that the policy update is stable and generalizable. Instead of relying on a single sampled output (which may be noisy or suboptimal), the model evaluates multiple possible completions, making it robust across different scenarios.

**Key benefits of sampling $G$ responses:**
- **Variance reduction:** By averaging the PPO objective across $G$ outputs, the model avoids overfitting to any particular sampled response.
- **Stable policy updates:** Evaluating multiple responses per query ensures that the updates are consistent across different completions.
- **Improved exploration:** The model can explore different possible responses and select the most reward-aligned behavior.

## **Step 4: Computing the PPO Objective for Full Responses**
The main PPO objective evaluates the **probability ratios** between the new policy $\pi_\theta$ and the old policy $\pi_\theta^{\text{old}}$, weighted by the **advantage** of each response.

The **probability of a full response** $o_i = [o_1, o_2, ..., o_T]$ under a policy is calculated as:

$$
\pi_\theta(o \mid q) = \prod_{t=1}^{T} \pi_\theta(o_t \mid o_{<t}, q)
$$

In practice, we compute the **log-probability of the entire sequence** to avoid numerical underflow:

$$
\log \pi_\theta(o \mid q) = \sum_{t=1}^{T} \log \pi_\theta(o_t \mid o_{<t}, q)
$$

The PPO objective compares the **log-probabilities of the new policy and old policy**:

$$
\text{Ratio} = \exp\left( \log \pi_\theta(o \mid q) - \log \pi_\theta^{\text{old}}(o \mid q) \right)
$$

The **advantage function** $A_i$ quantifies how much better a sampled response $o_i$ is compared to the expected baseline:

$$
A_i = \frac{r_i - \text{mean}(r_1, ..., r_G)}{\text{std}(r_1, ..., r_G)}
$$

We then apply **clipping** to the probability ratio to ensure that updates do not deviate too much from the old policy:

$$
\text{Clipped objective} = \min \left( \text{Ratio} \times A_i, \text{clip}(\text{Ratio}, 1 - \epsilon, 1 + \epsilon) \times A_i \right)
$$

## **Step 5: Averaging the PPO Loss Across Queries and Responses**
We compute the expectation $\mathbb{E}$ by averaging the PPO objective over both the sampled queries and the $G$ sampled responses for each query:

$$
J_{\text{GRPO}}(\theta) = \mathbb{E}_{q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_\theta^{\text{old}}} \left[ \frac{1}{G} \sum_{i=1}^G \text{Clipped objective} \right]
$$

This ensures that the policy update is guided by the **average performance across multiple responses** rather than being sensitive to individual samples.

## **Step 6: KL Regularization to Control Response Length and Safety**
The KL divergence penalty $D_{KL}(\pi_\theta \parallel \pi_{\text{ref}})$ prevents the new policy from deviating too far from a **reference policy** $\pi_{\text{ref}}$, which could be the original pre-trained model or a simpler model like DistilGPT-2.

$$
D_{KL}(\pi_\theta \parallel \pi_{\text{ref}}) = \sum_{o_i} \pi_{\text{ref}}(o_i \mid q) \log \frac{\pi_{\text{ref}}(o_i \mid q)}{\pi_\theta(o_i \mid q)}
$$

### **Why is KL regularization important?**
- **Maintains fluency and coherence:** Penalizing large deviations ensures that the model retains its general language generation ability while optimizing for rewards.
- **Controls response length:** A larger KL weight $\beta$ penalizes long or risky completions, encouraging the model to generate shorter, safer responses.
- **Balances exploration and exploitation:** By tuning $\beta$, we can allow the model to explore creative outputs while staying anchored to its pre-trained knowledge.

## **Step 7: When to Use a Reference Policy**
The reference policy is essential in cases where:
- **We are fine-tuning a pre-trained model:** To ensure that the model doesn’t lose its general knowledge.
- **In RLHF settings:** To balance maximizing rewards with preserving fluency and coherence.
- **To control output length and verbosity:** Increasing $\beta$ encourages shorter, safer responses, while decreasing it allows for more diverse and creative outputs.


## **Final Summary:**  
1. **Sample queries** from the distribution P(Q) 
2. **Generate G full responses** per query using the old policy.  
3. **Compute the PPO objective** by comparing the new policy to the old policy, weighted by advantages.  
4. **Apply KL regularization** using the reference policy to balance exploration and safety.  
5. **Update the model parameters** by averaging the PPO loss across queries and responses.

This process ensures that the model improves consistently, remains stable during training, and balances reward maximization with coherence and safety. 😊

# References:
[DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://arxiv.org/abs/2501.12948)

[DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://arxiv.org/pdf/2402.03300)


In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, TopKLogitsWarper
import numpy as np

In [2]:
def sample_sequences(
    model, tokenizer, inputs, max_length=50, top_k=50, n_sample_sequences=1
):
    """
    Sample n_sample_sequences per prompt using top-k sampling.

    Args:
        model: The model used for generation.
        tokenizer: The tokenizer for the model.
        inputs: A dict with keys "input_ids" and "attention_mask" (as returned by the tokenizer).
        max_length: Maximum number of new tokens to generate.
        top_k: The top-k value used in sampling.
        n_sample_sequences: Number of sequences to sample per prompt.

    Returns:
        Tensor of shape (batch_size, n_sample_sequences, total_sequence_length)
        where total_sequence_length = prompt_length + generated tokens.
    """
    G = n_sample_sequences
    input_ids = inputs["input_ids"]  # shape: (batch_size, prompt_length)
    attention_mask = inputs["attention_mask"]
    batch_size = input_ids.size(0)

    # Initialize top-k warper
    top_k_warper = TopKLogitsWarper(top_k=top_k)

    # Expand inputs so that each prompt is repeated G times.
    # New shape: (batch_size * G, prompt_length)
    input_ids = (
        input_ids.unsqueeze(1).expand(batch_size, G, -1).reshape(batch_size * G, -1)
    )
    attention_mask = (
        attention_mask.unsqueeze(1)
        .expand(batch_size, G, -1)
        .reshape(batch_size * G, -1)
    )

    # We'll build generated sequences starting from the prompt.
    generated_sequences = input_ids.clone()  # shape: (batch_size * G, current_length)

    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(
                input_ids=generated_sequences, attention_mask=attention_mask
            )
            # Get logits for the last token in each sequence
            logits = outputs.logits[:, -1, :]  # shape: (batch_size*G, vocab_size)

            # Apply top-k filtering
            filtered_logits = top_k_warper(None, logits)
            probs = F.softmax(filtered_logits, dim=-1)

            # Sample next token
            next_tokens = torch.multinomial(
                probs, num_samples=1
            )  # shape: (batch_size*G, 1)

            # If all sequences have generated an EOS token, stop early.
            if (next_tokens == tokenizer.eos_token_id).all():
                break

            # Append sampled token to sequences.
            generated_sequences = torch.cat([generated_sequences, next_tokens], dim=1)

            # Extend the attention mask accordingly.
            new_mask = torch.ones(
                (attention_mask.size(0), 1),
                device=attention_mask.device,
                dtype=attention_mask.dtype,
            )
            attention_mask = torch.cat([attention_mask, new_mask], dim=1)

    # Reshape back to (batch_size, G, total_sequence_length)
    generated_sequences = generated_sequences.view(batch_size, G, -1)
    return generated_sequences

In [3]:
def compute_sequence_log_probs(model, sequences, prompt_length, pad_token_id):
    """
    Compute the log–probability of the generated part (i.e. tokens after the prompt)
    for each sequence.

    Args:
        model: The language model.
        sequences: Tensor of shape (batch_size, G, total_sequence_length) where
                   the first prompt_length tokens are the prompt.
        prompt_length: Length of the prompt (number of tokens).
        pad_token_id: Token ID used for padding.

    Returns:
        Tensor of shape (batch_size, G) containing the summed log–probability
        of the generated tokens (i.e. excluding the prompt).
    """
    batch_size, G, total_seq_len = sequences.size()
    # Flatten the first two dimensions.
    flat_sequences = sequences.view(batch_size * G, total_seq_len)
    # Create an attention mask (assumes pad_token_id marks padded tokens)
    attention_mask = (flat_sequences != pad_token_id).long()

    with torch.no_grad():
        logits = model(input_ids=flat_sequences, attention_mask=attention_mask).logits
    # Shift logits and labels so that the probability for token t is given by logits[t-1]
    shift_logits = logits[:, :-1, :]  # shape: (B, total_seq_len-1, vocab_size)
    shift_labels = flat_sequences[:, 1:]  # shape: (B, total_seq_len-1)

    # We want the log–probabilities for generated tokens only.
    # For a prompt of length L, the generated tokens are from index L to end in flat_sequences.
    # Their probabilities are predicted at positions L-1 to (total_seq_len-1)-1.
    L = prompt_length
    N = total_seq_len - L  # number of generated tokens
    if N <= 0:
        raise ValueError("No generated tokens to compute log–probs for.")

    # Slice out only the generated portion.
    # The first generated token is predicted at position L-1 in shift_logits, corresponding to label at index L.
    gen_logits = shift_logits[:, L - 1 : L - 1 + N, :]  # shape: (B, N, vocab_size)
    gen_labels = shift_labels[:, L - 1 : L - 1 + N]  # shape: (B, N)

    # Compute log probabilities for each token.
    log_probs_all = F.log_softmax(gen_logits, dim=-1)  # shape: (B, N, vocab_size)
    token_log_probs = log_probs_all.gather(2, gen_labels.unsqueeze(-1)).squeeze(
        -1
    )  # shape: (B, N)
    # Sum log–probs over the generated tokens.
    seq_log_probs = token_log_probs.sum(dim=1)  # shape: (B,)

    return seq_log_probs.view(batch_size, G)

In [4]:
def compute_reward(sequences: torch.Tensor):
    """
    Generate rewards for the sampled sequences.

    At the moment this will be random rewards between 0 and 1 in the shape of the input tensor.

    Args:
        response_text: The tensor of sampled sequences.

    Returns:
        A tensor of rewards in the shape of the input tensor.
    """

    batch_size, G, _ = sequences.shape
    rewards = torch.rand(batch_size, G, device=sequences.device)
    return rewards

In [5]:
def compute_advantages(rewards):
    """
    Normalize rewards per query across the group.

    Args:
        rewards: Tensor of shape (batch_size, G)

    Returns:
        Tensor of shape (batch_size, G) of normalized advantages.
    """
    mean_reward = rewards.mean(dim=1, keepdim=True)
    std_reward = rewards.std(dim=1, unbiased=False, keepdim=True)
    advantages = (rewards - mean_reward) / (std_reward + 1e-8)
    return advantages

In [6]:
def ppo_loss(
    new_log_probs,
    old_log_probs,
    advantages,
    epsilon=0.2,
    beta=0.01,
    reference_log_probs=None,
):
    """
    Compute the PPO-style loss with a KL penalty term.

    Args:
        new_log_probs: Tensor (batch_size, G) – log–probs under the new policy.
        old_log_probs: Tensor (batch_size, G) – log–probs under the old policy.
        advantages: Tensor (batch_size, G) – normalized advantages.
        epsilon: Clipping parameter.
        beta: KL penalty coefficient.
        reference_log_probs: Tensor (batch_size, G) – log–probs under a reference policy.

    Returns:
        A scalar loss.
    """
    # Compute probability ratios: exp(new - old)
    ratios = torch.exp(new_log_probs - old_log_probs)
    clipped_ratios = torch.clamp(ratios, 1 - epsilon, 1 + epsilon)
    objective = torch.min(ratios * advantages, clipped_ratios * advantages)

    # KL penalty using the formula: (pi_ref/pi_new) - log(pi_ref/pi_new) - 1.
    if reference_log_probs is not None:
        ratio_ref = torch.exp(reference_log_probs - new_log_probs)
        kl_div = ratio_ref - torch.log(ratio_ref) - 1
        kl_div = kl_div.mean()
    else:
        kl_div = 0.0

    loss = -objective.mean() + beta * kl_div
    return loss

# Checking individual components of GRPO

In [7]:
# Initialize tokenizer and add PAD token
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Initialize Models
new_model = AutoModelForCausalLM.from_pretrained(
    "gpt2"
)  # New policy model (to be optimized)
old_model = AutoModelForCausalLM.from_pretrained(
    "gpt2"
)  # Old policy model (before update)
reference_model = AutoModelForCausalLM.from_pretrained(
    "distilgpt2"
)  # Optional reference policy

In [8]:
# Sample batch of prompts (queries)
prompts = [
    "The best food in the world is",
    "Corgis are",
    "The meaning of life is",
    "I am",
    "The most important thing in the world is",
]

# Tokenize prompts and get logits
inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)

In [9]:
# Sample G sequences per query using each policy
new_policy_sequences = sample_sequences(
    new_model,
    tokenizer,
    inputs,
    max_length=10,
    top_k=50,
    n_sample_sequences=5,
)
old_policy_sequences = sample_sequences(
    old_model,
    tokenizer,
    inputs,
    max_length=10,
    top_k=50,
    n_sample_sequences=5,
)
reference_policy_sequences = sample_sequences(
    reference_model,
    tokenizer,
    inputs,
    max_length=10,
    top_k=50,
    n_sample_sequences=5,
)

In [10]:
# The prompt length (assumed the same for all examples)
prompt_length = inputs["input_ids"].size(1)
pad_token_id = (
    tokenizer.pad_token_id
    if tokenizer.pad_token_id is not None
    else tokenizer.eos_token_id
)

# Compute log–probabilities for the generated (i.e. non–prompt) part.
new_log_probs = compute_sequence_log_probs(
    new_model, new_policy_sequences, prompt_length, pad_token_id
)
old_log_probs = compute_sequence_log_probs(
    old_model, old_policy_sequences, prompt_length, pad_token_id
)
if reference_policy_sequences is not None:
    reference_log_probs = compute_sequence_log_probs(
        reference_model, reference_policy_sequences, prompt_length, pad_token_id
    )
else:
    reference_log_probs = None


print("New policy log-probs:", new_log_probs)
print("Old policy log-probs:", old_log_probs)
print("Reference policy log-probs:", reference_log_probs)

New policy log-probs: tensor([[-37.0535, -24.5370, -20.2848, -33.6020, -30.9356],
        [-31.7608, -34.2596, -35.6612, -26.7157, -33.7347],
        [-33.5248, -34.9668, -43.7043, -38.8953, -29.9209],
        [-31.5192, -31.2574, -20.3213, -43.9429, -27.3535],
        [-25.6883, -26.4976, -25.9713, -33.7142, -29.4448]])
Old policy log-probs: tensor([[-39.1005, -29.6932, -30.6909, -29.8440, -30.2227],
        [-31.7843, -37.6550, -30.7647, -36.2704, -33.3683],
        [-28.6793, -10.1107, -28.9684, -37.1744, -25.7923],
        [-28.0612, -26.5376, -37.6537, -29.1060, -35.9013],
        [-22.7740, -22.2103, -26.1856, -26.3483, -31.9297]])
Reference policy log-probs: tensor([[-33.8827, -20.7227, -27.5098, -32.8465, -36.5514],
        [-13.8102, -27.4110, -38.8218, -33.9706, -17.7042],
        [-32.3322, -29.1900, -35.0639, -38.1682,  -5.4566],
        [-44.2237, -29.9062, -43.3479, -33.0118, -28.1428],
        [-32.6752, -31.0077, -28.6882, -24.5660, -33.0195]])


In [11]:
# Mock rewards for demonstration (one reward per sampled token)
rewards = compute_reward(new_policy_sequences)
print("Rewards:", rewards)

Rewards: tensor([[0.6701, 0.8050, 0.0647, 0.9789, 0.1421],
        [0.8898, 0.0237, 0.8894, 0.1515, 0.1535],
        [0.0225, 0.6304, 0.3054, 0.0383, 0.7066],
        [0.8884, 0.7723, 0.4005, 0.9084, 0.9065],
        [0.9083, 0.8879, 0.1165, 0.0722, 0.8682]])


In [12]:
# Compute normalized advantages
def compute_advantages(rewards: torch.Tensor) -> torch.Tensor:
    mean_reward = rewards.mean(dim=1, keepdim=True)
    std_reward = rewards.std(dim=1, unbiased=False, keepdim=True)
    advantages = (rewards - mean_reward) / (std_reward + 1e-8)
    return advantages

In [13]:
advantages = compute_advantages(rewards)
print("Advantages:", advantages)

Advantages: tensor([[ 0.3786,  0.7489, -1.2829,  1.2261, -1.0707],
        [ 1.2160, -1.0333,  1.2151, -0.7014, -0.6964],
        [-1.1086,  1.0098, -0.1228, -1.0536,  1.2752],
        [ 0.5832, -0.0150, -1.9314,  0.6866,  0.6766],
        [ 0.8673,  0.8149, -1.1663, -1.2803,  0.7644]])


In [14]:
# Compute PPO loss across G sampled outputs
ppo_loss_value = ppo_loss(
    new_log_probs=new_log_probs,
    old_log_probs=old_log_probs.detach(),  # Old policy is detached to prevent gradient flow
    advantages=advantages,
    epsilon=0.2,
    beta=0.01,
    reference_log_probs=reference_log_probs,
)

print("PPO Loss:", ppo_loss_value.item())

PPO Loss: 19490114.0


# Putting it all together

In [15]:
def grpo_train_step(
    new_model,
    old_model,
    reference_model,
    tokenizer,
    inputs,
    max_length=10,
    top_k=50,
    n_sample_sequences=5,
    epsilon=0.2,
    beta=0.01,
    optimizer=None,
):
    """
    Perform one GRPO update step.

    Args:
        new_model: The current (new) policy to optimize.
        old_model: A frozen copy of the policy before the update.
        reference_model: A reference policy (e.g. a pretrained model) for KL penalty.
        tokenizer: The tokenizer corresponding to the models.
        inputs: Tokenized prompts (a dict with "input_ids" and "attention_mask").
        max_length: Number of new tokens to generate.
        top_k: Top-k parameter for sampling.
        n_sample_sequences: Number of responses (G) to sample per prompt.
        epsilon: PPO clipping parameter.
        beta: KL divergence regularization weight.
        optimizer: The optimizer used for updating new_model.

    Returns:
        loss: The computed loss (a scalar tensor).
    """
    # Sample responses using each policy.
    new_policy_sequences = sample_sequences(
        new_model,
        tokenizer,
        inputs,
        max_length,
        top_k,
        n_sample_sequences,
    )
    old_policy_sequences = sample_sequences(
        old_model,
        tokenizer,
        inputs,
        max_length,
        top_k,
        n_sample_sequences,
    )
    if reference_model is not None:
        reference_policy_sequences = sample_sequences(
            reference_model,
            tokenizer,
            inputs,
            max_length,
            top_k,
            n_sample_sequences,
        )
    else:
        reference_policy_sequences = None

    # The prompt length (assumed the same for all examples)
    prompt_length = inputs["input_ids"].size(1)
    pad_token_id = (
        tokenizer.pad_token_id
        if tokenizer.pad_token_id is not None
        else tokenizer.eos_token_id
    )

    # Compute log–probabilities for the generated (i.e. non–prompt) part.
    new_log_probs = compute_sequence_log_probs(
        new_model, new_policy_sequences, prompt_length, pad_token_id
    )
    old_log_probs = compute_sequence_log_probs(
        old_model, old_policy_sequences, prompt_length, pad_token_id
    )
    if reference_policy_sequences is not None:
        reference_log_probs = compute_sequence_log_probs(
            reference_model, reference_policy_sequences, prompt_length, pad_token_id
        )
    else:
        reference_log_probs = None

    # Compute rewards and advantages.
    rewards = compute_reward(new_policy_sequences)  # shape: (batch_size, G)
    advantages = compute_advantages(rewards)  # shape: (batch_size, G)

    # Compute the GRPO (PPO-style) loss.
    loss = ppo_loss(
        new_log_probs,
        old_log_probs.detach(),
        advantages,
        epsilon,
        beta,
        reference_log_probs,
    )

    if optimizer is not None:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return loss

In [17]:
# Run one GRPO training step.
loss = grpo_train_step(
    new_model,
    old_model,
    reference_model,
    tokenizer,
    inputs,
    max_length=10,
    top_k=50,
    n_sample_sequences=5,
    epsilon=0.2,
    beta=0.01,
)

print("GRPO Loss:", loss.item())
print("The numbers can vary due to the random nature of the sampling and rewards.")

GRPO Loss: 2426870.5
The numbers can vary due to the random nature of the sampling and rewards.
