Skip to content

Error in input expansion for generate with num_return_sequences > 1 for multi-image inputs to AutoModelForImageTextToText #37900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
2 of 4 tasks
saujasv opened this issue Apr 30, 2025 · 1 comment
Assignees
Labels

Comments

@saujasv
Copy link

saujasv commented Apr 30, 2025

System Info

- `transformers` version: 4.51.3
- Platform: Linux-5.14.0-427.40.1.el9_4.x86_64-x86_64-with-glibc2.34
- Python version: 3.12.7
- Huggingface_hub version: 0.30.2
- Safetensors version: 0.5.3
- Accelerate version: 1.4.0
- Accelerate config:    - compute_environment: LOCAL_MACHINE
        - distributed_type: DEEPSPEED
        - mixed_precision: bf16
        - use_cpu: False
        - debug: False
        - num_processes: 2
        - machine_rank: 0
        - num_machines: 1
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - enable_cpu_affinity: False
        - deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'cpu', 'offload_param_device': 'cpu', 'zero3_init_flag': False, 'zero3_save_16bit_model': True, 'zero_stage': 3}
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []
- DeepSpeed version: 0.15.1
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: No
- Using GPU in script?: Yes
- GPU type: NVIDIA L40S

Who can help?

@zucchini-nlp @amyeroberts @qubvel

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I want to generate multiple responses to the same prompt with an image-text-to-text model. One straightforward way to do this is to use the generate function with num_return_sequences > 1 in the GenerationConfig. However, there appears to be an issue with this. I will use the example of the google/gemma-3-12b-it model to present the issue but anecdotally observed this with other models (mistral-community/pixtral-12b, mistralai/Mistral-Small-3.1-24B-Base-2503, etc.) but not sure which ones, or to what extend the specific model influences this issue.

When using generate with num_return_sequences > 1, the inputs are first expanded and then passed to the sample function.

# 11. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)

I suspect that the expansion for image inputs when there are multiple images present does not work as expected leading to this error. More details in reproduction/expected behavior.

Here is a code snippet that reproduces the behavior in my setting:

from transformers import AutoModelForImageTextToText, AutoProcessor

gemma_processor = AutoProcessor.from_pretrained(
    "google/gemma-3-12b-it", trust_remote_code=True
)
gemma_model = AutoModelForImageTextToText.from_pretrained(
    "google/gemma-3-12b-it",
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
    device_map="cuda:1",
    torch_dtype="bfloat16",
).eval()

messages = [
    {
        "role": "system",
        "content": [
            {
                "type": "text",
                "text": "Generate a message referring to one of the images.",
            }
        ],
    },
    {
        "role": "user",
        "content": [
            {
                "type": "text",
                "text": "I will show  4 images labelled as A, B, C, D. I will then mention an image. Describe the image corresponding to the label. Your response should only contain the message. Your message does not need to be a full sentence. Your message should be a fluent description.",
            },
            {"type": "text", "text": "Round 1, "},
            {"type": "text", "text": "\nImage A: "},
            {
                "type": "image",
                "url": "http://images.cocodataset.org/val2014/COCO_val2014_000000166401.jpg",
            },
            {"type": "text", "text": "\nImage B: "},
            {
                "type": "image",
                "url": "http://images.cocodataset.org/val2014/COCO_val2014_000000140076.jpg",
            },
            {"type": "text", "text": "\nImage C: "},
            {
                "type": "image",
                "url": "http://images.cocodataset.org/val2014/COCO_val2014_000000290477.jpg",
            },
            {"type": "text", "text": "\nImage D: "},
            {
                "type": "image",
                "url": "http://images.cocodataset.org/val2014/COCO_val2014_000000213224.jpg",
            },
            {
                "type": "text",
                "text": "Describe Image B. Generate only a message containing a description.",
            },
        ],
    },
]

inputs = gemma_processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
)

