### Installation

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm==0.8.5.post1

In [None]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm==0.8.5.post1
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

### Unsloth

In [None]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 16 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "Qwen/Qwen2.5-14B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank * 2, # *2 speeds up training
    use_gradient_checkpointing = True, # Reduces memory usage
    random_state = 3407,
)

### GRPO chat template


In [None]:
reasoning_start = "<start_working_out>"
reasoning_end   = "<end_working_out>"
solution_start  = "<SOLUTION>"
solution_end    = "</SOLUTION>"

system_prompt = \
f"""You are given a problem.
Think about the problem and provide your working out.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}"""
system_prompt

In [None]:
chat_template = \
    "{% if messages[0]['role'] == 'system' %}"\
        "{{ messages[0]['content'] + eos_token }}"\
        "{% set loop_messages = messages[1:] %}"\
    "{% else %}"\
        "{{ '{system_prompt}' + eos_token }}"\
        "{% set loop_messages = messages %}"\
    "{% endif %}"\
    "{% for message in loop_messages %}"\
        "{% if message['role'] == 'user' %}"\
            "{{ message['content'] }}"\
        "{% elif message['role'] == 'assistant' %}"\
            "{{ message['content'] + eos_token }}"\
        "{% endif %}"\
    "{% endfor %}"\
    "{% if add_generation_prompt %}{{ '{reasoning_start}' }}"\
    "{% endif %}"

# Replace with out specific template:
chat_template = chat_template\
    .replace("'{system_prompt}'",   f"'{system_prompt}'")\
    .replace("'{reasoning_start}'", f"'{reasoning_start}'")
tokenizer.chat_template = chat_template

In [None]:
tokenizer.apply_chat_template([
    {"role" : "user", "content" : "What is 1+1?"},
    {"role" : "assistant", "content" : f"{reasoning_start}I think it's 2.{reasoning_end}{solution_start}2{solution_end}"},
    {"role" : "user", "content" : "What is 2+2?"},
], tokenize = False, add_generation_prompt = True)

### Pre fine-tuning for formatting


In [None]:
from datasets import load_dataset
import pandas as pd
import numpy as np

dataset = load_dataset("gsm8k", "main", split="train")
dataset = dataset.to_pandas()[
    ["answer", "question"]
]

# Extract numeric answer from the answer field (GSM8K answers are in format "#### 123")
def extract_numeric_answer(answer_text):
    if "####" in answer_text:
        return answer_text.split("####")[-1].strip()
    return answer_text

dataset["expected_answer"] = dataset["answer"].apply(extract_numeric_answer)
dataset["problem"] = dataset["question"]

# Try converting to number - if not, replace with NaN
is_number = pd.to_numeric(pd.Series(dataset["expected_answer"]), errors="coerce").notnull()
# Select only numbers
dataset = dataset.iloc[np.where(is_number)[0]]
dataset

We have to format the dataset to follow our GRPO style formatting:

In [None]:
def format_dataset(x):
    expected_answer = x["expected_answer"]
    problem = x["problem"]

    # For GSM8K, we need to extract the reasoning from the full answer
    # GSM8K format: reasoning text followed by "#### [number]"
    full_answer = x["answer"]
    if "####" in full_answer:
        thoughts = full_answer.split("####")[0].strip()
    else:
        thoughts = full_answer.strip()

    # Add our custom formatting
    final_prompt = \
        reasoning_start + thoughts + reasoning_end + \
        solution_start + expected_answer + solution_end

    return [
        {"role" : "system",    "content" : system_prompt},
        {"role" : "user",      "content" : problem},
        {"role" : "assistant", "content" : final_prompt},
    ]

dataset["Messages"] = dataset.apply(format_dataset, axis = 1)

Check to see if it worked:

In [None]:
tokenizer.apply_chat_template(dataset["Messages"][0], tokenize = False)

Truncate the pre fine-tuning dataset to `max_seq_length/2` since we don't want too long reasoning traces.

In [None]:
dataset["N"] = dataset["Messages"].apply(lambda x: len(tokenizer.apply_chat_template(x)))

dataset = dataset.loc[dataset["N"] <= max_seq_length/2].copy()
dataset.shape

Keep only 78 examples

In [None]:
dataset = dataset.sample(n=78, random_state=42)
dataset.shape

We then tokenize the messages and convert it to a Hugging Face compatible dataset format:

In [None]:
from datasets import Dataset

dataset["text"] = tokenizer.apply_chat_template(dataset["Messages"].values.tolist(), tokenize = False)
dataset = Dataset.from_pandas(dataset)
dataset

Let's now pre fine-tune the model so it follows our custom GRPO formatting!

In [None]:
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 1, # Use GA to mimic batch size!
        warmup_steps = 5,
        num_train_epochs = 2, # Set this for 1 full training run.
        learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
        logging_steps = 5,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none", # Use this for WandB etc
    ),
)

In [None]:
trainer.train()

Let's check if the model has learnt to follow the custom format:

In [None]:
text = tokenizer.apply_chat_template(
    dataset[0]["Messages"][:2],
    tokenize = False,
    add_generation_prompt = True, # Must add for generation
)

from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    temperature = 0.01,
    max_new_tokens = 1024,
    streamer = TextStreamer(tokenizer, skip_prompt = False),
)

Yes it did follow the formatting! Great! Let's remove some items before the GRPO step

In [None]:
del dataset
torch.cuda.empty_cache()
import gc
gc.collect()

### Data Prep


In [None]:
from datasets import load_dataset
dataset = load_dataset("gsm8k", "main", split="train")

dataset = dataset.train_test_split(train_size=0.15, seed=42)["train"]
dataset

Let's look at the first row:

In [None]:
# Create evaluation dataset from a portion of your 15% dataset
eval_split = dataset.train_test_split(test_size=0.1, seed=42)  # 10% for eval, 90% for train
train_dataset = eval_split["train"]
eval_dataset = eval_split["test"]

print(f"Train dataset size: {len(train_dataset)}")
print(f"Eval dataset size: {len(eval_dataset)}")

In [None]:
dataset[0]["question"]

In [None]:
dataset[0]["answer"]

In GSM8K, we notice all answers like about have a ####, so we extract it.

In [None]:
def extract_hash_answer(text):
    # if "####" not in text: return None
    return text.split("####")[1].strip()
    #return text
extract_hash_answer(dataset[0]["answer"])

Let's map the dataset! and see the first row:

In [None]:
# Apply the mapping to both datasets
train_dataset = train_dataset.map(lambda x: {
    "prompt": [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": x["question"]},
    ],
    "answer": extract_hash_answer(x["answer"]),
})

eval_dataset = eval_dataset.map(lambda x: {
    "prompt": [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": x["question"]},
    ],
    "answer": extract_hash_answer(x["answer"]),
})

# Check format
print("Train dataset example:")
print(train_dataset[0])
print("\nEval dataset example:")
print(eval_dataset[0])

We create a regex format to match the reasoning sections and answers:

In [None]:
import re

# Add optional EOS token matching
solution_end_regex = r"</SOLUTION>[\s]{0,}" + \
    "(?:" + re.escape(tokenizer.eos_token) + ")?"

match_format = re.compile(
    rf"{reasoning_end}.*?"\
    rf"{solution_start}(.+?){solution_end_regex}"\
    rf"[\s]{{0,}}$",
    flags = re.MULTILINE | re.DOTALL
)
match_format

We verify it works:

In [None]:
match_format.findall(
    "Let me think!<end_working_out>"\
    f"<SOLUTION>\n2\n</SOLUTION>",
)

In [None]:
match_format.findall(
    "<start_working_out>Let me think!<end_working_out>"\
    f"<SOLUTION>  2  </SOLUTION>\n\n",
)

We now want to create a reward function to match the format exactly - we reward it with 3 points if it succeeds:

In [None]:
def match_format_exactly(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Match if format is seen exactly!
        if match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores

If it fails, we want to reward the model if it at least follows the format partially, by counting each symbol:

In [None]:
def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Count how many keywords are seen - we penalize if too many!
        # If we see 1, then plus some points!

        # No need to reward <start_working_out> since we always prepend it!
        # score += 0.5 if response.count(reasoning_start) == 1 else -1.0
        score += 0.5 if response.count(reasoning_end)   == 1 else -1.0
        score += 0.5 if response.count(solution_start)  == 1 else -1.0
        score += 0.5 if response.count(solution_end)    == 1 else -1.0
        scores.append(score)
    return scores

Finally, we want to extract the generated answer, and reward or penalize it! We also reward it based on how close the answer is to the true one via ratios:

In [None]:
def check_answer(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_format.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(-2.0)
            continue
        # Correct answer gets 5 points!
        if guess == true_answer:
            score += 5.0
        # Match if spaces are seen, but less reward
        elif guess.strip() == true_answer.strip():
            score += 3.5
        else:
            # We also reward it if the answer is close via ratios!
            # Ie if the answer is within some range, reward it!
            try:
                ratio = float(guess) / float(true_answer)
                if   ratio >= 0.9 and ratio <= 1.1: score += 2.0
                elif ratio >= 0.8 and ratio <= 1.2: score += 1.5
                else: score -= 2.5 # Penalize wrong answers
            except:
                score -= 4.5 # Penalize
        scores.append(score)
    return scores

Also sometimes it might not be 1 number as the answer, but like a sentence for example "The solution is $20" -> we extract 20.

We also remove possible commas for example as in 123,456

In [None]:
match_numbers = re.compile(
    solution_start + r".*?[\s]{0,}([-]?[\d\.\,]{1,})",
    flags = re.MULTILINE | re.DOTALL
)
print(match_numbers.findall("<SOLUTION>  0.34  </SOLUTION>"))
print(match_numbers.findall("<SOLUTION>  123,456  </SOLUTION>"))
print(match_numbers.findall("<SOLUTION>  -0.234  </SOLUTION>"))
print(match_numbers.findall("<SOLUTION>17</SOLUTION>"))

We now prepare our main function which will print out the generated responses and the true answer, along with another reward function which converts text to float via `float` and sees if it's the same.

In [None]:
global PRINTED_TIMES
PRINTED_TIMES = 0
global PRINT_EVERY_STEPS
PRINT_EVERY_STEPS = 5

def check_numbers(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_numbers.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    # Print only every few steps
    global PRINTED_TIMES
    global PRINT_EVERY_STEPS
    if PRINTED_TIMES % PRINT_EVERY_STEPS == 0:
        print(
            '*'*20 + f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}"
        )
    PRINTED_TIMES += 1

    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(-2.5)
            continue
        # Convert to numbers
        try:
            true_answer = float(true_answer.strip())
            # Remove commas like in 123,456
            guess       = float(guess.strip().replace(",", ""))
            scores.append(3.5 if guess == true_answer else -1.5)
        except:
            scores.append(0)
            continue
    return scores

Get the top 90% prompt length so we don't accidentally truncate them!

We'll remove the top 10% long prompts.

In [None]:
# Filter train_dataset
tokenized_train = train_dataset.map(
    lambda x: {"tokens": tokenizer.apply_chat_template(x["prompt"], add_generation_prompt=True, tokenize=True)}
)
tokenized_train = tokenized_train.map(lambda x: {"L": len(x["tokens"])})

# Filter eval_dataset
tokenized_eval = eval_dataset.map(
    lambda x: {"tokens": tokenizer.apply_chat_template(x["prompt"], add_generation_prompt=True, tokenize=True)}
)
tokenized_eval = tokenized_eval.map(lambda x: {"L": len(x["tokens"])})

# Use the same maximum_length threshold for both
import numpy as np
# You can recalculate or reuse the previous maximum_length
maximum_length = int(np.quantile(tokenized_train["L"], 0.9))  # Or use your previous value: 136
print("Max Length = ", maximum_length)

# Filter both datasets
train_dataset = train_dataset.select(np.where(np.array(tokenized_train["L"]) <= maximum_length)[0])
eval_dataset = eval_dataset.select(np.where(np.array(tokenized_eval["L"]) <= maximum_length)[0])

# Clean up
del tokenized_train, tokenized_eval

print(f"Filtered train dataset size: {len(train_dataset)}")
print(f"Filtered eval dataset size: {len(eval_dataset)}")

<a name="Train"></a>
### Train the model

Now set up GRPO Trainer and all configurations!

In [None]:
max_prompt_length = maximum_length + 1 # + 1 just in case!
max_completion_length = max_seq_length - max_prompt_length

from trl import GRPOConfig, GRPOTrainer
from vllm import SamplingParams

vllm_sampling_params = SamplingParams(
    min_p = 0.1,
    top_p = 1.0,
    top_k = -1,
    seed = 3407,
    stop = [tokenizer.eos_token],
    include_stop_str_in_output = True,
)

training_args = GRPOConfig(
    vllm_sampling_params = vllm_sampling_params,
    temperature = 1.0,
    learning_rate = 5e-6,
    weight_decay = 0.01,
    warmup_ratio = 0.1,
    lr_scheduler_type = "linear",
    optim = "adamw_8bit",

    logging_steps = 2,
    logging_first_step = True,

    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4,
    num_generations = 4,
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,
    num_train_epochs = 1,
    #max_steps = 172,

    # Exactly 4 checkpoints at 25%, 50%, 75%, 100%
    save_steps = 57,
    eval_steps = 57,

    report_to = "none", # Can use wandb
    output_dir = "outputs_14B", # Change based on model

    per_device_eval_batch_size = 4,
    eval_accumulation_steps = 1,
    eval_strategy = "steps",
)

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    args = training_args,
    train_dataset = train_dataset,
    eval_dataset = eval_dataset
)

trainer.train()

In [None]:
import shutil
import os
from google.colab import files
from zipfile import ZipFile

def download_lora_checkpoint(checkpoint_folder):
    zip_filename = os.path.basename(checkpoint_folder) + "_grpo.zip"

    files_to_keep = [
        "adapter_model.safetensors",
        "adapter_config.json",
        "tokenizer.json",
        "tokenizer_config.json",
        "vocab.json",
        "merges.txt",
        "special_tokens_map.json",
        "added_tokens.json",
        "chat_template.jinja",
        "trainer_state.json",
        "training_args.bin",
    ]

    with ZipFile(zip_filename, 'w') as zipf:
        for file_name in files_to_keep:
            file_path = os.path.join(checkpoint_folder, file_name)
            if os.path.exists(file_path):
                zipf.write(file_path, os.path.join(os.path.basename(checkpoint_folder), file_name))

    files.download(zip_filename)

download_lora_checkpoint("outputs_14B/checkpoint-171") # Select the best checkpoint