In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
from unsloth import FastLanguageModel
from trl import GRPOConfig, GRPOTrainer # type: ignore
from datasets import load_dataset
from vllm import SamplingParams

from speedy_utils import * # type: ignore
from llm_utils import * # type: ignore

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
INFO 09-07 19:15:38 [__init__.py:241] Automatically detected platform cuda.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [3]:
assert 'unsloth' in str(GRPOTrainer).lower(), "Expected GRPOTrainer from unsloth"


# ==================== Model Setup ====================
max_seq_length = 32_000
lora_rank = 16

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="/mnt/data/hf-models/unsloth/Qwen3-14B-bnb-4bit",  # Using 0.6B as per project
    # model_name='/mnt/data/models/Qwen-14B-250808/',
    max_seq_length=max_seq_length,
    load_in_4bit=True,
    load_in_8bit=False,
    fast_inference=False,
    max_lora_rank=lora_rank,
    # gpu_memory_utilization=0.7,
)


==((====))==  Unsloth 2025.9.1: Fast Qwen3 patching. Transformers: 4.56.1. vLLM: 0.10.1.1.
   \\   /|    NVIDIA GeForce RTX 4090. Num GPUs = 1. Max memory: 23.527 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.31. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
items = load_by_ext('/home/anhvth5/projects/TRANSLATE_UI/assets/LC_STANDARD/TSN_PW_ZH_TH/training_data.json')

In [5]:
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=lora_rank * 2,
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)

Unsloth 2025.9.1 patched 40 layers with 40 QKV layers, 40 O layers and 40 MLP layers.


In [11]:
from datasets import Dataset
dataset = Dataset.from_list(items[:10])
def f(ex):
    # content = ex["messages"][0]['content']+'\n\n/no_think'
    # prompt = tokenizer.apply_chat_template(
    #     [
    #         {"role": "user", "content": content},
    #     ], add_generation_prompt=True, tokenize=False
    # )
    messages = ex["messages"][:1]
    messages[0]['content']+='\n\n/no_think'
    return {"prompt": messages, "answer": jloads(ex["metadata"])["target_text"]}


dataset = dataset.map(
    f,
    num_proc=4,
    remove_columns=dataset.column_names,
)

Map (num_proc=4):   0%|          | 0/10 [00:00<?, ? examples/s]

In [12]:
# from rewards import func_rewards

In [13]:
# ==================== Reward Functions ====================

# Regex patterns for reward functions
global PRINTED_TIMES, PRINT_EVERY_STEPS
PRINTED_TIMES = 0
PRINT_EVERY_STEPS = 1


from tabulate import tabulate
import re

# Define the regex pattern to match the translation format
translation_format = re.compile(
    r"<think>.*</think><translation>\n(.+?)\n</translation>",
    flags=re.MULTILINE | re.DOTALL
)

# Global variables for printing control
PRINTED_TIMES = 0
PRINT_EVERY_STEPS = 5

def check_translation_format(prompts, completions, answer, **kwargs):
    """
    Reward function to check if the completion matches the translation format
    and if the extracted translation matches the expected answer.
    Prints a formatted table with question, answer, response, and extracted translation for debugging.
    """
    global PRINTED_TIMES, PRINT_EVERY_STEPS
    scores = []
    table_data = []

    for prompt, completion, true_answer in zip(prompts, completions, answer):
        question = prompt[-1]["content"] if isinstance(prompt, list) else prompt
        response = completion[0]["content"] if isinstance(completion, list) else completion
        score = 0

        # Check if the response matches the exact format
        format_match = translation_format.search(response)
        if format_match:
            score += 3.0  # Reward for correct format
            extracted_translation = format_match.group(1).strip()
        else:
            extracted_translation = None
            score -= 1.0  # Penalize for incorrect format

        # Check if the extracted translation matches the true answer
        if extracted_translation == true_answer:
            score += 5.0  # Reward for correct translation
        elif extracted_translation and extracted_translation.strip() == true_answer.strip():
            score += 3.5  # Reward for correct translation ignoring whitespace
        else:
            score -= 2.0  # Penalize for incorrect or missing translation

        # Prepare data for logging
        table_data.append([
            question[:50] + ("..." if len(question) > 50 else ""),  # Truncate for readability
            true_answer,
            response[:50] + ("..." if len(response) > 50 else ""),
            extracted_translation if extracted_translation else "None",
            score
        ])

        scores.append(score)

    # Print table every PRINT_EVERY_STEPS
    if PRINTED_TIMES % PRINT_EVERY_STEPS == 0:
        headers = ["Question", "True Answer", "Response", "Extracted", "Score"]
        print(tabulate(table_data, headers=headers, tablefmt="grid"))

    PRINTED_TIMES += 1
    return scores


