In [1]:
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

def extract_xml_answer(text: str) -> str:
    """
    Extracts the answer portion from an XML-formatted text response.
    
    Args:
        text (str): The full XML-formatted text containing <answer> tags
        
    Returns:
        str: The extracted answer text between <answer> tags, stripped of whitespace
    """
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()


def extract_hash_answer(text: str) -> str | None:
    """
    Extracts the numerical answer from a text that contains a hash (####) marker.
    Removes commas and dollar signs from the extracted answer.
    
    Args:
        text (str): The text containing the answer marked with ####
        
    Returns:
        str | None: The cleaned numerical answer, or None if no hash marker is found
    """
    if "####" not in text:
        return None
    return text.split("####")[1].strip().replace(",", "").replace("$", "")


def get_gsm8k_questions(split="train") -> Dataset:
    """
    Loads and processes the GSM8K dataset, formatting questions with system prompts.
    
    Args:
        split (str): The dataset split to load ("train" or "test")
        
    Returns:
        Dataset: A processed dataset containing formatted prompts and answers
    """
    data = load_dataset("openai/gsm8k", "main")[split]  # type: ignore
    data = data.map(
        lambda x: {  # type: ignore
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": x["question"]},
            ],
            "answer": extract_hash_answer(x["answer"]),
        }
    )  # type: ignore
    return data  # type: ignore


In [4]:
dataset = get_gsm8k_questions()

Map: 100%|██████████| 7473/7473 [00:00<00:00, 24817.99 examples/s]


In [5]:
for i in dataset:
    print(i)
    break

{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?', 'answer': '72', 'prompt': [{'content': '\nRespond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n', 'role': 'system'}, {'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?', 'role': 'user'}]}
