# Fine-Tuning 101: LoRA on Apple Silicon

This notebook walks through fine-tuning **Google Gemma 3 1B** using **LoRA** (Low-Rank Adaptation) on a MacBook Pro with Apple Silicon. By the end, you'll have a model that's been fine-tuned on medical Q&A data, and you'll understand each step well enough to scale up to larger models on GPU.

**What we're doing:**
- Loading a 1B parameter model (~4GB)
- Adding tiny trainable LoRA adapters (~3.4M params, ~0.3% of the model)
- Training on 1,000 medical flashcard Q&A pairs
- Comparing before/after responses

**Time:** ~20 minutes end-to-end on M3 Pro (18GB RAM)

## 1. Environment Check

First, let's verify that MPS (Metal Performance Shaders) is available — this is Apple Silicon's GPU acceleration for PyTorch. We also set an environment variable so that any operations not yet supported on MPS will silently fall back to CPU instead of crashing.

In [1]:
import os
import torch

# MPS doesn't support every PyTorch op yet. This flag makes unsupported ops
# fall back to CPU automatically instead of raising an error.
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# If you hit OOM during training, uncomment this line. It lets MPS allocate
# beyond its default memory limit (at the cost of potential system slowdown).
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

print(f"PyTorch version: {torch.__version__}")
print(f"MPS available: {torch.backends.mps.is_available()}")
print(f"MPS built: {torch.backends.mps.is_built()}")

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

PyTorch version: 2.10.0
MPS available: True
MPS built: True
Using device: mps


## 2. HuggingFace Login

Gemma is a **gated model** — you need to:
1. Have a HuggingFace account
2. Accept Google's terms at https://huggingface.co/google/gemma-3-1b-it
3. Create an access token at https://huggingface.co/settings/tokens

Create a `.env` file in this directory with your token:
```
HF_TOKEN=hf_your_token_here
```

In [2]:
from dotenv import load_dotenv
from huggingface_hub import login, whoami

load_dotenv()

token = os.environ.get("HF_TOKEN")
if not token:
    raise ValueError("HF_TOKEN not found. Create a .env file with: HF_TOKEN=hf_your_token_here")

login(token=token)
info = whoami()
print(f"Logged in as: {info['name']}")

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


Logged in as: abicyclerider


## 3. Load Base Model

We're using **Gemma 3 1B Instruct** (`google/gemma-3-1b-it`), the smallest model in Google's Gemma family. It shares the same architecture as MedGemma, so what you learn here transfers directly.

**Why float32?** Two Apple Silicon constraints force our hand:
- **bfloat16** is not supported on MPS
- **float16** causes numerical issues with Gemma 3's architecture

So we use float32, which means ~4GB for the model weights. With LoRA overhead and training state, expect ~10-13GB total memory usage.

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "google/gemma-3-1b-it"

