# üöÄ GRPO Training Tutorial: Teaching Models to Reason with MLX-LM-LoRA

## Overview

This notebook demonstrates **Group Relative Policy Optimization (GRPO)** for training reasoning models on **Apple Silicon** using **MLX-LM-LoRA**. GRPO is a reinforcement learning technique that improves model outputs using reward signals from multiple reward functions.

### What You'll Learn

1. **Two-stage training approach**: Cold-start SFT ‚Üí GRPO fine-tuning
2. **Efficient training on Apple Silicon** with 8-bit quantization and LoRA
3. **Custom reward functions** for structured reasoning outputs
4. **Long context handling** (up to 4096 tokens)

### Training Pipeline

**Stage 1: Cold Start (SFT)** ‚Üí Teach the model the basic reasoning format  
**Stage 2: GRPO** ‚Üí Optimize the model using reward-based feedback to improve reasoning quality

### Requirements

- Apple Silicon Mac (M1/M2/M3/M4)
- MLX-LM-LoRA installed
- Sufficient RAM (16GB+ recommended)

Let's begin! üéØ

---

## Step 1: Import Required Libraries

First, we import all necessary libraries from MLX-LM-LoRA for training, generation, and utilities.

In [None]:
from mlx_lm_lora.utils import save_pretrained_merged, from_pretrained, calculate_iters, push_to_hub
from mlx_lm_lora.trainer.grpo_trainer import GRPOTrainingArgs, train_grpo, evaluate_grpo
from mlx_lm_lora.trainer.sft_trainer import SFTTrainingArgs, train_sft, evaluate_sft
from mlx_lm_lora.trainer.datasets import CacheDataset, GRPODataset, TextDataset

from datasets import load_dataset

from mlx_lm.tuner.utils import print_trainable_parameters
from mlx_lm.tuner.callbacks import TrainingCallback, WandBCallback
from mlx_lm.generate import generate, make_sampler

import mlx.optimizers as optim

import re

---

## Step 2: Configure Model and Training Parameters

Here we define all configuration parameters for our training pipeline:

### Model Configuration
- **Base model**: A pre-trained instruction-tuned model
- **Cold start model**: Intermediate model trained with supervised learning
- **Final model**: GRPO-optimized model with improved reasoning

### Key Parameters

**Sequence Lengths:**
- `cold_start_max_seq_length`: 2048 tokens (for initial SFT training)
- `zero_max_seq_length`: 4096 tokens (for GRPO training with longer reasoning)

**LoRA Configuration:**
- `rank`: 8 (controls adapter size)
- `scale`: 10.0 (LoRA scaling factor)
- `num_layers`: 8 (number of layers to adapt)

**Quantization:**
- `bits`: 8 (8-bit quantization for memory efficiency)
- `group_size`: 128 (quantization granularity)

**Reasoning Format:**
We use special tokens to structure the model's reasoning:
- `<think>...</think>`: Contains the model's working/reasoning
- `<answer>...</answer>`: Contains the final answer

This format helps the model learn to separate reasoning from conclusions!

In [None]:
model_name = "Goekdeniz-Guelmez/Llama-3.2-1B-Instruct-gabliterated"
new_cold_start_model_name = "Llama-3.2-1B-Gabliterated-Zero-Cold_Start"
new_zero_model_name = "Llama-3.2-1B-Gabliterated-Zero"
user_name = "Goekdeniz-Guelmez"

cold_start_max_seq_length = 2048
zero_max_seq_length = 4096

cold_start_adapter_path = f"./{new_cold_start_model_name}"
zero_adapter_path = f"./{new_zero_model_name}"

grpo_dataset_name = "mlx-community/gsm8k"
lora_config = {
    "rank": 8,
    "dropout": 0.0,
    "scale": 10.0,
    "use_dora": False,
    "num_layers": 8
}
quantized_load={
    "bits": 8,
    "group_size": 128
}

reasoning_start = "<think>"
reasoning_end   = "</think>"
solution_start = "<answer>"
solution_end = "</answer>"

