In [None]:
!pip install "typing_extensions==4.11.0"
!pip install unsloth vllm
!pip install --upgrade pillow
!pip install diffusers
!pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b
!pip install wandb

# IMPORTANT restart the notebook after running this cell

In [11]:
import sys
import torch

print("Python version:", sys.version)
print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

Python version: 3.11.10 (main, Sep  7 2024, 18:35:41) [GCC 11.4.0]
PyTorch version: 2.5.1+cu124
CUDA version: 12.4
GPU 0: NVIDIA H200


In [2]:
from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 02-09 21:26:50 __init__.py:190] Automatically detected platform cuda.


In [3]:
import re
import torch
import wandb
from datasets import Dataset, load_dataset
from trl import GRPOConfig, GRPOTrainer
from unsloth import is_bfloat16_supported
from vllm import SamplingParams

In [4]:
# trainer config
MAX_PROMPT_LENGTH = 256
MAX_COMPLETION_LENGTH = 1024
NUM_GENERATIONS = 12
MAX_STEPS = 1000

# model config
LORA_RANK = 64
GPU_MEMORY_UTILIZATION = 0.7
MAX_SEQ_LENGTH = 1024 + 256 + 8

SYSTEM_PROMPT = """\
You are a helpful assistant. You first think about the reasoning process and then provide the user with the answer.

Put your thinking process in <reasoning> tags.
- As you're reasoning, say "Wait," and think more to check your work until you're confident.

Put just the final answer in <answer> tags.

Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

In [5]:
def extract_completion_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()


def extract_completion_reasoning(text: str) -> str:
    reasoning = text.split("<reasoning>")[-1]
    reasoning = reasoning.split("</reasoning>")[0]
    return reasoning.strip()


def debug_reward(q, answer, responses, extracted_responses, rewards):
    print(
        "--------------",
        f"Q: {q}",
        "--------------",
        f"A: {answer[0]}",
        "--------------",
        f"Response:\n{responses[0]}",
        "--------------",
        f"Extracted: {extracted_responses}",
        "--------------",
        f"Reward: {rewards}",
        "--------------",
        sep="\n",
    )

In [6]:
def correctness_reward(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]["content"] for completion in completions]
    q = prompts[0][-1]["content"]
    extracted_responses = [extract_completion_answer(r) for r in responses]

    def extract_float(text: str):
        text_no_commas = text.replace(",", "")
        match = re.search(r"([+-]?\d+(?:\.\d+)?)", text_no_commas)
        if match:
            try:
                return float(match.group(1))
            except ValueError:
                return None
        return None

    rewards = []
    for extracted_text, gold_text in zip(extracted_responses, answer):
        num_extracted = extract_float(extracted_text)
        num_gold = extract_float(gold_text)

        if (num_extracted is not None) and (num_gold is not None):
            if abs(num_extracted - num_gold) < 1e-6:
                rewards.append(2.0)
            else:
                rewards.append(0.0)
        else:
            rewards.append(0.0)

    debug_reward(q, answer, responses, extracted_responses, rewards)

    return rewards


def strict_format_reward(completions, **kwargs) -> list[float]:
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n?$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]


def soft_format_reward(completions, **kwargs) -> list[float]:
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]


def xmlcount_reward(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]

    def count_xml(text) -> float:
        count = 0.0
        if text.count("<reasoning>\n") == 1:
            count += 0.125
        if text.count("\n</reasoning>\n") == 1:
            count += 0.125
        if text.count("\n<answer>\n") == 1:
            count += 0.125
            count -= len(text.split("\n</answer>\n")[-1]) * 0.001
        if text.count("\n</answer>") == 1:
            count += 0.125
            count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
        return count

    return [count_xml(c) for c in contents]


def length_reward(completions, **kwargs) -> list[float]:
    chars_per_token = 4
    responses = [completion[0]["content"] for completion in completions]
    reasoning = [extract_completion_reasoning(r) for r in responses]
    return [2 * (len(r)) / (MAX_COMPLETION_LENGTH * chars_per_token) for r in reasoning]

In [None]:
wandb.login()

In [8]:
def extract_dataset_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()


def get_gsm8k_questions(split="train") -> Dataset:
    data = load_dataset("openai/gsm8k", "main")[split]  # type: ignore
    data = data.map(  # type: ignore
        lambda x: {
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": x["question"]},
            ],
            "answer": extract_dataset_answer(x["answer"]),
        }
    )
    return data


dataset = get_gsm8k_questions()

In [9]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="Qwen/Qwen2.5-3B-Instruct",
    max_seq_length=MAX_SEQ_LENGTH,
    load_in_4bit=True,
    fast_inference=True,
    max_lora_rank=LORA_RANK,
    gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=LORA_RANK,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=LORA_RANK,
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)

==((====))==  Unsloth 2025.2.5: Fast Qwen2 patching. Transformers: 4.48.3.
   \\   /|    GPU: NVIDIA H200. Max memory: 139.827 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu124. CUDA: 9.0. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.28.post3. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/qwen2.5-3b-instruct-unsloth-bnb-4bit with actual GPU utilization = 69.74%
Unsloth: Your GPU has CUDA compute capability 9.0 with VRAM = 139.83 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 1288. Num Sequences = 400.
Unsloth: vLLM's KV Cache can use up to 95.1 GB. Also swap space = 6 GB.
INFO 02-09 21:26:59 config.py:542] This model supports multiple tasks: {'generate', 'score', 'reward', 'classify', 'embed'}. Defaulting to 'generate'.
Unsloth: vLLM Bitsandbytes config using kwargs = {'l



INFO 02-09 21:27:00 weight_utils.py:252] Using model weights format ['*.safetensors']


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 02-09 21:27:01 model_runner.py:1115] Loading model weights took 2.2160 GB
INFO 02-09 21:27:01 punica_selector.py:18] Using PunicaWrapperGPU.
INFO 02-09 21:27:03 worker.py:267] Memory profiling takes 1.23 seconds
INFO 02-09 21:27:03 worker.py:267] the current vLLM instance can use total_gpu_memory (139.83GiB) x gpu_memory_utilization (0.70) = 97.52GiB
INFO 02-09 21:27:03 worker.py:267] model weights take 2.22GiB; non_torch_memory takes 0.16GiB; PyTorch activation peak memory takes 2.20GiB; the rest of the memory reserved for KV Cache is 92.95GiB.
INFO 02-09 21:27:03 executor_base.py:110] # CUDA blocks: 169212, # CPU blocks: 10922
INFO 02-09 21:27:03 executor_base.py:115] Maximum concurrency for 1288 tokens per request: 2102.01x
INFO 02-09 21:27:07 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory e

Capturing CUDA graph shapes: 100%|██████████| 53/53 [00:23<00:00,  2.28it/s]

INFO 02-09 21:27:30 model_runner.py:1562] Graph capturing finished in 23 secs, took 1.40 GiB
INFO 02-09 21:27:30 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 28.64 seconds



Unsloth 2025.2.5 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


In [12]:
training_args = GRPOConfig(
    use_vllm=True,
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_8bit",
    logging_steps=1,
    bf16=is_bfloat16_supported(),
    fp16=not is_bfloat16_supported(),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    num_generations=NUM_GENERATIONS,
    max_prompt_length=MAX_PROMPT_LENGTH,
    max_completion_length=MAX_COMPLETION_LENGTH,
    # num_train_epochs = 1, # set to 1 for full training run
    max_steps=MAX_STEPS,
    save_steps=500,
    max_grad_norm=0.1,
    report_to="wandb",
    output_dir="qwen3b-grpo",
)

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward,
        soft_format_reward,
        strict_format_reward,
        correctness_reward,
        length_reward,
    ],  # type: ignore
    args=training_args,
    train_dataset=dataset,
)

trainer.train()

torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. In order to use Torch DDP, launch your script with `python -m torch.distributed.launch
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 7,473 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 1 | Gradient Accumulation steps = 1
\        /    Total batch size = 1 | Total steps = 1,000
 "-____-"     Number of trainable parameters = 119,734,272
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


--------------
Q: Ahmed and Emily are having a contest to see who can get the best grade in the class. There have been 9 assignments and Ahmed has a 91 in the class. Emily has a 92. The final assignment is worth the same amount as all the other assignments. Emily got a 90 on the final assignment. What is the minimum grade Ahmed needs to get to beat Emily if all grades are whole numbers?
--------------
A: 100
--------------
Response:
<reasoning>
Wait, I need to determine the minimum grade Ahmed needs to surpass Emily's grade. First, I'll calculate Emily's current average with her 90 on the final assignment. Then, I'll figure out what Emily's total score would be if she got a 90 on the final. After that, I'll determine the minimum grade Ahmed needs on his final to beat this total.

Ahmed's current average: 91
Emily's current average: 92

The total score for 9 assignments: (91 * 9) + ε (Emily's final assignment) = Emily's total/9 = 92

Emily's total score with 90 on the final: 90 + (91 * 

Step,Training Loss,reward,reward_std,completion_length,kl
1,-0.0,0.937976,0.584777,452.5,0.0
2,0.0,2.433657,0.929238,266.833344,0.0
3,0.0,3.256969,0.256403,201.333344,0.000508
4,0.0,2.430234,0.989559,280.0,0.000384
5,0.0,2.087403,1.097325,278.583344,0.000848
6,0.0001,1.966974,1.035773,175.666672,0.001467
7,0.0001,1.498066,0.822459,174.916672,0.00138
8,0.0,3.199475,0.538804,194.583344,0.000506
9,0.0,2.615821,0.924391,169.083344,0.000647
10,0.0001,2.287473,1.284899,141.416672,0.001525


--------------
Q: The gauge on a water tank shows that the tank is 1/3 full of water. To fill the tank, 16 gallons of water are added. How many gallons of water does the tank hold when full?
--------------
A: 24
--------------
Response:
<reasoning>
Let's denote the total capacity of the water tank as \( T \) gallons.
Given that the tank is initially \( \frac{1}{3} \) full and adding 16 gallons fills the tank to its capacity, we can set up the following equation:
\[
\frac{T}{3} + 16 = T
\]

To solve for \( T \), we can start by eliminating the fraction:
\[
\frac{T}{3} + 16 = T
\]

Multiply every term by 3 to clear the fraction:
\[
T + 48 = 3T
\]

Subtract \( T \) from both sides to isolate \( T \):
\[
48 = 2T
\]

Now, divide both sides by 2:
\[
T = 24
\]

Thus, the total capacity \( T \) of the water tank is 24 gallons.
</reasoning>
<answer>
24
</answer>
--------------
Extracted: ['24', '24', '48', '24', '24', '24', '24', '24', '48', '48', '24', '24']
--------------
Reward: [2.0, 2.0, 0

TrainOutput(global_step=1000, training_loss=0.17141419771954497, metrics={'train_runtime': 11392.4532, 'train_samples_per_second': 0.088, 'train_steps_per_second': 0.088, 'total_flos': 0.0, 'train_loss': 0.17141419771954497})

In [13]:
model.save_lora("grpo_saved_lora")

In [14]:
def get_completion(model, tokenizer, prompt, lora=None):
    if lora:
        messages = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}]
    else:
        messages = [{"role": "user", "content": prompt}]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1024)
    return model.fast_generate(text, sampling_params=sampling_params, lora_request=lora)[0].outputs[0].text

In [36]:
prompt = "How many r's in strawberry?"
print(get_completion(model, tokenizer, prompt, lora=None))
print(get_completion(model, tokenizer, prompt, lora=model.load_lora("grpo_saved_lora")))

Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  6.93it/s, est. speed input: 250.93 toks/s, output: 111.51 toks/s]


In the word "strawberry," there are 2 r's.


Processed prompts: 100%|██████████| 1/1 [00:04<00:00,  4.25s/it, est. speed input: 25.40 toks/s, output: 114.06 toks/s]

<reasoning>
To determine the number of 'r's in the word "strawberry," we need to carefully count each instance of the letter 'r' within the word. The process should be thorough to avoid missing any instances. Let's examine the word step-by-piece.
First, let's look at the word "strawberry" and note the position of each letter.

1. The word "strawberry" consists of 10 letters.
2. I'll count the 'r's one by one:
   - First 'r' is in the 3rd position.
   - Second 'r' is in the 5th position.
   - Third 'r' is in the 7th position.

Now, let's check if there are any other instances of 'r':
- After the third 'r' in the 7th position, the next letter is 'b,' which is not an 'r.' We can confirm that there are no more 'r's after this point.
By this analysis, we can see that the 'r' letters appear at positions 3, 5, and 7. Thus, the total number of 'r's in the word "strawberry" is 3. 

To ensure accuracy, let's do a final check:
- The word "strawberry" = 10 letters.
- Counting the 'r's found: 3, 5,


