# LoRA **GRPO** fine-tuning `Qwen2.5 0.5B` on a single T4


In [1]:
%%capture
!pip install uv
!pip install trl
!uv pip install --system triton==2.2.0
!uv pip install --system vllm
!uv pip install --system bitsandbytes

In [2]:
import os
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer
from peft import LoraConfig

INFO 05-26 07:22:52 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-26 07:22:53 [__init__.py:239] Automatically detected platform cuda.


In [3]:
R1_STYLE_SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user
with the answer. The reasoning process and answer are enclosed within <reasoning> </reasoning> and
<answer> </answer> tags, respectively, i.e., <reasoning> reasoning process here </reasoning>
<answer> answer here </answer>."""

TASK_SPECIFIC_INSTRUCTIONS = "The answer must be a single integer."


def preprocess_dataset(dataset_name, split="train", chunk_size=1000) -> Dataset:
    dataset = load_dataset(dataset_name, 'main')[split]
    total_samples = len(dataset)
    print(f"Loaded {total_samples} samples")

    def extract_hash_answer(text: str) -> str | None:
        try:
            return text.split("####")[1].strip()
        except IndexError:
            return None

    def process_batch(batch):
        prompts = [[
            {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + "\n" + TASK_SPECIFIC_INSTRUCTIONS},
            {'role': 'user', 'content': "What is 2+2?"},
            {'role': 'assistant', 'content': "<reasoning>To calculate 2+2, we simply add the numbers together: 2 + 2 = 4.</reasoning>\n<answer>4</answer>"},
            {'role': 'user', 'content': q.strip()}
        ] for q in batch['question']]

        return {
            'prompt': prompts,
            'answer': [extract_hash_answer(a) for a in batch['answer']]
        }

    return dataset.map(process_batch, batched=True, batch_size=chunk_size)

dataset_name = 'openai/gsm8k'
dataset = preprocess_dataset(dataset_name, chunk_size=500)

def extract_xml_answer(text: str) -> str:
    try:
        answer = text.split("<answer>")[-1].split("</answer>")[0].strip()
        return answer
    except IndexError:
        return ""


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loaded 7473 samples


# Understanding the Dataset Structure

This dataset represents a mathematical word problem along with its expected solution format. The main structure contains three key components:

1. **Question Component**: Contains the primary word problem about Natalia's clip sales: "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"

2. **Answer Component**: Stores the numerical solution: 72

3. **Prompt Component**: Demonstrates the expected solution format through a series of conversation examples:
   - A system message outlining the required format using XML tags (`<reasoning>` and `<answer>`)
   - A simple example using "What is 2+2?" to demonstrate proper formatting
   - An assistant's response showing how to structure the reasoning and answer
   - The original Natalia question repeated in the conversation format


In [4]:
dataset[0]

{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
 'answer': '72',
 'prompt': [{'content': 'A conversation between User and Assistant. The user asks a question, and the Assistant solves it.\nThe assistant first thinks about the reasoning process in the mind and then provides the user\nwith the answer. The reasoning process and answer are enclosed within <reasoning> </reasoning> and\n<answer> </answer> tags, respectively, i.e., <reasoning> reasoning process here </reasoning>\n<answer> answer here </answer>.\nThe answer must be a single integer.',
   'role': 'system'},
  {'content': 'What is 2+2?', 'role': 'user'},
  {'content': '<reasoning>To calculate 2+2, we simply add the numbers together: 2 + 2 = 4.</reasoning>\n<answer>4</answer>',
   'role': 'assistant'},
  {'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in Ma

# Reward functions

In [5]:
def format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has the correct format."""
    pattern = r"^<reasoning>.*?</reasoning>\s*<answer>.*?</answer>$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [bool(re.match(pattern, r)) for r in responses]
    print(''.join('⭐' if match else '❌' for match in matches))
    return [1.0 if match else 0.0 for match in matches]

