# Environment Setup

In [None]:
!uv pip install bitsandbytes xformers triton unsloth vllm==0.10.2
!uv pip install transformers==4.55.4
!uv pip install --no-deps trl==0.22.2

[2mUsing Python 3.12.11 environment at: /usr[0m
[2K[2mResolved [1m166 packages[0m [2min 1.00s[0m[0m
[2K[2mPrepared [1m47 packages[0m [2min 8.96s[0m[0m
[2mUninstalled [1m6 packages[0m [2min 125ms[0m[0m
[2K[2mInstalled [1m47 packages[0m [2min 88ms[0m[0m
 [32m+[39m [1mastor[0m[2m==0.8.1[0m
 [32m+[39m [1mbitsandbytes[0m[2m==0.48.1[0m
 [32m+[39m [1mblake3[0m[2m==1.0.7[0m
 [32m+[39m [1mcbor2[0m[2m==5.7.0[0m
 [32m+[39m [1mcompressed-tensors[0m[2m==0.11.0[0m
 [32m+[39m [1mcut-cross-entropy[0m[2m==25.1.1[0m
 [31m-[39m [1mdatasets[0m[2m==4.0.0[0m
 [32m+[39m [1mdatasets[0m[2m==4.1.1[0m
 [32m+[39m [1mdepyf[0m[2m==0.19.0[0m
 [32m+[39m [1mdiskcache[0m[2m==5.6.3[0m
 [32m+[39m [1mdnspython[0m[2m==2.8.0[0m
 [32m+[39m [1memail-validator[0m[2m==2.3.0[0m
 [32m+[39m [1mfastapi-cli[0m[2m==0.0.13[0m
 [32m+[39m [1mfastapi-cloud-cli[0m[2m==0.3.0[0m
 [32m+[39m [1mgguf[0m[2m==0.17.1[0m
 [32m+[

In [None]:
from google.colab import drive
drive.mount('/content/drive/')
%cd /content/drive/MyDrive/multi-reward-math-reasoning # Add this folder as shortcut to your Drive to save your results here
!ls

Mounted at /content/drive/
[Errno 2] No such file or directory: '/content/drive/MyDrive/multi-reward-math-reasoning # Add this folder as shortcut to your Drive to save your results here'
/content
drive  sample_data


In [None]:
from unsloth import FastLanguageModel
from trl import SFTConfig, GRPOConfig, SFTTrainer, GRPOTrainer
from vllm import SamplingParams

import gc
import re
import time
import torch
import numpy as np
import pandas as pd

from pathlib import Path
from tqdm.notebook import tqdm
from datasets import load_dataset, Dataset
from safetensors import safe_open

from peft import LoraConfig, get_peft_model, TaskType
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    BitsAndBytesConfig, TextStreamer
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
INFO 10-09 06:14:58 [__init__.py:216] Automatically detected platform cuda.
🦥 Unsloth Zoo will now patch everything to make training faster!


# Model Setup

In [None]:
model_id = 'unsloth/Qwen3-1.7B-Base'          # Select model optimized for instruction-following and reasoning
model_name = model_id.split('/')[-1].lower()  # Extract model name from ID
max_seq_length = 2048                         # Can increase for longer reasoning traces
lora_rank = 32                                # Larger rank = smarter, but slower

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_id,
    max_seq_length=max_seq_length,
    load_in_4bit=False,         # False for LoRA 16bit
    fast_inference=True,        # Enable vLLM fast inference
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.9, # Reduce if out of memory
)
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,                          # Rank: adaptation capacity (16 good for reasoning tasks)
    lora_alpha=lora_rank * 2,             # Scaling factor (typically 2x rank)
    lora_dropout=0.1,                     # Regularization to prevent overfitting
    target_modules=[                      # Remove QKVO if out of memory
        'q_proj', 'k_proj', 'v_proj', 'o_proj',
        'gate_proj', 'up_proj', 'down_proj',
    ],
    use_gradient_checkpointing='unsloth', # Reduces memory usage
    random_state=3407,
)

INFO 10-09 06:15:12 [vllm_utils.py:689] Unsloth: Patching vLLM v1 graph capture
INFO 10-09 06:15:12 [vllm_utils.py:717] Unsloth: Patching vLLM v0 graph capture
==((====))==  Unsloth 2025.10.1: Fast Qwen3 patching. Transformers: 4.55.4. vLLM: 0.10.2.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 8.0. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/Qwen3-1.7B-Base with actual GPU utilization = 88.97%
Unsloth: Your GPU has CUDA compute capability 8.0 with VRAM = 39.56 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 2048. Num Sequences = 320.
Unsloth: vLLM's KV Cache can use up to 31.92 GB. Also swap space = 6 GB.
Unsloth: Not an error, but `device` is

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/166 [00:00<?, ?B/s]

INFO 10-09 06:15:40 [core.py:76] Initializing a V1 LLM engine (v0.10.2) with config: model='unsloth/Qwen3-1.7B-Base', speculative_config=None, tokenizer='unsloth/Qwen3-1.7B-Base', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=unsloth/Qwen3-1.7B-Base, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_co

model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

INFO 10-09 06:15:56 [weight_utils.py:369] Time spent downloading weights for unsloth/Qwen3-1.7B-Base: 13.201495 seconds
INFO 10-09 06:15:56 [weight_utils.py:406] No model.safetensors.index.json found in remote.


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 10-09 06:15:57 [default_loader.py:268] Loading weights took 1.06 seconds
INFO 10-09 06:15:57 [punica_selector.py:19] Using PunicaWrapperGPU.
INFO 10-09 06:15:58 [gpu_model_runner.py:2392] Model loading took 3.2919 GiB and 15.639712 seconds
INFO 10-09 06:16:10 [backends.py:539] Using cache directory: /root/.cache/vllm/torch_compile_cache/c4e82799bf/rank_0_0/backbone for vLLM's torch.compile
INFO 10-09 06:16:10 [backends.py:550] Dynamo bytecode transform time: 10.83 s


Unsloth: Compiling kernels: 100%|██████████| 7/7 [00:00<00:00, 10.86it/s, triton_poi_fused_view_6]


INFO 10-09 06:16:16 [backends.py:194] Cache the graph for dynamic shape for later use


Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 18.90it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 523.03it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 482.96it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 494.80it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 503.30it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 500.27it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 493.58it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 514.29it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 512.47it/s, triton_poi_fused_view_10]
Unsloth: Compiling kernels: 100%|██████████| 11/11 [00:00<00:00, 498.50it/

INFO 10-09 06:16:52 [backends.py:215] Compiling a graph for dynamic shape takes 40.77 s





INFO 10-09 06:17:05 [monitor.py:34] torch.compile takes 51.61 s in total
INFO 10-09 06:17:08 [gpu_worker.py:298] Available KV cache memory: 30.13 GiB
INFO 10-09 06:17:08 [kv_cache_utils.py:864] GPU KV cache size: 282,064 tokens
INFO 10-09 06:17:08 [kv_cache_utils.py:868] Maximum concurrency for 2,048 tokens per request: 137.73x
INFO 10-09 06:17:08 [vllm_utils.py:694] Unsloth: Running patched vLLM v1 `capture_model`.


Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 67/67 [00:17<00:00,  3.82it/s]
Capturing CUDA graphs (decode, FULL): 100%|██████████| 43/43 [00:05<00:00,  7.44it/s]

INFO 10-09 06:17:32 [gpu_model_runner.py:3118] Graph capturing finished in 23 secs, took 0.98 GiB
INFO 10-09 06:17:32 [vllm_utils.py:701] Unsloth: Patched vLLM v1 graph capture finished in 23 secs.





INFO 10-09 06:17:33 [gpu_worker.py:391] Free memory on device (39.03/39.56 GiB) on startup. Desired GPU memory utilization is (0.88969201168322, 35.19 GiB). Actual usage is 3.29 GiB for weight, 1.75 GiB for peak activation, 0.02 GiB for non-torch memory, and 0.98 GiB for CUDAGraph memory. Replace gpu_memory_utilization config with `--kv-cache-memory=31139481190` to fit into requested memory, or `--kv-cache-memory=35258589184` to fully utilize gpu memory. Current kv cache memory in use is 32349537894 bytes.
INFO 10-09 06:17:33 [core.py:218] init engine (profile, create kv cache, warmup model) took 95.37 seconds
INFO 10-09 06:17:35 [llm.py:295] Supported_tasks: ('generate',)
INFO 10-09 06:17:35 [__init__.py:36] No IOProcessor plugins requested by the model
Unsloth: Just some info: will skip parsing ['norm2', 'norm1', 'k_norm', 'post_layernorm', 'attention_norm', 'post_feedforward_layernorm', 'layer_norm1', 'pre_feedforward_layernorm', 'q_norm', 'ffn_norm', 'layer_norm2', 'input_layernorm

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.1.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.
Unsloth 2025.10.1 patched 28 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


# Chat Template

In [None]:
# Define structured output format for mathematical reasoning
REASONING_START = '<THINK>'   # Begin reasoning section
REASONING_END = '</THINK>'    # End reasoning section
SOLUTION_START = '<SOLUTION>' # Begin final answer
SOLUTION_END = '</SOLUTION>'  # End final answer

# System prompt that teaches the model our desired reasoning structure
SYSTEM_PROMPT = f'''You are a mathematical reasoning assistant. When given a math problem:
1. Show your step-by-step work between {REASONING_START} and {REASONING_END}.
2. Provide your final numerical answer between {SOLUTION_START} and {SOLUTION_END}.
3. Be precise and show all calculation steps clearly.'''
print(SYSTEM_PROMPT)

You are a mathematical reasoning assistant. When given a math problem:
1. Show your step-by-step work between <THINK> and </THINK>.
2. Provide your final numerical answer between <SOLUTION> and </SOLUTION>.
3. Be precise and show all calculation steps clearly.


In [None]:
chat_template = ( # Build and assign chat_template to the tokenizer
    # If the very first message is a SYSTEM role, print it + <eos>:
    "{% if messages[0]['role'] == 'system' %}"
      "{{ messages[0]['content'] + eos_token }}"
      "{% set loop_messages = messages[1:] %}"
    "{% else %}"
      # Otherwise, inject our system_prompt + <eos>:
      "{{ '{system_prompt}' + eos_token }}"
      "{% set loop_messages = messages %}"
    "{% endif %}"

    # Now loop over the remaining messages (either user or assistant):
    "{% for message in loop_messages %}"
      "{% if message['role'] == 'user' %}"
        "{{ message['content'] }}"
      "{% elif message['role'] == 'assistant' %}"
        "{{ message['content'] + eos_token }}"
      "{% endif %}"
    "{% endfor %}"

    # If we asked for "add_generation_prompt", append <REASONING> to the end:
    "{% if add_generation_prompt %}{{ '{reasoning_start}' }}"
    "{% endif %}"
)
# Replace with out specific template:
tokenizer.chat_template = chat_template\
    .replace("'{system_prompt}'",   f"'{SYSTEM_PROMPT}'")\
    .replace("'{reasoning_start}'", f"'{REASONING_START}'")

In [None]:
example_messages = [ # Quick sanity check of the template
    {'role': 'user', 'content': 'Which country has the highest population density?'},
    {'role': 'assistant', 'content': (
        f'{REASONING_START}'
        'I know that country X is small in area but has a huge population, '
        'so its people per square kilometer is extremely high.'
        f'{REASONING_END}{SOLUTION_START}Monaco{SOLUTION_END}'
    )},
    {'role': 'user', 'content': 'Which planet is farthest from the Sun?'},
]
print(tokenizer.apply_chat_template(example_messages, tokenize=False, add_generation_prompt=True))

You are a mathematical reasoning assistant. When given a math problem:
1. Show your step-by-step work between <THINK> and </THINK>.
2. Provide your final numerical answer between <SOLUTION> and </SOLUTION>.
3. Be precise and show all calculation steps clearly.<|endoftext|>Which country has the highest population density?<THINK>I know that country X is small in area but has a huge population, so its people per square kilometer is extremely high.</THINK><SOLUTION>Monaco</SOLUTION><|endoftext|>Which planet is farthest from the Sun?<THINK>


# Pre Fine-tuning (SFT)

## Data preparation

In [None]:
# Use a subset of NVIDIA's Open Math Reasoning dataset, which was filtered to only include high quality DeepSeek R1 traces
sft_dataset = load_dataset('unsloth/OpenMathReasoning-mini', split='cot').to_pandas()
sft_dataset = sft_dataset[['expected_answer', 'problem', 'generated_solution']]

# Try converting to number - if not, replace with NaN
is_number = pd.to_numeric(pd.Series(sft_dataset['expected_answer']), errors='coerce').notnull()
sft_dataset = sft_dataset.iloc[np.where(is_number)[0]] # Select only numbers
sft_dataset

README.md:   0%|          | 0.00/603 [00:00<?, ?B/s]

data/cot-00000-of-00001.parquet:   0%|          | 0.00/106M [00:00<?, ?B/s]

Generating cot split:   0%|          | 0/19252 [00:00<?, ? examples/s]

Unnamed: 0,expected_answer,problem,generated_solution
0,14,Given $\sqrt{x^2+165}-\sqrt{x^2-52}=7$ and $x$...,"<think>\nOkay, let's see. I need to solve the ..."
6,-2,Find the value of the parameter $a$ for which ...,"<think>\nOkay, so I need to find the value of ..."
9,18,What is the sum of all real numbers $x$ for wh...,"<think>\nOkay, so I need to solve the equation..."
13,2,Evaluate the sum \(\sum_{n=1}^\infty \frac{\ph...,"<think>\nOkay, so I need to evaluate the infin..."
17,30,What is the largest positive integer that divi...,"<think>\nAlright, so I need to find the larges..."
...,...,...,...
19243,244,"Let \( p \), \( q \), and \( r \) be the disti...","<think>\nOkay, so I need to find the value of ..."
19245,1,A bug is on the $0$ of a number line. At any p...,"<think>\nOkay, so I have this problem where a ..."
19247,4,A bus left point X for point Y. Two hours late...,"<think>\nOkay, let's tackle this problem step ..."
19248,18,Each interior angle of a regular n-gon measure...,"<think>\nOkay, let's see. I need to find the n..."


In [None]:
def format_dataset(x): # Format the dataset to follow our GRPO style formatting
    expected_answer = x['expected_answer']
    problem = x['problem']

    # Remove generated <think> and </think>
    thoughts = x['generated_solution'].replace('<think>', '').replace('</think>', '')
    thoughts = thoughts.strip()

    # Add our custom formatting
    final_prompt = REASONING_START + thoughts + REASONING_END + \
                   SOLUTION_START + expected_answer + SOLUTION_END
    return [
        {'role': 'system'   , 'content': SYSTEM_PROMPT},
        {'role': 'user'     , 'content': problem},
        {'role': 'assistant', 'content': final_prompt},
    ]

sft_dataset['messages'] = sft_dataset.apply(format_dataset, axis=1)
print(tokenizer.apply_chat_template(sft_dataset['messages'][0], tokenize=False))

You are a mathematical reasoning assistant. When given a math problem:
1. Show your step-by-step work between <THINK> and </THINK>.
2. Provide your final numerical answer between <SOLUTION> and </SOLUTION>.
3. Be precise and show all calculation steps clearly.<|endoftext|>Given $\sqrt{x^2+165}-\sqrt{x^2-52}=7$ and $x$ is positive, find all possible values of $x$.<THINK>Okay, let's see. I need to solve the equation √(x² + 165) - √(x² - 52) = 7, and find all positive values of x. Hmm, radicals can be tricky, but maybe if I can eliminate the square roots by squaring both sides. Let me try that.

First, let me write down the equation again to make sure I have it right:

√(x² + 165) - √(x² - 52) = 7.

Okay, so the idea is to isolate one of the radicals and then square both sides. Let me try moving the second radical to the other side:

√(x² + 165) = 7 + √(x² - 52).

Now, if I square both sides, maybe I can get rid of the square roots. Let's do that:

(√(x² + 165))² = (7 + √(x² - 52))².

Sim

In [None]:
# Truncate pre fine-tuning sft_dataset to max_seq_length / 2 since we don't want too long reasoning traces
sft_dataset['seq_length'] = sft_dataset['messages'].apply(lambda x: len(tokenizer.apply_chat_template(x)))
print('Token-length percentiles (50/90/99):', np.percentile(sft_dataset['seq_length'], [50, 90, 99]))

threshold = max_seq_length / 2
sft_dataset_filtered = sft_dataset.loc[sft_dataset['seq_length'] <= threshold].copy()
print(f'Remaining for training (<= {threshold} tokens): {len(sft_dataset_filtered)}/{len(sft_dataset)}')

sft_dataset_filtered['text'] = tokenizer.apply_chat_template(sft_dataset_filtered['messages'].values.tolist(), tokenize=False)
sft_dataset_filtered = Dataset.from_pandas(sft_dataset_filtered)
sft_dataset_filtered

Token-length percentiles (50/90/99): [ 3729.    9034.   15685.84]
Remaining for training (<= 1024.0 tokens): 51/7507


Dataset({
    features: ['expected_answer', 'problem', 'generated_solution', 'messages', 'seq_length', 'text', '__index_level_0__'],
    num_rows: 51
})

## Pre fine-tune to understand custom GRPO formatting

In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=sft_dataset_filtered,
    args=SFTConfig(
        dataset_text_field='text',
        num_train_epochs=3,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
        optim='adamw_8bit',
        weight_decay=0.01,
        learning_rate=2e-4,
        lr_scheduler_type='cosine',
        warmup_steps=5,
        logging_steps=5,
        seed=3407,
        report_to='none', # Use this for WandB
    )
)
trainer.train()

Unsloth: Tokenizing ["text"] (num_proc=16):   0%|          | 0/51 [00:00<?, ? examples/s]

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 51 | Num Epochs = 3 | Total steps = 153
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 1 x 1) = 1
 "-____-"     Trainable parameters = 34,865,152 of 1,755,440,128 (1.99% trained)


Step,Training Loss
5,0.8727
10,0.7518
15,0.5836
20,0.4972
25,0.5197
30,0.4605
35,0.3991
40,0.4755
45,0.4506
50,0.4531


Unsloth: Will smartly offload gradients to save VRAM!


TrainOutput(global_step=153, training_loss=0.3442266450987922, metrics={'train_runtime': 78.0589, 'train_samples_per_second': 1.96, 'train_steps_per_second': 1.96, 'total_flos': 1196951737651200.0, 'train_loss': 0.3442266450987922, 'epoch': 3.0})

## Check if model has learnt to follow the format

In [None]:
text = tokenizer.apply_chat_template( # Render into a single string and append <REASONING> for generation
    sft_dataset_filtered[1]['messages'][:2],
    tokenize=False, add_generation_prompt=True, # Append the final <REASONING>
)
_ = model.generate(
    **tokenizer(text, return_tensors='pt').to('cuda'),
    temperature=0, max_new_tokens=1024,
    streamer=TextStreamer(tokenizer, skip_prompt=False), # Stream the model's generations (CoT + solution)
)

You are a mathematical reasoning assistant. When given a math problem:
1. Show your step-by-step work between <THINK> and </THINK>.
2. Provide your final numerical answer between <SOLUTION> and </SOLUTION>.
3. Be precise and show all calculation steps clearly.<|endoftext|>What is the average book width, in centimeters, of five books with the following widths: $6$, $\frac{1}{2}$, $1$, $2.5$, and $10$?<THINK>Okay, let's see. I need to find the average width of five books. The widths given are 6 cm, 1/2 cm, 1 cm, 2.5 cm, and 10 cm. Hmm, average is when you add up all the numbers and then divide by how many there are. So first, I should add these numbers together. Let me write them down: 6, 0.5 (which is 1/2), 1, 2.5, and 10. 

Wait, adding decimals can be tricky. Maybe I should convert all the numbers to decimals to make it easier. 1/2 is 0.5, so the numbers are 6, 0.5, 1, 2.5, and 10. Let me add those step by step. 

First, 6 + 0.5 is 6.5. Then add 1: 6.5 + 1 = 7.5. Next, add 2.5: 7.5 + 

In [None]:
del sft_dataset, sft_dataset_filtered
gc.collect()
torch.cuda.empty_cache()

# Post Fine-tuning (RL)

## Data preparation

In [None]:
def process_dataset_sample(example): # Convert GSM8K example to conversation format for GRPO training
    return {
        'prompt': [ # Create conversation with system prompt for structured reasoning
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': example['question']},
        ],
        # Extract numerical answer from GSM8K format ('Explanation... #### 42') as Ground truth for reward functions
        'answer': example['answer'].split('####')[1].strip() if '####' in example['answer'] else None
    }

In [None]:
# train_dataset = load_dataset('openai/gsm8k', 'main', split=['train[:10%]'])
train_dataset = load_dataset('openai/gsm8k', 'main', split='train')
train_dataset = train_dataset.map(process_dataset_sample)

print(f'Training samples: {len(train_dataset):,}\n'
      f"- Sample question: {train_dataset[0]['prompt'][1]['content']}\n"
      f"- Sample answer (ground truth for rewards): {train_dataset[0]['answer']}\n"
      f"- Prompt (system + user):\n{train_dataset[0]['prompt']}")

README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

Training samples: 7,473
- Sample question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
- Sample answer (ground truth for rewards): 72
- Prompt (system + user):
[{'content': 'You are a mathematical reasoning assistant. When given a math problem:\n1. Show your step-by-step work between <THINK> and </THINK>.\n2. Provide your final numerical answer between <SOLUTION> and </SOLUTION>.\n3. Be precise and show all calculation steps clearly.', 'role': 'system'}, {'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?', 'role': 'user'}]


In [None]:
# Get the top 90% prompt length so we don't accidentally truncate them, i.e. we'll remove the top 10% long prompts
tokenized_dataset = train_dataset.map(
    lambda x: {'tokens': tokenizer.apply_chat_template(x['prompt'], add_generation_prompt=True, tokenize=True)},
    batched=True,
).map(lambda x: {'length': len(x['tokens'])})
print(tokenizer.decode(tokenized_dataset[0]['tokens']))

thresholds = np.percentile(tokenized_dataset['length'], [50, 90, 99])
max_prompt_length = int(thresholds[1])
print('Token-length percentiles (50/90/99):', thresholds, '=> Choose max_prompt_length =', max_prompt_length)

# Filter only samples smaller than 90% max length
train_dataset = train_dataset.select(np.where(np.array(tokenized_dataset['length']) <= max_prompt_length)[0])
print(f'Remaining for training (<= {max_prompt_length} tokens): {len(train_dataset)}/{len(tokenized_dataset)}')
del tokenized_dataset

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

You are a mathematical reasoning assistant. When given a math problem:
1. Show your step-by-step work between <THINK> and </THINK>.
2. Provide your final numerical answer between <SOLUTION> and </SOLUTION>.
3. Be precise and show all calculation steps clearly.<|endoftext|>Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<THINK>
Token-length percentiles (50/90/99): [119. 152. 191.] => Choose max_prompt_length = 152
Remaining for training (<= 152 tokens): 6749/7473


## Regex Patterns

In [None]:
# Match the reasoning sections and answers
match_format = re.compile(
    # rf'^[\s]{{0,}}'                                     # Optional whitespace at start
    # rf'{REASONING_START}.+?{REASONING_END}.*?'          # Reasoning section (non-greedy)
    rf'{REASONING_END}.*?'                              # We always prepend REASONING_START
    rf'{SOLUTION_START}(.+?){SOLUTION_END}'             # Solution section with capture group
    rf'[\s]{{0,}}(?:{re.escape(tokenizer.eos_token)})?' # Add optional EOS token matching
    rf'[\s]{{0,}}$',                                    # Optional whitespace at end
    flags=re.MULTILINE | re.DOTALL,                     # Multi-line matching with . matching newlines
)
match_format.findall( # Verify it works
    f'{REASONING_START}Let me think!{REASONING_END}'\
    f'{SOLUTION_START}\n2\n{SOLUTION_END}\n\n',
)

['\n2\n']

In [None]:
# Sometimes it might not be 1 number as the answer, but like a sentence.
# For example: 'The solution is $20' -> we extract 20
# We also remove possible commas for example as in 123,456
match_number = re.compile(
    rf'{SOLUTION_START}.*?[\s]{{0,}}([-]?[\d\.\,]{{1,}})', # Extract numbers from solution section
    flags=re.MULTILINE | re.DOTALL | re.IGNORECASE,        # Flexible pattern matching
)
print(match_number.findall('<SOLUTION>  0.34  </SOLUTION>'))
print(match_number.findall('<SOLUTION>  123,456  </SOLUTION>'))
print(match_number.findall('<SOLUTION>  -0.234  </SOLUTION>'))
print(match_number.findall('<SOLUTION>17</SOLUTION>'))

['0.34']
['123,456']
['-0.234']
['17']


## Multi-reward design

In [None]:
def match_format_strictly(completions, **kwargs) -> list[float]:
    ''' Reward Function 1: Exact Format Compliance
    High reward (3.0) for perfect format adherence
    Ensures model learns the complete structured output pattern
    '''
    return [
        3.0 if match_format.search(completion[0]['content']) else 0.0
        for completion in completions
    ]

In [None]:
# If it fails, reward the model if it at least follows the format partially, by counting each symbol
def match_format_softly(completions, **kwargs) -> list[float]:
    ''' Reward Function 2: Partial Format Credit
    Graduated scoring for format elements
    Encourages learning individual components even if not perfect
    '''
    rewards = []
    for completion in completions:
        reward = 0
        response = completion[0]['content']

        # Count how many keywords are seen - we penalize if too many!
        # Award +0.5 for correct token count, -0.5 for wrong count
        # No need to reward REASONING_START since we always prepend it!
        # reward += 0.5 if response.count(REASONING_START) == 1 else -0.5
        reward += 0.5 if response.count(REASONING_END) == 1 else -0.5
        reward += 0.5 if response.count(SOLUTION_START) == 1 else -0.5
        reward += 0.5 if response.count(SOLUTION_END) == 1 else -0.5
        rewards.append(reward)
    return rewards

In [None]:
# Extract the generated answer, and reward or penalize it
def check_answer_correctness(completions, answer, **kwargs) -> list[float]:
    ''' Reward Function 3: Graduated scoring for mathematical accuracy
    - 5.0: Exact string match gets full points
    - 2.0: Within 10% (close answer)
    - 1.5: Within 20% (reasonable attempt)
    - -2.5: Wrong answer (penalty for incorrect math)
    '''
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [ # Extract answers using format pattern
        guess.group(1) if (guess := match_format.search(r)) else None
        for r in responses
    ]
    rewards = []
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None: # No extractable answer
            rewards.append(-2.0)
            continue

        if guess == true_answer: rewards.append(5.0)                   # Correct answer gets 5 points!
        elif guess.strip() == true_answer.strip(): rewards.append(3.5) # Match if spaces are seen, but less reward
        else: # Try numerical comparison for partial credit
            try: # We also reward it based on how close the answer is to the true one via ratios
                ratio = float(guess) / float(true_answer)     # If the answer is within some range, reward it!
                if 0.9 <= ratio <= 1.1: rewards.append(2.0)   # Within 10%
                elif 0.8 <= ratio <= 1.2: rewards.append(1.5) # Within 20%
                else: rewards.append(-2.5)                    # Penalize wrong answers
            except (ValueError, ZeroDivisionError):
                rewards.append(-4.5)                          # Invalid numerical format
    return rewards

In [None]:
def check_numbers_extraction(prompts, completions, answer, **kwargs) -> list[float]:
    ''' Reward Function 4: Number Extraction Ability
    Tests the model's ability to extract numerical values from solution sections
    Complementary to exact format matching - focuses on parsing capability
    '''
    question = prompts[0][-1]['content'] # Exclude system prompt
    responses = [completion[0]['content'] for completion in completions]

    extracted_responses = [ # Extract numbers from solution sections using number pattern
        guess.group(1) if (guess := match_number.search(r)) else None
        for r in responses
    ]
    rewards = []

    # Print only every few steps
    check_numbers_extraction.counter = getattr(check_numbers_extraction, 'counter', 0) + 1
    if check_numbers_extraction.counter % 100 == 0:
        print(
            '==' * 100,
            f'\nQuestion: {question}'
            f'\nPrediction: {extracted_responses[0]}, GT Answer: {answer[0]}'
            f'\nResponse:\n{responses[0]}'
        )
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None: # No extractable number
            rewards.append(-2.5)
            continue

        try: # Simple numerical equality check
            true_val = float(true_answer.strip())             # Convert to numbers
            guess_val = float(guess.strip().replace(',', '')) # Remove commas like in 123,456
            rewards.append(3.5 if guess_val == true_val else -1.5)
        except (ValueError, TypeError):
            rewards.append(0) # Invalid number format
    return rewards

## GRPO training setup

In [None]:
max_prompt_length = 152 + 1 # + 1 just in case!
max_completion_length = max_seq_length - max_prompt_length
vllm_sampling_params = SamplingParams(
    min_p = 0.1,
    top_p = 1.0,
    top_k = -1,
    stop = [tokenizer.eos_token],
    include_stop_str_in_output = True,
)

In [None]:
training_args = GRPOConfig(          # Configure GRPO training parameters for mathematical reasoning
    output_dir=f'/tmp/{model_name}', # Directory for checkpoints and logs
    vllm_sampling_params=vllm_sampling_params,
    # Training speed control
    num_train_epochs=1,              # Total number of training epochs
    per_device_train_batch_size=2,   # Small batch for GPU memory constraints
    gradient_accumulation_steps=8,   # Effective batch size = 2 * 8 = 16
    # Computing the loss: https://huggingface.co/docs/trl/main/grpo_trainer#computing-the-loss
    scale_rewards='batch',           # Calculate mean at local/group level and std at global/batch level enables more robust reward shaping
    loss_type='dr_grpo',             # Fully remove response length bias, dividing by a constant instead of the sequence length
    # Precision & Optimization
    optim='adamw_8bit',              # adamw_torch_fused, adamw_8bit, paged_adamw_8bit
    weight_decay=0.1,                # Regularization
    max_grad_norm=0.1,               # Aggressive gradient clipping for stable training
    gradient_checkpointing=True,
    bf16=torch.cuda.is_available(),  # Enable mixed-precision training if a CUDA GPU is available (faster, less memory)
    # Learning rate scheduling
    learning_rate=1e-5,              # Conservative LR to prevent destabilizing reasoning
    warmup_ratio=0.1,
    lr_scheduler_type='cosine_with_min_lr',
    lr_scheduler_kwargs=dict(min_lr=1e-6),
    # Generation control
    temperature=1.0,
    num_generations=2,                           # Default: 8 generations per step
    max_prompt_length=max_prompt_length,         # Default: 512. Sufficient for complex word problems
    max_completion_length=max_completion_length, # Default: 256. Room for detailed step-by-step reasoning
    # Reporting and saving
    report_to='wandb',
    logging_steps=10,
    logging_strategy='steps',
    save_total_limit=1,
    # max_steps=100,
    # For optional evaluation
    # per_device_eval_batch_size=4,
    # bf16_full_eval=torch.cuda.is_available(),
    # eval_strategy='steps',                       # Evaluate after each epoch
    # load_best_model_at_end=True,                 # Load the best model based on validation loss
)

## Train the model

In [26]:
%%time
trainer = GRPOTrainer(            # Initialize GRPO trainer with multi-reward system
    model=model,                  # LoRA-adapted quantized model
    processing_class=tokenizer,
    train_dataset=train_dataset,  # Processed GSM8K dataset
    args=training_args,           # Training configuration
    reward_funcs=[                # 4 complementary reward functions
        match_format_strictly,    # Perfect structure compliance
        match_format_softly,      # Partial format credit
        check_answer_correctness, # Mathematical accuracy
        check_numbers_extraction, # Number parsing ability
    ]
)
trainer.train()
trainer.save_model(f'./{model_name}_grpo')

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 6,749 | Num Epochs = 1 | Total steps = 843
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 8 x 1) = 16
 "-____-"     Trainable parameters = 34,865,152 of 1,755,440,128 (1.99% trained)
  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mwalteryeyint[0m ([33mwalteryeyint-university-of-technology-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,kl,rewards / match_format_strictly / mean,rewards / match_format_strictly / std,rewards / match_format_softly / mean,rewards / match_format_softly / std,rewards / check_answer_correctness / mean,rewards / check_answer_correctness / std,rewards / check_numbers_extraction / mean,rewards / check_numbers_extraction / std
10,0.2561,10.671875,4.81173,839.05625,455.1,1613.6,0.04375,792.29043,455.1,1406.9,0.315366,2.83125,0.579939,1.3625,0.50247,3.834375,2.34847,2.64375,1.775513
20,0.2281,10.303125,5.179288,814.13125,459.1,1706.2,0.03125,778.455023,459.1,1418.8,0.288088,2.86875,0.47747,1.35,0.529929,3.653125,2.63255,2.43125,2.010059
30,0.1895,10.959375,4.54671,849.3375,502.4,1629.5,0.04375,804.677594,502.4,1339.8,0.296794,2.8125,0.654939,1.3375,0.579929,3.984375,2.087933,2.825,1.638401
40,0.2054,10.640625,5.100795,898.425,513.3,1689.7,0.0625,832.660413,513.3,1524.1,0.299149,2.75625,0.671807,1.29375,0.586094,3.903125,2.322403,2.6875,1.850548
50,0.2348,10.803125,4.894794,857.6125,446.8,1700.3,0.0625,786.148584,446.8,1433.3,0.295974,2.79375,0.578342,1.30625,0.528342,3.996875,2.23073,2.70625,1.649115
60,0.2388,10.171875,5.58958,903.275,491.4,1717.2,0.05,848.36911,491.4,1469.0,0.279545,2.79375,0.555098,1.325,0.491868,3.571875,2.832211,2.48125,2.072265
70,0.1508,11.19375,4.461209,834.0625,469.0,1684.8,0.05625,769.620917,469.0,1322.9,0.276271,2.8125,0.607409,1.34375,0.482409,4.18125,2.013164,2.85625,1.627516
80,0.0502,11.49375,3.743903,817.74375,481.8,1420.7,0.0125,804.720422,481.8,1303.6,0.310249,2.94375,0.225,1.45625,0.175,4.16875,2.074445,2.925,1.392046
90,0.2101,10.753125,4.48311,810.7875,439.7,1440.5,0.0375,768.363995,439.7,1327.5,0.289901,2.8875,0.236634,1.40625,0.209164,3.834375,2.369889,2.625,1.811529
100,0.1529,11.309375,4.388125,870.98125,470.5,1627.4,0.05,817.874084,470.5,1444.8,0.301106,2.83125,0.466868,1.34375,0.416868,4.246875,1.997014,2.8875,1.630611


Question: James wants to build a 16-foot by 20-foot quilt.  He uses patches that are each 4 square feet.  The first 10 patches cost $10 each and then each patch after that cost half as much.  How much do the patches for the quilt cost?
Prediction: 450, GT Answer: 450
Response:
Okay, let's see. James needs to make a quilt that's 16 feet by 20 feet. Each patch is 4 square feet. So first, I need to figure out how many patches he needs in total. Then, depending on how many patches are in the first 10, the rest will cost half as much. Let me start by calculating the total area of the quilt.

The area of the quilt is length times width, so 16 feet times 20 feet. Let me do that: 16 times 20. Hmm, 16 times 20 is 320 square feet. Each patch covers 4 square feet, so the total number of patches needed is 320 divided by 4. Let's compute that: 320 ÷ 4 = 80 patches. Okay, so the quilt requires 80 patches.

Now, the problem states that the first 10 patches cost $10 each, and then each patch after tha

Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,kl,rewards / match_format_strictly / mean,rewards / match_format_strictly / std,rewards / match_format_softly / mean,rewards / match_format_softly / std,rewards / check_answer_correctness / mean,rewards / check_answer_correctness / std,rewards / check_numbers_extraction / mean,rewards / check_numbers_extraction / std
10,0.2561,10.671875,4.81173,839.05625,455.1,1613.6,0.04375,792.29043,455.1,1406.9,0.315366,2.83125,0.579939,1.3625,0.50247,3.834375,2.34847,2.64375,1.775513
20,0.2281,10.303125,5.179288,814.13125,459.1,1706.2,0.03125,778.455023,459.1,1418.8,0.288088,2.86875,0.47747,1.35,0.529929,3.653125,2.63255,2.43125,2.010059
30,0.1895,10.959375,4.54671,849.3375,502.4,1629.5,0.04375,804.677594,502.4,1339.8,0.296794,2.8125,0.654939,1.3375,0.579929,3.984375,2.087933,2.825,1.638401
40,0.2054,10.640625,5.100795,898.425,513.3,1689.7,0.0625,832.660413,513.3,1524.1,0.299149,2.75625,0.671807,1.29375,0.586094,3.903125,2.322403,2.6875,1.850548
50,0.2348,10.803125,4.894794,857.6125,446.8,1700.3,0.0625,786.148584,446.8,1433.3,0.295974,2.79375,0.578342,1.30625,0.528342,3.996875,2.23073,2.70625,1.649115
60,0.2388,10.171875,5.58958,903.275,491.4,1717.2,0.05,848.36911,491.4,1469.0,0.279545,2.79375,0.555098,1.325,0.491868,3.571875,2.832211,2.48125,2.072265
70,0.1508,11.19375,4.461209,834.0625,469.0,1684.8,0.05625,769.620917,469.0,1322.9,0.276271,2.8125,0.607409,1.34375,0.482409,4.18125,2.013164,2.85625,1.627516
80,0.0502,11.49375,3.743903,817.74375,481.8,1420.7,0.0125,804.720422,481.8,1303.6,0.310249,2.94375,0.225,1.45625,0.175,4.16875,2.074445,2.925,1.392046
90,0.2101,10.753125,4.48311,810.7875,439.7,1440.5,0.0375,768.363995,439.7,1327.5,0.289901,2.8875,0.236634,1.40625,0.209164,3.834375,2.369889,2.625,1.811529
100,0.1529,11.309375,4.388125,870.98125,470.5,1627.4,0.05,817.874084,470.5,1444.8,0.301106,2.83125,0.466868,1.34375,0.416868,4.246875,1.997014,2.8875,1.630611


Question: While Greg was camping with his family for a week, it rained for 3 days. When he looked at the weather records, he saw that the amount of rain was 3 mm, 6 mm, and 5 mm on the three days. During the same week, it rained 26 mm at his house. How much less rain did Greg experience while camping?
Prediction: 12, GT Answer: 12
Response:
Okay, let's see. So the question is asking for how much less rain Greg experienced while camping compared to the 26 mm of rain he faced at his house. 

The problem mentions that Greg camped for seven days, but only three days had rain. The amounts for those three days were 3 mm, 6 mm, and 5 mm. The total rain during his camping trip should be the sum of those three amounts. Then after calculating that, we can subtract it from the 26 mm of rain he experienced at his house to find the difference.

Let me check the math step by step. 

First, sum up the rain Greg got while camping:
3 mm + 6 mm + 5 mm = 14 mm.

Then, subtract that amount from the 26 mm 

# Evaluation

## Resource usage

In [27]:
# Memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)

print(f'GPU = {gpu_stats.name}. Max memory = {max_memory} GB.')
print(f'{start_gpu_memory} GB of memory reserved.')

GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.557 GB.
38.053 GB of memory reserved.


In [28]:
# Extract runtime info
last_log = trainer.state.log_history[-1] # Final memory and time stats
train_seconds = last_log['train_runtime']
samples_per_second = last_log.get('train_samples_per_second', None)

# Recompute GPU memory stats
used_memory   = round(torch.cuda.max_memory_reserved() / 1024**3, 2)
used_for_lora = round(used_memory - start_gpu_memory, 2)
used_pct      = round(used_memory / max_memory * 100, 2)
lora_pct      = round(used_for_lora / max_memory * 100, 2)

print(f'Training time: {train_seconds:.1f} seconds ({train_seconds / 60:.2f} minutes)')
if samples_per_second: print(f'Throughput: {samples_per_second:.1f} samples/second')
print(f'Peak VRAM usage: {used_memory} GB ({used_pct}% of max memory)')
print(f'VRAM for training: {used_for_lora} GB ({lora_pct}% of max memory)')

Training time: 12945.5 seconds (215.76 minutes)
Throughput: 0.5 samples/second
Peak VRAM usage: 38.05 GB (96.19% of max memory)
VRAM for training: -0.0 GB (-0.0% of max memory)


## Verify LoRA is actually trained

In [None]:
example_text = 'What is the sqrt of 101?'
# example_text = 'Solve (x + 2)^2 = 0'
# example_text = "How many r's are in strawberry?"

sampling_params = SamplingParams(
    temperature=1.0,
    top_k=50,
    max_tokens=max_completion_length,
)
print(model.fast_generate( # Try the model without any GRPO trained
    example_text, sampling_params=sampling_params,
    lora_request=None
)[0].outputs[0].text)

In [None]:
tensors = {}
with safe_open(f'./{model_name}_grpo/adapter_model.safetensors', framework='pt') as f:
    for key in f.keys(): # Verify both A and B are non zero
        tensor = f.get_tensor(key)
        n_zeros = (tensor == 0).sum() / tensor.numel()
        assert(n_zeros.item() != tensor.numel())

In [None]:
# Load the LoRA and test without using system prompt
# which should not (or minimal) affect the model's original reasoning ability
text = tokenizer.apply_chat_template(
    [{'role': 'user', 'content': example_text}],
    add_generation_prompt=True, tokenize=False,
)
print(model.fast_generate(
    text, sampling_params=sampling_params,
    lora_request=model.load_lora(f'./{model_name}_grpo'),
)[0].outputs[0].text)

In [None]:
# Test using system prompt
text = tokenizer.apply_chat_template([
    {'role': 'system', 'content': SYSTEM_PROMPT},
    {'role': 'user'  , 'content': example_text},
], add_generation_prompt=True, tokenize=False)

# Compare results with system prompt but without LoRA
print(model.fast_generate(
    text, sampling_params=sampling_params,
    lora_request=None,
)[0].outputs[0].text)

# Reasoning model is much better - it's not always correct, since we only trained it for an hour
# It'll be better if we extend the sequence length and train for longer
print(model.fast_generate(
    text, sampling_params=sampling_params,
    lora_request=model.load_lora(f'./{model_name}_grpo'),
)[0].outputs[0].text)

## Performance on Test set

In [33]:
# test_dataset = load_dataset('openai/gsm8k', 'main', split=['test[:10%]'])
test_dataset = load_dataset('openai/gsm8k', 'main', split='test').map(process_dataset_sample)
test_texts = [
    tokenizer.apply_chat_template([
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': sample['prompt'][1]['content']},
    ], add_generation_prompt=True, tokenize=False)
for sample in test_dataset]
print(f'Testing samples:', len(test_dataset))

Map:   0%|          | 0/1319 [00:00<?, ? examples/s]

Testing samples: 1319


In [34]:
outputs_with_lora = model.fast_generate(
    test_texts, sampling_params=sampling_params,
    lora_request=model.load_lora(f'./{model_name}_grpo'),
)
outputs_without_lora = model.fast_generate(
    test_texts, sampling_params=sampling_params,
    lora_request=None,
)

Adding requests:   0%|          | 0/1319 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1319 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

Adding requests:   0%|          | 0/1319 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1319 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

In [35]:
# Compare the correct amount of using and not using LoRA
no_lora_correct_format_cnt = lora_correct_format_cnt = 0
no_lora_correct_answer_cnt = lora_correct_answer_cnt = 0
no_lora_correct_all_cnt = lora_correct_all_cnt = 0
num_test_samples = len(test_dataset)

for output_with_lora, output_without_lora, answer in zip(outputs_with_lora, outputs_without_lora, test_dataset['answer']):
    correct_format = match_format.search(output_with_lora.outputs[0].text)
    correct_answer = (guess := match_number.search(output_with_lora.outputs[0].text)) and guess.group(1) == answer
    correct_all = correct_format and correct_answer
    if correct_format: lora_correct_format_cnt += 1
    if correct_answer: lora_correct_answer_cnt += 1
    if correct_all: lora_correct_all_cnt += 1

    correct_format = match_format.search(output_without_lora.outputs[0].text)
    correct_answer = (guess := match_number.search(output_without_lora.outputs[0].text)) and guess.group(1) == answer
    correct_all = correct_format and correct_answer
    if correct_format: no_lora_correct_format_cnt += 1
    if correct_answer: no_lora_correct_answer_cnt += 1
    if correct_all: no_lora_correct_all_cnt += 1

pd.DataFrame({
    'Without LoRA': {
        'Correct Format': f'{no_lora_correct_format_cnt}/{num_test_samples} ({no_lora_correct_format_cnt / num_test_samples * 100:.2f}%)',
        'Correct Answer': f'{no_lora_correct_answer_cnt}/{num_test_samples} ({no_lora_correct_answer_cnt / num_test_samples * 100:.2f}%)',
        'Correct Both': f'{no_lora_correct_all_cnt}/{num_test_samples} ({no_lora_correct_all_cnt / num_test_samples * 100:.2f}%)',
    },
    'With LoRA': {
        'Correct Format': f'{lora_correct_format_cnt}/{num_test_samples} ({lora_correct_format_cnt / num_test_samples * 100:.2f}%)',
        'Correct Answer': f'{lora_correct_answer_cnt}/{num_test_samples} ({lora_correct_answer_cnt / num_test_samples * 100:.2f}%)',
        'Correct Both': f'{lora_correct_all_cnt}/{num_test_samples} ({lora_correct_all_cnt / num_test_samples * 100:.2f}%)',
    },
    'Improvement': {
        'Correct Format': f'+{lora_correct_format_cnt - no_lora_correct_format_cnt} ({(lora_correct_format_cnt - no_lora_correct_format_cnt) / num_test_samples * 100:.2f}%)',
        'Correct Answer': f'+{lora_correct_answer_cnt - no_lora_correct_answer_cnt} ({(lora_correct_answer_cnt - no_lora_correct_answer_cnt) / num_test_samples * 100:.2f}%)',
        'Correct Both': f'+{lora_correct_all_cnt - no_lora_correct_all_cnt} ({(lora_correct_all_cnt - no_lora_correct_all_cnt) / num_test_samples * 100:.2f}%)',
    }
}).T

Unnamed: 0,Correct Format,Correct Answer,Correct Both
Without LoRA,461/1319 (34.95%),316/1319 (23.96%),159/1319 (12.05%)
With LoRA,1306/1319 (99.01%),1014/1319 (76.88%),1013/1319 (76.80%)
Improvement,+845 (64.06%),+698 (52.92%),+854 (64.75%)


# Inference

In [36]:
def generate_with_reasoning(questions, max_completion_length=512):
    conversations = [[                        # Format input using conversation template
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': question},
    ] for question in questions]

    prompts = [tokenizer.apply_chat_template( # Apply chat template and tokenize
        conversation,
        add_generation_prompt=True,         # Add assistant prompt
        tokenize=False,                     # Return string, not tokens
    ) for conversation in conversations]

    # Generate response with reasoning-optimized parameters
    inputs = tokenizer(prompts, return_tensors='pt', padding=True).to(model.device)
    start_time = time.time()
    with torch.no_grad():
        output_ids = model.generate(           # Generate response with reasoning-optimized parameters
            **inputs,
            max_new_tokens=max_completion_length,
            temperature=0.7,                # Balance creativity and consistency
            top_p=0.9,                      # Nucleus sampling for quality
            do_sample=True,                 # Enable sampling for varied reasoning paths
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1,         # Reduce repetitive reasoning steps
            length_penalty=1.0,             # Neutral preference for response length
            early_stopping=True,            # Stop at natural completion
            streamer=TextStreamer(tokenizer, skip_prompt=True),
        )
    end_time = time.time()
    inference_duration = end_time - start_time
    num_generated_tokens = output_ids.shape[1] - inputs['input_ids'].shape[1]

    output_ids = output_ids[:, inputs['input_ids'][0].shape[-1]:output_ids.shape[-1]]
    responses = tokenizer.batch_decode(output_ids, skip_special_tokens=True) # Decode and extract only the generated portion
    return responses, inference_duration, num_generated_tokens

In [37]:
test_dataset = load_dataset('openai/gsm8k', 'main', split='test').map(process_dataset_sample)
gsm8k_question = test_dataset[0]['question']
expected_answer = test_dataset[0]['answer']

print('Question:', gsm8k_question, '\nResponse:')
gsm8k_responses, inference_duration, num_generated_tokens = generate_with_reasoning([gsm8k_question], max_completion_length)
gsm8k_response = gsm8k_responses[0]
print('Inference time (secs):', inference_duration)
print('Generated tokens:', num_generated_tokens)

Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? 
Response:
Okay, let's see... So Janet has some ducks that lay eggs each day. We need to figure out how many eggs she sells to her customers at the farmer's market.

First, we know that there are two groups of people who eat eggs: herself (for breakfast) and her friends (who get baked muffins). Each person consumes one egg, so together they use up 3 + 4 = 7 eggs per day. That means the remaining number of eggs is 16 - 7 = 9 eggs left over after Janet uses them for food. Then she can sell those 9 eggs to the market at $2 each. 

So multiplying the number of leftover eggs by the price per egg should give us the total earnings from selling those eggs. Let me calculate that now.
To determine how much mon

In [38]:
# Validate format compliance
has_solution = SOLUTION_START in gsm8k_response and SOLUTION_END in gsm8k_response
print('Reasoning section:', REASONING_END in gsm8k_response)
print('Solution section:', has_solution)

if has_solution: # Check answer accuracy if solution section exists
    try:
        solution_text = gsm8k_response.split(SOLUTION_START)[1].split(SOLUTION_END)[0].strip()
        extracted_number = ''.join(filter(str.isdigit, solution_text))
        expected_number = ''.join(filter(str.isdigit, expected_answer))
        print('Extracted:', solution_text)
        print('Expected:', expected_answer)
        print('Correct:', extracted_number == expected_number)
    except:
        print('Could not extract solution')

Reasoning section: True
Solution section: True
Extracted: 18
Expected: 18
Correct: True


In [43]:
!ls

drive			      sample_data
grpo_trainer_lora_model       unsloth_compiled_cache
huggingface_tokenizers_cache  unsloth_training_checkpoints
qwen3-1.7b-base_grpo	      wandb


In [44]:
!cp -r sample_data '/content/drive/My Drive/multi-reward-math-reasoning'
!cp -r grpo_trainer_lora_model '/content/drive/My Drive/multi-reward-math-reasoning'
!cp -r unsloth_compiled_cache '/content/drive/My Drive/multi-reward-math-reasoning'
!cp -r huggingface_tokenizers_cache '/content/drive/My Drive/multi-reward-math-reasoning'
!cp -r unsloth_training_checkpoints '/content/drive/My Drive/multi-reward-math-reasoning'
!cp -r qwen3-1.7b-base_grpo '/content/drive/My Drive/multi-reward-math-reasoning'
!cp -r wandb '/content/drive/My Drive/multi-reward-math-reasoning'