<a href="https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/LLama3_RL_GRPO_Reasoning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Qwen3 Reinforcement Learning & GRPO with Reasoning

In [None]:
#@title Colab Install { display-mode: "form" }
%%capture
# Install Unsloth + vLLM (pinned versions)
!pip install --no-deps unsloth vllm==0.8.5.post1

# Core dependencies for LoRA, TRL, and bitsandbytes on Colab
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft "trl==0.15.2" triton cut_cross_entropy unsloth_zoo

# Common NLP libraries
!pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub transformers==4.51.3

# Evaluation‐metric
!pip install evaluate rouge_score bert_score

# vLLM extra requirements (skip numpy/transformers/xformers to avoid conflicts)
import requests, re
reqs = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
filtered = re.sub(rb"(transformers|numpy|xformers)[^\n]*\n", b"", reqs)
with open("vllm_requirements.txt","wb") as f:
    f.write(filtered)
!pip install -r vllm_requirements.txt

In [None]:
import re, gc, os, getpass
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from unsloth import FastLanguageModel
from datasets import load_dataset, Dataset
from transformers import TextStreamer
from trl import SFTTrainer, SFTConfig, GRPOConfig, GRPOTrainer
from vllm import SamplingParams
from evaluate import load as load_metric

In [None]:
# Prompt for the token
hf_token = getpass.getpass('Enter your HF access token and press enter: ')

# Set the environment variable
os.environ['HF_TOKEN'] = hf_token

print("HF_TOKEN environment variable has been set.")

In [None]:
model = "unsloth/Meta-Llama-3.1-8B-bnb-4bit"
hub_model = "AManzoni/llama-grpo-rl"

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name            = model,
    max_seq_length        = 512,
    load_in_4bit          = True,
    fast_inference        = True,
    max_lora_rank         = 16,
)

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    random_state               = 111,
    r                          = 16,
    lora_alpha                 = 32,
    bias                       = "none",
    use_gradient_checkpointing = "unsloth",
    target_modules             = ["q_proj", "k_proj", "v_proj", "o_proj",
                                  "gate_proj", "up_proj", "down_proj"],
)

In [None]:
# Define special tokens and system prompt
reasoning_start = "<REASONING>"
reasoning_end   = "</REASONING>"
solution_start  = "<SOLUTION>"
solution_end    = "</SOLUTION>"

system_prompt = (
    "You are given a problem.\n"
    "Think over it and describe your step‐by‐step reasoning.\n"
    f"Enclose reasoning between {reasoning_start} and {reasoning_end}.\n"
    f"Finally, give your answer between {solution_start} and {solution_end}"
)

In [None]:
# Build and assign chat_template to the tokenizer

chat_template = (
    # If the very first message is a SYSTEM role, print it + <eos>:
    "{% if messages[0]['role'] == 'system' %}"
      "{{ messages[0]['content'] + eos_token }}"
      "{% set rest = messages[1:] %}"
    "{% else %}"
      # Otherwise, inject our system_prompt + <eos>:
      "{{ '{system_prompt}' + eos_token }}"
      "{% set rest = messages %}"
    "{% endif %}"

    # Now loop over the remaining messages (either user or assistant):
    "{% for m in rest %}"
      "{% if m['role'] == 'user' %}"
        "{{ m['content'] }}"
      "{% else %}"  # assistant
        "{{ m['content'] + eos_token }}"
      "{% endif %}"
    "{% endfor %}"

    # If we asked for “add_generation_prompt,” append <REASONING> to the end:
    "{% if add_generation_prompt %}"
      "{{ '{reasoning_start}' }}"
    "{% endif %}"
)

chat_template = chat_template\
    .replace("'{system_prompt}'",   f"'{system_prompt}'")\
    .replace("'{reasoning_start}'", f"'{reasoning_start}'")

tokenizer.chat_template = chat_template

In [None]:
# Quick sanity check of the template
example_messages = [
    {"role": "user",
     "content": "Which country has the highest population density?"},
    {"role": "assistant",
     "content": (
         f"{reasoning_start}"
         "I know that country X is small in area but has a huge population, "
         "so its people per square kilometer is extremely high."
         f"{reasoning_end}"
         f"{solution_start}Monaco{solution_end}"
     )},
    {"role": "user",
     "content": "Which planet is farthest from the Sun?"},
]


