# Train a custom Chat model using MLX-LM-LoRA's SFT trainer

I'm about to demonstrate the power of MLX-LM-LoRA through a finetuning example.

In [None]:
%%capture
%pip install -U mlx-lm-lora ipywidgets

# Import the necessary modules

In [None]:
# The trainer and evaluations
from mlx_lm_lora.trainer.sft_trainer import SFTTrainingArgs, train_sft, evaluate_sft

# The Datasets
from mlx_lm_lora.trainer.datasets import CacheDataset, ChatDataset

# For loading/saving the model and calculating the steps
from mlx_lm_lora.utils import from_pretrained, save_pretrained_merged, calculate_iters

# For loading the dataset
from datasets import load_dataset

# Other needed stuff
from mlx_lm.tuner.utils import print_trainable_parameters
from mlx_lm.tuner.callbacks import TrainingCallback
from mlx_lm.utils import save_config
from pathlib import Path

# The optimizer
import mlx.optimizers as optim


# Set the datase, model, and loading params

In [None]:
model_name = "Qwen/Qwen3-1.7B-Base"
adapter_path = "./tests"
dataset_name = "mlx-community/Dolci-Instruct-SFT-No-Tools-100K"

max_seq_length = 512
lora_config = { # LoRA adapter configuration
    "rank": 8,  # Low-rank bottleneck size (Larger rank = smarter, but slower). Suggested 8, 16, 32, 64, 128
    "dropout": 0.0,
    "scale": 10.0, # Multiplier for how hard the LoRA update hits the base weights
    "use_dora": False,
    "num_layers": 8 # Use -1 for all layers
}
quantized_config={
    "bits": 4, # Use 4 bit quantization. Suggested 4, 6, 8
    "group_size": 64
}

# Load the model

In [None]:
model, tokenizer, adapter_file = from_pretrained(
    model=model_name,
    new_adapter_path=adapter_path,
    lora_config=lora_config,
    quantized_load=quantized_config
)

# Load and process the dataset

Since this dataset it in the right format, we dont need to reformat.

If you have to reformat before loading, keep in mind it should be a jsonl looking like:

```json
{
    "messages": [
        {"role": "user", "content": "..."},
        {"role": "assistant", "content": "..."},
        ...
    ]
}
```

In [None]:
train_set = ChatDataset(
    load_dataset(dataset_name)["train"],
    tokenizer,
    chat_key="messages",
    mask_prompt=False
)
valid_set = ChatDataset(
    load_dataset(dataset_name)["valid"],
    tokenizer,
    chat_key="messages",
    mask_prompt=False
)
test_set = ChatDataset(
    load_dataset(dataset_name)["test"],
    tokenizer,
    chat_key="messages",
    mask_prompt=False
)

# Let's inspect the loaded dataset

In [None]:
print(test_set)
print(test_set[0])

# Now we're done with all the steps and can actually start the training phase

In [None]:
opt = optim.AdamW(learning_rate=1e-5)  # Set the optimizer

# Training settings
args = SFTTrainingArgs(
    batch_size=1,
    iters=100,  # Or use calculate_iters() for epochs
    gradient_accumulation_steps=1,  # Increase for simulating higher batch size
    val_batches=1,
    steps_per_report=20,
    steps_per_eval=50,
    steps_per_save=50,
    max_seq_length=512,
    adapter_file=adapter_file,
    grad_checkpoint=True,  # For memory saving
)

# Start Training
train_sft(
    model=model,
    args=args,
    optimizer=opt,
    train_dataset=CacheDataset(train_set),
    val_dataset=CacheDataset(valid_set),
    training_callback=TrainingCallback(),  # Or use WandBCallback()
)

# After training, let's test the trained model out!

In [None]:
eval_loss = evaluate_sft(
    model=model,
    dataset=CacheDataset(test_set),
    batch_size=1,
    num_batches=1,
    max_seq_length=512
)
print(eval_loss)

# Finally let's merge and save the final model

In [None]:
save_pretrained_merged(
    model=model,
    tokenizer=tokenizer,
    save_path=adapter_path,
    de_quantize=True # Since we quantized the model on load
)

## That's it!

And we're done! You successfully trained your own custom model. You can updload it using the api package by HF. If you have any questions on MLX-LM-LoRA, or find any bugs, or need help, feel free to go to my [GitHub](https://github.com/Goekdeniz-Guelmez/mlx-lm-lora)!

Cheers,
GÃ¶kdeniz