---

## Step 3: Load the Base Model for Cold Start Training

### What is Cold Start Training?

Before we can use GRPO (which requires the model to generate responses), we need to teach the model the **basic reasoning format**. This is called "cold start" training and uses standard **Supervised Fine-Tuning (SFT)**.

### Why Cold Start?

- **Problem**: The base model doesn't know how to use `<think>` and `<answer>` tags
- **Solution**: First train it with examples that have the correct format
- **Result**: Model learns the structure before we optimize with GRPO

### Model Loading

We load the model with:
- **LoRA adapters**: Only ~1-2% of parameters are trainable (memory efficient!)
- **8-bit quantization**: Reduces memory usage by ~50%
- **New adapter path**: Creates a fresh adapter for cold start training

The `print_trainable_parameters()` function shows how few parameters we're actually training!

In [None]:
cold_start_model, cold_start_tokenizer, cold_start_adapter_file = from_pretrained(
    model=model_name,
    lora_config=lora_config,
    quantized_load=quantized_load,
    new_adapter_path=cold_start_adapter_path
)
print_trainable_parameters(cold_start_model)

---

## Step 4: Prepare the Cold Start Dataset

### Dataset Formatting

We need to format the dataset to teach the model our reasoning structure. Each example contains:
- **System prompt**: Instructions on how to structure the response
- **User message**: The problem to solve
- **Assistant response**: The correctly formatted answer with reasoning

### Format Structure

```
<think>
[Step-by-step reasoning and working]
</think>
<answer>[Final answer]</answer>
```

### Dataset Source

We use the **GSM8K dataset** (Grade School Math problems) which includes:
- `prompt`: The math problem
- `reasoning`: Step-by-step solution
- `answer`: Final numerical answer

### Why Take Only 1000 Examples?

For cold start, we don't need the full dataset - just enough examples to teach the format. This:
- Reduces training time significantly
- Prevents overfitting to the cold start distribution
- Leaves the model flexible for GRPO optimization

The `TextDataset` wrapper prepares the data for efficient SFT training.

In [None]:
system = f"You are given a problem. Think about the problem and provide your working out. Place it between {reasoning_start} and {reasoning_end}. Then, provide your solution between {solution_start} {solution_end}."

def format_cold_start(sample):
    raw_answer = f"{reasoning_start}\n{sample["reasoning"]}\n{reasoning_end}\n{solution_start}{sample["answer"]}{solution_end}"

    sample["text"] = cold_start_tokenizer.apply_chat_template(
        conversation=[
            {"role": "system", "content": system},
            {"role": "user", "content": sample["prompt"]},
            {"role": "assistant", "content": raw_answer},
        ],
        add_generation_prompt=False,
        tokenize=False
    )
    return sample

cold_start_train_dataset = load_dataset(grpo_dataset_name)["test"].take(1000).map(format_cold_start, )

cold_start_train_set = TextDataset(
    cold_start_train_dataset,
    tokenizer=cold_start_tokenizer,
    text_key="text",
)

In [None]:
print(cold_start_train_dataset[0]["text"])

---

## Step 5: Train the Cold Start Model with Supervised Fine-Tuning

### Training Configuration

**Memory Optimization:**
- `batch_size=1`: Minimal memory usage (increase if you have more RAM)
- `gradient_accumulation_steps=32`: Simulates batch size of 32 by accumulating gradients
- `grad_checkpoint=True`: Trades compute for memory (gradient checkpointing)

**Training Duration:**
- `epochs=1`: We only need one pass through the data to learn the format
- `iters`: Automatically calculated based on dataset size

**Monitoring:**
- `steps_per_report=100`: Log training metrics every 100 steps
- `steps_per_eval=200`: Run validation every 200 steps
- `steps_per_save=400`: Save checkpoint every 400 steps

### What Happens During Training?

The model learns to:
1. Recognize when to use reasoning tags
2. Generate step-by-step working inside `<think>` tags
3. Provide final answers inside `<answer>` tags

