### Training a model with math problems from the GSM8k dataset, guided by GRPO reward function

```python
# train_grpo.py
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
```
This section imports necessary libraries:
- `re`:  For regular expressions, used for pattern matching in text, specifically for checking the output format.
- `torch`: PyTorch, a deep learning framework, used for tensor operations and model training.
- `datasets`: From Hugging Face `datasets` library, used for loading and processing datasets, like GSM8k in this case.
- `Dataset`:  Specifically imports the `Dataset` class from `datasets` for type hinting and clarity.
- `transformers`: From Hugging Face `transformers` library, provides pre-trained models and tokenizers.
    - `AutoTokenizer`:  Automatically loads the correct tokenizer for a given pre-trained model.
    - `AutoModelForCausalLM`: Automatically loads a pre-trained language model for causal language modeling (text generation).
- `peft`: From Hugging Face `peft` library (Parameter-Efficient Fine-Tuning), used for applying techniques like LoRA to fine-tune models efficiently.
    - `LoraConfig`: Configuration for LoRA (Low-Rank Adaptation), a PEFT technique.
- `trl`: From Hugging Face `trl` library (Transformers Reinforcement Learning), provides tools and trainers for reinforcement learning with language models.
    - `GRPOConfig`:  Configuration for GRPO (Guided Reward Policy Optimization) training.
    - `GRPOTrainer`: Trainer class specifically designed for GRPO training.

```python
# Load and prep dataset

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
```
This part defines constants used for prompt engineering and output formatting:
- `SYSTEM_PROMPT`:  A system-level instruction given to the language model at the beginning of the conversation. It tells the model to respond in a specific XML format with `<reasoning>` and `<answer>` tags. This encourages the model to perform Chain-of-Thought (CoT) reasoning and present the answer clearly.
- `XML_COT_FORMAT`: A string template defining the XML format for Chain-of-Thought outputs. It will be used to structure example outputs (though not used in the current configuration of the prompt, see commented section below).

```python
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()
```
These are utility functions to extract answers from text:
- `extract_xml_answer(text: str) -> str`:
    - Takes a text string as input, which is expected to be in XML format with `<answer>` tags.
    - `text.split("<answer>")[-1]`: Splits the text by the `<answer>` tag and takes the last part (everything after the last `<answer>` tag). This is done to handle cases where there might be text before the first `<answer>` tag.
    - `.split("</answer>")[0]`: Splits the result from the previous step by the `</answer>` tag and takes the first part (everything before the first `</answer>` tag).
    - `.strip()`: Removes leading/trailing whitespace from the extracted answer.
    - Returns the extracted answer as a string.
    - This function is designed to extract the answer content from within `<answer>` XML tags.
- `extract_hash_answer(text: str) -> str | None`:
    - Takes a text string (likely from the GSM8k dataset) as input.
    - `if "####" not in text`: Checks if the separator "####" is present in the text. If not, it means there is no answer in the expected format (for GSM8k, answers are typically marked after "####").
    - `return None`: If "####" is not found, it returns `None`, indicating that an answer could not be extracted in the expected hash format.
    - `text.split("####")[1].strip()`: If "####" is found, it splits the text by "####", takes the second part (index 1) which is expected to be the answer, and removes leading/trailing whitespace.
    - Returns the extracted answer as a string, or `None` if "####" is not found.
    - This function is designed to extract the answer from the GSM8k dataset's original format, which uses "####" as a separator.

