In [None]:
import transformers
import torch
import json
from tqdm import tqdm
import os
import re

from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

transformers.logging.set_verbosity_error()

In [2]:
def json_image_iterator(root_dir):
    images_dir = os.path.join(root_dir, "images")

    for json_file in os.listdir(root_dir):
        if json_file.endswith(".json"):
            json_path = os.path.join(root_dir, json_file)
            
            with open(json_path, "r") as f:
                data_list = json.load(f)

            file_name = os.path.splitext(json_file)[0]
            
            for data in data_list:
                image_filename = data["image_path"].split("/")[-1]
                
                image_found = False
                for root, dirs, files in os.walk(images_dir):
                    if image_filename in files:
                        image_path = os.path.join(root, image_filename)
                        
                        try:
                            # image = Image.open(image_path)
                            yield file_name, data, image_path 
                            image_found = True
                        except Exception as e:
                            print(f"Error opening image {image_path}: {e}")
                        break
                
                if not image_found:
                    print(f"Image not found for entry in {json_file}: {image_filename}")
                    yield file_name, data, None  

In [3]:
en_prompt = '''Write a caption in English for an image that may include culturally significant objects or elements from Southeast Asia.  
The caption should specifically name Southeast Asian cultural items, such as cuisine, traditions, landmarks, or other related elements if they appear in the image.
The caption should be concise, consisting of 3 to 5 sentences.'''
save_path= "qwen_seavqa_en_result.json"
image_root_path = r"/root/filter_sea_vqa_final/filter_sea_vqa_final"

In [None]:
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28

model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.float16, device_map="auto"
)

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)

In [5]:
def concat_message(caption_list, prompt):
    transformed_objs = []
    for obj in caption_list:
        new_obj = [{
            "role": "user",
            "content": [
                {"type": "image", "image": obj["image"]},
                {"type": "text", "text": prompt}
            ]
        }]
        transformed_objs.append(new_obj)
    return transformed_objs

In [6]:
def get_caption_batch(caption_list, prompt):
    messages = concat_message(caption_list, en_prompt)
    texts = [
        processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
        for msg in messages
    ]
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=texts,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")
    
    # Batch Inference
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_texts = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )


    return output_texts

In [7]:
def get_caption(image, prompt):
    messages =[{
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt}
            ]
        }]

    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")
    
    # Inference: Generation of the output
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    return output_text

In [8]:
def save_to_json(save_path, json_obj):
    if not os.path.exists(save_path):
        with open(save_path, "w") as f:
            json.dump([json_obj], f, indent=4)
    else:
        with open(save_path, "r") as f:
            data = json.load(f)
        data.append(json_obj)
        with open(save_path, "w") as f:
            json.dump(data, f, indent=4)  

In [None]:
iterator = json_image_iterator(image_root_path)
items = list(iterator)

for country, data, image in tqdm(items, desc="Progess"):
    caption = get_caption(image, en_prompt)

    json_obj={"name":data["culture_name"],
              "country":country,
              "image_url":data["image_path"],
              "gt_caption":data["gt_caption"],
              "caption":caption}
            
    save_to_json(save_path, json_obj)