In [None]:
import transformers
from transformers import (
    PaliGemmaProcessor,
    PaliGemmaForConditionalGeneration,
)
import torch
import json
from tqdm import tqdm
import os
transformers.logging.set_verbosity_error()

In [None]:
from datasets import load_dataset

dataset = load_dataset('worldcuisines/food-kb', '', split='main')
dataset

In [None]:
def sea_filter(row):
    SEA_REGION = "South Eastern Asia"
    for i in range(1,6):
        if row[f'region{i}'] == SEA_REGION:
            return True
    return False

dataset = dataset.filter(sea_filter)
dataset

In [4]:
en_prompt = '''Write a caption in English for an image that may include culturally significant objects or elements from Southeast Asia.  
The caption should specifically name Southeast Asian cultural items, such as cuisine, traditions, landmarks, or other related elements if they appear in the image.
The caption should be concise, consisting of 3 to 5 sentences.'''
save_path= "pali_gemma_world_cuisine_en_result.json"

In [None]:
model_id = "google/paligemma2-10b-pt-896"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto").eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)

In [6]:
def get_caption(image, prompt):
    model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.float16).to(model.device)
    input_len = model_inputs["input_ids"].shape[-1]
    with torch.inference_mode():
        generation = model.generate(**model_inputs, max_new_tokens=512, do_sample=False, num_beams=1, use_cache=True)
        generation = generation[0][input_len:]
        decoded = processor.decode(generation, skip_special_tokens=True)
    return decoded

In [7]:
def save_to_json(save_path, json_obj):
    if not os.path.exists(save_path):
        with open(save_path, "w") as f:
            json.dump([json_obj], f, indent=4)
    else:
        with open(save_path, "r") as f:
            data = json.load(f)
        data.append(json_obj)
        with open(save_path, "w") as f:
            json.dump(data, f, indent=4)  

In [None]:
for i in tqdm(range(len(dataset)), desc="Progess"):
    row = dataset[i]
    cuisines = row['cuisines']
    name = row['name']
    
    for i in range(1, 9):
        image_key = f"image{i}"  
        image = row[image_key]

        if image is not None:
            url_key = image_key + "_url"
            image_url = row[url_key]
            image_url = image_url.replace("?download", "")
            caption = get_caption(image, en_prompt)
            
            json_obj={"name":name,
                      "cuisines":cuisines,
                      "image_url":image_url,
                      "caption":caption}
            
            save_to_json(save_path, json_obj)