In [1]:
# Testing batch inference of Music Flamingo
from transformers import (
    MusicFlamingoForConditionalGeneration,
    AutoProcessor,
    TextStreamer
)
import torch
import librosa


In [2]:
model = MusicFlamingoForConditionalGeneration.from_pretrained(
    "./music_flamingo_fp8",
    device_map="cuda",
    dtype="auto",
    attn_implementation="sdpa",
)
processor = AutoProcessor.from_pretrained("./music_flamingo_fp8")
streamer = TextStreamer(processor)


Loading weights:   0%|          | 0/1027 [00:00<?, ?it/s]

In [3]:
# Perf flags
torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Enable KV cache (default config has use_cache=False)
for cfg in (model.config, model.language_model.config, getattr(model.config, "text_config", None)):
    if cfg is not None and hasattr(cfg, "use_cache"):
        cfg.use_cache = True
model.generation_config.use_cache = True
model.language_model.generation_config.use_cache = True
model.generation_config.cache_implementation = "dynamic"
model.language_model.generation_config.cache_implementation = "dynamic"
model.generation_config.max_new_tokens = 2048
model.generation_config.do_sample = False


In [4]:
# High-TPS batched inference using SGLang's low-level runner
import os
import sys
import json
from typing import List, Dict, Any

SGLANG_PYTHON_PATH = os.path.join(os.getcwd(), "sglang", "python")
if SGLANG_PYTHON_PATH not in sys.path:
    sys.path.insert(0, SGLANG_PYTHON_PATH)

_sglang_runner = None


def _get_sglang_runner(attention_backend: str = "triton", tp_size: int = 1):
    global _sglang_runner
    if _sglang_runner is not None:
        return _sglang_runner

    from sglang.bench_one_batch import load_model
    from sglang.srt.entrypoints.engine import _set_envs_and_config
    from sglang.srt.layers.moe import initialize_moe_config
    from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config
    from sglang.srt.server_args import PortArgs, ServerArgs

    model_override = {"architectures": ["MusicFlamingoQwen2ForCausalLM"], "model_type": "qwen2"}
    server_args_kwargs = {
        "model_path": "./music_flamingo_fp8",
        "tokenizer_path": "./music_flamingo_fp8",
        "trust_remote_code": True,
        "tp_size": tp_size,
        "disable_radix_cache": True,
        "json_model_override_args": json.dumps(model_override),
        "fp8_gemm_runner_backend": "cutlass",
        "disable_cuda_graph": True,
        "sampling_backend": "pytorch",
        "grammar_backend": "none",
    }
    if attention_backend:
        server_args_kwargs["attention_backend"] = attention_backend

    server_args = ServerArgs(**server_args_kwargs)
    _set_envs_and_config(server_args)
    initialize_moe_config(server_args)
    initialize_fp8_gemm_config(server_args)

    port_args = PortArgs.init_new(server_args)
    model_runner, _ = load_model(server_args, port_args, gpu_id=0, tp_rank=0)
    _sglang_runner = model_runner
    return _sglang_runner


@torch.inference_mode()
def infer_music_batch(
    batch: List[Dict[str, Any]],
    max_new_tokens: int = 256,
    attention_backend: str = "triton",
):
    """
    batch: list of {"text": str, "audio": path-or-array}
    returns: (music_embeds_list, descriptions_list)
    """
    # Build batched prompts + audio
    conversations = [
        [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": item["text"]},
                    {"type": "audio", "path": "<audio>"},
                ],
            }
        ]
        for item in batch
    ]
    prompts = processor.apply_chat_template(
        conversations,
        tokenize=False,
        add_generation_prompt=True,
    )

    audios = []
    for item in batch:
        audio = item["audio"]
        if isinstance(audio, str):
            audio, _ = librosa.load(audio, sr=16000)
        audios.append(audio)

    inputs = processor(
        text=prompts,
        audio=audios,
        return_tensors="pt",
        padding=True,
    ).to(model.device)

    input_ids = inputs["input_ids"]
    input_features = inputs["input_features"]
    input_features_mask = inputs["input_features_mask"]
    audio_times = inputs.get("audio_times")

    # Audio tower -> projector (keep per-sample embeddings)
    encoder_out = model.audio_tower(
        input_features,
        input_features_mask=input_features_mask,
        audio_times=audio_times,
    )
    audio_embeds_full = model.multi_modal_projector(encoder_out.last_hidden_state)

    post_lengths = (input_features_mask.sum(-1) - 2) // 2 + 1
    valid_mask = (
        torch.arange(audio_embeds_full.shape[1], device=post_lengths.device)[None, :]
        < post_lengths[:, None]
    )
    audio_embeds_flat = audio_embeds_full[valid_mask]

    inputs_embeds = model.get_input_embeddings()(input_ids)
    audio_token_mask = (input_ids == model.config.audio_token_id).unsqueeze(-1)
    inputs_embeds = inputs_embeds.masked_scatter(audio_token_mask, audio_embeds_flat)

    music_embeds = [
        audio_embeds_full[i, : int(post_lengths[i])].detach().cpu()
        for i in range(audio_embeds_full.size(0))
    ]

    # SGLang batched decode (same high-TPS path as bench_inference)
    from sglang.bench_one_batch import decode, extend
    from sglang.srt.managers.schedule_batch import Req
    from sglang.srt.sampling.sampling_params import SamplingParams

    runner = _get_sglang_runner(attention_backend=attention_backend)
    sampling_params = SamplingParams(temperature=0.0, max_new_tokens=max_new_tokens)

    input_embeds_list = inputs_embeds.detach().cpu().float().tolist()
    reqs = []
    for i, emb in enumerate(input_embeds_list):
        fake_ids = [1] * len(emb)
        req = Req(
            rid=str(i),
            origin_input_text="",
            origin_input_ids=fake_ids,
            sampling_params=sampling_params,
            input_embeds=emb,
        )
        req.fill_ids = req.origin_input_ids
        req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
        req.logprob_start_len = len(req.origin_input_ids) - 1
        reqs.append(req)

    next_token_ids, _, batch_state = extend(reqs, runner)
    if hasattr(next_token_ids, "tolist"):
        next_token_ids = next_token_ids.tolist()

    output_ids = [[int(tok)] for tok in next_token_ids]
    for _ in range(max_new_tokens - 1):
        next_token_ids, _ = decode(next_token_ids, batch_state, runner)
        if hasattr(next_token_ids, "tolist"):
            next_token_ids = next_token_ids.tolist()
        for idx, tok in enumerate(next_token_ids):
            output_ids[idx].append(int(tok))

    descriptions = processor.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    return music_embeds, descriptions


In [5]:
conversation = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates."},
            {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3"},
        ],
    }
]

# Build prompt without loading audio inside apply_chat_template
prompt = processor.apply_chat_template(
    conversation,
    tokenize=False,
    add_generation_prompt=True,
)

audio, _ = librosa.load(
    "../music/Lorde-Pure_Heroine-24BIT-WEB-FLAC-2013-TVRf/04-lorde-ribs.flac",
    sr=16000,
)
inputs = processor(
    text=prompt,
    audio=audio,
    return_tensors="pt",
).to(model.device)


In [1]:
#warmup
model.generate(**inputs, max_new_tokens=2056)

NameError: name 'model' is not defined

In [7]:
import time

torch.cuda.synchronize()
start = time.perf_counter()
outputs = model.generate(**inputs, max_new_tokens=2056)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
new_tokens = outputs.shape[1] - inputs["input_ids"].shape[1]
toks_per_s = new_tokens / elapsed
toks_per_s


21.17148250660933

In [8]:
toks_per_s


21.17148250660933