# Bonus Activity - Unsloth GRPO Training on Open R1 Math Raw

### 1. Insall Unsloth for Collab

In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm==0.7.3
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm==0.7.3

In [2]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft "trl==0.15.2" triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

In [3]:
from unsloth import FastLanguageModel
from trl import GRPOConfig, GRPOTrainer
from vllm import SamplingParams
from google.colab import drive
import re
import json
import torch
from google.colab import files
import numpy as np
from transformers import AutoModel, AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
import wandb
from datasets import load_dataset, Dataset
import gc

# Mount Drive
drive.mount('/content/drive')

# Initialize wandb
wandb.login()

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 05-06 07:57:11 __init__.py:207] Automatically detected platform cuda.
Mounted at /content/drive


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjulien-thomazo[0m ([33mjulien-thomazo-inria[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

### Helper functions to extract information and formatting system prompt

In [4]:
SYSTEM_PROMPT = """
Solve the following mathematical problem step by step. Your response must follow this format:

<reasoning>
Provide a clear, step-by-step solution to the problem. Include all mathematical steps and logic.
</reasoning>

<answer>
Write your final answer here concisely.
</answer>
"""

# 3. Improved XML Parsing Functions with Error Handling
def extract_xml_answer(text: str) -> str:
    """Extract text between <answer> and </answer> tags with robust error handling."""
    try:
        if "<answer>" not in text or "</answer>" not in text:
            return ""
        answer_start = text.find("<answer>") + len("<answer>")
        answer_end = text.find("</answer>", answer_start)
        if answer_end == -1:  # If closing tag not found
            return ""
        return text[answer_start:answer_end].strip()
    except Exception:
        # Return empty string on any error
        return ""

def extract_xml_reasoning(text: str) -> str:
    """Extract text between <reasoning> and </reasoning> tags with robust error handling."""
    try:
        if "<reasoning>" not in text or "</reasoning>" not in text:
            return ""
        reasoning_start = text.find("<reasoning>") + len("<reasoning>")
        reasoning_end = text.find("</reasoning>", reasoning_start)
        if reasoning_end == -1:  # If closing tag not found
            return ""
        return text[reasoning_start:reasoning_end].strip()
    except Exception:
        # Return empty string on any error
        return ""

def clean_math_text(text: str) -> str:
    """Clean mathematical text by normalizing spacing and preserving LaTeX."""
    if not text:
        return ""
    try:
        # Normalize whitespace
        text = re.sub(r'\s+', ' ', text)
        # Preserve LaTeX delimiters
        text = re.sub(r'(\$+)(.*?)(\$+)', lambda m: m.group(1) + m.group(2).replace(' ', '') + m.group(3), text)
        return text.strip()
    except Exception:
        # Return original text on any error
        return text

### Streamlined Dataset Processing

In [5]:
def get_open_r1_math_data(split="train", max_train_examples=1000, max_test_examples=50, random_seed=42):
    """
    Load the OpenR1-Math-Raw dataset with improved processing.
    """
    streaming_dataset = load_dataset("open-r1/OpenR1-Math-Raw", split=split, streaming=True)

    filtered_dataset = streaming_dataset.filter(
        lambda x: x["problem_is_valid"] == "Yes"  and
                  x["solution_is_valid"] == "Yes" and
                  "proof" not in x["answer"]
    )

    transformed_dataset = filtered_dataset.map(
        lambda x: {
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": x["problem"]}
            ],
            "expected_answer": x["answer"], # target to achieve
            "reference_solution": x["solution"], # reasoning steps
        }
    )

    shuffled_dataset = transformed_dataset.shuffle(buffer_size=10000, seed=random_seed)

    train_stream = shuffled_dataset.take(max_train_examples)
    train_dataset = Dataset.from_list(list(train_stream))

    test_stream = shuffled_dataset.take(max_test_examples)
    test_dataset = Dataset.from_list(list(test_stream))

    print(f"Dataset loaded: {len(train_dataset)} training examples, {len(test_dataset)} test examples")

    return train_dataset, test_dataset

### Reward functions

We define several reward functions:
- for checking the correctnes of the answer
- for checking  the reasoning length, the cpresence of mathematical symbols and LaTeX expressions
- for checking  the compliance with the expected xml format

In [6]:
def correctness_reward_func(completions, **kwargs) -> list[float]:
    """
    Reward the model for generating the correct answer.
    This function has the highest weight (2.0) to prioritize correctness.
    """
    responses = [completion[0]["content"] for completion in completions]
    extracted_answers = [extract_xml_answer(r) for r in responses]

    # Extract the expected answers from kwargs
    expected_answers = kwargs.get("expected_answer", [""] * len(completions))
    if not isinstance(expected_answers, list):
        expected_answers = [expected_answers] * len(completions)

    # Clean up answers for comparison
    expected_answers_clean = [a.strip() if a else "" for a in expected_answers]
    extracted_answers_clean = [a.strip() for a in extracted_answers]

    # Assign high reward (2.0) for correct answers
    return [2.0 if a == e else 0.0 for a, e in zip(extracted_answers_clean, expected_answers_clean)]

def reasoning_quality_reward_func(completions, **kwargs) -> list[float]:
    """
    Reward the model for providing mathematical reasoning.
    This function has a medium weight (1.5 max) to encourage good reasoning.
    """
    responses = [completion[0]["content"] for completion in completions]
    reasoning_parts = [extract_xml_reasoning(r) for r in responses]

    rewards = []
    for reasoning in reasoning_parts:
        # Check if reasoning exists
        if not reasoning:
            rewards.append(0.0)
            continue

        # Check reasoning length
        word_count = len(reasoning.split())
        length_reward = min(word_count / 50, 0.8)  # Clip to 0.8

        # Check for mathematical symbols
        math_expressions = len(re.findall(r'[\+\-\*/=><\(\)\[\]\{\}]', reasoning))
        math_reward = min(math_expressions / 10, 0.4)

        # Check for LaTeX expressions
        latex_count = len(re.findall(r'\$.*?\$|\$\$.*?\$\$', reasoning))
        latex_reward = min(latex_count * 0.1, 0.3)

        rewards.append(length_reward + math_reward + latex_reward)

    return rewards

def format_compliance_reward_func(completions, **kwargs) -> list[float]:
    """
    Reward the model for following the specified format.
    This function has a lower weight (1.0 max) as format is important but secondary to correctness.
    """
    responses = [completion[0]["content"] for completion in completions]

    rewards = []
    for response in responses:
        reward = 0.0

        # Check for proper reasoning tag usage
        if "<reasoning>" in response and "</reasoning>" in response:
            reward += 0.4

        # Check for proper answer tag usage
        if "<answer>" in response and "</answer>" in response:
            reward += 0.3

        # Check for correct ordering (reasoning before answer)
        if response.find("<reasoning>") < response.find("<answer>") and \
           response.find("</reasoning>") < response.find("</answer>"):
            reward += 0.3

        rewards.append(reward)

    return rewards

In [7]:
# Model and training configuration
def setup_and_train():
    """Set up and train the model using GRPO."""

    gc.collect()
    torch.cuda.empty_cache()

    # Load dataset with train/test split
    train_dataset, test_dataset = get_open_r1_math_data(
    max_train_examples=1000,
    max_test_examples=50,
    random_seed=42
)

    # Save the test dataset for later evaluation
    test_dataset.save_to_disk("/content/drive/MyDrive/test_math_dataset_improved")

    max_seq_length = 2048
    lora_rank = 16

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = "meta-llama/Llama-3.2-3B-Instruct",
        max_seq_length = max_seq_length,
        load_in_4bit = True,
        fast_inference = True,
        max_lora_rank = lora_rank,
        gpu_memory_utilization = 0.7
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
        target_modules = [
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ], # Remove QKVO if out of memory
        lora_alpha = lora_rank,
        use_gradient_checkpointing = "unsloth", # Enable long context finetuning
        random_state = 3407,
    )

    # Calculate maximum prompt length
    max_prompt_length = max(train_dataset.map(
        lambda x: {"tokens": tokenizer.apply_chat_template(x["prompt"], add_generation_prompt=True, tokenize=True)},
        batched=True,
    ).map(lambda x: {"length": len(x["tokens"])})["length"])

    max_prompt_length = max_prompt_length + 10  # Extra margin

    # Configure training
    training_args = GRPOConfig(
        learning_rate = 5e-6,
        weight_decay = 0.1,
        warmup_ratio = 0.1,
        lr_scheduler_type = "cosine",
        optim = "adamw_torch_fused",
        logging_steps = 1,
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4,
        num_generations = 8,
        max_prompt_length = max_prompt_length,
        max_completion_length = max_seq_length - max_prompt_length,
        # num_train_epochs = 1, # Set to 1 for a full training run
        max_steps = 175,
        save_steps = 25,
        max_grad_norm = 0.1,
        report_to = "wandb",
        output_dir = "outputs_math_reasoning_improved",
    )

    # Create and configure GRPO trainer
    trainer = GRPOTrainer(
        model = model,
        processing_class = tokenizer,
        reward_funcs = [
            correctness_reward_func,
            reasoning_quality_reward_func,
            format_compliance_reward_func
        ],
        args = training_args,
        train_dataset = train_dataset,
    )

    # Start training
    trainer.train()

    # Save model after training
    model.save_lora("/content/drive/MyDrive/math_reasoning_grpo_lora_improved")

    return model, tokenizer, test_dataset


### Train the model and save test dataset

In [8]:
# Train the model and save test dataset
model, tokenizer, test_dataset = setup_and_train()

README.md:   0%|          | 0.00/3.29k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/39 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/39 [00:00<?, ?it/s]

Dataset loaded: 1000 training examples, 50 test examples


Saving the dataset (0/1 shards):   0%|          | 0/50 [00:00<?, ? examples/s]

==((====))==  Unsloth 2025.4.7: Fast Llama patching. Transformers: 4.51.3. vLLM: 0.7.3.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/llama-3.2-3b-instruct-unsloth-bnb-4bit with actual GPU utilization = 69.2%
Unsloth: Your GPU has CUDA compute capability 8.0 with VRAM = 39.56 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 2048. Num Sequences = 320.
Unsloth: vLLM's KV Cache can use up to 24.95 GB. Also swap space = 6 GB.
INFO 05-06 07:58:32 config.py:549] This model supports multiple tasks: {'generate', 'classify', 'embed', 'reward', 'score'}. Defaulting to 'generate'.
Unsloth: vLLM Bitsandbytes 

tokenizer_config.json:   0%|          | 0.00/54.7k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/234 [00:00<?, ?B/s]

INFO 05-06 07:58:35 cuda.py:229] Using Flash Attention backend.
INFO 05-06 07:58:35 model_runner.py:1110] Starting to load model unsloth/llama-3.2-3b-instruct-unsloth-bnb-4bit...
INFO 05-06 07:58:35 loader.py:1089] Loading weights with BitsAndBytes quantization.  May take a while ...
INFO 05-06 07:58:36 weight_utils.py:254] Using model weights format ['*.safetensors']


