In [None]:
import os
%cd ~
if not os.path.exists("maya"):
    !git clone https://github.com/nahidalam/maya
%cd maya

In [11]:
import transformers
import torch
import json
from tqdm import tqdm
import os
import re
from datasets import load_dataset

from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
)
from llava.conversation import conv_templates
from llava.eval.maya.eval_utils import load_maya_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images
from PIL import Image
transformers.logging.set_verbosity_error()

In [3]:
def json_image_iterator(root_dir):
    images_dir = os.path.join(root_dir, "images")

    for json_file in os.listdir(root_dir):
        if json_file.endswith(".json"):
            json_path = os.path.join(root_dir, json_file)
            
            with open(json_path, "r") as f:
                data_list = json.load(f)

            file_name = os.path.splitext(json_file)[0]
            
            for data in data_list:
                image_filename = data["image_path"].split("/")[-1]
                
                image_found = False
                for root, dirs, files in os.walk(images_dir):
                    if image_filename in files:
                        image_path = os.path.join(root, image_filename)
                        
                        try:
                            # image = Image.open(image_path)
                            yield file_name, data, image_path 
                            image_found = True
                        except Exception as e:
                            print(f"Error opening image {image_path}: {e}")
                        break
                
                if not image_found:
                    print(f"Image not found for entry in {json_file}: {image_filename}")
                    yield file_name, data, None  

In [4]:
en_prompt = '''Write a caption in English for an image that may include culturally significant objects or elements from Southeast Asia.  
The caption should specifically name Southeast Asian cultural items, such as cuisine, traditions, landmarks, or other related elements if they appear in the image.
The caption should be concise, consisting of 3 to 5 sentences.'''
save_path= "maya_seavqa_en_result.json"
image_root_path = r"/root/filter_sea_vqa_final/filter_sea_vqa_final"

In [None]:
model_base = "CohereForAI/aya-23-8B"
model_path = "maya-multimodal/maya"
mode = "finetuned"  
projector_path = None  
model, tokenizer, image_processor, _ = load_maya_model(
        model_base, model_path, projector_path if mode == "pretrained" else None, mode
    )
model = model.half().cuda()
model.eval()

In [8]:
def get_caption(image, prompt):
    if model.config.mm_use_im_start_end:
        prompt = (
            DEFAULT_IM_START_TOKEN
            + DEFAULT_IMAGE_TOKEN
            + DEFAULT_IM_END_TOKEN
            + "\n"
            + prompt
        )
    else:
        prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt

    conv = conv_templates["aya"].copy()
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .cuda()
    )

    image = Image.open(image).convert("RGB")
    image_tensor = process_images([image], image_processor, model.config)[0]

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor.unsqueeze(0).half().cuda(),
            image_sizes=[image.size],
            do_sample=False,
            num_beams=1,
            max_new_tokens=512,
            use_cache=True,
        )
        
    answer = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
    return answer

In [9]:
def save_to_json(save_path, json_obj):
    if not os.path.exists(save_path):
        with open(save_path, "w") as f:
            json.dump([json_obj], f, indent=4)
    else:
        with open(save_path, "r") as f:
            data = json.load(f)
        data.append(json_obj)
        with open(save_path, "w") as f:
            json.dump(data, f, indent=4)  

In [None]:
iterator = json_image_iterator(image_root_path)
items = list(iterator)

for country, data, image in tqdm(items, desc="Progess"):
    caption = get_caption(image, en_prompt)
    
    json_obj={"name":data["culture_name"],
              "country":country,
              "image_url":data["image_path"],
              "gt_caption":data["gt_caption"],
              "caption":caption}
            
    save_to_json(save_path, json_obj)