# GRPO with small models
In this notebook, we will attempt to recreate the "aha" moment as seen in DeepSeek r1 paper. I am also following along to [Philipp Schmid's blog post that was reposted on Hugging Face](https://huggingface.co/blog/open-r1/mini-r1-contdown-game).

I'll be using Hugging Face Hub as my remote model versioning service.

## Setup

In [1]:
from dotenv import load_dotenv
import os

load_dotenv()

True

In [2]:
from huggingface_hub import login

login(token=os.getenv("HF_TOKEN"), add_to_git_credential=True) 

  from .autonotebook import tqdm as notebook_tqdm
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer

In [4]:
model_name = "Qwen/Qwen3-0.6B"

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="mps"
)


## Mapping data to format for GRPO
We will be using [Pan Jiayi's Countdown Tasks dataset on Hugging Face](https://huggingface.co/datasets/Jiayi-Pan/Countdown-Tasks-3to4). The motivation is that this was the dataset that both Jiayi and Phillipp used to replicate the DeepSeek aha moment. As we are currently running this locally on a MacBook M1, I've intentionally selected only 1k samples as our train data.

In [6]:
from datasets import load_dataset

In [7]:
dataset_id = "Jiayi-Pan/Countdown-Tasks-3to4"
dataset = load_dataset(dataset_id, split="train")
dataset = dataset.shuffle()#.select(range(5))

Next, we will be formatting each row of data to a suitable prompt for the language model.

In [8]:
def generate_r1_prompt(numbers, target):
    r1_prefix = [
        {
            "role": "system",
            "content": "You are a math expert. You will first reason carefully about the problem, then provide the user with the answer."
        },
        {
            "role": "user",
            "content": f"Given the numbers {numbers} and the target number {target}, please provide a solution to reach the target number using the four basic arithmetic operations: addition, subtraction, multiplication, and division (+, -, *, /). You can use each number only once. Show your work in <think> </think> tags. And return the final equation and answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 = 1 </answer>."
        },
        # {
        #     "role": "assistant",
        #     "content": "Let mes solve this step by step.\n<think> "
        # }
    ]
    return {"prompt": tokenizer.apply_chat_template(r1_prefix, tokenize=False, continue_final_message=False), "target": target}

In [9]:
# convert dataset to r1 format
dataset = dataset.map(
    lambda x: generate_r1_prompt(x["nums"], x["target"]),
)

Map: 100%|██████████| 490364/490364 [00:41<00:00, 11782.80 examples/s]


In [10]:
train_test_split = dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
test_dataset = train_test_split["test"]

## Train the model using GRPO
We will be using 2 reward functions:
1. Format Reward: Checks if the generated answer is in the correct format of `<think> [thinking content] </think><answer> [answer content] </answer>`.
2. Accuracy Reward: Extracts the equation from `<answer>` tag and evaluate it using the two conditions that (a) every number is used once and (b) how close it is to the target. 

In [11]:
import re
import logging # Optional: for more detailed error logging if needed

# --- Constants ---
# Regex for the full <think>...</think><answer>...</answer> format
# It ensures <think> and <answer> are direct children and not nested within each other incorrectly.
# The part ((?:(?!<\/think>).)*) captures content within <think> non-greedily.
# The part ((?:(?!<\/answer>).)*) captures content within <answer> non-greedily.
FORMAT_REGEX_PATTERN = r"^<think>((?:(?!<\/think>).)*)<\/think>\n<answer>((?:(?!<\/answer>).)*)<\/answer>$"
FORMAT_REGEX = re.compile(FORMAT_REGEX_PATTERN, re.DOTALL)

# Regex to extract content from <answer> tag
ANSWER_REGEX_PATTERN = r"<answer>((?:(?!<\/answer>).)*)<\/answer>"
ANSWER_REGEX = re.compile(ANSWER_REGEX_PATTERN, re.DOTALL) # re.DOTALL allows . to match newlines

# Regex for allowed characters in an equation
ALLOWED_EQUATION_CHARS_PATTERN = r'^[\d+\-*/().\s]+$'
ALLOWED_EQUATION_CHARS_REGEX = re.compile(ALLOWED_EQUATION_CHARS_PATTERN)

# Tolerance for float comparisons
FLOAT_COMPARISON_TOLERANCE = 1e-5

# Optional: Setup a logger if you want to see errors instead of just getting 0.0
# logger = logging.getLogger(__name__)
# logging.basicConfig(level=logging.INFO) # Or logging.DEBUG for more verbosity


def format_reward_func(completions: list[str], **kwargs) -> list[float]:
    """
    Checks if completions strictly follow the <think>...</think>\n<answer>...</answer> format.
    The model is expected to generate the full string including the opening <think> tag.

    Args:
        completions (list[str]): Generated outputs from the assistant, each expected to
                                 start with "<think>" and follow the full format.
                                 Example: "<think>I will solve it.</think>\n<answer>42</answer>"
        **kwargs: Additional keyword arguments (ignored by this function).

    Returns:
        list[float]: Reward scores (1.0 for correct format, 0.0 otherwise).
    """
    rewards = []
    for completion_text in completions:
        try:
            # The completion_text itself is expected to be the full string
            match = FORMAT_REGEX.search(completion_text)

            if match and len(match.groups()) == 2:
                # Both <think> and <answer> content captured
                rewards.append(1.0)
            else:
                # logger.debug(f"Format mismatch for: {completion_text}")
                rewards.append(0.0)
        except Exception as e:
            # logger.error(f"Error processing completion for format check: {completion_text}, Error: {e}")
            rewards.append(0.0)
    return rewards


def equation_reward_func(
    completions: list[str],
    target: list[str],
    nums: list[list[str]],
    **kwargs
) -> list[float]:
    rewards = []
    for completion_text, target_str, available_nums_str in zip(completions, target, nums):
        try:
            current_reward = 0.0 

            answer_match = ANSWER_REGEX.search(completion_text)
            if not answer_match:
                rewards.append(current_reward)
                continue

            full_answer_content = answer_match.group(1).strip()
            if not full_answer_content:
                rewards.append(current_reward)
                continue

            # Try to split "expression = result"
            parts = full_answer_content.rsplit('=', 1)
            expression_part_str = parts[0].strip()
            
            # It's possible the model only gives the expression, or gives something else.
            # If no '=', the whole thing is the expression.
            if len(parts) == 1: # No '=' found or it's at the very beginning
                expression_part_str = full_answer_content 
            elif not expression_part_str: # Handles cases like "= 138"
                rewards.append(current_reward) # Invalid format if expression is empty
                continue


            if not expression_part_str: # Handle empty expression after split
                rewards.append(current_reward)
                continue

            # 1. Check for allowed characters in the EXPRESSION PART
            if not ALLOWED_EQUATION_CHARS_REGEX.match(expression_part_str):
                # print(f"Debug: Expression '{expression_part_str}' contains forbidden characters.")
                rewards.append(current_reward)
                continue

            # 2. Check number usage (using expression_part_str)
            try:
                expected_numbers_int = sorted([int(n) for n in available_nums_str])
            except ValueError:
                rewards.append(current_reward)
                continue

            used_numbers_str = re.findall(r'\d+', expression_part_str) # Check numbers in expression only
            try:
                used_numbers_int = sorted([int(n) for n in used_numbers_str])
            except ValueError:
                rewards.append(current_reward)
                continue

            if used_numbers_int != expected_numbers_int:
                # print(f"Debug: Number usage mismatch. Used: {used_numbers_int}, Expected: {expected_numbers_int} in '{expression_part_str}'")
                rewards.append(current_reward)
                continue

            # 3. Evaluate the EXPRESSION PART and check correctness against target_str
            try:
                target_val = float(target_str)
                eval_globals = {"__builtins__": {}} 
                eval_locals = {}
                result = eval(expression_part_str, eval_globals, eval_locals) # Evaluate only the expression

                if abs(float(result) - target_val) < FLOAT_COMPARISON_TOLERANCE:
                    current_reward = 1.0
                # else:
                    # print(f"Debug: Equation result mismatch. Eq: '{expression_part_str}' -> {result}, Target: {target_val}")
            except SyntaxError:
                # print(f"Debug: Syntax error in expression: {expression_part_str}")
                pass 
            except TypeError:
                pass 
            except ZeroDivisionError:
                pass 
            except Exception as eval_e:
                # print(f"Debug: Unexpected error evaluating expression '{expression_part_str}': {eval_e}")
                pass 

            rewards.append(current_reward)

        except Exception as e:
            rewards.append(0.0)
    return rewards

Let's try our reward function with a few samples.

In [12]:
test_completions = [
    "<think>I need to use 2, 3, and 4 to make 14. I can multiply 4 by 3, which is 12, and then add 2 to get 14.</think>\n<answer>4 * 3 + 2</answer>",
    "<think>I will try to make 14. 4 times 3 is 12, plus 2 is 14.</think><answer>(4*3)+2</answer>", 
    "<think>I need to get 14. What if I use 7 and 2? 7 multiplied by 2 is 14.</think>\n<answer>7 * 2</answer>", # Wrong numbers
    "<think>I think the answer involves multiplication and addition. Let's try to spell it out.</think>\n<answer>four times three plus two equals 14</answer>" # Forbidden chars
]

In [13]:
# Shared data for all test completions
test_targets = ["14", "14", "14", "14"]
test_nums_list = [
    ["2", "3", "4"],
    ["2", "3", "4"],
    ["2", "3", "4"],
    ["2", "3", "4"],
]

In [14]:
format_rewards = format_reward_func(test_completions)
format_expected_rewards = [1.0, 0.0, 1.0, 1.0]
assert format_rewards == format_expected_rewards, f"Format rewards: {format_rewards}, Expected: {format_expected_rewards}"

In [15]:
equation_rewards = equation_reward_func(test_completions, test_targets, test_nums_list)
equation_expected_rewards = [1.0, 1.0, 0.0, 0.0]
assert equation_rewards == equation_expected_rewards, f"Equation rewards: {equation_rewards}, Expected: {equation_expected_rewards}"

Looking good! Now we just need to define our training parameters, create the trainer, and start training.

In [16]:
print(model)

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): Qwe

