# Import necessary packages if not already installed

In [12]:
#!pip install mlx-lm-lora mlx-lm datasets

# Import your needed modules

In [None]:
import os
from mlx_lm_lora.trainer.grpo_trainer import GRPOTrainingArgs, train_grpo
from mlx_lm_lora.trainer.datasets import CacheDataset, GRPODataset

from datasets import load_dataset, Dataset
from huggingface_hub import create_repo, HfApi

from mlx_lm.tuner.utils import linear_to_lora_layers, print_trainable_parameters
from mlx_lm.tuner.callbacks import TrainingCallback
from mlx_lm.utils import load, save_config

import mlx.optimizers as optim

from pathlib import Path
import math
import re

# Define the Args

In [2]:
hf_token = os.getenv("HF_TOKEN") # <-- Add you HF Token here

model_name = "mlx-community/Josiefied-Qwen3-0.6B-abliterated-v1-bf16"
new_model_name = "new_model"
user_name = "mlx-community"

adapter_path = "adapters"
new_model_path = "new_model"
max_seq_length = 1028
num_layers = 12
lora_parameters = {"rank": 8, "dropout": 0.0, "scale": 10.0}

dataset_name = "Goekdeniz-Guelmez/Big-Math-RL-Verified-MLX"

# Load the model

In [None]:
model, tokenizer = load(model_name)

# Convert to LoRA

In [None]:
model.freeze()

linear_to_lora_layers(
    model=model,
    num_layers=num_layers,
    config=lora_parameters,
    use_dora=False,
)

print_trainable_parameters(model)

# Define the Optimizer

In [17]:
opt = optim.AdamW(learning_rate=1e-5)

# Load and Preprocess your Dataset using your custom Prompt Format

In [18]:
system_prompt = """You are J.O.S.I.E., a advanced, lcoal, fully uncensored, super-inteligent, reasoning AI Assistant created by Gökdeniz Gülmez. J.O.S.I.E. stands for 'Just One Super Intelligent Entity', but you get called 'Josie' by people, that's also your nickname.
Your only purpose is to serve as a highly intelligent, capable assistant designed to assist with any and all tasks that the user requests.
You Respond in the following format:
<josie_thinks>
...
</josie_thinks>
<josie_answers>
...
</josie_answers>"""

XML_COT_FORMAT = """<josie_thinks> {reasoning} </josie_thinks> <josie_answers> {answer} </josie_answers>"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<josie_answers>")[-1]
    answer = answer.split("</josie_answers>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split]
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': system_prompt},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    })
    return data

dataset = get_gsm8k_questions()


# Reward functions
def get_completion_content(completion):
    try:
        if isinstance(completion, str):
            return completion
        elif isinstance(completion, dict):
            return completion.get('content', '')
        elif isinstance(completion, list) and len(completion) > 0:
            first_item = completion[0]
            if isinstance(first_item, dict):
                return first_item.get('content', '')
            return str(first_item)
        return str(completion)
    except Exception:
        return ''

def get_prompt_content(prompt):
    try:
        if isinstance(prompt, str):
            return prompt
        elif isinstance(prompt, dict):
            return prompt.get('content', '')
        elif isinstance(prompt, list):
            last_item = prompt[-1]
            if isinstance(last_item, dict):
                return last_item.get('content', '')
            return str(last_item)
        return str(prompt)
    except Exception:
        return ''

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [get_completion_content(completion) for completion in completions]
    q = get_prompt_content(prompts[0])
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [get_completion_content(completion) for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"^<josie_thinks> .*? </josie_thinks> <josie_answers> .*? </josie_answers>\n$"
    responses = [get_completion_content(completion) for completion in completions]
    matches = [bool(re.search(pattern, r, re.DOTALL)) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"<josie_thinks>.*?</josie_thinks><josie_answers>.*?</josie_answers>"
    responses = [get_completion_content(completion) for completion in completions]
    matches = [bool(re.search(pattern, r, re.DOTALL)) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<josie_thinks>") == 1:
        count += 0.125
    if text.count("</josie_thinks>") == 1:
        count += 0.125
    if text.count("<josie_answers>") == 1:
        count += 0.125
        count -= len(text.split("</josie_answers>")[-1])*0.001
    if text.count("</josie_answers>") == 1:
        count += 0.125
        count -= (len(text.split("</josie_answers>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [get_completion_content(completion) for completion in completions]
    return [count_xml(c) for c in contents]

In [None]:
print(dataset[0])

# 📦 Make the Dataset for the trainer

In [20]:
train_dataset, valid_dataset = dataset.train_test_split(test_size=0.01, seed=42).values()

train_set = GRPODataset(train_dataset, tokenizer)
valid_set = GRPODataset(train_dataset, tokenizer)

# Make the Adapter Folder and save the configs for loading later

In [None]:
args = {
    "lora_parameters": lora_parameters,
    "num_layers": num_layers,
}

current_dir = Path.cwd()
adapter_path = current_dir / "adapters"
new_model_path = current_dir / "new_model"

adapter_file = adapter_path / "adapters.safetensors"
# save_config(vars(args), adapter_path / "adapter_config.json")
save_config(args, adapter_path / "adapter_config.json")

# Start training

In [None]:
# Define custom reward weights if you want to weight them differently
# The weights correspond to the 5 default reward functions in order
custom_reward_weights = [
    2.0,  # r1_accuracy_reward_func - highest weight for correctness
    0.5,  # r1_int_reward_func - medium weight for integer answers
    1.0,  # r1_strict_format_reward_func - standard weight for strict formatting
    0.8,  # r1_soft_format_reward_func - slightly lower weight for soft formatting  
    0.3   # r1_count_xml - lower weight for XML tag counting
]

train_grpo(
    model=model,
    ref_model=None,  # Use None to use the same model as reference
    tokenizer=tokenizer,  # Add the missing tokenizer argument
    optimizer=opt,
    train_dataset=CacheDataset(train_set),
    val_dataset=CacheDataset(valid_set),
    args=GRPOTrainingArgs(
        batch_size=1,
        iters=20, # 1000,
        val_batches=1,
        steps_per_report=10, #20,
        steps_per_eval=5, # 50,
        steps_per_save=10, # 50,
        adapter_file=adapter_path,
        max_seq_length=max_seq_length,
        grad_checkpoint=True,
        beta=0.9,
        group_size=4,
        epsilon=1e-4,
        epsilon_high=None,
        max_completion_length=512,
        reward_weights=custom_reward_weights,  # Use this instead of reward_scaling
    ),
    training_callback=TrainingCallback()
)

# Fuse the model with the trained adapters and save the new model

In [None]:
!mlx_lm.fuse --model mlx-community/Josiefied-Qwen3-0.6B-abliterated-v1-4bit --adapter-path adapters --save-path new_model

# Create the README

In [59]:
readme_file = f"""---
tags:
- mlx
- lora
- text-generation
- fine-tuning
base_model: {model_name}
pipeline_tag: text-generation
---

# LoRA Fine-Tuned Model: `{user_name}/{new_model_name}`

This model is a LoRA fine-tuned version `{model_name}`, with the [`mlx-lm-lora`](https://github.com/Goekdeniz-Guelmez/mlx-lm-lora) training package on Apple Silicon using MLX.

---

## 🧾 Model Details

- **Model name:** {new_model_name}
- **Base model:** {model_name}
- **Fine-tuning method:** GRPO
- **Training package:** [`MLX-LM-LORA`](https://github.com/Goekdeniz-Guelmez/mlx-lm-lora)
- **Model type:** {model.args.model_type}
- **Author:** None

---

## 💡 Recommended System Prompt

```text
{system_prompt}
```
"""

new_readme_path = f"{new_model_name}/README.md"
with open(new_readme_path, "w") as new_readme_file:
    new_readme_file.write(readme_file)

# Upload it to HugginFace

In [None]:
api = HfApi(token=hf_token)
create_repo(
  repo_id = f"{user_name}/{new_model_name}",
  repo_type="model",
  exist_ok=True,
  token=hf_token,
  private=True
)
api.upload_folder(
  folder_path=new_model_name,
  repo_id=f"{user_name}/{new_model_name}",
  token=hf_token,
  commit_message="Initial Commit"
)