In [None]:
import os
%cd ~
if not os.path.exists("LLaVA-NeXT"):
    !git clone https://github.com/LLaVA-VL/LLaVA-NeXT
%cd /LLaVA-NeXT

In [6]:
import transformers
import torch
import json
from tqdm import tqdm
import os
import re


from llava.model.builder import load_pretrained_model
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.utils import disable_torch_init
from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from typing import Dict
from PIL import Image

transformers.logging.set_verbosity_error()

In [11]:
def json_image_iterator(root_dir):
    images_dir = os.path.join(root_dir, "images")

    for json_file in os.listdir(root_dir):
        if json_file.endswith(".json"):
            json_path = os.path.join(root_dir, json_file)
            
            with open(json_path, "r") as f:
                data_list = json.load(f)

            file_name = os.path.splitext(json_file)[0]
            
            for data in data_list:
                image_filename = data["image_path"].split("/")[-1]
                
                image_found = False
                for root, dirs, files in os.walk(images_dir):
                    if image_filename in files:
                        image_path = os.path.join(root, image_filename)
                        
                        try:
                            # image = Image.open(image_path)
                            yield file_name, data, image_path 
                            image_found = True
                        except Exception as e:
                            print(f"Error opening image {image_path}: {e}")
                        break
                
                if not image_found:
                    print(f"Image not found for entry in {json_file}: {image_filename}")
                    yield file_name, data, None  

In [12]:
en_prompt = '''Write a caption in English for an image that may include culturally significant objects or elements from Southeast Asia.  
The caption should specifically name Southeast Asian cultural items, such as cuisine, traditions, landmarks, or other related elements if they appear in the image.
The caption should be concise, consisting of 3 to 5 sentences.'''
save_path= "pangea_seavqa_en_result.json"
image_root_path = r"/root/filter_sea_vqa_final/filter_sea_vqa_final"

In [None]:
model_path = 'neulab/Pangea-7B'
model_name = 'Pangea-7B-qwen'
args = {"multimodal": True}
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, torch_dtype="float16", **args)

In [13]:
def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=512, system_message: str = "You are a helpful assistant.") -> Dict:
    roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
    im_start, im_end = tokenizer.additional_special_tokens_ids
    nl_tokens = tokenizer("\n").input_ids
    _system = tokenizer("system").input_ids + nl_tokens
    _user = tokenizer("user").input_ids + nl_tokens
    _assistant = tokenizer("assistant").input_ids + nl_tokens
    input_ids = []
    source = sources
    if roles[source[0]["from"]] != roles["human"]: source = source[1:]
    input_id, target = [], []
    system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
    input_id += system
    target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
    assert len(input_id) == len(target)
    for j, sentence in enumerate(source):
        role = roles[sentence["from"]]
        if has_image and sentence["value"] is not None and "<image>" in sentence["value"]:
            num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
            texts = sentence["value"].split('<image>')
            _input_id = tokenizer(role).input_ids + nl_tokens 
            for i,text in enumerate(texts):
                _input_id += tokenizer(text).input_ids 
                if i<len(texts)-1: _input_id += [IMAGE_TOKEN_INDEX] + nl_tokens
            _input_id += [im_end] + nl_tokens
            assert sum([i==IMAGE_TOKEN_INDEX for i in _input_id])==num_image
        else:
            if sentence["value"] is None: _input_id = tokenizer(role).input_ids + nl_tokens
            else: _input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
        input_id += _input_id
    input_ids.append(input_id)
    return torch.tensor(input_ids, dtype=torch.long)

In [14]:
def get_caption(image, prompt):
    image_tensors = []
    prompt = "<image>\n" + prompt

    image = Image.open(image)
    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
    image_tensors.append(image_tensor.half().cuda())
    input_ids = preprocess_qwen([{'from': 'human', 'value': prompt},{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
    
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensors,
            do_sample=False,
            num_beams=1,
            max_new_tokens=512,
            use_cache=True
        )
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
    outputs = outputs.strip()
    return outputs

In [15]:
def save_to_json(save_path, json_obj):
    if not os.path.exists(save_path):
        with open(save_path, "w") as f:
            json.dump([json_obj], f, indent=4)
    else:
        with open(save_path, "r") as f:
            data = json.load(f)
        data.append(json_obj)
        with open(save_path, "w") as f:
            json.dump(data, f, indent=4)  

In [None]:
iterator = json_image_iterator(image_root_path)
items = list(iterator)

for country, data, image in tqdm(items, desc="Progess"):
    caption = get_caption(image, en_prompt)
    caption = caption.replace('Caption: ', '').replace('"','').replace('The caption for this image could be:','')

    json_obj={"name":data["culture_name"],
              "country":country,
              "image_url":data["image_path"],
              "gt_caption":data["gt_caption"],
              "caption":caption}
            
    save_to_json(save_path, json_obj)