model.safetensors:   0%|          | 0.00/2.35G [00:00<?, ?B/s]

INFO 05-06 07:59:07 weight_utils.py:270] Time spent downloading weights for unsloth/llama-3.2-3b-instruct-unsloth-bnb-4bit: 31.326314 seconds


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 05-06 07:59:10 model_runner.py:1115] Loading model weights took 2.2405 GB
INFO 05-06 07:59:10 punica_selector.py:18] Using PunicaWrapperGPU.
INFO 05-06 07:59:19 worker.py:267] Memory profiling takes 8.43 seconds
INFO 05-06 07:59:19 worker.py:267] the current vLLM instance can use total_gpu_memory (39.56GiB) x gpu_memory_utilization (0.69) = 27.37GiB
INFO 05-06 07:59:19 worker.py:267] model weights take 2.24GiB; non_torch_memory takes 0.09GiB; PyTorch activation peak memory takes 1.49GiB; the rest of the memory reserved for KV Cache is 23.55GiB.
INFO 05-06 07:59:19 executor_base.py:111] # cuda blocks: 13781, # CPU blocks: 3510
INFO 05-06 07:59:19 executor_base.py:116] Maximum concurrency for 2048 tokens per request: 107.66x
INFO 05-06 07:59:22 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error

Capturing CUDA graph shapes: 100%|██████████| 43/43 [00:57<00:00,  1.34s/it]