# ==================== GRPO Configuration ====================
vllm_sampling_params = SamplingParams(
    min_p=0.1,
    top_p=1.0,
    top_k=-1,
    seed=3407,
    stop=[tokenizer.eos_token, '</translation>'],
    include_stop_str_in_output=True,
)

training_args = GRPOConfig(
    generation_kwargs={
        "min_p": 0.1,
        "top_p": 1.0,
        "top_k": 10,
        "stop": [tokenizer.eos_token],
        "include_stop_str_in_output": True,
    },
    disable_tqdm=True,
    torch_compile=False,
    temperature=1.0,
    learning_rate=5e-6,
    weight_decay=0.01,
    warmup_ratio=0.1,
    lr_scheduler_type="linear",
    optim="adamw_8bit",
    logging_steps=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    num_generations=4,
    max_prompt_length=30_000,
    max_completion_length=256,
    max_steps=10000,
    save_steps=100,
    report_to="none",
    output_dir="outputs",
)

# ==================== Trainer Creation ====================
from rewards import func_rewards
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[  # type: ignore
        check_translation_format, *func_rewards
    ],
    args=training_args,
    train_dataset=dataset,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 4


In [None]:
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 10 | Num Epochs = 1,000 | Total steps = 10,000
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 1 x 1) = 4
 "-____-"     Trainable parameters = 64,225,280 of 14,832,532,480 (0.43% trained)


In [10]:
%debug

> [32m/home/anhvth5/projects/opensloth/src/async-grpo/rewards.py[39m([92m58[39m)[36m_extract_translation_data[39m[34m()[39m
[32m     56[39m     [38;5;28;01mtry[39;00m:
[32m     57[39m         [38;5;66;03m# 1. Extract source text from the user prompt[39;00m
[32m---> 58[39m         user_prompt = prompts[[32m0[39m][-[32m1[39m][[33m"content"[39m]
[32m     59[39m         source_match = re.search(
[32m     60[39m             [33mr"Source Text \(.*? → .*?\):\n(.*)"[39m, user_prompt, re.DOTALL

['<|im_start|>user\n### Role:\nYou are a **game-localisation specialist**. Use the examples below to infer tone, genre terminology, and style.\n\nTranslate the following text faithfully, preserving tone and meaning. Use the examples and glossaries as guidance for terminology and style. Do not include any reasoning or explanation—output only the translation.\n\n**Output schema:**\ntranslation: "<final translation>"\n\n<translation>\n<your final translation here>\n</translati

In [30]:
prompts, completions, answer = load_by_ext("debug_grpo.pkl")

In [40]:
x = jloads(items[100]['metadata'])

Using model: Qwen/Qwen3-235B-A22B-Instruct-2507-FP8


In [None]:
INSTRUCTION = """
You are tasked to evaluate the quality of translations produced by a language model for game translation tasks.
"""

def build_input(source_text, search_examples, glossaries)->str
    return 

lm = MOpenAI(base_url='http://localhost:7999/v1')

SyntaxError: expected ':' (2858548705.py, line 5)