In [None]:
import os
import textwrap
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
from datasets import load_dataset
ds = load_dataset(
    "hwaseem04/Aya-testing",
    data_files={"xm3600_captioning": "data/xm3600_captioning-00000-of-00001.parquet"}
)

In [None]:
ds['xm3600_captioning'][0]

In [None]:
from tqdm import tqdm
from PIL import Image
import torch
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
from transformers.image_utils import load_image

# Load PaliGemma model and processor
model_id = "google/paligemma2-3b-mix-448"
model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto"
).eval()

processor = PaliGemmaProcessor.from_pretrained(model_id)

# Languages to iterate over
languages = ["en", "bn", "de", "ko", "ru", "zh"]

# Directory to save temp images (needed for this model)
os.makedirs("temp_images_caption", exist_ok=True)

dataset = ds['xm3600_captioning']

for sample in tqdm(dataset, desc="Iterating samples"):
    try:
        image = sample["image"]
        sample_id = sample["sample_id"]

        # Save image if not already saved
        image_path = f"temp_images_caption/{sample_id}.jpg"
        if not os.path.exists(image_path):
            image.save(image_path)

        print(f"\n========== Sample ID: {sample_id} ==========")

        for lang in languages:
            prompt_col = f"prompt_{lang}"
            caption_col = f"captions_{lang}"

            # Ensure expected fields exist in sample
            if prompt_col not in sample or caption_col not in sample:
                print(f"[{lang}] Missing data.")
                continue

            prompt = sample[prompt_col]
            gt_caption = sample[caption_col]

            # Format prompt using PaliGemma's multilingual captioning syntax
            paligemma_prompt = f"caption {lang}"

            # Load image in format expected by processor
            loaded_image = load_image(image_path)

            # Prepare input for the model
            model_inputs = processor(
                text=paligemma_prompt,
                images=loaded_image,
                return_tensors="pt"
            ).to(torch.bfloat16).to(model.device)

            input_len = model_inputs["input_ids"].shape[-1]

            # Generate prediction
            with torch.inference_mode():
                generation = model.generate(
                    **model_inputs,
                    max_new_tokens=100,
                    do_sample=False
                )
                generation = generation[0][input_len:]

            # Decode output tokens to string
            pred_caption = processor.decode(generation, skip_special_tokens=True)

            # Display results
            print(f"\n[{lang.upper()}]")
            print(f"Prompt: {prompt}")
            print(f"GT: {gt_caption}")
            print(f"Pred: {textwrap.fill(pred_caption, width=80)}")

        print("=" * 100)

    except Exception as e:
        print(f"Error processing sample {sample['sample_id']}: {e}")

    break