# MLX-LM-LoRA GRPO Training Tutorial

This notebook demonstrates how to fine-tune language models using **GRPO (Group Relative Policy Optimization)** on Apple Silicon with MLX-LM-LoRA. GRPO is a reinforcement learning technique that trains models to generate outputs that maximize custom reward functions.

## What You'll Learn:
- Setting up quantized models with LoRA adapters for efficient training
- Defining custom reward functions for structured output
- Training with GRPO on mathematical reasoning tasks
- Evaluating model improvements before and after training
- Saving and sharing your trained models

## Requirements:
- Apple Silicon Mac (M1/M2/M3/M4)
- MLX-LM-LoRA library
- Sufficient memory for model loading (4-bit quantization helps!)

Let's get started! ðŸš€

---

## Step 1: Import Required Libraries

In [None]:
from mlx_lm_lora.utils import save_pretrained_merged, from_pretrained, calculate_iters, push_to_hub
from mlx_lm_lora.trainer.grpo_trainer import GRPOTrainingArgs, train_grpo, evaluate_grpo
from mlx_lm_lora.trainer.datasets import CacheDataset, GRPODataset

from datasets import load_dataset

from mlx_lm.tuner.utils import print_trainable_parameters
from mlx_lm.tuner.callbacks import TrainingCallback, WandBCallback
from mlx_lm.generate import generate, make_sampler

import mlx.optimizers as optim

import re

---

## Step 2: Configure Training Parameters

Here we set up all the hyperparameters and configuration needed for GRPO training:

### Model Configuration:
- **Base Model**: LiquidAI's LFM2.5-1.2B-Instruct - a fast, capable instruction-following model
- **Max Sequence Length**: 2048 tokens - balancing memory and context
- **LoRA Configuration**: Rank 8, Scale 10.0, targeting 8 layers for efficient fine-tuning
- **Quantization**: 4-bit MXFP4 for the trainable model, 6-bit for reference model to save memory

### Dataset:
- Using GSM8K - a dataset of grade school math problems for reasoning training

### Special Tokens:
We define custom tokens to structure the model's reasoning output:
- `<think>` / `</think>` - For step-by-step reasoning
- `<answer>` / `</answer>` - For the final numerical answer

In [None]:
model_name = "LiquidAI/LFM2.5-1.2B-Instruct"
new_model_name = "LFM2.5-1.2B-Zero"
user_name = "Goekdeniz-Guelmez"

max_seq_length = 2048

adapter_path = f"./{new_model_name}"

grpo_dataset_name = "mlx-community/gsm8k"
lora_config = {
    "rank": 8,
    "dropout": 0.0,
    "scale": 10.0,
    "use_dora": False,
    "num_layers": 8
}
quantized_load={
    "bits": 4,
    "group_size": 32,
    "mode": "mxfp4"
}
ref_quantized_load={
    "bits": 6,
    "group_size": 128
}

reasoning_start = "<think>"
reasoning_end   = "</think>"
solution_start = "<answer>"
solution_end = "</answer>"

---

## Step 3: Load Models and Tokenizer

Now we load two models:
1. **Training Model**: Quantized with LoRA adapters attached - this is what we'll fine-tune
2. **Reference Model**: A frozen copy used for KL divergence calculation in GRPO

The reference model prevents the trained model from deviating too far from the original behavior, ensuring stable learning.

**Memory Tip**: Using different quantization levels (4-bit vs 6-bit) optimizes the memory footprint for both models on Apple Silicon.

In [None]:
model, tokenizer, adapter_file = from_pretrained(
    model=model_name,
    lora_config=lora_config,
    quantized_load=quantized_load,
    new_adapter_path=adapter_path
)

ref_model, _, _ = from_pretrained(
    model=model_name,
    quantized_load=ref_quantized_load,
)
print_trainable_parameters(model)

---

## Step 4: Define Reward Functions

Reward functions are the heart of GRPO training! They evaluate how good each generated response is.

We define **four reward functions** to train the model on both format correctness and answer accuracy:

