# Vision Language Models with Gemma 3

This notebook demonstrates how to use the Gemma3 Vision Language Model for various tasks, including text generation, image captioning, and answering questions about single or multiple images.

In [1]:
import random
from pathlib import Path

import numpy as np
import torch
from PIL import Image

from lxmls.multimodal.gemma3 import config
from lxmls.multimodal.gemma3 import model as gemma3_model
from lxmls.multimodal.gemma3.utils import set_default_tensor_type, display_prompt_and_result, format_prompt

## Setup

First, let's define the arguments. You will need to set the `ckpt` path to your downloaded model checkpoint.

In [2]:
class Args:
    model_dir: str = "../../../data/vlm/gemma3"
    image_dir: str = "../../../data/vlm/images"
    device: str = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    output_len: int = 128
    seed: int = 42
    quant: bool = False

args = Args()

In [3]:
from huggingface_hub import snapshot_download
snapshot_download("rshwndsz/gemma-3-4b-it-ckpt", local_dir=args.model_dir)

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

'/home/dsouzars/projects/lxmls-toolkit/data/vlm/gemma3'

### Load Model and Images

In [4]:
# Construct the model config
model_config = config.get_model_config()
model_config.dtype = "float32"
model_config.quant = args.quant
model_config.tokenizer = str(Path(args.model_dir) / "tokenizer.model")

# Reproducibility settings
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

# Instantiate model and load weights
device = torch.device(args.device)
with set_default_tensor_type(model_config.get_dtype()):
    model = gemma3_model.Gemma3ForMultimodalLM(model_config)
    model.load_state_dict(torch.load(Path(args.model_dir) / "model.ckpt")["model_state_dict"])
    model = model.to(device).eval()
print("Model loading done")

Model loading done


## Text-Only Generation

In [None]:
prompts = [
    format_prompt(["Write a poem about a chonky cat."]),
]
results = model.generate(
    prompts,
    device,
    output_len=args.output_len,
)

for i, (prompt, result) in enumerate(zip(prompts, results)):
    display_prompt_and_result(prompt, result)

INPUT
<start_of_turn>user
Write a poem about a chonky cat.<end_of_turn>
<start_of_turn>model
 --------------------------------------------------------------------------------
GENERATED
<start_of_turn>user
Write a poem about a chonky cat.<end_of_turn>
<start_of_turn>modelOkay, here's a poem about a chonky cat, aiming for a lighthearted and affectionate tone:

**The Sultan of Softness**

A rumble of warmth, a velvet plea,
A furry mountain, happy to be.
Sir Reginald, or just Reggie, you see,
Is a chonky cat of magnificent glee.

His belly jiggles with a joyful sway,
As he navigates his kingdom, day by day.
A nap upon the sofa, a blissful sigh,
Beneath a fluffy, contented eye.

He sheds a little, it’s true, it’s quite a


: 

## Generation with text & a single image as input

In [None]:
golden_test_image_path = Path(args.image_dir) / "test_image.jpg"
prompts = [
    format_prompt([ Image.open(golden_test_image_path), "Caption this image." ]),
]
results = model.generate(
    prompts,
    device,
    output_len=args.output_len,
)

for prompt, result in zip(prompts, results):
    display_prompt_and_result(prompt, result)

In [None]:
cow_in_beach_path = Path(args.image_dir) / "cow_in_beach.jpg"
prompt = [
    format_prompt([ Image.open(cow_in_beach_path), "The name of the animal in the image is"]),
]
results = model.generate(
    prompts,
    device,
    output_len=args.output_len,
)

for prompt, result in zip(prompts, results):
    display_prompt_and_result(prompt, result)

## Generation with interleaved image & text input

In [None]:
lilly_path = Path(args.image_dir) / "lilly.jpg"
sunflower_path = Path(args.image_dir) / "sunflower.jpg"
prompts = [
    format_prompt([
        "This image", Image.open(lilly_path),
        "and this image", Image.open(sunflower_path),
        "are similar because? Give me the main reason."
    ]),
]
results = model.generate(
    prompts,
    device,
    output_len=args.output_len,
)

for prompt, result in zip(prompts, results):
    display_prompt_and_result(prompt, result)