# Train a custom R1 model from scratch using MLX-LM-LoRA

In this one we will train a Zero model with the GRPO trainer to then create a reasoning dataset to then finaly train a custom R1 model. Grab some popcorn and enjoy!

In [None]:
%%capture
%pip install -U mlx-lm-lora mlx-lm ipywidgets

In [26]:
# The trainers and evaluations
from mlx_lm_lora.trainer.grpo_trainer import GRPOTrainingArgs, train_grpo, evaluate_grpo
from mlx_lm_lora.trainer.sft_trainer import SFTTrainingArgs, train_sft

# The Datasets
from mlx_lm_lora.trainer.datasets import CacheDataset, GRPODataset, TextDataset

# The reward functions
from mlx_lm_lora.trainer.grpo_reward_functions import (
    r1_accuracy_reward_func,
    r1_int_reward_func,
    r1_strict_format_reward_func,
    r1_soft_format_reward_func,
    r1_count_xml,
)

# For loading/saving the model and calculating the steps
from mlx_lm_lora.utils import from_pretrained, fuse_and_save_model, calculate_iters

# For loading the dataset
from datasets import load_dataset, Dataset

# Other needed stuff
from mlx_lm.tuner.utils import print_trainable_parameters
from mlx_lm.tuner.callbacks import TrainingCallback
from mlx_lm.sample_utils import make_sampler
from mlx_lm.generate import generate
from mlx_lm.utils import save_config
from pathlib import Path
import json

# The optimizer
import mlx.optimizers as optim


# Set the datasets, models, and loading params

In [24]:
base_model_name = "Qwen/Qwen3-1.7B-Base"
zero_ref_model_name = "Qwen/Qwen3-1.7B-Base"
zero_adapter_path = "./Qwen3-1.7B-Zero"
zero_dataset_name = "mlx-community/gsm8k"
r1_dataset_generator_model_name = "Qwen/Qwen3-1.7B"
r1_model_name = "Qwen/Qwen3-1.7B"
r1_adapter_path = "./Qwen3-1.7B-R1"
num_r1_samples = 10 # How many reasoning samples we will generate the finetune the R1 model.

max_seq_length = 1024
lora_config = { # LoRA adapter configuration
    "rank": 8,  # Low-rank bottleneck size (Larger rank = smarter, but slower). Suggested 8, 16, 32, 64, 128
    "dropout": 0.0,
    "scale": 10.0, # Multiplier for how hard the LoRA update hits the base weights
    "use_dora": False,
    "num_layers": -1 # Use -1 for all layers
}
quantized_config={
    "bits": 4, # Use 4 bit quantization. Suggested 4, 6, 8
    "group_size": 64
}

# Let's first start with the zero model

In [None]:
zero_ref_model, zero_ref_tokenizer = from_pretrained(
    model=zero_ref_model_name,
    quantized_load=quantized_config,
)

zero_model, zero_tokenizer = from_pretrained(
    model=base_model_name,
    lora_config=lora_config,
    quantized_load=quantized_config,
)
print_trainable_parameters(zero_model)

Loading model Qwen/Qwen3-1.7B-Base


Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

Quantizing model with 4 bits
Loading model Qwen/Qwen3-1.7B-Base


Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

Loading LoRA adapters with config: {'rank': 8, 'dropout': 0.0, 'scale': 10.0, 'use_dora': False, 'num_layers': -1}
Quantizing model with 4 bits
Trainable parameters: 0.507% (8.716M/1720.575M)


