In [None]:
!git clone https://github.com/JamesSand/grpo_again.git
%cd grpo_again
!pip install -r requirements.txt

Cloning into 'grpo_again'...
remote: Enumerating objects: 18, done.[K
remote: Counting objects: 100% (18/18), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 18 (delta 6), reused 15 (delta 3), pack-reused 0 (from 0)[K
Receiving objects: 100% (18/18), 30.93 KiB | 4.42 MiB/s, done.
Resolving deltas: 100% (6/6), done.
/content/grpo_again/grpo_again


In [None]:
from transformers import AutoModelForCausalLM, AutoModel, AutoModelForSequenceClassification, AutoTokenizer, PreTrainedModel
from dataclasses import dataclass
from typing import Optional, Union, Tuple
import random
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from typing import Callable, Dict, List, Optional, Tuple, Union, Any
from copy import deepcopy
from datasets import load_dataset
from reward_func import *

### Define GRPO Variables, Load Models


In [None]:
class GRPOArguments:

    output_dir = './output'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    lr = 1e-6
    save_steps = 100
    epoch = 3
    num_generations = 4 # number of samples in a group
    max_prompt_length = 256 # maximum input length
    max_generate_length = 512 # maximum output length
    reward_weights : List[float] = None # weights for rewards (multiple reward functions)
    beta = 1e-2 # KL divergence coefficient, 0 to ignore KL divergence (no reference model)
    clip_eps = 0.2
    gradient_accumulation_steps = 2 # gradient accumulation
    num_iterations = 1 # number of training iterations per sample collection
    batch_size = 1

args = GRPOArguments()

model_name = "Qwen/Qwen2.5-0.5B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

### Sample Utils

In [None]:
@dataclass
class Samples:
    prompt_response_ids: torch.Tensor
    response_ids: torch.Tensor
    prompt: Any
    answer: Any
    attention_mask: Optional[torch.LongTensor]
    action_mask: Optional[torch.BoolTensor]
    num_actions: Union[int, torch.Tensor]
    response_length: int



### Let's define GRPO Trainer


In [None]:

SYSTEM_PROMPT = "Let's think step by step and output the final answer within \\boxed{}."


class GRPOTrainer:
    def __init__(self,
        model = None,
        reward_funcs: Union[List[str], List[Callable]] = None,
        args = None,
        train_dataset: Optional[Union[Dataset]] = None,
        eval_dataset: Optional[Union[Dataset]] = None,
        tokenizer = None):

        self.args = args
        # Load model
        if isinstance(model, str):
            model = AutoModelForCausalLM.from_pretrained(model)
        self.model = model.to(self.args.device)

        # Whether to use reference model
        self.ref_model = None
        if self.args.beta != 0.0:
            self.ref_model = deepcopy(model)
            self.ref_model.eval()
            # Keep reference model on CPU to save GPU memory, will move to GPU when needed
            self.ref_model = self.ref_model.cpu()

        if isinstance(tokenizer, str):
            tokenizer = AutoTokenizer.from_pretrained(tokenizer)

        self.tokenizer = self.get_tokenizer(tokenizer)

        if isinstance(reward_funcs, str):
            reward_funcs = [reward_funcs]

        for i, reward_func in enumerate(reward_funcs):
            # If reward function is a string, it indicates using a reward model, then load the model
            if isinstance(reward_func, str):
                reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
                    reward_func, num_labels=1).to(self.args.device)

        self.reward_funcs = reward_funcs

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset

        # Cache one batch of generated data for multiple training iterations without regeneration
        self.input_buffer = [None] * self.args.gradient_accumulation_steps

        # Number of model updates
        self.update_steps = 0

    def get_tokenizer(self, tokenizer):
        tokenizer.padding_side = "left"
        return tokenizer

    # Generate samples for each prompt in the inputs
    def generate_samples(self, inputs):
        samples_list = []
        self.model.eval()
        prompts = [prompt for prompt in inputs['prompt']]
        answers = [None] * len(prompts)

        if 'answer' in inputs:
            answers = [answer for answer in inputs['answer']]

        max_length = self.args.max_generate_length + self.args.max_prompt_length
        for prompt, answer in zip(prompts, answers):
            # For each prompt, generate a group of samples

            # Apply chat template with system prompt
            input_text = self.tokenizer.apply_chat_template([{"role": "system", 'content': SYSTEM_PROMPT}, {"role": "user", 'content': prompt}], add_generation_prompt=True, tokenize=False)

            # Generate input data for a group
            inputs = self.tokenizer([input_text] * self.args.num_generations, padding='max_length', max_length=self.args.max_prompt_length, truncation=True, return_tensors='pt')

            prompt_ids = inputs['input_ids']

            with torch.no_grad():
                prompt_response_ids = self.model.generate(
                    **inputs.to(self.args.device),
                    max_new_tokens = self.args.max_generate_length,
                    temperature=0.9,
                    top_p = 1,
                    top_k = 50)

            if prompt_response_ids.size(1) >= max_length:
                prompt_response_ids = prompt_response_ids[:, :max_length]
            else:
                # Pad sequences to max_length
                prompt_response_ids = torch.cat([prompt_response_ids, torch.full((prompt_response_ids.size(0), max_length - prompt_response_ids.size(1)), fill_value=self.tokenizer.pad_token_id, device=prompt_response_ids.device)], dim=1)

            attention_mask = (prompt_response_ids.ne(self.tokenizer.pad_token_id)).to(dtype=torch.long)

            # [B, max gen tokens]
            response_ids = prompt_response_ids[:, prompt_ids.size(1):]
            action_mask = (response_ids.ne(self.tokenizer.eos_token_id) & response_ids.ne(self.tokenizer.pad_token_id)).to(dtype=torch.long)


            # Create Samples dataclass instance
            samples = Samples(
                prompt_response_ids=prompt_response_ids,
                response_ids=response_ids,
                prompt = prompt,
                answer = answer,
                attention_mask=attention_mask,
                action_mask=action_mask,
                num_actions=action_mask.size(1),
                response_length=action_mask.float().sum(dim=-1)
            )
            samples_list.append(samples)

        return samples_list

    # Generate experiences (advantages, token probability distributions)
    def generate_experiences(self, inputs):

        self.model.eval()
        samples_list = self.generate_samples(inputs)

        batch_prompt_response_ids = []
        batch_attention_mask = []
        batch_action_mask = []
        batch_advantages = []
        batch_old_action_log_probs = []
        batch_ref_action_log_probs = []

        for sample_idx, samples in enumerate(samples_list):
            prompt_response_ids = samples.prompt_response_ids # shape: (num_generations, seq_len)
            response_ids = samples.response_ids # shape: (num_generations, seq_len)
            answer = samples.answer
            attention_mask = samples.attention_mask # shape: (num_generations, seq_len)
            action_mask = samples.action_mask # shape: (num_generations, seq_len)
            num_actions = samples.num_actions
            prompt = samples.prompt
            batch_prompt_response_ids.append(prompt_response_ids)
            batch_attention_mask.append(attention_mask)
            batch_action_mask.append(action_mask)

            with torch.no_grad():
                # Calculate token probabilities from policy model
                old_action_log_probs = self.get_action_log_probs(self.model, prompt_response_ids, attention_mask, num_actions)
                batch_old_action_log_probs.append(old_action_log_probs)

                # move self.model to cpu
                self.model = self.model.cpu()
                torch.cuda.empty_cache()

                # Whether to use reference model
                if self.ref_model:
                    # Move reference model to GPU for inference
                    self.ref_model = self.ref_model.to(self.args.device)
                    # Calculate token probabilities from reference model
                    ref_action_log_probs = self.get_action_log_probs(self.ref_model, prompt_response_ids, attention_mask, num_actions)
                    batch_ref_action_log_probs.append(ref_action_log_probs)
                    # Move reference model back to CPU to save GPU memory
                    self.ref_model = self.ref_model.cpu()
                    torch.cuda.empty_cache()

                # move self.model back
                self.model = self.model.to(self.args.device)

                # Store rewards from each reward function for each response in a group
                rewards_per_func = torch.zeros(len(self.reward_funcs), self.args.num_generations, device=self.args.device)

                # Convert output to text
                response_texts = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)
                prompt_texts = [prompt] * len(response_texts)
                prompt_response_texts = [prompt + response for prompt, response in zip(prompt_texts, response_texts)]

                for i, reward_func in enumerate(self.reward_funcs):
                    answers = [answer] * len(prompt_texts)
                    output_reward_func = reward_func(prompts=prompt_texts, responses=response_texts, answers=answers, sample_idx=sample_idx)
                    output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
                    rewards_per_func[i] = torch.tensor(output_reward_func, dtype=torch.float32, device=self.args.device)

                # rewards_per_func: [num_funcs, num_generations]
                if not self.args.reward_weights:
                    self.args.reward_weights = [1.0] * len(self.reward_funcs)
                if len(self.args.reward_weights) != len(self.reward_funcs):
                    raise ValueError("The number of reward weights must be equal to the number of reward functions.")
                # Multiply by weights for each reward function
                rewards = rewards_per_func * torch.tensor(self.args.reward_weights, dtype=torch.float32, device=rewards_per_func.device).unsqueeze(1)

                # rewards: [num_funcs, num_generations]
                rewards = rewards.sum(dim=0) # shape: [num_generations]
                # print(f'rewards: {rewards}')
                print(f'rewards: {rewards[0].item()}')
                mean_group_rewards = rewards.mean()
                std_group_rewards = rewards.std()

                # GRPO advantages are sentence-level, not token-level
                advantages = (rewards - mean_group_rewards) / (std_group_rewards + 1e-8) # shape: [num_generations]
                batch_advantages.append(advantages)


        return {
            "prompt_response_ids": torch.cat(batch_prompt_response_ids, dim=0),
            "attention_mask": torch.cat(batch_attention_mask, dim=0),
            "action_mask": torch.cat(batch_action_mask, dim=0),
            "old_action_log_probs": torch.cat(batch_old_action_log_probs, dim=0),
            "ref_action_log_probs": torch.cat(batch_ref_action_log_probs, dim=0) if self.ref_model else None,
            "advantages": torch.cat(batch_advantages, dim=0),
        }

    def compute_loss(self, model, inputs):

        # [B, max prompt len + max gen len]
        prompt_response_ids = inputs['prompt_response_ids']
        attention_mask = inputs['attention_mask']

        # [B, max gen len]
        action_mask = inputs['action_mask']

        num_actions = action_mask.size(1)

        # [B, max gen len]
        action_log_probs = self.get_action_log_probs(model, prompt_response_ids, attention_mask, num_actions)


        if self.args.beta != 0.0:

            ref_action_log_probs = inputs['ref_action_log_probs']
            log_ratio = ref_action_log_probs - action_log_probs
            log_ratio = log_ratio * action_mask

            # k3: log_ratio.exp() - 1 - log_ratio
            k3 = log_ratio.exp() - 1 - log_ratio

        # [B]
        advantages = inputs['advantages']

        old_action_log_probs = inputs['old_action_log_probs'] if self.args.num_iterations > 1 else action_log_probs.detach()
        coef_1 = torch.exp(action_log_probs - old_action_log_probs) # Importance Sampling shape: [batch_size * num_generations, num_actions]
        coef_2 = torch.clamp(coef_1, 1 - self.args.clip_eps, 1 + self.args.clip_eps)
        per_token_loss1 = coef_1 * advantages.unsqueeze(1)
        per_token_loss2 = coef_2 * advantages.unsqueeze(1)

        per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
        per_token_loss = per_token_loss * action_mask
        if self.args.beta != 0.0:
            per_token_loss = per_token_loss + self.args.beta * k3

        loss = per_token_loss.sum(dim=1) / action_mask.sum(dim=1) # shape: [batch_size * num_generations]
        loss = loss.mean()

        return loss


    def get_action_log_probs(self, model, input_ids, attention_mask, num_actions):

        # Calculate token probabilities from policy model
        output = model(input_ids, attention_mask=attention_mask)
        logits = output.logits
        log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
        log_probs_labels = log_probs.gather(dim=-1, index=input_ids[:, 1:].unsqueeze(-1))
        action_log_probs = log_probs_labels.squeeze(-1)[:, -num_actions:]
        return action_log_probs



    def train_step(self, model, inputs, optimizer, step):
        model.train()
        # scaler = torch.amp.GradScaler()
        # with torch.amp.autocast(device_type='cuda'):
        loss = self.compute_loss(model, inputs)
        loss = loss / self.args.gradient_accumulation_steps
        # loss = scaler.scale(loss)
        loss.backward()
        if (step + 1) % self.args.gradient_accumulation_steps == 0:

            optimizer.step()
            optimizer.zero_grad()
            # scaler.unscale_(optimizer)
            # scaler.step(optimizer)
            # scaler.update()

            print(f"step: {self.update_steps}/{self.global_steps}  grpo_loss: {loss.item():.8f}")
        torch.cuda.empty_cache()

    def train(self):
        self.global_steps = self.args.num_iterations * self.args.epoch * len(self.train_dataset) // (self.args.batch_size * self.args.gradient_accumulation_steps)

        for _ in range(self.args.epoch):
            dataloader = DataLoader(self.train_dataset, batch_size=self.args.batch_size, shuffle=True)

            for idx, batch in enumerate(dataloader):
                # sample answers for each question in the batch
                inputs = self.generate_experiences(batch)
                self.input_buffer[idx % self.args.gradient_accumulation_steps] = inputs

                if (idx + 1) % self.args.gradient_accumulation_steps == 0:

                    for _ in range(self.args.num_iterations):
                        for step, inputs in enumerate(self.input_buffer):
                            self.train_step(self.model, inputs, self.optimizer, step)

                        self.update_steps += 1
                        if self.update_steps % self.args.save_steps == 0:
                            self.model.save_pretrained(self.args.output_dir + f'/checkpoint_{self.update_steps}')
                            self.tokenizer.save_pretrained(self.args.output_dir + f'/checkpoint_{self.update_steps}')

                del inputs

    def save_model(self):
        self.model.save_pretrained(self.args.output_dir)
        self.tokenizer.save_pretrained(self.args.output_dir)



