In [1]:
%%capture
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
from torch import __version__ as torch_version
from packaging.version import Version as V

xformers = "xformers==0.0.27" if V(torch_version) < V("2.4.0") else "xformers"
!pip install --no-deps {xformers} peft accelerate bitsandbytes datasets


In [2]:
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from datasets import Dataset
from unsloth import FastLanguageModel, is_bfloat16_supported

device = "cuda" if torch.cuda.is_available() else "cpu"
max_seq_length = 512
dtype = None
load_in_4bit = True


ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


In [3]:
model_name = "unsloth/gemma-3-1b-it-unsloth-bnb-4bit"

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name     = model_name,
    max_seq_length = max_seq_length,
    dtype          = dtype,
    load_in_4bit   = load_in_4bit,
)

model = FastLanguageModel.get_peft_model(
    model,
    r          = 16,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = 16,
    lora_dropout = 0.0,
    bias       = "none",
    use_gradient_checkpointing = "unsloth",
)

FastLanguageModel.for_training(model)
model.to(device)


==((====))==  Unsloth 2025.11.2: Fast Gemma3 patching. Transformers: 4.57.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Using float16 precision for gemma3 won't work! Using float32.
Unsloth: Gemma3 does not support SDPA - switching to fast eager.


model.safetensors:   0%|          | 0.00/1.00G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/233 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/670 [00:00<?, ?B/s]

Unsloth: Making `model.base_model.model.model` require gradients


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Gemma3ForCausalLM(
      (model): Gemma3TextModel(
        (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 1152, padding_idx=0)
        (layers): ModuleList(
          (0-15): 16 x Gemma3DecoderLayer(
            (self_attn): Gemma3Attention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=1152, out_features=1024, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1152, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
    

In [4]:
math_problems = [
    "What is 12 + 7?",
    "If you have 5 apples and you buy 9 more, how many apples do you have now?",
    "A box has 4 rows of 6 chocolates. How many chocolates are there in total?",
    "Tom had 25 marbles and gave 7 to his friend. How many does he have left?",
]

math_answers = [19, 14, 24, 18]

train_ds = Dataset.from_dict({
    "question": math_problems,
    "answer": math_answers,
})

train_ds


Dataset({
    features: ['question', 'answer'],
    num_rows: 4
})

In [5]:
import re

def format_prompt(question: str) -> str:
    return (
        "You are a careful reasoning assistant.\n"
        f"Question: {question}\n\n"
        "Think step by step and show your reasoning.\n"
        "At the very end, write exactly: Final answer: <number>\n"
    )

def extract_final_answer(text: str):
    m = re.search(r"Final answer:\s*([-+]?\d+)", text)
    if m:
        return int(m.group(1))
    return None

def compute_reward(text: str, correct: int) -> float:
    pred = extract_final_answer(text)
    if pred is None:
        return 0.0
    return 1.0 if pred == correct else 0.0


In [6]:
K = 3                    # candidates per question (group size)
num_epochs = 3
max_new_tokens = 64
lr = 5e-6

optimizer = AdamW(model.parameters(), lr=lr)

for epoch in range(num_epochs):
    epoch_losses = []
    epoch_rewards = []

    for ex in train_ds:
        question = ex["question"]
        correct  = ex["answer"]

        prompt = format_prompt(question)
        prompts = [prompt] * K

        # ---- 1) Generate K candidate solutions (no grad) ----
        model.eval()
        enc = tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
        ).to(device)

        with torch.no_grad():
            gen_out = model.generate(
                **enc,
                max_new_tokens = max_new_tokens,
                do_sample      = True,
                top_p          = 0.9,
                temperature    = 0.7,
                use_cache      = False,   # safer with quantized model
            )

        gen_texts = [
            tokenizer.decode(gen_out[i], skip_special_tokens=True)
            for i in range(K)
        ]

        rewards = torch.tensor(
            [compute_reward(t, correct) for t in gen_texts],
            dtype=torch.float32,
            device=device,
        )

        # ---- 2) Compute log-probs for each candidate (with grad) ----
        model.train()
        log_probs = []

        for i in range(K):
            full_text = prompts[i] + "\n" + gen_texts[i]
            enc2 = tokenizer(
                full_text,
                return_tensors="pt",
            ).to(device)

            # labels = input_ids to compute NLL over whole sequence
            outputs = model(**enc2, labels=enc2["input_ids"])
            seq_len = enc2["input_ids"].shape[1]
            # sum of log probs â‰ˆ -loss * seq_len
            sum_logp = -outputs.loss * seq_len
            log_probs.append(sum_logp)

        log_probs = torch.stack(log_probs)

        # ---- 3) GRPO-style group-relative advantage ----
        advantages = rewards - rewards.mean()
        if advantages.std() > 0:
            advantages = advantages / (advantages.std() + 1e-8)

        loss = -(advantages.detach() * log_probs).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_losses.append(loss.item())
        epoch_rewards.extend(rewards.detach().cpu().tolist())

    print(
        f"Epoch {epoch+1}/{num_epochs} | "
        f"loss={sum(epoch_losses)/len(epoch_losses):.4f} | "
        f"avg_reward={sum(epoch_rewards)/len(epoch_rewards):.3f}"
    )


Epoch 1/3 | loss=-0.8857 | avg_reward=0.667
Epoch 2/3 | loss=1.6732 | avg_reward=0.500
Epoch 3/3 | loss=-0.4680 | avg_reward=0.500


In [7]:
FastLanguageModel.for_inference(model)
inference_dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16
model = model.to(device=device, dtype=inference_dtype)

def reasoning_chat(question: str, max_new_tokens: int = 128):
    prompt = format_prompt(question)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens = max_new_tokens,
            do_sample      = True,
            top_p          = 0.9,
            temperature    = 0.7,
        )
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))

for q in math_problems:
    print("=" * 80)
    print("QUESTION:", q)
    reasoning_chat(q)


QUESTION: What is 12 + 7?
You are a careful reasoning assistant.
Question: What is 12 + 7?

Think step by step and show your reasoning.
At the very end, write exactly: Final answer: <number>
---
Final answer: 19
---

QUESTION: If you have 5 apples and you buy 9 more, how many apples do you have now?
You are a careful reasoning assistant.
Question: If you have 5 apples and you buy 9 more, how many apples do you have now?

Think step by step and show your reasoning.
At the very end, write exactly: Final answer: <number>
<number> apples
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
<number>
QUESTION: A box has 4 rows of 6 chocolates. How many chocolates are there in total?
You are a careful reasoning assistant.
Question: A box has 4 rows of 6 chocolates. How many choco

In [8]:
save_dir = "gemma-3-1b-grpo-reasoning-final"
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
print(f"Saved GRPO reasoning model to {save_dir}")


Saved GRPO reasoning model to gemma-3-1b-grpo-reasoning-final
