# Gemma-3 270m Fine-tuning with Unsloth-MLX

This notebook demonstrates how to fine-tune the Gemma-3 270m model locally on Apple Silicon using `unsloth_mlx`. This library optimizes Apple's MLX framework for significantly faster Lora fine-tuning with less memory usage.

In [1]:
import os
from dotenv import load_dotenv
from unsloth_mlx import FastLanguageModel
import mlx.core as mx

# Load environment variables
load_dotenv()

model_path = os.getenv("MLX_MODEL_PATH", "./gemma3")
print(f"Loading model from: {model_path}")


Loading model from: ./gemma3


## 1. Load Model and Tokenizer

We use `FastLanguageModel.from_pretrained` to load both the weights and the tokenizer. 

### Arguments Explained:
- **model_name**: The local path or HuggingFace ID of the model.
- **max_seq_length**: The maximum context window (number of tokens) the model will handle. Setting this higher uses more VRAM.
- **load_in_4bit**: Enables **4-bit NormalFloat (NF4)** quantization. This shrinks the model size by ~4x, allowing it to fit into much smaller RAM/VRAM while maintaining high accuracy.

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_path,
    max_seq_length = 2048,
    load_in_4bit = True, # Use 4-bit quantization to save memory
)

## 2. Add LoRA Adapters

Instead of updating billions of parameters, **LoRA (Low-Rank Adaptation)** adds small trainable "adapters" to specific layers. We only train these small matrices, making training much faster and lightweight.

### Arguments Explained:
- **r (Rank)**: Controls the size of the adapter matrices. `16` is a standard value that balances performance and efficiency.
- **target_modules**: Specifies which layers of the model to attach adapters to (e.g., Query, Key, Value projections in Attention).
- **lora_alpha**: A scaling factor for the adapters. Usually set to the same as `r` or $2 \times r$.
- **lora_dropout**: Dropout probability for the adapters. Set to `0` for best performance in most cases.
- **bias**: Whether to train biases. `none` is recommended for standard LoRA.
- **use_gradient_checkpointing**: Saves VRAM by recalculating parts of the network during the backward pass instead of storing all activations.
- **random_state**: Ensures reproducibility of the initialization.

In [3]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Rank
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = True,
    random_state = 3407,
)

LoRA configuration set: rank=16, alpha=16, modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'], dropout=0


## 3. Prepare Dataset

We load our `dataset.jsonl` file and format it into strings the model can understand. The Gemma-3 model expects instructions to be clearly delineated.

In [None]:
from datasets import load_dataset

# Load the locally generated dataset from .env path
dataset_file = os.getenv("DATASET_PATH", "dataset.jsonl")
dataset = load_dataset("json", data_files=dataset_file, split="train")

def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    outputs      = examples["output"]
    texts = []
    for instruction, output in zip(instructions, outputs):
        # The model is trained to generate the output given the instruction
        text = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
        texts.append(text)
    return { "text" : texts, }

dataset = dataset.map(formatting_prompts_func, batched = True,)

print(f"Loaded {len(dataset)} examples from {dataset_file}")

## 4. Training

The `SFTTrainer` manages the optimization loop. We use `SFTConfig` for better configuration and tracking.

### Arguments Explained:
- **train_dataset**: The processed data to learn from.
- **dataset_text_field**: The key in our dataset containing the formatted strings.
- **args**: High-level configuration via `SFTConfig`. This controls tracking, learning rates, and more.

### SFTConfig Details:
- **output_dir**: Where checkpoints and logs are stored.
- **learning_rate**: How fast the model adjusts its weights. `2e-4` is standard for LoRA.
- **num_train_epochs**: Pass count over the full dataset.
- **logging_steps**: Frequency of progress updates.
- **report_to**: Set to `"tensorboard"` to enable visual progress tracking.

In [None]:
from unsloth_mlx import SFTTrainer
from trl import SFTConfig

# 1. Define the Training Configuration
sft_config = SFTConfig(
    output_dir="./lora_finetuned",
    logging_steps=1,
    num_train_epochs=3,
    max_seq_length=2048,
    learning_rate=2e-4,
    report_to="tensorboard", # Logs progress for visualization
)

# 2. Initialize the Trainer
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    args = sft_config,
)

print("Starting training...")
trainer.train()
print("Training complete!")

### Visualizing Progress with TensorBoard

Since we enabled `report_to="tensorboard"`, you can view the training loss and other metrics in real-time. 
Run the cell below to launch TensorBoard within the notebook.

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lora_finetuned

## 5. Saving the Model

We use `model.save_pretrained("lora_model")` to save the results. 

### Saving Format:
- **Format**: It saves the adapters in **`adapters.safetensors`** format along with config files (`adapter_config.json`). 
- **Efficiency**: It does **not** save the full multi-gigabyte base model. Instead, it only saves the few megabytes of LoRA weights we trained. 
- **Usage**: To use this later, you load the base Gemma-3 model and then "attach" these specific adapter files.

In [None]:
model.save_pretrained("lora_model")
tokenizer.save_pretrained("lora_model")
print("Model and tokenizer saved to './lora_model'")

## 6. Testing (Inference)

Finally, we run the model on a new prompt to see what it learned.

### Why `for_inference`?
- **FastLanguageModel.for_inference(model)**: Switches the model into an optimized generation mode (disables gradients, enables KV caching). This makes token generation significantly faster.

In [None]:
FastLanguageModel.for_inference(model)

prompt = "### Instruction:\nWrite the code for the file `index.html` in the codebase project.\n\n### Response:"
input_ids = mx.array(tokenizer.encode(prompt))

output = model.generate(input_ids = input_ids, max_new_tokens = 128)
print(tokenizer.decode(output))