In [None]:
import transformers
import torch
import json
from tqdm import tqdm
import os
import re

from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

transformers.logging.set_verbosity_error()

In [None]:
from datasets import load_dataset

dataset = load_dataset('worldcuisines/food-kb', '', split='main')
dataset

In [None]:
def sea_filter(row):
    SEA_REGION = "South Eastern Asia"
    for i in range(1,6):
        if row[f'region{i}'] == SEA_REGION:
            return True
    return False

dataset = dataset.filter(sea_filter)
dataset

In [4]:
en_prompt = '''Write a caption in English for an image that may include culturally significant objects or elements from Southeast Asia.  
The caption should specifically name Southeast Asian cultural items, such as cuisine, traditions, landmarks, or other related elements if they appear in the image.
The caption should be concise, consisting of 3 to 5 sentences.'''
save_path= "qwen_worldcuisine_en_result.json"

In [None]:
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28

model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.float16, device_map="auto"
)

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)

In [6]:
def concat_message(caption_list, prompt):
    transformed_objs = []
    for obj in caption_list:
        new_obj = [{
            "role": "user",
            "content": [
                {"type": "image", "image": obj["image"]},
                {"type": "text", "text": prompt}
            ]
        }]
        transformed_objs.append(new_obj)
    return transformed_objs

In [7]:
def get_caption_batch(caption_list, prompt):
    messages = concat_message(caption_list, en_prompt)
    texts = [
        processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
        for msg in messages
    ]
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=texts,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")
    
    # Batch Inference
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_texts = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )


    return output_texts

In [8]:
def get_caption(image, prompt):
    messages =[{
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt}
            ]
        }]

    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")
    
    # Inference: Generation of the output
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    return output_text

In [9]:
def save_to_json(save_path, json_obj):
    if not os.path.exists(save_path):
        with open(save_path, "w") as f:
            json.dump([json_obj], f, indent=4)
    else:
        with open(save_path, "r") as f:
            data = json.load(f)
        data.append(json_obj)
        with open(save_path, "w") as f:
            json.dump(data, f, indent=4)  

In [None]:
for i in tqdm(range(len(dataset)), desc="Progess"):
    row = dataset[i]
    caption_list=[]
    cuisines = row['cuisines']
    name = row['name']
    
    for i in range(1, 9):
        image_key = f"image{i}"  
        image = row[image_key]

        if image is not None:
            url_key = image_key + "_url"
            image_url = row[url_key]
            image_url = image_url.replace("?download", "")

            url_key = image_key + "_url"
            image_url = row[url_key]
            image_url = image_url.replace("?download", "")
            caption = get_caption(image, en_prompt)

            json_obj={"name":name,
                      "cuisines":cuisines,
                      "image_url":image_url,
                      "caption":caption}
            
            save_to_json(save_path, json_obj)

In [None]:
# for i in tqdm(range(len(dataset)), desc="Progess"):
#     row = dataset[i]
#     caption_list=[]
#     cuisines = row['cuisines']
#     name = row['name']
    
#     for i in range(1, 9):
#         image_key = f"image{i}"  
#         image = row[image_key]

#         if image is not None:
#             url_key = image_key + "_url"
#             image_url = row[url_key]
#             image_url = image_url.replace("?download", "")

#             json_obj={"name":name,
#                       "cuisines":cuisines,
#                       "image_url":image_url,
#                       "image":image,
#                       "caption":""}
#             caption_list.append(json_obj)

#     caption_result_list = get_caption_batch(caption_list, en_prompt)

#     for i in range(len(caption_result_list)):
#         del caption_list[i]['image']
#         caption_list[i]['caption']=caption_result_list[i]
#         save_to_json(save_path, caption_list[i])
        