### Training Callbacks

You can use:
- `TrainingCallback()`: Basic console logging (default)
- `WandBCallback()`: Advanced logging with Weights & Biases (commented out)

**Training will take 10-30 minutes depending on your Mac's performance.**

After training, we save the merged model (base model + adapters combined).

In [None]:
opt = optim.AdamW(learning_rate=8e-5)

# Training arguments. Adjust these based on your dataset size, GPU capacity, and how long you want to train.
args = SFTTrainingArgs(
    batch_size=1, # Use batch size of 1 to save RAM, increase if you have more GPU memory
    iters=calculate_iters(cold_start_train_set, batch_size=1, epochs=1), # Only train for 1 epoch since the dataset is small, increase if you want to train longer
    gradient_accumulation_steps=32, # Accumulate gradients over 8 steps to simulate a larger batch size and save RAM. Adjust based on your GPU capacity.
    val_batches=1, # Only use 1 batch for validation to speed it up, since the dataset is small. Remove or increase for better evaluation.
    steps_per_report=100, # Log training progress every 10 steps
    steps_per_eval=200, # Evaluate every 20 steps
    steps_per_save=400, # Save the model every 10 steps
    max_seq_length=cold_start_max_seq_length,
    adapter_file=cold_start_adapter_file,
    grad_checkpoint=True, # Use gradient checkpointing to save RAM at the cost of slightly slower training
)

# Start training
train_sft(
    model=cold_start_model,
    args=args,
    optimizer=opt,
    train_dataset=CacheDataset(cold_start_train_set),
    # training_callback=WandBCallback(
    #     project_name=f"{new_model_name}-finetuning",
    #     log_dir=adapter_path,
    #     wrapped_callback=TrainingCallback(),
    #     config=None
    # )
    training_callback=TrainingCallback(), # You can use the basic TrainingCallback to log training progress to the console instead of Weights & Biases. Just comment out the WandBCallback and uncomment this line if you prefer that.
)

In [None]:
save_pretrained_merged(
    model=cold_start_model,
    tokenizer=cold_start_tokenizer,
    save_path=cold_start_adapter_path
)

---

## Step 6: Transition to GRPO Training

### Cleaning Up Memory

First, we delete the cold start resources to free up memory. This is **crucial on Apple Silicon** to avoid running out of RAM during GRPO training.

### Loading Models for GRPO

We need **TWO models** for GRPO:

1. **Policy Model** (model to optimize)
   - Loaded from the cold start checkpoint
   - Has fresh LoRA adapters to train
   - Will be updated during GRPO

2. **Reference Model** (baseline for KL divergence)
   - Loaded from the **original base model** (not cold start!)
   - Frozen (no training)
   - Used to compute KL penalty to prevent the model from drifting too far

### Why Use Base Model as Reference?

Using the **original base model** (instead of cold start) as reference:
- ‚úÖ Provides better KL divergence regularization
- ‚úÖ Prevents the model from overfitting to GRPO rewards
- ‚úÖ Maintains general language capabilities
- ‚úÖ Results in more robust reasoning

### Longer Context

Note we now use `zero_max_seq_length=4096` (double the cold start length) because:
- Reasoning chains can be much longer
- GRPO generates completions for multiple candidates
- We need space for detailed step-by-step thinking

In [None]:
del cold_start_model
del cold_start_tokenizer
del cold_start_adapter_file
del cold_start_train_set
del cold_start_train_dataset
del opt

In [None]:
model, tokenizer, adapter_file = from_pretrained(
    model=cold_start_adapter_path,
    lora_config=lora_config,
    quantized_load=quantized_load,
    new_adapter_path=zero_adapter_path
)

# Use the base model as reference instead of the cold-start model
# This provides better KL divergence regularization during GRPO
ref_model, _, _ = from_pretrained(
    model=cold_start_adapter_path,
    quantized_load=quantized_load,
)
print_trainable_parameters(model)

---

## Step 7: Define Custom Reward Functions

### What are Reward Functions?

