# Image retrieval with example queries

## Imports

In [1]:
from functools import lru_cache
from pathlib import Path
from typing import TypeVar, Generator, Iterable, Literal

import numpy as np
import pandas as pd
import PIL
import requests
import torch
from IPython.display import HTML
from transformers import AutoImageProcessor, AutoModel

In [2]:
T = TypeVar("T")
K1 = TypeVar("K1")
K2 = TypeVar("K2")
EmbeddingType = Literal["SigLIP", "CLIP", "ViT"]

## Utility functionality

### Functions to compute the embeddings

In [3]:
def get_device() -> torch.device:
    """Get the torch device, use CUDA if available.
    """
    if torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")


@lru_cache
def load_model(model_name: str, device: torch.device) -> tuple[AutoImageProcessor, AutoModel]:
    """Get a pretrained huggingface processor and model. Note that these are cached!
    """
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    return processor, model


def calculate_embeddings(image: PIL.Image.Image, model_name: str, multimodal: bool) -> tuple[float, ...]:
    """Compute embedding vectors for an image.
    """
    device = get_device()
    processor, model = load_model(model_name, device)

    processed = processor(images=image, return_tensors="pt")["pixel_values"]
    if multimodal:
        embeddings = model.get_image_features(processed.to(device))
    else:
        model_out = model(processed.to(device), output_hidden_states=True)
        embeddings = model_out.hidden_states[-1][:, 0, :]

    return tuple(embeddings.detach().cpu().numpy().squeeze().tolist())


def compute_embeddings(img: PIL.Image.Image) -> dict[EmbeddingType, tuple[float, ...]]:
    """Compute all relevant embedding vectors for an image.
    """
    siglip_model = "google/siglip-base-patch16-256-multilingual"
    clip_model = "openai/clip-vit-base-patch32"
    vit_model = "google/vit-base-patch16-224"

    return {
        "SigLIP": calculate_embeddings(img, siglip_model, multimodal=True),
        "CLIP": calculate_embeddings(img, clip_model, multimodal=True),
        "ViT": calculate_embeddings(img, vit_model, multimodal=False),
    }

### Functions to find similar images

In [4]:
@lru_cache
def find_similar_images_from_embedding(
    query_vector: tuple[float, ...],
    embedding_type: EmbeddingType,
    query_size: int,
) -> list[str]:
    """Query the database using the given embedding and return the image urls.
    """
    results = requests.post(
      f"https://api.nb.no/dhlab/image_search/vector?&limit={query_size}",
      json={
        "vector": list(query_vector),
        "embedding_type": embedding_type
      }
    ).json()

    return [result["payload"]["image_url"] for result in results]

def find_all_similar_images_from_embeddings(
    query_vectors: dict[EmbeddingType, tuple[float, ...]], query_size: int = 10
) -> dict[EmbeddingType, list[str]]:
    """Query the database with all embedding types for a given image.
    """
    return {
        embedding_type: find_similar_images_from_embedding(query_vector, embedding_type, query_size)
        for embedding_type, query_vector in query_vectors.items()
    }

### Functions to load the images

In [5]:
def load_images(image_parent: Path) -> dict[Path, PIL.Image.Image]:
    """Load all ".jpg" images in the given directory.
    """
    return {img_file: PIL.Image.open(img_file) for img_file in sorted(image_parent.glob("*.jpg"))}

def compute_all_embeddings(image_parent: Path) -> dict[Path, tuple[float, ...]]:
    """Compute embeddings for all ".jpg" images in the given directory.
    """
    return {img_file: compute_embeddings(img) for img_file, img in load_images(image_parent).items()}

def find_all_similar_images(
    image_parent, query_size: int = 10
) -> dict[Path, list[str]]:
    """Find all similar images for all ".jpg" images in the given directory.
    """
    return {
        img_file: find_all_similar_images_from_embeddings(emb, query_size)
        for img_file, emb in compute_all_embeddings(image_parent).items()
    }

def download_from_url(url: str, download_path: Path) -> Path:
    """Given a url, download its content and store at the given location.
    """
    resp = requests.get(url)
    resp.raise_for_status()
    download_path.write_bytes(resp.content)
    
    return download_path


def download_similar_images(download_path: Path, similar_image_urls: list[str]) -> list[str]:
    """Download all files from a list of URLs (assumed to point to JPEG images)

    The files will be named in numerical order based on their location in the list.
    """
    download_path.mkdir(parents=True, exist_ok=True)

    return [
        download_from_url(url, download_path / f"{i:02d}.jpg") for i, url in enumerate(similar_image_urls)
    ]

    
def download_all_similar_images(
    download_path: Path, all_similar_image_urls: dict[EmbeddingType, list[str]]
) -> dict[str, list[str]]:
    """Given a dictionary of embedding types that point to urls, iterate over the urls and download them.
    """
    return {
        emb_type: download_similar_images(download_path / emb_type, similar_image_urls)
        for emb_type, similar_image_urls in all_similar_image_urls.items()
    }
    