#Correctness function that shows only one generation
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """Reward function that checks if the answer is correct."""
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print(f"Question: {prompts[0][-1]['content']}\nAnswer: {answer[0]}\nResponse: {responses[0]}\nExtracted: {extracted_responses[0]}")
    print(''.join('✅' if r == a else '❌' for r, a in zip(extracted_responses, answer)))
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

# Correctness function that shows multiple generation
# Cut and paste this into the code section to use it

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """Reward function that checks if the answer is correct."""
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    
    # Print all responses
    for i, (response, extracted) in enumerate(zip(responses, extracted_responses)):
        print(f"\nGeneration {i+1}:")
        print(f"Question: {prompts[0][-1]['content']}")
        print(f"Answer: {answer[0]}")
        print(f"Response: {response}")
        print(f"Extracted: {extracted}")
    
    # Compare each response with the answer
    print(''.join('✅' if r == answer[0] else '❌' for r in extracted_responses))
    return [2.0 if r == answer[0] else 0.0 for r in extracted_responses]

# Traning Settings

In [6]:
# @title
# model_name = "Qwen/Qwen2.5-0.5B"
model_name = "Qwen/Qwen2.5-0.5B-Instruct"

output_dir = f"outputs/{model_name.split('/')[-1]}-GRPO"
run_name = f"{model_name.split('/')[-1]}-{dataset_name.split('/')[-1]}"


# Set memory-related environment variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

In [13]:
max_prompt_length=256
max_completion_length=512

peft_config = LoraConfig(
    r=16,                             # Rank of the LoRA decomposition (low-rank matrix size)
    lora_alpha=64,                   # Scaling factor for the LoRA update
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],            # Which modules to apply LoRA to
    task_type="CAUSAL_LM",           # Tells PEFT it's for autoregressive language modeling
    lora_dropout=0.05,               # Dropout applied to LoRA layers during training
)

training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=1e-5,
    beta=0.005, # divergence coefficient – how much the policy is allowed to deviate from the reference model. higher value – more conservative updates. Default is 0.04
    optim="adamw_8bit",
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=4, #batch
    num_generations=4,  # group size
    gradient_accumulation_steps=4,
    max_prompt_length=max_prompt_length,
    max_completion_length=max_completion_length,
    num_train_epochs=1,
    save_steps=10,
    max_grad_norm=0.1,
    report_to="wandb",
    log_on_each_node=False,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    # attn_implementation="flash_attention_2", # T4 is not supported
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    model_max_length=training_args.max_completion_length,
)
tokenizer.pad_token = tokenizer.eos_token

# Initialize trainer
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        correctness_reward_func,
        format_reward_func
    ],
    args=training_args,
    train_dataset=dataset,
    peft_config=peft_config
)


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [21]:
!wandb login --relogin

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


In [22]:
trainer.train()

Question: Ahmed and Emily are having a contest to see who can get the best grade in the class. There have been 9 assignments and Ahmed has a 91 in the class. Emily has a 92. The final assignment is worth the same amount as all the other assignments. Emily got a 90 on the final assignment. What is the minimum grade Ahmed needs to get to beat Emily if all grades are whole numbers?
Answer: 100
Response: <reasoning>To determine the minimum grade Ahmed needs on the final assignment to beat Emily, we need to consider that Ahmed's maximum possible grade on the final assignment is 91 (his last possible score), and Emily's maximum possible grade is 100 (her last possible score).
The difference between Ahmed's and Emily's maximum possible grades is 100-91 = 9. This would occur if Ahmed gets a perfect score on the final assignment. To beat Emily, Ahmed needs a grade that is at least as high as Emily, which is 90.</reasoning>
<answer>89</answer>
Extracted: 89
❌❌❌❌❌❌❌✅✅❌✅❌❌❌❌❌
❌⭐❌❌⭐❌⭐❌⭐❌⭐⭐❌❌⭐❌


Step,Training Loss
1,0.0789
2,0.0359
3,0.0408
4,0.1606
5,-0.0018
6,0.1157
7,0.0906
8,0.263
9,0.0454
10,-0.0595