### 1. `match_format_exactly` (+3.0 points)
   - Checks if the output follows the exact expected format with proper reasoning and answer tags

### 2. `match_format_approximately` (+0.5 per correct element)
   - More lenient scoring for partial format compliance
   - Checks for presence of each required tag

### 3. `check_answer` (up to +3.0 points)
   - Extracts the answer and compares it to the ground truth
   - Gives partial credit for close numerical matches
   - Penalizes wrong answers

### 4. `check_numbers` (+1.5 points)
   - Secondary check that looks for any number in the answer section
   - Helps reinforce numerical output generation

These functions work together to guide the model toward generating well-formatted, accurate mathematical reasoning!

In [None]:
match_format = re.compile(
    rf".+?\n{reasoning_end}\n{solution_start}(.+?){solution_end}",
    flags = re.MULTILINE | re.DOTALL
)

def match_format_exactly(prompts, completions, answer, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"] if isinstance(completion, list) else completion
        if match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores

def match_format_approximately(prompts, completions, answer, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"] if isinstance(completion, list) else completion
        score += 0.5 if response.count(reasoning_end)   == 1 else -0.5
        score += 0.5 if response.count(solution_start)  == 1 else -0.5
        score += 0.5 if response.count(solution_end)    == 1 else -0.5
        score -= 0.5 if response.count(reasoning_start) >= 1 else 0
        scores.append(score)
    return scores

def check_answer(prompts, completions, answer, **kwargs):
    responses = [completion[0]["content"] if isinstance(completion, list) else completion for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_format.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(0)
            continue
        if guess == true_answer:
            score += 3.0
        elif guess.strip() == true_answer.strip():
            score += 1.5
        else:
            try:
                ratio = float(guess) / float(true_answer)
                if   ratio >= 0.9 and ratio <= 1.1: score += 0.5
                elif ratio >= 0.8 and ratio <= 1.2: score += 0.25
                else: score -= 1.0
            except:
                score -= 0.5
        scores.append(score)
    return scores

match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})",
    flags = re.MULTILINE | re.DOTALL
)