print(f"Loading tokenizer from {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

print(f"Loading model from {MODEL_ID} in float32...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    dtype=torch.float32,
    device_map="mps",
)

print(f"\nModel loaded on: {model.device}")
print(f"Total parameters: {model.num_parameters():,}")
print(f"Model dtype: {model.dtype}")

Loading tokenizer from google/gemma-3-1b-it...
Loading model from google/gemma-3-1b-it in float32...


Loading weights:   0%|          | 0/340 [00:00<?, ?it/s]


Model loaded on: mps:0
Total parameters: 999,885,952
Model dtype: torch.float32


Let's peek at the model architecture. The key thing to notice is the repeated `GemmaDecoderLayer` blocks — each one contains attention layers (`q_proj`, `k_proj`, `v_proj`, `o_proj`) and feed-forward layers. LoRA will target the attention projections.

In [4]:
# Show just the first decoder layer to see the structure
print(model.model.layers[0])

Gemma3DecoderLayer(
  (self_attn): Gemma3Attention(
    (q_proj): Linear(in_features=1152, out_features=1024, bias=False)
    (k_proj): Linear(in_features=1152, out_features=256, bias=False)
    (v_proj): Linear(in_features=1152, out_features=256, bias=False)
    (o_proj): Linear(in_features=1024, out_features=1152, bias=False)
    (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
    (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
  )
  (mlp): Gemma3MLP(
    (gate_proj): Linear(in_features=1152, out_features=6912, bias=False)
    (up_proj): Linear(in_features=1152, out_features=6912, bias=False)
    (down_proj): Linear(in_features=6912, out_features=1152, bias=False)
    (act_fn): GELUTanh()
  )
  (input_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
  (post_attention_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
  (pre_feedforward_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
  (post_feedforward_layernorm): Gemma3RMSNorm((1152,), eps=1e-06)
)


## 4. Test Base Model

Before fine-tuning, let's see how the base model responds to medical questions. We'll save these responses and compare them to the fine-tuned model later.

In [5]:
test_questions = [
    "What are the main symptoms of Type 2 diabetes?",
    "What is the mechanism of action of metformin?",
    "What are the risk factors for developing hypertension?",
]


def generate_response(model, tokenizer, question, max_new_tokens=256):
    """Generate a response using the chat template."""
    messages = [{"role": "user", "content": question}]
    input_text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
        )

    # Decode only the generated tokens (skip the input)
    response = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
    )
    return response.strip()


print("=" * 60)
print("BASE MODEL RESPONSES")
print("=" * 60)

base_responses = []
for q in test_questions:
    print(f"\nQ: {q}")
    response = generate_response(model, tokenizer, q)
    base_responses.append(response)
    print(f"A: {response[:500]}")
    print("-" * 40)

BASE MODEL RESPONSES

Q: What are the main symptoms of Type 2 diabetes?
A: Okay, let's break down the main symptoms of Type 2 diabetes. It's really important to remember that many people with Type 2 diabetes have *no* noticeable symptoms at all, especially in the early stages. This is why regular check-ups are so crucial. However, when symptoms do appear, they can be varied and often mimic other conditions.

Here's a breakdown of the common symptoms, grouped by severity:

**1. Early & Subtle Symptoms (Often Overlooked):**

* **Increased Thirst (Polydipsia):** Feeling v
----------------------------------------

Q: What is the mechanism of action of metformin?
A: Metformin is a cornerstone medication for type 2 diabetes, and its mechanism of action is complex and still being actively researched. It’s not a single "magic bullet," but rather a multifaceted effect on glucose metabolism. Here's a breakdown of the current understanding:

**1. Primarily Reduces Glucose Production in the Liver:

## 5. Prepare Dataset

We'll use the **Medical Meadow Medical Flashcards** dataset — 33K medical Q&A pairs from HuggingFace Hub. For this tutorial, we only use 1,000 training examples and 200 for evaluation to keep training fast.

**Data format:** SFTTrainer expects conversations in the chat message format that matches the model's chat template. Each example becomes:
```
[{"role": "user", "content": <question>}, {"role": "assistant", "content": <answer>}]
```

In [6]:
from datasets import load_dataset

print("Loading medical flashcards dataset...")
raw_dataset = load_dataset("medalpaca/medical_meadow_medical_flashcards", split="train")
print(f"Total examples: {len(raw_dataset):,}")

# Peek at the raw format
print(f"\nColumns: {raw_dataset.column_names}")
print(f"\nExample:")
print(f"  input: {raw_dataset[0]['input'][:200]}")
print(f"  output: {raw_dataset[0]['output'][:200]}")

Loading medical flashcards dataset...
Total examples: 33,955

Columns: ['input', 'output', 'instruction']

Example:
  input: What is the relationship between very low Mg2+ levels, PTH levels, and Ca2+ levels?
  output: Very low Mg2+ levels correspond to low PTH levels which in turn results in low Ca2+ levels.


In [7]:
# Format into chat messages for SFTTrainer
def format_to_chat(example):
    """Convert raw Q&A pair to chat message format."""
    messages = [
        {"role": "user", "content": example["input"]},
        {"role": "assistant", "content": example["output"]},
    ]
    return {"messages": messages}


# Shuffle and select subsets
shuffled = raw_dataset.shuffle(seed=42)
train_dataset = shuffled.select(range(1000)).map(format_to_chat)
eval_dataset = shuffled.select(range(1000, 1200)).map(format_to_chat)

print(f"Training examples: {len(train_dataset)}")
print(f"Evaluation examples: {len(eval_dataset)}")
print(f"\nFormatted example:")
print(train_dataset[0]["messages"])

Training examples: 1000
Evaluation examples: 200

Formatted example:
[{'content': 'What type of injury to the arm/elbow most often leads to supracondylar fractures?', 'role': 'user'}, {'content': 'Supracondylar fractures most often occur after hyperextension injuries of the arm/elbow.', 'role': 'assistant'}]


## 6. Configure LoRA

### What is LoRA?

**LoRA (Low-Rank Adaptation)** is a technique that makes fine-tuning practical by freezing all original model weights and injecting small trainable matrices into specific layers.

Instead of updating a weight matrix **W** (e.g., 2048×2048 = 4M params), LoRA decomposes the update into two small matrices: **A** (2048×8) and **B** (8×2048). That's only 32K params instead of 4M — a 125x reduction.

The key hyperparameters:
- **rank (`r`)**: Size of the low-rank matrices. Higher = more capacity, more memory. We use 8.
- **alpha**: Scaling factor. Convention is 2× rank. We use 16.
- **target modules**: Which layers get LoRA adapters. We target the attention projections (`q_proj`, `k_proj`, `v_proj`, `o_proj`).

In [8]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=8,                           # Rank of the low-rank matrices
    lora_alpha=16,                 # Scaling factor (convention: 2x rank)
    target_modules=[               # Which layers to add LoRA to
        "q_proj",                  #   Query projection
        "k_proj",                  #   Key projection
        "v_proj",                  #   Value projection
        "o_proj",                  #   Output projection
    ],
    lora_dropout=0.05,             # Small dropout for regularization
    bias="none",                   # Don't train bias terms
    task_type="CAUSAL_LM",         # We're doing causal language modeling
)

# Apply LoRA to the model
model = get_peft_model(model, lora_config)

# Show the parameter breakdown
model.print_trainable_parameters()

trainable params: 1,490,944 || all params: 1,001,376,896 || trainable%: 0.1489


The exact numbers depend on the model version, but the key point is the same — we're only training **~0.15%** of the parameters. That's the magic of LoRA.

## 7. Train

We use `SFTTrainer` (Supervised Fine-Tuning Trainer) from the `trl` library. It handles chat template formatting, packing, and the training loop.

**Training config rationale:**
- **batch_size=1** with **gradient_accumulation=4**: Keeps peak memory low while simulating batch_size=4
- **1 epoch**: ~250 steps, enough to see the model learn without overfitting
- **max_seq_length=512**: Covers most flashcard Q&A pairs while keeping memory reasonable
- **learning_rate=2e-4**: Standard for LoRA fine-tuning
- **num_workers=0, pin_memory=False**: Required for MPS compatibility

In [9]:
from trl import SFTTrainer, SFTConfig

training_args = SFTConfig(
    # Output
    output_dir="./output",

    # Training duration
    num_train_epochs=1,
    max_steps=-1,                    # -1 means use num_train_epochs

    # Batch size & memory
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    max_length=512,

    # Memory optimization
    gradient_checkpointing=True,

    # Optimizer
    learning_rate=1e-4,
    weight_decay=0.01,
    warmup_steps=25,
    lr_scheduler_type="cosine",

    # Logging & eval
    logging_steps=25,
    eval_strategy="no",
    save_strategy="steps",
    save_steps=500,

    # MPS compatibility
    bf16=False,
    fp16=False,
    dataloader_num_workers=0,
    dataloader_pin_memory=False,

    # Misc
    seed=42,
    report_to="none",                # No wandb/tensorboard for this tutorial
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    processing_class=tokenizer,
)

print(f"Training samples: {len(train_dataset)}")
print(f"Eval samples: {len(eval_dataset)}")
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"Steps per epoch: ~{len(train_dataset) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps)}")
print(f"\nStarting training...")

Training samples: 1000
Eval samples: 200
Effective batch size: 2
Steps per epoch: ~500

Starting training...


In [10]:
# This is the actual training cell — takes ~15-20 min on M3 Pro
train_result = trainer.train()

print(f"\nTraining complete!")
print(f"Total steps: {train_result.global_step}")
print(f"Final training loss: {train_result.training_loss:.4f}")

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1}.


Step,Training Loss
25,5.182949
50,3.049642
75,1.928152
100,1.739538
125,1.72163
150,1.604375
175,1.670352
200,1.578549
225,1.579519
250,1.522842



Training complete!
Total steps: 500
Final training loss: 1.7733


You should see the training loss decrease over the ~250 steps. A typical pattern:
- Step 25: loss ~2.5-3.0
- Step 100: loss ~1.5-2.0  
- Step 250: loss ~1.0-1.5

If your loss isn't decreasing, something went wrong with the data formatting or config.

## 8. Evaluate: Base vs Fine-Tuned

Now let's ask the same 3 questions and compare the responses. The fine-tuned model should give more focused, medical-flashcard-style answers.

In [14]:
# Disable gradient checkpointing for inference — it conflicts with the KV cache
# that generate() uses for autoregressive decoding
model.gradient_checkpointing_disable()

print("=" * 60)
print("COMPARISON: BASE vs FINE-TUNED")
print("=" * 60)

for i, q in enumerate(test_questions):
    print(f"\nQ: {q}")
    print(f"\n[BASE MODEL]:")
    print(f"{base_responses[i][:500]}")
    print(f"\n[FINE-TUNED]:")
    ft_response = generate_response(model, tokenizer, q)
    print(f"{ft_response[:500]}")
    print("=" * 60)

COMPARISON: BASE vs FINE-TUNED

Q: What are the main symptoms of Type 2 diabetes?

[BASE MODEL]:
Okay, let's break down the main symptoms of Type 2 diabetes. It's really important to remember that many people with Type 2 diabetes have *no* noticeable symptoms at all, especially in the early stages. This is why regular check-ups are so crucial. However, when symptoms do appear, they can be varied and often mimic other conditions.

Here's a breakdown of the common symptoms, grouped by severity:

**1. Early & Subtle Symptoms (Often Overlooked):**

* **Increased Thirst (Polydipsia):** Feeling v

[FINE-TUNED]:
Type 2 diabetes can be diagnosed by looking for symptoms like increased thirst, frequent urination, and fatigue. These symptoms can occur in individuals with Type 2 diabetes, as well as in those who have a family history of the condition. Type 2 diabetes is a chronic condition that is characterized by high blood sugar levels, which can lead to complications such as kidney disease, ner

## 9. Save & Load Adapter

One of the best things about LoRA: the adapter weights are **tiny**. The full model is ~4GB, but the LoRA adapter is just a few MB. You can save multiple fine-tuned versions without duplicating the base model.

The pattern for deployment:
1. Save adapter weights (few MB)
2. Load base model
3. Load adapter on top
4. Optionally merge adapter into base model for faster inference

In [15]:
# Save the LoRA adapter
adapter_path = "./output/medical-flashcards-adapter"
model.save_pretrained(adapter_path)
tokenizer.save_pretrained(adapter_path)

# Show how small the adapter is
import pathlib
adapter_size = sum(
    f.stat().st_size for f in pathlib.Path(adapter_path).rglob("*") if f.is_file()
)
print(f"Adapter saved to: {adapter_path}")
print(f"Adapter size: {adapter_size / 1024 / 1024:.1f} MB")
print(f"(vs ~4,000 MB for the full model)")

Adapter saved to: ./output/medical-flashcards-adapter
Adapter size: 37.6 MB
(vs ~4,000 MB for the full model)


In [16]:
# Demonstrate loading the adapter onto a fresh base model
from peft import PeftModel

print("Loading fresh base model...")
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float32,
    device_map="mps",
)

