# Quantitatively measuring exact image retrieval accuracy

## Imports

In [1]:
import base64
import hashlib
import json
from collections.abc import Iterable
from dataclasses import dataclass, asdict
from functools import lru_cache, partial
from io import BytesIO
from pathlib import Path
from uuid import UUID

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import PIL
import requests
import torch

from IPython.display import HTML
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModel

## Utility functionality

### Class to transform images

In [2]:
@dataclass
class ImageAugmenter:
    """Class that takes care of augmenting an image.
    """
    rotation: float
    crop_left: float
    crop_right: float
    crop_top: float
    crop_bottom: float
    width_scale: float
    height_scale: float

    def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
        """Augment the image, first rotating, then cropping before finally scaling.
        """
        image = image.rotate(self.rotation, resample=PIL.Image.Resampling.BICUBIC, fillcolor="white", expand=True)
        width, height = image.size

        left = round(self.crop_left*width)
        right = round(width - self.crop_right*width)
        upper = round(self.crop_top*height)
        lower = round(height - self.crop_top*height)

        new_width = round((right - left) * self.width_scale)
        new_height = round((lower - upper) * self.height_scale)

        return (
            image
            .crop((left, upper, right, lower))
            .resize((new_width, new_height), resample=PIL.Image.Resampling.LANCZOS)
        )

    @classmethod
    def init_random(
        cls,
        min_rotation: float,
        max_rotation: float,
        max_crop_left: float,
        max_crop_right: float,
        max_crop_top: float,
        max_crop_bottom: float,
        min_scale_width: float,
        max_scale_width: float,
        min_scale_height: float,
        max_scale_height: float,
        rng: np.random.Generator,
    ):
        """Create a random image augmenter.
        """
        return cls(
            rotation=rng.uniform(min_rotation, max_rotation),
            crop_left=rng.uniform(0, max_crop_left),
            crop_right=rng.uniform(0, max_crop_right),
            crop_top=rng.uniform(0, max_crop_top),
            crop_bottom=rng.uniform(0, max_crop_bottom),
            width_scale=rng.uniform(min_scale_width, max_scale_width),
            height_scale=rng.uniform(min_scale_height, max_scale_height),
        )

## Functions that interact with the API

In [3]:
def get_url_uuid(url: str) -> str:
    """Convert the URL to the ID in the Qdrant database.
    """
    sha_digest = hashlib.sha256(url.encode("utf-8")).digest()
    return UUID(bytes=sha_digest[:16]).urn.split(":")[-1]
    

def get_retrieval_test_urls() -> Iterable[str]:
    """Load the labelled dataset, extract the interesting classes and get image IDs.
    """
    label_df = pd.read_json("../data/labelled_data.json")
    interesting_labels = ["Illustration or photograph", "Map", "Mathematical chart"]
    url_mask = label_df[interesting_labels].any(axis=1)
    urls = url_mask[url_mask].index
    return urls

@lru_cache
def get_image_from_url(url: str, image_directory: Path) -> PIL.Image.Image:
    """Download the image given its ID.
    """
    id_ = get_url_uuid(url)
    filename = (image_directory / f"{id_}.jpg")
    if filename.exists():
        return PIL.Image.open(filename)

    image_directory.mkdir(exist_ok=True, parents=True)
    url = f"https://api.nb.no/dhlab/image_search/id?image_id={id_}&limit=1&embedding_type=CLIP"
    response = requests.get(url).json()

    image_url = response[0]["payload"]["image_url"]
    image_resp = requests.get(image_url)
    image_resp.raise_for_status()
    jpg_contents = image_resp.content
    filename.write_bytes(jpg_contents)
    return PIL.Image.open(filename)

@lru_cache
def get_retrieval_score(
    query_vector: tuple[float],
    target_id: str,
    embedding_type: str,
    query_size: int = 10,
) -> int:
    """Query the API and return the position of the target image in query (-1 if not there)
    """
    result = requests.post(
        f"https://api.nb.no/dhlab/image_search/vector?&limit={query_size}",
        json={
        "vector": list(query_vector),
        "embedding_type": embedding_type
        }
    ).json()

    for idx, element in enumerate(result):
        if element["id"] == target_id:
            return idx
    else:
        return -1