Question: In a graveyard, there are 20 skeletons.  Half of these skeletons are adult women, and the remaining number are split evenly between adult men and children.  If an adult woman has 20 bones in their body, and a male has 5 more than this, and a child has half as many as an adult woman, how many bones are in the graveyard?
Answer: 375
Response: To solve this, let's first determine the number of each type of skeleton:
- Total number of skeletons = 20
- Number of adult women = 20/2 = 10
- Remaining number of skeletons = 20 - 10 = 10
- Number of adult men = 10/2 = 5
- Number of children = 10/2 = 5

Now, let's determine the number of bones for each category:
- Each adult woman has 20 bones, so 10 adult women have 10 * 20 = 200 bones.
- Each male has 5 more bones (than an adult woman), so 5 adult men have 5 * 20 = 100 bones.
- Each child has half as many bones (than an adult woman), so 5 children have 20 / 2 = 10 bones.

The total number of bones in the graveyard is equal to the sum o

KeyboardInterrupt: 

In [23]:
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM

# Your last checkpoint
checkpoint_path = f"{output_dir}/checkpoint-10"

# Load adapter config
peft_cfg = PeftConfig.from_pretrained(checkpoint_path)

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(peft_cfg.base_model_name_or_path)

# Load adapter onto base
model = PeftModel.from_pretrained(base_model, checkpoint_path)

# Merge adapter into base weights
model = model.merge_and_unload()

# Save merged model (now a standalone full model)
merged_path = f"{output_dir}/merged"
model.save_pretrained(merged_path)

# Save tokenizer (so vLLM can tokenize properly)
tokenizer.save_pretrained(merged_path)

('outputs/Qwen2.5-0.5B-Instruct-GRPO/merged/tokenizer_config.json',
 'outputs/Qwen2.5-0.5B-Instruct-GRPO/merged/special_tokens_map.json',
 'outputs/Qwen2.5-0.5B-Instruct-GRPO/merged/vocab.json',
 'outputs/Qwen2.5-0.5B-Instruct-GRPO/merged/merges.txt',
 'outputs/Qwen2.5-0.5B-Instruct-GRPO/merged/added_tokens.json',
 'outputs/Qwen2.5-0.5B-Instruct-GRPO/merged/tokenizer.json')

# Eval

In [1]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from tqdm.notebook import tqdm
import numpy as np
from typing import List, Dict
import json
from datetime import datetime
import logging

# Disable VLLM's progress bars
logging.getLogger("vllm").setLevel(logging.WARNING)

