# Groupe Relative Policy Optimization (GRPO)

Install the Hugging Face libraries to run this notebook.

In [None]:
%pip install transformers

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

Your goal is to fill in the `GRPOTrainer` class. You have two options (and you can do both):
* the "normal GRPO" with clipped surrogate objective
* or the "vanilla GRPO" with original objective

In [None]:
class GRPOTrainer:
    def __init__(self, 
                 model,
                 tokenizer,
                 learning_rate = 1e-5, 
                 temperature = 1.0, 
                 max_length = 100, 
                 device = "cpu"):
        self.llm = model.to(device)
        self.tokenizer = tokenizer
        self.optimizer = torch.optim.AdamW(self.llm.parameters(), lr=learning_rate)
        self.device = device
        self.temperature = temperature
        self.max_length = max_length

    def generate(self, prompt):
        input = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        output = None
        loss = None

        text = self.tokenizer.decode(output[0])
        return loss, text

    def calculate_reward(self, output):
        """
            Calcule the reward of a single output
        """
        pass

    def calculate_GRPO_advantages(self, outputs):
        """
            Calculate the advantages of each output
        """
        pass 

    def train_step(self, prompt):
        """
            A training step on a single prompt
        """
        pass

In [None]:
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
trainer = GRPOTrainer(model, tokenizer)
prompts = ["The best way to learn coding is", "The future of AI is"]

for epoch in range(3): # Train for a few epochs
    loss = 0
    for prompt in prompts:
        loss += trainer.train_step(prompts)        
    print(f"Epoch {epoch+1}, Loss: {loss / len(prompts)}")

In [None]:
trainer.generate_text(prompts)