# L2 - Multi-vector Image Retrieval: ColPali

<p style="background-color:#fff6e4; padding:15px; border-width:3px; border-color:#f5ecda; border-style:solid; border-radius:6px"> ‚è≥ <b>Note <code>(Kernel Starting)</code>:</b> This notebook takes about 30 seconds to be ready to use. You may start and watch the video while you wait.</p>

<div style="background-color:#fff6ff; padding:13px; border-width:3px; border-color:#efe6ef; border-style:solid; border-radius:6px">
<p> üíª &nbsp; <b>Access <code>requirements.txt</code> and <code>helper.py</code> files:</b> 1) click on the <em>"File"</em> option on the top menu of the notebook and then 2) click on <em>"Open"</em>.

<p> ‚¨á &nbsp; <b>Download Notebooks:</b> 1) click on the <em>"File"</em> option on the top menu of the notebook and then 2) click on <em>"Download as"</em> and select <em>"Notebook (.ipynb)"</em>.</p>

<p> üìí &nbsp; For more help, please see the <em>"Appendix ‚Äì Tips, Help, and Download"</em> Lesson.</p>

</div>

The following cell is not in the video and just ensures output later in this notebook will render properly.

In [None]:
import plotly.io as pio
pio.renderers.default = "notebook"

: 

#### Loading the ColPali Model

In [None]:
# Expect this cell may take several minutes to finish

import torch

is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
    from colpali_engine.models import ColPaliProcessor, ColPali

    model_name = "vidore/colpali-v1.3"
    processor = ColPaliProcessor.from_pretrained(model_name)
    model = ColPali.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        attn_implementation="eager",
    )
else:
    from colpali_engine.models import ColIdefics3Processor, ColIdefics3

    model_name = "vidore/colSmol-256M"
    processor = ColIdefics3Processor.from_pretrained(model_name)
    model = ColIdefics3.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        attn_implementation="eager",
    )

model_name

#### Loading a Sample Image and Splitting It Into Patches

In [None]:
from IPython.display import Image as JupyterImage

JupyterImage(
    filename="ro_shared_data/attention-is-all-you-need/page-2.png",
    width=480,
)

In [None]:
from PIL import Image

image = Image.open(
    "ro_shared_data/attention-is-all-you-need/page-2.png"
)
image.size

In [None]:
from helper import visualize_image_patches

fig = visualize_image_patches(
    image,
    processor,
    patch_size=getattr(model, "patch_size", 0),
    line_color="blue",
)
fig.show()

#### Exploring ColPali Tokenization

In [None]:
batch_images = processor.process_images([image]).to(model.device)
batch_images.data.keys()

In [None]:
print("Number of tokens:", len(batch_images.input_ids[0]))

In [None]:
decoded_tokens = processor.decode(batch_images.input_ids[0])
print(decoded_tokens[:50], "...", decoded_tokens[:-50])

#### Exploring Dimensionality of ColPali Document Vectors

In [None]:
with torch.no_grad():
    image_embeddings = model(**batch_images)

image_embeddings.shape

In [None]:
image_mask = processor.get_image_mask(batch_images)
masked_image_embeddings = image_embeddings[image_mask]
masked_image_embeddings.shape

#### Exploring Dimensionality of ColPali Query Vectors

In [None]:
query = "How does a single transformer layer look like?"

In [None]:
batch_queries = processor.process_queries([query]).to(model.device)
with torch.no_grad():
    query_embeddings = model(**batch_queries)
query_embeddings.shape

In [None]:
# Remove the padding tokens and the query augmentation tokens
query_content = processor.decode(batch_queries.input_ids[0])
query_content = query_content.replace(processor.tokenizer.pad_token, "")
query_content = query_content.replace(
    processor.query_augmentation_token, ""
).strip()

# Retokenize the cleaned query
query_tokens = processor.tokenizer.tokenize(query_content)

