<a href="https://colab.research.google.com/github/Ravi-Teja-konda/AudioInsightsGenerator/blob/main/%F0%9F%96%A5%EF%B8%8F_Gemma_3n_GUI_Finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U torch torchvision torchaudio
!pip install transformers datasets accelerate
!pip install trl peft bitsandbytes
!pip install huggingface_hub pillow
!pip install wandb  # Optional: for experiment tracking
!pip install -U git+https://github.com/huggingface/transformers.git
!pip install -U git+https://github.com/huggingface/pytorch-image-models.git

In [None]:
import io
import os
import zipfile

import torch
from datasets import load_dataset
from huggingface_hub import hf_hub_download, list_repo_files
from PIL import Image
from transformers import (
    AutoModelForImageTextToText,
    AutoProcessor,
    Gemma3nForConditionalGeneration,
)

from trl import SFTConfig, SFTTrainer
from peft import LoraConfig


In [None]:
# Model and training parameters
MODEL_NAME = "google/gemma-3n-E2B-it"
DATASET_NAME = "rootsautomation/ScreenSpot"
OUTPUT_DIR = "gemma-3n-E2B-it-trl-sft-screenspot"

In [None]:
def format_screenspot_data(samples: dict) -> dict:
    """Format ScreenSpot dataset to match expected message format"""
    formatted_samples = {"messages": []}
    for idx in range(len(samples["image"])):
        image = samples["image"][idx].convert("RGB")

        # Handle different possible field names for instruction
        instruction = None
        if "instruction" in samples:
            instruction = samples["instruction"][idx]
        elif "text" in samples:
            instruction = samples["text"][idx]
        elif "query" in samples:
            instruction = samples["query"][idx]

        # Handle different possible field names for target/answer
        target = None
        if "target" in samples:
            target = str(samples["target"][idx])
        elif "answer" in samples:
            target = str(samples["answer"][idx])
        elif "location" in samples:
            target = str(samples["location"][idx])
        elif "coordinates" in samples:
            target = str(samples["coordinates"][idx])

        # If no explicit target, use the instruction as a grounding task
        if target is None:
            target = "I'll help you locate that element in the screenshot."

        message = [
            {
                "role": "system",
                "content": [
                    {
                        "type": "text",
                        "text": "You are a helpful assistant that can identify and locate GUI elements in screenshots.",
                    }
                ],
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {
                        "type": "text",
                        "text": instruction,
                    },
                ],
            },
            {"role": "assistant", "content": [{"type": "text", "text": target}]},
        ]
        formatted_samples["messages"].append(message)
    return formatted_samples


def process_vision_info(messages: list) -> list:
    """Extract images from message content"""
    image_inputs = []
    for msg in messages:
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        for element in content:
            if isinstance(element, dict) and (
                "image" in element or element.get("type") == "image"
            ):
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                if image is not None:
                    # Handle dictionary with bytes
                    if isinstance(image, dict) and "bytes" in image:
                        pil_image = Image.open(io.BytesIO(image["bytes"]))
                        image_inputs.append(pil_image.convert("RGB"))
                    # Handle PIL Image objects
                    elif hasattr(image, "convert"):
                        image_inputs.append(image.convert("RGB"))
    return image_inputs


In [None]:
# Load processor and model
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
processor.tokenizer.padding_side = "right"

model = Gemma3nForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    # torch_dtype=torch.bfloat16, # Removed
    attn_implementation="eager",
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

## 6. Dataset Loading and Processing

Load and format the ScreenSpot dataset:

In [None]:
# Load dataset
dataset = load_dataset(DATASET_NAME)
dataset


DatasetDict({
    test: Dataset({
        features: ['file_name', 'bbox', 'instruction', 'data_type', 'data_source', 'image'],
        num_rows: 1272
    })
})

In [None]:
# Format the dataset for training
formatted_dataset = dataset.map(
    format_screenspot_data, batched=True, batch_size=4, num_proc=4
)
print("Dataset formatted for training")


Dataset formatted for training


