# MusicFlamingo FP8 LLM + Multimodal Inference

This notebook loads the full **MusicFlamingoForConditionalGeneration** model while swapping in the **FP8‑quantized LLM** from `models/music-flamingo-2601-llm-fp8`.
It then runs a single audio + text inference to verify the multimodal path works.


In [1]:
from pathlib import Path
import torch
from safetensors import safe_open
from huggingface_hub import hf_hub_download
from transformers import (
    MusicFlamingoConfig,
    MusicFlamingoForConditionalGeneration,
    AutoProcessor,
    AutoModelForCausalLM,
)
import shutil


model_id = "nvidia/music-flamingo-2601-hf"
llm_dir = "models/music-flamingo-2601-llm-fp8"
full_model_path = "models/full-music-flamingo.safetensor"
mmproj_model = "models/mmproj-music-flamingo.safetensor"


if not Path(llm_dir).exists():
    raise FileNotFoundError(f"Missing FP8 LLM dir: {llm_dir}")

if not torch.cuda.is_available():
    raise RuntimeError("CUDA is required for FP8 inference.")


In [4]:
print("Loading base config...")
config = MusicFlamingoConfig.from_pretrained(model_id)
model = MusicFlamingoForConditionalGeneration(config)
model.eval()

print("Loading audio_tower + multi_modal_projector weights...")
weights_path = hf_hub_download(model_id, filename="model.safetensors")
#shutil.move(source
state_dict = {}
with safe_open(weights_path, framework="pt", device="cpu") as f:
    for k in f.keys():
        if k.startswith("audio_tower.") or k.startswith("multi_modal_projector."):
            state_dict[k] = f.get_tensor(k)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(f"Loaded {len(state_dict)} tensors for audio + projector")

print("Loading FP8 LLM...")
llm = AutoModelForCausalLM.from_pretrained(
    llm_dir,
    device_map="cuda",
    dtype="auto",
)
model.language_model = llm
model.vocab_size = llm.config.vocab_size
model.config.text_config = llm.config

print("Moving audio_tower and projector to GPU (bf16)...")
model.audio_tower = model.audio_tower.to(device="cuda", dtype=torch.bfloat16)
model.multi_modal_projector = model.multi_modal_projector.to(device="cuda", dtype=torch.bfloat16)

print("Loading processor...")
processor = AutoProcessor.from_pretrained(model_id)


Loading base config...
Loading audio_tower + multi_modal_projector weights...
Loaded 492 tensors for audio + projector
Loading FP8 LLM...


Loading weights:   0%|          | 0/535 [00:00<?, ?it/s]

Moving audio_tower and projector to GPU (bf16)...




Loading processor...


In [5]:
conversation = [
    {
        "role": "user",
        "content": [
            {
                "type": "text",
                "text": (
                    "Describe this track in full detail - tell me the genre, tempo, and key, then "
                    "dive into the instruments, production style, and overall mood it creates."
                ),
            },
            {
                "type": "audio",
                "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3",
            },
        ],
    }
]

inputs = processor.apply_chat_template(
    conversation,
    tokenize=True,
    add_generation_prompt=True,
    return_dict=True,
)

# Move all inputs to GPU (audio + text)
for key, value in list(inputs.items()):
    if torch.is_tensor(value):
        inputs[key] = value.to("cuda")

with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=128)

decoded = processor.batch_decode(
    outputs[:, inputs["input_ids"].shape[1]:],
    skip_special_tokens=True,
)
print("\n=== Output ===\n")
print(decoded[0])



=== Output ===

This track is an energetic Eurodance / Dance‑Pop anthem that blends the bright, hook‑laden sensibility of mainstream pop with the driving, club‑ready pulse of classic Eurodance.  The duration of the piece is 163.59 seconds.
Tempo & key – The song moves at a brisk 150 BPM and is rooted in E major.
Instrumentation & production – A polished, high‑fidelity production frames the arrangement. The rhythm foundation is built on a four‑on‑the‑floor electronic drum kit with crisp kick,
