# Load Data

In [5]:
from datasets import load_dataset


dataset = load_dataset("nirantk/finance-pdf-vqa", split="train")
images = dataset["image"]

In [6]:
rag_data = []
for index, image in tqdm(enumerate(images), total=len(images)):
    file_path = f"data/{index}.webp"
    image.save(file_path)
    rag_data.append(
        {
            "id": index,
            "image_path": file_path,
        }
    )

In [None]:
# Python environment setup script for COLPALI
"""
conda create -n colpali python=3.11.4 -y
conda activate colpali
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
pip install transformers
pip install colpali_engine==0.1.1
pip install mteb
pip install qdrant-client
"""

In [7]:
import os


# get all files in the dir data
files = os.listdir("data")
files = [f"data/{f}" for f in files if f.endswith(".webp")]

In [None]:
len(files)

# Load Model

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
from PIL import Image
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
import numpy as np

In [None]:
model_name = "vidore/colpali"
model = ColPali.from_pretrained(
    "vidore/colpaligemma-3b-mix-448-base", torch_dtype=torch.bfloat16, device_map="cuda"
).eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name)

# Index Data

In [9]:
def process_images_v2(processor, original_images, max_length: int = 50):
    texts_doc = ["Describe the image."] * len(original_images)
    images = [(Image.open(image)).convert("RGB") for image in original_images]

    batch_doc = processor(
        text=texts_doc,
        images=images,
        return_tensors="pt",
        padding="longest",
        max_length=max_length + processor.image_seq_length,
    )
    return batch_doc, original_images

In [None]:
dataloader = DataLoader(
    files,
    batch_size=8,
    shuffle=False,
    collate_fn=lambda x: process_images_v2(processor, x),
)

batch_doc_sample, images_sample = next(iter(dataloader))
with torch.no_grad():
    batch_doc_sample = {k: v.to(model.device) for k, v in batch_doc_sample.items()}
    embeddings_doc_sample = model(**batch_doc_sample)
embedding_dim = embeddings_doc_sample.shape[-1]

print(f"Embedding Dimension: {embedding_dim}")

In [3]:
client = QdrantClient("http://localhost:6333")

In [None]:
client.recreate_collection(
    collection_name="rag_test",
    vectors_config=rest.VectorParams(size=embedding_dim, distance="Cosine"),
)

In [None]:
from typing import List
from qdrant_client.http import models
from uuid import uuid4


dataloader = DataLoader(
    files,
    batch_size=8,
    shuffle=False,
    collate_fn=lambda x: process_images_v2(processor, x),
)

for batch_doc, original_images in tqdm(dataloader):
    with torch.no_grad():
        batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
        embeddings_doc = model(**batch_doc)
    vectors = list(torch.unbind(embeddings_doc.to("cpu")))
    ids = original_images
    payloads = [{"doc_id": idx} for idx in ids]

    points: List[models.PointStruct] = []
    for idx, vector, payload in zip(ids, vectors, payloads):
        for ind, vec in enumerate(vector):
            point_id = str(uuid4())
            point = models.PointStruct(
                id=point_id,
                vector=vec.tolist(),
                payload=payload,
            )
            points.append(point)
    print(points[0])
    client.upsert(
        collection_name="rag_test",
        points=points,
    )

# Query Data

In [3]:
from qdrant_client import QdrantClient


client = QdrantClient("http://localhost:6333")

In [None]:
queries = ["What is the total profit of E2E Networks?"]

dataloader = DataLoader(
    queries,
    batch_size=4,
    shuffle=False,
    collate_fn=lambda x: process_queries(processor, x, Image.new("RGB", (448, 448), (255, 255, 255))),
)

results = []
for batch_query in tqdm(dataloader):
    with torch.no_grad():
        batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
        embeddings_query = model(**batch_query)
        embeddings_query_pooled = embeddings_query.mean(dim=1)

    vectors = list(torch.unbind(embeddings_query_pooled.to("cpu")))
    for vector in vectors:
        search_result = client.search(
            collection_name="rag_test",
            query_vector=vector,
            limit=10,
        )
        print("Search Results:")
        for hit in search_result:
            print(f"Document ID: {hit.payload['doc_id']}, Score: {hit.score}")
            results.append(hit.payload["doc_id"])

In [None]:
Image.open(results[0])