In [1]:
import os

import torch
from pdf2image import convert_from_path
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor

from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries


COLORS = ["#4285f4", "#db4437", "#f4b400", "#0f9d58", "#e48ef1"]
# Load model
model_name = "vidore/colpali"
token = os.environ.get("HF_TOKEN")
model = ColPali.from_pretrained(
    "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda", token=token
).eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name, token=token)
device = model.device

  from .autonotebook import tqdm as notebook_tqdm
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.17it/s]
Some weights of ColPali were not initialized from the model checkpoint at google/paligemma-3b-mix-448 and are newly initialized: ['custom_text_proj.bias', 'custom_text_proj.weight', 'language_model.lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [50]:
from PIL import Image, ImageChops
import numpy as np
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))

def search(query:str, ds,k:int=1):
    qs = []
    with torch.no_grad():
        batch_query = process_queries(processor,[query],mock_image)
        batch_query = {k: v.to(device) for k,v in batch_query.items()}
        embeddings_query = model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
    
    retriever_evaluator = CustomEvaluator(is_multi_vector=True)
    scores = retriever_evaluator.evaluate(qs,ds)
    best_pages_idxs = np.argsort(scores,axis=1).squeeze(0)
    return best_pages_idxs[::-1][:k]


def trim_and_square(im, padding=10, target_size=448):
    # Trim whitespace
    # print(im)
    bg = Image.new(im.mode, im.size, im.getpixel((0,0)))
    diff = ImageChops.difference(im, bg)
    diff = ImageChops.add(diff, diff, 2.0, -100)
    bbox = diff.getbbox()
    if bbox:
        # Add padding to the bounding box
        left, top, right, bottom = bbox
        left = max(0, left - padding)
        top = max(0, top - padding)
        right = min(im.width, right + padding)
        bottom = min(im.height, bottom + padding)
        trimmed = im.crop((left, top, right, bottom))
    else:
        trimmed = im

    # Resize to square
    w, h = trimmed.size
    if w > h:
        new_w = target_size
        new_h = int(h * (target_size / w))
    else:
        new_h = target_size
        new_w = int(w * (target_size / h))
    
    resized = trimmed.resize((new_w, new_h), Image.LANCZOS)

    # Create a white square image
    square = Image.new('RGB', (target_size, target_size), (255, 255, 255))
    
    # Paste the resized image onto the square, centered
    paste_x = (target_size - new_w) // 2
    paste_y = (target_size - new_h) // 2
    square.paste(resized, (paste_x, paste_y))

    return square

def index(file,ds):
    images = []
    for f in file:
        print(f)
        f_imgs = convert_from_path(f)
        for f_img in f_imgs:
            cropped_f_img = trim_and_square(f_img)
            images.append(cropped_f_img)

    dataloader = DataLoader(
        images,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: process_images(processor,x)
    )
    
    for batch_doc in tqdm(dataloader):
        with torch.no_grad():
            batch_doc = {k: v.to(device) for k,v in batch_doc.items()}
            embeddings_doc = model(**batch_doc)
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
    return f"Uploaded and converted {len(images)} pages", ds, images


In [3]:
ds = []
files = [os.path.join("output/SEC_EDGAR_FILINGS/GOOG-2024",f) for f in os.listdir("output/SEC_EDGAR_FILINGS/GOOG-2024") if f.endswith(".pdf")]
msg, ds, images = index(files,ds)

output/SEC_EDGAR_FILINGS/GOOG-2024/goog-20240630-10-Q2.pdf
output/SEC_EDGAR_FILINGS/GOOG-2024/goog-20240331-10-Q1.pdf


100%|██████████| 25/25 [00:06<00:00,  3.63it/s]


In [51]:
# query = "What was the operating expense of Google for the year?"
query = "What is the total revenue generated?"

best_img_idxs = search(query,ds,k=5)

tensor([84])
Top 1 Accuracy (verif): 0.0


In [52]:
best_img_idxs

array([84, 86, 11, 36, 31])

In [4]:
import pprint
from dataclasses import asdict, dataclass
from pathlib import Path
from uuid import uuid4

import matplotlib.pyplot as plt
import torch
from einops import rearrange
from PIL import Image
from tqdm import trange

from colpali_engine.interpretability.plot_utils import plot_patches
from colpali_engine.interpretability.processor import ColPaliProcessor
from colpali_engine.interpretability.torch_utils import normalize_attention_map_per_query_token
from colpali_engine.interpretability.vit_configs import VIT_CONFIG
from colpali_engine.models.paligemma_colbert_architecture import ColPali

OUTDIR_INTERPRETABILITY = Path("outputs/interpretability")


@dataclass
class InterpretabilityInput:
    query: str
    image: Image.Image
    start_idx_token: int
    end_idx_token: int


In [5]:
def generate_interpretability_plots(
    model: ColPali,
    processor: ColPaliProcessor,
    query: str,
    image: Image.Image,
    add_special_prompt_to_doc: bool = True,
) -> None:

    # Sanity checks
    if len(model.active_adapters()) != 1:
        raise ValueError("The model must have exactly one active adapter.")

    if model.config.name_or_path not in VIT_CONFIG:
        raise ValueError("The model must be referred to in the VIT_CONFIG dictionary.")
    vit_config = VIT_CONFIG[model.config.name_or_path]
    # Preprocess the inputs
    input_text_processed = processor.process_text(query).to(model.device)
    input_image_processed = processor.process_image(image, add_special_prompt=add_special_prompt_to_doc).to(
        model.device
    )

    # Forward pass
    with torch.no_grad():
        output_text = model.forward(**asdict(input_text_processed))  # (1, n_text_tokens, hidden_dim)

    # NOTE: `output_image`` will have shape:
    # (1, n_patch_x * n_patch_y, hidden_dim) if `add_special_prompt_to_doc` is False
    # (1, n_patch_x * n_patch_y + n_special_tokens, hidden_dim) if `add_special_prompt_to_doc` is True
    with torch.no_grad():
        output_image = model.forward(**asdict(input_image_processed))

    if add_special_prompt_to_doc:  # remove the special tokens
        output_image = output_image[
            :, : processor.processor.image_seq_length, :
        ]  # (1, n_patch_x * n_patch_y, hidden_dim)

    output_image = rearrange(
        output_image, "b (h w) c -> b h w c", h=vit_config.n_patch_per_dim, w=vit_config.n_patch_per_dim
    )  # (1, n_patch_x, n_patch_y, hidden_dim)

    # Get the unnormalized attention map
    attention_map = torch.einsum(
        "bnk,bijk->bnij", output_text, output_image
    )  # (1, n_text_tokens, n_patch_x, n_patch_y)
    attention_map_normalized = normalize_attention_map_per_query_token(
        attention_map
    )  # (1, n_text_tokens, n_patch_x, n_patch_y)
    attention_map_normalized = attention_map_normalized.float()

    # Get text token information
    text_tokens = processor.tokenizer.tokenize(processor.decode(input_text_processed.input_ids[0]))
    # print("Text tokens:")
    # pprint.pprint(text_tokens)
    # print("\n")

    return attention_map_normalized, attention_map,text_tokens

In [None]:
attention_map_normalized, attention_map,text_tokens = generate_interpretability_plots(
    model,
    ColPaliProcessor(processor=processor),
    query,
    images[34],
)