```python
# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            #{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
            #{'role': 'assistant', 'content': XML_COT_FORMAT.format(
            #    reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
            #    answer="7"
            #)},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()
```
This section defines a function to load and process the GSM8k dataset:
- `get_gsm8k_questions(split = "train") -> Dataset`:
    - Takes `split` (defaulting to "train") as input, specifying which split of the dataset to load (e.g., "train", "test").
    - `data = load_dataset('openai/gsm8k', 'main')[split]`: Loads the GSM8k dataset from Hugging Face `datasets` library. `'openai/gsm8k'` specifies the dataset and `'main'` likely refers to the dataset configuration (though 'main' is not really relevant for gsm8k). `[split]` selects the desired dataset split. The `# type: ignore` comments are likely to suppress type checking errors that might arise during dataset loading or mapping in some environments.
    - `data = data.map(lambda x: { ... })`: Applies a mapping function to each example in the dataset to transform it.
        - `lambda x: { ... }`:  A lambda function that takes a dataset example `x` and returns a new dictionary representing the processed example.
        - `'prompt'`: Creates a 'prompt' field, which is a list of dictionaries, formatted for conversational models.
            - `{'role': 'system', 'content': SYSTEM_PROMPT}`: Adds the system prompt defined earlier, setting the overall instruction for the model.
            - **Commented Part (1-shot prompting - currently disabled):**
                - `#{'role': 'user', 'content': 'What is the largest single-digit prime number?'}`:  An example question to enable 1-shot learning (showing the model an example in the prompt).
                - `#{'role': 'assistant', 'content': XML_COT_FORMAT.format(...) }`: An example assistant response in the desired XML format for the example question, showing the model the desired output format.
            - `{'role': 'user', 'content': x['question']}`: Adds the actual question from the GSM8k dataset (`x['question']`) as the user's turn in the dialogue.
        - `'answer'`: Creates an 'answer' field by calling `extract_hash_answer(x['answer'])` to extract the numerical answer from the original GSM8k answer format (which is marked by "####").
    - `return data`: Returns the processed `Dataset`.
- `dataset = get_gsm8k_questions()`: Calls the `get_gsm8k_questions` function with the default "train" split to load and process the training dataset, storing it in the `dataset` variable.

```python
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]
```
This section defines several reward functions. Reward functions are crucial in reinforcement learning, guiding the model to behave as desired. Each function takes `completions` (model generated outputs), and potentially other arguments like `prompts` and `answer` in this GRPO setup, and returns a list of reward scores.

- `correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]`:
    - `prompts`: The input prompts given to the model.
    - `completions`: A list of lists, where each inner list contains a dictionary describing the model's completion. Here, we assume it's expecting one completion per prompt and takes the first one `completion[0]`.
    - `answer`: The ground truth answer for the prompt.
    - `**kwargs`: Accepts any additional keyword arguments.
    - `responses = [completion[0]['content'] for completion in completions]`: Extracts the text content of the model's responses from the `completions`.
    - `q = prompts[0][-1]['content']`: Extracts the question from the input prompts. It assumes the question is the last message (`[-1]`) in the first prompt (`[0]`).
    - `extracted_responses = [extract_xml_answer(r) for r in responses]`: Extracts the answer part from the XML-formatted responses using the `extract_xml_answer` function.
    - `print(...)`: Prints debug information: the original question, the ground truth answer, the raw response, and the extracted answer. This is helpful for monitoring the training process.
    - `return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]`: Compares each extracted response `r` with the ground truth answer `a` (assuming `answer` is also a list). Returns a reward of `2.0` if they are the same (correct answer) and `0.0` otherwise. This is a correctness-based reward.

- `int_reward_func(completions, **kwargs) -> list[float]`:
    - `completions`: Model completions as before.
    - `**kwargs`: Accepts any additional keyword arguments.
    - `responses = [completion[0]['content'] for completion in completions]`: Extracts the text content of the responses.
    - `extracted_responses = [extract_xml_answer(r) for r in responses]`: Extracts the answer part from the XML responses.
    - `return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]`: Checks if each extracted response `r` is composed of digits only using `r.isdigit()`. Returns a reward of `0.5` if it's an integer and `0.0` otherwise. This reward encourages the model to output numerical answers.