print("Rendered example:\n")
print(tokenizer.apply_chat_template(example_messages, tokenize=False, add_generation_prompt = True))

In [None]:
# Load the dataset
dataset = load_dataset("openai/gsm8k", "main", split="train")
dataset

In [None]:
print("=== Raw GSM8K columns ===")
print(dataset.column_names)
print("\n=== First raw example ===")
print(dataset[0])

In [None]:
# Define formatting + token-count function
def format_and_count_gsm8k(example):
    question = example["question"].strip()
    reasoning = example["answer"].split("####")[0].replace("\n", " ").strip()
    final_ans = example["answer"].split("####")[1].strip()

    messages = [
        {"role": "system",    "content": system_prompt},
        {"role": "user",      "content": question},
        {"role": "assistant", "content": (
            f"{reasoning_start}{reasoning}{reasoning_end}"
            f"{solution_start}{final_ans}{solution_end}"
        )}
    ]

    enc = tokenizer.apply_chat_template(messages, tokenize=True)
    if isinstance(enc, dict):
        token_len = len(enc["input_ids"])
    else:
        token_len = len(enc)

    text_str = tokenizer.apply_chat_template(messages, tokenize=False)

    return {
        "token_len": token_len,
        "text": text_str,
        "Messages": messages
    }

In [None]:
# Apply formatting to the dataset
dataset = dataset.map(
    format_and_count_gsm8k,
    remove_columns=dataset.column_names,
)

In [None]:
# Sanity check: print a few "text" examples before filtering
print("\n=== Few formatted text examples (first 3) ===")
for i in range(3):
    print(f"\n--- Example {i} ---")
    print(dataset[i]["text"])

In [None]:
# Token-length statistics and dataset filtered for training
lengths = np.array(dataset["token_len"])
print("\nToken-length percentiles (50/90/99):", np.percentile(lengths, [50, 90, 99]))

threshold = 200
sft_ds_filtered     = dataset.filter(lambda ex: ex["token_len"] <= threshold)
sft_ds_filtered     = sft_ds_filtered.select(range(100))
sft_ds_filtered_out = dataset.filter(lambda ex: ex["token_len"] >  threshold)

print(f"\nRemaining for training (≤{threshold} tokens): {len(sft_ds_filtered)} / {len(dataset)}")

In [None]:
# Drop extra columns so dataset contains only "text"
sft_dataset = sft_ds_filtered.remove_columns(["token_len", "Messages"])
print("\n=== Final dataset ===")
print(sft_dataset)

In [None]:
# Define training arguments
sft_config = SFTConfig(
    seed                        = 111,
    do_train                    = True,
    num_train_epochs            = 2,
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 2,
    learning_rate               = 2e-4,
    lr_scheduler_type           = "linear",
    warmup_ratio                = 0.03,
    weight_decay                = 0.01,
    logging_strategy            = "steps",
    logging_steps               = 5,
    report_to                   = "none",
)

In [None]:
# Instantiate SFTTrainer
trainer = SFTTrainer(
    model         = model,
    args          = sft_config,
    train_dataset = sft_dataset,
    tokenizer     = tokenizer,
)

In [None]:
trainer.train()

In [None]:
# Pick one example’s first two “system + user” messages
prompt_messages = sft_ds_filtered_out[0]["Messages"][:2]

# Render into a single string and append <REASONING> for generation:
text = tokenizer.apply_chat_template(
    prompt_messages,
    tokenize=False,
    add_generation_prompt=True,  # append the final <REASONING>
)

In [None]:
# Stream the model’s generations (CoT + solution)
streamer = TextStreamer(tokenizer, skip_prompt=False)

_ = model.generate(
    **tokenizer(text, return_tensors="pt").to("cuda"),
    temperature    = 0.0,
    max_new_tokens = 512,
    streamer       = streamer,
)

In [None]:
del dataset, sft_ds_filtered, sft_ds_filtered_out, sft_dataset

# Free up Python objects and empty GPU cache
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Load the full XSum train split into memory
dataset = load_dataset("EdinburghNLP/xsum", split="train")
dataset

In [None]:
# Keep only the first 400 documents whose raw token count ≤ 300
DOC_CUTOFF = 300
TARGET_EXAMPLES = 400

selected = []
for ex in dataset:
    # ex["document"] is the source text, ex["summary"] is the gold summary.
    doc_tokens = tokenizer(
        ex["document"],
        truncation=False,  # we just want to measure length, not truncate
    )["input_ids"]
    if len(doc_tokens) <= DOC_CUTOFF:
        selected.append({
            "document": ex["document"],
            "summary":  ex["summary"]
        })
        if len(selected) >= TARGET_EXAMPLES:
            break

print(f"✔ Collected {len(selected)} examples with doc‐tokens ≤ {DOC_CUTOFF}.")

In [None]:
# Build a Hugging Face Dataset from that Python list
dataset = Dataset.from_list(selected)
dataset

In [None]:
# Adjust the prompt
system_prompt = (
    "You are given a text.\n"
    "Think carefully about its main points and organize your reasoning.\n"
    f"Enclose reasoning between {reasoning_start} and {reasoning_end}.\n"
    f"Finally, provide the final summary between {solution_start} and {solution_end}"
)

In [None]:
# Turn each document into a formatted “text” string
def to_grpo_input(ex):
    # Build the `<SYSTEM> + <USER>` prompt, then append "<REASONING>" so model knows to start thinking.
    messages = [
        {"role": "system",  "content": system_prompt},
        {"role": "user",    "content": ex["document"]}
    ]
    text_str = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True  # this injects "<REASONING>" at the end
    )
    full_ids = tokenizer(text_str, truncation=False)["input_ids"]

    return {
        "prompt":         text_str,
        "gold_summary": ex["summary"],
        "full_len": len(full_ids)
    }

dataset = dataset.map(to_grpo_input)


In [None]:
# Check token‐length percentiles on the newly‐built “text” field
full_lens = np.array(dataset["full_len"])

print("✔ Final “prompt” lengths 50/90/99 pct:", np.percentile(full_lens, [50,90,99]))

In [None]:
# Remove every column except the needed columns
rl_dataset = dataset.remove_columns([c for c in dataset.column_names
                                    if c not in ("prompt","gold_summary")])

rl_dataset

In [None]:
# Sanity check
print("Ready for GRPO")

print("columns now:", rl_dataset.column_names)

for i in range(3):
    print(f"\n─ Example {i} ─")
    print("text :", repr(rl_dataset[i]["prompt"]))
    print("gold :", repr(rl_dataset[i]["gold_summary"]))

In [None]:
# Build an “</SOLUTION> + optional EOS/whitespace” pattern
solution_end_regex = (
    r"</SOLUTION>"
  + r"[\s]*"
  + "(?:" + re.escape(tokenizer.eos_token) + ")?"
)

# Build a single regex that matches “</REASONING> … <SOLUTION> … </SOLUTION>” at the end
match_full_format = re.compile(
    rf"{reasoning_end}"        # literally "</REASONING>"
    r".*?"                     # any characters (DOTALL)
    rf"{solution_start}"       # literally "<SOLUTION>"
    r"(.+?)"                   # (capture group for content, but we don’t use it here)
    rf"{solution_end_regex}"   # the literal "</SOLUTION>" + optional whitespace + optional EOS
    r"[\s]*$"                  # anchored at the end of the string
,   flags = re.MULTILINE | re.DOTALL
)

In [None]:
# Reward func for the format
def reward_format(completions, **kwargs):
    """
    Returns one float per generated completion:
      • +3.0  if there is exactly one </REASONING>,
               exactly one <SOLUTION>, exactly one </SOLUTION>,
               AND that entire sequence “</REASONING>…<SOLUTION>…</SOLUTION>”
               sits at the very end of the string.
      • Otherwise:
          +0.5 if exactly one </REASONING>,     else −0.5
          +0.5 if exactly one <SOLUTION>,       else −0.5
          +0.5 if exactly one </SOLUTION>,      else −0.5
        (so partial tag matches range from −1.5 … +1.5)
      • If none of the three tags appear (all counts == 0), override to −3.0.
    """
    scores = []
    for c in completions:
        if isinstance(c, list) and isinstance(c[0], dict) and "content" in c[0]:
            resp = c[0]["content"]
        elif isinstance(c, dict) and "content" in c:
            resp = c["content"]
        else:
            resp = str(c)

        # Count how many times each tag appears
        cnt_rend  = resp.count(reasoning_end)
        cnt_sst   = resp.count(solution_start)
        cnt_send  = resp.count(solution_end)

        # Case A: “Perfect” only if exactly one of each tag *and* the regex matches at the very end
        if (cnt_rend == 1) and (cnt_sst == 1) and (cnt_send == 1):
            # Now verify the strict end‐of‐string match
            if match_full_format.search(resp):
                scores.append(3.0)
                continue

        # If none of the three tags appear → heavy penalty
        if (cnt_rend == 0) and (cnt_sst == 0) and (cnt_send == 0):
            scores.append(-3.0)
            continue

        # Otherwise, award partial credit/penalty per tag:
        score = 0.0
        score +=  0.5 if (cnt_rend  == 1) else -0.5
        score +=  0.5 if (cnt_sst   == 1) else -0.5
        score +=  0.5 if (cnt_send  == 1) else -0.5
        scores.append(score)

    return scores

In [None]:
rouge = load_metric("rouge")
bertscore = load_metric("bertscore")

In [None]:
# Capture everything between <SOLUTION> and </SOLUTION>
solution_regex = re.compile(
    rf"{re.escape(solution_start)}(.*?){re.escape(solution_end)}",
    flags = re.DOTALL | re.MULTILINE
)

In [None]:
# Extract text from completion
def extract_solution_text(raw: str) -> str:
    m = solution_regex.search(raw)
    if m:
        return m.group(1).strip() or raw.strip()
    return raw.strip()

In [None]:
# Define a soft length‐reward function
def length_reward(generated: str, reference: str) -> float:
    len_gen = len(generated.split())
    len_ref = len(reference.split())
    dev = abs(len_gen - len_ref) / max(1, len_ref)
    return max(0.0, 1.0 - dev)

In [None]:
# Build the reward_content function
ROUGE_WEIGHT  = 9.0   # maximum points from ROUGE-2 + BERTScore
LENGTH_WEIGHT = 1.0   # maximum points from length

def reward_content(prompts, completions, gold_summary, **kwargs):
    """
    For each generated completion, do:
      a) Extract `gen_summary` between <SOLUTION>…</SOLUTION>, or raw text if no tags.
      b) Compute:
           - rouge2_f1  = ROUGE-2 F1(gen_summary, reference)
           - bert_f1    = BERTScore F1(gen_summary, reference)
         Then combined_similarity = (rouge2_f1 + bert_f1) / 2
      c) Compute length_penalty = length_reward(gen_summary, reference)
      d) final_score = (combined_similarity × ROUGE_WEIGHT)
                     + (length_penalty      × LENGTH_WEIGHT)
      → final_score ∈ [0, ROUGE_WEIGHT + LENGTH_WEIGHT].
    """
    scores = []
    for completion, reference in zip(completions, gold_summary):
        if isinstance(completion, list) and isinstance(completion[0], dict) and "content" in completion[0]:
            raw = completion[0]["content"]
        elif isinstance(completion, dict) and "content" in completion:
            raw = completion["content"]
        else:
            raw = str(completion)
        gen_summary = extract_solution_text(raw)

        # ROUGE-2 F1 (lexical overlap)
        r = rouge.compute(
            predictions = [gen_summary],
            references  = [reference],
            rouge_types = ["rouge2"],
            use_stemmer = True
        )
        rouge2_f1 = r["rouge2"]

        # BERTScore F1 (semantic similarity)
        b = bertscore.compute(
            predictions = [gen_summary],
            references  = [reference],
            model_type  = "sentence-transformers/all-MiniLM-L6-v2",
            num_layers  = 6
        )
        bert_f1 = b["f1"][0]

        # Length penalty
        lp = length_reward(gen_summary, reference)

        # Scale & sum
        combined_sim = 0.5 * (rouge2_f1 + bert_f1)
        final_score = (combined_sim * ROUGE_WEIGHT) + (lp * LENGTH_WEIGHT)
        scores.append(final_score)

    return scores

