# Load Data

In [None]:
from datasets import load_dataset


dataset = load_dataset("nirantk/finance-pdf-vqa", split="train")
images = dataset["image"]

In [None]:
from tqdm import tqdm
rag_data = []
for index, image in tqdm(enumerate(images), total=10):
    file_path = f"/opt/datasets/test/{index}.png"
    image.save(file_path)
    rag_data.append(
        {
            "id": index,
            "image_path": file_path,
        }
    )

In [None]:
# Python environment setup script for COLPALI
"""
conda create -n colpali python=3.11.4 -y
conda activate colpali
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
pip install transformers
pip install colpali_engine==0.1.1
pip install mteb
pip install qdrant-client
"""

In [None]:
import os

# get all files in the dir data
files = os.listdir("/opt/datasets/data")
files = [f"/opt/datasets/data/{f}" for f in files if f.endswith(".webp")]

In [None]:
len(files)

# Load Model

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
from PIL import Image
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
import numpy as np

In [None]:
model = ColPali.from_pretrained("/opt/models/base-models/vidore/colpaligemma-3b-mix-448-base", torch_dtype=torch.bfloat16, device_map="cuda").eval()
processor = AutoProcessor.from_pretrained("/opt/models/base-models/vidore/colpali")

# Index Data

In [None]:
def process_images_v2(processor, original_images, max_length: int = 50):
    texts_doc = ["Describe the image, summarised by content and company"] * len(original_images)
    images = [(Image.open(image)).convert("RGB") for image in original_images]

    batch_doc = processor(
        text=texts_doc,
        images=images,
        return_tensors="pt",
        padding="longest",
        max_length=max_length + processor.image_seq_length,
    )
    return batch_doc, original_images

In [None]:
dataloader = DataLoader(
    files,
    batch_size=8,
    shuffle=False,
    collate_fn=lambda x: process_images_v2(processor, x),
)

batch_doc_sample, images_sample = next(iter(dataloader))
with torch.no_grad():
    batch_doc_sample = {k: v.to(model.device) for k, v in batch_doc_sample.items()}
    embeddings_doc_sample = model(**batch_doc_sample)
embedding_dim = embeddings_doc_sample.shape[-1]

print(f"Embedding Dimension: {embedding_dim}")

In [None]:
client = QdrantClient("http://localhost:6333")

In [None]:
client.recreate_collection(
    collection_name="rag_test_2",
    vectors_config=rest.VectorParams(size=embedding_dim, distance="Cosine"),
)

In [None]:
from typing import List
from qdrant_client.http import models
from uuid import uuid4


dataloader = DataLoader(
    files,
    batch_size=8,
    shuffle=False,
    collate_fn=lambda x: process_images_v2(processor, x),
)

for batch_doc, original_images in tqdm(dataloader):
    with torch.no_grad():
        batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
        embeddings_doc = model(**batch_doc)
    vectors = list(torch.unbind(embeddings_doc.to("cpu")))
    ids = original_images
    payloads = [{"doc_id": idx} for idx in ids]

    points: List[models.PointStruct] = []
    for idx, vector, payload in zip(ids, vectors, payloads):
        for ind, vec in enumerate(vector):
            point_id = str(uuid4())
            point = models.PointStruct(
                id=point_id,
                vector=vec.tolist(),
                payload=payload,
            )
            points.append(point)
    print(points[0])
    client.upsert(
        collection_name="rag_test_2",
        points=points,
    )

# Query Data

In [None]:
from qdrant_client import QdrantClient


client = QdrantClient("http://localhost:6333")

In [None]:
queries = ["What benefits does OpenFaaS serverless compute offer in terms of performance?",
"How are cash flows reported in the E2E Cloud's financial statements according to the indirect method?",
"What benefits do Kubernetes Containers provide in cloud computing environments?",
"What are the cash and cash equivalents reported by E2E Cloud as of March 31, 2024, and March 31, 2023?",
"Why are Internet Protocol (IP) addresses considered assets with an indefinite useful life?",
"How do K8s containers boost serverless in cloud?",
"How does OpenFaaS boost K8s container performance?",
"How does E2E Cloud's cash flow report adjust for non-cash items and credit issues?",
"How are foreign exchange gains or losses on financial liabilities denominated in a foreign currency recognized in the financial statements of E2E Networks Limited?"]


dataloader = DataLoader(
    queries,
    batch_size=4,
    shuffle=False,
    collate_fn=lambda x: process_queries(processor, x, Image.new("RGB", (448, 448), (255, 255, 255))),
)

results = []
for batch_query in tqdm(dataloader):
    with torch.no_grad():
        batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
        embeddings_query = model(**batch_query)
        embeddings_query_pooled = embeddings_query.mean(dim=1)

    vectors = list(torch.unbind(embeddings_query_pooled.to("cpu")))
    for vector in vectors:
        search_result = client.search(
            collection_name="rag_test_2",
            query_vector=vector,
            limit=10,
        )
        print("Search Results:")
        for hit in search_result:
            print(f"Document ID: {hit.payload['doc_id']}, Score: {hit.score}")
            results.append(hit.payload["doc_id"])

In [None]:
Image.open("/opt/datasets/data/12.webp")