In [17]:
from trl import GRPOConfig, GRPOTrainer, get_peft_config, ModelConfig
from peft import LoraConfig

model_config = ModelConfig(
    model_name_or_path=model_name,
    torch_dtype="bfloat16",
    use_peft=True,
    load_in_4bit=True
)

peft_config = LoraConfig(
    r=16,                       # LoRA rank - higher means more capacity but more parameters
    lora_alpha=32,              # LoRA alpha - scaling factor
    lora_dropout=0.05,          # Dropout probability for LoRA layers
    bias="none",                # Don't train bias parameters to save memory
    task_type="CAUSAL_LM",      # Task type for causal language modeling
    target_modules=[
        # Attention layers
        "q_proj", 
        "k_proj", 
        "v_proj", 
        "o_proj",
        # MLP/FFN layers
        "gate_proj",  
        "up_proj", 
        "down_proj"
    ],
    # QLoRA specific settings
    fan_in_fan_out=False,       # Set to True for specific architectures that need this
    modules_to_save=None,       # Specific modules to fully fine-tune if needed
)

training_args = GRPOConfig(
    output_dir="qwen-r1-aha-countdown-tasks",
    learning_rate=5e-7,
    lr_scheduler_type="cosine",
    logging_steps=10,
    max_steps=100,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    bf16=True,
    # GRPO specific arguments
    max_prompt_length=256,
    max_completion_length=1024,
    num_generations=2,
    beta=0.001,
)

