<a href="https://colab.research.google.com/github/abhijeet3922/vision-RAG/blob/main/6_matching_interpretability_with_heatmaps.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Interpretability by superimposing the late interaction heatmap on original image

This notebook describes interpretzbility aspect of colpali model to understand how and why the model makes specific predictions.

We can visulaize image patches matching to each uery term by superimposing the late interaction heatmap on original image.

To demostrate interpretability, we perform following steps:

* Create image and query embeddings.
* Get similarity maps using API provided by colpali-engine
* Plot maps

Acknowledgement

This notebook uses code or ideas from the following github repository:
https://github.com/illuin-tech/colpali

Installations & imports

In [None]:
!pip install pdf2image
!pip install colpali-engine[interpretability]
!sudo apt-get install poppler-utils

In [None]:
from PIL import Image
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from pdf2image import convert_from_path
from torch.utils.data import DataLoader
from colpali_engine.models import ColPali, ColPaliProcessor
from colpali_engine.interpretability import (
    get_similarity_maps_from_embeddings,
    plot_all_similarity_maps,)

Load model

In [None]:
model_name = "vidore/colpali-v1.3"
model = ColPali.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",  # or "mps" if on Apple Silicon
).eval()
processor = ColPaliProcessor.from_pretrained(model_name)

Process document, identify page to visulaize.

In [None]:
images = convert_from_path("/content/google-alphabet-2024.pdf")
print("Number of pages:", len(images))

In [None]:
image = images[10]

In [None]:
query = ["Revenue Google Cloud 2023 and 2024 ?"]
query[0]

Process image and query using colpali

In [None]:
batch_images = processor.process_images([image]).to(model.device)
batch_queries = processor.process_queries(query[0]).to(model.device)

with torch.no_grad():
  image_embeddings = model.forward(**batch_images)
  query_embeddings = model.forward(**batch_queries)

In [None]:
# Get the number of image patches
n_patches = processor.get_n_patches(image_size=image.size, patch_size=model.patch_size)
# Get the tensor mask to filter out the embeddings that are not related to the image
image_mask = processor.get_image_mask(batch_images)

In [None]:
query_embeddings.shape, image_embeddings.shape

In [None]:
# Generate the similarity maps
batched_similarity_maps = get_similarity_maps_from_embeddings(
    image_embeddings=image_embeddings,
    query_embeddings=query_embeddings,
    n_patches=n_patches,
    image_mask=image_mask,
)

In [None]:
# Get the similarity map for our (only) input image
similarity_maps = batched_similarity_maps[0]  # (query_length, n_patches_x, n_patches_y)
# Tokenize the query
query_tokens = processor.tokenizer.tokenize(query[0])


In [None]:
similarity_maps.shape, len(query_tokens)

Plot similarity maps

In [None]:
# Plot and save the similarity maps for each query token
plots = plot_all_similarity_maps(
    image=image,
    query_tokens=query_tokens,
    similarity_maps=similarity_maps,
)
#for idx, (fig, ax) in enumerate(plots):
#    fig.savefig(f"similarity_map_{idx}.png")