<a href="https://colab.research.google.com/github/LennoxC/gemma3-supplements-small-finetune/blob/main/gemma3_finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Gemma3-4b fine-tune for food and beverage label OCR-VQA

In [None]:
#%pip install torch tensorboard
#%pip install transformers datasets accelerate evaluate trl protobuf sentencepiece
#%pip install flash-attn # only if GPU supports flashAttention (nvidia Ampere)

In [None]:
from google.colab import userdata

In [None]:
import os

In [None]:
hf_token = userdata.get('HF_TOKEN')

In [None]:
# for saving results on google drive

from google.colab import drive
drive.mount('/content/drive')

In [None]:
base_model = "google/gemma-3-4b-it"
checkpoint_dir = "/content/drive/MyDrive/MyGemmaNPC"
learning_rate = 5e-5

In [None]:
system_message = "You are a quality control robot responsible for monitoring the quality of supplement labels."
user_prompt = """Using primarily the text contained in the attached label supplement image, answer the list of questions in the <QUESTIONS> tags.
Answer concisely in a JSON format with no preamble, allowing the response to easily be parsed. An example response would be:
{
  "brand": "label supplelments co",
  "contents": 120
}

<QUESTIONS>
"""

In [None]:
def format_data(sample):
    return {
        "messages": [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}],
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": sample["questions"]
                        #"text": user_prompt.format(
                        #    questions=sample["questions"]
                        #),
                    },
                    {
                        "type": "image",
                        "image": sample["image"],
                    },
                ],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": sample["answers"]}],
            },
        ],
    }

In [None]:
import json, random
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class OCRVQADataset(Dataset):
    def __init__(self, jsonl_file, transform=None, min_q=1, max_q=4):
        with open(jsonl_file, 'r') as f:
            self.samples = [json.loads(line) for line in f]
        self.transform = transform or transforms.ToTensor()
        self.min_q = min_q
        self.max_q = max_q

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        # Keep as PIL so process_vision_info works later
        image_path = "/content/drive/MyDrive/datasets/small-supplements-ocrvqa/images/" + sample["image"]
        image = Image.open(image_path).convert("RGB")

        qa_pairs = sample["qas"]
        k = random.randint(self.min_q, min(self.max_q, len(qa_pairs)))
        chosen_pairs = random.sample(qa_pairs, k)

        questions_str = (
            user_prompt
            + "; ".join(f"Question: {p['q']} This corresponds to JSON key {p['k']}" for p in chosen_pairs)
            + "</QUESTIONS>"
        )
        answers_dict = {p['k']: p['a'] for p in chosen_pairs}
        answers_str = json.dumps(answers_dict, ensure_ascii=False)

        return {
            "image": image,  # PIL
            "questions": questions_str,
            "answers": answers_str
        }

In [None]:
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    for msg in messages:
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        for element in content:
            if isinstance(element, dict) and (
                "image" in element or element.get("type") == "image"
            ):
                img = element.get("image", element)
                if isinstance(img, str):
                    img = Image.open(img).convert("RGB")
                elif not isinstance(img, Image.Image):
                    raise ValueError(f"Unsupported image type: {type(img)}")
                image_inputs.append(img)
    return image_inputs


In [None]:
dataset_path = "/content/drive/MyDrive/datasets/small-supplements-ocrvqa/output.jsonl"
dataset_obj = OCRVQADataset(dataset_path)

In [None]:
import torch
from torch.utils.data import random_split

# Set split sizes
total_size = len(dataset_obj)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

# Use a generator with a manual seed for reproducibility
generator = torch.Generator().manual_seed(42)

In [None]:
train_dataset, val_dataset, test_dataset = random_split(
    dataset_obj, [train_size, val_size, test_size], generator=generator
)

train_dataset_fmt = [format_data(sample) for sample in train_dataset]

print(train_dataset_fmt[100])

In [None]:
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

# Hugging Face model id
model_id = "google/gemma-3-4b-it" # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`

# Check if GPU benefits from bfloat16
#if torch.cuda.get_device_capability()[0] < 8:
#    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")

# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch.float16, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig int-4 config
#model_kwargs["quantization_config"] = BitsAndBytesConfig(
#    load_in_4bit=True,
#    bnb_4bit_use_double_quant=True,
#    bnb_4bit_quant_type="nf4",
#    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
#    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
#)

# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained(base_model)

In [None]:
%pip install peft

In [None]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

In [None]:
from trl import SFTConfig

args = SFTConfig(
    output_dir="gemma-supplements-small",       # directory to save and repository id
    num_train_epochs=1,                         # number of training epochs
    per_device_train_batch_size=1,              # batch size per device during training
    gradient_accumulation_steps=4,              # number of steps before performing a backward/update pass
    gradient_checkpointing=True,                # use gradient checkpointing to save memory
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    logging_steps=5,                            # log every 5 steps
    save_strategy="epoch",                      # save checkpoint every epoch
    learning_rate=2e-4,                         # learning rate, based on QLoRA paper
    bf16=True,                                  # use bfloat16 precision
    max_grad_norm=0.3,                          # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                          # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",               # use constant learning rate scheduler
    push_to_hub=True,                           # push model to hub
    report_to="tensorboard",                    # report metrics to tensorboard
    gradient_checkpointing_kwargs={
        "use_reentrant": False
    },  # use reentrant checkpointing
    dataset_text_field="",                      # need a dummy field for collator
    dataset_kwargs={"skip_prepare_dataset": True},  # important for collator
)
args.remove_unused_columns = False # important for collator

# Create a data collator to encode text and image pairs
def collate_fn(examples):
    texts = []
    images = []
    for example in examples:
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
    labels = batch["input_ids"].clone()

    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100

    batch["labels"] = labels
    return batch

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset_fmt,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

In [None]:
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the Hugging Face Hub
trainer.save_model()