# Gemma-3 270m Fine-tuning with Unsloth-MLX

This notebook demonstrates how to fine-tune the Gemma-3 270m model locally on Apple Silicon using `unsloth_mlx`. This library optimizes Apple's MLX framework for significantly faster Lora fine-tuning with less memory usage.

In [1]:
import os
from dotenv import load_dotenv
from unsloth_mlx import FastLanguageModel
import mlx.core as mx

# Load environment variables
load_dotenv()

model_path = os.getenv("MLX_MODEL_PATH", "./gemma3")
print(f"Loading model from: {model_path}")


Loading model from: ./gemma3


## 1. Load Model and Tokenizer

We use `FastLanguageModel.from_pretrained` to load both the weights and the tokenizer. 

### Arguments Explained:
- **model_name**: The local path or HuggingFace ID of the model.
- **max_seq_length**: The maximum context window (number of tokens) the model will handle. Setting this higher uses more VRAM.
- **load_in_4bit**: Enables **4-bit NormalFloat (NF4)** quantization. This shrinks the model size by ~4x, allowing it to fit into much smaller RAM/VRAM while maintaining high accuracy.

In [2]:
max_seq_length=2150
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_path,
    max_seq_length = max_seq_length,
    load_in_4bit = True, # Use 4-bit quantization to save memory
)

## 2. Add LoRA Adapters

Instead of updating billions of parameters, **LoRA (Low-Rank Adaptation)** adds small trainable "adapters" to specific layers.


In [3]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Rank
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = True,
    random_state = 3407,
)

LoRA configuration set: rank=16, alpha=16, modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'], dropout=0


## 3. Prepare Dataset

We load our `dataset.jsonl` file and format it into strings the model can understand.

In [4]:
from datasets import load_dataset

# Load the locally generated dataset from .env path
dataset_file = os.getenv("DATASET_PATH", "dataset.jsonl")
dataset = load_dataset("json", data_files=dataset_file, split="train")

def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    outputs      = examples["output"]
    texts = []
    for instruction, output in zip(instructions, outputs):
        text = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
        texts.append(text)
    return { "text" : texts, }

dataset = dataset.map(formatting_prompts_func, batched = True,)
print(f"Loaded {len(dataset)} examples from {dataset_file}")

Loaded 18 examples from dataset.jsonl


## 4. Initialization & LIVE Monitoring

**IMPORTANT**: Run the TensorBoard cell below **BEFORE** you start the training in the next step. 

This allows the dashboard to initialize. While the training is running (which blocks your notebook), you can simply click the **Refresh** button (top right of the TensorBoard UI) to see live progress.

In [5]:
%load_ext tensorboard
%tensorboard --logdir ./lora_finetuned

Reusing TensorBoard on port 6006 (pid 65029), started 0:12:31 ago. (Use '!kill 65029' to kill it.)

In [None]:
from unsloth_mlx import SFTTrainer
from trl import SFTConfig

# 1. Define the Training Configuration
sft_config = SFTConfig(
    output_dir="./lora_finetuned",
    logging_steps=1,
    num_train_epochs=3,
    learning_rate=2e-4,
    report_to="tensorboard", # Logs progress for visualization
)

# 2. Initialize the Trainer
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    args = sft_config,
)

print("Starting training... (Check the TensorBoard cell above for live updates!)")
trainer.train()
print("Training complete!")

mx.metal.device_info is deprecated and will be removed in a future version. Use mx.device_info instead.


Trainer initialized:
  Output dir: lora_finetuned
  Adapter path: lora_finetuned/adapters
  Learning rate: 0.0002
  Iterations: 6
  Batch size: 8
  LoRA r=16, alpha=16
  Native training: True
  LR scheduler: SchedulerType.LINEAR
  Grad checkpoint: False
Starting training... (Check the TensorBoard cell above for live updates!)
Starting Fine-Tuning

[Using Native MLX Training]

Applying LoRA adapters...
Applying LoRA to 18 layers: {'rank': 16, 'scale': 1.0, 'dropout': 0, 'keys': ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj', 'mlp.gate_proj', 'mlp.up_proj', 'mlp.down_proj']}
✓ LoRA applied successfully to 18 layers
  Trainable LoRA parameters: 252
Preparing training data...
  Detected format: text
✓ Prepared 18 training samples
  Saved to: lora_finetuned/train.jsonl
✓ Created validation set (copied from train)

Training configuration:
  Iterations: 6
  Batch size: 8
  Learning rate: 0.0002
  LR scheduler: SchedulerType.LINEAR
  Grad checkpoint: True
  Ada

Calculating loss...:   0%|                                                                              | 0/2 [00:00<?, ?it/s]



Calculating loss...:  50%|███████████████████████████████████                                   | 1/2 [00:00<00:00,  1.12it/s]



Calculating loss...: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.17it/s]

Iter 1: Val loss 3.174, Val took 1.718s





## 5. Saving the Model


In [None]:
model.save_pretrained("lora_model")
tokenizer.save_pretrained("lora_model")
print("Model and tokenizer saved to './lora_model'")

## 6. Testing (Inference)


In [None]:
FastLanguageModel.for_inference(model)

prompt = "### Instruction:\nWrite the code for the file `index.html` in the codebase project.\n\n### Response:"
input_ids = mx.array(tokenizer.encode(prompt))

output = model.generate(input_ids = input_ids, max_new_tokens = 128)
print(tokenizer.decode(output))