GRPO optimizes the model by **rewarding good behaviors**. We define multiple reward functions to evaluate different aspects of the model's responses.

### Our Four Reward Functions

#### 1. `match_format_exactly` (Reward: +3.0)
- Checks if the response **perfectly follows** the format
- Must have reasoning ending with `</think>`, followed by answer in `<answer>...</answer>`
- Strictest reward - only given for perfectly structured responses

#### 2. `match_format_approximately` (Reward: ¬±0.5 per component)
- More forgiving than exact matching
- Checks for presence of key components:
  - One `</think>` tag ‚Üí +0.5
  - One `<answer>` tag ‚Üí +0.5  
  - One `</answer>` tag ‚Üí +0.5
  - No extra `<think>` tags ‚Üí 0 (penalty if multiple)
- Helps guide the model even when format isn't perfect

#### 3. `check_answer` (Reward: 0 to +3.0)
- **Most important**: Checks if the answer is correct!
- Exact match ‚Üí +3.0
- Close match (with whitespace) ‚Üí +1.5
- Numerical approximation (¬±10%) ‚Üí +0.5
- Numerical approximation (¬±20%) ‚Üí +0.25
- Wrong answer ‚Üí -0.5 or -1.0 penalty

#### 4. `check_numbers` (Reward: 0 or +1.5)
- Simpler numerical check
- Extracts first number after `<answer>`
- Exact match ‚Üí +1.5
- No match or wrong ‚Üí 0

### How Rewards Work Together

The model receives signals about:
- **Structure** (format rewards) ‚Üí Learn how to organize output
- **Correctness** (answer rewards) ‚Üí Learn to solve problems accurately
- **Consistency** (number rewards) ‚Üí Learn to be consistent

These combined rewards guide the model to generate well-structured AND correct reasoning!

In [None]:
match_format = re.compile(
    rf".+?\n{reasoning_end}\n{solution_start}(.+?){solution_end}",
    flags = re.MULTILINE | re.DOTALL
)

def match_format_exactly(prompts, completions, answer, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"] if isinstance(completion, list) else completion
        if match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores

def match_format_approximately(prompts, completions, answer, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"] if isinstance(completion, list) else completion
        score += 0.5 if response.count(reasoning_end)   == 1 else -0.5
        score += 0.5 if response.count(solution_start)  == 1 else -0.5
        score += 0.5 if response.count(solution_end)    == 1 else -0.5
        score -= 0.5 if response.count(reasoning_start) >= 1 else 0
        scores.append(score)
    return scores

def check_answer(prompts, completions, answer, **kwargs):
    responses = [completion[0]["content"] if isinstance(completion, list) else completion for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_format.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(0)
            continue
        if guess == true_answer:
            score += 3.0
        elif guess.strip() == true_answer.strip():
            score += 1.5
        else:
            try:
                ratio = float(guess) / float(true_answer)
                if   ratio >= 0.9 and ratio <= 1.1: score += 0.5
                elif ratio >= 0.8 and ratio <= 1.2: score += 0.25
                else: score -= 1.0
            except:
                score -= 0.5
        scores.append(score)
    return scores

match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})",
    flags = re.MULTILINE | re.DOTALL
)

