# Train a custom Chat model using MLX-LM-LoRA's DPO trainer

I'm about to demonstrate the power of MLX-LM-LoRA through a preference optimization example.

In [None]:
%%capture
%pip install -U mlx-lm-lora ipywidgets

In [None]:
# The trainer and evaluations
from mlx_lm_lora.trainer.dpo_trainer import DPOTrainingArgs, evaluate_dpo, train_dpo

# The Datasets
from mlx_lm_lora.trainer.datasets import CacheDataset, PreferenceDataset

# For loading/saving the model and calculating the steps
from mlx_lm_lora.utils import from_pretrained, fuse_and_save_model, calculate_iters

# For loading the dataset
from datasets import load_dataset

# Other needed stuff
from mlx_lm.tuner.utils import print_trainable_parameters
from mlx_lm.tuner.callbacks import TrainingCallback
from mlx_lm.utils import save_config
from pathlib import Path

# The optimizer
import mlx.optimizers as optim


# Set the datase, model, and loading params

In [None]:
model_name = "Qwen/Qwen3-1.7B"
ref_model_name = "Qwen/Qwen3-1.7B"
adapter_path = "./tests"
dataset_name = "mlx-community/Josiefied-Qwen3-dpo-v1-flat"

max_seq_length = 512
lora_config = { # LoRA adapter configuration
    "rank": 8,  # Low-rank bottleneck size (Larger rank = smarter, but slower). Suggested 8, 16, 32, 64, 128
    "dropout": 0.0,
    "scale": 10.0, # Multiplier for how hard the LoRA update hits the base weights
    "use_dora": False,
    "num_layers": 8 # Use -1 for all layers
}
quantized_config={
    "bits": 4, # Use 4 bit quantization. Suggested 4, 6, 8
    "group_size": 64
}

In [None]:
ref_model, _ = from_pretrained(
    model=ref_model_name,
    quantized_load=None, # Ref model shoudl be "smarter" then studend model
)

model, tokenizer = from_pretrained(
    model=model_name,
    lora_config=lora_config,
    quantized_load=quantized_config,
)
print_trainable_parameters(model)

In [None]:
adapter_path = Path(adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors"
save_config(lora_config, adapter_path / "adapter_config.json")

# Load and process the dataset

We have to format the Dataset before feeding into the model in training.

If you have to reformat before loading, keep in mind it should be a jsonl looking like:

```json
{
    "prompt": "...",
    "chosen": "...",
    "rejected": "..."
}
```

In [None]:
def format(sample):
    prompt = sample["prompt"]
    chosen = sample["chosen"]
    rejected = sample["rejected"]

    sample["chosen"] = tokenizer.apply_chat_template(
        conversation=[
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": chosen}
        ],
        add_generation_prompt=False,
        enable_thinking=False,
        tokenize=False
    )

    sample["rejected"] = tokenizer.apply_chat_template(
        conversation=[
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": rejected}
        ],
        add_generation_prompt=False,
        enable_thinking=False,
        tokenize=False
    )
    return sample

dataset = load_dataset(dataset_name)["train"]
train_dataset = dataset.select(range(0, 400)).map(format, ) # 400 samples for training
valid_dataset = dataset.select(range(400, 460)).map(format, ) # 60 samples for validation
test_dataset = dataset.select(range(460, 500)).map(format, ) # 40 samopes for testing at the end

# Let's inspect the loaded dataset

In [None]:
print("#"*50 , "Chosen", "#"*100)
print(train_dataset[0]["chosen"])
print("#"*50 , "Rejected", "#"*100)
print(train_dataset[0]["rejected"])

In [None]:
train_set = PreferenceDataset(train_dataset, tokenizer, chosen_key="chosen", rejected_key="rejected")
valid_set = PreferenceDataset(valid_dataset, tokenizer, chosen_key="chosen", rejected_key="rejected")
test_set = PreferenceDataset(test_dataset, tokenizer, chosen_key="chosen", rejected_key="rejected")

# Now we're done with all the steps and can actually start the training phase

In [None]:
opt = optim.Muon(learning_rate=1e-4)  # Set the optimizer

args = DPOTrainingArgs(
    batch_size=1,
    iters=calculate_iters(train_set, batch_size=1, epochs=1),
    gradient_accumulation_steps=1,
    val_batches=1,
    steps_per_report=1,
    steps_per_eval=10,
    steps_per_save=20,
    max_seq_length=512,
    adapter_file=adapter_file,
    grad_checkpoint=True,
    beta=0.1,
    loss_type="sigmoid", # Choose one: "sigmoid", "hinge", "ipo", "dpop"
    delta=0.01,
    reference_model_path=model_name
)

train_dpo(
    model=model,
    ref_model=model.freeze(),
    args=args,
    optimizer=opt,
    train_dataset=CacheDataset(train_set),
    val_dataset=CacheDataset(valid_set),
    training_callback=TrainingCallback(),
    loss_type="sigmoid", # Choose one: "sigmoid", "hinge", "ipo", "dpop"
)

In [None]:
evaluate_dpo(
    model=model,
    ref_model=model.freeze(),
    dataset=CacheDataset(test_set),
    batch_size=1,
    num_batches=1,
    beta=0.1,
    delta=0.01,
    max_seq_length=512,
    loss_type="sigmoid"
)