- `strict_format_reward_func(completions, **kwargs) -> list[float]`:
    - `completions`: Model completions.
    - `**kwargs`: Accepts any additional keyword arguments.
    - `pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"`: Defines a regular expression pattern for a **strict** XML format.
        - `^`: Matches the beginning of the string.
        - `<reasoning>\n.*?\n</reasoning>\n`: Matches `<reasoning>`, newline, any characters non-greedily (`.*?`), newline, and `</reasoning>`, newline.
        - `<answer>\n.*?\n</answer>\n`: Matches `<answer>`, newline, any characters non-greedily, newline, and `</answer>`, newline.
        - `$`: Matches the end of the string.
        - `\n`: Explicitly requires newlines within the XML structure.
    - `responses = [completion[0]["content"] for completion in completions]`: Extracts response content.
    - `matches = [re.match(pattern, r) for r in responses]`: Applies the regex pattern to each response using `re.match()`. `re.match()` only matches at the beginning of the string. It returns a match object if successful, and `None` otherwise.
    - `return [0.5 if match else 0.0 for match in matches]`: Returns a reward of `0.5` if there is a match (response follows the strict format), and `0.0` otherwise. This strongly encourages the model to strictly adhere to the defined XML format with newlines.

- `soft_format_reward_func(completions, **kwargs) -> list[float]`:
    - `completions`: Model completions.
    - `**kwargs`: Accepts any additional keyword arguments.
    - `pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"`: Defines a regular expression pattern for a **softer** XML format.
        - `<reasoning>.*?</reasoning>`: Matches `<reasoning>`, any characters non-greedily, and `</reasoning>`.
        - `\s*`: Matches zero or more whitespace characters (spaces, tabs, newlines, etc.).
        - `<answer>.*?</answer>`: Matches `<answer>`, any characters non-greedily, and `</answer>`.
        - In this pattern, newlines are not explicitly required within the tags, and whitespace is allowed between the tags.
    - `responses = [completion[0]["content"] for completion in completions]`: Extracts response content.
    - `matches = [re.match(pattern, r) for r in responses]`: Applies the regex pattern using `re.match()`.
    - `return [0.5 if match else 0.0 for match in matches]`: Returns a reward of `0.5` if there is a match (response follows the soft format), and `0.0` otherwise. This reward is less strict about the exact format and more lenient towards variations in whitespace and newlines within the XML structure.

- `count_xml(text) -> float`:
    - `text`: The response text.
    - This function is designed to provide a more granular reward based on the presence and structure of XML tags, instead of a binary reward like the format reward functions.
    - `count = 0.0`: Initializes a reward counter to 0.
    - `if text.count("<reasoning>\n") == 1:`: Checks if there is exactly one occurrence of `<reasoning>\n`. If yes, adds 0.125 to the count.
    - `if text.count("\n</reasoning>\n") == 1:`: Checks for one occurrence of `\n</reasoning>\n`. If yes, adds 0.125.
    - `if text.count("\n<answer>\n") == 1:`: Checks for one occurrence of `\n<answer>\n`. If yes, adds 0.125 and then subtracts a penalty based on the length of the text *after* the closing `</answer>\n` tag. This penalty is applied if there's extra text after what looks like the answer, discouraging extraneous output.
    - `if text.count("\n</answer>") == 1:`: Checks for one occurrence of `\n</answer>`. If yes, adds 0.125 and similarly subtracts a penalty for extra text after the `</answer>` tag (adjusting for the missing newline in the tag count).
    - `return count`: Returns the accumulated reward count, which will be between 0 and 0.5 if all four components are correctly present once. The penalties are subtracted in case there's extra text after the answer tags, aiming to encourage concise answers within the expected format.

- `xmlcount_reward_func(completions, **kwargs) -> list[float]`:
    - `completions`: Model completions.
    - `**kwargs`: Accepts any additional keyword arguments.
    - `contents = [completion[0]["content"] for completion in completions]`: Extracts response content.
    - `return [count_xml(c) for c in contents]`: Applies the `count_xml` function to each response content `c` and returns the list of reward scores. This reward function encourages the model to generate responses with the correct XML tags and penalizes extraneous text after the answer.

