In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm==0.8.5.post1

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm==0.8.5.post1
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

In [None]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-4B-Base",
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.7, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank*2, # *2 speeds up training
    use_gradient_checkpointing = "unsloth", # Reduces memory usage
    random_state = 3407,
)

In [None]:
#4. Chat Template
reasoning_start, reasoning_end = "<start_working_out>", "<end_working_out>"
solution_start, solution_end = "<SOLUTION>", "</SOLUTION>"
system_prompt = f"""You are given a problem.
Think about the problem and provide your working out.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}"""
tokenizer.chat_template = (
    "{% for message in messages %}"
    "{% if message['role'] == 'system' %}{{ message['content'] + eos_token }}"
    "{% elif message['role'] == 'user' %}{{ message['content'] }}"
    "{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}"
    "{% endif %}{% endfor %}{% if add_generation_prompt %}" + reasoning_start + "{% endif %}"
)

In [None]:
from trl import SFTTrainer, GRPOTrainer, GRPOConfig
from datasets import load_dataset, Dataset
from transformers import TrainingArguments
import random

In [None]:
#5. Load Datasets
split = "train"
gsm8k = load_dataset("openai/gsm8k", "main", split=split)
commonsenseqa = load_dataset("tau/commonsense_qa", split=split)
sciq = load_dataset("allenai/sciq", split=split)
logiqa = load_dataset("lucasmccabe/logiqa", split=split)

In [None]:
#6. Format Examples
def format_example(example, dataset_name):
    try:
        if dataset_name == "gsm8k":
            problem = example["question"]
            parts = example["answer"].split("####")
            reasoning = parts[0].strip()
            answer = parts[-1].strip()
        elif dataset_name == "commonsenseqa":
            problem = example["question"] + "\n" + "\n".join(
                [f"{label}. {text}" for label, text in zip(example["choices"]["label"], example["choices"]["text"])]
            )
            reasoning = "Let's use common sense to find the best answer."
            answer = example["answerKey"]
        elif dataset_name == "sciq":
            problem = example["question"] + "\n" + "\n".join(
                [f"{i+1}. {x}" for i, x in enumerate(example["distractors"] + [example["correct_answer"]])]
            )
            reasoning = "Let's apply science to reason this out."
            answer = example["correct_answer"]
        elif dataset_name == "logiqa":
            problem = example["context"] + "\n" + example["query"] + "\n" + "\n".join(
                [f"{i+1}. {opt}" for i, opt in enumerate(example["options"])]
            )
            reasoning = "Let's logically analyze the given context and query."
            answer = example["options"][example["correct_option"]]
        return {
            "problem": problem,
            "generated_solution": f"{reasoning_start}{reasoning}{reasoning_end}{solution_start}{answer}{solution_end}",
            "expected_answer": answer,
        }
    except Exception:
        return None

In [None]:
#7. Dataset Sampling (30-25-20-15 Split)
total = 10000
ratios = {"gsm8k": 0.30, "commonsenseqa": 0.25, "sciq": 0.20, "logiqa": 0.15}
sizes = {k: int(r * total) for k, r in ratios.items()}

formatted = []
formatted += [format_example(x, "gsm8k") for x in gsm8k.shuffle(seed=3407).select(range(sizes["gsm8k"]))]
formatted += [format_example(x, "commonsenseqa") for x in commonsenseqa.shuffle(seed=3407).select(range(sizes["commonsenseqa"]))]
formatted += [format_example(x, "sciq") for x in sciq.shuffle(seed=3407).select(range(sizes["sciq"]))]
formatted += [format_example(x, "logiqa") for x in logiqa.shuffle(seed=3407).select(range(sizes["logiqa"]))]

formatted = [f for f in formatted if f is not None]
dataset = Dataset.from_list(formatted)

In [None]:
#8. Add Prompt Field for GRPO
dataset = dataset.map(lambda ex: {
    "prompt": ex["problem"],
    "generated_solution": ex["generated_solution"],
    "expected_answer": ex["expected_answer"],
})

In [None]:
#9. Reward Function (Unified for All)
def reward_func(prompts, completions, completion_ids, **kwargs):
    expected = kwargs["expected_answer"]
    rewards = []
    for comp, exp in zip(completions, expected):
        try:
            pred = comp.split(solution_start)[-1].split(solution_end)[0].strip()
            rewards.append(1.0 if pred.lower() == exp.lower() else 0.0)
        except:
            rewards.append(0.0)
    return torch.tensor(rewards, dtype=torch.float32)

In [None]:
grpo_config = GRPOConfig(
    beta=0.1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    warmup_steps=10,
    max_steps=300,
    lr_scheduler_type = "linear",
    learning_rate=2e-5,
    # num_train_epochs = 1,
    num_generations = 4,
    fp16=True,
    logging_steps=10,
    save_steps = 100,
    output_dir="grpo_outputs",
    optim="adamw_8bit",
    seed=3407,
)

In [None]:
trainer = GRPOTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    dataset_text_field="generated_solution",
    expected_answer_field="expected_answer",
    reward_funcs=[reward_func],
    args=grpo_config,
)

trainer.train()