print("Loading LoRA adapter on top...")
loaded_model = PeftModel.from_pretrained(base_model, adapter_path)

# Verify it works
response = generate_response(loaded_model, tokenizer, test_questions[0])
print(f"\nQ: {test_questions[0]}")
print(f"A: {response[:300]}")
print("\nAdapter loaded and working!")

Loading fresh base model...


Loading weights:   0%|          | 0/340 [00:00<?, ?it/s]

Loading LoRA adapter on top...

Q: What are the main symptoms of Type 2 diabetes?
A: Type 2 diabetes is characterized by high blood sugar levels, as well as the presence of high cholesterol levels.

Adapter loaded and working!


## 10. Next Steps

You've completed a full LoRA fine-tuning loop! Here's how to scale up:

### Immediate improvements
- **More data**: Use all 33K flashcards instead of 1K
- **Higher LoRA rank**: Try `r=16` or `r=32` for more capacity
- **More target modules**: Add `gate_proj`, `up_proj`, `down_proj` (feed-forward layers)
- **Multiple epochs**: 2-3 epochs with early stopping

### Scaling to GPU (RunPod)
- **QLoRA**: Use `bitsandbytes` to load the model in 4-bit, then apply LoRA. Needs CUDA GPU.
- **Larger models**: Gemma 3 4B, 12B, or 27B with QLoRA on an A100/H100
- **MedGemma**: Same architecture as Gemma — this exact notebook works with `google/medgemma-4b-it` or `google/medgemma-27b-it` on a GPU with enough VRAM

### For entity resolution
- Fine-tune MedGemma on your entity resolution training pairs
- Format: patient record pairs → match/no-match with reasoning
- The same LoRA approach works — just change the dataset and possibly the target modules

### Key config changes for GPU
```python
# QLoRA on CUDA GPU
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
    "google/medgemma-27b-it",
    quantization_config=bnb_config,
    device_map="auto",
)
```