# 1. Setup Environment

In this notebook, we will finetune LLama 3.2 1B with LoRA with GRPO

In [None]:
!nvidia-smi

In [None]:
%%capture
!pip install unsloth vllm==0.7.3
!pip install -U huggingface_hub
!pip install -U wandb

In [None]:
import wandb
wb_token = "79126da44d32381139323a9fc5fc6ba0e32b99c4"
wandb.login(key=wb_token)
# wandb.init(project="Finetuning Qwen2.5 1.5B Math Instruct GRPO", name="track 1", reinit=True)

In [None]:
from huggingface_hub import login

API_KEY = "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"
login(token=API_KEY)

In [None]:
import re

from vllm import SamplingParams
from unsloth import FastLanguageModel
from datasets import load_dataset, Dataset
from trl import GRPOConfig, GRPOTrainer

# 2. Load Base Model

In [None]:
max_seq_length = 1024
lora_rank = 32
SEED = 42
MODEL_NAME = "Qwen/Qwen2.5-Math-1.5B-Instruct"

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    load_in_4bit=False,# Turn off quantization to increase accuracy for reasoning
    fast_inference=True, # optimize throughput
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.8,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    lora_alpha=64,
    use_gradient_checkpointing="unsloth",
    random_state=SEED,
)

In [None]:
# print(model.config)

In [None]:
# Counting parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"→ Total parameters:      {total_params:,}")
print(f"→ Trainable parameters:  {trainable_params:,}")
print(f"→ Frozen parameters:     {total_params-trainable_params:,}")

# 3. Prepare Dataset

In [None]:
DATASET_PATH = "5CD-AI/Vietnamese-meta-math-MetaMathQA-40K-gg-translated"
dataset = load_dataset(DATASET_PATH, split='train')

In [None]:
dataset

In [None]:
dataset[0]

# 4. Configure LoRA
+ Standardize data for the model to learning the reasoning trace and answers distinctively.
+ Use `answer_pattern` to extract the answers.
+ Signal the start/end of the reasoning chain with <thinking>...</thinking> and answer with <answer>...</answer>.
+ Build `system_prompt` to guide the model to produce reasoning chain and then the answer.
+ Change `train_dataset` to 2 fields `prompt` and `answer`.

In [None]:
# answer_pattern = re.compile(
#     r"(đáp án là:|đáp án là :|câu trả lời là:|câu trả lời là :)\s*(.*)",
#     re.IGNORECASE
# )

answer_pattern_en = re.compile(
    r"(?:the answer is:|answer:)\s*(.*)",
    re.IGNORECASE
)

formatted_dataset = []
# Fix this loop correspondingly
for item in dataset:
    response = item['response_en'].strip().lower()
    match = answer_pattern_en.search(response)
    if match:
        answer = match.group(1).strip()
        formatted_dataset.append({
            "question": item['query_en'],
            "answer": answer,
        })

reasoning_start = "<thinking>"
reasoning_end   = "</thinking>"
solution_start  = "<SOLUTION>"
solution_end    = "</SOLUTION>"

system_prompt = \
    f"""You are given a problem.
Think about the problem and provide your thought process.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your final answer between {solution_start}{solution_end}"""

In [None]:
len(formatted_dataset)

In [None]:
formatted_dataset[2]

In [None]:
train_dataset = Dataset.from_list(formatted_dataset[:8000])
train_dataset = train_dataset.map(lambda x: {
    "prompt": [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": x['question']},
    ],
    "answer": x["answer"],
})

In [None]:
train_dataset

In [None]:
from pprint import pprint
sample = train_dataset[0]  

print(sample)
print("*" * 100)
text = tokenizer.apply_chat_template(
    sample["prompt"],
    add_generation_prompt=True,
    tokenize=False
)

pprint(text)

# 5. Training LLM

For reinforcment learning algorithm, evaluating the efficiency of the model is through the reward function. The reward functions evaluate the output based on: correct format reasoning and the correct answer.

In [None]:
# Reward for correct formatting
match_format = re.compile(rf"""
    ^\s*                              # bất kỳ khoảng trắng đầu dòng
    {re.escape(reasoning_start)}     # <thinking>
    .*?                               # chain-of-thought (non-greedy)
    {re.escape(reasoning_end)}        # </thinking>
    .*?                               # có thể có text khác giữa
    {re.escape(solution_start)}       # <SOLUTION>
    (.+?)                             # nhóm 1: nội dung solution
    {re.escape(solution_end)}         # </SOLUTION>
    \s*                               # optional trailing whitespace
    $                                 # kết thúc chuỗi
""", flags=re.DOTALL | re.MULTILINE | re.VERBOSE)

def match_format_exactly(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]['content']
        if match_format.search(response) is not None:
            score += 3.0
        scores.append(score)
    return scores

def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]['content']
        # mỗi tag đúng một lần thì +0.5, thiếu hoặc lặp lại thì -1.0
        score += 0.5 if response.count(reasoning_start) == 1 else -1.0
        score += 0.5 if response.count(reasoning_end) == 1 else -1.0
        score += 0.5 if response.count(solution_start) == 1 else -1.0
        score += 0.5 if response.count(solution_end) == 1 else -1.0
        scores.append(score)
    return scores

Next, we define a function `check_answer(` that p
+ Parses the string between <answer> tags.
+ Awards +3.0 points if the answer is exactly correct.
+ Awards +1.5 points if it only differs by whitespace.
+ Deducts 1.5 points if it’s completely wrong.

Finally, we have a function `check_numbers()` that extracts numeric values from the response then compares them as floats.
+ Awards +1.5 points for each correct number.
+ Deducts 0.5 points for each incorrect number.

In [None]:
# match_numbers = re.compile(
#     solution_start + r".*?([\d\.\,]{1,})",
#     flags=re.MULTILINE | re.DOTALL
# )
    
# def check_answer(prompts, completions, answer, **kwargs):
#     responses = [completion[0]['content'] for completion in completions]
    
#     extracted_responses = [
#     m.group(1) if (m := match_numbers.search(r)) else None
#     for r in responses
#     ]
    
#     scores = []
#     for guess, true_answer in zip(extracted_responses, answer):
#         score = 0
#         if guess is None:
#             scores.append(0)
#             continue
#         if guess == true_answer:
#             score += 3.0
#         elif guess.strip() == true_answer.strip():
#             score += 1.5
#         else:
#             score -= 1.5
#         scores.append(score)
#     return scores

# def check_numbers(prompts, completions, answer, **kwargs):
#     question = prompts[0][-1]["content"]
#     responses = [completion[0]["content"] for completion in completions]

#     extracted_responses = [
#         guess.group(1) if (guess := match_numbers.search(r)) is not None else None for r in responses
#     ]

#     count = getattr(check_numbers, 'counter', 0) + 1
#     check_numbers.counter = count

#     if count % 5 == 0:
#         print('*'*20, f"Question:{question}", f"\nResponse:\n{responses[0]}",
#               f"\nExtracted: {extracted_responses[0]}", f"\nGT Answer: {answer[0]}")

#     scores = []
#     for guess, true_answer in zip(extracted_responses, answer):
#         if guess is None:
#             scores.append(0)
#             continue
#         try:
#             true_answer = float(true_answer.strip())
#             # Remove commas like in 123,456
#             guess = float(guess.strip().replace(",", ""))
#             scores.append(1.5 if guess == true_answer else -0.5)
#         except:
#             scores.append(0)
#     return scores

In [None]:
# Reward for correct answer
match_solution = re.compile(
    rf"{re.escape(solution_start)}\s*(.+?)\s*{re.escape(solution_end)}",
    flags=re.DOTALL
)

def check_answer(prompts, completions, answer, **kwargs):
    responses = [c[0]['content'] for c in completions]
    extracted = [
        m.group(1).strip() if (m := match_solution.search(r)) else None
        for r in responses
    ]
    scores = []
    for guess, true in zip(extracted, answer):
        if guess is None:
            scores.append(0); continue
        scores.append(
            3.0 if guess == true
            else 1.5 if guess.strip() == true.strip()
            else -1.5
        )
    return scores

def check_numbers(prompts, completions, answer, **kwargs):
    responses = [c[0]['content'] for c in completions]
    extracted = [
        m.group(1).strip() if (m := match_solution.search(r)) else None
        for r in responses
    ]
    scores = []
    for guess, true in zip(extracted, answer):
        if guess is None:
            scores.append(0); continue
        try:
            t = float(true.replace(",", ""))
            g = float(guess.replace(",", ""))
            scores.append(1.5 if g == t else -0.5)
        except:
            scores.append(0)
    return scores

# 6. Finetuning & Saving Checkpoints

In [None]:
train_dataset

In [None]:
max_len = max(train_dataset.map(
    lambda x: {"tokens": tokenizer.apply_chat_template(
        x['prompt'], add_generation_prompt=True, tokenize=True)},
    batched=True,
).map(lambda x: {"length": len(x['tokens'])})['length'])

print(max_len)

In [None]:
max_prompt_length = max_len + 1

training_args = GRPOConfig(
    # Diagnostics
    report_to="wandb",
    output_dir="output_bz2",
    logging_steps=1,
    logging_dir="output_bz2/logs",  # thư mục chứa TensorBoard logs
    run_name  = "grpo-run1",

    # Optimization
    learning_rate=5e-6,
    weight_decay=5e-4,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim='adamw_torch_fused',
    max_grad_norm=0.1,

    # Batch
    per_device_train_batch_size=8,
    gradient_accumulation_steps=32,

    # Specific settings
    num_generations=8,
    max_prompt_length=max_prompt_length,
    max_completion_length=max_seq_length - max_prompt_length,
    num_train_epochs=1,
    max_steps=-1,
    save_steps=50,
)

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    args=training_args,
    train_dataset=train_dataset,
)

trainer.train()

In [None]:
# Print the training time
print(trainer.state.global_step, "/", trainer.state.max_steps)
print("Train time:", trainer.state.log_history[-1]["train_runtime"])

# 7. Run Evaluate

In [None]:
model.save_lora("grpo_saved_lora")

In [None]:
idx = 2
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": train_dataset[idx]["question"]},
]

sampling_params = SamplingParams(
    temperature=0.8,
    top_p=0.95,
    max_tokens=1024,
)

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=False,
)

path_lora = "grpo_saved_lora"
output = model.fast_generate(
    [text],
    sampling_params=sampling_params,
    lora_request= model.load_lora(path_lora),
)[0].outputs[0].text

print(f"Problem:\n{train_dataset[idx]["question"]}")
print(f"Response:\n{output}")
print("GT Answer:", train_dataset[idx]["answer"])

## We could expand the method for more complicated problems or integrate with other diverse evaluating signals to further optimize the reasoning chians of LLMs. 

In [None]:
from huggingface_hub import create_repo

repo_id = "Savoxism/grpo-lora-vietnam-llm"
create_repo(repo_id, exist_ok=True)

In [None]:
from huggingface_hub import HfApi, upload_folder

api = HfApi()
upload_folder(
    folder_path="grpo_saved_lora",  
    repo_id=repo_id,                
    repo_type="model"               
)  


In [None]:
from peft import PeftModel
model = PeftModel.from_pretrained("Savoxism/grpo-lora-vietnam-llm")