In [None]:
!pip uninstall unsloth unsloth_zoo -y
!pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --upgrade --no-cache-dir "git+https://github.com/unslothai/unsloth-zoo.git"

In [None]:
!pip install datasets peft bitsandbytes -qq

In [None]:
from unsloth import FastLanguageModel, PatchFastRL
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

from pydantic import BaseModel
import random
import os
import numpy as np

PatchFastRL('GRPO', FastLanguageModel)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [None]:
class Config(BaseModel):
    model_name: str = 'unsloth/gemma-3-270m-it'
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    dataset_name: str = 'openai/gsm8k'
    subset: str = 'main' # or 'none'
    seed: int = 3407
    max_input_length: int = 1024
    batch_size: int = 2
    num_epochs: int = 10

    # grpo
    epsilon: float = 0.2
    rollout_N: int = 2

    # generation_params
    temperature: float = 1.0
    top_k: int = 64
    top_p: float = 0.95
    min_p: float = 0.0
    do_sample: bool = True

config = Config()

In [None]:
# utils

SYSTEM_PROMPT = """
  Respond in the following format:
  <reasoning>
  ...
  </reasoning>
  <answer>
  ...
  </answer>
"""

def get_clear_target(text: str) -> str:
    return text.split('####')[-1].replace(',', '').strip()

def get_reasoning_target(text: str) -> str:
    return text.split('####')[0].strip()

def extract_answer(text: str) -> str:
    return text.split('<answer>')[-1].split('</answer>')[0].strip()

def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)


seed_everything(config.seed)

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=config.model_name,
    max_seq_length=config.max_input_length*2,
    load_in_4bit=True,
    fast_inference=False, # 'Gemma3ForCausalLM' object has no attribute 'vllm_engine'
    random_state=config.seed,
    dtype=torch.float16,
)

==((====))==  Unsloth 2025.8.9: Fast Gemma3 patching. Transformers: 4.55.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = FALSE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Using float16 precision for gemma3 won't work! Using float32.


model.safetensors:   0%|          | 0.00/393M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/233 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/670 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    finetune_vision_layers=False,
    finetune_language_layers=True,
    finetune_attention_modules=True,
    finetune_mlp_modules=True,
    r=16,
    lora_alpha=32,
    lora_dropout=0,
    bias='none',
    random_state=config.seed
)

Unsloth: Making `model.base_model.model.model` require gradients


In [None]:
model.print_trainable_parameters()

trainable params: 3,796,992 || all params: 271,895,168 || trainable%: 1.3965


In [None]:
if config.subset != 'none':
    dataset = load_dataset(config.dataset_name, config.subset, split='train')
else:
    dataset = load_dataset(config.dataset_name, split='train')

README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

In [None]:
dataset

Dataset({
    features: ['question', 'answer'],
    num_rows: 7473
})

In [None]:
def preproc(example):
    messages = [
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': example['question']}
    ]
    return {
        'prompt': messages,
        'answer': get_clear_target(example['answer'])
    }

dataset = dataset.map(preproc)

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

In [None]:
def get_answer_reward(prompts, completions, answer, **kwargs):
    responses = [completion[0]['content'] for completion in completions]
    scores = []
    for policy_answer, target in zip(responses, answer):
        try:
            ans = int(extract_answer(policy_answer))
            scores.append(1.0 if ans==target else 0.0)
        except Exception as e:
            scores.append(-0.5)

    return scores

def get_format_reward(completions, **kwargs):
    scores = []
    for completion in completions:
        reward = 0.0
        policy_answer = completion[0]['content']
        if '<reasoning>' in policy_answer and '</reasoning>' in policy_answer:
            reward += 0.5
        if '<answer>' in policy_answer and '</answer>' in policy_answer:
            reward += 0.5
        scores.append(reward)

    return scores

In [None]:
training_args = GRPOConfig(
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.15,
    lr_scheduler_type='cosine',
    optim='adamw_8bit',
    logging_steps=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_generations=4,
    max_prompt_length=config.max_input_length,
    max_completion_length=config.max_input_length*2,
    max_steps=500,
    save_steps=100,
    max_grad_norm=0.1,
    report_to='none',
    output_dir='outputs',
    use_vllm=False # 'Gemma3ForCausalLM' object has no attribute 'vllm_engine'
)

In [None]:
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        get_answer_reward,
        get_format_reward
    ],
    args=training_args,
    train_dataset=dataset
)

Unsloth: Switching to float32 training since model cannot work with float16


In [None]:
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 7,473 | Num Epochs = 1 | Total steps = 500
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 4 x 1) = 16
 "-____-"     Trainable parameters = 3,796,992 of 271,895,168 (1.40% trained)
`generation_config` default values have been modified to match model-specific defaults: {'max_length': 32768, 'top_p': 0.95}. If this is not desired, please set these values explicitly.


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,kl,entropy,rewards / get_answer_reward / mean,rewards / get_answer_reward / std,rewards / get_format_reward / mean,rewards / get_format_reward / std
1,0.0,-0.5,0.0,474.3125,10.0,2048.0,0.1875,111.153847,10.0,805.0,0.0,0,-0.5,0.0,0.0,0.0
2,0.0,-0.5,0.0,385.4375,26.0,2048.0,0.125,147.928574,26.0,816.0,0.0,No Log,-0.5,0.0,0.0,0.0
3,0.0,-0.5,0.0,215.625,9.0,2048.0,0.0625,93.466675,9.0,688.0,0.0,No Log,-0.5,0.0,0.0,0.0
4,0.0,-0.5,0.0,66.6875,15.0,236.0,0.0,66.6875,15.0,236.0,0.0,No Log,-0.5,0.0,0.0,0.0
5,0.0,-0.5,0.0,539.6875,11.0,2048.0,0.25,36.916668,11.0,94.0,0.0,No Log,-0.5,0.0,0.0,0.0
6,0.0,-0.5,0.0,433.4375,8.0,2048.0,0.1875,60.846157,8.0,329.0,0.0,No Log,-0.5,0.0,0.0,0.0
7,0.0,-0.46875,0.0625,168.5,17.0,2048.0,0.0625,43.200001,17.0,103.0,0.0,No Log,-0.46875,0.125,0.0,0.0
8,0.0,-0.5,0.0,559.9375,13.0,2048.0,0.25,63.916668,13.0,221.0,0.0,No Log,-0.5,0.0,0.0,0.0
9,0.0,-0.5,0.0,34.125,12.0,165.0,0.0,34.125,12.0,165.0,0.0,No Log,-0.5,0.0,0.0,0.0
10,0.0,-0.5,0.0,286.5,11.0,2048.0,0.125,34.857143,11.0,78.0,0.0,No Log,-0.5,0.0,0.0,0.0
