In [None]:
import os
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"  # To get extra 30% context length

# Install dependencies
!pip install unsloth_zoo
!uv pip install --upgrade unsloth vllm==0.9.2 numpy torchvision bitsandbytes xformers
!uv pip install triton==3.2.0
!uv pip install transformers==4.55.4
!uv pip install --no-deps trl==0.22.2

In [None]:
from unsloth import FastLanguageModel
import torch

# Context length
max_seq_length = 1024

# Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen2.5-3B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True,
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = 8,
    gpu_memory_utilization = 0.9,
)

In [3]:
# LoRA for parameter-efficient fine-tuning
model = FastLanguageModel.get_peft_model(
    model,
    r = 8,
    # Modules to fine-tune
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = 8,
    use_gradient_checkpointing = "unsloth",
    random_state = 1234,
)

Unsloth 2025.9.6 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


In [None]:
import re
from datasets import load_dataset, Dataset

# System prompt
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

# Template for wrapping the reasoning and answer
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

# Function to extract the text inside <answer>...</answer> from a model output
def extract_xml_answer(text):
    if "<answer>" not in text or "</answer>" not in text:
        return ""
    return text.split("<answer>", 1)[-1].split("</answer>", 1)[0].strip()

# Function to extract the correct answer from GSM8K labels that are in the form '... #### final_answer'
def extract_hash_answer(text):
    return text.split("####")[-1].strip() if "####" in text else None

# Function to load the GSM8K dataset and format it as chat-style prompts
def get_gsm8k_dataset(split = "train"):
    data = load_dataset("openai/gsm8k", "main")[split]
    return data.map(
        lambda x: {
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": x["question"]},
            ],
            "answer": extract_hash_answer(x["answer"]),
        }
    )

dataset = get_gsm8k_dataset()

In [6]:
# Reward function that checks if the extracted answer from the completion
# exactly matches the given ground truth answer.
# Returns 2.0 for a correct match, otherwise 0.0.
def correctness_reward_func(prompts, completions, answer, **kwargs):
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

# Reward function that checks if the extracted response is an integer.
# Returns 0.5 if it is a digit, otherwise 0.0.
def int_reward_func(completions, **kwargs):
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

# Reward function that enforces a strict XML format where
# the response must match the exact structure:
# <reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n
# Returns 0.5 if the format is correct, otherwise 0.0.
def strict_format_reward_func(completions, **kwargs):
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

# Reward function that enforces a softer XML format check:
# The response must contain <reasoning>...</reasoning> and <answer>...</answer>,
# but allows for some flexibility in spacing and newlines.
# Returns 0.5 if pattern is matched, otherwise 0.0.
def soft_format_reward_func(completions, **kwargs):
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

# Helper function that counts and scores XML tag occurrences
def count_xml(text):
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

# Reward function that applies the count_xml to completions.
# Scores XML structure based on correct tags, with penalties for trailing content.
def xmlcount_reward_func(completions, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

In [19]:
from trl import GRPOConfig, GRPOTrainer

# Training arguments
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 4,
    gradient_accumulation_steps = 1,
    num_generations = 4,
    max_prompt_length = 256,
    max_completion_length = 200,
    max_steps = 250,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none",
    output_dir = "outputs",
)

In [20]:
# GRPO Trainer
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)

In [None]:
# Start training
trainer.train()

In [22]:
# Save LoRA
model.save_lora("grpo_saved_lora")

In [23]:
from vllm import SamplingParams

# Inference with model before training
query = "How many r's are in strawberry?"

text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : query},
], tokenize = False, add_generation_prompt = True)

sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)

output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

print(output)

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

In the word "strawberry," there is one occurrence of the letter "r."


In [24]:
# Inference with model after training
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : query},
], tokenize = False, add_generation_prompt = True)

sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)

output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

print(output)

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

<reasoning>
To find out how many times the letter 'r' appears in the word 'strawberry', I will go through the word character by character and count each occurrence of 'r'.
</reasoning>
<answer>
I found that the letter 'r' appears 3 times in the word 'strawberry'. 
</answer>
