# Quantitatively measuring exact image retrieval accuracy

## Imports

In [1]:
import base64
import hashlib
import json
from collections.abc import Iterable
from dataclasses import dataclass, asdict
from functools import lru_cache, partial
from io import BytesIO
from pathlib import Path
from uuid import UUID

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import PIL
import requests
import torch

from IPython.display import HTML
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel

## Utility functionality

### Class to transform images

In [2]:
@dataclass
class ImageAugmenter:
    """Class that takes care of augmenting an image.
    """
    rotation: float
    crop_left: float
    crop_right: float
    crop_top: float
    crop_bottom: float
    width_scale: float
    height_scale: float

    def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
        """Augment the image, first rotating, then cropping before finally scaling.
        """
        image = image.rotate(self.rotation, resample=PIL.Image.Resampling.BICUBIC, fillcolor="white", expand=True)
        width, height = image.size

        left = round(self.crop_left*width)
        right = round(width - self.crop_right*width)
        upper = round(self.crop_top*height)
        lower = round(height - self.crop_top*height)

        new_width = round((right - left) * self.width_scale)
        new_height = round((lower - upper) * self.height_scale)

        return (
            image
            .crop((left, upper, right, lower))
            .resize((new_width, new_height), resample=PIL.Image.Resampling.LANCZOS)
        )

    @classmethod
    def init_random(
        cls,
        min_rotation: float,
        max_rotation: float,
        max_crop_left: float,
        max_crop_right: float,
        max_crop_top: float,
        max_crop_bottom: float,
        min_scale_width: float,
        max_scale_width: float,
        min_scale_height: float,
        max_scale_height: float,
        rng: np.random.Generator,
    ):
        """Create a random image augmenter.
        """
        return cls(
            rotation=rng.uniform(min_rotation, max_rotation),
            crop_left=rng.uniform(0, max_crop_left),
            crop_right=rng.uniform(0, max_crop_right),
            crop_top=rng.uniform(0, max_crop_top),
            crop_bottom=rng.uniform(0, max_crop_bottom),
            width_scale=rng.uniform(min_scale_width, max_scale_width),
            height_scale=rng.uniform(min_scale_height, max_scale_height),
        )

## Functions that interact with the API

In [3]:
def get_url_uuid(url: str) -> str:
    """Convert the URL to the ID in the Qdrant database.
    """
    sha_digest = hashlib.sha256(url.encode("utf-8")).digest()
    return UUID(bytes=sha_digest[:16]).urn.split(":")[-1]
    

def get_retrieval_test_urls() -> Iterable[str]:
    """Load the labelled dataset, extract the interesting classes and get image IDs.
    """
    label_df = pd.read_json("../data/labelled_data.json")
    interesting_labels = ["Illustration or photograph", "Map", "Mathematical chart"]
    url_mask = label_df[interesting_labels].any(axis=1)
    urls = url_mask[url_mask].index
    return urls

@lru_cache
def get_image_from_url(url: str, image_directory: Path) -> PIL.Image.Image:
    """Download the image given its ID.
    """
    id_ = get_url_uuid(url)
    filename = (image_directory / f"{id_}.jpg")
    if filename.exists():
        return PIL.Image.open(filename)

    image_directory.mkdir(exist_ok=True, parents=True)
    url = f"https://api.nb.no/dhlab/image_search/id?image_id={id_}&limit=1&embedding_type=CLIP"
    response = requests.get(url).json()

    image_url = response[0]["payload"]["image_url"]
    image_resp = requests.get(image_url)
    image_resp.raise_for_status()
    jpg_contents = image_resp.content
    filename.write_bytes(jpg_contents)
    return PIL.Image.open(filename)

@lru_cache
def get_retrieval_score(
    query_vector: tuple[float],
    target_id: str,
    embedding_type: str,
    query_size: int = 10,
) -> int:
    """Query the API and return the position of the target image in query (-1 if not there)
    """
    result = requests.post(
        f"https://api.nb.no/dhlab/image_search/vector?&limit={query_size}",
        json={
        "vector": list(query_vector),
        "embedding_type": embedding_type
        }
    ).json()

    for idx, element in enumerate(result):
        if element["id"] == target_id:
            return idx
    else:
        return -1

def get_base64_image(image: PIL.Image.Image) -> str:
    """Base64-encode a PIL image to embed it in a dataframe.
    """
    with BytesIO() as buffer:
        image.save(buffer, 'jpeg')
        return base64.b64encode(buffer.getvalue()).decode()

def get_image_thumbnail(image: PIL.Image.Image, size: tuple[int, int] = (150, 150)) -> PIL.Image.Image:
    """Get a thumbnail from a PIL image.
    """
    image_thumbnail = image.copy()
    image_thumbnail.thumbnail(size)
    return image_thumbnail

def format_image(image_base64: str) -> str:
    """Add an img-tag for base64 encode JPEG-data.
    """
    return f'<img src="data:image/jpeg;base64,{image_base64}">'


### Functions to compute the embeddings

In [4]:
def get_device() -> torch.device:
    """Get the torch device, use CUDA if available.
    """
    if torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")


@lru_cache
def load_model(model_name: str, device: torch.device) -> tuple[AutoImageProcessor, AutoModel]:
    """Get a pretrained huggingface processor and model. Note that these are cached!
    """
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    return processor, model


def calculate_embeddings(image: PIL.Image, model_name: str, multimodal: bool) -> tuple[float, ...]:
    """Compute embedding vectors for an image.
    """
    device = get_device()
    processor, model = load_model(model_name, device)

    processed = processor(images=image, return_tensors="pt")["pixel_values"]
    if multimodal:
        embeddings = model.get_image_features(processed.to(device))
    else:
        model_out = model(processed.to(device), output_hidden_states=True)
        embeddings = model_out.hidden_states[-1][:, 0, :]

    return tuple(embeddings.detach().cpu().numpy().squeeze().tolist())


def compute_embeddings(img: PIL.Image.Image) -> dict[str, tuple[float, ...]]:
    """Compute all relevant embedding vectors for an image.
    """
    siglip_model = "google/siglip-base-patch16-256-multilingual"
    clip_model = "openai/clip-vit-base-patch32"
    vit_model = "google/vit-base-patch16-224"

    return {
        "SigLIP": calculate_embeddings(img, siglip_model, multimodal=True),
        "CLIP": calculate_embeddings(img, clip_model, multimodal=True),
        "ViT": calculate_embeddings(img, vit_model, multimodal=False),
    }

## Run queries based on transformed images and measure accuracy

### Image transformation parameters
Set parameters for image transformations

In [5]:
min_rotation = -10
max_rotation = 10

max_crop_left = 0.15
max_crop_right = 0.15
max_crop_top = 0.15
max_crop_bottom = 0.15

min_scale_width = 0.8
max_scale_width = 1.2

min_scale_height = 0.8
max_scale_height = 1.2

embedding_models = ["ViT", "SigLIP", "CLIP"]
query_size = 100

rng = np.random.default_rng(42)

image_directory = Path("images")

### Run experiment

In [6]:
urls = get_retrieval_test_urls()

results = []

for url in tqdm(urls):
    try:
        image = get_image_from_url(url, image_directory)
    except requests.HTTPError as e:
        print(e)
        continue
    augmenter = ImageAugmenter.init_random(
        min_rotation=min_rotation,
        max_rotation=max_rotation,
        max_crop_left=max_crop_left,
        max_crop_right=max_crop_right,
        max_crop_top=max_crop_top,
        max_crop_bottom=max_crop_bottom,
        min_scale_width=min_scale_width,
        max_scale_width=max_scale_width,
        min_scale_height=min_scale_height,
        max_scale_height=max_scale_height,
        rng=rng,
    )
    augmented_image = augmenter(image)
    embeddings = compute_embeddings(augmented_image)

    retrieval_position = {
        f"{embedding_model} position": get_retrieval_score(
            tuple(embeddings[embedding_model]),
            target_id=get_url_uuid(url),
            embedding_type=embedding_model,
            query_size=query_size,
        )
        for embedding_model in embedding_models
    }
    results.append(
        {
            "url": url,
            "image_thumbnail": get_base64_image(get_image_thumbnail(image)),
            "augmented_image_thumbnail": get_base64_image(get_image_thumbnail(augmented_image)),
            **retrieval_position,
            **asdict(augmenter),
        }
    )

  0%|          | 0/684 [00:00<?, ?it/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
result_df = pd.DataFrame(results)

# Store CSV-file with results
result_df[[
    "url",
    "ViT position",
    "SigLIP position",
    "CLIP position",
    "rotation",
    "crop_left",
    "crop_right",
    "crop_top",
    "crop_bottom",
    "width_scale",
    "height_scale"
]].to_csv("../data/image_retrieval_results.csv")

## Evaluate results

### Melt the results so it's easier to create accuracy-tables

In [8]:
tidy_results = pd.melt(
    result_df,
    id_vars=["id"],
    value_vars=[f"{model} position" for model in embedding_models],
    var_name="embedding_model",
    value_name="position",
)
tidy_results["embedding_model"] = tidy_results["embedding_model"].apply(lambda x:x.removesuffix(" position"))

KeyError: "The following id_vars or value_vars are not present in the DataFrame: ['id']"

### Table of accuracies

#### Create Pandas accuracy table

In [None]:
accuracy_table = (
    tidy_results
    .eval(
        """
        top_1 = position == 0
        top_5 = -1 < position < 5
        top_10 = -1 < position < 10
        top_50 = -1 < position < 50
        """
    )
    .groupby("embedding_model")
    .sum()
    [["top_1", "top_5", "top_10", "top_50"]]
)
accuracy_table.columns.name = "accuracy"
accuracy_table

#### Format dataframe as HTML

In [None]:
def format_count(cnt: int) -> str:
    return f"{cnt} ({cnt / len(result_df):.0%})"

latex_df = accuracy_table.copy()
latex_df.index.name = "Model"
latex_df.columns = latex_df.columns.map(lambda s: s.capitalize().replace("_", " "))
latex_df.columns.name = latex_df.columns.name.capitalize()

HTML(
    latex_df.to_html(
        formatters={col: format_count for col in latex_df.columns},
        escape=False
    )
)

#### Create LaTeX table

In [None]:
def format_count(cnt: int, column: pd.Series) -> str:
    data_str = f"{cnt} \\ \\ ({100 * cnt / len(result_df):.0f} \\ \\%)"
    if cnt == column.max():
        data_str = f"\\mathbf{{{data_str}}}"
    return f"\\({data_str}\\)"
        


print(
    latex_df.to_latex(
        formatters={col: partial(format_count, column=latex_df[col]) for col in latex_df.columns},
    ).replace("{lrrr}", "{@{}l@{\hspace{2em}}r@{\hspace{2em}}r@{\hspace{2em}}r@{}}")
)

### Show tables of images its unable to find for the different embedding types

#### SigLIP

In [None]:
HTML(
    result_df.query("`SigLIP position` == -1 or `SigLIP position` >= 50").to_html(
        formatters={
            "image_thumbnail": format_image,
            "augmented_image_thumbnail": format_image
          },
        escape=False
      )
  )

#### CLIP

In [None]:
HTML(
    result_df.query("`CLIP position` == -1 or `CLIP position` >= 50").to_html(
        formatters={
            "image_thumbnail": format_image,
            "augmented_image_thumbnail": format_image
          },
        escape=False
      )
  )

#### ViT

In [None]:
HTML(
    result_df.query("`ViT position` == -1 or `ViT position` >= 50").to_html(
        formatters={
            "image_thumbnail": format_image,
            "augmented_image_thumbnail": format_image
          },
        escape=False
      )
  )