# Environment Setup

In [1]:
!uv pip install trackio bitsandbytes xformers triton unsloth vllm==0.10.2
!uv pip install transformers==4.55.4
!uv pip install --no-deps trl==0.22.2

[2mUsing Python 3.12.11 environment at: /usr[0m
[2mAudited [1m6 packages[0m [2min 112ms[0m[0m
[2mUsing Python 3.12.11 environment at: /usr[0m
[2mAudited [1m1 package[0m [2min 113ms[0m[0m
[2mUsing Python 3.12.11 environment at: /usr[0m
[2mAudited [1m1 package[0m [2min 98ms[0m[0m


In [None]:
from google.colab import drive
drive.mount('/content/drive/')
%cd /content/drive/MyDrive/multi-reward-math-reasoning # Add this folder as shortcut to your Drive to save your results here
!ls

In [2]:
from unsloth import FastLanguageModel
from trl import SFTConfig, GRPOConfig, SFTTrainer, GRPOTrainer
from vllm import SamplingParams

import gc
import re
import time
import torch
import trackio
import numpy as np
import pandas as pd

from pathlib import Path
from tqdm.notebook import tqdm
from datasets import load_dataset, Dataset
from safetensors import safe_open

from peft import LoraConfig, get_peft_model, TaskType
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    BitsAndBytesConfig, TextStreamer
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
INFO 10-06 20:25:11 [__init__.py:216] Automatically detected platform cuda.
🦥 Unsloth Zoo will now patch everything to make training faster!


# Model Setup

## With Hugging Face (very slow)

In [None]:
model_id = 'Qwen/Qwen3-4B-Base'               # Select model optimized for instruction-following and reasoning
model_name = model_id.split('/')[-1].lower()  # Extract model name from ID
max_seq_length = 2048                         # Can increase for longer reasoning traces
lora_rank = 32                                # Larger rank = smarter, but slower

In [None]:
bnb_config = BitsAndBytesConfig(           # Configure 4-bit quantization for ~75% memory reduction
    load_in_4bit=True,                     # Enable 4-bit precision (vs 16-bit default)
    bnb_4bit_quant_type='nf4',             # NormalFloat4: optimal for neural network weights
    bnb_4bit_compute_dtype=torch.float16,  # Use fp16 for forward/backward passes
    bnb_4bit_use_double_quant=True,        # Further quantize quantization constants
)

In [None]:
model = AutoModelForCausalLM.from_pretrained( # Load model with quantization and automatic device mapping
    model_id,
    max_length=max_seq_length,                # Token limit for mathematical problems (reduce if OOM)
    # quantization_config=bnb_config,           # Apply 4-bit quantization
    device_map='auto',                        # Auto-distribute across available GPUs/CPU
    trust_remote_code=True,                   # Allow custom model code execution
    dtype=torch.float16,                      # Use fp16 for non-quantized operations
)
print(f'Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M')
print(f"Quantized parameters: {sum(p.numel() for p in model.parameters() if hasattr(p, 'quant_type')) / 1e6:.2f}M")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Model parameters: 4022.47M
Quantized parameters: 0.00M


In [None]:
lora_config = LoraConfig(
    r=lora_rank,                           # Rank: adaptation capacity (16 good for reasoning tasks)
    lora_alpha=lora_rank * 2,              # Scaling factor (typically 2x rank)
    lora_dropout=0.1,                      # Regularization to prevent overfitting
    target_modules=[                       # Remove QKVO if out of memory
        'q_proj', 'k_proj', 'v_proj', 'o_proj',
        'gate_proj', 'up_proj', 'down_proj',
    ],
    task_type=TaskType.CAUSAL_LM,          # Causal language modeling task
    bias='none',                           # Skip bias adaptation for simplicity
)
model = get_peft_model(model, lora_config) # Apply LoRA configuration to create trainable adapter
model.print_trainable_parameters()         # Shows trainable vs total parameters

trainable params: 66,060,288 || all params: 4,088,528,384 || trainable%: 1.6157


In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token is None or tokenizer.pad_token_id is None:
    # Ensure tokenizer has proper padding token for batch processing
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
else: print(f'Pad token ({tokenizer.pad_token_id}): {tokenizer.pad_token}')
print(f'EOS token ({tokenizer.eos_token_id}): {tokenizer.eos_token}')

Pad token (151643): <|endoftext|>
EOS token (151643): <|endoftext|>


## With Unsloth (much better)

In [3]:
model_id = 'unsloth/Qwen3-1.7B-Base'          # Select model optimized for instruction-following and reasoning
model_name = model_id.split('/')[-1].lower()  # Extract model name from ID
max_seq_length = 2048                         # Can increase for longer reasoning traces
lora_rank = 32                                # Larger rank = smarter, but slower

In [4]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_id,
    max_seq_length = max_seq_length,
    load_in_4bit = False,         # False for LoRA 16bit
    fast_inference = True,        # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.9, # Reduce if out of memory
)
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,                          # Rank: adaptation capacity (16 good for reasoning tasks)
    lora_alpha = lora_rank * 2,             # Scaling factor (typically 2x rank)
    lora_dropout=0.1,                       # Regularization to prevent overfitting
    target_modules = [                      # Remove QKVO if out of memory
        'q_proj', 'k_proj', 'v_proj', 'o_proj',
        'gate_proj', 'up_proj', 'down_proj',
    ],
    use_gradient_checkpointing = 'unsloth', # Reduces memory usage
    random_state = 3407,
)

INFO 10-06 20:25:23 [vllm_utils.py:689] Unsloth: Patching vLLM v1 graph capture
INFO 10-06 20:25:23 [vllm_utils.py:717] Unsloth: Patching vLLM v0 graph capture
==((====))==  Unsloth 2025.10.1: Fast Qwen3 patching. Transformers: 4.55.4. vLLM: 0.10.2.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 8.0. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/Qwen3-1.7B-Base with actual GPU utilization = 88.97%
Unsloth: Your GPU has CUDA compute capability 8.0 with VRAM = 39.56 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 2048. Num Sequences = 320.
Unsloth: vLLM's KV Cache can use up to 31.92 GB. Also swap space = 6 GB.
Unsloth: Not an error, but `device` is

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 10-06 20:25:42 [default_loader.py:268] Loading weights took 1.05 seconds
INFO 10-06 20:25:42 [punica_selector.py:19] Using PunicaWrapperGPU.
INFO 10-06 20:25:44 [gpu_model_runner.py:2392] Model loading took 3.2919 GiB and 1.584866 seconds
INFO 10-06 20:25:56 [backends.py:539] Using cache directory: /root/.cache/vllm/torch_compile_cache/c4e82799bf/rank_0_0/backbone for vLLM's torch.compile
INFO 10-06 20:25:56 [backends.py:550] Dynamo bytecode transform time: 11.24 s
INFO 10-06 20:26:01 [backends.py:161] Directly load the compiled graph(s) for dynamic shape from the cache, took 3.723 s
INFO 10-06 20:26:03 [monitor.py:34] torch.compile takes 11.24 s in total
INFO 10-06 20:26:04 [gpu_worker.py:298] Available KV cache memory: 30.13 GiB
INFO 10-06 20:26:05 [kv_cache_utils.py:864] GPU KV cache size: 282,080 tokens
INFO 10-06 20:26:05 [kv_cache_utils.py:868] Maximum concurrency for 2,048 tokens per request: 137.73x
INFO 10-06 20:26:05 [vllm_utils.py:694] Unsloth: Running patched vLLM v1 `

Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 67/67 [00:09<00:00,  7.25it/s]
Capturing CUDA graphs (decode, FULL): 100%|██████████| 43/43 [00:05<00:00,  7.27it/s]

INFO 10-06 20:26:20 [gpu_model_runner.py:3118] Graph capturing finished in 15 secs, took 0.98 GiB
INFO 10-06 20:26:20 [vllm_utils.py:701] Unsloth: Patched vLLM v1 graph capture finished in 15 secs.





INFO 10-06 20:26:22 [gpu_worker.py:391] Free memory on device (39.03/39.56 GiB) on startup. Desired GPU memory utilization is (0.88969201168322, 35.19 GiB). Actual usage is 3.29 GiB for weight, 1.75 GiB for peak activation, 0.02 GiB for non-torch memory, and 0.98 GiB for CUDAGraph memory. Replace gpu_memory_utilization config with `--kv-cache-memory=31143675494` to fit into requested memory, or `--kv-cache-memory=35262783488` to fully utilize gpu memory. Current kv cache memory in use is 32351635046 bytes.
INFO 10-06 20:26:22 [core.py:218] init engine (profile, create kv cache, warmup model) took 38.27 seconds
INFO 10-06 20:26:23 [llm.py:295] Supported_tasks: ('generate',)
INFO 10-06 20:26:23 [__init__.py:36] No IOProcessor plugins requested by the model
Unsloth: Just some info: will skip parsing ['ffn_norm', 'k_norm', 'input_layernorm', 'layer_norm1', 'post_feedforward_layernorm', 'attention_norm', 'q_norm', 'pre_feedforward_layernorm', 'norm2', 'norm1', 'layer_norm2', 'post_attention

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.1.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.
Unsloth 2025.10.1 patched 28 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


# Chat Template

In [5]:
# Define structured output format for mathematical reasoning
REASONING_START = '<THINK>'   # Begin reasoning section
REASONING_END = '</THINK>'    # End reasoning section
SOLUTION_START = '<SOLUTION>' # Begin final answer
SOLUTION_END = '</SOLUTION>'  # End final answer

# System prompt that teaches the model our desired reasoning structure
SYSTEM_PROMPT = f'''You are a mathematical reasoning assistant. When given a math problem:
1. Show your step-by-step work between {REASONING_START} and {REASONING_END}.
2. Provide your final numerical answer between {SOLUTION_START} and {SOLUTION_END}.
3. Be precise and show all calculation steps clearly.'''
print(SYSTEM_PROMPT)

You are a mathematical reasoning assistant. When given a math problem:
1. Show your step-by-step work between <THINK> and </THINK>.
2. Provide your final numerical answer between <SOLUTION> and </SOLUTION>.
3. Be precise and show all calculation steps clearly.


In [6]:
chat_template = ( # Build and assign chat_template to the tokenizer
    # If the very first message is a SYSTEM role, print it + <eos>:
    "{% if messages[0]['role'] == 'system' %}"
      "{{ messages[0]['content'] + eos_token }}"
      "{% set loop_messages = messages[1:] %}"
    "{% else %}"
      # Otherwise, inject our system_prompt + <eos>:
      "{{ '{system_prompt}' + eos_token }}"
      "{% set loop_messages = messages %}"
    "{% endif %}"

    # Now loop over the remaining messages (either user or assistant):
    "{% for message in loop_messages %}"
      "{% if message['role'] == 'user' %}"
        "{{ message['content'] }}"
      "{% elif message['role'] == 'assistant' %}"
        "{{ message['content'] + eos_token }}"
      "{% endif %}"
    "{% endfor %}"

    # If we asked for "add_generation_prompt", append <REASONING> to the end:
    "{% if add_generation_prompt %}{{ '{reasoning_start}' }}"
    "{% endif %}"
)
# Replace with out specific template:
tokenizer.chat_template = chat_template\
    .replace("'{system_prompt}'",   f"'{SYSTEM_PROMPT}'")\
    .replace("'{reasoning_start}'", f"'{REASONING_START}'")

In [7]:
example_messages = [ # Quick sanity check of the template
    {'role': 'user', 'content': 'Which country has the highest population density?'},
    {'role': 'assistant', 'content': (
        f'{REASONING_START}'
        'I know that country X is small in area but has a huge population, '
        'so its people per square kilometer is extremely high.'
        f'{REASONING_END}{SOLUTION_START}Monaco{SOLUTION_END}'
    )},
    {'role': 'user', 'content': 'Which planet is farthest from the Sun?'},
]
print(tokenizer.apply_chat_template(example_messages, tokenize=False, add_generation_prompt=True))

You are a mathematical reasoning assistant. When given a math problem:
1. Show your step-by-step work between <THINK> and </THINK>.
2. Provide your final numerical answer between <SOLUTION> and </SOLUTION>.
3. Be precise and show all calculation steps clearly.<|endoftext|>Which country has the highest population density?<THINK>I know that country X is small in area but has a huge population, so its people per square kilometer is extremely high.</THINK><SOLUTION>Monaco</SOLUTION><|endoftext|>Which planet is farthest from the Sun?<THINK>


# Pre Fine-tuning (SFT)

## Data preparation

In [8]:
# Use a subset of NVIDIA's Open Math Reasoning dataset, which was filtered to only include high quality DeepSeek R1 traces
sft_dataset = load_dataset('unsloth/OpenMathReasoning-mini', split='cot').to_pandas()
sft_dataset = sft_dataset[['expected_answer', 'problem', 'generated_solution']]

# Try converting to number - if not, replace with NaN
is_number = pd.to_numeric(pd.Series(sft_dataset['expected_answer']), errors='coerce').notnull()
sft_dataset = sft_dataset.iloc[np.where(is_number)[0]] # Select only numbers
sft_dataset

README.md:   0%|          | 0.00/603 [00:00<?, ?B/s]

data/cot-00000-of-00001.parquet:   0%|          | 0.00/106M [00:00<?, ?B/s]

Generating cot split:   0%|          | 0/19252 [00:00<?, ? examples/s]

Unnamed: 0,expected_answer,problem,generated_solution
0,14,Given $\sqrt{x^2+165}-\sqrt{x^2-52}=7$ and $x$...,"<think>\nOkay, let's see. I need to solve the ..."
6,-2,Find the value of the parameter $a$ for which ...,"<think>\nOkay, so I need to find the value of ..."
9,18,What is the sum of all real numbers $x$ for wh...,"<think>\nOkay, so I need to solve the equation..."
13,2,Evaluate the sum \(\sum_{n=1}^\infty \frac{\ph...,"<think>\nOkay, so I need to evaluate the infin..."
17,30,What is the largest positive integer that divi...,"<think>\nAlright, so I need to find the larges..."
...,...,...,...
19243,244,"Let \( p \), \( q \), and \( r \) be the disti...","<think>\nOkay, so I need to find the value of ..."
19245,1,A bug is on the $0$ of a number line. At any p...,"<think>\nOkay, so I have this problem where a ..."
19247,4,A bus left point X for point Y. Two hours late...,"<think>\nOkay, let's tackle this problem step ..."
19248,18,Each interior angle of a regular n-gon measure...,"<think>\nOkay, let's see. I need to find the n..."


In [9]:
def format_dataset(x): # Format the dataset to follow our GRPO style formatting
    expected_answer = x['expected_answer']
    problem = x['problem']

    # Remove generated <think> and </think>
    thoughts = x['generated_solution'].replace('<think>', '').replace('</think>', '')
    thoughts = thoughts.strip()

    # Add our custom formatting
    final_prompt = REASONING_START + thoughts + REASONING_END + \
                   SOLUTION_START + expected_answer + SOLUTION_END
    return [
        {'role': 'system'   , 'content': SYSTEM_PROMPT},
        {'role': 'user'     , 'content': problem},
        {'role': 'assistant', 'content': final_prompt},
    ]

sft_dataset['messages'] = sft_dataset.apply(format_dataset, axis=1)
print(tokenizer.apply_chat_template(sft_dataset['messages'][0], tokenize=False))

You are a mathematical reasoning assistant. When given a math problem:
1. Show your step-by-step work between <THINK> and </THINK>.
2. Provide your final numerical answer between <SOLUTION> and </SOLUTION>.
3. Be precise and show all calculation steps clearly.<|endoftext|>Given $\sqrt{x^2+165}-\sqrt{x^2-52}=7$ and $x$ is positive, find all possible values of $x$.<THINK>Okay, let's see. I need to solve the equation √(x² + 165) - √(x² - 52) = 7, and find all positive values of x. Hmm, radicals can be tricky, but maybe if I can eliminate the square roots by squaring both sides. Let me try that.

First, let me write down the equation again to make sure I have it right:

√(x² + 165) - √(x² - 52) = 7.

Okay, so the idea is to isolate one of the radicals and then square both sides. Let me try moving the second radical to the other side:

√(x² + 165) = 7 + √(x² - 52).

Now, if I square both sides, maybe I can get rid of the square roots. Let's do that:

(√(x² + 165))² = (7 + √(x² - 52))².

Sim

In [10]:
# Truncate pre fine-tuning sft_dataset to max_seq_length / 2 since we don't want too long reasoning traces
sft_dataset['seq_length'] = sft_dataset['messages'].apply(lambda x: len(tokenizer.apply_chat_template(x)))
print('Token-length percentiles (50/90/99):', np.percentile(sft_dataset['seq_length'], [50, 90, 99]))

threshold = max_seq_length / 2
sft_dataset_filtered = sft_dataset.loc[sft_dataset['seq_length'] <= threshold].copy()
print(f'Remaining for training (<= {threshold} tokens): {len(sft_dataset_filtered)}/{len(sft_dataset)}')

sft_dataset_filtered['text'] = tokenizer.apply_chat_template(sft_dataset_filtered['messages'].values.tolist(), tokenize=False)
sft_dataset_filtered = Dataset.from_pandas(sft_dataset_filtered)
sft_dataset_filtered

Token-length percentiles (50/90/99): [ 3729.    9034.   15685.84]
Remaining for training (<= 1024.0 tokens): 51/7507


Dataset({
    features: ['expected_answer', 'problem', 'generated_solution', 'messages', 'seq_length', 'text', '__index_level_0__'],
    num_rows: 51
})

## Pre fine-tune to understand custom GRPO formatting

In [11]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=sft_dataset_filtered,
    args=SFTConfig(
        dataset_text_field='text',
        num_train_epochs=3,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
        optim='adamw_8bit',
        weight_decay=0.01,
        learning_rate=2e-4,
        lr_scheduler_type='cosine',
        warmup_steps=5,
        logging_steps=5,
        seed=3407,
        report_to='none', # Use this for WandB
    )
)
trainer.train()

Unsloth: Tokenizing ["text"] (num_proc=16):   0%|          | 0/51 [00:00<?, ? examples/s]

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 51 | Num Epochs = 3 | Total steps = 153
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 1 x 1) = 1
 "-____-"     Trainable parameters = 34,865,152 of 1,755,440,128 (1.99% trained)


Step,Training Loss
5,0.8731
10,0.7529
15,0.5826
20,0.4971
25,0.5201
30,0.4614
35,0.4004
40,0.4738
45,0.4496
50,0.4534


Unsloth: Will smartly offload gradients to save VRAM!


TrainOutput(global_step=153, training_loss=0.34412570365893297, metrics={'train_runtime': 75.8841, 'train_samples_per_second': 2.016, 'train_steps_per_second': 2.016, 'total_flos': 1196951737651200.0, 'train_loss': 0.34412570365893297, 'epoch': 3.0})

## Check if model has learnt to follow the format

In [12]:
text = tokenizer.apply_chat_template( # Render into a single string and append <REASONING> for generation
    sft_dataset_filtered[1]['messages'][:2],
    tokenize=False, add_generation_prompt=True, # Append the final <REASONING>
)
_ = model.generate(
    **tokenizer(text, return_tensors='pt').to('cuda'),
    temperature=0, max_new_tokens=1024,
    streamer=TextStreamer(tokenizer, skip_prompt=False), # Stream the model's generations (CoT + solution)
)

You are a mathematical reasoning assistant. When given a math problem:
1. Show your step-by-step work between <THINK> and </THINK>.
2. Provide your final numerical answer between <SOLUTION> and </SOLUTION>.
3. Be precise and show all calculation steps clearly.<|endoftext|>What is the average book width, in centimeters, of five books with the following widths: $6$, $\frac{1}{2}$, $1$, $2.5$, and $10$?<THINK>Okay, let's see. I need to find the average width of five books. The widths given are 6 cm, 1/2 cm, 1 cm, 2.5 cm, and 10 cm. Hmm, average is when you add up all the numbers and then divide by how many there are. So first, I should add these numbers together. Let me write them down: 6, 0.5 (which is 1/2), 1, 2.5, and 10. 

Adding them step by step. Let's start with 6 and 0.5. 6 plus 0.5 is 6.5. Then add 1 to that. 6.5 plus 1 is 7.5. Next, add 2.5. 7.5 plus 2.5 is 10. Then add the last number, 10. So 10 plus 10 is 20. So the total sum is 20 centimeters. 

Now, there are 5 books, so to 

In [13]:
del sft_dataset, sft_dataset_filtered
gc.collect()
torch.cuda.empty_cache()

# Post Fine-tuning (RL)

## Data preparation

In [8]:
def process_dataset_sample(example): # Convert GSM8K example to conversation format for GRPO training
    return {
        'prompt': [ # Create conversation with system prompt for structured reasoning
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': example['question']},
        ],
        # Extract numerical answer from GSM8K format ('Explanation... #### 42') as Ground truth for reward functions
        'answer': example['answer'].split('####')[1].strip() if '####' in example['answer'] else None
    }

In [9]:
# train_dataset = load_dataset('openai/gsm8k', 'main', split=['train[:10%]'])
train_dataset = load_dataset('openai/gsm8k', 'main', split='train')
train_dataset = train_dataset.map(process_dataset_sample)

print(f'Training samples: {len(train_dataset):,}\n'
      f"- Sample question: {train_dataset[0]['prompt'][1]['content']}\n"
      f"- Sample answer: {train_dataset[0]['answer']} (ground truth for rewards)\n"
      f"- Prompt (system + user):\n{train_dataset[0]['prompt']}")

Training samples: 7,473
- Sample question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
- Sample answer: 72 (ground truth for rewards)
- Prompt (system + user):
[{'content': 'You are a mathematical reasoning assistant. When given a math problem:\n1. Show your step-by-step work between <THINK> and </THINK>.\n2. Provide your final numerical answer between <SOLUTION> and </SOLUTION>.\n3. Be precise and show all calculation steps clearly.', 'role': 'system'}, {'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?', 'role': 'user'}]


In [10]:
# Get the top 90% prompt length so we don't accidentally truncate them, i.e. we'll remove the top 10% long prompts
tokenized_dataset = train_dataset.map(
    lambda x: {'tokens': tokenizer.apply_chat_template(x['prompt'], add_generation_prompt=True, tokenize=True)},
    batched=True,
).map(lambda x: {'length': len(x['tokens'])})
print(tokenizer.decode(tokenized_dataset[0]['tokens']))

thresholds = np.percentile(tokenized_dataset['length'], [50, 90, 99])
max_prompt_length = int(thresholds[1])
print('Token-length percentiles (50/90/99):', thresholds, '=> Choose max_prompt_length =', max_prompt_length)

# Filter only samples smaller than 90% max length
train_dataset = train_dataset.select(np.where(np.array(tokenized_dataset['length']) <= max_prompt_length)[0])
print(f'Remaining for training (<= {max_prompt_length} tokens): {len(train_dataset)}/{len(tokenized_dataset)}')
del tokenized_dataset

You are a mathematical reasoning assistant. When given a math problem:
1. Show your step-by-step work between <THINK> and </THINK>.
2. Provide your final numerical answer between <SOLUTION> and </SOLUTION>.
3. Be precise and show all calculation steps clearly.<|endoftext|>Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<THINK>
Token-length percentiles (50/90/99): [119. 152. 191.] => Choose max_prompt_length = 152
Remaining for training (<= 152 tokens): 6749/7473


## Regex Patterns

In [11]:
# Match the reasoning sections and answers
match_format = re.compile(
    # rf'^[\s]{{0,}}'                                     # Optional whitespace at start
    # rf'{REASONING_START}.+?{REASONING_END}.*?'          # Reasoning section (non-greedy)
    rf'{REASONING_END}.*?'                              # We always prepend REASONING_START
    rf'{SOLUTION_START}(.+?){SOLUTION_END}'             # Solution section with capture group
    rf'[\s]{{0,}}(?:{re.escape(tokenizer.eos_token)})?' # Add optional EOS token matching
    rf'[\s]{{0,}}$',                                    # Optional whitespace at end
    flags=re.MULTILINE | re.DOTALL,                     # Multi-line matching with . matching newlines
)
match_format.findall( # Verify it works
    f'{REASONING_START}Let me think!{REASONING_END}'\
    f'{SOLUTION_START}\n2\n{SOLUTION_END}\n\n',
)

['\n2\n']

In [12]:
# Sometimes it might not be 1 number as the answer, but like a sentence.
# For example: 'The solution is $20' -> we extract 20
# We also remove possible commas for example as in 123,456
match_numbers = re.compile(
    rf'{SOLUTION_START}.*?[\s]{{0,}}([-]?[\d\.\,]{{1,}})', # Extract numbers from solution section
    flags=re.MULTILINE | re.DOTALL,  # Flexible pattern matching
)
print(match_numbers.findall('<SOLUTION>  0.34  </SOLUTION>'))
print(match_numbers.findall('<SOLUTION>  123,456  </SOLUTION>'))
print(match_numbers.findall('<SOLUTION>  -0.234  </SOLUTION>'))
print(match_numbers.findall('<SOLUTION>17</SOLUTION>'))

['0.34']
['123,456']
['-0.234']
['17']


## Multi-reward design

In [13]:
def match_format_strictly(completions, **kwargs) -> list[float]:
    ''' Reward Function 1: Exact Format Compliance
    High reward (3.0) for perfect format adherence
    Ensures model learns the complete structured output pattern
    '''
    return [
        3.0 if match_format.search(completion[0]['content']) else 0.0
        for completion in completions
    ]

In [14]:
# If it fails, reward the model if it at least follows the format partially, by counting each symbol
def match_format_softly(completions, **kwargs) -> list[float]:
    ''' Reward Function 2: Partial Format Credit
    Graduated scoring for format elements
    Encourages learning individual components even if not perfect
    '''
    rewards = []
    for completion in completions:
        reward = 0
        response = completion[0]['content']

        # Count how many keywords are seen - we penalize if too many!
        # Award +0.5 for correct token count, -0.5 for wrong count
        # No need to reward REASONING_START since we always prepend it!
        # reward += 0.5 if response.count(REASONING_START) == 1 else -0.5
        reward += 0.5 if response.count(REASONING_END) == 1 else -0.5
        reward += 0.5 if response.count(SOLUTION_START) == 1 else -0.5
        reward += 0.5 if response.count(SOLUTION_END) == 1 else -0.5
        rewards.append(reward)
    return rewards

In [15]:
# Extract the generated answer, and reward or penalize it
def check_answer_correctness(completions, answer, **kwargs) -> list[float]:
    ''' Reward Function 3: Graduated scoring for mathematical accuracy
    - 5.0: Exact string match gets full points
    - 2.0: Within 10% (close answer)
    - 1.5: Within 20% (reasonable attempt)
    - -2.5: Wrong answer (penalty for incorrect math)
    '''
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [ # Extract answers using format pattern
        guess.group(1) if (guess := match_format.search(r)) else None
        for r in responses
    ]
    rewards = []
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None: # No extractable answer
            rewards.append(-2.0)
            continue

        if guess == true_answer: rewards.append(5.0)                   # Correct answer gets 5 points!
        elif guess.strip() == true_answer.strip(): rewards.append(3.5) # Match if spaces are seen, but less reward
        else: # Try numerical comparison for partial credit
            try: # We also reward it based on how close the answer is to the true one via ratios
                ratio = float(guess) / float(true_answer)     # If the answer is within some range, reward it!
                if 0.9 <= ratio <= 1.1: rewards.append(2.0)   # Within 10%
                elif 0.8 <= ratio <= 1.2: rewards.append(1.5) # Within 20%
                else: rewards.append(-2.5)                    # Penalize wrong answers
            except (ValueError, ZeroDivisionError):
                rewards.append(-4.5)                          # Invalid numerical format
    return rewards

In [16]:
def check_numbers_extraction(prompts, completions, answer, **kwargs) -> list[float]:
    ''' Reward Function 4: Number Extraction Ability
    Tests the model's ability to extract numerical values from solution sections
    Complementary to exact format matching - focuses on parsing capability
    '''
    question = prompts[0][-1]['content'] # Exclude system prompt
    responses = [completion[0]['content'] for completion in completions]

    extracted_responses = [ # Extract numbers from solution sections using number pattern
        guess.group(1) if (guess := match_numbers.search(r)) else None
        for r in responses
    ]
    rewards = []

    # Print only every few steps
    check_numbers_extraction.counter = getattr(check_numbers_extraction, 'counter', 0) + 1
    if check_numbers_extraction.counter % 100 == 0:
        print(
            '==' * 100,
            f'\nQuestion: {question}'
            f'\nPrediction: {extracted_responses[0]}, GT Answer: {answer[0]}'
            f'\nResponse:\n{responses[0]}'
        )
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None: # No extractable number
            rewards.append(-2.5)
            continue

        try: # Simple numerical equality check
            true_val = float(true_answer.strip())             # Convert to numbers
            guess_val = float(guess.strip().replace(',', '')) # Remove commas like in 123,456
            rewards.append(3.5 if guess_val == true_val else -1.5)
        except (ValueError, TypeError):
            rewards.append(0) # Invalid number format
    return rewards

## GRPO training setup

In [17]:
max_prompt_length = 152 + 1 # + 1 just in case!
max_completion_length = max_seq_length - max_prompt_length
vllm_sampling_params = SamplingParams(
    min_p = 0.1,
    top_p = 1.0,
    top_k = -1,
    stop = [tokenizer.eos_token],
    include_stop_str_in_output = True,
)

In [18]:
training_args = GRPOConfig(          # Configure GRPO training parameters for mathematical reasoning
    output_dir=f'/tmp/{model_name}', # Directory for checkpoints and logs
    vllm_sampling_params=vllm_sampling_params,
    # Training speed control
    num_train_epochs=1,              # Total number of training epochs
    per_device_train_batch_size=2,   # Small batch for GPU memory constraints
    gradient_accumulation_steps=8,   # Effective batch size = 2 * 8 = 16
    # Computing the loss: https://huggingface.co/docs/trl/main/grpo_trainer#computing-the-loss
    scale_rewards='batch',           # Calculate mean at local/group level and std at global/batch level enables more robust reward shaping
    loss_type='dr_grpo',             # Fully remove response length bias, dividing by a constant instead of the sequence length
    # Precision & Optimization
    optim='adamw_8bit',              # adamw_torch_fused, adamw_8bit, paged_adamw_8bit
    weight_decay=0.1,                # Regularization
    max_grad_norm=0.1,               # Aggressive gradient clipping for stable training
    gradient_checkpointing=True,
    bf16=torch.cuda.is_available(),  # Enable mixed-precision training if a CUDA GPU is available (faster, less memory)
    # Learning rate scheduling
    learning_rate=1e-5,              # Conservative LR to prevent destabilizing reasoning
    warmup_ratio=0.1,
    lr_scheduler_type='cosine_with_min_lr',
    lr_scheduler_kwargs=dict(min_lr=1e-6),
    # Generation control
    temperature=1.0,
    num_generations=2,                           # Default: 8 generations per step
    max_prompt_length=max_prompt_length,         # Default: 512. Sufficient for complex word problems
    max_completion_length=max_completion_length, # Default: 256. Room for detailed step-by-step reasoning
    # Reporting and saving
    report_to='wandb',
    logging_steps=10,
    logging_strategy='steps',
    save_total_limit=1,
    max_steps=100,
    # For optional evaluation
    # per_device_eval_batch_size=4,
    # bf16_full_eval=torch.cuda.is_available(),
    # eval_strategy='steps',                       # Evaluate after each epoch
    # load_best_model_at_end=True,                 # Load the best model based on validation loss
)

## Train the model

In [25]:
%%time
trainer = GRPOTrainer(            # Initialize GRPO trainer with multi-reward system
    model=model,                  # LoRA-adapted quantized model
    processing_class=tokenizer,
    train_dataset=train_dataset,  # Processed GSM8K dataset
    args=training_args,           # Training configuration
    reward_funcs=[                # 4 complementary reward functions
        match_format_strictly,    # Perfect structure compliance
        match_format_softly,      # Partial format credit
        check_answer_correctness, # Mathematical accuracy
        check_numbers_extraction, # Number parsing ability
    ]
)
trainer.train()
trainer.save_model(f'./{model_name}_grpo')

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 6,749 | Num Epochs = 1 | Total steps = 100
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 8 x 1) = 16
 "-____-"     Trainable parameters = 34,865,152 of 1,755,440,128 (1.99% trained)
  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33m18520339[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference, mcp, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,kl,rewards / match_format_strictly / mean,rewards / match_format_strictly / std,rewards / match_format_softly / mean,rewards / match_format_softly / std,rewards / check_answer_correctness / mean,rewards / check_answer_correctness / std,rewards / check_numbers_extraction / mean,rewards / check_numbers_extraction / std
10,0.053,10.865625,4.441904,835.2625,446.4,1587.4,0.04375,787.431207,446.4,1377.4,0.383798,2.85,0.457409,1.3625,0.407409,3.946875,2.219778,2.70625,1.658784
20,0.1379,10.709375,4.913189,789.94375,477.7,1614.3,0.03125,753.260077,477.7,1299.2,0.301318,2.86875,0.429939,1.3875,0.40247,3.828125,2.552707,2.625,1.86255
30,0.2469,11.13125,4.56005,801.7375,483.0,1470.3,0.025,774.304181,483.0,1236.1,0.307156,2.90625,0.32747,1.41875,0.30246,4.05,2.282706,2.75625,1.825852
40,0.0669,11.4125,3.47258,807.01875,472.5,1552.4,0.025,779.038049,472.5,1407.9,0.311523,2.90625,0.270934,1.41875,0.253078,4.20625,1.707592,2.88125,1.358486
50,0.1852,11.11875,4.154785,774.50625,440.8,1682.5,0.03125,737.067517,440.8,1280.6,0.27451,2.86875,0.47747,1.39375,0.40246,4.05625,2.011801,2.8,1.543678
60,0.2087,10.615625,4.381963,806.09375,450.6,1520.4,0.03125,772.862524,450.6,1383.7,0.296932,2.8875,0.40247,1.4125,0.30247,3.721875,2.307364,2.59375,1.769619
70,0.2193,10.946875,4.341668,775.9,443.5,1617.6,0.04375,725.158234,443.5,1371.8,0.289751,2.86875,0.429939,1.36875,0.429939,4.00625,2.09538,2.703125,1.652172
80,0.1535,10.6375,4.603097,800.45,450.3,1531.8,0.01875,778.101678,450.3,1298.5,0.299942,2.90625,0.32747,1.40625,0.32747,3.7625,2.321766,2.5625,1.775115
90,0.0336,11.346875,4.099768,752.5,456.9,1468.9,0.01875,730.906262,456.9,1320.4,0.288545,2.94375,0.225,1.44375,0.225,4.134375,2.096604,2.825,1.668246
100,0.1313,11.584375,3.374236,795.54375,449.1,1493.4,0.00625,788.530005,449.1,1420.0,0.279975,2.98125,0.075,1.48125,0.075,4.253125,1.818517,2.86875,1.483587


Question: James wants to build a 16-foot by 20-foot quilt.  He uses patches that are each 4 square feet.  The first 10 patches cost $10 each and then each patch after that cost half as much.  How much do the patches for the quilt cost?
Prediction: 450, GT Answer: 450
Response:
Okay, let's try to figure out how much James needs to spend on patches for his 16 by 20-foot quilt. The patches are each 4 square feet, so I first need to calculate the total number of patches required.

The quilt is 16 feet by 20 feet, so the area is 16 * 20 = 320 square feet. Each patch covers 4 square feet. So dividing the total area by the size of each patch, that would be 320 / 4 = 80 patches needed. If each of the first 10 patches costs $10, then those would be 10 * 10 = $100. The rest would be 80 - 10 = 70 patches. Since each subsequent patch costs half as much, $10 divided by 2 is $5 per patch. So for the remaining 70 patches, that would be 70 * 5 = $350. Adding the two amounts together, $100 + $350 = $45

# Evaluation

## Resource usage

In [26]:
# Memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)

print(f'GPU = {gpu_stats.name}. Max memory = {max_memory} GB.')
print(f'{start_gpu_memory} GB of memory reserved.')

GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.557 GB.
38.051 GB of memory reserved.


In [27]:
# Extract runtime info
last_log = trainer.state.log_history[-1] # Final memory and time stats
train_seconds = last_log['train_runtime']
samples_per_second = last_log.get('train_samples_per_second', None)

# Recompute GPU memory stats
used_memory   = round(torch.cuda.max_memory_reserved() / 1024**3, 2)
used_for_lora = round(used_memory - start_gpu_memory, 2)
used_pct      = round(used_memory / max_memory * 100, 2)
lora_pct      = round(used_for_lora / max_memory * 100, 2)

print(f'Training time: {train_seconds:.1f} seconds ({train_seconds / 60:.2f} minutes)')
if samples_per_second: print(f'Throughput: {samples_per_second:.1f} samples/second')
print(f'Peak VRAM usage: {used_memory} GB ({used_pct}% of max memory)')
print(f'VRAM for training: {used_for_lora} GB ({lora_pct}% of max memory)')

Training time: 1853.0 seconds (30.88 minutes)
Throughput: 0.9 samples/second
Peak VRAM usage: 38.05 GB (96.19% of max memory)
VRAM for training: -0.0 GB (-0.0% of max memory)


## Verify LoRA is actually trained

In [24]:
example_text = 'What is the sqrt of 101?'
# example_text = 'Solve (x + 2)^2 = 0'
# example_text = "How many r's are in strawberry?"

sampling_params = SamplingParams(
    temperature=1.0,
    top_k=50,
    max_tokens=max_completion_length,
)
print(model.fast_generate( # Try the model without any GRPO trained
    example_text, sampling_params=sampling_params,
    lora_request=None
)[0].outputs[0].text)

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 How do you use a calculator to find the nth root? Can you show me how to find the square root and cube root of 101 by doing long division and the Newton-Raphson Method? Can you explain how you arrived at the square root and cube root?

To find the square root of 101, we can use a few methods. First, using a calculator, the square root of 101 is approximately 10.049875621. For other roots, using a calculator is the easiest method. 

To find the nth root of 101 by long division, we can break it down into steps of finding integer values for the nth root. Here’s how you can do it:

1. **Estimate the integer part of the root:**
   - Since \(10^2 = 100\) and \(11^2 = 121\), the integer part of the root must be 10 because \(10^2 < 101 < 11^2\).

2. **Subtract the largest perfect square less than 101:**
   - \(101 - 10^2 = 101 - 100 = 1\).

3. **Divide the remainder by the integer part and adjust the exponent:**
   - Now, we have \(10^{(0.1)} = 10.0499\) (approximately) because \(101 \approx 

In [19]:
tensors = {}
with safe_open(f'./{model_name}_grpo/adapter_model.safetensors', framework='pt') as f:
    for key in f.keys(): # Verify both A and B are non zero
        tensor = f.get_tensor(key)
        n_zeros = (tensor == 0).sum() / tensor.numel()
        assert(n_zeros.item() != tensor.numel())

In [26]:
# Load the LoRA and test without using system prompt
# which should not (or minimal) affect the model's original reasoning ability
text = tokenizer.apply_chat_template(
    [{'role': 'user', 'content': example_text}],
    add_generation_prompt=True, tokenize=False,
)
print(model.fast_generate(
    text, sampling_params=sampling_params,
    lora_request=model.load_lora(f'./{model_name}_grpo'),
)[0].outputs[0].text)

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Okay, let me see. I need to find the square root of 101. Hmm, square root... So, I'm looking for a number that when multiplied by itself gives 101. Let me think. Let me start with some numbers I know. Like, 10 squared is 100, right? That's close to 101. So, maybe 10 is a bit too low. Let me try a little higher. Let's take 11. 11 squared is 121. Oh, that's too high. So, between 10 and 11, the square root has to be some number between those two. Let me try 10.5. Let's calculate 10.5 squared. 10.5 times 10.5. Let me do that step by step. 10*10 is 100, then 10*0.5 is 5, add those to get 105. Then 0.5*10 is 5, plus 5 more is 110. So, 10.5 squared is 110.25. That's a little high, but way above 101. Hmm, 10.4 maybe? Let me check. 10.4 squared. 10*10 is 100, 10*0.4 is 4, so 104. Plus 0.4*0.4 is 0.16, total 104.16. That's closer but still a bit high. Let me try 10.2. 10.2 squared. 10*10=100, 10*0.2=2, 0.2*10=2, plus 0.2*0.2=0.04, total 102.04. That's lower than 101. So, between 10.2 and 10.4. L

In [21]:
# Test using system prompt
text = tokenizer.apply_chat_template([
    {'role': 'system', 'content': SYSTEM_PROMPT},
    {'role': 'user'  , 'content': example_text},
], add_generation_prompt=True, tokenize=False)

# Compare results with system prompt but without LoRA
print(model.fast_generate(
    text, sampling_params=sampling_params,
    lora_request=None,
)[0].outputs[0].text)

# Reasoning model is much better - it's not always correct, since we only trained it for an hour
# It'll be better if we extend the sequence length and train for longer
print(model.fast_generate(
    text, sampling_params=sampling_params,
    lora_request=model.load_lora(f'./{model_name}_grpo'),
)[0].outputs[0].text)

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 We need to find the square root of 101.
We can estimate this using known perfect squares.
The perfect squares closest to 101 are $36^2 = 1296$ and $26^2 = 676$.
Since 101 is closer to 676, we estimate the square root of 101 to be around 10.<SOLUTION> The square root of 101 is around 10, as it lies between 10 and 11.<SOLUTION>


Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Okay, so I need to find the square root of 101. Hmm, square roots... I remember that some numbers like 100 have a square root that's a whole number, which is 10, but 101 doesn't seem to be a perfect square. Let me think.

Wait, maybe I can estimate it. Since 10² is 100, the square root of 101 should be slightly more than 10. So around 10.1? Let me check 10.1 squared. 10.1² would be 100 plus 2*10*0.1 + 0.1², which is 100 + 2 + 0.01 = 102.01. That's a bit higher than 101. So 10.1 is too high. Let's try 10.05. 10.05² would be 100 + 2*10*0.05 + 0.05² = 100 + 1 + 0.0025 = 101.0025. That's very close to 101. So the square root must be 10.05. Or maybe even 10.0505? I don't think the decimal goes up that far. But let me check if there's a simpler way.

Is there a better method than trial and error? Maybe using a calculator? But since I don't have one right now, I'll stick with estimating. Alternatively, can 101 be expressed as a product of perfect squares? 101 divided by 100 is 1.01, which isn

## Performance on Test set

In [23]:
# test_dataset = load_dataset('openai/gsm8k', 'main', split=['test[:10%]'])
test_dataset = load_dataset('openai/gsm8k', 'main', split='test').map(process_dataset_sample)
test_texts = [
    tokenizer.apply_chat_template([
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': sample['prompt'][1]['content']},
    ], add_generation_prompt=True, tokenize=False)
for sample in test_dataset]
print(f'Testing samples:', len(test_dataset))

Testing samples: 1319


In [28]:
outputs_with_lora = model.fast_generate(
    test_texts, sampling_params=sampling_params,
    lora_request=model.load_lora(f'./{model_name}_grpo'),
)
outputs_without_lora = model.fast_generate(
    test_texts, sampling_params=sampling_params,
    lora_request=None,
)

Adding requests:   0%|          | 0/1319 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1319 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…

In [34]:
# Compare the correct amount of using and not using LoRA
lora_correct_count = no_lora_correct_count = 0
num_test_samples = len(test_dataset)

for output_with_lora, output_without_lora in zip(outputs_with_lora, outputs_without_lora):
    if match_format.search(output_with_lora.outputs[0].text): lora_correct_count += 1
    if match_format.search(output_without_lora.outputs[0].text): no_lora_correct_count += 1

print(f'With LoRA: {lora_correct_count}/{num_test_samples} ({lora_correct_count / num_test_samples * 100:.2f}%)')
print(f'No LoRA: {no_lora_correct_count}/{num_test_samples} ({no_lora_correct_count / num_test_samples * 100:.2}%)')
print(f'Improvement: +{lora_correct_count - no_lora_correct_count} correct responses with LoRA')

With LoRA: 1258/1319 (95.38%)
No LoRA: 437/1319 (3.3e+01%)
Improvement: +821 correct responses with LoRA


# Inference

In [19]:
def generate_with_reasoning(questions, max_length=512):
    conversations = [[                        # Format input using conversation template
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': question},
    ] for question in questions]

    prompts = [tokenizer.apply_chat_template( # Apply chat template and tokenize
        conversation,
        add_generation_prompt=True,         # Add assistant prompt
        tokenize=False,                     # Return string, not tokens
    ) for conversation in conversations]

    # Generate response with reasoning-optimized parameters
    inputs = tokenizer(prompts, return_tensors='pt', padding=True).to(model.device)
    start_time = time.time()
    with torch.no_grad():
        output_ids = model.generate(           # Generate response with reasoning-optimized parameters
            **inputs,
            max_new_tokens=max_length,
            temperature=0.7,                # Balance creativity and consistency
            top_p=0.9,                      # Nucleus sampling for quality
            do_sample=True,                 # Enable sampling for varied reasoning paths
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1,         # Reduce repetitive reasoning steps
            length_penalty=1.0,             # Neutral preference for response length
            early_stopping=True,            # Stop at natural completion
            # streamer=TextStreamer(tokenizer, skip_prompt=True),
        )
    end_time = time.time()
    inference_duration = end_time - start_time
    num_generated_tokens = output_ids.shape[1] - inputs['input_ids'].shape[1]

    output_ids = output_ids[:, inputs['input_ids'][0].shape[-1]:output_ids.shape[-1]]
    responses = tokenizer.batch_decode(output_ids, skip_special_tokens=True) # Decode and extract only the generated portion
    return responses, inference_duration, num_generated_tokens

In [None]:
test_dataset = load_dataset('openai/gsm8k', 'main', split='test').map(process_dataset_sample)
gsm8k_question = test_dataset[0]['question']
expected_answer = test_dataset[0]['answer']

gsm8k_response, inference_duration, num_generated_tokens = generate_with_reasoning(gsm8k_question, max_length=768)
print('Question:', gsm8k_question)
print('Response:', gsm8k_response)
print('Inference time (secs):', inference_duration)
print('Generated tokens:', num_generated_tokens)

In [None]:
# Validate format compliance
has_reasoning = REASONING_START in gsm8k_response and REASONING_END in gsm8k_response
has_solution = SOLUTION_START in gsm8k_response and SOLUTION_END in gsm8k_response
print('Reasoning section:', has_reasoning)
print('Solution section:', has_solution)

if has_solution: # Check answer accuracy if solution section exists
    try:
        solution_text = gsm8k_response.split(SOLUTION_START)[1].split(SOLUTION_END)[0].strip()
        extracted_number = ''.join(filter(str.isdigit, solution_text))
        expected_number = ''.join(filter(str.isdigit, expected_answer))
        print('Extracted:', solution_text)
        print('Expected:', expected_answer)
        print('Correct:', extracted_number == expected_number)
    except:
        print('Could not extract solution')