In [1]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [2]:
from colpali_engine.models import ColModernVBert, ColModernVBertProcessor

model_id = "ModernVBERT/colmodernvbert"

processor = ColModernVBertProcessor.from_pretrained(model_id)
model = ColModernVBert.from_pretrained(
            model_id,
            torch_dtype=torch.float32,
            trust_remote_code=True
).to(device)

In [10]:
import io
import base64
import fitz
from PIL import Image
from tqdm import tqdm

def page_to_pil(page: fitz.Page) -> Image.Image:
    pix = page.get_pixmap(dpi=300)
    img = pix.pil_image()
    return img

def image_to_embedding(img: Image.Image) -> torch.Tensor:
    """
    NOTE: Batch processing would be better.
    Here only individually for checking with tqdm.
    """
    inputs = processor.process_images([img])
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        embeddings = model(**inputs)

    if embeddings.dim() == 3 and embeddings.shape[0] == 1:
        embeddings = embeddings.squeeze(0)
    return embeddings.cpu()

def pil_to_base64(img: Image.Image) -> str:
    if img.mode != "RGB":
        img = img.convert("RGB")
    buffer = io.BytesIO()
    img.save(buffer, format="JPEG")
    buffer.seek(0)
    img_bytes = buffer.getvalue()
    return base64.b64encode(img_bytes).decode("utf-8")

def process_document(path: str) -> tuple[list[torch.Tensor], list[dict[str, str|int]]]:
    doc = fitz.open(path)
    embeddings = []
    payloads = []
    for i, _ in tqdm(enumerate(doc), total=len(doc)):
        page = doc.load_page(i)
        # processing
        img = page_to_pil(page)
        embedding = image_to_embedding(img)
        image_encoding = pil_to_base64(img)
        # appending
        embeddings.append(embedding)
        payloads.append({"page": i+1, "image": image_encoding})
    doc.close()
    return embeddings, payloads

In [11]:
embeddings, payloads = process_document("./data/United-in-Science-2024_en.pdf")

100%|██████████| 48/48 [00:58<00:00,  1.21s/it]


In [8]:
from qdrant_client import QdrantClient, models

collection_name = "DocumentRetrieval"
path = "./qdrant"
client = QdrantClient(path=path) 

In [14]:
if not client.collection_exists(collection_name=collection_name):
    import uuid
    print("Create collection...")
    # Create
    client.create_collection(
        collection_name=collection_name,
        vectors_config=models.VectorParams(
            size=128,
            distance=models.Distance.COSINE,
            multivector_config=models.MultiVectorConfig(
                comparator=models.MultiVectorComparator.MAX_SIM
            )
        )
    )
    # Create Points
    points = []
    for i, (embedding, payload) in enumerate(zip(embeddings, payloads)):
        if isinstance(embedding, torch.Tensor):
            multi_vector = embedding.cpu().numpy().tolist()
        else:
            multi_vector = embedding
        
        points.append(
            models.PointStruct(
                id=str(uuid.uuid4()),  
                vector=multi_vector,   
                payload=payload        
            )
        )

    # Upload
    client.upsert(
        collection_name=collection_name,
        wait=True,
        points=points
    )

else:
    print("Collection already exists!")


Collection already exists!