In [None]:
vllm_sampling_params = SamplingParams(
    seed                       = 111,
    temperature                = 1.0,
    min_p                      = 0.1,
    top_p                      = 0.95,
    top_k                      = 64,
    stop                       = [tokenizer.eos_token],
    include_stop_str_in_output = True,
)

In [None]:
grpo_config = GRPOConfig(
    seed                        = 111,
    use_vllm                    = True,
    vllm_sampling_params        = vllm_sampling_params,
    do_train                    = True,
    num_train_epochs            = 2.0,
    per_device_train_batch_size = 16,
    gradient_accumulation_steps = 4,
    learning_rate               = 5e-6,
    lr_scheduler_type           = "linear",
    warmup_ratio                = 0.03,
    weight_decay                = 0.01,
    optim                       = "adamw_8bit",
    num_generations             = 8,
    max_completion_length       = 150,
    logging_strategy            = "steps",
    logging_steps               = 10,
    report_to                   = "none",
    output_dir                  = "grpo_outputs",
    overwrite_output_dir        = True,
    save_strategy               = "epoch",
    push_to_hub                 = True,
    hub_model_id                = hub_model,
)

In [None]:
trainer = GRPOTrainer(
    model            = model,
    processing_class = tokenizer,
    train_dataset    = rl_dataset,
    args             = grpo_config,
    reward_funcs     = [reward_format,
                        reward_content],
)

In [None]:
# Memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)

print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

In [None]:
trainer.train()

In [None]:
# Final memory and time stats
last_log = trainer.state.log_history[-1]

# Extract runtime info
train_seconds      = last_log["train_runtime"]
samples_per_second = last_log.get("train_samples_per_second", None)

# Recompute GPU memory stats
used_memory     = round(torch.cuda.max_memory_reserved() / 1024**3, 2)
used_for_lora   = round(used_memory - start_gpu_memory, 2)
used_pct        = round(used_memory / max_memory * 100, 2)
lora_pct        = round(used_for_lora / max_memory * 100, 2)

# Print summary
print(f"Training time      : {train_seconds:.1f} seconds ({train_seconds/60:.2f} minutes)")
if samples_per_second:
    print(f"Throughput         : {samples_per_second:.1f} samples/second")
print(f"Peak VRAM usage    : {used_memory} GB ({used_pct}% of max memory)")
print(f"VRAM for training  : {used_for_lora} GB ({lora_pct}% of max memory)")

In [None]:
#@title Visualize GRPO Training Metrics { display-mode: "form" }
plt.style.use('seaborn-v0_8')
sns.set_palette("pastel")

logs = trainer.state.log_history

steps      = []
losses     = []
rewards    = []
kl_penalty = []

for log in logs:
    if "loss" in log:
        steps.append(log["step"])
        losses.append(log["loss"])
        rewards.append(log.get("reward", None))
        kl_penalty.append(log.get("kl", None))

plt.figure(figsize=(8, 6))

# Loss vs. Step
plt.subplot(3, 1, 1)
plt.plot(steps, losses, marker="o", linestyle="-", color="C0")
plt.ylabel("Policy Loss")
plt.grid(True, alpha=0.3)
plt.title("Policy Loss")

# Reward vs. Step
plt.subplot(3, 1, 2)
plt.plot(steps, rewards, marker="s", linestyle="--", color="C1")
plt.ylabel("Total Reward")
plt.grid(True, alpha=0.3)
plt.title("Total Reward")

# KL Penalty vs. Step
plt.subplot(3, 1, 3)
plt.plot(steps, kl_penalty, marker="x", linestyle="-.", color="C2")
plt.xlabel("Training Step")
plt.ylabel("KL Penalty")
plt.grid(True, alpha=0.3)
plt.title("KL Penalty")

plt.tight_layout()
plt.show()