<a href="https://colab.research.google.com/github/Ratapakorn/chatbot-profile/blob/main/nb/Gemma3_(1B)-GRPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

To install Unsloth your local device, follow [our guide](https://docs.unsloth.ai/get-started/install-and-update). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### News


Unsloth's [Docker image](https://hub.docker.com/r/unsloth/unsloth) is here! Start training with no setup & environment issues. [Read our Guide](https://docs.unsloth.ai/new/how-to-train-llms-with-unsloth-and-docker).

[gpt-oss RL](https://docs.unsloth.ai/new/gpt-oss-reinforcement-learning) is now supported with the fastest inference & lowest VRAM. Try our [new notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-GRPO.ipynb) which creates kernels!

Introducing [Vision](https://docs.unsloth.ai/new/vision-reinforcement-learning-vlm-rl) and [Standby](https://docs.unsloth.ai/basics/memory-efficient-rl) for RL! Train Qwen, Gemma etc. VLMs with GSPO - even faster with less VRAM.

Unsloth now supports Text-to-Speech (TTS) models. Read our [guide here](https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning).

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

In [1]:
%%capture
import os
os.environ["UNSLOTH_VLLM_STANDBY"] = "1" # [NEW] Extra 30% context lengths!
if "COLAB_" not in "".join(os.environ.keys()):
    # If you're not in Colab, just use pip install or uv pip install
    !pip install unsloth vllm
else:
    pass # For Colab / Kaggle, we need extra instructions hidden below \/

In [2]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
!pip install --upgrade -qqq uv
if "COLAB_" not in "".join(os.environ.keys()):
    # If you're not in Colab, just use pip install!
    !pip install unsloth vllm
else:
    try: import numpy, PIL; get_numpy = f"numpy=={numpy.__version__}"; get_pil = f"pillow=={PIL.__version__}"
    except: get_numpy = "numpy"; get_pil = "pillow"
    try: import subprocess; is_t4 = "Tesla T4" in str(subprocess.check_output(["nvidia-smi"]))
    except: is_t4 = False
    get_vllm, get_triton = ("vllm==0.9.2", "triton==3.2.0") if is_t4 else ("vllm==0.10.2", "triton")
    !uv pip install -qqq --upgrade \
        unsloth {get_vllm} {get_numpy} {get_pil} torchvision bitsandbytes xformers
    !uv pip install -qqq {get_triton}
!uv pip install transformers==4.56.2
!uv pip install --no-deps trl==0.22.2

### Unsloth

Load up `Gemma 3 1B Instruct`, and set parameters

In [3]:
from unsloth import FastModel
import torch
max_seq_length = 1024

fourbit_models = [
    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-27b-it-unsloth-bnb-4bit",

    # Other popular models!
    "unsloth/Llama-3.1-8B",
    "unsloth/Llama-3.2-3B",
    "unsloth/Llama-3.3-70B",
    "unsloth/mistral-7b-instruct-v0.3",
    "unsloth/Phi-4",
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3-1b-it",
    max_seq_length = max_seq_length, # Choose any for long context!
    load_in_4bit = False,  # 4 bit quantization to reduce memory
    load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)

NotImplementedError: Unsloth cannot find any torch accelerator? You need a GPU.

We now add LoRA adapters so we only need to update a small amount of parameters!

In [None]:
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, # Turn off for just text!
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_mlp_modules       = True,  # SHould leave on always!

    r = 8,           # Larger = higher accuracy, but might overfit
    lora_alpha = 8,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
)

### Data Prep
<a name="Data"></a>

We're using OpenAI's famous GSM8K dataset!

In [None]:
# from datasets import load_dataset
# dataset = load_dataset("openai/gsm8k", "main", split = "train")
# dataset

from datasets import load_dataset

dataset = load_dataset(
    "json",
    data_files="labels.jsonl",  # one JSON object per line
    split="train"
)
print(dataset)

Let's look at the first row:

In [None]:
str(dataset[0]["user"])

In [None]:
dataset[0]["response"]

We notice all answers like about have a ####, so we extract it:

In [None]:
def extract_hash_answer(text):
    if "####" not in text: return None
    return text.split("####")[1].strip()
extract_hash_answer(dataset[0]["response"])

We now create a system prompt which can be customized. We add 4 extra symbols for working out or thinking / reasoning sections and a final answer:

In [None]:


system_prompt = \
f"""ill give you a log like this and you pick which ai model to use

cnn, random forest, isolation forest, naivebayes, bert, autoencoder, dqn, other

(PICK ONLY ONE) and a short one sentence reasoning

The format is <AI Model> - <Reasoning>
"""
system_prompt

Let's map the dataset! and see the first row:

In [None]:
system_prompt = str(system_prompt)

system_prompt

In [None]:
# from datasets import Features, Sequence, Value

# features = Features({
#     "prompt": Sequence({"role": Value("string"), "content": Value("string")}),
#     "answer": Value("string"),
# })

# dataset = dataset.map(
#     lambda x: {
#         "prompt": [
#             {"role": "system", "content": (system_prompt)},
#             {"role": "user",   "content": '"' + str(x.get("user", "")).replace('"', '\\"') + '"'},
#         ],
#         "answer": (x.get("response", "")),
#     },
#     features=features,
#     remove_columns=dataset.column_names,
# )

# dataset[0]


In [None]:
# dataset = dataset.map(lambda x: {
#     "prompt" : [
#         {"role": "system", "content": system_prompt},
#         {"role": "user",   "content": x["user"]},
#     ],
#     "answer": x["response"],
# })
# str(dataset[0])

#original
dataset = dataset.map(lambda x: {
    "prompt" : [
        {"role": "system", "content": system_prompt},
        {"role": "user",   "content": str(x["user"])},
    ],
    "answer": (x["response"]),
})
dataset[0]

We create a regex format to match the reasoning sections and answers:

In [None]:
import re

match_format = re.compile(
    rf"^[\s]{{0,}}"\
    rf"{reasoning_start}.+?{reasoning_end}.*?"\
    rf"{solution_start}(.+?){solution_end}"\
    rf"[\s]{{0,}}$",
    flags = re.MULTILINE | re.DOTALL
)

We verify it works:

In [None]:
match_format.search(
    "<start_working_out>Let me think!<end_working_out>"\
    "<SOLUTION>2</SOLUTION>",
)

We now want to create a reward function to match the format exactly - we reward it with 3 points if it succeeds:

In [None]:
def match_format_exactly(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Match if format is seen exactly!
        if match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores

If it fails, we want to reward the model if it at least follows the format partially, by counting each symbol:

In [None]:
def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Count how many keywords are seen - we penalize if too many!
        # If we see 1, then plus some points!
        score += 0.5 if response.count(reasoning_start) == 1 else -0.5
        score += 0.5 if response.count(reasoning_end)   == 1 else -0.5
        score += 0.5 if response.count(solution_start)  == 1 else -0.5
        score += 0.5 if response.count(solution_end)    == 1 else -0.5
        scores.append(score)
    return scores

In [None]:
import re

# Helper: extract whatever the model put as its final answer (text)
# 1) Prefer content between your tags:  <solution_start> ... <solution_end>
# 2) Otherwise fall back to a line like: "#### <answer>"
FINAL_LINE_RE = re.compile(r"(?m)^\s*####\s*(?P<ans>[^\n]+?)\s*$")
ONLY_NUMBER_RE = re.compile(r"^\s*[-+]?\d+(?:\.\d+)?\s*$")

def extract_solution_text(response: str) -> str:
    if solution_start in response and solution_end in response:
        try:
            return response.split(solution_start, 1)[1].split(solution_end, 1)[0].strip()
        except Exception:
            pass
    m = FINAL_LINE_RE.search(response)
    return m.group("ans").strip() if m else ""

def is_text_answer(ans: str) -> bool:
    # Accept if it has at least one letter and is not just a number
    return bool(ans) and re.search(r"[A-Za-z]", ans) and not ONLY_NUMBER_RE.match(ans)

def match_format_exactly(completions, **kwargs):
    scores = []
    for completion in completions:
        response = completion[0]["content"]
        ans = extract_solution_text(response)
        score = 3.0 if is_text_answer(ans) else 0.0
        scores.append(score)
    return scores

def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        response = completion[0]["content"]
        score = 0.0

        # keep your tag-count heuristics
        score += 0.5 if response.count(reasoning_start) == 1 else -0.5
        score += 0.5 if response.count(reasoning_end)   == 1 else -0.5
        score += 0.5 if response.count(solution_start)  == 1 else -0.5
        score += 0.5 if response.count(solution_end)    == 1 else -0.5

        # new: text-specific checks for the extracted final answer
        ans = extract_solution_text(response)
        if is_text_answer(ans):
            score += 0.5  # good: looks like text, not just a number
            n_words = len(ans.split())
            if 1 <= n_words <= 40:
                score += 0.5   # concise final text
            elif n_words > 120:
                score -= 0.5   # overly long for a "final answer"
        else:
            score -= 0.5       # missing/empty or purely numeric

        scores.append(score)
    return scores


Finally, we want to extract the generated answer, and reward or penalize it! We also reward it based on how close the answer is to the true one via ratios:

In [None]:
def check_answer(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_format.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(0)
            continue
        # Correct answer gets 3 points!
        if guess == true_answer:
            score += 3.0
        # Match if spaces are seen
        elif guess.strip() == true_answer.strip():
            score += 1.5
        else:
            # We also reward it if the answer is close via ratios!
            # Ie if the answer is within some range, reward it!
            try:
                ratio = float(guess) / float(true_answer)
                if   ratio >= 0.9 and ratio <= 1.1: score += 0.5
                elif ratio >= 0.8 and ratio <= 1.2: score += 0.25
                else: score -= 1.0 # Penalize wrong answers
            except:
                score -= 0.5 # Penalize
        scores.append(score)
    return scores

Also sometimes it might not be 1 number as the answer, but like a sentence for example "The solution is $20" -> we extract 20.

In [None]:
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})",
    flags = re.MULTILINE | re.DOTALL
)
match_numbers.findall("<SOLUTION>  0.34  </SOLUTION>")

In [None]:
def check_numbers(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_numbers.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    print('*'*20, f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
        # Convert to numbers
        try:
            true_answer = float(true_answer.strip())
            guess       = float(guess.strip())
            scores.append(1.5 if guess == true_answer else 0.0)
        except:
            scores.append(0)
            continue
    return scores

In [None]:
import re, difflib
from fractions import Fraction

# --- helpers ---
_OnlyNum = re.compile(r'^\s*[-+]?(\d+(?:[,_]\d{3})*|\d+)(?:\.\d+)?(?:[eE][-+]?\d+)?\s*%?\s*$')

def _normalize_text(s: str) -> str:
    s = s.strip().lower()
    # collapse whitespace
    s = re.sub(r'\s+', ' ', s)
    # drop most punctuation (keep % for possible semantics)
    s = re.sub(r'[^\w\s%]', '', s)  # remove punctuation except %; adjust if needed
    return s.strip()

def _text_similarity(a: str, b: str) -> float:
    # similarity on normalized strings
    a_n, b_n = _normalize_text(a), _normalize_text(b)
    return difflib.SequenceMatcher(None, a_n, b_n).ratio()

def _to_float(s: str):
    if s is None:
        return None
    t = s.strip()
    if not t:
        return None
    # percentage?
    is_pct = t.endswith('%')
    if is_pct:
        t = t[:-1]

    # fractions like "1/2"
    if '/' in t and not re.search(r'[A-Za-z]', t):
        try:
            val = float(Fraction(t.replace(' ', '')))
            return val / 100.0 if is_pct else val
        except Exception:
            pass

    # remove thousands separators/underscores/spaces
    t = t.replace(',', '').replace('_', '').replace(' ', '')

    try:
        val = float(t)
        return val / 100.0 if is_pct else val
    except Exception:
        return None

# --- TEXT ANSWER CHECKER ---
def check_answer(prompts, completions, answer, **kwargs):
    # Extract candidate answers using your existing regex `match_format`
    responses = [completion[0]["content"] for completion in completions]
    extracted = [
        m.group(1) if (m := match_format.search(r)) is not None else None
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted, answer):
        if guess is None:
            scores.append(0.0)
            continue

        # 1) exact match (verbatim)
        if guess == true_answer:
            scores.append(3.0)
            continue

        # 2) whitespace-only differences
        if guess.strip() == true_answer.strip():
            scores.append(2.0)
            continue

        # 3) case/punct/spacing-insensitive equality
        if _normalize_text(guess) == _normalize_text(true_answer):
            scores.append(1.5)
            continue

        # 4) fuzzy similarity on normalized text
        sim = _text_similarity(guess, true_answer)
        if   sim >= 0.90: scores.append(1.0)
        elif sim >= 0.80: scores.append(0.5)
        else:             scores.append(-0.5)  # penalize clear mismatch

    return scores

# --- NUMERIC ANSWER CHECKER (more robust parsing) ---
def check_numbers(prompts, completions, answer, **kwargs):
    # Extract numeric-looking substrings using your existing `match_numbers`
    responses = [completion[0]["content"] for completion in completions]
    extracted = [
        m.group(1) if (m := match_numbers.search(r)) is not None else None
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted, answer):
        if guess is None:
            scores.append(0.0)
            continue
        try:
            g = _to_float(guess)
            t = _to_float(true_answer if isinstance(true_answer, str) else str(true_answer))
            if g is None or t is None:
                scores.append(0.0)
                continue

            # exact numeric equality
            if g == t:
                scores.append(1.5)
                continue

            # relative tolerance buckets
            # avoid division by zero: use max(|t|, 1e-12)
            denom = max(abs(t), 1e-12)
            rel_err = abs(g - t) / denom

            if   rel_err <= 0.01: scores.append(1.0)   # within 1%
            elif rel_err <= 0.05: scores.append(0.75)  # within 5%
            elif rel_err <= 0.10: scores.append(0.5)   # within 10%
            else:                  scores.append(0.0)
        except Exception:
            scores.append(0.0)
    return scores


In [None]:
import re
from difflib import SequenceMatcher

# --- helpers ---------------------------------------------------------------

_PAIR_LINE = re.compile(
    r'^\s*[“"]?(?P<model>[^-–—:|]+?)\s*[-–—:|]\s*(?P<reason>.+?)\s*[”"]?\s*$'
)

def _normalize(s: str) -> str:
    """Lowercase, trim, collapse whitespace."""
    s = re.sub(r'\s+', ' ', s or '').strip()
    return s.casefold()

def _similar(a: str, b: str) -> float:
    """Similarity in [0,1] using SequenceMatcher."""
    return SequenceMatcher(None, _normalize(a), _normalize(b)).ratio()

def _extract_pair(text: str):
    """
    Return (model, reason) by scanning lines and grabbing the first
    'X - Y' (or –, —, :, |) pattern. If none, return (None, None).
    """
    if not text:
        return None, None
    for line in text.splitlines():
        m = _PAIR_LINE.match(line)
        if m:
            return m.group('model').strip(), m.group('reason').strip()
    return None, None

def _coerce_answer_pair(ans):
    """
    Accept answers as:
      - "Model - Reason" string
      - ("Model", "Reason") tuple/list
      - {"model": "...", "reasoning": "..."} or {"model": "...", "reason": "..."}
    """
    if isinstance(ans, (list, tuple)) and len(ans) == 2:
        return str(ans[0]), str(ans[1])
    if isinstance(ans, dict):
        model = ans.get('model') or ans.get('ai_model') or ans.get('name')
        reason = ans.get('reasoning') or ans.get('reason') or ans.get('rationale')
        return (str(model) if model is not None else None,
                str(reason) if reason is not None else None)
    if isinstance(ans, str):
        return _extract_pair(ans)
    # Fallback
    return None, None

# --- main -----------------------------------------------------------------

def check_answer(prompts, completions, answer, **kwargs):
    """
    Scores completions against expected 'AI Model - Reasoning' pairs.

    Inputs:
      - prompts: unused here but kept for API compatibility.
      - completions: list like [[{"content": "..."}], ...]
      - answer: list of expected pairs; each item can be:
          * "Model - Reasoning"
          * ("Model", "Reasoning")
          * {"model": "...", "reasoning": "..."} (keys are flexible)
    Output:
      - List[float] of scores, one per completion (max 3.0).
    """
    # Flatten completion texts
    texts = [c[0]["content"] if c and isinstance(c[0], dict) else "" for c in completions]

    # Extract predicted pairs
    preds = [_extract_pair(t) for t in texts]

    # Coerce ground-truth pairs
    truths = [_coerce_answer_pair(a) for a in answer]

    scores = []
    for (pred_model, pred_reason), (true_model, true_reason) in zip(preds, truths):
        # If extraction failed or truth missing, score 0
        if not pred_model or not pred_reason or not true_model or not true_reason:
            scores.append(0.0)
            continue

        # Similarity per field
        sm = _similar(pred_model, true_model)
        sr = _similar(pred_reason, true_reason)

        # Convert similarities to points (max 1.5 per field)
        def field_points(sim: float) -> float:
            if sim >= 0.999:     # effectively exact after normalization
                return 1.5
            elif sim >= 0.90:
                return 1.0
            elif sim >= 0.80:
                return 0.5
            elif sim >= 0.65:
                return 0.25
            else:
                return -0.25

        total = field_points(sm) + field_points(sr)

        # Clamp to [0, 3]
        total = max(0.0, min(3.0, total))
        scores.append(total)

    return scores


In [None]:
import re
from difflib import SequenceMatcher

_PAIR_LINE = re.compile(r'^\s*[“"]?(?P<model>[^-–—:|]+?)\s*[-–—:|]\s*(?P<reason>.+?)\s*[”"]?\s*$')

def _normalize(s):
    import re
    return re.sub(r'\s+', ' ', (s or '')).strip().casefold()

def _sim(a, b):
    return SequenceMatcher(None, _normalize(a), _normalize(b)).ratio()

def _extract_pair(text):
    if not text:
        return None, None
    for line in str(text).splitlines():
        m = _PAIR_LINE.match(line)
        if m:
            return m.group('model').strip(), m.group('reason').strip()
    return None, None

def _coerce_truth(ans):
    if isinstance(ans, (list, tuple)) and len(ans) == 2:
        return str(ans[0]), str(ans[1])
    if isinstance(ans, dict):
        m = ans.get('model') or ans.get('ai_model') or ans.get('name')
        r = ans.get('reasoning') or ans.get('reason') or ans.get('rationale')
        return (str(m) if m is not None else None, str(r) if r is not None else None)
    if isinstance(ans, str):
        return _extract_pair(ans)
    return None, None

def _to_text_list(completions):
    out = []
    for c in completions:
        if isinstance(c, str):
            out.append(c)
        elif isinstance(c, list):
            # e.g., [[{"content": "..."}]] or a token list; join safely
            if c and isinstance(c[0], dict) and "content" in c[0]:
                out.append(c[0]["content"])
            else:
                out.append(" ".join([ (x.get("content","") if isinstance(x,dict) else str(x)) for x in c ]))
        elif isinstance(c, dict):
            out.append(c.get("content",""))
        else:
            out.append(str(c))
    return out

def check_answer(completions=None, **kwargs):
    """
    Reward function for TRL GRPOTrainer.
    Expects completions (list[str]) and a dataset column 'answer' in kwargs.
    Returns a list[float] with the same length as completions.
    """
    answers = kwargs.get("answer")
    if answers is None:
        # No ground truth provided -> reward nothing
        return [0.0] * (len(completions) if completions else 0)

    texts = _to_text_list(completions or [])
    preds = [_extract_pair(t) for t in texts]

    # TRL usually repeats each row's columns per generated completion;
    # just trust the length to match len(texts). If it's per-prompt, repeat.
    if len(answers) != len(texts):
        # broadcast per-prompt answers to per-completion if needed
        if len(answers) == 1:
            answers = answers * len(texts)
        else:
            # fallback: truncate/pad
            answers = (answers * ((len(texts) + len(answers) - 1)//len(answers)))[:len(texts)]

    truths = [_coerce_truth(a) for a in answers]

    rewards = []
    for (pm, pr), (tm, tr) in zip(preds, truths):
        if not pm or not pr or not tm or not tr:
            rewards.append(0.0)
            continue
        sm, sr = _sim(pm, tm), _sim(pr, tr)
        # up to 1.5 per field → max 3.0
        def pts(s):
            return 1.5 if s >= 0.999 else 1.0 if s >= 0.90 else 0.5 if s >= 0.80 else 0.25 if s >= 0.65 else -0.25
        total = max(0.0, min(3.0, pts(sm) + pts(sr)))
        rewards.append(total)
    return rewards


In [None]:
import re, difflib
from fractions import Fraction

# ====== Tag defaults (override if you use different ones) ======
reasoning_start = "<reasoning_start>"
reasoning_end   = "</reasoning_end>"
solution_start  = "<solution_start>"
solution_end    = "</solution_end>"

# ====== Regexes (available if other code uses them) ======
# Numeric extractor: first numeric-looking token (supports commas, decimals, exp, %, simple fractions)
match_numbers = re.compile(r'(?<!\w)([-+]?(?:\d[\d,._]*)?(?:\.\d+)?(?:[eE][-+]?\d+)?%?|[-+]?\d+/\d+%?)(?!\w)')

# If some legacy code still uses `match_format.search(resp).group(1)`, prefer using the helpers below.
# (We keep this here but the functions below do not rely on it.)
match_format = re.compile(r'(?m)^\s*####\s*(?P<ans>[^\n]+?)\s*$')


# ====== Helpers ======
def extract_response_text(completion):
    """Return the assistant's final text from various completion shapes."""
    # 1) Already a string
    if isinstance(completion, str):
        return completion

    # 2) List of chat messages
    if isinstance(completion, list) and completion:
        # prefer the last assistant message
        for msg in reversed(completion):
            if isinstance(msg, dict) and msg.get("role") in ("assistant", "assistant_final"):
                c = msg.get("content")
                if isinstance(c, str):
                    return c
        # fallback: last content if present
        last = completion[-1]
        if isinstance(last, dict) and isinstance(last.get("content"), str):
            return last["content"]

    # 3) Dict-like generations
    if isinstance(completion, dict):
        for k in ("content", "generated_text", "text", "response"):
            v = completion.get(k)
            if isinstance(v, str):
                return v

    # 4) Last resort
    return str(completion)


def _between_tags(text, start_tag, end_tag):
    if start_tag in text and end_tag in text:
        try:
            return text.split(start_tag, 1)[1].split(end_tag, 1)[0].strip()
        except Exception:
            return None
    return None


FINAL_LINE_RE = re.compile(r"(?m)^\s*####\s*(?P<ans>[^\n]+?)\s*$")
ONLY_NUMBER_RE = re.compile(r"^\s*[-+]?\d+(?:\.\d+)?\s*$")

def extract_solution_text(response: str) -> str:
    """Best-effort final answer extraction."""
    # 1) Preferred: tags
    tagged = _between_tags(response, solution_start, solution_end)
    if tagged:
        return tagged

    # 2) Markdown final line: "#### answer"
    m = FINAL_LINE_RE.search(response)
    if m:
        return m.group("ans").strip()

    # 3) Fallback: whole response
    return response.strip()


def _normalize_text(s: str) -> str:
    s = s.strip().lower()
    s = re.sub(r'\s+', ' ', s)
    s = re.sub(r'[^\w\s%]', '', s)  # drop punctuation except %
    return s.strip()


def _text_similarity(a: str, b: str) -> float:
    return difflib.SequenceMatcher(None, _normalize_text(a), _normalize_text(b)).ratio()


def _to_float(s: str):
    if s is None:
        return None
    t = s.strip()
    if not t:
        return None
    is_pct = t.endswith('%')
    if is_pct:
        t = t[:-1]

    # fractions like "1/2"
    if '/' in t and not re.search(r'[A-Za-z]', t):
        try:
            val = float(Fraction(t.replace(' ', '')))
            return val / 100.0 if is_pct else val
        except Exception:
            pass

    t = t.replace(',', '').replace('_', '').replace(' ', '')
    try:
        val = float(t)
        return val / 100.0 if is_pct else val
    except Exception:
        return None


def _is_text_answer(ans: str) -> bool:
    # At least one letter & not purely numeric
    return bool(ans) and re.search(r"[A-Za-z]", ans) and not ONLY_NUMBER_RE.match(ans)


# ====== Format scorers ======
def match_format_exactly(completions, **kwargs):
    """Score 3.0 if the output has exactly one pair of reasoning/solution tags and a non-empty text answer."""
    scores = []
    for c in completions:
        resp = extract_response_text(c)
        score = 0.0
        ok = (
            resp.count(reasoning_start) == 1 and
            resp.count(reasoning_end)   == 1 and
            resp.count(solution_start)  == 1 and
            resp.count(solution_end)    == 1
        )
        if ok:
            ans = extract_solution_text(resp)
            if _is_text_answer(ans):
                score = 3.0
        scores.append(score)
    return scores


def match_format_approximately(completions, **kwargs):
    """Heuristic: right number of tags + plausible final text."""
    scores = []
    for c in completions:
        resp = extract_response_text(c)
        score = 0.0
        score += 0.5 if resp.count(reasoning_start) == 1 else -0.5
        score += 0.5 if resp.count(reasoning_end)   == 1 else -0.5
        score += 0.5 if resp.count(solution_start)  == 1 else -0.5
        score += 0.5 if resp.count(solution_end)    == 1 else -0.5

        ans = extract_solution_text(resp)
        if _is_text_answer(ans):
            score += 0.5
            n_words = len(ans.split())
            if 1 <= n_words <= 40:
                score += 0.5
            elif n_words > 120:
                score -= 0.5
        else:
            score -= 0.5

        scores.append(score)
    return scores


# ====== Answer checkers ======
def check_answer(prompts, completions, answer, **kwargs):
    """
    Text answer scorer.
    - If `true_answer` is a dict with keys like {"model": "...", "reason": "..."},
      we reward presence/overlap of both parts.
    - If it's a plain string, we do normalized + fuzzy matching.
    """
    responses = [extract_response_text(c) for c in completions]
    preds = [extract_solution_text(r) or "" for r in responses]

    scores = []
    for pred, true_answer in zip(preds, answer):
        score = 0.0

        # Dict form: e.g. {"model": "GPT-4", "reason": "..."}
        if isinstance(true_answer, dict):
            gold_model  = str(true_answer.get("model", "")).strip()
            gold_reason = str(true_answer.get("reason", "")).strip()

            p_norm = _normalize_text(pred)
            m_norm = _normalize_text(gold_model)
            r_norm = _normalize_text(gold_reason)

            # model presence
            if m_norm and (m_norm in p_norm or _text_similarity(p_norm, m_norm) >= 0.90):
                score += 1.5

            # reason similarity
            if r_norm:
                sim = _text_similarity(pred, gold_reason)
                if   sim >= 0.90: score += 1.5
                elif sim >= 0.80: score += 1.0
                elif sim >= 0.60: score += 0.5

            # small penalty if answer is suspiciously short
            if len(pred.split()) < 2:
                score -= 0.25

        # String form
        else:
            true_str = str(true_answer)

            if pred == true_str:
                score += 3.0
            elif pred.strip() == true_str.strip():
                score += 2.0
            elif _normalize_text(pred) == _normalize_text(true_str):
                score += 1.5
            else:
                sim = _text_similarity(pred, true_str)
                if   sim >= 0.90: score += 1.0
                elif sim >= 0.80: score += 0.5
                else:             score -= 0.5

        scores.append(score)

    return scores


def check_numbers(prompts, completions, answer, **kwargs):
    """
    Numeric scorer with robust parsing (commas, percents, simple fractions).
    Uses relative error buckets.
    """
    responses = [extract_response_text(c) for c in completions]
    extracted = [
        (m.group(1) if (m := match_numbers.search(r)) else None)
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted, answer):
        if guess is None:
            scores.append(0.0)
            continue

        g = _to_float(guess)
        t = _to_float(true_answer if isinstance(true_answer, str) else str(true_answer))
        if g is None or t is None:
            scores.append(0.0)
            continue

        if g == t:
            scores.append(1.5)
            continue

        denom = max(abs(t), 1e-12)
        rel_err = abs(g - t) / denom

        if   rel_err <= 0.01: scores.append(1.0)   # within 1%
        elif rel_err <= 0.05: scores.append(0.75)  # within 5%
        elif rel_err <= 0.10: scores.append(0.5)   # within 10%
        else:                  scores.append(0.0)

    return scores


In [None]:
import re, difflib, json

# ------- helpers -------
ALLOWED = {
    "cnn", "random_forest", "isolation_forest", "naivebayes",
    "bert", "autoencoder", "dqn", "other"
}

def _extract_response_text(completion):
    """Return assistant text from string / list-of-messages / dict shapes."""
    if isinstance(completion, str):
        return completion
    if isinstance(completion, list) and completion:
        for msg in reversed(completion):
            if isinstance(msg, dict) and msg.get("role") in ("assistant", "assistant_final"):
                c = msg.get("content")
                if isinstance(c, str):
                    return c
        last = completion[-1]
        if isinstance(last, dict) and isinstance(last.get("content"), str):
            return last["content"]
    if isinstance(completion, dict):
        for k in ("content", "generated_text", "text", "response"):
            v = completion.get(k)
            if isinstance(v, str):
                return v
    return str(completion)

def _canon_model(s: str) -> str:
    t = s.strip().lower()
    # normalize common spellings
    t = t.replace("-", " ").replace("_", " ")
    t = re.sub(r"\s+", " ", t)

    # map synonyms → canonical labels
    if "random" in t and "forest" in t:   return "random_forest"
    if "isolation" in t and "forest" in t:return "isolation_forest"
    if "naive" in t and "bayes" in t:     return "naivebayes"
    if t in ("cnn", "convnet", "convolutional neural network"): return "cnn"
    if t.startswith("bert"):              return "bert"
    if "autoencoder" in t or "auto encoder" in t: return "autoencoder"
    if t == "dqn" or "deep q" in t:       return "dqn"
    if t == "other":                      return "other"

    # plain canonicalization fallback
    t = re.sub(r"[^\w]", "", t)
    if t == "randomforest":    return "random_forest"
    if t == "isolationforest": return "isolation_forest"
    if t == "naivebayes":      return "naivebayes"
    if t in ALLOWED:           return t
    return "other"

_SPLIT = re.compile(r"^\s*(?P<model>[^-\n]{1,80})\s*-\s*(?P<reason>.+?)\s*$", re.S)

def _parse_pred(line: str):
    """Return (model_canon, reason_str, format_ok)."""
    m = _SPLIT.match(line.strip())
    if not m:
        return ("other", line.strip(), False)
    model_raw = m.group("model")
    reason    = m.group("reason").strip()
    return (_canon_model(model_raw), reason, True)

def _norm(s: str) -> str:
    s = s.strip().lower()
    s = re.sub(r"\s+", " ", s)
    return s

def _sim(a: str, b: str) -> float:
    return difflib.SequenceMatcher(None, _norm(a), _norm(b)).ratio()

# ------- TEXT-ONLY REWARD -------
def check_answer(prompts, completions, answer=None, label=None, rationale=None, **kwargs):
    """
    Scores purely text outputs of the form '<MODEL> - <REASON>'.
    Prefers row keys:
      - label: canonical gold model (e.g., 'random_forest')
      - rationale: gold one-line reason
      - answer: full string '<MODEL> - <REASON>' (fall back if label/rationale missing)
    """
    # Build golds per-sample from whichever fields trainer passes
    # TRL usually passes 'answer'; but your dataset also has 'label' + 'rationale'.
    gold_answers   = answer
    gold_labels    = label
    gold_rationales = rationale

    # If any are None, make them lists of None so zip works
    if gold_answers is None:    gold_answers = [None] * len(completions)
    if gold_labels is None:     gold_labels = [None] * len(completions)
    if gold_rationales is None: gold_rationales = [None] * len(completions)

    preds = [_extract_response_text(c) for c in completions]

    scores = []
    for pred, gold_ans, gold_lab, gold_rat in zip(preds, gold_answers, gold_labels, gold_rationales):
        score = 0.0

        # If we have the full gold 'answer' string: exact / normalized checks first
        if isinstance(gold_ans, str) and gold_ans:
            if pred == gold_ans:
                scores.append(3.0)  # perfect
                continue
            if _norm(pred) == _norm(gold_ans):
                score += 2.0

        # Parse prediction into model + reason (format '<model> - <reason>')
        pred_model, pred_reason, fmt_ok = _parse_pred(pred)
        if fmt_ok:
            score += 0.5   # bonus for correct format
        else:
            score -= 0.25  # small penalty for wrong format

        # Decide gold model/rationale
        gold_model = None
        gold_reason = None

        if isinstance(gold_lab, str) and gold_lab:
            gold_model = gold_lab
        elif isinstance(gold_ans, str) and gold_ans:
            m = _SPLIT.match(gold_ans)
            if m:
                gold_model = _canon_model(m.group("model"))
                gold_reason = m.group("reason").strip()

        if isinstance(gold_rat, str) and gold_rat:
            gold_reason = gold_rat

        # Model correctness
        if gold_model:
            if pred_model == _canon_model(gold_model):
                score += 1.5
            else:
                # partial credit if it's at least a valid label but wrong
                score += 0.0  # (no extra credit)
        else:
            # if we don't know gold model, don't penalize
            pass

        # Reasoning similarity (if we have a target)
        if gold_reason:
            sim = _sim(pred_reason, gold_reason)
            if   sim >= 0.92: score += 1.5
            elif sim >= 0.85: score += 1.0
            elif sim >= 0.70: score += 0.5
            else:             score += 0.0
        else:
            # If we have no gold rationale, lightly reward non-empty reason
            if len(pred_reason.split()) >= 4:
                score += 0.25

        # Guard against trivially short outputs
        if len(pred.strip()) < 6:
            score -= 0.25

        scores.append(score)

    return scores


In [None]:
fake_completions = [
    "GPT-4 - Uses chain-of-thought to explain steps",
    "BERT - Bidirectional masked LM pretraining"
]
fake_answers = [
    {"model": "GPT-4", "reason": "uses chain-of-thought to explain steps"},
    {"model": "BERT", "reason": "bidirectional masked LM pretraining"}
]
print(check_answer(completions=fake_completions, answer=fake_answers))
# Expect non-zero, varied numbers (not all 0)


<a name="Train"></a>
### Train the model

Now set up GRPO Trainer and all configurations!

In [None]:
max_prompt_length = 256

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_torch_fused",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 50,
    save_steps = 50,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!

You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!

| Step | Training Loss | reward    | reward_std | completion_length | kl       |
|------|---------------|-----------|------------|-------------------|----------|
| 1    | 0.000000      | 0.125000  | 0.000000   | 200.000000        | 0.000000 |
| 2    | 0.000000      | 0.072375  | 0.248112   | 200.000000        | 0.000000 |
| 3    | 0.000000      | -0.079000 | 0.163776   | 182.500000        | 0.000005 |


In [None]:
# 1) Rename your target to what TRL expects
# if "answer" in dataset.column_names:
#     dataset = dataset.rename_column("answer", "completion")
# if "response" in dataset.column_names and "completion" not in dataset.column_names:
#     dataset = dataset.rename_column("response", "completion")
if "label" in dataset.column_names:
    # either drop it OR keep it, but only if you ALSO have "completion"
    dataset = dataset.remove_columns(["label"])

# 2) Keep only the accepted keys for this mode: {"prompt","completion"}
# keep = {"prompt", "completion"}
# dataset = dataset.remove_columns([c for c in dataset.column_names if c not in keep])

# print(dataset.column_names)  # should be ['prompt', 'completion']


In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        check_answer,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

In [None]:
from google.colab import files
files.download("/content/gemma-3-finetune/model.safetensors")


<a name="Inference"></a>
### Inference
Now let's try the model we just trained!

In [None]:
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user",   "content": """
  "event_id": "3b00489f-5c83-4b94-a0cc-54e19c790edc",
  "timestamp": "2025-01-18T20:07:20",
  "severity": "info",
  "raw_log": "CEF:0|Carbon Black v7.8.0|SIEM|1.0|100|firewall|info| desc=Firewall drop SSH traffic from 100.11.91.183:51767 to 35.35.93.73:580 No additional info",
  "user": null,
  "action": "drop",
  "description": "Firewall drop SSH traffic from 100.11.91.183:51767 to 35.35.93.73:580 No additional info",
  "src_ip": "100.11.91.183",
  "dst_ip": "35.35.93.73",
  "category": null
"""},
]

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    tokenize = False,
)
from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    max_new_tokens = 64, # Increase for longer outputs!
    # Recommended Gemma-3 settings!
    temperature = 1.0, top_p = 0.95, top_k = 64,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

<a name="Save"></a>
### Saving, loading finetuned models
To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.

**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!

In [None]:
model.save_pretrained("gemma-3")  # Local saving
tokenizer.save_pretrained("gemma-3")
# model.push_to_hub("HF_ACCOUNT/gemma-3", token = "...") # Online saving
# tokenizer.push_to_hub("HF_ACCOUNT/gemma-3", token = "...") # Online saving

### Saving to float16 for VLLM

We also support saving to `float16` directly for deployment! We save it in the folder `gemma-3-finetune`. Set `if False` to `if True` to let it run!

In [None]:
if True: # Change to True to save finetune!
    model.save_pretrained_merged("gemma-3-finetune", tokenizer)

If you want to upload / push to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!

In [None]:
if False: # Change to True to upload finetune
    model.push_to_hub_merged(
        "HF_ACCOUNT/gemma-3-finetune", tokenizer,
        token = "hf_..."
    )

### GGUF / llama.cpp Conversion
To save to `GGUF` / `llama.cpp`, we support it natively now for all models! For now, you can convert easily to `Q8_0, F16 or BF16` precision. `Q4_K_M` for 4bit will come later!

In [None]:
if True: # Change to True to save to GGUF
    model.save_pretrained_gguf(
        "gemma-3-finetune",
        tokenizer,
        quantization_method = "Q8_0", # For now only Q8_0, BF16, F16 supported
    )

Likewise, if you want to instead push to GGUF to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!

In [None]:
if False: # Change to True to upload GGUF
    model.push_to_hub_gguf(
        "gemma-3-finetune",
        tokenizer,
        quantization_method = "Q8_0", # Only Q8_0, BF16, F16 supported
        repo_id = "HF_ACCOUNT/gemma-finetune-gguf",
        token = "hf_...",
    )

Now, use the `gemma-3-finetune.gguf` file or `gemma-3-finetune-Q4_K_M.gguf` file in llama.cpp.

And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!

Some other links:
1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!

<div class="align-center">
  <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
  <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
  <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>

  Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
</div>

  This notebook and all Unsloth notebooks are licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).
