In [1]:
import sys
import os
import torch
import torchaudio
import numpy as np
from IPython.display import Audio
from chroma.qwen2_5_omni_config import (
    Qwen2_5OmniTextConfig,
    Qwen2_5OmniThinkerConfig,
    Qwen2_5OmniAudioEncoderConfig,
    Qwen2_5OmniVisionEncoderConfig,
)
from chroma.qwen2_5_omni_config import Qwen2_5OmniThinkerConfig
from chroma.qwen2_5_omni_modeling import Qwen2_5OmniThinkerForConditionalGeneration

  from .autonotebook import tqdm as notebook_tqdm


# 1. SGLang Environment Initialization

In [2]:
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")
os.environ.setdefault("LOCAL_RANK", "0")
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29501")

if not torch.distributed.is_initialized():
    backend = "nccl" if torch.cuda.is_available() else "gloo"
    torch.distributed.init_process_group(
        backend=backend, init_method="env://", world_size=1, rank=0
    )

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.cuda.set_device(0)

# 2. Initialize SGLang DP Attention

In [3]:
import sglang.srt.layers.dp_attention as dp_attn

dp_attn._ATTN_TP_SIZE = 1
dp_attn._ATTN_TP_RANK = 0

# 3. Initialize SGLang Global Server Args

In [4]:
import sglang.srt.server_args as server_args_module
from sglang.srt.server_args import ServerArgs

server_args = ServerArgs(
    model_path="dummy",
    tp_size=1,
    mm_attention_backend="sdpa",
)
server_args_module._global_server_args = server_args

# 4. Initialize SGLang Tensor Parallel Groups

In [5]:
from sglang.srt.distributed.parallel_state import (
    initialize_model_parallel,
    init_distributed_environment,
)

init_distributed_environment(
    backend="nccl" if torch.cuda.is_available() else "gloo",
)

initialize_model_parallel(
    tensor_model_parallel_size=1,
    pipeline_model_parallel_size=1,
)

[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0


# 5. load chroma

In [6]:
import sys
import torch
from torch.nn import Parameter as OriginalParameter

_original_parameter_new = OriginalParameter.__new__

def patched_parameter_new(cls, *args, **kwargs):
    sglang_attrs = ['input_dim', 'output_dim', 'weight_loader', 'weight_loader_v2']
    for attr in sglang_attrs:
        kwargs.pop(attr, None)
    return _original_parameter_new(cls, *args, **kwargs)

OriginalParameter.__new__ = staticmethod(patched_parameter_new)

In [None]:
import os
from safetensors.torch import safe_open
import torch
from chroma.qwen2_5_omni_config import Qwen2_5OmniConfig
from chroma.qwen2_5_omni_modeling import Qwen2_5OmniModel
from chroma.modeling_chroma import ChromaForConditionalGeneration
from chroma.processing_chroma import ChromaProcessor
from chroma.configuration_chroma import ChromaConfig
import logging
import warnings

# Suppress transformers warnings
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
logging.getLogger("transformers").setLevel(logging.ERROR)
warnings.filterwarnings("ignore", message=".*were not sharded.*")
warnings.filterwarnings("ignore", message=".*were not used.*")

chroma_model_path = "/models/Qwencsm/Chroma/checkpoints/chroma_1121"

# Load config directly from chroma model path
chroma_config = ChromaConfig.from_pretrained(chroma_model_path)
sgl_cfg = chroma_config.thinker_config
sgl_model = Qwen2_5OmniModel(sgl_cfg, quant_config=None)

chroma_model = ChromaForConditionalGeneration.from_pretrained(
    chroma_model_path,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)

chroma_model = chroma_model.to(device)
chroma_model.eval()

def iter_thinker_weights_from_chroma_ckpt(ckpt_dir: str):
    weight_files = []
    for fn in os.listdir(ckpt_dir):
        if fn.endswith(".safetensors"):
            weight_files.append(os.path.join(ckpt_dir, fn))
    weight_files.sort()

    for wf in weight_files:
        with safe_open(wf, framework="pt", device="cpu") as f:
            for k in f.keys():
                if k.startswith("thinker."):
                    yield (k, f.get_tensor(k))

sgl_model.load_weights(iter_thinker_weights_from_chroma_ckpt(chroma_model_path))
sgl_model = sgl_model.to("cuda:0").to(torch.bfloat16).eval()
chroma_model.thinker = sgl_model.thinker

print("Model loaded successfully")

In [8]:
processor = ChromaProcessor.from_pretrained(chroma_model_path)
tokenizer = processor.tokenizer

In [9]:
if hasattr(processor.tokenizer, 'chat_template'):
    processor.chat_template = processor.tokenizer.chat_template

# 6.inference

In [10]:
# Prepare conversation
default_prompt = "Please ensure your responses are concise and under 88 words."

conversation = [
    {
        "role": "system",
        "content": [
            {"type": "text", "text": default_prompt},
        ],
    },
    {
        "role": "user", 
        "content": [
            {"type": "audio", "audio": "assets/question_audio.wav"},
        ],
    },
]

prompt_text = "I have not... I'm so exhausted, I haven't slept in a very long time. It could be because... Well, I used our... Uh, I'm, I just use... This is what I use every day. I use our cleanser every day, I use serum in the morning and then the moistu- daily moisturizer. That's what I use every morning."
prompt_audio = "assets/ref_audio.wav"

print(f"Prompt text: {prompt_text}")
print(f"Prompt audio: {prompt_audio}")

Prompt text: I have not... I'm so exhausted, I haven't slept in a very long time. It could be because... Well, I used our... Uh, I'm, I just use... This is what I use every day. I use our cleanser every day, I use serum in the morning and then the moistu- daily moisturizer. That's what I use every morning.
Prompt audio: assets/ref_audio.wav


In [11]:
# Process inputs
inputs = processor.__call__(
    conversation,
    add_generation_prompt=True, 
    tokenize=False,
    prompt_audio=prompt_audio,
    prompt_text=prompt_text
)

for k, v in inputs.items():
    if isinstance(v, torch.Tensor):
        if v.dtype in [torch.float32, torch.float64]:
            inputs[k] = v.to(torch.bfloat16)
        inputs[k] = inputs[k].to(device)

print("Inputs prepared")



Inputs prepared


In [12]:
# Generate
import time

print("Generating...")

with torch.no_grad():
    generated = chroma_model.generate(
        **inputs, 
        max_new_tokens=1000, 
        do_sample=True, 
        use_cache=True,
        output_attentions=False,
        output_hidden_states=False,
    )

Generating...
Generation time: 6.773123741149902


In [13]:
# Decode audio
generated_d = generated.permute(0, 2, 1).to(device)
generated_d = generated_d.clamp(min=0, max=2047)

output = chroma_model.codec_model.decode(generated_d)
wav = output.squeeze(0).squeeze(0)

print(f"Audio generated: {wav.shape[0] / 24000:.2f}s")

Audio(wav.cpu().float().detach().numpy(), rate=24000)

Audio generated: 12.24s