def check_numbers(prompts, completions, answer, **kwargs):
    responses = [completion[0]["content"] if isinstance(completion, list) else completion for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_numbers.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
        try:
            true_answer = float(true_answer.strip())
            guess       = float(guess.strip())
            scores.append(1.5 if guess == true_answer else 0.0)
        except:
            scores.append(0)
            continue
    return scores

---

## Step 5: Load and Prepare Datasets

We create three dataset splits using the `GRPODataset` wrapper:

- **Training Set**: 100 examples - for learning
- **Validation Set**: 2 examples - for monitoring during training
- **Test Set**: 10 examples - for final evaluation

The `GRPODataset` handles:
- Tokenization of prompts and answers
- Proper formatting with system messages
- Storage of ground truth answers for reward calculation

**Note**: We're using small subsets for this tutorial. For production training, you'd use the full datasets!

In [None]:
train_set = GRPODataset(
    load_dataset(grpo_dataset_name)["train"].take(100),
    tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    system_key="system",
    type_key="type",
)
test_set = GRPODataset(
    load_dataset(grpo_dataset_name)["test"].take(10),
    tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    system_key="system",
    type_key="type",
)

---

## Step 6: Inspect a Sample Prompt

Let's look at what a typical training example looks like after processing. This helps us understand what the model will see as input.

In [None]:
text = tokenizer.decode(test_set.process(test_set[0][0]))
print(text)

---

## Step 7: Test Model BEFORE Training

Before we start training, let's see how the base model performs on a test example. This gives us a baseline to compare against after GRPO training.

We use:
- **Temperature**: 0.1 (low) for more deterministic output
- **Top-p**: 0.95 for nucleus sampling
- **Top-k**: 50 for limiting vocabulary choices

In [None]:
before_test_output = generate(
    model=model,
    tokenizer=tokenizer,
    prompt=text,
    verbose=True,
    max_tokens=max_seq_length//2,
    sampler=make_sampler(temp=0.1, top_p=0.95, top_k=50)
)

---

## Step 8: Evaluate Pre-Training Rewards

Now let's run our reward functions on the pre-training output to see the baseline scores:

This shows us:
- Whether the model already follows the format
- How accurate the initial answers are
- Which rewards need the most improvement

These baseline scores help us track training progress!

In [None]:
test_answer = "540"
completions = [before_test_output]

print(f"match_format_exactly: {match_format_exactly([text], completions, [test_answer])[0]}")
print(f"match_format_approximately: {match_format_approximately([text], completions, [test_answer])[0]}")
print(f"check_answer: {check_answer([text], completions, [test_answer])[0]}")
print(f"check_numbers: {check_numbers([text], completions, [test_answer])[0]}")

# Extract the matched number if found
generated_match = match_numbers.search(before_test_output)
generated_answer = generated_match.group(1) if generated_match else "None"
print(f"Answer: {test_answer}, Generated answer: {generated_answer}")

---

## Step 9: Train with GRPO! ðŸŽ¯

Now for the main event - GRPO training!

### Key Training Parameters:
- **Batch Size**: 1 (conservative for memory)
- **Epochs**: 1 full pass through the training data
- **Beta**: 0.4 - KL divergence penalty weight
- **Group Size**: 4 - number of completions sampled per prompt
- **Gradient Checkpointing**: Enabled to save memory
- **Loss Type**: Standard GRPO
- **Importance Sampling**: Sequence-level for better credit assignment

### How GRPO Works:
1. For each prompt, the model generates multiple completions (group_size=4)
2. Each completion is scored using our reward functions
3. The model learns to increase probability of high-reward completions
4. The reference model ensures we don't deviate too far from the original behavior

### Training Progress:
You'll see reports every 5 steps showing:
- Average rewards
- Training loss
- KL divergence from reference model

**Note**: Uncomment the WandBCallback section if you want to log training metrics to Weights & Biases!

This may take a while depending on your hardware. Grab a coffee! â˜•

In [None]:
# Define custom reward weights if you want to weight them differently
custom_reward_weights = [
    1.0,  # match_format_exactly
    1.0,  # match_format_approximately
    1.0,  # check_answer
    1.0,  # check_numbers
]

opt = optim.AdamW(learning_rate=8e-5)

args=GRPOTrainingArgs(
    batch_size=1,
    iters=calculate_iters(train_set, batch_size=1, epochs=1),
    val_batches=1,
    steps_per_report=5,
    steps_per_eval=100,
    steps_per_save=200,
    adapter_file=adapter_file,
    max_seq_length=max_seq_length,
    grad_checkpoint=True,
    gradient_accumulation_steps=1,
    beta=0.4,
    group_size=4,
    epsilon=1e-3,
    epsilon_high=2e-3,
    temperature=0.1,
    top_p=0.95,
    top_k=50,
    max_completion_length=max_seq_length//2,
    reward_weights=custom_reward_weights,
    grpo_loss_type="grpo", # grpo, bnpo, or dr_grpo
    importance_sampling_level="sequence", # token, sequence, None for basic grpo
)

train_grpo(
    model=model,
    ref_model=ref_model.freeze(),
    tokenizer=tokenizer,
    optimizer=opt,
    train_dataset=CacheDataset(train_set),
    args=args,
    training_callback=TrainingCallback(),
    # training_callback=WandBCallback(
    #     project_name=new_model_name,
    #     log_dir=adapter_path,
    #     config=vars(args),
    #     wrapped_callback=TrainingCallback(),
    # ),
    reward_funcs=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    end_answer_token=solution_end
)

---

## Step 10: Test Model AFTER Training

Training is complete! Now let's test the same example again to see how the model has improved.

**Note**: We're using slightly different sampling parameters (temp=0.6, top_k=20) to encourage some diversity while still maintaining quality.

In [None]:
after_test_output = generate(
    model=model,
    tokenizer=tokenizer,
    prompt=text,
    verbose=True,
    max_tokens=max_seq_length//2,
    sampler=make_sampler(temp=0.6, top_p=0.95, top_k=20)
)

---

## Step 11: Evaluate Post-Training Rewards

Let's compare the post-training scores with our baseline!

**Look for improvements in:**
- Format compliance (higher format scores)
- Answer accuracy (higher check_answer scores)
- Numerical extraction (check_numbers scores)

The difference between pre- and post-training scores shows the effectiveness of GRPO training! ðŸ“ˆ

In [None]:
test_answer = "540"
completions = [after_test_output]

print(f"match_format_exactly: {match_format_exactly([text], completions, [test_answer])[0]}")
print(f"match_format_approximately: {match_format_approximately([text], completions, [test_answer])[0]}")
print(f"check_answer: {check_answer([text], completions, [test_answer])[0]}")
print(f"check_numbers: {check_numbers([text], completions, [test_answer])[0]}")

# Extract the matched number if found
generated_match = match_numbers.search(after_test_output)
generated_answer = generated_match.group(1) if generated_match else "None"
print(f"Answer: {test_answer}, Generated answer: {generated_answer}")

---

## Step 12: Full Evaluation on Test Set

Time for a comprehensive evaluation! This runs the trained model on the entire test set (10 examples) and computes aggregate statistics.

The evaluation provides:
- **Average rewards** across all test examples
- **Individual reward breakdowns** for each reward function
- **Total reward** - the combined score
- **GRPO loss metrics** - showing model optimization

This gives us a quantitative measure of how well the model performs on unseen data!

In [None]:
evaluate_grpo(
    model=model,
    ref_model=ref_model.freeze(),
    dataset=CacheDataset(test_set),
    tokenizer=tokenizer,
    batch_size=args.batch_size,
    num_batches=None,
    beta=args.beta,
    epsilon=args.epsilon,
    epsilon_high=args.epsilon_high,
    group_size=args.group_size,
    max_seq_length=max_seq_length,
    max_tokens=args.max_completion_length,
    temperature=args.temperature,
    top_p=args.top_p,
    top_k=args.top_k,
    min_p=args.min_p,
    reward_funcs=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    reward_weights=custom_reward_weights,
    grpo_loss_type=args.grpo_loss_type,
    importance_sampling_level=args.importance_sampling_level,
    end_answer_token=solution_end
)

---

## Step 13: Save the Merged Model

Now we merge the trained LoRA adapters with the base model and save the full model to disk.

This creates a standalone model that includes all the improvements from training, ready to use without needing to load adapters separately.

The merged model will be saved in the `LFM2.5-1.2B-Zero` directory.

In [None]:
save_pretrained_merged(
    model=model,
    tokenizer=tokenizer,
    save_path=new_model_name
)

---

## Step 14: Share on Hugging Face Hub ðŸ¤—

Finally, let's share our trained model with the community!

This uploads the LoRA adapters to Hugging Face Hub, making them accessible to others who want to use or build upon your work.

**Before running**: Replace `"HF_KEY"` with your actual Hugging Face API token!

### What gets uploaded:
- LoRA adapter weights
- Adapter configuration
- (Optional) Full merged model if `remove_adapters=False`

---

## ðŸŽ‰ Congratulations!

You've successfully:
- âœ… Loaded and quantized a model for efficient Apple Silicon training
- âœ… Defined custom reward functions for structured output
- âœ… Trained a model using GRPO reinforcement learning
- âœ… Evaluated improvements in model performance
- âœ… Saved and shared your trained model

### Next Steps:
- Experiment with different reward functions
- Try longer training (more epochs)
- Test on different datasets
- Adjust hyperparameters like beta, group_size, and temperature
- Train larger models with different quantization settings

Happy training with MLX-LM-LoRA! ðŸš€

In [None]:
push_to_hub(
  model_path=adapter_path,
  hf_repo=f"{user_name}/{new_model_name}",
  api_key="HF_KEY",
  private=False,
  commit_message="Add preference adapters",
  remove_adapters=False
)