In [None]:
import os
import textwrap
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
from datasets import load_dataset
ds = load_dataset(
    "hwaseem04/Aya-testing",
    data_files={"xGQA_vqa": "data/xGQA_vqa-00000-of-00001.parquet"}
)

In [None]:
ds['xGQA_vqa'][0]


In [None]:
from tqdm import tqdm
from PIL import Image
import torch
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
from transformers.image_utils import load_image

# Load the PaliGemma 2 model and processor
model_id = "google/paligemma2-3b-mix-448"
model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, device_map="auto"
).eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)

# Languages to iterate over
languages = ["en", "bn", "de", "ko", "ru", "zh"]

# Directory to save temp images (needed for this model)
os.makedirs("temp_images_vqa", exist_ok=True)

dataset = ds['xGQA_vqa']

for sample in tqdm(dataset, desc="Iterating samples"):
    try:
        image = sample["image"]
        sample_id = sample["sample_id"]

        # Save image locally if it isn't already saved
        image_path = f"temp_images_vqa/{sample_id}.jpg"
        if not os.path.exists(image_path):
            image.save(image_path)

        print(f"\n========== Sample ID: {sample_id} ==========")

        for lang in languages:
            question_col = f"question_{lang}"
            answer_col = f"answer_{lang}"

            # Check if the sample contains required fields
            if question_col not in sample or answer_col not in sample:
                print(f"[{lang}] Missing data.")
                continue

            question = sample[question_col]
            gt_answer = sample[answer_col]

            # Construct the prompt as per PaliGemma's VQA format
            formatted_prompt = f"answer {lang} {question}"

            # Load image into format expected by the processor
            loaded_image = load_image(image_path)

            # Prepare model inputs
            model_inputs = processor(
                text=formatted_prompt,
                images=loaded_image,
                return_tensors="pt"
            ).to(torch.bfloat16).to(model.device)

            input_len = model_inputs["input_ids"].shape[-1]

            # Perform inference
            with torch.inference_mode():
                generation = model.generate(
                    **model_inputs,
                    max_new_tokens=100,
                    do_sample=False
                )
                generation = generation[0][input_len:]

            # Decode generated tokens into string
            pred_answer = processor.decode(generation, skip_special_tokens=True)

            # Display results
            print(f"\n[{lang.upper()}]")
            print(f"Question: {question}")
            print(f"GT: {gt_answer}")
            print(f"Pred: {textwrap.fill(pred_answer, width=80)}")

        print("=" * 100)

    except Exception as e:
        print(f"Error processing sample {sample['sample_id']}: {e}")

    break