In [1]:
import os
import torch
import time
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.http import models
from tqdm import tqdm
from PIL import Image
from uuid import uuid4
import stamina

In [2]:
# Initialize Qdrant
client = QdrantClient(path="qdrant_storage")
collection_name = "colpali_embeddings"
dim = 128  # Dimensionality of ColPali embeddings

In [None]:
from colpali_engine.models import ColPali, ColPaliProcessor

# Initialize ColPali Model and Processor
model_name = "vidore/colpali-v1.2"
model = ColPali.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="mps",#cuda:0
).eval()

  from .autonotebook import tqdm as notebook_tqdm
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
processor = ColPaliProcessor.from_pretrained(model_name)

In [None]:
client.create_collection(
    collection_name=collection_name,
    on_disk_payload=True,  # store the payload on disk
    vectors_config=models.VectorParams(
        size=128,
        distance=models.Distance.COSINE,
        on_disk=True, # move original vectors to disk
        multivector_config=models.MultiVectorConfig(
            comparator=models.MultiVectorComparator.MAX_SIM
        ),
        quantization_config=models.BinaryQuantization(
        binary=models.BinaryQuantizationConfig(
            always_ram=True  # keep only quantized vectors in RAM
            ),
        ),
    ),
)

True

In [None]:
@stamina.retry(on=Exception, attempts=3) # retry mechanism if an exception occurs during the operation
def upsert_to_qdrant(batch):
    try:
        client.upsert(
            collection_name=collection_name,
            points=points,
            wait=False,
        )
    except Exception as e:
        print(f"Error during upsert: {e}")
        return False
    return True

In [None]:
# Prepare Documents (Images of Pages)
image_dir = "output_images"
image_files = os.listdir(image_dir)
images = [{"image": Image.open(os.path.join(image_dir, name)), "filename": name} for name in image_files]

batch_size = 4  # Adjust based on your GPU memory constraints

# Use tqdm to create a progress bar
with tqdm(total=len(images), desc="Indexing Progress") as pbar:
    for i in range(0, len(images), batch_size):
        batch = images[i: i + batch_size]
        batch_images = [item["image"] for item in batch]
        batch_filenames = [item["filename"] for item in batch]

        # Process and encode images
        with torch.no_grad():
            processed_images = processor.process_images(batch_images).to(model.device)
            image_embeddings = model(**processed_images)

        # Prepare points for Qdrant
        points = []
        for j, (embedding, filename) in enumerate(zip(image_embeddings, batch_filenames)):
            # Convert the embedding to a list of vectors
            multivector = embedding.cpu().float().numpy().tolist()
            points.append(
                models.PointStruct(
                    id=str(uuid4()),
                    vector=multivector,  # List of vectors
                    payload={
                        "filepath": os.path.join(image_dir, filename),
                    },  # Metadata
                )
            )

        # Upload points to Qdrant
        try:
            upsert_to_qdrant(points)
        except Exception as e:
            print(f"Error during upsert: {e}")
            continue

        # Update the progress bar
        pbar.update(batch_size)

print("Indexing complete!")


Indexing Progress: 164it [30:28, 11.15s/it]                         

Indexing complete!





In [None]:
client.update_collection(
    collection_name=collection_name,
    optimizer_config=models.OptimizersConfigDiff(indexing_threshold=10),
)

False

In [None]:
query_text = "Steven Murphy's head?"
with torch.no_grad():
    batch_query = processor.process_queries([query_text]).to(
        model.device
    )
    query_embedding = model(**batch_query)
query_embedding

tensor([[[ 0.1699, -0.0190,  0.1055,  ..., -0.0247, -0.0903, -0.0442],
         [ 0.0598, -0.1147,  0.0564,  ...,  0.0791, -0.0249,  0.0280],
         [ 0.0078,  0.0332,  0.1182,  ..., -0.0322, -0.0889, -0.0036],
         ...,
         [-0.1846,  0.0884,  0.1035,  ..., -0.0295,  0.0640, -0.1138],
         [-0.1816,  0.1147,  0.1079,  ..., -0.0432,  0.0693, -0.1162],
         [-0.0356,  0.1221,  0.1621,  ..., -0.0045, -0.0444, -0.0625]]],
       device='mps:0', dtype=torch.bfloat16)

In [None]:
multivector_query = query_embedding[0].cpu().float().numpy().tolist()

## Step 8: Searching and Retrieving the Documents

In this step, we perform a search to retrieve the top 10 results closer to our query multivector.

We apply rescoring to adjust and refine the initial search results by reevaluating the most relevant candidates with a more precise scoring algorithm. Oversampling is used to improve search accuracy by retrieving a larger pool of candidate results than the final number required. Finally, we measure and display how long the search process takes.

In [None]:
start_time = time.time()
search_result = client.query_points(
    collection_name=collection_name,
    query=multivector_query,
    limit=10,
    timeout=100,
    search_params=models.SearchParams(
        quantization=models.QuantizationSearchParams(
            ignore=False,
            rescore=True,
            oversampling=2.0,
        )
    )
)
end_time = time.time()
# Search in Qdrant
search_result.points

# Extract and display the search results
import pprint

if search_result.points:
    print(f"Top {len(search_result.points)} results:")
    for idx, point in enumerate(search_result.points, start=1):
        print(f"Result {idx}:")
        pprint.pprint({
            "id": point.id,
            "score": point.score,  # Similarity score
            "payload": point.payload,  # Metadata associated with the point
        })
else:
    print("No results found.")

elapsed_time = end_time - start_time
print(f"Search completed in {elapsed_time:.4f} seconds")

Top 10 results:
Result 1:
{'id': '110538ed-261c-464e-94dd-7242bf3eb330',
 'payload': {'filepath': 'output_images/The-Psychedelics-as-Medicine-Report_page_4.png'},
 'score': 11.731613663642815}
Result 2:
{'id': '8382fe4c-900c-4dfb-ab81-aa340169e50b',
 'payload': {'filepath': 'output_images/PSFC_Report_Public_Release_Web_12.16.21_page_67.png'},
 'score': 9.268593605883533}
Result 3:
{'id': '7dc6b978-9538-4658-a7c8-f50c9104d686',
 'payload': {'filepath': 'output_images/PSFC_Report_Public_Release_Web_12.16.21_page_68.png'},
 'score': 8.921506498231984}
Result 4:
{'id': '7600bf1a-e348-41cc-b3e0-a38dc5ba5cf7',
 'payload': {'filepath': 'output_images/COMPASS_Pathways_plc_-_Q2_2024_page_1.png'},
 'score': 8.795488833108616}
Result 5:
{'id': 'aa8a3e9a-6875-471e-a11f-4b3922a7092c',
 'payload': {'filepath': 'output_images/The-Psychedelics-as-Medicine-Report_page_11.png'},
 'score': 8.515884776977774}
Result 6:
{'id': '98301eab-10a6-4c11-a39c-da72df583a80',
 'payload': {'filepath': 'output_images/