In [None]:
import os
import textwrap
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
from datasets import load_dataset
ds = load_dataset(
    "hwaseem04/Aya-testing",
    data_files={"xm3600_captioning": "data/xm3600_captioning-00000-of-00001.parquet"}
)

In [None]:
from datasets import load_dataset
from tqdm import tqdm
from PIL import Image
import os

import torch
from transformers import AutoProcessor, Gemma3ForConditionalGeneration

# Languages to iterate over
languages = ["en", "bn", "de", "ko", "ru", "zh"]

# Directory to save temp images (needed for this model)
os.makedirs("temp_images_caption", exist_ok=True)

# Load Gemma model and processor
model_id = "google/gemma-3-12b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto").eval()
processor = AutoProcessor.from_pretrained(model_id)

# Set torch dtype based on device
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

dataset = ds['xm3600_captioning']

for sample in tqdm(dataset, desc="Iterating samples"):
    try:
        image = sample["image"]
        sample_id = sample["sample_id"]

        #### This part will be replaced when attack implementation is ready #####

        image_path = f"temp_images_caption/{sample_id}.jpg"
        if not os.path.exists(image_path):
            image.save(image_path)

        #########################################################################

        print(f"\n========== Sample ID: {sample_id} ==========")

        for lang in languages:
            prompt_col = f"prompt_{lang}"
            caption_col = f"captions_{lang}"

            # Safety check if caption exists
            if prompt_col not in sample or caption_col not in sample:
                print(f"[{lang}] Missing data.")
                continue

            prompt = sample[prompt_col]
            gt_caption = sample[caption_col]

            #### Run inference using Gemma model ####

            messages = [
                {
                    "role": "system",
                    "content": [{"type": "text", "text": "You are a helpful assistant."}]
                },
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": image_path},
                        {"type": "text", "text": prompt}
                    ]
                }
            ]

            inputs = processor.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=True,
                return_dict=True,
                return_tensors="pt"
            ).to(model.device, dtype=torch_dtype)

            input_len = inputs["input_ids"].shape[-1]

            with torch.inference_mode():
                output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
                pred_caption = processor.decode(output[0][input_len:], skip_special_tokens=True)

            ###########################################

            # Display result
            print(f"\n[{lang.upper()}]")
            print(f"Prompt: {prompt}")
            print(f"GT: {gt_caption}")
            print(f"Pred: {textwrap.fill(pred_caption, width=80)}")

        print("=" * 100)

    except Exception as e:
        print(f"Error processing sample {sample['sample_id']}: {e}")

    break