# VLM Notebook 2: Fine-tuning Vision-Language Models with LoRA


## 1. Introduction to LoRA

### What is LoRA?

**LoRA (Low-Rank Adaptation)** is a parameter-efficient fine-tuning method that:

- Freezes the pre-trained model weights
- Injects trainable low-rank decomposition matrices into each layer
- Reduces trainable parameters by 10,000x while maintaining performance
- Enables fine-tuning on consumer GPUs

### Mathematical Foundation

For a pre-trained weight matrix $W_0 \in \mathbb{R}^{d \times k}$:

$$W = W_0 + \Delta W = W_0 + BA$$

Where:
- $B \in \mathbb{R}^{d \times r}$ and $A \in \mathbb{R}^{r \times k}$
- $r \ll \min(d, k)$ (rank is much smaller)
- Only $B$ and $A$ are trained

### Why LoRA for VLMs?

| Aspect | Full Fine-tuning | LoRA |
|--------|-----------------|------|
| **Parameters** | Billions | Millions |
| **GPU Memory** | 80GB+ | 16GB |
| **Training Time** | Days | Hours |
| **Storage** | GBs per task | MBs per task |
| **Performance** | 100% | 95-99% |

---

## 2. Setup and Installation

### Install Required Libraries

In [None]:
!pip install -q -U trl bitsandbytes peft datasets tensorboard
!pip install -q transformers accelerate pillow requests matplotlib
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

### Import Dependencies

In [None]:
import torch
import numpy as np
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import os
os.environ["WANDB_DISABLED"] = "true"

# Transformers and PEFT
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from peft import LoraConfig

# Datasets
from datasets import load_dataset

# TRL for SFTTrainer
from trl import SFTTrainer, SFTConfig

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

---

## 3. Dataset Preparation


### Load Dataset

In [None]:
from collections import Counter
import random
# from sklearn.model_selection import train_test_split # Removed sklearn's train_test_split

# Load
dataset = load_dataset("lmms-lab/VQAv2", split="validation[:1%]")

# Helper to get most frequent answer
def get_best_answer(example):
    answers = [ans["answer"] for ans in example["answers"]]
    return Counter(answers).most_common(1)[0][0]

# Format as chat: user (image + question) → assistant (answer)
def format_chat(example):
    best_answer = get_best_answer(example)
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": example["question"]}
            ]
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": best_answer}]
        }
    ]
    # Return messages and the original image column to keep it in the dataset
    return {"messages": messages, "image": example["image"]}

# Apply formatting
formatted_dataset = dataset.map(format_chat, remove_columns=dataset.column_names)

# Split for train/eval (80/20) using the datasets library's own train_test_split method
split_dataset = formatted_dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(eval_dataset)}")
print(f"Sample train data keys: {train_dataset[0].keys()}")

### Visualize Dataset Samples

In [None]:
# Take the first example
sample = dataset[0]

# Display the image
image = sample["image"]  # This is already a PIL Image in lmms-lab/VQAv2
plt.figure(figsize=(8, 8))
plt.imshow(image)
plt.axis("off")
plt.title(f"Question: {sample['question']}", fontsize=14, pad=20)
plt.show()

# Print question and all answers with their counts
print(f"Question: {sample['question']}")
print(f"Question ID: {sample['question_id']}")
print("\nAnswers:")
for answer in sample["answers"]:
    print(f"  • {answer['answer']}")

---

## 4. Model Preparation

### Load Model and Processor

In [None]:
model_id = "google/gemma-3-4b-it"  # Smallest multimodal variant

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_storage=torch.bfloat16,
)

model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    device_map="auto",
    dtype=torch.bfloat16,
    attn_implementation="eager",
    quantization_config=bnb_config
)

processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.padding_side = "right"

### Configure LoRA

In [None]:
# LoRA configuration
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"],
    ensure_weight_tying=True,
)

### Understanding LoRA Parameters

| Parameter | Description | Typical Values |
|-----------|-------------|----------------|
| **r** | Rank of adaptation matrices | 8, 16, 32, 64 |
| **lora_alpha** | Scaling factor (typically 2×r) | 16, 32, 64 |
| **target_modules** | Which layers to adapt | q_proj, v_proj, k_proj |
| **lora_dropout** | Regularization | 0.05, 0.1 |

**Trade-offs:**
- Higher `r` → More capacity but more parameters
- More `target_modules` → Better performance but slower training

---

## 5. Training Setup
### Define Collate Function

In [None]:
def collate_fn(examples):
    texts = []
    images = []

    for ex in examples:
        # Apply chat template to full conversation (includes assistant for SFT)
        text = processor.apply_chat_template(
            ex["messages"],
            tokenize=False,
            add_generation_prompt=False,  # No extra prompt for training
        ).strip()
        texts.append(text)
        images.append(ex["image"].convert("RGB"))

    # Processor tokenizes + embeds images
    batch = processor(
        text=texts,
        images=images,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512,  # Tune based on your VRAM; Gemma supports 128K
    )

    # Labels: Clone input_ids, mask pads (image embeddings auto-ignored in loss)
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100

    batch["labels"] = labels
    return batch

### Configure Training Arguments

In [None]:
training_args = SFTConfig(
    output_dir="gemma-3-4b-vqa-finetuned",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    save_strategy="epoch",
    learning_rate=2e-05,
    bf16=True,
    report_to="none",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False,
    logging_steps=10,  # For monitoring
)

### Set Up SFTTrainer

In [None]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    processing_class=processor,
)

---

## 6. Train the Model

In [None]:
trainer.train()