# Constants from training script
R1_STYLE_SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user
with the answer. The reasoning process and answer are enclosed within <reasoning> </reasoning> and
<answer> </answer> tags, respectively, i.e., <reasoning> reasoning process here </reasoning>
<answer> answer here </answer>."""

TASK_SPECIFIC_INSTRUCTIONS = "The answer must be a single integer."

def extract_xml_answer(text: str) -> str:
    try:
        answer = text.split("<answer>")[-1].split("</answer>")[0].strip()
        return answer
    except IndexError:
        return ""

def extract_hash_answer(text: str) -> str | None:
    try:
        return text.split("####")[1].strip()
    except IndexError:
        return None

def evaluate_model(
    model_path: str,
    batch_size: int = 4,
    num_samples: int = None,
    save_results: bool = True,
    gpu_memory_utilization: float = 0.3,
) -> Dict:
    print("Initializing evaluation...")

    # Initialize VLLM with progress indicator
    with tqdm(total=2, desc="Loading model components") as pbar:
        llm = LLM(
            model=model_path,
            dtype="half",
            gpu_memory_utilization=gpu_memory_utilization,
            max_model_len=768,
            device="cuda:0",
            enable_chunked_prefill=True,
        )
        pbar.update(1)

        tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            model_max_length=768,
            padding_side='right',
            truncation_side='right'
        )
        pbar.update(1)

    # Set up sampling parameters
    sampling_params = SamplingParams(
        temperature=0.0,
        max_tokens=512,  # Matching max_completion_length from training
        stop_token_ids=[tokenizer.eos_token_id],
    )

    # Load test dataset
    print("Loading dataset...")
    dataset = load_dataset('openai/gsm8k', 'main', split='test')
    if num_samples:
        dataset = dataset.select(range(num_samples))
    total_samples = len(dataset)
    print(f"Loaded {total_samples} samples")

    results = []
    correct = 0
    total = 0

    # Create progress bar
    progress_bar = tqdm(
        total=total_samples,
        desc="Processing samples",
        unit="examples",
        dynamic_ncols=True,
    )

    progress_bar.set_postfix({
        'acc': '0.00%',
        'correct': '0',
    })

    # Process in batches
    for i in range(0, total_samples, batch_size):
        batch_data = dataset[i:i + batch_size]
        current_batch_size = len(batch_data['question'])

        # Prepare prompts using same format as training
        prompts = [
            [
                {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + "\n" + TASK_SPECIFIC_INSTRUCTIONS},
                {'role': 'user', 'content': "What is 2+2?"},
                {'role': 'assistant', 'content': "<reasoning>To calculate 2+2, we simply add the numbers together: 2 + 2 = 4.</reasoning>\n<answer>4</answer>"},
                {'role': 'user', 'content': q.strip()}
            ] for q in batch_data['question']
        ]

        # Convert to chat format
        formatted_prompts = [
            tokenizer.apply_chat_template(
                p,
                tokenize=False,
                add_generation_prompt=True
            )
            for p in prompts
        ]

        # Generate responses
        outputs = llm.generate(
            formatted_prompts,
            sampling_params,
        )

        # Process responses
        for j, output in enumerate(outputs):
            response = output.outputs[0].text

            # Extract answers
            generated_answer = extract_xml_answer(response)
            true_answer = extract_hash_answer(batch_data['answer'][j])

            # Store result
            result = {
                'question': batch_data['question'][j],
                'true_answer': true_answer,
                'generated_answer': generated_answer,
                'full_response': response,
                'correct': generated_answer == true_answer
            }
            results.append(result)

            # Update metrics
            if generated_answer == true_answer:
                correct += 1
            total += 1

        # Update progress
        progress_bar.update(current_batch_size)
        progress_bar.set_postfix({
            'acc': f'{(correct/total)*100:.2f}%',
            'correct': f'{correct}/{total}',
        })

    progress_bar.close()

    # Calculate metrics
    accuracy = correct / total if total > 0 else 0
    metrics = {
        'accuracy': accuracy,
        'correct': correct,
        'total': total,
        'model_path': model_path,
        'timestamp': datetime.now().isoformat()
    }

    # Save results
    if save_results:
        save_path = f"gsm8k_eval_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        with open(save_path, 'w') as f:
            json.dump({
                'metrics': metrics,
                'results': results
            }, f, indent=2)
        print(f"\nResults saved to {save_path}")

    return metrics

print("Starting GSM8K evaluation...")

checkpoint_path = "outputs/Qwen2.5-0.5B-Instruct-GRPO/merged"

metrics = evaluate_model(
    model_path=checkpoint_path,
    batch_size=4,
    num_samples=100,
    save_results=True,
    gpu_memory_utilization=0.3,
)

print("\nFinal Evaluation Results:")
print(f"Accuracy: {metrics['accuracy']:.2%}")
print(f"Correct: {metrics['correct']}/{metrics['total']}")

INFO 05-26 08:12:26 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 05-26 08:12:26 [__init__.py:239] Automatically detected platform cuda.
Starting GSM8K evaluation...
Initializing evaluation...


Loading model components:   0%|          | 0/2 [00:00<?, ?it/s]



Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


Capturing CUDA graph shapes:   0%|          | 0/35 [00:00<?, ?it/s]

Loading dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loaded 100 samples


Processing samples:   0%|          | 0/100 [00:00<?, ?examples/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]


Results saved to gsm8k_eval_results_20250526_081430.json

Final Evaluation Results:
Accuracy: 19.00%
Correct: 19/100
