In [1]:
from transformers import AutoModelForImageTextToText, AutoProcessor, DataCollatorForSeq2Seq
from trl.trainer.sft_trainer import DataCollatorForVisionLanguageModeling
from processor.color_enhance_collator import ColorSensitiveCollator
from processor.color_simulate_collator import ColorSimulateCollator
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info
import torchvision
from matplotlib import pyplot as plt
from datasets import Dataset
from PIL import Image
import pandas as pd
import os
from torchvision import transforms
from tqdm import tqdm

# MODEL_NAME = "/root/autodl-tmp/model"  # 或任何支持 swift 的多模态模型
MODEL_NAME = "/root/sentence_estimator/output_color_sensitive/deu80-checkpoint-500/"
DATA_PATH = "/root/color_150k.json"  # SFT 数据：包含 {"instruction": ..., "output": ...}
OUTPUT_DIR = "./output_color_sensitive"

# ===== 加载模型与processor =====
model = AutoModelForImageTextToText.from_pretrained(
    MODEL_NAME, dtype="auto", device_map="auto"
)
dtype = model.dtype
device = model.device
processor = AutoProcessor.from_pretrained(MODEL_NAME)
collator = ColorSimulateCollator(processor,cvd_type="deutan_80")
SYS_PROMPT = """ You are a color blind with limited perception on image. 
However, you can still guess the color from your experience. """
COLOR_WORDS = [
    "red", "green", "blue", "yellow", "orange",
    "purple", "pink", "brown", "gray", "black",
    "white", "chocolate"
]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The module name  (originally ) is not a valid Python identifier. Please rename the original module to avoid import issues.


In [2]:
def predict(img_path:str):
    """Given test image, let the CVD-LLM recognize color. """
    # 构建标准格式消息
    sample_message = [
            {    
                # "image": "/root/autodl-tmp/images/train2017/000000033471.jpg",
                "image": img_path,
                "messages": [
                    {
                        "role":"system",
                        "content":[{"type":"text","text":SYS_PROMPT}]
                    },

                    {
                        "role": "user",
                        "content": [{"type": "text", "text": "Identify the color in the bounding box, which is also shown on the right side. (pay attention to color especially)"},
                                    # {"type": "image",  "image": "/root/autodl-tmp/images/train2017/000000033471.jpg"},
                                    {"type": "image",  "image": img_path},
                                    ]

                    },
                    # {
                    #     "role": "assistant",
                    #     "content": [{"type": "text", "text": "The image is a street scene with a car and a person."}]
                    # }
                ]
            }
        ]
    out = collator(sample_message).to("cuda")
    # from transformers import TextStreamer
    # text_streamer = TextStreamer(processor, skip_prompt = True)
    # 利用model.generate生成回复，直接返回文本
    # 生成回复
    generated_ids = model.generate(**out, max_new_tokens = 256,
                    use_cache = True, temperature = 1.5, min_p = 0.1)
    # 解码回复
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens = True)
    return generated_text[0]


In [4]:
df_result = pd.DataFrame(columns=['Filename', 'GT', 'Color'])
for root, dirs, files in os.walk('./dataset/eval_images'):
    for f in tqdm(files):
        ori_f_name = f.split('_')[1:-1]
        ori_f_name = '_'.join(ori_f_name)
        label = f.split('_')[1]
        result_sentence = predict(os.path.join(root, f))
        result_sentence = result_sentence.lower()
        # 从生成的句子中提取颜色单词
        color_words_result = [word for word in COLOR_WORDS if word in result_sentence]
        color_result = "_".join(color_words_result)
        df_result.loc[len(df_result)] = {'Filename': ori_f_name, 'GT': label, 'Color': color_result}
df_result.to_csv('LLM_color_eval_result.csv', index=False)


100%|██████████| 110/110 [00:44<00:00,  2.48it/s]
