In [35]:
import os
import pathlib
import zipfile
from urllib.request import urlretrieve

import gdown
import torch
from ipywidgets import widgets, interact_manual, fixed
import numpy as np
import matplotlib.gridspec as gridspec
import faiss
from matplotlib import pyplot as plt

In [36]:
%load_ext autoreload
%autoreload 2

from utils import build_index, get_model, get_text_emb, load_image

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [37]:
image_dir_name = "Flicker8k_Dataset"

In [38]:
url = "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip"
out_file = "Flickr8k_Dataset.zip"

if not os.path.isdir(image_dir_name):
    urlretrieve(url, out_file)

In [39]:
if not os.path.isdir(image_dir_name):
    with zipfile.ZipFile(out_file, "r") as zip_archive:
        members = filter(lambda x: x.filename.startswith(image_dir_name), zip_archive.infolist())
        zip_archive.extractall(members=members)
    os.remove(out_file)

In [40]:
root_image_dir = pathlib.Path(image_dir_name)

In [41]:
images = list(map(str, root_image_dir.rglob("*.jpg")))
images.sort()

In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [43]:
model, img_transformation = get_model("ViT-B/32", device)

In [44]:
file_id = "10ZyrIGDrE8nIMSO4jpvultO4TuniJU-y"

In [45]:
index_name = gdown.download(id=file_id, quiet=True)

In [46]:
if not os.path.isfile(index_name):
    index = build_index(model, img_transformation, device, images)
else:
    index = faiss.read_index(index_name)

In [47]:
text_input = widgets.Textarea(description="Describe image:", placeholder="Cool image", value="A black dog")
top_n_input = widgets.IntSlider(min=1, max=10, value=5, description="Number of images to find:")

In [48]:
def show_most_sim_image(text, model, index, device, image_paths, top_n: int):
    text_emb = get_text_emb(model, text, device).astype(np.float32)
    cos_sim, indices = index.search(text_emb, top_n)

    # 1 - 0.5 * (1 - cos_sim)
    distance = 0.5 * (1 + cos_sim[0])
    indices = indices[0]

    images = [load_image(image_paths[index]) for index in indices]

    total_samples = len(images)

    fig = plt.figure(figsize=(20, 60))
    gs = gridspec.GridSpec(total_samples, 2)
    axes = fig.add_subplot(gs[0, 1])

    y_ticks = tuple(range(total_samples))
    axes.barh(y_ticks, distance[::-1], height=0.3)
    axes.set_yticks(y_ticks)
    labels = [f"Image {i + 1}" for i in y_ticks[::-1]]
    axes.set_yticklabels(labels)
    axes.grid(True, axis="x")
    axes.set_title("Relevance estimation")

    for i in range(total_samples):
        ax = fig.add_subplot(gs[i, 0])
        ax.imshow(images[i])
        ax.set_title(f"{labels[total_samples - i - 1]} {os.path.basename(image_paths[indices[i]])}")
        ax.set_axis_off()
    plt.tight_layout()

In [49]:
interact_manual(show_most_sim_image, text=text_input, model=fixed(model), index=fixed(index), device=fixed(device), image_paths=fixed(images), top_n=top_n_input)

interactive(children=(Textarea(value='A black dog', description='Describe image:', placeholder='Cool image'), …

<function __main__.show_most_sim_image(text, model, index, device, image_paths, top_n: int)>