# Train a custom R1 model from scratch using MLX-LM-LoRA

In this one we will train a Zero model with the GRPO trainer to then create a reasoning dataset to then finaly train a custom R1 model. Grab some popcorn and enjoy!

In [None]:
%%capture
%pip install -U mlx-lm-lora mlx-lm ipywidgets

In [None]:
# The trainers and evaluations
from mlx_lm_lora.trainer.grpo_trainer import GRPOTrainingArgs, train_grpo, evaluate_grpo
from mlx_lm_lora.trainer.sft_trainer import SFTTrainingArgs, train_sft

# The Datasets
from mlx_lm_lora.trainer.datasets import CacheDataset, GRPODataset, TextDataset

# The reward functions
from mlx_lm_lora.trainer.grpo_reward_functions import (
    r1_accuracy_reward_func,
    r1_int_reward_func,
    r1_strict_format_reward_func,
    r1_soft_format_reward_func,
    r1_count_xml,
)

# For loading/saving the model and calculating the steps
from mlx_lm_lora.utils import from_pretrained, fuse_and_save_model, calculate_iters

# For loading the dataset
from datasets import load_dataset

# Other needed stuff
from mlx_lm.tuner.utils import print_trainable_parameters
from mlx_lm.tuner.callbacks import TrainingCallback
from mlx_lm.sample_utils import make_sampler
from mlx_lm.generate import generate
from mlx_lm.utils import save_config
from pathlib import Path
import json

# The optimizer
import mlx.optimizers as optim


# Set the datasets, models, and loading params

In [None]:
zero_model_name = "Qwen/Qwen3-1.7B-Base"
zero_ref_model_name = "Qwen/Qwen3-1.7B-Base"
zero_adapter_path = "./Qwen3-1.7B-Zero"
zero_dataset_name = "mlx-community/gsm8k"
r1_dataset_generator_model_name = "mlx-community/Josiefied-Qwen3-8B-abliterated-v1-8bit"
r1_model_name = "Qwen/Qwen3-1.7B"
r1_adapter_path = "./Qwen3-1.7B-R1"
num_r1_samples = 100 # How many reasoning samples we will generate the finetune the R1 model.

max_seq_length = 1024
lora_config = { # LoRA adapter configuration
    "rank": 8,  # Low-rank bottleneck size (Larger rank = smarter, but slower). Suggested 8, 16, 32, 64, 128
    "dropout": 0.0,
    "scale": 10.0, # Multiplier for how hard the LoRA update hits the base weights
    "use_dora": False,
    "num_layers": -1 # Use -1 for all layers
}
quantized_config={
    "bits": 6, # Use 4 bit quantization. Suggested 4, 6, 8
    "group_size": 64
}

# Let's first start with the zero model

In [None]:
zero_ref_model, zero_ref_tokenizer = from_pretrained(
    model=zero_ref_model_name,
    quantized_load=quantized_config,
)

zero_model, zero_tokenizer = from_pretrained(
    model=zero_model_name,
    lora_config=lora_config,
    quantized_load=quantized_config,
)
print_trainable_parameters(zero_model)

In [None]:
adapter_path = Path(zero_adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors"
save_config(lora_config, adapter_path / "adapter_config.json")

# Load and process the dataset

We don't have to format the Dataset the GRPODataset class will do that itself.

If you have to reformat before loading, keep in mind it should be a jsonl looking like:

```json
{
    "prompt": "...",
    "answer": "..."
}
```

This model does not have the Prompt Format we want, so let's do that first.

In [None]:
chat_template = """
{% if messages[0]['role'] == 'system' %}
{{ messages[0]['content'] }}
{% endif %}

User: {{ messages[1]['content'] }}

Assistant: """.strip()

zero_tokenizer.chat_template = chat_template

In [None]:
system = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks quickly in the mind and then provides the user with the answer. The assistant places it's think process between <think> and </think> tags. Then, provides the raw solution between <answer> </answer> tags."

train_set = GRPODataset(
    load_dataset(zero_dataset_name)["train"],
    zero_tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    type_key="type",
    default_system_str=system
)
valid_set = GRPODataset(
    load_dataset(zero_dataset_name)["valid"],
    zero_tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    type_key="type",
    default_system_str=system
)
test_set = GRPODataset(
    load_dataset(zero_dataset_name)["test"],
    zero_tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    type_key="type",
    default_system_str=system
)

# Let's see how the datasset looks like
This is what will get inputed into the model.

In [None]:
sample_input = zero_tokenizer.decode(test_set._data[0][0])
print(sample_input)
sample_input_answer = zero_tokenizer.decode(test_set._data[0][1])

Let's use this exact input the see what the untrained model generates. Since we know the actual answer to this question (18), we know how the model performs. Which is ok, the generated answer is correct!

In [None]:
test_untrained_zero = generate(
    model=zero_model,
    tokenizer=zero_tokenizer,
    prompt=sample_input,
    max_tokens=max_seq_length//2,
)

print(test_untrained_zero)

print("\n\n" + "-"*100)
print(f"Actual answer: {sample_input_answer}")

# Now we're done with all the steps and can actually start the training phase

In [None]:
opt = optim.AdamW(learning_rate=2e-4)  # Set the optimizer

args = GRPOTrainingArgs(
    batch_size=1,
    iters=1000, # calculate_iters(train_set=train_set, batch_size=1, epochs=1),
    gradient_accumulation_steps=1,
    val_batches=1,
    steps_per_report=25,
    steps_per_eval=100,
    steps_per_save=200,
    max_seq_length=max_seq_length,
    adapter_file=adapter_file,
    grad_checkpoint=True,
    group_size=4,
    beta=0.1,
    epsilon=0.0001,
    epsilon_high=0.1,
    max_completion_length=max_seq_length//2,
    reference_model_path=zero_ref_model_name,
    temperature=0.6,
    grpo_loss_type="grpo", # Chosse one: "grpo", "bnpo", "dr_grpo"
    reward_weights=None,
    importance_sampling_level="sequence", # Choose one: "token", "sequence", None
    low_mem_usage=True # Reduces memory usage but doesn't do batch generation, so it will take longer
)

train_grpo(
    model=zero_model,
    tokenizer=zero_tokenizer,
    ref_model=zero_ref_model.freeze(),
    args=args,
    optimizer=opt,
    train_dataset=CacheDataset(train_set),
    val_dataset=CacheDataset(valid_set),
    training_callback=TrainingCallback(),
    reward_funcs=[r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml],
    end_answer_token="</answer>"
)

# After training, let's evaluate and test the trained model out!

In [None]:
loss, _, rewards = evaluate_grpo(
    model=zero_model,
    tokenizer=zero_tokenizer,
    ref_model=zero_ref_model.freeze(),
    dataset=CacheDataset(test_set),
    batch_size=1,
    num_batches=1,
    max_seq_length=max_seq_length,
    beta=0.01,
    epsilon=0.1,
    epsilon_high=0.3,
    group_size=1,
    max_tokens=max_seq_length//2,
    temperature=0.6,
    reward_funcs=[
        r1_accuracy_reward_func,
        r1_int_reward_func,
        r1_strict_format_reward_func,
        r1_soft_format_reward_func,
        r1_count_xml
    ],
    grpo_loss_type="grpo",
    importance_sampling_level="sequence"
)
print(rewards)

In [None]:
test_trained_zero = generate(
    model=zero_model,
    tokenizer=zero_tokenizer,
    prompt=sample_input,
    max_tokens=max_seq_length//2,
)

print(test_trained_zero)

# Finally let's merge and save the final zero model

In [None]:
fuse_and_save_model(
    model=zero_model,
    tokenizer=zero_tokenizer,
    save_path=adapter_path,
    de_quantize=True # Since we quantized the model on load
)

# Let's also remove the reference model from RAM, we don't need it anymore

So that we free out some RAM before we continue...

In [None]:
del zero_ref_model
del zero_ref_tokenizer
del valid_set
del test_set

# Dataset Curation Phase

Now we can go into the dataset curation phase. Here we will first generate some reasoning traces using the zero model, after we've collected a sufficient number of traces, we need to distill them into a format suitable for SFT training.

## Why Distillation?

The zero model outputs structured responses with raw answers:
```
<think>
reasoning steps
</think>
<answer> raw answer </answer>.
```

We want to transform this into natural language while preserving the reasoning:
```
<think> reasoning steps </think>
fluent natural language answer
```

## Distillation Process

We'll use a strong base model to rewrite the raw answers into natural language. This creates high-quality SFT data that teaches the model to:
1. Maintain the reasoning process (thinking tags)
2. Output polished, fluent answers
3. Preserve correctness from the RL training

### Step 1: Generate Zero Reasoning Traces
We'll sample from our dataset, format prompts with the chat template, and generate some reasoning traces.

In [None]:
distil_dataset = load_dataset(zero_dataset_name)["train"].select(range(num_r1_samples))
zero_reasoning_traces = []
prompts = []

sampler = make_sampler(
    temp=0.6,
    top_p=0.95,
    min_p=0.05,
    top_k=20,
)

for idx in range(num_r1_samples):
    example = distil_dataset[idx]
    print(f"Generating trace {idx+1}/{num_r1_samples}...")

    # Extract prompt
    prompt_str = example["prompt"]

    # Format with chat template → returns input_ids
    prompt_tokens = zero_tokenizer.apply_chat_template(
        [
            {"role": "system", "content": system},
            {"role": "user", "content": prompt_str},
        ],
        add_generation_prompt=True,
        return_tokens=True,
    )

    # Generate
    response = generate(
        model=zero_model,
        tokenizer=zero_tokenizer,
        prompt=prompt_tokens,
        max_tokens=max_seq_length // 2,
        sampler=sampler,
    )

    prompts.append(prompt_str)
    zero_reasoning_traces.append(response)

print(f"\n✓ Generated {len(zero_reasoning_traces)} zero reasoning traces")

with open(f"{zero_adapter_path}/zero_reasoning_traces.json", "w") as f:
    json.dump(
        {
            "prompts": prompts,
            "traces": zero_reasoning_traces
        },
        f,
        indent=2
    )

In [None]:
del zero_model
del zero_tokenizer

### Step 2: Distill to Natural Language

Now we'll use a strong model to rewrite the raw answers into fluent natural language.

In [None]:
distill_model, distill_tokenizer = from_pretrained(
    model=r1_dataset_generator_model_name,
    lora_config=lora_config,
    quantized_load=quantized_config,
)
print_trainable_parameters(zero_model)

In [None]:
def extract_between(text, start_tag, end_tag):
    """Extract content between tags."""
    start_idx = text.find(start_tag)
    end_idx = text.find(end_tag)
    if start_idx == -1 or end_idx == -1:
        return None
    return text[start_idx + len(start_tag):end_idx].strip()

def distill_trace(trace, model, tokenizer):
    """Convert one zero trace to SFT format with natural language answer."""
    
    # Extract reasoning and raw answer
    reasoning = extract_between(trace, "<think>", "</think>")
    raw_answer = extract_between(trace, "<answer>", "</answer>")
    
    if not reasoning or not raw_answer:
        return None
    
    # Rewrite raw answer to natural language
    distill_prompt = f"""Rewrite this answer in clear, natural language. Keep it accurate and complete. Only reply with the rewritten reasoning and no additional text.

Raw answer: {raw_answer}

Natural answer:"""
    
    sampler = make_sampler(
        temp=0.8,
        top_p=0.95,
        min_p=0.0,
        top_k=20,
    )
    
    natural_answer = generate(
        model,
        tokenizer,
        prompt=distill_prompt,
        max_tokens=max_seq_length*2,
        sampler=sampler,
    )
    
    sft_completion = f"<think>\n{reasoning}\n</think>\n{natural_answer.strip()}"
    
    return sft_completion

In [None]:
sft_dataset = []

for idx, (prompt, trace) in enumerate(zip(prompts, zero_reasoning_traces)):
    print(f"Distilling {idx+1}/{len(zero_reasoning_traces)}...")
    
    sft_completion = distill_trace(trace, distill_model, distill_tokenizer)
    
    if sft_completion:
        sft_dataset.append({
            "prompt": prompt,
            "completion": sft_completion
        })
    
    if (idx + 1) % 10 == 0:
        print(f"✓ Distilled {idx+1} traces")

### Step 3: Save Final SFT Dataset

Yay! the dataset has been generated let's look at how it turned out.

In [None]:
with open(f"{r1_adapter_path}/sft_dataset.json", "w") as f:
    json.dump(sft_dataset, f, indent=2)

## That's it!

And we're done! You successfully trained your own custom model. You can updload it using the api package by HF. If you have any questions on MLX-LM-LoRA, or find any bugs, or need help, feel free to go to my [GitHub](https://github.com/Goekdeniz-Guelmez/mlx-lm-lora)!

Cheers,
Gökdeniz