In [None]:
def collate_fn(examples):
    """Custom collate function for processing vision-language data"""
    texts = []
    images_list = []

    for example in examples:
        # Apply chat template to get text
        text = processor.apply_chat_template(
            example["messages"], tokenize=False, add_generation_prompt=False
        ).strip()
        texts.append(text)

        # Extract images
        if "images" in example:  # single-image case
            images = [img.convert("RGB") for img in example["images"]]
        else:  # multi-image case or intersection dataset
            images = process_vision_info(example["messages"])
        images_list.append(images)

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images_list, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens
    labels = batch["input_ids"].clone()

    # Use Gemma3n specific token masking
    labels[labels == processor.tokenizer.pad_token_id] = -100
    if hasattr(processor.tokenizer, "image_token_id"):
        labels[labels == processor.tokenizer.image_token_id] = -100
    if hasattr(processor.tokenizer, "boi_token_id"):
        labels[labels == processor.tokenizer.boi_token_id] = -100
    if hasattr(processor.tokenizer, "eoi_token_id"):
        labels[labels == processor.tokenizer.eoi_token_id] = -100

    batch["labels"] = labels
    return batch


In [None]:
# Training configuration
training_args = SFTConfig(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    num_train_epochs=3,
    learning_rate=1e-4,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    weight_decay=0.01,
    max_grad_norm=1.0,
    bf16=True,
    remove_unused_columns=False,
    gradient_checkpointing=False,
    dataloader_pin_memory=False,
    dataset_kwargs={"skip_prepare_dataset": True},
    report_to="none"
)

peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

In [None]:
# Initialize trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=formatted_dataset["test"],
    processing_class=processor.tokenizer,
    peft_config=peft_config,
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
# Train the model
trainer.train()

Step,Training Loss
10,33.3697
20,5.7562
30,4.7187
40,5.3259
50,4.5155
60,4.4695
70,3.9725
80,4.4368
90,4.352
100,4.5198


TrainOutput(global_step=477, training_loss=4.201079038703966, metrics={'train_runtime': 13127.7412, 'train_samples_per_second': 0.291, 'train_steps_per_second': 0.036, 'total_flos': 1.9970850051072e+16, 'train_loss': 4.201079038703966})

In [None]:
# Save model and processor
trainer.save_model(training_args.output_dir)
processor.save_pretrained(training_args.output_dir)
print(f"Model saved to: {training_args.output_dir}")


In [None]:
# Get test sample
test_sample = formatted_dataset["test"][0]
test_image = None
test_question = None
expected_answer = None

for message in test_sample["messages"]:
    if message["role"] == "user":
        for content in message["content"]:
            if content["type"] == "image":
                test_image = content["image"]
            elif content["type"] == "text":
                test_question = content["text"]
    elif message["role"] == "assistant":
        expected_answer = message["content"][0]["text"]


In [None]:
# Display test sample
import matplotlib.pyplot as plt

print(f"Question: {test_question}")
print(f"Expected answer: {expected_answer}")

if test_image is not None:
    plt.figure(figsize=(8, 6))
    plt.imshow(test_image)
    plt.axis("off")
    plt.title("Test Image")
    plt.show()


In [None]:
# Generate prediction
if test_image is not None:
    messages = [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": "You are a helpful assistant that can identify and locate GUI elements in screenshots.",
                }
            ],
        },
        {
            "role": "user",
            "content": [
                {"type": "image", "image": test_image},
                {
                    "type": "text",
                    "text": test_question,
                },
            ],
        },
    ]

    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = processor(text=text, images=[test_image], return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            do_sample=False,
            temperature=0.1,
        )

    response = processor.decode(outputs[0], skip_special_tokens=True)
    generated_text = response[len(text) :].strip()

    print(f"Model prediction: {generated_text}")
    print(f"Expected: {expected_answer}")
    print(f"Match: {generated_text.strip() == expected_answer.strip()}")


In [None]:
# Optional: Push to Hugging Face Hub
# hub_model_id = "your-username/gemma-3n-screenspot-finetuned"
# trainer.push_to_hub(hub_model_id)
# processor.push_to_hub(hub_model_id)


In [None]:
import torch
import gc

del model
del trainer
del processor

# Clear CUDA cache
torch.cuda.empty_cache()

# Collect garbage
gc.collect()

# Optional: Delete model and trainer if they are large
# del model
# del trainer
# gc.collect()

173