output_tokens = gemma_model.generate(
    **inputs.to(gemma_model.device, gemma_model.dtype),
    do_sample=True,
    max_new_tokens=128,
    temperature=1.0,
    top_p=1.0,
    num_return_sequences=8,
    tokenizer=gemma_processor.tokenizer,
)
outputs = gemma_processor.batch_decode(
    output_tokens[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
)

This yields the following list for outputs:

['A luxurious bathroom space with dark cabinetry, a large countertop sink and mirror, a tiled accent wall, and a soaking tub in the corner.',
 'A luxury bathroom with dark cabinetry, a stone countertop, and a large mirror illuminated by a modern light fixture; a potted plant and candles add decorative touches alongside a glimpse of a bathtub and a view out a window.',
 'A vibrant patchwork quilt hangs above a dark wood dining table set with red placemats, adorned with a vase of yellow tulips and a figurine.',
 'A dining room scene with a patchwork wall hanging, dark leather chairs, a wooden table set with red placemats, and vibrant tulips in a glass vase.',
 'A lush, green foliage arrangement bursts from a metallic vase, resting on a vibrant purple cloth atop a wooden altar; flanked by tall candlesticks.',
 'A vibrant, leafy arrangement sits in a decorative bronze vase, centered on a purple runner, flanked by tall candlesticks in a church setting.',
 'A vibrant arrangement of lilies, carnations, and other blossoms overflowing from a clear glass vase, complemented by smaller vases of red flowers on a wooden table.',
 'A vibrant flower arrangement in a clear glass vase, complemented by smaller vases with red blooms, all sitting on a wooden table.']

Note how the first two captions are for the first image, the second two for the second image, and so on. This should not be the case, the model is capable of describing the correct image. A description of how this can be determined is under expected behavior.

Expected behavior

If, instead of asking for 8 completions for 1 prompt, I ask for 1 completion each of 8 copies of the prompt, this issue is fixed.

inputs = gemma_processor.apply_chat_template(
    [messages for _ in range(8)],
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
)

output_tokens = gemma_model.generate(
    **inputs.to(gemma_model.device, gemma_model.dtype),
    do_sample=True,
    max_new_tokens=128,
    temperature=1.0,
    top_p=1.0,
    num_return_sequences=1,
    tokenizer=gemma_processor.tokenizer,
)
outputs = gemma_processor.batch_decode(
    output_tokens[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
)

yields the outputs

['A vibrant patchwork quilt adorns a stark white wall, centered above a dark wood dining table set with cheerful tulips.',
 'A vibrant, patchwork textile hangs on a crisp white wall, complemented by a dark wooden chair and a table set with red placemats and a vase of tulips.',
 'A vibrant, patchwork quilt dominates a white wall, framed by a dark wooden chair and table with red accents and a tulip arrangement.',
 'A vibrant, intricately patched textile hangs on a white wall, complemented by a wooden chair and a dining table set with tulips and place settings.',
 'A vibrant patchwork quilt dominates the wall above a dark wooden table set with black chairs and a vase of tulips.',
 'A vibrant patchwork quilt adorns a white wall, centered above a dark wood dining table set with black chairs and a bouquet of tulips.',
 'A vibrant patchwork textile hangs on a white wall, complemented by a dark wood dining table set with red placemats and a vase of tulips.',
 'A vibrant patchwork wall hanging dominates, framed by a simple white wall, accented by a dark wooden chair and a table set with tulips.']

which is expected behavior.

This suggests a bug in how the inputs are expanded for generation.

@saujasv saujasv added the bug label Apr 30, 2025
@zucchini-nlp
Copy link
Member

@saujasv thanks for reporting! We've had similar issues with dataset iterations when training the model, because in multi-image case the first dimension doesn't reflect the batch size. I believe it is the same issue in generate(), didn't know it was affecting num_return_sequences. I will investigate and see how it can be fixed

cc @yonigozlan one more case to take into account when refactoring IDEFICS image processors :)

@zucchini-nlp zucchini-nlp self-assigned this May 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants