In [None]:
!pip install --quiet unsloth vllm==0.7.3

# Load Base Model

In [None]:
import os
import re

from vllm import SamplingParams
from unsloth import FastLanguageModel
from datasets import load_dataset, Dataset
from trl import GRPOConfig, GRPOTrainer


In [None]:
# Set parameters
max_seq_length = 2048
lora_rank = 64

# Initialize model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="meta-llama/Llama-3.2-1B-Instruct",
    max_seq_length=max_seq_length,
    load_in_4bit=False,
    fast_inference=True,
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.8,
)

# Adjust model with PEFT
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=lora_rank,
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)


In [None]:
# Load dataset
dataset = load_dataset("5CD-AI/Vietnamese-meta-math-MetaMathQA-40K-gg-translated", split="train")

In [None]:
# Regex pattern to extract answers from responses
answer_pattern = re.compile(
    r"(đáp án là:|đáp án là :|câu trả lời là:|câu trả lời là :)\s*(.*)",
    re.IGNORECASE
)

# Process data and create formatted list
formatted_dataset = []
for item in dataset:
    response = item["response_vi"].strip().lower()
    match = answer_pattern.search(response)
    if match:
        answer = match.group(2).strip()
        formatted_dataset.append({
            "question": item["query_vi"],
            "answer": answer
        })

# Define reasoning and solution tags
reasoning_start = "<thinking>"
reasoning_end = "</thinking>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

# Create system prompt
system_prompt = f"""
You are given a problem.
Think about the problem and provide your thought process.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your final answer between {solution_start} and {solution_end}.
"""

# Convert formatted list to dataset
train_dataset = Dataset.from_list(formatted_dataset[:8000])
train_dataset = train_dataset.map(lambda x: {
    "prompt": [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": x["question"]},
    ],
    "answer": x["answer"],
})


In [31]:
# Define regex to check response format
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}\$",
    flags=re.MULTILINE | re.DOTALL
)

def match_format_exactly(completions, **kwargs):
    """Kiểm tra chính xác định dạng của phản hồi."""
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        if match_format.search(response) is not None:
            score += 3.0
        scores.append(score)
    return scores

def match_format_approximately(completions, **kwargs):
    """Kiểm tra định dạng một cách tương đối."""
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        score += 0.5 if response.count(reasoning_start) == 1 else -1.0
        score += 0.5 if response.count(reasoning_end) == 1 else -1.0
        score += 0.5 if response.count(solution_start) == 1 else -1.0
        score += 0.5 if response.count(solution_end) == 1 else -1.0
        scores.append(score)
    return scores


# Reward

In [32]:
# Define regex to extract numbers from response
match_numbers = re.compile(
    solution_start + r".*?([\d\.\,]{1,})",
    flags=re.MULTILINE | re.DOTALL
)

def check_answer(prompts, completions, answer, **kwargs):
    """Kiểm tra phản hồi so với đáp án đúng."""
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1) if (guess := match_format.search(r)) is not None else None
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(0)
            continue
        if guess == true_answer:
            score += 3.0
        elif guess.strip() == true_answer.strip():
            score += 1.5
        else:
            score -= 1.5
        scores.append(score)
    return scores

def check_numbers(prompts, completions, answer, **kwargs):
    """Kiểm tra tính chính xác của các số được trích xuất."""
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1) if (guess := match_numbers.search(r)) is not None else None
        for r in responses
    ]

    # Count number of checks
    count = getattr(check_numbers, 'counter', 0) + 1
    check_numbers.counter = count

    # Display information every 5 checks
    if count % 5 == 0:
        print('*' * 20, f"Question: {question}",
              f"\nResponse:\n{responses[0]}",
              f"\nExtracted: {extracted_responses[0]}",
              f"\nGT Answer: {answer[0]}")

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
        try:
            true_answer = float(true_answer.strip())
            # Remove commas in numbers
            guess = float(guess.strip().replace(",", ""))
            scores.append(1.5 if guess == true_answer else -0.5)
        except ValueError:
            scores.append(0)
    
    return scores


# Trainning

In [None]:
# Determine maximum prompt length
max_len = max(train_dataset.map(
        lambda x: {"tokens": tokenizer.apply_chat_template(
            x["prompt"], add_generation_prompt=True, tokenize=True)},
        batched=True,).map(lambda x: {"length": len(x["tokens"])})["length"])

max_prompt_length = max_len + 1

# Set training configuration
training_args = GRPOConfig(
    learning_rate=5e-6,
    weight_decay=5e-4,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_torch_fused",
    logging_steps=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=64,
    num_generations=8,
    max_prompt_length=max_prompt_length,
    max_completion_length=max_seq_length - max_prompt_length,
    num_train_epochs=1,
    max_steps=-1,
    save_steps=250,
    max_grad_norm=0.1,
    report_to="wandb",
    output_dir="outputs_bz2",
)

# Initialize trainer
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    args=training_args,
    train_dataset=dataset,
)




In [None]:
# Start training process
trainer.train()

# **Save LoRA**

In [None]:
model.save_lora("grpo_saved_lora")

#  **Inference**

In [None]:
# Select first question index
idx = 0

# Create message list for model
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": train_dataset[idx]["question"]},
]

# Set sampling parameters
sampling_params = SamplingParams(
    temperature=0.8,
    top_p=0.95,
    max_tokens=1024,
)

# Apply chat template to tokenizer
text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=False,
)

# Path to saved LoRA model
path_lora = "grpo_saved_lora"

# Generate response from model
output = model.fast_generate(
    [text],
    sampling_params=sampling_params,
    lora_request=model.load_lora(path_lora),
)[0].outputs[0].text

# Print results
print(f"Problem:\n{train_dataset[idx]['question']}")
print(f"Response:\n{output}")
print(f"GT Answer: {train_dataset[idx]['answer']}")