def get_base64_image(image: PIL.Image.Image) -> str:
    """Base64-encode a PIL image to embed it in a dataframe.
    """
    with BytesIO() as buffer:
        image.save(buffer, 'jpeg')
        return base64.b64encode(buffer.getvalue()).decode()

def get_image_thumbnail(image: PIL.Image.Image, size: tuple[int, int] = (150, 150)) -> PIL.Image.Image:
    """Get a thumbnail from a PIL image.
    """
    image_thumbnail = image.copy()
    image_thumbnail.thumbnail(size)
    return image_thumbnail

def format_image(image_base64: str) -> str:
    """Add an img-tag for base64 encode JPEG-data.
    """
    return f'<img src="data:image/jpeg;base64,{image_base64}">'


### Functions to compute the embeddings

In [4]:
def get_device() -> torch.device:
    """Get the torch device, use CUDA if available.
    """
    if torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")


@lru_cache
def load_model(model_name: str, device: torch.device) -> tuple[AutoImageProcessor, AutoModel]:
    """Get a pretrained huggingface processor and model. Note that these are cached!
    """
    processor = AutoImageProcessor.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    return processor, model


def calculate_embeddings(image: PIL.Image, model_name: str, multimodal: bool) -> tuple[float, ...]:
    """Compute embedding vectors for an image.
    """
    device = get_device()
    processor, model = load_model(model_name, device)

    processed = processor(images=image, return_tensors="pt")["pixel_values"]
    if multimodal:
        embeddings = model.get_image_features(processed.to(device))
    else:
        model_out = model(processed.to(device), output_hidden_states=True)
        embeddings = model_out.hidden_states[-1][:, 0, :]

    return tuple(embeddings.detach().cpu().numpy().squeeze().tolist())


def compute_embeddings(img: PIL.Image.Image) -> dict[str, tuple[float, ...]]:
    """Compute all relevant embedding vectors for an image.
    """
    siglip_model = "google/siglip-base-patch16-256-multilingual"
    clip_model = "openai/clip-vit-base-patch32"
    vit_model = "google/vit-base-patch16-224"

    return {
        "SigLIP": calculate_embeddings(img, siglip_model, multimodal=True),
        "CLIP": calculate_embeddings(img, clip_model, multimodal=True),
        "ViT": calculate_embeddings(img, vit_model, multimodal=False),
    }

## Run queries based on transformed images and measure accuracy

### Image transformation parameters
Set parameters for image transformations

In [5]:
min_rotation = -10
max_rotation = 10

max_crop_left = 0.15
max_crop_right = 0.15
max_crop_top = 0.15
max_crop_bottom = 0.15

min_scale_width = 0.8
max_scale_width = 1.2

min_scale_height = 0.8
max_scale_height = 1.2

embedding_models = ["ViT", "SigLIP", "CLIP"]
query_size = 100

rng = np.random.default_rng(42)

image_directory = Path("images")

### Run experiment

In [6]:
urls = get_retrieval_test_urls()

results = []

for url in tqdm(urls):
    try:
        image = get_image_from_url(url, image_directory)
    except requests.HTTPError as e:
        print(e)
        continue
    augmenter = ImageAugmenter.init_random(
        min_rotation=min_rotation,
        max_rotation=max_rotation,
        max_crop_left=max_crop_left,
        max_crop_right=max_crop_right,
        max_crop_top=max_crop_top,
        max_crop_bottom=max_crop_bottom,
        min_scale_width=min_scale_width,
        max_scale_width=max_scale_width,
        min_scale_height=min_scale_height,
        max_scale_height=max_scale_height,
        rng=rng,
    )
    augmented_image = augmenter(image)
    embeddings = compute_embeddings(augmented_image)

    retrieval_position = {
        f"{embedding_model} position": get_retrieval_score(
            tuple(embeddings[embedding_model]),
            target_id=get_url_uuid(url),
            embedding_type=embedding_model,
            query_size=query_size,
        )
        for embedding_model in embedding_models
    }
    results.append(
        {
            "url": url,
            "image_thumbnail": get_base64_image(get_image_thumbnail(image)),
            "augmented_image_thumbnail": get_base64_image(get_image_thumbnail(augmented_image)),
            **retrieval_position,
            **asdict(augmenter),
        }
    )

  0%|          | 0/684 [00:00<?, ?it/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
result_df = pd.DataFrame(results)

# Store CSV-file with results
result_df[[
    "url",
    "ViT position",
    "SigLIP position",
    "CLIP position",
    "rotation",
    "crop_left",
    "crop_right",
    "crop_top",
    "crop_bottom",
    "width_scale",
    "height_scale"
]].to_csv("../data/image_retrieval_results.csv")

## Evaluate results

### Melt the results so it's easier to create accuracy-tables

In [8]:
tidy_results = pd.melt(
    result_df,
    id_vars=["url"],
    value_vars=[f"{model} position" for model in embedding_models],
    var_name="embedding_model",
    value_name="position",
)
tidy_results["embedding_model"] = tidy_results["embedding_model"].apply(lambda x:x.removesuffix(" position"))

### Table of accuracies

#### Create Pandas accuracy table

In [9]:
accuracy_table = (
    tidy_results
    .eval(
        """
        top_1 = position == 0
        top_5 = -1 < position < 5
        top_10 = -1 < position < 10
        top_50 = -1 < position < 50
        """
    )
    .groupby("embedding_model")
    .sum()
    [["top_1", "top_5", "top_10", "top_50"]]
)
accuracy_table.columns.name = "accuracy"
accuracy_table

accuracy,top_1,top_5,top_10,top_50
embedding_model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
CLIP,492,596,613,638
SigLIP,529,633,645,665
ViT,529,582,597,612


#### Format dataframe as HTML

In [10]:
def format_count(cnt: int) -> str:
    return f"{cnt} ({cnt / len(result_df):.0%})"

latex_df = accuracy_table.copy()
latex_df.index.name = "Model"
latex_df.columns = latex_df.columns.map(lambda s: s.capitalize().replace("_", " "))
latex_df.columns.name = latex_df.columns.name.capitalize()

HTML(
    latex_df.to_html(
        formatters={col: format_count for col in latex_df.columns},
        escape=False
    )
)

Accuracy,Top 1,Top 5,Top 10,Top 50
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
CLIP,492 (72%),596 (87%),613 (90%),638 (93%)
SigLIP,529 (77%),633 (93%),645 (94%),665 (97%)
ViT,529 (77%),582 (85%),597 (87%),612 (89%)


#### Create LaTeX table

In [11]:
def format_count(cnt: int, column: pd.Series) -> str:
    data_str = f"{cnt} \\ \\ ({100 * cnt / len(result_df):.0f} \\ \\%)"
    if cnt == column.max():
        data_str = f"\\mathbf{{{data_str}}}"
    return f"\\({data_str}\\)"
        


print(
    latex_df.to_latex(
        formatters={col: partial(format_count, column=latex_df[col]) for col in latex_df.columns},
    ).replace("{lrrr}", "{@{}l@{\hspace{2em}}r@{\hspace{2em}}r@{\hspace{2em}}r@{}}")
)

\begin{tabular}{lrrrr}
\toprule
Accuracy & Top 1 & Top 5 & Top 10 & Top 50 \\
Model &  &  &  &  \\
\midrule
CLIP & \(492 \ \ (72 \ \%)\) & \(596 \ \ (87 \ \%)\) & \(613 \ \ (90 \ \%)\) & \(638 \ \ (93 \ \%)\) \\
SigLIP & \(\mathbf{529 \ \ (77 \ \%)}\) & \(\mathbf{633 \ \ (93 \ \%)}\) & \(\mathbf{645 \ \ (94 \ \%)}\) & \(\mathbf{665 \ \ (97 \ \%)}\) \\
ViT & \(\mathbf{529 \ \ (77 \ \%)}\) & \(582 \ \ (85 \ \%)\) & \(597 \ \ (87 \ \%)\) & \(612 \ \ (89 \ \%)\) \\
\bottomrule
\end{tabular}



### Show tables of images its unable to find for the different embedding types

#### SigLIP

In [12]:
HTML(
    result_df.query("`SigLIP position` == -1 or `SigLIP position` >= 50").to_html(
        formatters={
            "image_thumbnail": format_image,
            "augmented_image_thumbnail": format_image
          },
        escape=False
      )
  )

Unnamed: 0.1,Unnamed: 0,url,ViT position,SigLIP position,CLIP position,rotation,crop_left,crop_right,crop_top,crop_bottom,width_scale,height_scale
11,11,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2020022048012_0058/276,295,1775,2128/443,/0/default.jpg",0,-1,0,3.23833,0.083555,0.117585,0.099647,0.060958,1.125608,0.866789
12,12,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2013121924010_0306/732,1678,1246,356/1246,356/0/default.jpg",0,-1,-1,-9.545759,0.013507,0.108354,0.069282,0.024191,1.000418,0.860925
37,37,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2018040928002_0110/1543,943,7677,6095/479,380/0/default.jpg",0,-1,15,1.888673,0.021926,0.1237,0.04655,0.021581,1.168388,0.866213
50,50,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2012041724028_0019/90,2043,1302,542/650,271/0/default.jpg",0,-1,0,2.323314,0.025694,0.084743,0.085865,0.069898,1.009053,1.105569
55,55,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2012091808030_0084/197,2399,450,107/450,107/0/default.jpg",-1,55,-1,9.732631,0.107664,0.142679,0.017772,0.12758,1.05483,0.848769
187,187,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2010070208002_0448/1136,1179,1276,455/1276,455/0/default.jpg",-1,-1,0,8.910398,0.058873,0.138585,0.086847,0.000693,0.815417,1.072168
190,190,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2008060910006_0156/619,487,526,718/263,/0/default.jpg",0,-1,6,5.498158,0.074489,0.080473,0.144812,0.144537,1.142257,0.875044
281,281,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2015062329001_0234/966,1431,2072,2597/258,/0/default.jpg",0,-1,0,6.979568,0.07087,0.076475,0.046832,0.145419,1.09325,0.947595
284,284,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2009020200022_0078/1313,1456,165,616/165,/0/default.jpg",5,-1,58,9.746252,0.02259,0.01326,0.101129,0.05096,0.828498,0.990122
309,309,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2017101626019_0038/9,37,989,1425/494,/0/default.jpg",0,-1,0,-1.497947,0.131912,0.040551,0.134061,0.063896,0.931123,1.001005


#### CLIP

In [13]:
HTML(
    result_df.query("`CLIP position` == -1 or `CLIP position` >= 50").to_html(
        formatters={
            "image_thumbnail": format_image,
            "augmented_image_thumbnail": format_image
          },
        escape=False
      )
  )

Unnamed: 0.1,Unnamed: 0,url,ViT position,SigLIP position,CLIP position,rotation,crop_left,crop_right,crop_top,crop_bottom,width_scale,height_scale
3,3,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2012070408119_0026/448,423,1405,2217/351,/0/default.jpg",0,1,-1,-2.909481,0.145605,0.133968,0.116758,0.029196,0.986688,0.817522
12,12,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2013121924010_0306/732,1678,1246,356/1246,356/0/default.jpg",0,-1,-1,-9.545759,0.013507,0.108354,0.069282,0.024191,1.000418,0.860925
23,23,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2014031428001_0533/158,444,903,1088/451,/0/default.jpg",-1,4,-1,8.320237,0.034532,0.005612,0.083228,0.055638,1.131916,1.123301
26,26,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2014031808067_0010/200,2026,834,794/416,396/0/default.jpg",0,6,-1,0.377167,0.047389,0.115802,0.099249,0.056049,0.837787,1.098716
33,33,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2015062908059_0186/794,1069,1221,1603/305,/0/default.jpg",4,0,-1,-1.617714,0.144935,0.089406,0.139953,0.120654,0.986953,1.113905
55,55,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2012091808030_0084/197,2399,450,107/450,107/0/default.jpg",-1,55,-1,9.732631,0.107664,0.142679,0.017772,0.12758,1.05483,0.848769
62,62,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2009052510003_0023/56,1397,1848,485/1848,485/0/default.jpg",-1,0,-1,-9.196974,0.030267,0.018739,0.07568,0.111778,1.052005,1.140452
64,64,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2009042203002_0134/216,1045,659,688/329,/0/default.jpg",0,0,-1,-8.909994,0.092446,0.006353,0.132622,0.106437,0.869251,0.836688
65,65,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2010052720010_0072/1255,1493,763,738/381,369/0/default.jpg",0,0,-1,-6.329335,0.147004,0.068784,0.117612,0.095461,1.028965,0.858052
67,67,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2009052810001_0045/399,103,2763,3873/344,/0/default.jpg",0,0,84,0.167055,0.060605,0.071125,0.017883,0.020114,0.91123,0.921882


#### ViT

In [14]:
HTML(
    result_df.query("`ViT position` == -1 or `ViT position` >= 50").to_html(
        formatters={
            "image_thumbnail": format_image,
            "augmented_image_thumbnail": format_image
          },
        escape=False
      )
  )

Unnamed: 0.1,Unnamed: 0,url,ViT position,SigLIP position,CLIP position,rotation,crop_left,crop_right,crop_top,crop_bottom,width_scale,height_scale
22,22,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2009103010001_0140/686,2621,1129,1337/282,/0/default.jpg",-1,5,1,-7.828485,0.100836,0.042185,0.098913,0.109049,1.107459,0.843096
23,23,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2014031428001_0533/158,444,903,1088/451,/0/default.jpg",-1,4,-1,8.320237,0.034532,0.005612,0.083228,0.055638,1.131916,1.123301
25,25,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2009030403037_0187/487,677,2688,3669/335,/0/default.jpg",-1,0,6,-9.101788,0.065265,0.148856,0.133752,0.112291,1.156317,1.157379
44,44,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2009041503003_0031/1183,1205,1813,758/906,378/0/default.jpg",-1,0,0,-4.670734,0.020965,0.071682,0.062533,0.034885,0.947005,0.946557
48,48,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2010011103007_0039/1699,1474,615,981/307,/0/default.jpg",-1,0,0,-8.098085,0.108857,0.012674,0.140391,0.020611,1.183552,1.120354
51,51,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2008010713001_0744/241,461,1842,2467/460,/0/default.jpg",-1,0,5,5.984894,0.073823,0.089939,0.139685,0.01796,0.846841,0.835084
55,55,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2012091808030_0084/197,2399,450,107/450,107/0/default.jpg",-1,55,-1,9.732631,0.107664,0.142679,0.017772,0.12758,1.05483,0.848769
59,59,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2012071012001_0463/60,285,3338,5026/417,/0/default.jpg",-1,0,0,7.940668,0.00435,0.036124,0.021453,0.116515,0.879282,1.164255
62,62,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2009052510003_0023/56,1397,1848,485/1848,485/0/default.jpg",-1,0,-1,-9.196974,0.030267,0.018739,0.07568,0.111778,1.052005,1.140452
72,72,"https://www.nb.no/services/image/resolver/URN:NBN:no-nb_digibok_2016122229002_0434/723,273,1968,1025/491,256/0/default.jpg",-1,1,0,0.405502,0.119981,0.047168,0.125607,0.074121,0.846343,0.828824
