In [None]:
%%capture
# Skip restarting message in Colab
import sys; modules = list(sys.modules.keys())
for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None

!pip install "unsloth==2025.2.12" vllm
!pip install --upgrade pillow
# If you are running this notebook on local, you need to install `diffusers` too
# !pip install diffusers
# Temporarily install a specific TRL nightly version
!pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b

In [None]:
from datasets import load_dataset

ds = load_dataset("open-r1/OpenR1-Math-Raw")

In [None]:
# Reduce the size of training set so the script won't run forever
print("Before", len(ds["train"]))
ds["train"] = ds["train"].select(range(10_000))
print("After", len(ds["train"]))

In [None]:
from unsloth import FastLanguageModel, PatchFastRL

PatchFastRL("GRPO", FastLanguageModel)

In [None]:
from unsloth import is_bfloat16_supported
import torch
max_seq_length = 512 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

In [None]:
import re
from datasets import load_dataset, Dataset
import pandas as pd
import json
from tqdm import tqdm

SYSTEM_PROMPT = """\
Respond in the following JSON format:
{
  "reasoning": "Chain of thought to achieve the goal",
  "answer": "Final answer"
}
"""

def format_to_json(reasoning: str, answer: str) -> str:
    """
    Just replace placeholders to create JSON formatted string
    """
    result = {
        "reasoning": reasoning,
        "answer": answer,
    }
    return json.dumps(result)

def get_questions(dataset, split="train") -> Dataset:
    """
    Processes each item of dataset.

    For each problem in the dataset:
    1. Creates a prompt with two messages:
       - A system message containing the response formatting instructions.
       - A user message containing the question.
    2. Extracts the answer using the 'format_to_json' function.

    Returns:
        The processed dataset with each entry containing the formatted prompt and extracted answer.
    """
    data = dataset[split]
    result: list[dict] = []
    for item in tqdm(data):
        if not item["answer"]:
          continue
        result.append(
            {
                "question": item["problem"],
                "prompt": [
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": item["problem"]},
                ],
                "answer": item["answer"],
            }
        )
    # list of dicts to dataset
    return Dataset.from_pandas(pd.DataFrame(data=result))

In [None]:
dataset = get_questions(ds)

In [None]:
from pprint import pprint

pprint(dataset[0])

In [None]:
import json
import re

# Reward function for checking strict JSON format adherence
def strict_json_format_reward_func(completions, **kwargs) -> list[float]:
    """
    Rewards responses that strictly follow the required JSON format with "reasoning" and "answer" fields.

    Steps:
    - Attempts to parse the response as JSON.
    - Checks if the JSON contains exactly two keys: "reasoning" and "answer".
    - Returns a reward of 1.0 if the format is correct, otherwise 0.0.
    """
    responses = [completion[0]["content"] for completion in completions]
    scores = []
    for response in responses:
        try:
            parsed = json.loads(response)
            if isinstance(parsed, dict) and list(parsed.keys()) == ["reasoning", "answer"]:
                scores.append(1.0)
            else:
                scores.append(0.0)
        except json.JSONDecodeError:
            scores.append(0.0)
    return scores

# Reward function for correctness of answer
def correctness_json_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """
    Rewards responses that provide the correct answer.

    Steps:
    - Extracts the "answer" field from the JSON response.
    - Compares it to the expected answer.
    - Assigns a reward of 2.0 if the answer is correct, otherwise 0.0.
    """
    responses = [completion[0]["content"] for completion in completions]
    extracted_answers = []
    # parsed_answers = [json.loads(a)["answer"] for a in answer]

    for response in responses:
        try:
            parsed = json.loads(response)
            extracted_answers.append(str(parsed.get("answer", "")).strip())
        except json.JSONDecodeError:
            extracted_answers.append("")

    return [2.0 if str(r) == str(a) else 0.0 for r, a in zip(extracted_answers, answer)]

# Reward function for reasoning quality
def reasoning_length_reward_func(completions, **kwargs) -> list[float]:
    """
    Rewards responses that provide a sufficiently detailed reasoning section.

    Steps:
    - Extracts the "reasoning" field from the JSON response.
    - Measures its length and assigns a reward based on the number of words.
    - Rewards 1.0 for at least 10 words, with a max of 1.5 for 30+ words.
    """
    responses = [completion[0]["content"] for completion in completions]
    scores = []

    for response in responses:
        try:
            parsed = json.loads(response)
            reasoning = parsed.get("reasoning", "").strip()
            word_count = len(reasoning.split())
            if word_count >= 30:
                scores.append(1.5)
            elif word_count >= 10:
                scores.append(1.0)
            else:
                scores.append(0.5 if word_count > 0 else 0.0)
        except json.JSONDecodeError:
            scores.append(0.0)

    return scores

# Reward function for JSON validity
def valid_json_reward_func(completions, **kwargs) -> list[float]:
    """
    Rewards responses that are valid JSON regardless of content accuracy.

    Steps:
    - Attempts to parse the response as JSON.
    - Returns 1.0 if valid JSON, otherwise 0.0.
    """
    responses = [completion[0]["content"] for completion in completions]
    return [1.0 if is_valid_json(r) else 0.0 for r in responses]

def is_valid_json(text: str) -> bool:
    """Helper function to check if a string is valid JSON."""
    try:
        json.loads(text)
        return True
    except json.JSONDecodeError:
        return False

# Reward function for penalizing extra content
def no_extra_content_reward_func(completions, **kwargs) -> list[float]:
    """
    Rewards responses that contain only JSON and no extra text.

    Steps:
    - Uses regex to check if the response starts and ends with a valid JSON object.
    - Returns 1.0 if there's no extra content, otherwise 0.0.
    """
    responses = [completion[0]["content"] for completion in completions]
    pattern = r'^\s*\{.*\}\s*$'
    return [1.0 if re.match(pattern, r, re.DOTALL) else 0.0 for r in responses]


In [None]:
from trl import GRPOConfig, GRPOTrainer

training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 3, # Decrease if out of memory
    max_prompt_length = 256,
    max_completion_length = 200,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 250,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        strict_json_format_reward_func,
        correctness_json_reward_func,
        reasoning_length_reward_func,
        valid_json_reward_func,
        no_extra_content_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()