# Train a custom reasoning model using MLX-LM-LoRA's GRPO trainer

I'm about to demonstrate the power of MLX-LM-LoRA through a RL example.

In [None]:
%%capture
%pip install -U mlx-lm-lora mlx-lm ipywidgets

In [None]:
# The trainer and evaluations
from mlx_lm_lora.trainer.grpo_trainer import GRPOTrainingArgs, train_grpo, evaluate_grpo

# The Datasets
from mlx_lm_lora.trainer.datasets import CacheDataset, GRPODataset

# The reward functions
from mlx_lm_lora.trainer.grpo_reward_functions import (
    r1_accuracy_reward_func,
    r1_int_reward_func,
    r1_strict_format_reward_func,
    r1_soft_format_reward_func,
    r1_count_xml
)

# For loading/saving the model and calculating the steps
from mlx_lm_lora.utils import from_pretrained, fuse_and_save_model, calculate_iters

# For loading the dataset
from datasets import load_dataset

# Other needed stuff
from mlx_lm.tuner.utils import print_trainable_parameters
from mlx_lm.tuner.callbacks import TrainingCallback
from mlx_lm.generate import generate
from mlx_lm.utils import save_config
from pathlib import Path

# The optimizer
import mlx.optimizers as optim


# Set the datase, model, and loading params

In [None]:
model_name = "mistralai/Ministral-3-3B-Base-2512"
ref_model_name = "mistralai/Ministral-3-3B-Base-2512"
adapter_path = "./Ministral-3-3B-Zero"
dataset_name = "mlx-community/Dolci-Think-RL-7B-2k"

max_seq_length = 4096
lora_config = { # LoRA adapter configuration
    "rank": 8,  # Low-rank bottleneck size (Larger rank = smarter, but slower). Suggested 8, 16, 32, 64, 128
    "dropout": 0.0,
    "scale": 10.0, # Multiplier for how hard the LoRA update hits the base weights
    "use_dora": False,
    "num_layers": 8 # Use -1 for all layers
}
quantized_config={
    "bits": 4, # Use 4 bit quantization. Suggested 4, 6, 8
    "group_size": 64
}

In [None]:
ref_model, _ = from_pretrained(
    model=ref_model_name,
    quantized_load=quantized_config, # Ref model shoudl be "smarter" then studend model
)

model, tokenizer = from_pretrained(
    model=model_name,
    lora_config=lora_config,
    quantized_load=quantized_config,
)
print_trainable_parameters(model)

In [None]:
adapter_path = Path(adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors"
save_config(lora_config, adapter_path / "adapter_config.json")

# Load and process the dataset

We don't have to format the Dataset the GRPODataset class will do that itself.

If you have to reformat before loading, keep in mind it should be a jsonl looking like:

```json
{
    "prompt": "...",
    "answer": "..."
}
```

This model does not have the Prompt Format we want, so let's do that first.

In [None]:
chat_template = """
{% if messages[0]['role'] == 'system' %}
{{ messages[0]['content'] }}
{% endif %}

User: {{ messages[1]['content'] }}

Assistant: Let me solve this step by step.
""".strip()

tokenizer.chat_template = chat_template

In [None]:
system = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The assistant places it's reasoning between <think> and </think>. Then, provides the solution between <answer> </answer>."

train_set = GRPODataset(
    load_dataset(dataset_name)["train"],
    tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    type_key="type",
    default_system_str=system
)
valid_set = GRPODataset(
    load_dataset(dataset_name)["valid"],
    tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    type_key="type",
    default_system_str=system
)
test_set = GRPODataset(
    load_dataset(dataset_name)["test"],
    tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    type_key="type",
    default_system_str=system
)

# Let's test how the datasset looks like
This is what will get inputed into the model.

In [None]:
sample_input = tokenizer.decode(test_set._data[0][0])
print(sample_input)

Let's use this exact input the see what the untrained model generates.

In [None]:
test_untrained = generate(
    model=model,
    tokenizer=tokenizer,
    prompt=sample_input,
    max_tokens=max_seq_length//4,
)

print(test_untrained)

# Now we're done with all the steps and can actually start the training phase

In [None]:
opt = optim.Muon(learning_rate=2e-4)  # Set the optimizer

args = GRPOTrainingArgs(
    batch_size=1,
    iters=calculate_iters(train_set=train_set, batch_size=1, epochs=1),
    gradient_accumulation_steps=1,
    val_batches=1,
    steps_per_report=25,
    steps_per_eval=100,
    steps_per_save=200,
    max_seq_length=max_seq_length,
    adapter_file=adapter_file,
    grad_checkpoint=True,
    group_size=1,
    beta=0.01,
    epsilon=0.1,
    epsilon_high=0.3,
    max_completion_length=max_seq_length//2,
    reference_model_path=ref_model_name,
    temperature=0.6,
    grpo_loss_type="grpo", # Chosse one: "grpo", "bnpo", "dr_grpo"
    reward_weights=None,
    importance_sampling_level=None # Choose one: "token", "sequence", None
)

train_grpo(
    model=model,
    tokenizer=tokenizer,
    ref_model=ref_model.freeze(),
    args=args,
    optimizer=opt,
    train_dataset=CacheDataset(train_set),
    val_dataset=CacheDataset(valid_set),
    training_callback=TrainingCallback(),
    reward_funcs=[r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml]
)

# After training, let's evaluate and test the trained model out!

In [None]:
loss, _, rewards = evaluate_grpo(
    model=model,
    tokenizer=tokenizer,
    ref_model=ref_model.freeze(),
    dataset=CacheDataset(test_set),
    batch_size=1,
    num_batches=1,
    max_seq_length=max_seq_length,
    beta=0.01,
    epsilon=0.1,
    epsilon_high=0.3,
    group_size=1,
    max_tokens=max_seq_length//2,
    temperature=0.7,
    reward_funcs=[
        r1_accuracy_reward_func,
        r1_int_reward_func,
        r1_strict_format_reward_func,
        r1_soft_format_reward_func,
        r1_count_xml
    ],
    grpo_loss_type="grpo",
    importance_sampling_level=None
)
print(rewards)

In [None]:
test_trained = generate(
    model=model,
    tokenizer=tokenizer,
    prompt=sample_input,
    max_tokens=max_seq_length//2,
)

print(test_trained)

# Finally let's merge and save the final model

In [None]:
fuse_and_save_model(
    model=model,
    tokenizer=tokenizer,
    save_path=adapter_path,
    de_quantize=True # Since we quantized the model on load
)

## That's it!

And we're done! You successfully trained your own custom model. You can updload it using the api package by HF. If you have any questions on MLX-LM-LoRA, or find any bugs, or need help, feel free to go to my [GitHub](https://github.com/Goekdeniz-Guelmez/mlx-lm-lora)!

Cheers,
GÃ¶kdeniz