<a href="https://colab.research.google.com/github/AlperYildirim1/gemma-pipeline/blob/main/grpo_gemma.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### News

Unsloth now supports Text-to-Speech (TTS) models. Read our [guide here](https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning).

Read our **[Gemma 3N Guide](https://docs.unsloth.ai/basics/gemma-3n-how-to-run-and-fine-tune)** and check out our new **[Dynamic 2.0](https://docs.unsloth.ai/basics/unsloth-dynamic-2.0-ggufs)** quants which outperforms other quantization methods!

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm==0.8.5.post1

In [None]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm==0.8.5.post1
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

### Unsloth

Load up `Gemma 3 1B Instruct`, and set parameters

In [None]:
from unsloth import FastModel
import torch
max_seq_length = 2048

fourbit_models = [
    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-27b-it-unsloth-bnb-4bit",

    # Other popular models!
    "unsloth/Llama-3.1-8B",
    "unsloth/Llama-3.2-3B",
    "unsloth/Llama-3.3-70B",
    "unsloth/mistral-7b-instruct-v0.3",
    "unsloth/Phi-4",
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastModel.from_pretrained(
    model_name = "Yujivus/gemma-3-1b-sft1",
    max_seq_length = max_seq_length, # Choose any for long context!
    load_in_4bit = False,  # 4 bit quantization to reduce memory
    load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)

In [None]:
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, # Turn off for just text!
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_mlp_modules       = True,  # SHould leave on always!

    r = 32,           # Larger = higher accuracy, but might overfit
    lora_alpha = 64,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
)

In [None]:
print(model)

In [None]:
from datasets import load_dataset

dataset = load_dataset("Yujivus/mmlu_grpo", "default", split = "train")
dataset

In [None]:
dataset[1]["question"]

In [None]:
dataset[0]["answer"]

In [None]:
def extract_hash_answer(text):
    if "####" not in text: return None
    return text.split("####")[1].strip()
extract_hash_answer(dataset[0]["answer"])

In [None]:
reasoning_start = "<start_working_out>"
reasoning_end   = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

system_prompt = \
f"""You are given a problem.
Think about the problem and provide your working out.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}"""
system_prompt

In [None]:
dataset = dataset.map(lambda x: {
    "prompt" : [
        {"role": "system", "content": system_prompt},
        {"role": "user",   "content": x["question"]},
    ],
    "answer": extract_hash_answer(x["answer"]),
})
dataset[0]

We create a regex format to match the reasoning sections and answers:

In [None]:
import re

match_format = re.compile(
    rf"^[\s]{{0,}}"\
    rf"{reasoning_start}.+?{reasoning_end}.*?"\
    rf"{solution_start}(.+?){solution_end}"\
    rf"[\s]{{0,}}$",
    flags = re.MULTILINE | re.DOTALL
)

In [None]:
match_format.search(
    "<start_working_out>Let me think!<end_working_out>"\
    "<SOLUTION>2</SOLUTION>",
)

In [None]:
def match_format_exactly(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Match if format is seen exactly!
        if match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores

In [None]:
def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Count how many keywords are seen - we penalize if too many!
        # If we see 1, then plus some points!
        score += 0.2 if response.count(reasoning_start) == 1 else -0.5
        score += 0.2 if response.count(reasoning_end)   == 1 else -0.5
        score += 0.2 if response.count(solution_start)  == 1 else -0.5
        score += 0.2 if response.count(solution_end)    == 1 else -0.5
        scores.append(score)
    return scores

In [None]:
def extract_model_solution(completion_text: str) -> str | None:
    """Extracts text from between <SOLUTION> tags."""
    pattern = re.compile(f"{re.escape(solution_start)}(.*?){re.escape(solution_end)}", re.DOTALL)
    match = pattern.search(completion_text)
    if match: return match.group(1) # Return the raw text, including whitespace
    return None

def extract_all_option_pairs(question_text: str) -> list[str]:
    """
    Finds all multiple-choice option lines (e.g., "A. ...", "B. ...")
    and returns them as a list of exact strings.
    """
    pattern = re.compile(r"^[A-Z][\.\)].*", re.MULTILINE)
    # Strip whitespace from each found option for clean matching.
    return [opt.strip() for opt in pattern.findall(question_text)]
    import re

# ==============================================================================
# 1. HELPER FUNCTIONS (Unchanged)
# ==============================================================================

solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

def extract_model_solution(completion_text: str) -> str | None:
    pattern = re.compile(f"{re.escape(solution_start)}(.*?){re.escape(solution_end)}", re.DOTALL)
    match = pattern.search(completion_text)
    if match: return match.group(1)
    return None

def extract_all_option_pairs(question_text: str) -> list[str]:
    pattern = re.compile(r"^[A-Z][\.\)].*", re.MULTILINE)
    return [opt.strip() for opt in pattern.findall(question_text)]

# ==============================================================================
# 2. THE REWARD FUNCTION WITH REVISED SCORING
# ==============================================================================
def check_for_unique_pair_final_batched(completions, **kwargs):
    """
    Looks for the presence of exact "option-answer pair" strings.
    THIS VERSION IS BATCH-AWARE and uses the correct data access pattern.

    - Rewards +5.0 if exactly one correct pair is found.
    - Penalizes -2.0 if more than one pair is found (ambiguous).
    - All other cases result in a neutral 0.0 score.
    """
    all_questions_in_batch = kwargs.get("question")
    all_answers_in_batch = kwargs.get("answer")

    if not all_questions_in_batch or not all_answers_in_batch:
        return [0.0] * len(completions)

    scores = []
    for i, completion in enumerate(completions):
        question = all_questions_in_batch[i]
        ground_truth_pair = all_answers_in_batch[i]

        # --- THIS IS THE CORRECTED LINE ---
        # `completion` is a list like [{"content": "..."}], so we access it with [0] then ["content"]
        response_text = completion[0]["content"]

        solution_text = extract_model_solution(response_text)

        if solution_text is None:
            scores.append(0.0)
            continue

        all_possible_pairs = extract_all_option_pairs(question)
        found_pairs = [pair for pair in all_possible_pairs if pair in solution_text]

        if len(found_pairs) > 1:
            scores.append(-2.0)
        elif len(found_pairs) == 1:
            if found_pairs[0] == ground_truth_pair:
                scores.append(5.0)
            else:
                scores.append(0.0)
        else: # len(found_pairs) == 0
            scores.append(0.0)

    return scores


In [None]:
import re

# ==============================================================================
# 1. HELPER FUNCTIONS (Confirmed Safe)
# ==============================================================================

def extract_model_solution(completion_text: str) -> str | None:
    pattern = re.compile(f"{re.escape(solution_start)}(.*?){re.escape(solution_end)}", re.DOTALL)
    match = pattern.search(completion_text)
    if match: return match.group(1)
    return None

def extract_all_option_pairs(question_text: str) -> list[str]:
    pattern = re.compile(r"^[A-Z][\.\)].*", re.MULTILINE)
    return [opt.strip() for opt in pattern.findall(question_text)]

# ==============================================================================
# 2. THE FINAL, BATCH-AWARE REWARD FUNCTION (Confirmed Correct)
# ==============================================================================

def check_for_unique_pair_final_batched(completions, **kwargs):
    all_questions_in_batch = kwargs.get("question")
    all_answers_in_batch = kwargs.get("answer")

    if not all_questions_in_batch or not all_answers_in_batch:
        return [0.0] * len(completions)

    scores = []
    for i, completion in enumerate(completions):
        question = all_questions_in_batch[i]
        ground_truth_pair = all_answers_in_batch[i]

        response_text = completion[0]["content"]
        solution_text = extract_model_solution(response_text)

        if solution_text is None:
            scores.append(0.0)
            continue

        all_possible_pairs = extract_all_option_pairs(question)
        found_pairs = [pair for pair in all_possible_pairs if pair in solution_text]

        if len(found_pairs) > 1:
            scores.append(-2.0)
        elif len(found_pairs) == 1:
            if found_pairs[0] == ground_truth_pair:
                scores.append(5.0)
            else:
                scores.append(0.0)
        else:
            scores.append(0.0)

    return scores

# ==============================================================================
# 3. TEST DATA AND THE UPDATED TEST RUNNER
# ==============================================================================

sample_data = {
    "question": """A 73-year-old woman...deficiency of which of the following vitamins?

A. Vitamin B1 (thiamine)
B. Vitamin B2 (riboflavin)
C. Vitamin B6 (pyridoxine)
D. Vitamin B12 (cyanocobalamin)""",
    "answer": "D. Vitamin B12 (cyanocobalamin)"
}

final_test_cases = [
    {"description": "✅ [Correct] Simple case: Contains only the correct PAIR", "completion": [{"content": f"<SOLUTION>The correct choice is D. Vitamin B12 (cyanocobalamin).</SOLUTION>"}], "expected_score": 5.0},
    {"description": "✅ [Correct] Your key example: Mentions other letters, but only one full PAIR", "completion": [{"content": f"<SOLUTION>The answer is not B or C. The correct answer is D. Vitamin B12 (cyanocobalamin).</SOLUTION>"}], "expected_score": 5.0},
    {"description": "❌ [Penalized] Ambiguous: Contains two PAIRS", "completion": [{"content": f"<SOLUTION>It could be A. Vitamin B1 (thiamine). However, the better answer is D. Vitamin B12 (cyanocobalamin).</SOLUTION>"}], "expected_score": -2.0},
    {"description": "➖ [Neutral] Incorrect: Contains only one, but incorrect, PAIR", "completion": [{"content": f"<SOLUTION>The answer is definitely A. Vitamin B1 (thiamine).</SOLUTION>"}], "expected_score": 0.0},
    {"description": "➖ [Neutral] No PAIR: Only mentions the letter", "completion": [{"content": f"<SOLUTION>The answer is D.</SOLUTION>"}], "expected_score": 0.0},
]

# --- THE CORRECTED TEST RUNNER ---
print("🚀 Starting FINAL BATCH-AWARE Test Suite (Reward: +5)...\n" + "="*60)
all_passed = True
for test in final_test_cases:
    description = test["description"]
    completion_to_test = test["completion"] # This is `[{"content": "..."}]`
    expected = test["expected_score"]

    # THIS IS THE FIX: We now pass the data in the same list format the trainer uses.
    actual_scores = check_for_unique_pair_final_batched(
        completions=[completion_to_test],      # Simulates a batch: [[{"content":...}]]
        question=[sample_data["question"]],    # Simulates a batch: ["The question..."]
        answer=[sample_data["answer"]]         # Simulates a batch: ["D. The answer..."]
    )
    actual = actual_scores[0]

    print(f"🧪 TESTING: {description}")
    print(f"   - Expected Score: {expected}")
    print(f"   - Actual Score:   {actual}")

    if actual == expected:
        print("   - STATUS: ✅ PASS")
    else:
        print(f"   - STATUS: ❌ FAIL")
        all_passed = False
    print("-" * 60)

print("\n🏁 Test Suite Finished.")
if all_passed:
    print("🎉 All tests passed successfully!")
else:
    print("🔥 Some tests failed. Please review the output above.")

<a name="Train"></a>
### Train the model

Now set up GRPO Trainer and all configurations!

In [None]:
max_prompt_length = 512

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_torch_fused",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 12, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    num_train_epochs = 1, # Set to 1 for a full training run
    #max_steps = 50,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 12


In [None]:

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        match_format_exactly,
        match_format_approximately,
        check_for_unique_pair_final_batched,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()


<a name="Inference"></a>
### Inference
Now let's try the model we just trained!

In [None]:
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user",   "content": """An expected side effect of creatine supplementation is:

A. muscle weakness.
B. gain in body mass.
C. muscle cramps.
D. loss of electrolytes."""},
]

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    tokenize = False,
)
from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    max_new_tokens = 2000, # Increase for longer outputs!
    # Recommended Gemma-3 settings!
    temperature = 1.0, top_p = 0.95, top_k = 64,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

In [None]:
#model.save_pretrained("gemma-3-grpo1")  # Local saving
#tokenizer.save_pretrained("gemma-3")
model.push_to_hub("Yujivus/gemma-3_sft_grpo", token = "") # Online saving
tokenizer.push_to_hub("Yujivus/gemma-3_sft_grpo", token = "") # Online saving

In [None]:
# ==============================================================================
# FINAL CELL: SHUTDOWN RUNTIME
# ==============================================================================

print("\n--- All tasks complete. Shutting down the Colab runtime to save resources. ---")
print("You will be disconnected shortly.")

# This is the official command to disconnect and terminate the Colab runtime.
from google.colab import runtime
runtime.unassign()