In [None]:
import torch
from PIL import ImageDraw
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
from datasets import load_dataset
from transformers import (
    PaliGemmaForConditionalGeneration,
    AutoProcessor,
    BitsAndBytesConfig,
)
from peft import PeftModel
import torch
from datasets import load_dataset
import pickle

PALIGEMMA_IMAGE_SIZE = 896
PALIGEMMA_PATCH_SIZE = 14
PALIGEMMA_IMAGE_TOKEN_ID = 257152

dataset = load_dataset("arnaudstiegler/v2_synthetic_us_passports_easy")


bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
dataset = load_dataset("arnaudstiegler/v2_synthetic_us_passports_easy")
device = "cuda" if torch.cuda.is_available() else "cpu"

base_model = "google/paligemma-3b-pt-896"
adapter_model = "arnaudstiegler/paligemma-3b-pt-896-us-passports-lora-adapters"

model = PaliGemmaForConditionalGeneration.from_pretrained(
    base_model, quantization_config=bnb_config
)
model = PeftModel.from_pretrained(model, adapter_model).to(device)
processor = AutoProcessor.from_pretrained(base_model)

model = model.eval().to("cuda")

sample = dataset['test'][0]
image = sample["image"].convert("RGB")
inputs = processor(text="Process ", images=image, return_tensors="pt")
inputs = {
    "input_ids": inputs["input_ids"].to("cuda"),
    "attention_mask": inputs["attention_mask"].to("cuda"),
    "pixel_values": inputs["pixel_values"].to("cuda"),
}
out = model.generate(
    **inputs,
    output_attentions=True,
    output_hidden_states=True,
    return_dict_in_generate=True,
    max_new_tokens=164,
)

img_tokens = torch.where(inputs["input_ids"] == PALIGEMMA_IMAGE_TOKEN_ID)

grid = {}
for k in range((PALIGEMMA_IMAGE_SIZE // PALIGEMMA_PATCH_SIZE) * (PALIGEMMA_IMAGE_SIZE // PALIGEMMA_PATCH_SIZE)):
    col_idx = k % (PALIGEMMA_IMAGE_SIZE // PALIGEMMA_PATCH_SIZE)
    row_idx = k // (PALIGEMMA_IMAGE_SIZE // PALIGEMMA_PATCH_SIZE)
    # In format x1,y1,x2,y2
    grid[k] = [
        col_idx * PALIGEMMA_PATCH_SIZE,
        row_idx * PALIGEMMA_PATCH_SIZE,
        (col_idx + 1) * PALIGEMMA_PATCH_SIZE,
        (row_idx + 1) * PALIGEMMA_PATCH_SIZE,
    ]

In [None]:
top = []
# We start from 1 because we're skipping the original self-attention across all tokens
# TODO: maybe we could bring it back
for idx in range(1, len(out.attentions)):
    # Stack across layers
    img_attention = torch.stack([x for x in out.attentions[idx]], axis=-1)

    # Average the attentions across layers
    avg_img_attention = torch.mean(img_attention, axis=-1)
    # Average the attention across heads
    avg_img_attention = torch.mean(avg_img_attention, axis=1)

    # Only take the attention scores corresponding to the image
    img_attentions = avg_img_attention[:, :, : img_tokens[1][-1]]

    # Only look at absolute value of the attention
    att_score = torch.abs(img_attentions)
    # Take top-k image patches
    patch_indices = torch.topk(att_score, k=5, dim=-1).indices

    # Given a patch index, and a grid, retrieve the location of the different image patches on the image
    top_bbox = [
        tuple(grid[patch_idx.item()]) for patch_idx in patch_indices.flatten()
    ]
    top += top_bbox

width, height = image.size
min_x, min_y, max_x, max_y = normalized_bbox
factor_width = width / PALIGEMMA_IMAGE_SIZE
factor_height = height / PALIGEMMA_IMAGE_SIZE

# Get coordinates of the bbox that covers all top-k image patches (should be the reason of interest)
min_x, min_y, max_x, max_y = PALIGEMMA_IMAGE_SIZE, PALIGEMMA_IMAGE_SIZE, 0, 0
for token, boxes in zip(out.sequences, top):
    for bbox in boxes:
        resized_bbox = [
            factor_width * bbox[0],
            factor_height * bbox[1],
            factor_width * bbox[2],
            factor_height * bbox[3],
        ]