INFO 05-10 22:17:42 [importing.py:17] Triton not installed or not compatible; certain GPU-related functions will not be available.
INFO 05-10 22:17:42 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-10 22:17:42 [__init__.py:239] Automatically detected platform cpu.


In [18]:
trainer = GRPOTrainer(
    model=model_config.model_name_or_path,
    reward_funcs=[format_reward_func, equation_reward_func],
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    peft_config=peft_config,
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [19]:
trainer.train()
trainer.save_model(training_args.output_dir)

`generation_config` default values have been modified to match model-specific defaults: {'top_k': 20, 'top_p': 0.95, 'bos_token_id': 151643}. If this is not desired, please set these values explicitly.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss


KeyboardInterrupt: 

In [None]:
train_dataset

Dataset({
    features: ['target', 'nums', 'prompt'],
    num_rows: 900
})

## View inference of trained model
We want to further inspect the thought process behind the trained model to verify if the model has sufficiently learnt how to reason and solve the problem.

To ensure that we can run this section independently of the other sections above, there will be some repeat of code here.

In [8]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

In [9]:
base_model_id = "Qwen/Qwen3-0.6B"
adapter_id = "qwen-r1-aha-countdown-tasks/checkpoint-400"

In [10]:
tokenizer = AutoTokenizer.from_pretrained(base_model_id)

In [11]:
# Qwen models usually require a pad token, and it's often set to eos_token if not present
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.eos_token})")