INFO 05-06 08:00:20 model_runner.py:1562] Graph capturing finished in 58 secs, took 0.75 GiB
INFO 05-06 08:00:20 llm_engine.py:436] init engine (profile, create kv cache, warmup model) took 70.58 seconds





tokenizer_config.json:   0%|          | 0.00/54.7k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

Unsloth 2025.4.7 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 8


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,000 | Num Epochs = 1 | Total steps = 175
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 4 x 1) = 32
 "-____-"     Trainable parameters = 24,313,856/3,000,000,000 (0.81% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,reward,reward_std,completion_length,kl,rewards / correctness_reward_func,rewards / reasoning_quality_reward_func,rewards / format_compliance_reward_func
1,-0.0,1.3375,0.898345,526.5,0.0,0.0,0.9,0.4375
2,0.0,1.575,0.384161,358.65625,0.0,0.0,1.075,0.5
3,0.0,1.24375,0.640645,882.3125,0.000358,0.0,0.8125,0.43125
4,0.0,1.7875,0.664862,585.75,0.000342,0.0,1.2,0.5875
5,0.0,1.334375,0.607039,675.4375,0.000356,0.0,0.890625,0.44375
6,0.0,1.18125,0.908065,824.28125,0.000312,0.125,0.75,0.30625
7,0.0,1.0125,0.9303,898.34375,0.000368,0.0,0.65625,0.35625
8,0.0,0.88125,0.742076,515.0,0.000391,0.0625,0.55625,0.2625
9,0.0,1.1225,1.065785,771.125,0.000379,0.0,0.69125,0.43125
10,0.0,1.303125,0.78889,733.5,0.000366,0.0,0.809375,0.49375


My training curves in W&B :

![Screeshot](./images/W&B_Train_GRPO.png)