```python
#model_name = "meta-llama/Llama-3.2-1B-Instruct"
model_name = "Qwen/Qwen2.5-1.5B-Instruct"

if "Llama" in model_name:
    output_dir = "outputs/Llama-1B-GRPO"
    run_name = "Llama-1B-GRPO-gsm8k"
else:
    output_dir="outputs/Qwen-1.5B-GRPO"
    run_name="Qwen-1.5B-GRPO-gsm8k"
```
This section sets up the model name, output directory, and run name for the training.
- `model_name = "Qwen/Qwen2.5-1.5B-Instruct"`: Specifies the pre-trained model to use. Initially, it seems "meta-llama/Llama-3.2-1B-Instruct" was considered but is commented out, and "Qwen/Qwen2.5-1.5B-Instruct" is chosen.
- The `if/else` block dynamically sets `output_dir` (directory to save training outputs) and `run_name` (name for the training run, often used for logging and saving) based on the `model_name`. If "Llama" is in the model name, it uses names specific to Llama-1B, otherwise, it uses names for Qwen-1.5B. This helps organize output files for different models.

```python
training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=16,
    max_prompt_length=256,
    max_completion_length=786,
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    report_to="wandb",
    log_on_each_node=False,
)
```
This section configures the training process using `GRPOConfig` from the `trl` library.  `GRPOConfig` holds all the hyperparameters and settings for the GRPO training.

