Skip to content

Attention mask for multi-image input in gemma3 #38053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
1 of 4 tasks
deval281shah opened this issue May 9, 2025 · 1 comment · May be fixed by #38080
Open
1 of 4 tasks

Attention mask for multi-image input in gemma3 #38053

deval281shah opened this issue May 9, 2025 · 1 comment · May be fixed by #38080
Labels

Comments

@deval281shah
Copy link

deval281shah commented May 9, 2025

System Info

As per the attention mask example in the Gemma3 blog (https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/gemma3/attention-ascii.png), it looks like there is non-causal attention within the image and causal attention across images (i.e., an image does not attend to a future image). However, when running gemma3 generate using transformers (v4.51.3), looks like there is non-causal attention across images.

import torch
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
import os
import pickle
import torch._dynamo
torch._dynamo.config.suppress_errors = True

ckpt = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
    ckpt, device_map="auto", torch_dtype=torch.float32,
)
processor = AutoProcessor.from_pretrained(ckpt)

messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "First image: "},
            {"type": "image", "path": "img1.jpg"},
            {"type": "text", "text": "Second image:"},
            {"type": "image", "path": "img2.jpg"},
            {"type": "text", "text": "Describe all images in single sentence."}
        ]
    }
]
inputs = processor.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=True,
    return_dict=True, return_tensors="pt"
).to(model.device)

generation = model.generate(**inputs, max_new_tokens=1, return_dict_in_generate=True,output_attentions=True , do_sample=False)

Image

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Code attached

Expected behavior

Should the attention across image be non-causal or causal?

@zucchini-nlp
Copy link
Member

Hmm right, just checked the mask and an image attends to all images back and forward. I believe this needs a fix, will check with the original implementation once more and make a fix if needed, thanks for reporting

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants