In [None]:
import transformers
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
)
from transformers.image_utils import load_image
import torch
import json
from tqdm import tqdm
import os
transformers.logging.set_verbosity_error()

In [2]:
def json_image_iterator(root_dir):
    images_dir = os.path.join(root_dir, "images")

    for json_file in os.listdir(root_dir):
        if json_file.endswith(".json"):
            json_path = os.path.join(root_dir, json_file)
            
            with open(json_path, "r") as f:
                data_list = json.load(f)

            file_name = os.path.splitext(json_file)[0]
            
            for data in data_list:
                image_filename = data["image_path"].split("/")[-1]
                
                image_found = False
                for root, dirs, files in os.walk(images_dir):
                    if image_filename in files:
                        image_path = os.path.join(root, image_filename)
                        
                        try:
                            # image = Image.open(image_path)
                            yield file_name, data, image_path 
                            image_found = True
                        except Exception as e:
                            print(f"Error opening image {image_path}: {e}")
                        break
                
                if not image_found:
                    print(f"Image not found for entry in {json_file}: {image_filename}")
                    yield file_name, data, None  

In [3]:
en_prompt = '''Write a caption in English for an image that may include culturally significant objects or elements from Southeast Asia.  
The caption should specifically name Southeast Asian cultural items, such as cuisine, traditions, landmarks, or other related elements if they appear in the image.
The caption should be concise, consisting of 3 to 5 sentences.'''
save_path= "pali_gemma_seavqa_en_result.json"
image_root_path = r"/root/filter_sea_vqa_final/filter_sea_vqa_final"

In [None]:
model_id = "google/paligemma2-10b-ft-docci-448"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto").eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)

In [5]:
def get_caption(image, prompt):
    image = load_image(image)
    model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.float16).to(model.device)
    input_len = model_inputs["input_ids"].shape[-1]
    with torch.inference_mode():
        generation = model.generate(**model_inputs, max_new_tokens=512, do_sample=False, num_beams=1)
        generation = generation[0][input_len:]
        decoded = processor.decode(generation, skip_special_tokens=True)
    return decoded

In [6]:
def save_to_json(save_path, json_obj):
    if not os.path.exists(save_path):
        with open(save_path, "w") as f:
            json.dump([json_obj], f, indent=4)
    else:
        with open(save_path, "r") as f:
            data = json.load(f)
        data.append(json_obj)
        with open(save_path, "w") as f:
            json.dump(data, f, indent=4)  

In [None]:
iterator = json_image_iterator(image_root_path)
items = list(iterator)

for country, data, image in tqdm(items, desc="Progess"):
    caption = get_caption(image, en_prompt)
    
    json_obj={"name":data["culture_name"],
              "country":country,
              "image_url":data["image_path"],
              "gt_caption":data["gt_caption"],
              "caption":caption}
            
    save_to_json(save_path, json_obj)