In [1]:
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
from qwen_omni_utils import process_mm_info
import torch
import soundfile as sf
import os
import matplotlib.pyplot as plt
from PIL import Image

# --- 强制显示 Hugging Face Hub 下载进度条 ---
# os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # 启用并行下载器，通常更快
# os.environ["HF_HUB_PROGRESS"] = "1"         # 强制显示进度条

# 加载模型和处理器
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-Omni-3B",
    torch_dtype="auto",
    device_map="cuda",
    attn_implementation="flash_attention_2",  # 使用 Flash Attention 2
)
model.disable_talker()
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-3B")

# img_path = "../img/BlueUp1.jpg"
# image = Image.open(img_path).convert("RGB")
# image = image.resize((224, 224))  # 或者 (384, 384)
# image.save("resized.jpg")

# conversation = [
#     {
#         "role": "system",
#         "content": [{"type": "text", "text": "You are a helpful assistant that can understand images and answer questions."}],
#     },
#     {
#         "role": "user",
#        
# "content": [
#             {"type": "image", "image": "resized.jpg"},
#             {"type": "text", "text": "Descirbe objects and their relative locations in details: "}
#         ],
#     },
# ]


# # 如果不涉及音频，设为 False
# USE_AUDIO_IN_VIDEO = False

# # 准备推理输入
# text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
# audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO)
# inputs = processor(
#     text=text,
#     audio=audios,
#     images=images,
#     videos=videos,
#     return_tensors="pt",
#     padding=True,
#     use_audio_in_video=USE_AUDIO_IN_VIDEO,
# )
# inputs = inputs.to(model.device).to(model.dtype)

# # 推理（仅生成文字，无需语音时可不保存 audio）
# # text_ids, _ = model.generate(**inputs, use_audio_in_video=USE_AUDIO_IN_VIDEO)

# text_ids = model.generate(
#     **inputs,
#     use_audio_in_video=False,           # 不启用音频
#     return_audio=False, # 不返回音频
#     max_new_tokens=20,                  # 限制输出长度
#     do_sample=True,                     # 采样生成（非贪心）
#     temperature=0.7,                    # 控制多样性
#     top_p=0.9,                          # nucleus sampling
#     repetition_penalty=1.1             # 减少重复输出
# )


# output_text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
# print(output_text)


Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
Qwen2_5OmniToken2WavModel must inference with fp32, but flash_attention_2 only supports fp16 and bf16, attention implementation of Qwen2_5OmniToken2WavModel will fallback to sdpa.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.


In [2]:
def inference_img(img_path, 
                system_prompt="You are a helpful assistant that can understand images and answer questions.", 
                user_prompt="Describe objects and their relative locations in details: ",
                resize_location="resized.jpg",
                resieze_size=(224, 224)):
    img_path = img_path
    image = Image.open(img_path).convert("RGB")
    image = image.resize(resieze_size) 
    image.save(resize_location)
    
    # plot resized image
    # plt.imshow(image)
    # plt.axis('off')  # 去掉坐标轴
    # plt.title("Resized Image")
    # plt.show()
    
    conversation = [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_prompt}],
        },
        {
            "role": "user",
            "content": [
                {"type": "image", "image": resize_location},
                {"type": "text", "text": user_prompt}
            ],
        },
    ]

    if_return_audio = False  # 如果不涉及音频，设为 False
    # 准备推理输入
    text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
    audios, images, videos = process_mm_info(conversation, use_audio_in_video=if_return_audio)
    # print(f"audios: {audios}, images: {images}, videos: {videos}")
    inputs = processor(
        text=text,
        audio=audios,
        images=images,
        videos=videos,
        return_tensors="pt",
        padding=True,
        use_audio_in_video=if_return_audio,
    )
    inputs = inputs.to(model.device).to(model.dtype)

    text_ids = model.generate(
        **inputs,
        use_audio_in_video=if_return_audio,           # 不启用音频
        return_audio=False,                 # 不返回音频
        # max_new_tokens=50,                # 限制输出长度
        # max_length=10, 
        do_sample=True,                     # 采样生成（非贪心）
        temperature=0.5,                    # 控制多样性
        top_p=0.9,                          # nucleus sampling
        repetition_penalty=1.1              # 减少重复输出
    )

    # print(model.generation_config)
    text_ids = text_ids[:, inputs["input_ids"].shape[1]:]  # 截取生成的部分
    output_text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    # print(output_text)
    return output_text

user_prompt = """Describe object(s) and (possible) their relative locations in brief with 1 very short sentence:
                DO NOT describe background/surface: 
                for example: 
                A red phone is under a wood mouse. ; 
                Only one red cup. ; 
                The image contains no visible objects""" 


inference_img(img_path="../img/img1.jpg", user_prompt=user_prompt)



'A blue robot is next to a red block.'

In [3]:
inference_img(img_path="../image_train/blue_0001.jpg", user_prompt=user_prompt)



'A blue cube is on the floor.'

In [4]:
user_prompt = """Describe objects' relative locations""" 


import json
from tqdm.notebook import tqdm
import logging
# 屏蔽 WARNING 及以下级别的日志（只显示 ERROR 和 CRITICAL）
logging.getLogger().setLevel(logging.ERROR)

# 输入输出文件路径
input_json_file = "../annotations.jsonl"      # 你的输入 JSON 文件
output_json_file = "../prediction.jsonl"  # 输出文件名

with open(input_json_file, "r", encoding="utf-8") as fin:
    total_lines = sum(1 for _ in fin)  # 先统计总行数

with open(input_json_file, "r", encoding="utf-8") as fin, \
     open(output_json_file, "w", encoding="utf-8") as fout:

    for line in tqdm(fin, total=total_lines, desc="Processing"):
        dataset_dir_prefix = "../"
        data = json.loads(line.strip())
        img_path = dataset_dir_prefix + data["image"]
        
        target = data["text"]
        if target != "A green cube on top of a blue cube":
            continue

        # inference
        prediction = inference_img(img_path=img_path, user_prompt=user_prompt)

        data["prediction"] = prediction
        
        fout.write(json.dumps(data, ensure_ascii=False) + "\n")

Processing:   0%|          | 0/1400 [00:00<?, ?it/s]