# GRPO with small models
In this notebook, we will attempt to recreate the "aha" moment as seen in DeepSeek r1 paper. I am also following along to [Philipp Schmid's blog post that was reposted on Hugging Face](https://huggingface.co/blog/open-r1/mini-r1-contdown-game).

I'll be using Hugging Face Hub as my remote model versioning service.

## Setup

In [5]:
from dotenv import load_dotenv
import os

load_dotenv()

True

In [8]:
from huggingface_hub import login

login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True) 

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "Qwen/Qwen3-0.6B"

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="mps"
)


## Mapping data to format for GRPO
We will be using [Pan Jiayi's Countdown Tasks dataset on Hugging Face](https://huggingface.co/datasets/Jiayi-Pan/Countdown-Tasks-3to4). The motivation is that this was the dataset that both Jiayi and Phillipp used to replicate the DeepSeek aha moment. As we are currently running this locally on a MacBook M1, I've intentionally selected only 1k samples as our train data.

In [12]:
from datasets import load_dataset

In [14]:
dataset_id = "Jiayi-Pan/Countdown-Tasks-3to4"
dataset = load_dataset(dataset_id, split="train")
dataset = dataset.shuffle().select(range(1000))

Next, we will be formatting each row of data to a suitable prompt for the language model.

In [20]:
def generate_r1_prompt(numbers, target):
    r1_prefix = [
        {
            "role": "system",
            "content": "You are a math expert. You will first reason carefully about the problem, then provide the user with the answer."
        },
        {
            "role": "user",
            "content": f"Given the numbers {numbers} and the target number {target}, please provide a solution to reach the target number using the four basic arithmetic operations: addition, subtraction, multiplication, and division (+, -, *, /). You can use each number only once. Show your work in <think> </think> tags. And return the final equation and answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 = 1 </answer>."
        },
        # {
        #     "role": "assistant",
        #     "content": "Let mes solve this step by step.\n<think> "
        # }
    ]
    return {"prompt": tokenizer.apply_chat_template(r1_prefix, tokenize=False, continue_final_message=False), "target": target}

In [21]:
# convert dataset to r1 format
dataset = dataset.map(
    lambda x: generate_r1_prompt(x["nums"], x["target"]),
)

Map: 100%|██████████| 1000/1000 [00:00<00:00, 11415.40 examples/s]


In [22]:
train_test_split = dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
test_dataset = train_test_split["test"]

## Train the model using GRPO
We will be using 2 reward functions:
1. Format Reward: Checks if the generated answer is in the correct format of `<think> [thinking content] </think><answer> [answer content] </answer>`.
2. Accuracy Reward: Extracts the equation from `<answer>` tag and evaluate it using the two conditions that (a) every number is used once and (b) how close it is to the target. 

In [24]:
import re
import logging # Optional: for more detailed error logging if needed

# --- Constants ---
# Regex for the full <think>...</think><answer>...</answer> format
# It ensures <think> and <answer> are direct children and not nested within each other incorrectly.
# The part ((?:(?!<\/think>).)*) captures content within <think> non-greedily.
# The part ((?:(?!<\/answer>).)*) captures content within <answer> non-greedily.
FORMAT_REGEX_PATTERN = r"^<think>((?:(?!<\/think>).)*)<\/think>\n<answer>((?:(?!<\/answer>).)*)<\/answer>$"
FORMAT_REGEX = re.compile(FORMAT_REGEX_PATTERN, re.DOTALL)

# Regex to extract content from <answer> tag
ANSWER_REGEX_PATTERN = r"<answer>((?:(?!<\/answer>).)*)<\/answer>"
ANSWER_REGEX = re.compile(ANSWER_REGEX_PATTERN, re.DOTALL) # re.DOTALL allows . to match newlines

# Regex for allowed characters in an equation
ALLOWED_EQUATION_CHARS_PATTERN = r'^[\d+\-*/().\s]+$'
ALLOWED_EQUATION_CHARS_REGEX = re.compile(ALLOWED_EQUATION_CHARS_PATTERN)

# Tolerance for float comparisons
FLOAT_COMPARISON_TOLERANCE = 1e-5

# Optional: Setup a logger if you want to see errors instead of just getting 0.0
# logger = logging.getLogger(__name__)
# logging.basicConfig(level=logging.INFO) # Or logging.DEBUG for more verbosity


def format_reward_func(completions: list[str], **kwargs) -> list[float]:
    """
    Checks if completions strictly follow the <think>...</think>\n<answer>...</answer> format.
    The model is expected to generate the full string including the opening <think> tag.

    Args:
        completions (list[str]): Generated outputs from the assistant, each expected to
                                 start with "<think>" and follow the full format.
                                 Example: "<think>I will solve it.</think>\n<answer>42</answer>"
        **kwargs: Additional keyword arguments (ignored by this function).

    Returns:
        list[float]: Reward scores (1.0 for correct format, 0.0 otherwise).
    """
    rewards = []
    for completion_text in completions:
        try:
            # The completion_text itself is expected to be the full string
            match = FORMAT_REGEX.search(completion_text)

            if match and len(match.groups()) == 2:
                # Both <think> and <answer> content captured
                rewards.append(1.0)
            else:
                # logger.debug(f"Format mismatch for: {completion_text}")
                rewards.append(0.0)
        except Exception as e:
            # logger.error(f"Error processing completion for format check: {completion_text}, Error: {e}")
            rewards.append(0.0)
    return rewards


def equation_reward_func(
    completions: list[str],
    targets: list[str],
    nums_list: list[list[str]],
    **kwargs
) -> list[float]:
    """
    Evaluates completions based on:
    1. Presence of a valid <answer>...</answer> tag within the completion.
    2. Mathematical correctness of the equation in the <answer> tag.
    3. Usage of all numbers from the `nums_list` exactly once in the equation.
    4. Equation only contains allowed characters (numbers, operators, parentheses, whitespace).

    The model is expected to generate the full string including the opening <think> tag.

    Args:
        completions (list[str]): Generated outputs from the assistant, each expected to
                                 start with "<think>" and contain an <answer> tag.
                                 Example: "<think>The equation is 2*3.</think>\n<answer>2*3</answer>"
        targets (list[str]): Expected numerical answers (as strings).
        nums_list (list[list[str]]): For each completion, a list of available numbers (as strings)
                                     that must be used in the equation. Example: [["2", "3"], ["1", "5", "7"]]
        **kwargs: Additional keyword arguments (ignored by this function).

    Returns:
        list[float]: Reward scores (1.0 for correct, 0.0 otherwise).
    """
    rewards = []
    for completion_text, target_str, available_nums_str in zip(completions, targets, nums_list):
        try:
            current_reward = 0.0 # Default to 0.0, set to 1.0 only on full success

            # The completion_text itself is expected to be the full string
            answer_match = ANSWER_REGEX.search(completion_text)
            if not answer_match:
                # logger.debug(f"No <answer> tag in: {completion_text}")
                rewards.append(current_reward)
                continue

            equation_str = answer_match.group(1).strip()
            if not equation_str: # Handle empty answer
                # logger.debug(f"Empty <answer> content in: {completion_text}")
                rewards.append(current_reward)
                continue

            # 1. Check for allowed characters in the equation
            if not ALLOWED_EQUATION_CHARS_REGEX.match(equation_str):
                # logger.debug(f"Equation '{equation_str}' contains forbidden characters.")
                rewards.append(current_reward)
                continue

            # 2. Check number usage
            try:
                expected_numbers_int = sorted([int(n) for n in available_nums_str])
            except ValueError:
                # logger.error(f"Invalid non-integer number in available_nums_str: {available_nums_str}")
                rewards.append(current_reward)
                continue

            used_numbers_str = re.findall(r'\d+', equation_str)
            try:
                used_numbers_int = sorted([int(n) for n in used_numbers_str])
            except ValueError:
                # logger.debug(f"Invalid number format in equation '{equation_str}'.")
                rewards.append(current_reward)
                continue

            if used_numbers_int != expected_numbers_int:
                # logger.debug(f"Number usage mismatch. Used: {used_numbers_int}, Expected: {expected_numbers_int} in '{equation_str}'")
                rewards.append(current_reward)
                continue

            # 3. Evaluate the equation and check correctness
            try:
                target_val = float(target_str)
                eval_globals = {"__builtins__": {}}
                eval_locals = {}
                result = eval(equation_str, eval_globals, eval_locals)

                if abs(float(result) - target_val) < FLOAT_COMPARISON_TOLERANCE:
                    current_reward = 1.0
                # else:
                    # logger.debug(f"Equation result mismatch. Eq: '{equation_str}' -> {result}, Target: {target_val}")
            except SyntaxError:
                # logger.debug(f"Syntax error in equation: {equation_str}")
                pass # current_reward remains 0.0
            except TypeError:
                # logger.debug(f"Type error during evaluation (e.g. trying to operate on non-numerics): {equation_str}")
                pass # current_reward remains 0.0
            except ZeroDivisionError:
                # logger.debug(f"Zero division error in equation: {equation_str}")
                pass # current_reward remains 0.0
            except Exception as eval_e:
                # logger.warning(f"Unexpected error evaluating equation '{equation_str}': {eval_e}")
                pass # current_reward remains 0.0

            rewards.append(current_reward)

        except Exception as e:
            # logger.error(f"Outer error processing completion for equation check: {completion_text}, Error: {e}")
            rewards.append(0.0)
    return rewards

Let's try our reward function with a few samples.

In [None]:
test_completions = [
    "<think>I need to use 2, 3, and 4 to make 14. I can multiply 4 by 3, which is 12, and then add 2 to get 14.</think>\n<answer>4 * 3 + 2</answer>",
    "<think>I will try to make 14. 4 times 3 is 12, plus 2 is 14.</think><answer>(4*3)+2</answer>", 
    "<think>I need to get 14. What if I use 7 and 2? 7 multiplied by 2 is 14.</think>\n<answer>7 * 2</answer>", # Wrong numbers
    "<think>I think the answer involves multiplication and addition. Let's try to spell it out.</think>\n<answer>four times three plus two equals 14</answer>" # Forbidden chars
]

In [26]:
# Shared data for all test completions
test_targets = ["14", "14", "14", "14"]
test_nums_list = [
    ["2", "3", "4"],
    ["2", "3", "4"],
    ["2", "3", "4"],
    ["2", "3", "4"],
]

In [None]:
format_rewards = format_reward_func(test_completions)
format_expected_rewards = [1.0, 0.0, 1.0, 1.0]
assert format_rewards == format_expected_rewards, f"Format rewards: {format_rewards}, Expected: {format_expected_rewards}"

In [None]:
equation_rewards = equation_reward_func(test_completions, test_targets, test_nums_list)
equation_expected_rewards = [1.0, 1.0, 0.0, 0.0]
assert equation_rewards == equation_expected_rewards, f"Equation rewards: {equation_rewards}, Expected: {equation_expected_rewards}"

Looking good! Now we just need to define our training parameters, create the trainer, and start training.