In [1]:
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
from transformers import TrainingArguments
from transformers import AutoModelForCausalLM
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType

  from .autonotebook import tqdm as notebook_tqdm


# get data

In [2]:
import re
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('./gsm8k', data_dir='main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

# get model

In [3]:
model_name_or_path = "./Qwen2-0.5B-Instruct"
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, inference_mode=False, r=4, lora_alpha=32, lora_dropout=0.1
)

model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


trainable params: 270,336 || all params: 494,303,104 || trainable%: 0.0547


# get trainer

In [16]:
def reward_num_unique_chars(completions, **kwargs):
    return [len(set(c)) for c in completions]

grpo_config = GRPOConfig(
    output_dir="./output_grpo",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=2,
    weight_decay=0.01,
    save_total_limit=2,
    report_to="tensorboard",
    # bf16=True,
    # num_workers=8
)


# training_args = GRPOConfig(
#     output_dir=output_dir,
#     run_name=run_name,
#     learning_rate=5e-6,
#     adam_beta1 = 0.9,
#     adam_beta2 = 0.99,
#     weight_decay = 0.1,
#     warmup_ratio = 0.1,
#     lr_scheduler_type='cosine',
#     logging_steps=1,
#     bf16=True,
#     per_device_train_batch_size=1,
#     gradient_accumulation_steps=4,
#     num_generations=8,
#     max_prompt_length=256,
#     max_completion_length=200,
#     num_train_epochs=1,
#     save_steps=100,
#     max_grad_norm=0.1,
#     log_on_each_node=False,
    # use_vllm=True,
    # vllm_gpu_memory_utilization=.2,
    # vllm_device="cuda:0",
    # report_to="none" #I'm disabling Wandb.
# )

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func],
    args=grpo_config,
    train_dataset=dataset,
    #peft_config=peft_config
)

# trainer = GRPOTrainer(
#     model=model,
#     reward_funcs=reward_num_unique_chars,
#     args=grpo_config,
#     train_dataset=dataset,
# )

trainer.train()

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


-------------------- Question:
Ahmed and Emily are having a contest to see who can get the best grade in the class. There have been 9 assignments and Ahmed has a 91 in the class. Emily has a 92. The final assignment is worth the same amount as all the other assignments. Emily got a 90 on the final assignment. What is the minimum grade Ahmed needs to get to beat Emily if all grades are whole numbers? 
Answer:
100 
Response:
To find the minimum grade Ahmed needs to beat Emily, we first need to determine the total points available for Emily as a whole number. Ahmed has a 91 and Emily has a 92, so together their total is \(91 + 92 = 183\) points.

Emily's first assignment is worth the same as all the other assignments, which is 90 points. Since all grades are whole numbers, we divide Emily's total points by 9:

\(183 \div 9 = 20\) (rounded down to the nearest whole number because she cannot get a fraction of an assignment).

For the final assignment, we use the remaining 183 points to find

Step,Training Loss


-------------------- Question:
Marie has 98 unread messages on her phone. She decides to clear them by reading 20 messages a day. However, she also gets 6 new messages a day. How many days will it take her to read all her unread messages? 
Answer:
7 
Response:
To solve this problem, we can use the concept of fractional arithmetic. In this case, we can represent the initial state of her phone as m messages unread and the final state as n messages unread.

Initially, Marie has 98 unread messages, so the system or equations can be represented as follows:

m = 98

Each day, she reads 20 messages, but also, she gets 6 new messages. So, the net messages she reads each day is:

20 - 6 = 14 messages

To see how many days it would take her to read all her unread messages, we can use the formula:

n = m / (14 + 6)

Let's calculate this:

1. Compute the initial number of unread messages (m):
\[ m = m_{initial} = 98 \]
\[ m_{initial} / 14 + 6 = 98 / 14 + 6 = 7 + 6 = 13 \]

2. Calculate the final n

KeyboardInterrupt: 