In [None]:



class GSM8KDataset(Dataset):
    def __init__(self, data_path, tokenizer, filter_dataset_by_len=False, grpo_args=None):

        self.tokenizer = tokenizer
        data = load_dataset(data_path)
        self.data = data['train']

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        # prompt = self.tokenizer.apply_chat_template(sample['prompt'], tokenize=False, add_generation_prompt=True)
        answer = sample['answer_only']
        # prompt = sample['question_zh']
        prompt = sample['question']
        return {'prompt': prompt, 'answer': answer}


dataset_name = "meta-math/GSM8K_zh"

prompts_dataset = GSM8KDataset(dataset_name, tokenizer, filter_dataset_by_len=True, grpo_args=args)



In [None]:
print(prompts_dataset[0]["prompt"])

print(prompts_dataset[0]["answer"])

# print(prompts_dataset[0]["answer_only"])

Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
72


### Define Reward Function


In [None]:
from box_correct_score import compute_score
import re

# Boxed format reward - checks if answer is within \boxed{} format
def boxed_format_reward(prompts, responses, answers, sample_idx=None):
    """
    Reward function for \boxed{} format.
    Awards points based on:
    - Contains \boxed{...} format: 0.5 points
    - \boxed{} has non-empty content: additional 0.3 points
    - \boxed{} contains a number: additional 0.2 points
    Total maximum: 1.0 points
    """
    rewards = []
    boxed_pattern = r'\\boxed\{([^}]*)\}'

    for response in responses:
        reward = 0.0
        match = re.search(boxed_pattern, response)

        if match:
            reward += 0.5  # Has boxed format
            content = match.group(1).strip()
            if content:  # Boxed content is non-empty
                reward += 0.3
                # Check if content is a number (integer or decimal, with optional negative sign)
                if content.replace('.', '', 1).replace('-', '', 1).isdigit():
                    reward += 0.2

        rewards.append(reward)

    return rewards


def boxed_correctness_reward(prompts, responses, answers, sample_idx=None):
    ret_rewards = []
    for response, answer in zip(responses, answers):
        reward = compute_score(solution_str=response, ground_truth=str(answer), is_longcot=False, is_use_math_verify=True)
        ret_rewards.append(reward)

    if sample_idx == 0:
        print("=" * 50)
        print("question:\n", prompts[0], "\n")
        print("response:\n", responses[0], "\n")
        print("answer:\n", answers[0], "\n")
        # print("reward:\n", ret_rewards[0], "\n")

    return ret_rewards


In [None]:
reward_func_list = [boxed_correctness_reward, boxed_format_reward]

trainer = GRPOTrainer(model=model,
                      reward_funcs = reward_func_list,
                      args=args,
                      train_dataset=prompts_dataset,
                      tokenizer=tokenizer)
trainer.train()

question:
 James finds 3 bills in his pocket.  They are all 20s.  If he already had $75 in his wallet how much money doe he have now? 

response:
 To determine how much money James has now, we need to follow these steps:

1. Calculate the total amount of money James found in his pocket.
2. Add the amount of money he had before finding the cash to the total amount of money he found.

First, let's calculate the total amount of money James found:
\[ 3 \text{ bills} \times \$20 \text{ per bill} = \$60 \]

Next, we add the amount of money he already had in his wallet to the total of the cash he found:
\[ \$75 + \$60 = \$135 \]

So, the final amount of money James has now is:
\[
\boxed{135}
\] 

answer:
 135 

rewards: 2.0
question:
 Georgina owns a parrot and is teaching it to speak. So far, the parrot knows seventeen phrases. If she is teaching it two phrases a week and it already knew three phrases when she bought it, how many days has Georgina had it? 

response:
 To determine how many w