In [13]:
print(f"Loading base model '{base_model_id}'...")
try:
    # Try bfloat16 first, as it was in your training config
    print("Attempting to load with torch_dtype=torch.bfloat16 (as per training config)...")
    base_model_for_peft = AutoModelForCausalLM.from_pretrained(
        base_model_id,
        torch_dtype=torch.bfloat16, # From your ModelConfig
        device_map="auto",          # Should map to MPS on Mac
        # trust_remote_code=True # Uncomment if Qwen3 requires it (often needed for Qwen models)
    )
    print("Base model loaded successfully with bfloat16.")
except Exception as e_bf16:
    print(f"Error loading base model with bfloat16: {e_bf16}")
    print("Attempting to load with torch_dtype=torch.float16 as a common MPS fallback...")
    try:
        base_model_for_peft = AutoModelForCausalLM.from_pretrained(
            base_model_id,
            torch_dtype=torch.float16, # Common MPS-friendly dtype
            device_map="auto",
            # trust_remote_code=True
        )
        print("Base model loaded successfully with float16.")
    except Exception as e_f16:
        print(f"Error loading base model with float16: {e_f16}")
        print("Attempting to load with torch_dtype='auto' as a general fallback...")
        base_model_for_peft = AutoModelForCausalLM.from_pretrained(
            base_model_id,
            torch_dtype="auto",       # General fallback
            device_map="auto",
            # trust_remote_code=True
        )
        print("Base model loaded with torch_dtype='auto'.")

Loading base model 'Qwen/Qwen3-0.6B'...
Attempting to load with torch_dtype=torch.bfloat16 (as per training config)...
Base model loaded successfully with bfloat16.


In [14]:
print(f"Loading PEFT adapter from '{adapter_id}'...")
# The adapter should be compatible with the precision of the loaded base_model_for_peft
ft_model = PeftModel.from_pretrained(base_model_for_peft, adapter_id)
ft_model.eval()  # Set the model to evaluation mode for inference
print("Fine-tuned PEFT model loaded and set to evaluation mode.")
print(f"Fine-tuned model is on device: {ft_model.device} with dtype: {ft_model.dtype}")

Loading PEFT adapter from 'qwen-r1-aha-countdown-tasks/checkpoint-400'...
Fine-tuned PEFT model loaded and set to evaluation mode.
Fine-tuned model is on device: mps:0 with dtype: torch.bfloat16


In [17]:
if 'generate_r1_prompt' in globals():
    print("\n--- Example Inference ---")
    sample_numbers = ["25", "10", "4", "2"]
    sample_target = "138"

    r1_example_data = generate_r1_prompt(sample_numbers, sample_target)
    prompt_text = r1_example_data["prompt"]
    
    print(f"\nFormatted Prompt:\n{prompt_text}")

    inputs = tokenizer(prompt_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    
    # Move inputs to the same device as the model
    inputs = {k: v.to(ft_model.device) for k, v in inputs.items()}

    print(f"\nGenerating completion for target {sample_target} with numbers {sample_numbers}...")
    with torch.no_grad():
        outputs = ft_model.generate(
            **inputs,
            max_new_tokens=1024,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=True,
            top_k=50,
            top_p=0.9,
            temperature=0.7,
        )
    
    decoded_full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("\n--- Generated Full Output ---")
    print(decoded_full_output)

    prompt_tokens_count = inputs["input_ids"].shape[1]
    generated_part_tokens = outputs[0][prompt_tokens_count:]
    decoded_generated_part = tokenizer.decode(generated_part_tokens, skip_special_tokens=True)
    print("\n--- Generated Assistant's Response ---")
    print(decoded_generated_part)
else:
    print("\n`generate_r1_prompt` function not found in global scope. Skipping inference example.")


--- Example Inference ---

Formatted Prompt:
<|im_start|>system
You are a math expert. You will first reason carefully about the problem, then provide the user with the answer.<|im_end|>
<|im_start|>user
Given the numbers ['25', '10', '4', '2'] and the target number 138, please provide a solution to reach the target number using the four basic arithmetic operations: addition, subtraction, multiplication, and division (+, -, *, /). You can use each number only once. Show your work in <think> </think> tags. And return the final equation and answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 = 1 </answer>.<|im_end|>


Generating completion for target 138 with numbers ['25', '10', '4', '2']...

--- Generated Full Output ---
system
You are a math expert. You will first reason carefully about the problem, then provide the user with the answer.
user
Given the numbers ['25', '10', '4', '2'] and the target number 138, please provide a solution to reach the target number using 

Oh wait that's weird! I thought 0.6B model did not manage to get anything correct? Seems like I need to work on my training code to spot if there are any issues!