In [4]:
adapter_path = Path(zero_adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors"
save_config(lora_config, adapter_path / "adapter_config.json")

# Load and process the dataset

We don't have to format the Dataset the GRPODataset class will do that itself.

If you have to reformat before loading, keep in mind it should be a jsonl looking like:

```json
{
    "prompt": "...",
    "answer": "..."
}
```

This model does not have the Prompt Format we want, so let's do that first.

In [5]:
chat_template = """
{% if messages[0]['role'] == 'system' %}
{{ messages[0]['content'] }}
{% endif %}

User: {{ messages[1]['content'] }}

Assistant: """.strip()

zero_tokenizer.chat_template = chat_template

In [6]:
system = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks quickly in the mind and then provides the user with the answer. The assistant places it's think process between <think> and </think> tags. Then, provides the raw solution between <answer> </answer> tags."

train_set = GRPODataset(
    load_dataset(zero_dataset_name)["train"],
    zero_tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    type_key="type",
    default_system_str=system
)
valid_set = GRPODataset(
    load_dataset(zero_dataset_name)["valid"],
    zero_tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    type_key="type",
    default_system_str=system
)
test_set = GRPODataset(
    load_dataset(zero_dataset_name)["test"],
    zero_tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    type_key="type",
    default_system_str=system
)

# Let's see how the datasset looks like
This is what will get inputed into the model.

In [None]:
sample_input = zero_tokenizer.decode(test_set._data[0][0])
print(sample_input)
sample_input_answer = zero_tokenizer.decode(test_set._data[0][1])

Let's use this exact input the see what the untrained model generates. Since we know the actual answer to this question (18), we know how the model performs. Which is ok, the generated answer is correct!

In [None]:
test_untrained_zero = generate(
    model=zero_model,
    tokenizer=zero_tokenizer,
    prompt=sample_input,
    max_tokens=max_seq_length//2,
)

print(test_untrained_zero)

print("\n\n" + "-"*100)
print(f"Actual answer: {sample_input_answer}")

# Now we're done with all the steps and can actually start the training phase

In [None]:
opt = optim.AdamW(learning_rate=2e-4)  # Set the optimizer

args = GRPOTrainingArgs(
    batch_size=1,
    iters=100, # calculate_iters(train_set=train_set, batch_size=1, epochs=1),
    gradient_accumulation_steps=1,
    val_batches=1,
    steps_per_report=10,
    steps_per_eval=100,
    steps_per_save=200,
    max_seq_length=max_seq_length,
    adapter_file=adapter_file,
    grad_checkpoint=True,
    group_size=2,
    beta=0.1,
    epsilon=0.0001,
    epsilon_high=0.1,
    max_completion_length=max_seq_length//2,
    reference_model_path=zero_ref_model_name,
    temperature=0.6,
    grpo_loss_type="grpo", # Chosse one: "grpo", "bnpo", "dr_grpo"
    reward_weights=None,
    importance_sampling_level="sequence", # Choose one: "token", "sequence", None
)

train_grpo(
    model=zero_model,
    tokenizer=zero_tokenizer,
    ref_model=zero_ref_model.freeze(),
    args=args,
    optimizer=opt,
    train_dataset=CacheDataset(train_set),
    val_dataset=CacheDataset(valid_set),
    training_callback=TrainingCallback(),
    reward_funcs=[r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml],
    end_answer_token="</answer>"
)

# After training, let's evaluate and test the trained model out!

In [None]:
loss, _, rewards = evaluate_grpo(
    model=zero_model,
    tokenizer=zero_tokenizer,
    ref_model=zero_ref_model.freeze(),
    dataset=CacheDataset(test_set),
    batch_size=1,
    num_batches=1,
    max_seq_length=max_seq_length,
    beta=0.01,
    epsilon=0.1,
    epsilon_high=0.3,
    group_size=1,
    max_tokens=max_seq_length//2,
    temperature=0.6,
    reward_funcs=[
        r1_accuracy_reward_func,
        r1_int_reward_func,
        r1_strict_format_reward_func,
        r1_soft_format_reward_func,
        r1_count_xml
    ],
    grpo_loss_type="grpo",
    importance_sampling_level="sequence",
    end_answer_token="</answer>"
)
print(rewards)

In [None]:
test_trained_zero = generate(
    model=zero_model,
    tokenizer=zero_tokenizer,
    prompt=sample_input,
    max_tokens=max_seq_length//2,
)

print(test_trained_zero)

# Finally let's merge and save the final zero model

In [None]:
fuse_and_save_model(
    model=zero_model,
    tokenizer=zero_tokenizer,
    save_path=adapter_path,
    de_quantize=True # Since we quantized the model on load
)

# Let's also remove the reference model from RAM, we don't need it anymore

So that we free out some RAM before we continue...

In [7]:
del zero_ref_model
del zero_ref_tokenizer
del valid_set
del test_set

# Dataset Curation Phase

Now we can go into the dataset curation phase. Here we will first generate some reasoning traces using the zero model, after we've collected a sufficient number of traces, we need to distill them into a format suitable for SFT training.

## Why Distillation?

The zero model outputs structured responses with raw answers:
```
<think>
reasoning steps
</think>
<answer> raw answer </answer>.
```

We want to transform this into natural language while preserving the reasoning:
```
<think> reasoning steps </think>
fluent natural language answer
```

## Distillation Process

We'll use a strong base model to rewrite the raw answers into natural language. This creates high-quality SFT data that teaches the model to:
1. Maintain the reasoning process (thinking tags)
2. Output polished, fluent answers
3. Preserve correctness from the RL training

### Step 1: Generate Zero Reasoning Traces
We'll sample from our dataset, format prompts with the chat template, and generate some reasoning traces.

In [None]:
distil_dataset = load_dataset(zero_dataset_name)["train"].select(range(num_r1_samples))
zero_reasoning_traces = []
prompts = []

sampler = make_sampler(
    temp=0.6,
    top_p=0.95,
    min_p=0.05,
    top_k=20,
)

for idx in range(num_r1_samples):
    example = distil_dataset[idx]
    print(f"Generating trace {idx+1}/{num_r1_samples}...")

    # Extract prompt
    prompt_str = example["prompt"]

    # Format with chat template → returns input_ids
    prompt_input = zero_tokenizer.apply_chat_template(
        [
            {"role": "system", "content": system},
            {"role": "user", "content": example["prompt"]},
        ],
        add_generation_prompt=True,
        tokenize=False, # <- since we"re using a qwen model which is a hybrid.
    )

    # Generate
    response = generate(
        model=zero_model,
        tokenizer=zero_tokenizer,
        prompt=prompt_input,
        max_tokens=max_seq_length // 2,
        sampler=sampler,
    )

    prompts.append(prompt_str)
    zero_reasoning_traces.append(response)

print(f"\n✓ Generated {len(zero_reasoning_traces)} zero reasoning traces")

with open(f"{zero_adapter_path}/zero_reasoning_traces.json", "w") as f:
    json.dump(
        {
            "prompts": prompts,
            "traces": zero_reasoning_traces
        },
        f,
        indent=2
    )

Generating trace 1/10...
Generating trace 2/10...
Generating trace 3/10...
Generating trace 4/10...
Generating trace 5/10...
Generating trace 6/10...
Generating trace 7/10...
Generating trace 8/10...
Generating trace 9/10...
Generating trace 10/10...

✓ Generated 10 zero reasoning traces


# Great lets take a lott at one of the generated traces

In [9]:
print("-"*500, "\n", f"Prompt: {prompts[0]}", "\n", f"Generation: {zero_reasoning_traces[0]}")

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 
 Prompt: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? 
 Generation: <think>
1. First, Natalia sold clips to 48 of her friends in April, so she sold 48 clips in April.
2. In May, Natalia sold half as many clips as in April, so she sold 48 / 2 = 24 clips in May.
3. To find out how many clips Natalia sold altogether in April and May, we add the number of clips sold in April and May: 48 +

In [10]:
del zero_model
del zero_tokenizer

### Step 2: Distill to Natural Language

Now we'll use a strong model to rewrite the raw answers into fluent natural language.

In [11]:
distill_model, distill_tokenizer = from_pretrained(
    model=r1_dataset_generator_model_name,
    quantized_load=None,
)

Loading model Qwen/Qwen3-1.7B


Fetching 9 files:   0%|          | 0/9 [00:00<?, ?it/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/622M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

In [None]:
def extract_between(text, start_tag, end_tag):
    """Extract content between tags."""
    start_idx = text.find(start_tag)
    end_idx = text.find(end_tag)
    if start_idx == -1 or end_idx == -1:
        return None
    return text[start_idx + len(start_tag):end_idx].strip()

def distill_trace(trace, model, tokenizer):
    """Convert one zero trace to SFT format with natural language answer."""
    
    # Extract reasoning and raw answer
    reasoning = extract_between(trace, "<think>", "</think>")
    raw_answer = extract_between(trace, "<answer>", "</answer>")
    
    if not reasoning or not raw_answer:
        return None
    
    # Rewrite raw answer to natural language
    distill_prompt = f"""Given this reasoning and answer, rewrite the answer in clear, natural language. Only return the natural answer, no additional text:

Reasoning: {reasoning}
Raw answer: {raw_answer}

Natural answer:"""
    
    sampler = make_sampler(
        temp=0.8,
        top_p=0.95,
        min_p=0.0,
        top_k=20,
    )

    distil_input = distill_tokenizer.apply_chat_template(
        [
            {"role": "user", "content": distill_prompt},
        ],
        add_generation_prompt=True,
        tokenize=False,
        enable_thinking=False
    )
    
    natural_answer = generate(
        model,
        tokenizer,
        prompt=distil_input,
        max_tokens=max_seq_length,
        sampler=sampler,
    )
    
    sft_completion = f"<think>\n{reasoning}\n</think>\n{natural_answer.strip()}"
    
    return sft_completion

In [None]:
sft_dataset = []
for idx, (prompt, trace) in enumerate(zip(prompts, zero_reasoning_traces)):
    print(f"Distilling {idx+1}/{len(zero_reasoning_traces)}...")
    sft_completion = distill_trace(trace, distill_model, distill_tokenizer)
    if sft_completion:
        # Format as messages structure
        sft_dataset.append({
            "messages": [
                {"role": "user", "content": prompt},
                {"role": "assistant", "content": sft_completion}
            ]
        })
        if (idx + 1) % 10 == 0:
            print(f"✓ Distilled {idx+1} traces")

Distilling 1/10...
<think>
1. First, Natalia sold clips to 48 of her friends in April, so she sold 48 clips in April.
2. In May, Natalia sold half as many clips as in April, so she sold 48 / 2 = 24 clips in May.
3. To find out how many clips Natalia sold altogether in April and May, we add the number of clips sold in April and May: 48 + 24 = 72.
</think>
Natalia sold 48 clips in April and 24 clips in May, totaling 72 clips.
Distilling 2/10...
Distilling 3/10...
Distilling 4/10...
<think>
Yesterday, Julie read 12 pages, and today she read twice as many, which is 12 x 2 = 24 pages. In total, she has read 12 + 24 = 36 pages so far. The remaining pages are 120 - 36 = 84 pages. If she wants to read half of the remaining pages tomorrow, she needs to read 84 / 2 = 42 pages.
</think>
42
Distilling 5/10...
Distilling 6/10...
<think>
First, calculate the number of purple flowers by taking 10 flowers (yellow) and multiplying it by 1.80 (80% more). This gives 18 purple flowers.
Next, calculate the

FileNotFoundError: [Errno 2] No such file or directory: './Qwen3-1.7B-R1/sft_dataset.jsonl'

In [21]:
# Save as JSONL (one JSON object per line)
with open("./sft_dataset.jsonl", "w") as f:
    for item in sft_dataset:
        f.write(json.dumps(item) + "\n")

In [19]:
print(sft_dataset[0]["prompt"])
print(sft_dataset[0]["completion"])

Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
<think>
1. First, Natalia sold clips to 48 of her friends in April, so she sold 48 clips in April.
2. In May, Natalia sold half as many clips as in April, so she sold 48 / 2 = 24 clips in May.
3. To find out how many clips Natalia sold altogether in April and May, we add the number of clips sold in April and May: 48 + 24 = 72.
</think>
Natalia sold 48 clips in April and 24 clips in May, for a total of 72 clips.


### Step 3: Save Final SFT Dataset

Yay! the dataset has been generated let's look at how it turned out.

In [22]:
del distill_model
del distill_tokenizer
del distil_dataset
del distill_trace

# OK so now that we have our R1 dataset we can now SFT finetune the Base model

In [25]:
r1_model, r1_tokenizer = from_pretrained(
    model=base_model_name,
    lora_config=lora_config,
    quantized_load=quantized_config,
)

Loading model Qwen/Qwen3-1.7B-Base


Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

Loading LoRA adapters with config: {'rank': 8, 'dropout': 0.0, 'scale': 10.0, 'use_dora': False, 'num_layers': -1}
Quantizing model with 4 bits


In [30]:
def format_prompts_func(sample):
    sample["text"] = r1_tokenizer.apply_chat_template(
        conversation=sample["messages"],
        add_generation_prompt=False,
        tokenize=False
    )
    return sample

dataset = Dataset.from_list(sft_dataset) # Turn it into a pyarrow.Table to make Dataset class happy

train_set = TextDataset(
    dataset.map(format_prompts_func, ).remove_columns(["messages"]),
    r1_tokenizer,
    text_key="text",
)

valid_set = TextDataset(
    dataset.map(format_prompts_func, ).remove_columns(["messages"]),
    r1_tokenizer,
    text_key="text",
)

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

In [31]:
print(valid_set[0]["text"])

<|im_start|>user
Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<|im_end|>
<|im_start|>assistant
<think>
1. First, Natalia sold clips to 48 of her friends in April, so she sold 48 clips in April.
2. In May, Natalia sold half as many clips as in April, so she sold 48 / 2 = 24 clips in May.
3. To find out how many clips Natalia sold altogether in April and May, we add the number of clips sold in April and May: 48 + 24 = 72.
</think>

Natalia sold 48 clips in April and 24 clips in May, totaling 72 clips.<|im_end|>



In [33]:
adapter_path = Path(r1_adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors"
save_config(lora_config, adapter_path / "adapter_config.json")

opt = optim.AdamW(learning_rate=2e-4)

# Training settings
args = SFTTrainingArgs(
    batch_size=1,
    iters=calculate_iters(train_set, batch_size=1, epochs=1),
    gradient_accumulation_steps=1,
    val_batches=1,
    steps_per_report=50,
    steps_per_eval=500,
    steps_per_save=200,
    max_seq_length=max_seq_length,
    adapter_file=adapter_file,
    grad_checkpoint=True,
)

# Start Training
train_sft(
    model=r1_model,
    args=args,
    optimizer=opt,
    train_dataset=CacheDataset(train_set),
    val_dataset=CacheDataset(valid_set),
    training_callback=TrainingCallback(),
)

[INFO] Calculated 5 iterations from 1 epochs (dataset size: 5, batch size: 1)
Starting training..., iters: 5


Training:   0%|          | 0/5 [00:00<?, ?it/s]

Iter 1: Val loss 0.680, Val took 0.403s


Training:  80%|████████  | 4/5 [00:05<00:01,  1.32s/it]

Iter 5: Val loss 0.573, Val took 0.261s


Training: 100%|██████████| 5/5 [00:06<00:00,  1.33s/it, loss=0.742, it/s=8.346]


Iter 5: loss 0.742, lr 2.000e-04, it/s 8.346, tok/s 174.091, trained_tok 1043, peak_mem 6.882GB
Saved final weights to Qwen3-1.7B-R1/adapters.safetensors.





# Sooooo, finaly! we"re finished

We just creaed and trained our own Reasoning model completely from scratch.

The only thing we now have to do is to save te R1 model.

In [34]:
fuse_and_save_model(
    model=r1_model,
    tokenizer=r1_tokenizer,
    save_path=adapter_path,
    de_quantize=True # Since we quantized the model on load
)

De-quantizing model
Created README.md in Qwen3-1.7B-R1


## That's it!

And we're done! You successfully trained your own custom model. You can updload it using the api package by HF. If you have any questions on MLX-LM-LoRA, or find any bugs, or need help, feel free to go to my [GitHub](https://github.com/Goekdeniz-Guelmez/mlx-lm-lora)!

Cheers,
Gökdeniz