# Train a custom Chat model using MLX-LM-LoRA's SFT trainer

I'm about to demonstrate the power of MLX-LM-LoRA through a finetuning example.

In [None]:
%%capture
%pip install -U mlx-lm-lora ipywidgets

# Import the necessary modules

In [None]:
# The trainer and evaluations
from mlx_lm_lora.trainer.sft_trainer import SFTTrainingArgs, train_sft, evaluate_sft

# The Datasets
from mlx_lm_lora.trainer.datasets import CacheDataset, TextDataset

# For loading/saving the model and calculating the steps
from mlx_lm_lora.utils import from_pretrained, fuse_and_save_model, calculate_iters

# For loading the dataset
from datasets import load_dataset

# Other needed stuff
from mlx_lm.tuner.utils import print_trainable_parameters
from mlx_lm.tuner.callbacks import TrainingCallback
from mlx_lm.utils import save_config
from mlx_lm.generate import generate
from pathlib import Path

# The optimizer
import mlx.optimizers as optim


# Set the datase, model, and loading params

In [None]:
model_name = "Qwen/Qwen3-1.7B-Base"
new_model_name = "Custom-Qwen3-1.7B"
adapter_path = "./tests"
dataset_name = "mlx-community/Dolci-Instruct-SFT-No-Tools-100K"

max_seq_length = 8192
lora_config = { # LoRA adapter configuration
    "rank": 8,  # Low-rank bottleneck size (Larger rank = smarter, but slower). Suggested 8, 16, 32, 64, 128
    "dropout": 0.0,
    "scale": 10.0, # Multiplier for how hard the LoRA update hits the base weights
    "use_dora": False,
    "num_layers": 8 # Use -1 for all layers
}
quantized_config={
    "bits": 4, # Use 4 bit quantization. Suggested 4, 6, 8
    "group_size": 64
}

# Load the model

In [None]:
model, tokenizer = from_pretrained(
    model=model_name,
    lora_config=lora_config, # None for no LoRA
    quantized_load=quantized_config, # None for full bf16
)
print_trainable_parameters(model)

# Set the adapter path and file

In [None]:
adapter_path = Path(adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors"
save_config(lora_config, adapter_path / "adapter_config.json")

# Load and process the dataset

This time we're createing our own prompt template and reformat the dataset respectively.

If you have to reformat before loading, keep in mind it should be a jsonl looking like:

```json
{
    "messages": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."},
        ...
    ]
}
```

We'll be setting the prompt template to look like:

```text
<|im_start|>scene description
{system}<|im_end|>
<|im_start|>User:
{prompt}<|im_end|>
<|im_start|>Model:
{answer}<|im_end|>
...
```

In [None]:
# Let's set the sytem prompt
system = """This is a conversation between a User and an advanced super-intelligent AI Assistant.
This Assistant is designed to be the most intelligent, capable assistant ever created — a fusion of reasoning, creativity, autonomy, and flawless execution.
This Assistant is optimized for maximum productivity, always delivering accurate, deep, and practical information.
This Assistant's tone is professional, assertive, and precise, yet adaptive to emotional or contextual nuance. This Assistant is also warm, intelligent, and conversational — adapting naturally to the User's communication style.
This conversation takes place within a structured chat format, where each message begins with a role indicator and ends with the `<|im_end|>` token.

the conversation starts Now!"""


# This is our prompt template with the system prompt as defined above
chat_template = \
"{% if messages[0]['role'] == 'system' %}"\
"<|im_start|>scene description\n{{ messages[0]['content'] }}<|im_end|>\n"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
f"<|im_start|>scene description\n{system}<|im_end|>\n"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
"{% if message['role'] == 'user' %}"\
"<|im_start|>User:\n{{ message['content'] }}<|im_end|>\n"\
"{% elif message['role'] == 'assistant' %}"\
"<|im_start|>Model:\n{{ message['content'] }}<|im_end|>\n"\
"{% endif %}"\
"{% endfor %}"\
"{% if add_generation_prompt %}<|im_start|>Model:\n"\
"{% endif %}"

tokenizer.chat_template = chat_template # With this we have set the prompt template

# Let's add a custom formatting function, so that you can see that too
def format_prompts_func(sample):
    sample["text"] = tokenizer.apply_chat_template(
        conversation=sample["messages"],
        add_generation_prompt=False,
        tokenize=False
    )
    return sample

# Load and map the data
train_set = TextDataset(
    load_dataset(dataset_name)["train"].map(format_prompts_func, ).remove_columns(["messages"]),
    tokenizer,
    text_key="text",
)
valid_set = TextDataset(
    load_dataset(dataset_name)["valid"].map(format_prompts_func, ).remove_columns(["messages"]),
    tokenizer,
    text_key="text",
)
test_set = TextDataset(
    load_dataset(dataset_name)["test"].map(format_prompts_func, ).remove_columns(["messages"]),
    tokenizer,
    text_key="text",
)

# Let's inspect the dataset

In [None]:
print(test_set[0]["text"])

# Before we start training, let's test out the untrained model

In [None]:
input_text = tokenizer.apply_chat_template(
    conversation=[
        {"role": "system", "content": system},
        {"role": "user", "content": "What is your name?"},
    ],
    add_generation_prompt=False,
    tokenize=False
)

print(input_text)
print("-"*50)

generate(
    model=model,
    tokenizer=tokenizer,
    prompt=input_text,
)

# Now we're done with all the steps and can actually start the training phase

In [None]:
opt = optim.AdamW(learning_rate=1e-4)  # Set the optimizer

# Training settings
args = SFTTrainingArgs(
    batch_size=1,
    iters=40,  # Or use calculate_iters() for epochs
    gradient_accumulation_steps=1,  # Increase for simulating higher batch size
    val_batches=1,
    steps_per_report=20,
    steps_per_eval=50,
    steps_per_save=50,
    max_seq_length=max_seq_length,
    adapter_file=adapter_file,
    grad_checkpoint=True,  # For memory saving
    seq_step_size=1024,  # This enables the efficient long context training
)

# Start Training
train_sft(
    model=model,
    args=args,
    optimizer=opt,
    train_dataset=CacheDataset(train_set),
    val_dataset=CacheDataset(valid_set),
    training_callback=TrainingCallback(),  # Or use WandBCallback()
)

# After training, let's test the trained model out!

In [None]:
eval_loss = evaluate_sft(
    model=model,
    dataset=CacheDataset(test_set),
    batch_size=1,
    num_batches=1,
    max_seq_length=max_seq_length
)
print(eval_loss)

In [None]:
generate(
    model=model,
    tokenizer=tokenizer,
    prompt=input_text,
)

# Finally let's merge and save the final model

In [None]:
fuse_and_save_model(
    model=model,
    tokenizer=tokenizer,
    save_path=adapter_path,
    de_quantize=True # Since we quantized the model on load
)

## That's it!

And we're done! You successfully trained your own custom model. You can updload it using the api package by HF. If you have any questions on MLX-LM-LoRA, or find any bugs, or need help, feel free to go to my [GitHub](https://github.com/Goekdeniz-Guelmez/mlx-lm-lora)!

Cheers,
Gökdeniz