# Groupe Relative Policy Optimization (GRPO)

Install the Hugging Face libraries to run this notebook.

In [2]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

Your goal is to fill in the `GRPOTrainer` class. You have two options (and you can do both):
* the "normal GRPO" with clipped surrogate objective
* or the "vanilla GRPO" with original objective

In [None]:
import torch
import torch.nn.functional as F

class GRPOTrainer:
    def __init__(self, 
                 model,
                 tokenizer,
                 learning_rate=1e-5, 
                 temperature=1.0, 
                 max_length=100, 
                 device="cpu",
                 clip_epsilon=0.2,  # Seuil pour le clipping dans "normal GRPO"
                 use_clipped=True):  # Active Clipped GRPO si True, sinon Vanilla GRPO
        self.llm = model.to(device)
        self.tokenizer = tokenizer
        self.optimizer = torch.optim.AdamW(self.llm.parameters(), lr=learning_rate)
        self.device = device
        self.temperature = temperature
        self.max_length = max_length
        self.clip_epsilon = clip_epsilon
        self.use_clipped = use_clipped  # Toggle entre normal et vanilla GRPO

    def generate(self, prompt):
        """
        Generate text from a prompt using the LLM.
        """
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)

        with torch.no_grad():
            output = self.llm.generate(
                input_ids,
                max_length=self.max_length,
                temperature=self.temperature,
                top_k=50,
                do_sample=True
            )
        
        text = self.tokenizer.decode(output[0], skip_special_tokens=True)
        return text

    def calculate_reward(self, text):
        """
        Calculate reward using inverse perplexity.
        """
        input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to(self.device)

        with torch.no_grad():
            outputs = self.llm(input_ids, labels=input_ids)
            loss = outputs.loss  # NLL Loss (Negative Log-Likelihood)
        
        reward = -loss.item()  # Reward = -Perplexity (minimiser la perplexité)
        return reward

    def calculate_GRPO_advantages(self, rewards):
        """
        Compute the normalized advantage for GRPO.
        """
        rewards = torch.tensor(rewards, dtype=torch.float32, device=self.device)
        mean_reward = rewards.mean()
        std_reward = rewards.std() + 1e-8  # Avoid division by zero
        advantages = (rewards - mean_reward) / std_reward
        return advantages

    def train_step(self, prompt):
        """
        Perform one training step on a single prompt.
        """
        # Générer des sorties
        generated_text = self.generate(prompt)

        # Calculer la récompense
        reward = self.calculate_reward(generated_text)

        # Calculer l'avantage GRPO
        advantages = self.calculate_GRPO_advantages([reward])

        # Encoder l'entrée et la sortie en tokens
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
        output_ids = self.tokenizer(generated_text, return_tensors="pt").input_ids.to(self.device)

        # Calculer la probabilité de la politique


In [None]:
#Inspired from HuggingFace
import re
import copy

model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True, padding_side="left")