- `output_dir`: Directory where training outputs (checkpoints, logs, etc.) will be saved (set in the previous block based on model name).
- `run_name`: Name of the training run (set in the previous block). This is often used for logging, e.g., in WandB.
- `learning_rate=5e-6`: The learning rate for the optimizer, a crucial hyperparameter that controls the step size during optimization.
- `adam_beta1=0.9`, `adam_beta2=0.99`:  Beta parameters for the Adam optimizer, which control the exponential decay rates for the first and second moment estimates.
- `weight_decay=0.1`: Weight decay regularization, a technique to prevent overfitting by penalizing large weights.
- `warmup_ratio=0.1`: Ratio of training steps used for learning rate warmup. The learning rate will increase linearly from 0 to the initial `learning_rate` during this warmup phase.
- `lr_scheduler_type='cosine'`: Type of learning rate scheduler. 'cosine' scheduler reduces the learning rate following a cosine curve after the warmup phase.
- `logging_steps=1`: Number of training steps between logging metrics. Set to 1 to log after every step.
- `bf16=True`: Enables bfloat16 mixed precision training, which can speed up training and reduce memory usage on compatible hardware (like NVIDIA GPUs with Ampere architecture or later).
- `per_device_train_batch_size=1`: Batch size per GPU for training. Here it's set to 1.
- `gradient_accumulation_steps=4`: Number of steps to accumulate gradients before performing a backward pass. Effective batch size will be `per_device_train_batch_size * gradient_accumulation_steps` (here, it's 1 * 4 = 4). This is useful when you need a larger effective batch size than can fit into GPU memory at once.
- `num_generations=16`: Number of generations to use per prompt during training and reward calculation in GRPO. This likely means that for each prompt, the model generates 16 completions to be scored by the reward functions.
- `max_prompt_length=256`: Maximum length of the input prompt (number of tokens). Prompts longer than this will likely be truncated.
- `max_completion_length=786`: Maximum length of the generated completion (number of tokens). Completions longer than this will likely be truncated.
- `num_train_epochs=1`: Number of training epochs to run. An epoch is a complete pass through the training dataset. Here it's set to 1, indicating a single pass over the dataset.
- `save_steps=100`: Number of training steps between saving model checkpoints. Model will be saved every 100 steps.
- `max_grad_norm=0.1`: Maximum norm for gradient clipping. Gradients are clipped to this value to prevent exploding gradients during training and improve stability.
- `report_to="wandb"`: Specifies where to report training logs. "wandb" means it will use Weights & Biases (wandb.ai) for experiment tracking.
- `log_on_each_node=False`: In distributed training setups, this controls whether logging is done on each node or just the main node. Set to `False` here.

```python
peft_config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
)
```
This section defines the PEFT (Parameter-Efficient Fine-Tuning) configuration using LoRA (Low-Rank Adaptation). Although it is defined, it's commented out in the `GRPOTrainer` initialization later, so LoRA is currently **not being used** in the training.

- `LoraConfig(...)`: Creates a configuration object for LoRA.
    - `r=16`: Rank of the low-rank matrices in LoRA. A higher rank means more parameters are tuned and potentially better performance but also more resources.
    - `lora_alpha=64`: Scaling factor for LoRA. It controls the magnitude of the LoRA updates.
    - `target_modules=[...]`: List of module names in the transformer model where LoRA adapters will be injected. Here, it targets the attention and MLP layers (`q_proj`, `k_proj`, `v_proj`, `o_proj` for attention, and `up_proj`, `down_proj`, `gate_proj` for MLP in a typical transformer decoder architecture).
    - `task_type="CAUSAL_LM"`: Specifies the task type as causal language modeling.
    - `lora_dropout=0.05`: Dropout probability for LoRA layers.

```python
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map=None
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
```
This part loads the pre-trained model and tokenizer.

- `model = AutoModelForCausalLM.from_pretrained(...)`: Loads the pre-trained causal language model from Hugging Face Hub, specified by `model_name`.
    - `model_name`: The name of the pre-trained model (e.g., "Qwen/Qwen2.5-1.5B-Instruct").
    - `torch_dtype=torch.bfloat16`: Loads the model weights in bfloat16 data type, matching the `bf16=True` training setting.
    - `attn_implementation="flash_attention_2"`: Enables FlashAttention-2, a faster and more memory-efficient attention mechanism, if available and compatible with the model.
    - `device_map=None`: Let PyTorch handle device mapping of the model.
    - `.to("cuda")`: Moves the entire model to the CUDA device (GPU) for training.

- `tokenizer = AutoTokenizer.from_pretrained(model_name)`: Loads the tokenizer associated with the pre-trained model.
- `tokenizer.pad_token = tokenizer.eos_token`: Sets the tokenizer's padding token to be the same as the end-of-sentence token (`eos_token`). This is a common practice for causal language models to ensure consistent padding behavior during training and generation.

```python
# use peft at your own risk; not working for me with multi-GPU training
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
    #peft_config=peft_config
)
trainer.train()
```
This section initializes the `GRPOTrainer` and starts the training process.

- `trainer = GRPOTrainer(...)`: Creates an instance of the `GRPOTrainer` class.
    - `model=model`: Passes the loaded language model to the trainer.
    - `processing_class=tokenizer`: Passes the tokenizer to the trainer, used for tokenizing inputs and outputs.
    - `reward_funcs=[...]`:  A list of reward functions to be used during GRPO training. Here, it uses all the reward functions defined earlier: `xmlcount_reward_func`, `soft_format_reward_func`, `strict_format_reward_func`, `int_reward_func`, and `correctness_reward_func`. The order in the list likely indicates their priority or how they are combined (check trl documentation for details if needed).
    - `args=training_args`: Passes the `GRPOConfig` object containing all training hyperparameters to the trainer.
    - `train_dataset=dataset`: Passes the processed GSM8k dataset to be used for training.
    - `#peft_config=peft_config`:  PEFT configuration is commented out, so LoRA is not used during training in this configuration.

- `trainer.train()`: Starts the GRPO training process using the configured trainer, model, dataset, reward functions, and training arguments. This will train the language model to solve math problems from the GSM8k dataset, guided by the specified reward functions to encourage correctness, integer answers, and output in the desired XML format. The GRPO algorithm presumably uses these reward functions to update the model's policy to maximize the expected rewards.

In [None]:
# train_grpo.py
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

# Load and prep dataset

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            #{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
            #{'role': 'assistant', 'content': XML_COT_FORMAT.format(
            #    reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
            #    answer="7"
            #)},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

#model_name = "meta-llama/Llama-3.2-1B-Instruct"
model_name = "Qwen/Qwen2.5-1.5B-Instruct"

if "Llama" in model_name:
    output_dir = "outputs/Llama-1B-GRPO"
    run_name = "Llama-1B-GRPO-gsm8k"
else:
    output_dir="outputs/Qwen-1.5B-GRPO"
    run_name="Qwen-1.5B-GRPO-gsm8k"

training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=16,
    max_prompt_length=256,
    max_completion_length=786,
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    report_to="wandb",
    log_on_each_node=False,
)
peft_config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map=None
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# use peft at your own risk; not working for me with multi-GPU training
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
    #peft_config=peft_config
)
trainer.train()