Closed
Description
Hi,
First, thank you for releasing Audio Flamingo 2 and the detailed code!
While working with the audio_flamingo_2/eval/inference.py
script, I noticed that the inference loop (lines 115–158) processes each audio and prompt individually, one at a time. Even though the model, tokenizer, and dataloader all seem to support batching, the call to model.generate()
is performed with batch_size=1
in a for loop.
Relevant code snippet (lines 115–158):
for idx in range(input_ids.shape[0]):
...
output = model.generate(
audio_x=audio_clips[idx].unsqueeze(0),
audio_x_mask=audio_embed_mask[idx].unsqueeze(0),
lang_x=prompt.unsqueeze(0),
...
)[0]
This approach appears to run inference on each sample separately rather than as a batch.
My questions:
- Is there a technical or architectural reason for not supporting fully batched inference in this script?
- Is it possible to run
model.generate()
on the entire batch at once, using all batched inputs, and are there any known issues or limitations?
Thanks for your time and for clarifying this design choice!
Metadata
Metadata
Assignees
Labels
No labels