class GRPOTrainer:
    def __init__(self, 
                 model,
                 tokenizer,
                 learning_rate = 1e-5, 
                 temperature = 1.0, 
                 max_length = 100, 
                 device = "cpu"):
        self.model = model.to(device) # model to optimize
        self.ref_model = None # reference model
        self.tokenizer = tokenizer
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate)
        self.device = device
        self.temperature = temperature
        self.max_prompt_length = max_length
        self.eps = 1e-10
        self.num_generations = 5 # num of generation per prompts
        self.num_iterations = 1
        self.beta = 0.1


    def reward_func(self, prompts, completions, **kwargs):
        """Reward function that checks if the completion has a specific format."""

        pattern = r"^<think>.*?</think><answer>.*?</answer>$"
        completion_contents = [completion[0]["content"] for completion in completions]
        matches = [re.match(pattern, content) for content in completion_contents]
        return [1.0 if match else 0.0 for match in matches]
    

    def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
        """
        Get the per-token log propabilities for the completions for the model and the reference model
        """
        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        logits = logits[:,:-1,:] #(B, L-1, V) exclude the last logit corresponding to next token prediction

        input_ids = input_ids[:, -logits_to_keep:]
        logits = logits[:, -logits_to_keep:]
        return torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) #log soft max to get probabilities
    

    def _generate_and_score_completions(self, prompts):
        """
            Génère une réponse avec self.model à un batch de prompts et effectue le scoring
        """
        # Tokenize les prompts et récupère le mask d'attention (pour le padding)
        prompt_inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, padding_side='left').to(self.device) #tokenization
        prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]

        # Troncatène les prompts si ils sont trop longs
        if self.max_prompt_length is not None:
            prompt_ids = prompt_ids[:, -self.max_prompt_length :]
            prompt_mask = prompt_mask[:, -self.max_prompt_length :]

        # Génération des réponses 
        with torch.inference_mode():
            prompt_completion_ids = self.model.generate(prompt_ids, attention_mask=prompt_mask, num_return_sequences=self.num_generations)
        
        batch_size = prompt_ids.size(0)
        prompt_completion_ids = prompt_completion_ids.view(batch_size, self.num_generations, -1)  # resizing (B, G, P+C)

        # Extrait les ids des prompts et des réponses
        prompt_length = prompt_ids.size(1)
        prompt_ids = prompt_completion_ids[:, :, :prompt_length] #(B, G, P)
        completion_ids = prompt_completion_ids[:, :, prompt_length:] #(B, G, C)

        # Padding après le premier token EOS
        is_eos = completion_ids == self.tokenizer.eos_token_id
        eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=self.device)
        eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
        sequence_indices = torch.arange(is_eos.size(1), device=self.device).expand(is_eos.size(0), -1)
        completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
        
        # Concatenation du masque du prompt et de la complétion pour le calcul des logits
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)  # (B, P+C)
        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

        # Calcul les log-probabilités des tokens générés
        self.model.eval()
        with torch.inference_mode():
            # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
            # computation here, and use per_token_logps.detach() instead.
            if self.num_iterations > 1:
                old_per_token_logps = self._get_per_token_logps(
                    self.model, prompt_completion_ids, attention_mask, logits_to_keep
                )
            else:
                old_per_token_logps = None

            if self.ref_model is not None:
                ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep)
            else: # Si il n'y a pas de ref model, on prend self.model
                ref_per_token_logps = self._get_per_token_logps(self.model, prompt_completion_ids, attention_mask, logits_to_keep)
        
        # Décodage des réponses
        completions = self.tokenizer.batch_decode(completion_ids, skip_special_tokens=True)

        # Calcul des rewards
        output_rewards = self.reward_func(prompts=prompts, completions=completions)
        rewards = torch.tensor(output_rewards, dtype=torch.float32, device=self.device)
        assert len(output_rewards) == len(completions) # Mismatch entre le nombre de complétions et de récompenses

        # Calcul des avantages
        mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
        std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
        mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
        advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

        return {
            "prompt_ids": prompt_ids,
            "prompt_mask": prompt_mask,
            "completion_ids": completion_ids,
            "completion_mask": completion_mask,
            "old_per_token_logps": old_per_token_logps,  # Vérifier si bien calculé
            "ref_per_token_logps": ref_per_token_logps,
            "advantages": advantages,}


    def _prepare_inputs(self, inputs):
        return self._generate_and_score_completions(inputs) # Sans buffering
    

    def compute_loss(self, model, inputs):
        """
        Compute Loss according to GRPO paper, using advantages, per_token probabilities and KL divergence to reference model
        """

        # Compute the per-token log probabilities for the model
        prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
        completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
        logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

        per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

        # Compute the KL divergence between the model and the reference model
        if self.beta != 0.0:
            ref_per_token_logps = inputs["ref_per_token_logps"]
            per_token_kl = (
                torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
            )

        # Compute the loss
        advantages = inputs["advantages"]
        # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's computation (see
        # _generate_and_score_completions) and use per_token_logps.detach() instead.
        old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
        coef_1 = torch.exp(per_token_logps - old_per_token_logps)
        coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon) #clipped ratio

        per_token_loss1 = coef_1 * advantages.unsqueeze(1)
        per_token_loss2 = coef_2 * advantages.unsqueeze(1)
        per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
        if self.beta != 0.0:
            per_token_loss = per_token_loss + self.beta * per_token_kl # on ajoute la KL
        loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() # on calcule la loss que sur la complétion

        return loss
    

    def train(self, num_epochs, dataset, batch_size=8):
        """
        Entraîne le modèle GRPO sur num_epochs en utilisant des batchs du dataset.
        - Met à jour `self.ref_model` à chaque epoch.
        - Effectue des `train_step()` sur chaque batch.
        """
        self.ref_model = copy.deepcopy(self.model)  # Initialise le modèle de référence

        for epoch in range(num_epochs):
            total_loss = 0
            num_batches = len(dataset) // batch_size

            for i in range(0, len(dataset), batch_size):
                batch_prompts = dataset[i:i+batch_size]  # Sélectionne un batch
                loss = self.train_step(batch_prompts)  # Effectue un train_step
                total_loss += loss

            avg_loss = total_loss / num_batches
            print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

            # Met à jour le modèle de référence après chaque epoch
            self.ref_model = copy.deepcopy(self.model)

    def train_step(self, prompts):
        """
            A training step with num_iterations optimization setp on given prompts mini-batch
        """
        
        self.model.train()

        # génère les réponses avec model = model_old, renvoie "old_per_token_logps" et "ref_per_token_logps"
        inputs = self._prepare_inputs(prompts)

        # Pour chaque GRPO iteration, on met à jour le modèle sur le même batch
        for i in range(self.num_iterations): 

            # Calcul de la loss (utilise le modèle actuel qui se met à jour dans la boucle et on calcule la loss par rapport aux log_prob de old_model (dans inputs))
            loss = self.compute_loss(self.model, inputs)

            # Backpropagation
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        return loss.item()




In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
inputs = tokenizer(["Hello, how are you?", "Fine, thanks!"], return_tensors="pt", padding=True)

print(inputs)
print(tokenizer.eos_token_id)


{'input_ids': tensor([[ 101, 7592, 1010, 2129, 2024, 2017, 1029,  102],
        [ 101, 2986, 1010, 4283,  999,  102,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 0, 0]])}
None


In [None]:
class GRPOTrainer:
    def __init__(self, 
                 model,
                 tokenizer,
                 learning_rate = 1e-5, 
                 temperature = 1.0, 
                 max_length = 100, 
                 device = "cpu"):
        self.llm = model.to(device)
        self.tokenizer = tokenizer
        self.optimizer = torch.optim.AdamW(self.llm.parameters(), lr=learning_rate)
        self.device = device
        self.temperature = temperature
        self.max_length = max_length

    def generate(self, prompt):
        input = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        output = None
        loss = None

        text = self.tokenizer.decode(output[0])
        return loss, text

    def calculate_reward(self, output):
        """
            Calcule the reward of a single output
        """
        pass

    def calculate_GRPO_advantages(self, outputs):
        """
            Calculate the advantages of each output
        """
        pass 

    def train_step(self, prompt):
        """
            A training step on a single prompt
        """
        pass

In [10]:
model_name = "gpt2"
#model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True, padding_side="left")
print(tokenizer.eos_token_id)

50256


In [None]:
trainer = GRPOTrainer(model, tokenizer)
prompts = ["The best way to learn coding is", "The future of AI is"]

for epoch in range(3): # Train for a few epochs
    loss = 0
    for prompt in prompts:
        loss += trainer.train_step(prompts)        
    print(f"Epoch {epoch+1}, Loss: {loss / len(prompts)}")

In [None]:
trainer.generate_text(prompts)