<h3 align="center"></h3>

<h1 align="center">Qwen 0.5b on GRPO</h1>

---

<h1 align="center">Training a small math reasoner with RL</h1>

This notebook is an alternate version of the [GRPO demo](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) by [will brown,](https://x.com/willccbb) training llama-1b on the gsm8k math dataset.

We've only implemented a series of changes to make the code more workable on Colab:
* Replacement of llama-1b with Qwen-0.5b
* Generation with vllm, which yields a significant speed-up. Qwen small size makes it possible to run vllm on the same gpu as the one being used for GRPO.
* Dropping flash-attn (recurrent bug with modeling qwen, not clear why)

## Setting up the models.

First we install vllm. Notice that you'll have to restart the session afterwards.

In [6]:
!pip install -U vllm

Collecting vllm
  Downloading vllm-0.7.1-cp38-abi3-manylinux1_x86_64.whl.metadata (12 kB)
Collecting blake3 (from vllm)
  Downloading blake3-1.0.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting transformers>=4.48.2 (from vllm)
  Downloading transformers-4.48.2-py3-none-any.whl.metadata (44 kB)
Collecting openai>=1.52.0 (from vllm)
  Downloading openai-1.61.0-py3-none-any.whl.metadata (27 kB)
Collecting lm-format-enforcer<0.11,>=0.10.9 (from vllm)
  Downloading lm_format_enforcer-0.10.9-py3-none-any.whl.metadata (17 kB)
Collecting outlines==0.1.11 (from vllm)
  Downloading outlines-0.1.11-py3-none-any.whl.metadata (17 kB)
Collecting xgrammar>=0.1.6 (from vllm)
  Downloading xgrammar-0.1.11-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (2.0 kB)
Collecting filelock>=3.16.1 (from vllm)
  Downloading filelock-3.17.0-py3-none-any.whl.metadata (2.9 kB)
Collecting mistral_common>=1.5.0 (from mistral_common[opencv]>=1.5.0->vllm)
  

Then we install trl and datasets. It has to be in this order for some reason (bug on trl if you do vllm afterwards)

In [4]:
!pip install -U trl datasets

Collecting trl
  Downloading trl-0.14.0-py3-none-any.whl.metadata (12 kB)
Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Downloading trl-0.14.0-py3-none-any.whl (313 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
Installing collected packages: datasets, trl
  Attempting uninstall: datasets
    Found existing installation: datasets 2.20.0
    Uninstalling datasets-2.20.0:
      Successfully uninstalled datasets-2.20.0
  Attempting uninstall: trl
    Found existing installation: trl 0.8.6
    Uninstalling trl-0.8.6:
      Successfully uninstalled trl-0.8.6
Successfully installed datasets-3.2.0 trl-0.14.0


## Defining the RL rewards

Now we have everything ready to set up our RL training set and reward policy.

First we set the general prompt structure (with the reasoning tags).

In [1]:
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer

# Load and prep dataset

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

  from .autonotebook import tqdm as notebook_tqdm


INFO 02-02 05:36:22 __init__.py:183] Automatically detected platform cuda.


2025-02-02 05:36:22,357	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Now we import the gsm8k dataset and restructure it to fit into a conversational prompt format:

In [2]:
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

We move on now to the reward functions. The most important one is the "correctness" function which acts as a verifier (comparison of model completions vs. answer). The three others are formatting functions.

In [3]:
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

We now set the training arguments:

In [4]:
model_name = "allenai/OLMoE-1B-7B-0924"

output_dir="outputs/OLMoE-1B-7B-0924-GRPO"
run_name="OLMoE-1B-7B-0924-GRPO-gsm8k"

training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=16,
    max_prompt_length=256,
    max_completion_length=200,
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    log_on_each_node=False,
    use_vllm=True,
    vllm_gpu_memory_utilization=.3,
    vllm_device="cuda:0",
    report_to="none" #I'm disabling Wandb.
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=None
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.91it/s]


And launch the actual training:

In [5]:
# use peft at your own risk; not working for me with multi-GPU training
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
    #peft_config=peft_config
)
trainer.train()



INFO 02-02 05:36:40 config.py:526] This model supports multiple tasks: {'embed', 'reward', 'classify', 'score', 'generate'}. Defaulting to 'generate'.
INFO 02-02 05:36:40 llm_engine.py:232] Initializing a V0 LLM engine (v0.7.1) with config: model='allenai/OLMoE-1B-7B-0924', speculative_config=None, tokenizer='allenai/OLMoE-1B-7B-0924', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda:0, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=allenai/OLMoE-1B-7B-0924, num_sche

Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  33% Completed | 1/3 [00:01<00:02,  1.00s/it]
Loading safetensors checkpoint shards:  67% Completed | 2/3 [00:02<00:01,  1.04s/it]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:02<00:00,  1.07it/s]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:02<00:00,  1.04it/s]



INFO 02-02 05:36:44 model_runner.py:1116] Loading model weights took 12.8890 GB
INFO 02-02 05:36:45 worker.py:266] Memory profiling takes 0.94 seconds
INFO 02-02 05:36:45 worker.py:266] the current vLLM instance can use total_gpu_memory (47.38GiB) x gpu_memory_utilization (0.30) = 14.21GiB
INFO 02-02 05:36:45 worker.py:266] model weights take 12.89GiB; non_torch_memory takes 0.08GiB; PyTorch activation peak memory takes 0.48GiB; the rest of the memory reserved for KV Cache is 0.76GiB.
INFO 02-02 05:36:46 executor_base.py:108] # CUDA blocks: 390, # CPU blocks: 2048
INFO 02-02 05:36:46 executor_base.py:113] Maximum concurrency for 4096 tokens per request: 1.52x
INFO 02-02 05:36:47 model_runner.py:1435] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_util

Capturing CUDA graph shapes: 100%|██████████| 35/35 [00:20<00:00,  1.68it/s]

INFO 02-02 05:37:08 model_runner.py:1563] Graph capturing finished in 21 secs, took 0.23 GiB
INFO 02-02 05:37:08 llm_engine.py:429] init engine (profile, create kv cache, warmup model) took 23.90 seconds





ValueError: Cannot use chat template functions because tokenizer.chat_template is not set and no template argument was passed! For information about writing templates and setting the tokenizer.chat_template attribute, please see the documentation at https://huggingface.co/docs/transformers/main/en/chat_templating