def find_and_download_all_similar_images(
    image_parent: Path, destination: Path, query_size: int = 10
) -> dict[Path, list[str]]:
    """Given an image, compute the embeddings, find the URLs for the similar images and download them.
    """
    return {
        img_file: download_all_similar_images(destination / img_file.stem, similar_image_urls)
        for img_file, similar_image_urls in find_all_similar_images(image_parent, query_size).items()
    }


def squeeze_nested_dict(nested_dict: dict[K1, dict[K2, T]]) -> dict[tuple[K1, K2], T]:
    """Given a nested dictionary, flatten it so we index `dict[k1, k2]` instead of `dict[k1][k2]`.
    """
    return {(k1, k2): v for k1, d in nested_dict.items() for k2, v in d.items()}

## Download images and form dataframe

In [6]:
data_path = Path("../data/analysis/query_examples")
download_path = data_path / "query_results"
num_similar_images = 5


df = pd.DataFrame(
    squeeze_nested_dict(find_and_download_all_similar_images(
        data_path,
        download_path,
        num_similar_images,
    ))
).T
df.index.names = ["Image path", "Embedding type"]
df.columns = [str(i + 1) for i in df.columns]
df.columns.name = "Image position"

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Format dataframe as HTML

In [7]:
def format_image(image_path: str) -> str:
    """Format an image path using an <img> tag"""
    return f'<img src="{image_path}" width=120px>'

def get_index_formatter():
    """Create a formatter that displays each image only once.
    
    We need this function since pandas to_html doesn't support multiindices well.
    To get a nice multi-index, we therefore reset the index and display the images
    in the index-column only once.
    """
    _seen = set()
    def format_index(idx: str) -> str:
        if idx in _seen:
            return ""
        _seen.add(idx)
        return format_image(idx)
    return format_index


In [8]:
HTML(
    df.sort_index().reset_index().to_html(
        formatters={c: format_image for c in df.columns} | {"Image path": get_index_formatter()},
        escape=False,
        index=False,
        index_names=False
      )
)

Image path,Embedding type,1,2,3,4,5
,CLIP,,,,,
,SigLIP,,,,,
,ViT,,,,,
,CLIP,,,,,
,SigLIP,,,,,
,ViT,,,,,
,CLIP,,,,,
,SigLIP,,,,,
,ViT,,,,,
,CLIP,,,,,


## Format dataframe as LaTeX table

There are some issues with images in tables so the output LaTeX table needs some tweaking to avoid overfull vboxes, but this forms a nice starting point.

In [9]:
def format_latex_image(path):
    latex_path = str(Path("images") / path.relative_to(data_path.parent))
    return r"\includegraphics[width=0.12\textwidth]{PATH}".replace("PATH", latex_path)

latex_df = df.applymap(format_latex_image)
latex_df.index.names = ["Query image", "Model"]
latex_df.columns.name = None
latex_df.columns = [f"Pos. {i}" for i in  latex_df.columns]

latex_df = latex_df.reset_index()
latex_df["Query image"] = latex_df["Query image"].map(format_latex_image)
latex_df = latex_df.set_index(df.index.names)

  latex_df = df.applymap(format_latex_image)


In [10]:
print(latex_df.to_latex(longtable=True))

\begin{longtable}{lllllll}
\toprule
 &  & Pos. 1 & Pos. 2 & Pos. 3 & Pos. 4 & Pos. 5 \\
Query image & Model &  &  &  &  &  \\
\midrule
\endfirsthead
\toprule
 &  & Pos. 1 & Pos. 2 & Pos. 3 & Pos. 4 & Pos. 5 \\
Query image & Model &  &  &  &  &  \\
\midrule
\endhead
\midrule
\multicolumn{7}{r}{Continued on next page} \\
\midrule
\endfoot
\bottomrule
\endlastfoot
\multirow[t]{3}{*}{\includegraphics[width=0.12\textwidth]{images/query_examples/1-ivar_aasen.jpg}} & SigLIP & \includegraphics[width=0.12\textwidth]{images/query_examples/query_results/1-ivar_aasen/SigLIP/00.jpg} & \includegraphics[width=0.12\textwidth]{images/query_examples/query_results/1-ivar_aasen/SigLIP/01.jpg} & \includegraphics[width=0.12\textwidth]{images/query_examples/query_results/1-ivar_aasen/SigLIP/02.jpg} & \includegraphics[width=0.12\textwidth]{images/query_examples/query_results/1-ivar_aasen/SigLIP/03.jpg} & \includegraphics[width=0.12\textwidth]{images/query_examples/query_results/1-ivar_aasen/SigLIP/04.jpg} \\
