# Let's make Gemma 3 think! 🏥 🧠

In this notebook we'll use TRL and `GRPOTrainer` to make Gemma3 think before it answers.

👩‍🎓 If you want to learn more about making models think and reason, check out [The Reasoning Course](https://huggingface.co/reasoning-course)

### Installation

In [None]:
# install this release tag of transformers
#!pip install -qqq git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3 git+https://github.com/huggingface/trl.git@main bitsandbytes

In [1]:
from huggingface_hub import login
import os
import wandb
import datetime

login(token=os.getenv('hf_api_wtoken'))
os.environ["WANDB_INIT_TIMEOUT"] ='120'
current_datetime = datetime.datetime.now()
formatted_datetime = current_datetime.strftime('%Y%m%dT%H:%M:%S')

wandb.login()
wandb_project = "gemma3-12b-grpo-firstaid"
wandb.init(
    project=wandb_project, name=formatted_datetime, entity="alfredcs_team",
)

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Currently logged in as: [33malfredcs[0m ([33malfredcs_team[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


CommError: failed to upsert bucket: returned error 422: {"data":{"upsertBucket":null},"errors":[{"message":"Error 3988 (HY000): Conversion from collation utf8mb4_0900_ai_ci into utf8mb3_general_ci impossible for parameter","path":["upsertBucket"]}]}

In [2]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, AutoTokenizer
from peft import LoraConfig, get_peft_model
ckpt = "unsloth/gemma-3-27b-it"

model = AutoModelForImageTextToText.from_pretrained(
    ckpt, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
)
tokenizer = AutoTokenizer.from_pretrained(ckpt)
# Load LoRA
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=16, # R the rank, high for more parameters to be fin tuned and low for faster training.
    lora_alpha=32, # alpha = rank is scaling weights at x times when mergered with main weights.
    target_modules="all-linear",
)
model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters())

processor = AutoProcessor.from_pretrained(ckpt)

Fetching 12 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [02:39<00:00, 13.25s/it]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [03:24<00:00, 17.06s/it]


trainable params: 122,211,840 || all params: 27,554,618,480 || trainable%: 0.4435
None


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


### Process data to create reasoning chains

Borrowing from [Will Brown's gist](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) we'll make reasoning chains from GSM8k.

In [4]:
import re
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_firstaid_questions(split = "train") -> Dataset:
    #data = load_dataset('lextale/FirstAidInstructionsDataset', 'main')[split] # type: ignore
    data = load_dataset('lextale/FirstAidInstructionsDataset')[split]
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_firstaid_questions()
#dataset_train_test = dataset.train_test_split(test_size=0.1)

Generating Superdataset split: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 75688/75688 [00:00<00:00, 325129.26 examples/s]
Generating train split: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 267/267 [00:00<00:00, 144612.50 examples/s]
Generating chatGPTGenerated split: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 267/267 [00:00<00:00, 185349.08 examples/s]
Generating MedQuAD split: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [None]:
dataset[2]

# Reward Functions

Now, let's define reward functions. These are the functions we'll need to setup reward chains.

Define additional Reward Functions (future usages), based on this [blog](https://medium.com/@mb20261/llm-by-examples-fine-tune-llama-3-1-with-unsloth-and-distilled-grpo-on-your-computer-f6a3a78dc0f1)

In GRPO fine tuning, multiple reward functions are recommended to evaluate and enhance the performance. In our case, we will need below reward functions:

* The `correctness_reward_func` primarily focuses on rewarding exact matches between model responses and correct answers
* The `int_reward_func` incentivizes numeric responses.
* Two formatting functions, `strict_format_reward_func` and `soft_format_reward_func`, assess adherence to an XML-like structure with varying levels of strictness, rewarding responses that conform to predefined formats.
* The `xmlcount_reward_func` uses this to score the completions. Together, these functions guide the model to produce accurate, well-structured, and appropriately formatted responses during training, ultimately improving its performance and reliability.
  

| Reward Function | Purpose |
|---|---|
| `correctness_reward_func` | Rewards the model when its answer matches the correct answer |
| `int_reward_func` | Rewards the model for providing a numeric answer |
| `strict_format_reward_func` and `soft_format_reward_func` | Reward the model for following the specified format |
| `xmlcount_reward_func` | Rewards proper XML tag usage and penalizes extra content after the closing tags |

In [None]:
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

### Extra reward functions

In [None]:
# We create a regex format to match the reasoning sections and answers:
import re

reasoning_start = "<start_working_out>"
reasoning_end   = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

match_format = re.compile(
    rf"^[\s]{{0,}}"\
    rf"{reasoning_start}.+?{reasoning_end}.*?"\
    rf"{solution_start}(.+?){solution_end}"\
    rf"[\s]{{0,}}$",
    flags = re.MULTILINE | re.DOTALL
)

# ... and we verify it works:
match_format.search(
    "<start_working_out>Let me think!<end_working_out>"\
    "<SOLUTION>2</SOLUTION>",
)

# An match_number method
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})",
    flags = re.MULTILINE | re.DOTALL
)
match_numbers.findall("<SOLUTION>  0.34  </SOLUTION>")

In [None]:
# We now want to create a reward function to match the format exactly - we reward it with 3 points if it succeeds:
def match_format_exactly(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Match if format is seen exactly!
        if match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores

#If it fails, we want to reward the model if it at least follows the format partially, by counting each symbol:
def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Count how many keywords are seen - we penalize if too many!
        # If we see 1, then plus some points!
        score += 0.5 if response.count(reasoning_start) == 1 else -0.5
        score += 0.5 if response.count(reasoning_end)   == 1 else -0.5
        score += 0.5 if response.count(solution_start)  == 1 else -0.5
        score += 0.5 if response.count(solution_end)    == 1 else -0.5
        scores.append(score)
    return scores

# Finally, we want to extract the generated answer, and reward or penalize it! We also reward it based on how close the answer is to the true one via ratios:
def check_answer(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_format.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(0)
            continue
        # Correct answer gets 3 points!
        if guess == true_answer:
            score += 3.0
        # Match if spaces are seen
        elif guess.strip() == true_answer.strip():
            score += 1.5
        else:
            # We also reward it if the answer is close via ratios!
            # Ie if the answer is within some range, reward it!
            try:
                ratio = float(guess) / float(true_answer)
                if   ratio >= 0.9 and ratio <= 1.1: score += 0.5
                elif ratio >= 0.8 and ratio <= 1.2: score += 0.25
                else: score -= 1.0 # Penalize wrong answers
            except:
                score -= 0.5 # Penalize
        scores.append(score)
    return scores

# Also sometimes it might not be 1 number as the answer, but like a sentence for example "The solution is $20" -> we extract 20.
def check_numbers(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_numbers.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    print('*'*20, f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
        # Convert to numbers
        try:
            true_answer = float(true_answer.strip())
            guess       = float(guess.strip())
            scores.append(1.5 if guess == true_answer else 0.0)
        except:
            scores.append(0)
            continue
    return scores

# Train with GRPOTrainer

Now we'll confgure training with the `GRPOConfig`


* learning_rate: The initial learning rate for the optimizer. A value of `5e-6` suggests that the model will take relatively small steps during the training process, which can help with convergence when fine-tuning.
* adam_beta1: The exponential decay rate for the first moment estimates in the Adam optimizer. A value of `0.9` is typical and helps balance the influence of past gradients.
* adam_beta2: The exponential decay rate for the second moment estimates in the Adam optimizer. A value of `0.99` is also standard and controls the influence of past squared gradients.
* weight_decay: The weight decay (L2 regularization) factor applied to the model weights during training, set to `0.1` in this case to help prevent overfitting.
* warmup_ratio: The ratio of warmup steps relative to the total training steps. A value of `0.1` indicates that 10% of the training steps will be used to gradually increase the learning rate from zero to the initial learning rate.
* lr_scheduler_type: Specifies the type of learning rate scheduler to use. In this case, “cosine” means the learning rate will follow a cosine decay schedule after the warmup period.
* optim: The optimization algorithm to be used during training. “paged_adamw_8bit” refers to an efficient version of the AdamW optimizer that caters to 8-bit models, potentially improving memory efficiency.
* logging_steps: The frequency (in terms of training steps) at which logging occurs. A value of `1` means logging will happen at every training step.
* bf16: This flag indicates whether to use bfloat16 (a 16-bit floating point format) for training, which can increase performance on compatible hardware.
* fp16: This flag indicates whether to use float16 for training, which helps reduce memory consumption and can speed up training, provided the hardware supports it. It’s set to the opposite of `bf16`, ensuring one of these formats is used for training.
* per_device_train_batch_size: The batch size per device during training. A value of `1` indicates that each processing unit (e.g., GPU) will handle one example at a time, which is suitable for limited memory scenarios.
* gradient_accumulation_steps: The number of steps to accumulate gradients before updating the model weights. A value of `1` means that the model’s weights are updated after every batch, while increasing this value allows for the effective batch size to be larger without needing more memory.
* num_generations: This parameter defines how many samples to generate during each training step. A value of `6` means that six completions will be generated for evaluation in this training process.
* max_prompt_length: The maximum length (in tokens) for the input prompts fed into the model. A value of `256` ensures that longer prompts are truncated to fit within this limit.
* max_completion_length: The maximum length (in tokens) for the model’s generated outputs. A value of `200` sets a cap on the response length.
* max_steps: The total number of training steps to run. A value of `350` indicates that the training process will stop after this many steps, regardless of the epochs.
* save_steps: This parameter specifies how often (in terms of training steps) to save the model’s weights during training. A value of `350` means the model will be saved at the end of the training.
* max_grad_norm: This parameter is used for gradient clipping to prevent exploding gradients. A value of `0.1` will clip gradients to a maximum norm of `0.1`.
* report_to: Specifies the platform to report logging information. The value `”none”` indicates that no reporting will be done, though it could be set to platforms like “wandb” for experiment tracking.
* output_dir: The directory where the model checkpoints, logs, and other output files will be saved. The path `”outputs”` specifies that outputs will go into a subdirectory named “outputs” in the current working directory.

In [None]:
from trl import GRPOConfig, GRPOTrainer

max_prompt_length = 256
max_seq_length = 1024


training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 4, # was 2
    gradient_accumulation_steps = 4, # was 1
    num_generations = 2, # was 2
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    num_train_epochs = 1,
    max_steps = 250,
    save_steps = 250,
    max_grad_norm = 0.1,
    #report_to = "none",
    report_to=["wandb"],
)

# Training Run

In [None]:
processor.pad_token_id = 0
processor.bos_token_id = 1
processor.eos_token_id = 2

trainer = GRPOTrainer(
    model = model,
    processing_class = processor,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
        match_format_exactly, # Extra below
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    args = training_args,
    train_dataset = dataset,
)

In [None]:
# Train the model
#wandb.init(project=wandb_project, entity="alfredcs_team", name="ft-test-01")
trainer.train()

### Inference

In [None]:
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user",   "content": "What is the sqrt of 101?"},
]

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    tokenize = False,
)

from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    max_new_tokens = 64, # Increase for longer outputs!
    # Recommended Gemma-3 settings!
    temperature = 1.0, top_p = 0.95, top_k = 64,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

### Saving, loading finetuned models
To save the final model as LoRA adapters, either use Huggingface's push_to_hub for an online save or save_pretrained for a local save.

[NOTE] This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!

In [None]:
model_save_name = "gemma-3-12b-firstaid"

In [None]:
model.save_pretrained(model_save_nam)  # Local saving
tokenizer.save_pretrained(model_save_nam)
model.push_to_hub(f"alfredcs/{model_save_name}") # Online saving
# Make sure to push token as well
tokenizer.push_to_hub(f"alfredcs/{model_save_name}") # Online saving

### Saving to float16 for VLLM
We also support saving to float16 directly for deployment! We save it in the folder gemma-3-finetune. Set if False to if True to let it run!

In [None]:
# If you want to upload / push to your Hugging Face account, set if False to if True and add your Hugging Face token and upload location!
if False: # Change to True to save finetune!
    model.save_pretrained_merged("gemma-3-finetune", tokenizer)

# If you want to upload / push to your Hugging Face account, set if False to if True and add your Hugging Face token and upload location!
if False: # Change to True to upload finetune
    model.push_to_hub_merged(
        "HF_ACCOUNT/gemma-3-finetune", tokenizer,
        token = "hf_..."
    )

# GGUF / llama.cpp Conversion
# To save to GGUF / llama.cpp, we support it natively now for all models! For now, you can convert easily to Q8_0, F16 or BF16 precision. Q4_K_M for 4bit will come later!
                                                                                                                                              
if False: # Change to True to save to GGUF
    model.save_pretrained_gguf(
        "gemma-3-finetune",
        quantization_type = "Q8_0", # For now only Q8_0, BF16, F16 supported
    )

In [None]:
from transformers import pipeline

question = "How often does a patient need to repeat endoscopy if he is found to have metaplasia but no dysplasia the last 2 exams during 5 years?"
generator = pipeline("text-generation", model=trainer.model, tokenizer=processor.tokenizer)
input = processor.apply_chat_template([{"role": "user", "content": question}])
input + "<reasoning>"
output = generator(input, max_new_tokens=1024)

In [None]:
output

In [None]:
## Merge and save
merged_model = trainer.model.merge_and_unload()
merged_model.save_pretrained("gemma-3-12b-firstaid-merged")

In [None]:
merged_model.push_to_hub(
    "torchrun-gemma-3-12b-grpo-firstaid-merged", private=False, tags=["GRPO", "Reasoning-Course"]
)
# Make sure to push token as well
tokenizer.push_to_hub("alfredcs/torchrun-gemma-3-12b-grpo-firstaid-merged") # Online saving

### Test on the merged model

In [None]:
from transformers import pipeline

question = "How often does a patient need to repeat endoscopy if he is found to have metaplasia but no dysplasia the last 2 exams during 5 years?"
generator = pipeline("text-generation", model=merged_model, tokenizer=processor.tokenizer)
input = processor.apply_chat_template([{"role": "user", "content": question}])
input + "<reasoning>"
output = generator(input, max_new_tokens=1024)

In [None]:
output

# Next Steps!

Checkout the [The Reasoing Course](https://huggingface.co/reasoning-course) for more info on GRPO.

In the coming days we'll release a version of this notebook with Unsloth!

<a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>