# Use this cell output to choose a token using its index
print({idx: val for idx, val in enumerate(query_tokens)})

In [None]:
# Number of patches on an image of given size
n_patches = processor.get_n_patches(
    image_size=image.size,
    # The patch_size property is only available for Colpali,
    # but it's not required for Colsmol. So, we extract it
    # with getattr to avoid error.
    patch_size=getattr(model, "patch_size", 0),
)

# Extract just the image tokens from the processed images
image_mask = processor.get_image_mask(batch_images)

#### Generating a Similarity Map for Document Page and a Query Token

In [None]:
from colpali_engine.interpretability import (
    get_similarity_maps_from_embeddings,
)

similarity_maps = get_similarity_maps_from_embeddings(
    image_embeddings=image_embeddings,
    query_embeddings=query_embeddings,
    n_patches=n_patches,
    image_mask=image_mask,
)

# Get the similarity map for the "layer" token
# The first index (0) is an image index, while the second (target_token_idx)
# is the index of the selected token
target_token_idx = next(
    idx for idx, val in enumerate(query_tokens) if "layer" in val
)
layer_similarity_mask = similarity_maps[0][target_token_idx]

In [None]:
from colpali_engine.interpretability import plot_similarity_map

plot_similarity_map(
    image=image,
    similarity_map=layer_similarity_mask,
    figsize=(8, 8),
    show_colorbar=False,
)

#### Creating the Qdrant Collection and Adding Vectors

In [None]:
from qdrant_client import QdrantClient, models

# Keep the collection and vector name for easy reference
collection_name = "colpali-experiments"
vector_name = "colpali"

# Connect to Qdrant and create a collection
client = QdrantClient("http://localhost:6333")

# Make sure the collection is empty
if client.collection_exists(collection_name):
    client.delete_collection(collection_name)

# Create it with the target configuration
client.create_collection(
    collection_name,
    vectors_config={
        vector_name: models.VectorParams(
            # Size of an individual token vector
            size=model.dim,
            # Distance function to be used for similarity
            distance=models.Distance.DOT,
            multivector_config=models.MultiVectorConfig(
                # Enable MaxSim comparison for the multivectors
                comparator=models.MultiVectorComparator.MAX_SIM,
            ),
            # Disable HNSW as it won't be used anyway
            hnsw_config=models.HnswConfigDiff(m=0),
        )
    },
)

In [None]:
from helper import load_or_compute_attention_embeddings
from tqdm import tqdm
import uuid

# Load or compute embeddings for all pages
# Set load_precomputed=True to use cached embeddings (fast)
# Set load_precomputed=False to regenerate from model (slow, 2+ min)
embeddings_df = load_or_compute_attention_embeddings(
    load_precomputed=True,
    model_name=model_name,
)

# Upsert embeddings to Qdrant
for _, row in tqdm(embeddings_df.iterrows()):
    client.upsert(
        collection_name,
        points=[
            models.PointStruct(
                # ID has to be either integer or UUID-like string
                id=uuid.uuid4().hex,
                vector={
                    vector_name: row["image_embedding"],
                },
                payload={
                    "file_path": row["file_path"],
                },
            )
        ],
    )

#### Performing Search with ColPali Vectors

In [None]:
def search(query: str, limit: int = 3) -> list[models.ScoredPoint]:
    batch_queries = processor.process_queries([query]).to(model.device)
    with torch.no_grad():
        query_embeddings = model(**batch_queries).to(
            dtype=torch.float32
        )
    return client.query_points(
        collection_name,
        query=query_embeddings[0].cpu().numpy(),
        using=vector_name,
        limit=limit,
        with_payload=True,
    ).points

In [None]:
from helper import display_search_results

results = search("model architecture")
display_search_results(results, layout="horizontal")

In [None]:
results = search("scaled dot-product attention")
display_search_results(results, layout="horizontal")

In [None]:
results = search("experiment results")
display_search_results(results, layout="horizontal")