In [None]:
# Python environment setup script for COLPALI
"""
conda create -n colpali python=3.11.4 -y
conda activate colpali
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
pip install transformers
pip install colpali_engine==0.1.1
pip install mteb
"""

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
from PIL import Image

from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
from colpali_engine.utils.image_from_page_utils import load_from_dataset

In [None]:
# Load model
model_name = "vidore/colpali"
model = ColPali.from_pretrained("vidore/colpaligemma-3b-mix-448-base", torch_dtype=torch.bfloat16, device_map="cuda").eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name)
print(processor)

In [3]:
from datasets import load_dataset


dataset = load_dataset("nirantk/finance-pdf-vqa", split="train")
images = dataset["image"]

In [None]:
print(len(images))
print(images[0].size)

In [None]:
# run inference - docs
dataloader = DataLoader(
    images,
    batch_size=8,
    shuffle=False,
    collate_fn=lambda x: process_images(processor, x),
)
ds = []
for batch_doc in tqdm(dataloader):
    with torch.no_grad():
        batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
        embeddings_doc = model(**batch_doc)
    ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))

In [None]:
print(len(ds))
print(ds[0].shape)

In [None]:
# run inference - queries
queries = ["What happened in 30 September?"]

dataloader = DataLoader(
    queries,
    batch_size=4,
    shuffle=False,
    collate_fn=lambda x: process_queries(processor, x, Image.new("RGB", (448, 448), (255, 255, 255))),
)

qs = []
for batch_query in tqdm(dataloader):
    with torch.no_grad():
        batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
        embeddings_query = model(**batch_query)
    qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))


print(len(qs))
print(qs[0].shape)

In [None]:
def evaluate_colbert(qs, ps) -> torch.Tensor:
    qs_padded = torch.nn.utils.rnn.pad_sequence(qs, batch_first=True, padding_value=0).to("cuda")
    ps_padded = torch.nn.utils.rnn.pad_sequence(ps, batch_first=True, padding_value=0).to("cuda")
    similarity_matrix = torch.einsum("qnd,psd->qpns", qs_padded, ps_padded)
    max_similarities, _ = similarity_matrix.max(dim=3)
    final_scores = max_similarities.sum(dim=2)
    return final_scores.cpu()


new_scores = evaluate_colbert(qs, ds)


print(new_scores.argmax(axis=1))
for index, result in enumerate(new_scores.argmax(axis=1)):
    print(f"Query: {queries[index]}")
    print(display(images[result]))