# 🧠 Fine-Tuning LLaMA on the Countdown Task with GRPO

Welcome to this notebook where we bring together two powerful ideas:

- 🔁 **Group Relative Policy Optimization (GRPO)** – a novel, efficient reinforcement learning technique for aligning large language models, and  
- 🔢 **The Countdown Task** – a symbolic reasoning challenge where models must compose equations using basic arithmetic to hit a target value.

---

## 🎯 The Countdown Task

Given a list of numbers and a target value, the model must generate an equation using **each number exactly once**, applying basic operations:

$$
\text{Operations: } +,\ -,\ \times,\ \div
$$

**Example:**

```text
Numbers: [3, 7, 50], Target: 29  
Solution: (50 - 7) + (3) = 46 → ❌  
         ((50 / (7 + 3)) + 3) = 8 → ❌  
         (7 * 3) + 8 = 29 → ✅ (if 8 was in the list)

In [None]:
from lmalign.model import load_model_and_tokenizer
from trl import GRPOConfig, GRPOTrainer
from datasets import DatasetDict, load_dataset, Dataset
import random
import torch
import regex as re
from typing import List, Dict, Any, Union

In [None]:
import numpy as np
from transformers import set_seed

SEED = 42

# Set seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Set deterministic behavior (optional, but useful for debugging)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# HF transformers seed
set_seed(SEED)

In [None]:
config = {
    "model": {
        "id": "meta-llama/Llama-3.1-8B-Instruct",
        "quantization": {
            "enabled": True,
            "type": "4bit",
            "compute_dtype": "bf16",
        },
        "peft_settings": {
            "enabled": True,
            "mode": "create",
            "lora_config": {
                "r": 16,
                "alpha": 32,
                "dropout": 0.05,
                "target_modules": "all-linear",
            },
        },
    },
    "tokenizer": {"id": "meta-llama/Llama-3.1-8B-Instruct"},
}

# Model loading

In [None]:
model, tokenizer = load_model_and_tokenizer(config=config)

# Load the dataset

In [None]:
def format_batch_to_chatml(batch, system_prompt, user_template, assistant_template):
    all_conversations = []
    # Use zip for cleaner iteration over columns in the batch
    for target, numbers in zip(batch["target"], batch["nums"]):
        numbers_str = ", ".join(map(str, numbers))
        conversation = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_template.format(numbers=numbers_str, target=target)},
            {"role": "assistant", "content": assistant_template},
        ]
        all_conversations.append(conversation)
    # Return a dictionary where the key is the new column name
    # and the value is the list of processed items for the batch
    return {"prompt": all_conversations}

In [None]:
system_prompt = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)
user_template = (
    "Using the numbers {numbers}, create an equation that equals {target} and provide the solution step by step. "
    "You can use basic arithmetic operations (+, -, *, /) and parentheses to create the equation. "
    "Each number can only be used once. If it is not possible to create an equation, please say 'impossible'."
)
assistent_template = "Let me solve this step by step. <think>"

In [None]:
dataset = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split="train")
dataset

In [None]:
dataset = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split="train")
dataset = dataset.select(range(50_000))
dataset = dataset.map(
    format_batch_to_chatml,  # Use the batched version
    batched=True,  # Process in batches (highly recommended)
    fn_kwargs={
        "system_prompt": system_prompt,
        "user_template": user_template,
        "assistant_template": assistent_template,  # Pass the variable
    },
)
train_test_split = dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
test_dataset = train_test_split["test"]

In [None]:
train_dataset[0]

In [None]:
dataset = DatasetDict(
    {
        "train": train_dataset,
        "test": test_dataset,
    }
)

In [None]:
del train_test_split
del train_dataset
del test_dataset

In [None]:
# # Test getting an output:

# sample = dataset["train"][0]
# input_text = sample["conversations"]

# # Convert the conversation into a string using the tokenizer's chat template
# input_text_str = tokenizer.apply_chat_template(
#     input_text,
#     tokenize=False,  # set to True if you want token IDs instead
#     add_generation_prompt=False,  # or True if you're about to generate a reply
# )

# inputs = tokenizer(input_text_str, return_tensors="pt", padding=True, truncation=True)
# input_ids = inputs["input_ids"].to(model.device)


# attention_mask = inputs["attention_mask"].to(model.device)

# outputs = model.generate(
#     input_ids=input_ids,
#     attention_mask=attention_mask,
#     max_new_tokens=256,
#     do_sample=True,
#     temperature=0.7,
#     top_p=0.95,
#     top_k=50,
# )

# output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# print("Output Text:")
# print(output_text)

# Reward function

In [None]:
def format_reward_func(
    completions: List[Union[str, List[Dict[str, str]], Dict[str, str]]], **kwargs: Any
) -> List[float]:
    """
    Reward function that checks if completions follow the required format pattern.

    The function checks each completion for the pattern <think>...</think><answer>...</answer>
    and assigns a binary reward (1.0 for matching, 0.0 for non-matching).

    Args:
        completions: List of model completions. Can be one of:
            - List of strings
            - List of lists containing dictionaries with "content" key
            - List of dictionaries with "content" key
        **kwargs: Additional keyword arguments (unused but allowed for compatibility)

    Returns:
        List[float]: A list of reward scores (1.0 for matching format, 0.0 otherwise)

    Example:
        >>> completions = ["<think>Some reasoning</think><answer>42</answer>"]
        >>> efficient_format_reward(completions)
        [1.0]

        >>> completions = [[{"content": "<think>Reasoning</think><answer>Result</answer>"}]]
        >>> efficient_format_reward(completions)
        [1.0]
    """
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    rewards = []

    for completion in completions:
        try:
            # Handle different input formats
            if isinstance(completion, list):
                content = completion[0]["content"]
            elif isinstance(completion, dict):
                content = completion["content"]
            else:
                content = completion

            # Check if the pattern is present anywhere in the string using DOTALL to match across lines
            match = re.search(pattern, content, re.DOTALL)
            rewards.append(1.0 if match else 0.0)
        except (KeyError, IndexError, TypeError):
            # Handle cases where the expected structure isn't found
            rewards.append(0.0)

    return rewards

In [None]:
def equation_reward_func(completions, target, nums, **kwargs):
    """
    Evaluates completions based on:
    2. Mathematical correctness of the answer

    Args:
        completions (list[str]): Generated outputs
        target (list[str]): Expected answers
        nums (list[str]): Available numbers

    Returns:
        list[float]: Reward scores
    """
    rewards = []
    for completion, gt, numbers in zip(completions, target, nums):
        try:
            # add synthetic <think> as its already part of the prompt and prefilled for the assistant to more easily match the regex
            completion = "<think>" + completion

            # Check if the format is correct
            match = re.search(r"<answer>(.*?)<\/answer>", completion)
            if match is None:
                rewards.append(0.0)
                continue

            # Extract the "answer" part from the completion
            equation = match.group(1).strip()

            # Extract all numbers from the equation
            used_numbers = [int(n) for n in re.findall(r"\d+", equation)]

            # Check if all numbers are used exactly once
            if sorted(used_numbers) != sorted(numbers):
                rewards.append(0.0)
                continue

            # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
            allowed_pattern = r"^[\d+\-*/().\s]+$"
            if not re.match(allowed_pattern, equation):
                rewards.append(0.0)
                continue

            # Evaluate the equation with restricted globals and locals
            result = eval(equation, {"__builtins__": None}, {})

            # Check if the equation is correct and matches the ground truth
            if abs(float(result) - float(gt)) < 1e-5:
                rewards.append(1.0)
            else:
                rewards.append(0.0)
        except Exception:

            # If evaluation fails, reward is 0
            rewards.append(0.0)
    return rewards

In [None]:
correct_sample_1 = """We need to find an equation using the numbers 19, 36, 55, and 7
exactly once, with basic arithmetic operations, that equals 65. One possible
combination is 55 + 36 - 19 + 7... </think>
<answer> 55 + 36 - 7 - 19 </answer>"""

correct_sample_2 = """ ... </think>
<answer> 55 + 36 - 7 - 19 </answer>"""

wrong_format = """User: Using the numbers [19, 36, 55, 7], create an equation that equals 65."""

wrong_format_2 = """To find the equation that equals 79 using the numbers 95, 78, 6, 88, I'll start by adding 88 and 95:                      
95 + 88 = 183                                                                                                              
Now, let's subtract 104 from 183 to get 79:
183 - 104 = 79
<think> 183 - 104 = 79 </think><think> 183 - 104 = 79 </think><answer> 183 - 104 = 79 </answer>"""

wrong_result = """ ... </think>
<answer> 55 + 36 - 7 - 18 </answer>"""


test_rewards = format_reward_func(
    completions=[correct_sample_1, correct_sample_2, wrong_format, wrong_format_2, wrong_result],
    target=["65", "65", "65", "65", "65"],
    nums=[[19, 36, 55, 7]] * 5,
)
assert test_rewards == [1.0, 1.0, 0.0, 0.0, 1.0], "Reward function is not working"
test_rewards = equation_reward_func(
    completions=[correct_sample_1, correct_sample_2, wrong_format, wrong_format_2, wrong_result],
    target=["65", "65", "65", "65", "65"],
    nums=[[19, 36, 55, 7]] * 5,
)
assert test_rewards == [1.0, 1.0, 0.0, 0.0, 0.0], "Reward function is not working"

In [None]:
dataset["train"][0]

# Model Training

In [None]:
training_args = GRPOConfig(
    output_dir="./output",
    learning_rate=1e-5,
    remove_unused_columns=False,
    gradient_accumulation_steps=16,
    num_train_epochs=1,
    bf16=True,
    lr_scheduler_type="cosine",
    logging_steps=1,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    # GRPO specific parameters
    max_prompt_length=64,
    max_completion_length=128,  # max length of the generated output for our solution
    num_generations=4,
    beta=0.001,
)

In [None]:
# Trainer
# Seems to create a column "completions" with the generated outputs
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        format_reward_func,
        equation_reward_func,
    ],
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
)

In [None]:
trainer.train()

# Generating Text

In [None]:
user_prompt_sample = (
    "Using the numbers 95, 21, 3,, create an equation that equals 88 and provide the solution step by step. "
    "You can use basic arithmetic operations (+, -, *, /) and parentheses to create the equation. "
    "Each number can only be used once. If it is not possible to create an equation, please say 'impossible'."
)

In [None]:
input_text_sample_output = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": user_prompt_sample},
]

# Convert the conversation into a string using the tokenizer's chat template
input_text_sample_str = tokenizer.apply_chat_template(
    input_text_sample_output,
    tokenize=False,  # set to True if you want token IDs instead
    add_generation_prompt=True,  # or True if you're about to generate a reply
)

inputs_sample = tokenizer(input_text_sample_str, return_tensors="pt", padding=True, truncation=True)
input_ids_sample = inputs_sample["input_ids"].to(model.device)
attention_mask_sample = inputs_sample["attention_mask"].to(model.device)
outputs_sample = model.generate(
    input_ids=input_ids_sample,
    attention_mask=attention_mask_sample,
    max_new_tokens=256,
    do_sample=True,
    temperature=0.7,
    top_p=0.95,
    top_k=50,
)
output_text_sample = tokenizer.decode(outputs_sample[0], skip_special_tokens=True)
print("Output Text:")
print(output_text_sample)