def check_numbers(prompts, completions, answer, **kwargs):
    responses = [completion[0]["content"] if isinstance(completion, list) else completion for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_numbers.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
        try:
            true_answer = float(true_answer.strip())
            guess       = float(guess.strip())
            scores.append(1.5 if guess == true_answer else 0.0)
        except:
            scores.append(0)
            continue
    return scores

---

## Step 8: Prepare the GRPO Training Dataset

### GRPODataset vs TextDataset

Unlike cold start training where we had pre-formatted text, GRPO needs a **different data format**:

- **TextDataset** (cold start): Complete input-output pairs for supervised learning
- **GRPODataset** (GRPO): Prompts that the model will complete and evaluate

### Dataset Structure

The `GRPODataset` expects:
- `prompt_key="prompt"`: The problem statement
- `answer_key="answer"`: Ground truth for reward calculation
- `system_key="system"`: Optional system message
- `type_key="type"`: Optional task type identifier

### Why Only 100 Examples?

GRPO is **much more compute-intensive** than SFT because:
- Each training step generates **multiple completions** (group_size=4)
- Each completion can be very long (up to 2048 tokens)
- All completions are evaluated by 4 reward functions
- Everything runs on your Mac's unified memory

100 examples √ó 4 completions √ó multiple epochs = substantial training!

### Memory Considerations

On Apple Silicon, we're constrained by unified memory. This smaller dataset:
- Keeps training time reasonable (1-3 hours)
- Prevents OOM (Out of Memory) errors
- Still provides strong learning signal with GRPO's group-based optimization

**Note**: You can increase this for better results if you have an M2 Ultra/M3 Max with 64GB+ RAM!

In [None]:
train_set = GRPODataset(
    load_dataset(grpo_dataset_name)["train"].take(100),
    tokenizer,
    prompt_key="prompt",
    answer_key="answer",
    system_key="system",
    type_key="type",
)

---

## Step 9: Inspect the Dataset Format

Let's examine what the model will see during training. The `train_set.process()` function converts the raw data into tokenized format, and we decode it back to see the actual prompt text.

### What to Look For

The output should show:
- System instructions about using `<think>` and `<answer>` tags
- The math problem from the dataset
- The chat template used by the model
- A generation prompt (where the model will start generating)

This is the **exact format** the model will complete during GRPO training!

In [None]:
text = tokenizer.decode(train_set.process(train_set[0][0]))
print(text)

---

## Step 10: Test Model BEFORE GRPO Training (Baseline)

### Why Test Before Training?

This generates a **baseline response** so we can compare the model's performance before and after GRPO optimization.

### Generation Parameters

- `max_tokens=2048`: Up to 2048 tokens for reasoning + answer
- `temp=0.6`: Moderate temperature (not too random, not too deterministic)
- `top_p=0.95`: Nucleus sampling for diverse responses
- `top_k=20`: Consider top 20 tokens at each step
- `verbose=True`: Shows generation statistics (tokens/sec, etc.)

### What to Expect

After cold start training, the model should:
- ‚úÖ Use the `<think>` and `<answer>` tags correctly
- ‚ö†Ô∏è May not solve the problem correctly yet
- ‚ö†Ô∏è Reasoning might be shallow or formulaic

**Save this output** - we'll compare it with the post-GRPO output to see the improvement!

In [None]:
before_test_output = generate(
    model=model,
    tokenizer=tokenizer,
    prompt=text,
    verbose=True,
    max_tokens=zero_max_seq_length//2,
    sampler=make_sampler(temp=0.6, top_p=0.95, top_k=20)
)

---

## Step 11: Train with GRPO (Group Relative Policy Optimization)

### GRPO Training Configuration

#### Core GRPO Parameters

**Group-Based Optimization:**
- `group_size=4`: Generate 4 completions per prompt
- GRPO ranks these 4 and learns from relative performance
- More efficient than PPO (no value function needed!)

**KL Divergence Control:**
- `beta=0.1`: Controls KL penalty strength
- `epsilon=1e-3`: Lower bound for probability ratio
- `epsilon_high=2e-3`: Upper bound for probability ratio
- Prevents model from diverging too far from reference

#### Generation Parameters

- `temperature=0.6`: Sampling temperature
- `top_p=0.95`: Nucleus sampling threshold
- `top_k=20`: Top-k sampling
- `max_completion_length=2048`: Maximum tokens to generate

#### Training Hyperparameters

- `learning_rate=8e-5`: Moderate learning rate for stable training
- `batch_size=1`: Memory efficiency
- `gradient_accumulation_steps=1`: With group_size=4, effective batch size is 4
- `epochs=1`: One pass through the dataset

#### Reward Configuration

- `reward_weights=[1.0, 1.0, 1.0, 1.0]`: Equal weighting for all 4 reward functions
- You can adjust these to prioritize certain aspects (e.g., [2.0, 1.0, 3.0, 1.0] to emphasize correctness)

#### GRPO Variants

```python
grpo_loss_type="grpo"  # Options:
```
- **grpo**: Original GRPO algorithm
- **bnpo**: Batch Normalized Policy Optimization
- **dr_grpo**: Doubly Robust GRPO (experimental)

### What Happens During GRPO?

1. **Sample**: Generate 4 completions for each prompt
2. **Evaluate**: Run all reward functions on completions
3. **Rank**: Order completions by total reward
4. **Update**: Increase probability of better completions, decrease worse ones
5. **Regularize**: Apply KL penalty to prevent over-optimization

### Training Time

Expect **1-3 hours** depending on your Mac:
- M1/M2 base: ~2-3 hours
- M1/M2 Pro/Max: ~1-2 hours  
- M3/M4 chips: ~1 hour

**Pro Tip**: Use `WandBCallback` to track training metrics like:
- Average rewards per function
- KL divergence
- Policy loss
- Learning dynamics

### Monitoring Progress

Watch the console output for:
- Loss decreasing over time
- Rewards increasing
- KL divergence staying stable (not exploding)
- Tokens/second (throughput)

In [None]:
# Define custom reward weights if you want to weight them differently
custom_reward_weights = [
    1.0,  # match_format_exactly
    1.0,  # match_format_approximately
    1.0,  # check_answer
    1.0,  # check_numbers
]

opt = optim.AdamW(learning_rate=2e-5)

args=GRPOTrainingArgs(
    batch_size=1,
    iters=calculate_iters(train_set, batch_size=1, epochs=1),
    val_batches=1,
    steps_per_report=5,
    steps_per_eval=100,
    steps_per_save=200,
    adapter_file=adapter_file,
    max_seq_length=zero_max_seq_length,
    grad_checkpoint=True,
    gradient_accumulation_steps=1,
    beta=0.1,
    group_size=4,
    epsilon=1e-3,
    epsilon_high=2e-3,
    temperature=0.6,
    top_p=0.95,
    top_k=20,
    max_completion_length=zero_max_seq_length//2,
    reward_weights=custom_reward_weights,
    grpo_loss_type="grpo", # grpo, bnpo, or dr_grpo
    # importance_sampling_level="sequence", # token, sequence, None for basic grpo
)

train_grpo(
    model=model,
    ref_model=ref_model.freeze(),
    tokenizer=tokenizer,
    optimizer=opt,
    train_dataset=CacheDataset(train_set),
    args=args,
    training_callback=TrainingCallback(),
    # training_callback=WandBCallback(
    #     project_name=new_model_name,
    #     log_dir=adapter_path,
    #     config=vars(args),
    #     wrapped_callback=TrainingCallback(),
    # ),
    reward_funcs=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    end_answer_token=solution_end
)

---

## Step 12: Test Model AFTER GRPO Training (Compare Results!)

### Evaluation Time! üéâ

Now we generate with the **GRPO-optimized model** using the exact same:
- Prompt (same test problem)
- Generation parameters (temp=0.6, top_p=0.95, top_k=20)
- Maximum tokens

### What Improvements to Expect

After GRPO training, the model should show:

**‚úÖ Better Structure:**
- More consistent use of reasoning tags
- Well-organized step-by-step thinking
- Clear separation of reasoning and answer

**‚úÖ More Accurate Answers:**
- Higher correctness rate on math problems
- Better numerical accuracy
- Fewer hallucinations or wrong calculations

**‚úÖ Deeper Reasoning:**
- More detailed working out
- Explicit intermediate steps
- Self-correction and validation

### Comparing Before vs After

Look at:
1. **Format compliance**: Does it follow the structure better?
2. **Reasoning quality**: Is the thinking more thorough?
3. **Answer correctness**: Is the final answer right?
4. **Confidence**: Does the reasoning support the conclusion?

The difference can be dramatic! GRPO's reward-based learning often produces much more coherent and accurate reasoning compared to the cold-start model.

In [None]:
after_test_output = generate(
    model=model,
    tokenizer=tokenizer,
    prompt=text,
    verbose=True,
    max_tokens=zero_max_seq_length//2,
    sampler=make_sampler(temp=0.6, top_p=0.95, top_k=20)
)

---

## Step 13: Save the Final GRPO-Optimized Model

### Merging Adapters with Base Model

The `save_pretrained_merged()` function:
- Combines the LoRA adapters with the base model weights
- Saves a complete, standalone model
- Can be loaded like any other MLX model (no special adapter loading needed)

### What Gets Saved

The saved directory contains:
- `model.safetensors` (or split files): Full merged model weights
- `config.json`: Model configuration
- `tokenizer.json` & `tokenizer_config.json`: Tokenizer files
- `chat_template.jinja`: Chat template (if applicable)

### Using the Saved Model

You can now use this model with:

```python
from mlx_lm import load, generate

model, tokenizer = load(new_zero_model_name)
response = generate(model, tokenizer, prompt="Your problem here", ...)
```

No need for adapter loading - it's a complete model!

In [None]:
save_pretrained_merged(
    model=model,
    tokenizer=tokenizer,
    save_path=new_zero_model_name
)

---

## Step 14: Push Model to Hugging Face Hub (Optional)

### Share Your Model! ü§ó

The `push_to_hub()` function uploads your trained model to Hugging Face, making it:
- Accessible from anywhere
- Easy to share with others
- Compatible with the MLX ecosystem

### Configuration Parameters

- `model_path`: Local path to the model (with or without adapters)
- `hf_repo`: Your Hugging Face username/repo name
- `api_key`: Your HF token (replace "HF_KEY" with actual token or use HF CLI login)
- `private=False`: Makes the model public (set to `True` for private repos)
- `commit_message`: Description of what you're uploading
- `remove_adapters=False`: Keep adapters in the upload (useful for incremental training)

### Before Running This Cell

1. **Get your HF token**: Go to https://huggingface.co/settings/tokens
2. **Replace "HF_KEY"** with your actual token, or use:
   ```bash
   huggingface-cli login
   ```
3. **Create the repo** on Hugging Face (or set `create_repo=True` if supported)

### What Gets Uploaded

- All model files
- Tokenizer
- Configuration
- Optional: Adapters (if `remove_adapters=False`)

After uploading, others can use your model with:

```python
model, tokenizer = load(f"{user_name}/{new_zero_model_name}")
```

---

## üéâ Congratulations!

You've successfully trained a reasoning model using GRPO on Apple Silicon! 

### Key Takeaways

‚úÖ Two-stage training (Cold Start SFT ‚Üí GRPO) produces better results  
‚úÖ Custom reward functions guide specific behaviors  
‚úÖ LoRA + quantization enables efficient training on Mac  
‚úÖ GRPO optimizes based on relative ranking (no value function needed)  
‚úÖ Long context (4K tokens) allows detailed reasoning chains  

### Next Steps

1. **Experiment with reward weights**: Adjust `custom_reward_weights` to emphasize different aspects
2. **Try different loss types**: Test `bnpo` or `dr_grpo` variants
3. **Scale up**: Use more training data if you have sufficient RAM
4. **Evaluate systematically**: Test on held-out math problems to measure improvement
5. **Apply to other domains**: Adapt reward functions for coding, logical reasoning, etc.

### Resources

- MLX-LM-LoRA docs: https://github.com/Goekdeniz-Guelmez/mlx-lm-lora
- GRPO paper: https://arxiv.org/abs/2402.03300
- GSM8K dataset: https://github.com/openai/grade-school-math

Happy training! üöÄ

In [None]:
push_to_hub(
  model_path=zero_adapter_path,
  hf_repo=f"{user_name}/{new_zero_model_name}",
  api_key="HF_KEY",
  private=False,
  commit_message="Add preference adapters",
  remove_adapters=False
)