In [11]:
import os
import pathlib
import zipfile

from urllib.request import urlretrieve

import torch
from ipywidgets import widgets, interact_manual, fixed
import numpy as np
from matplotlib import pyplot as plt

In [18]:
%load_ext autoreload
%autoreload 2

from utils import build_index, get_model, get_text_emb, load_image

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
url = "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip"
out_file = "Flickr8k_Dataset.zip"

if not os.path.exists(out_file):
    urlretrieve(url, out_file)

In [4]:
image_dir_name = "Flicker8k_Dataset"

if not os.path.isdir(image_dir_name):
    with zipfile.ZipFile(out_file, "r") as zip_archive:
        members = filter(lambda x: x.filename.startswith(image_dir_name), zip_archive.infolist())
        zip_archive.extractall(members=members)

In [5]:
root_image_dir = pathlib.Path(image_dir_name)

In [6]:
images = list(map(str, root_image_dir.rglob("*.jpg")))

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
model, img_transformation = get_model("ViT-B/32", device)

In [10]:
index = build_index(model, img_transformation, device, images)

Encode images: 100%|██████████| 253/253 [05:28<00:00,  1.30s/it]


In [15]:
text_input = widgets.Text()
top_n_input = widgets.IntSlider(min=1, max=20, value=10)

In [19]:
def show_most_sim_image(text, model, index, device, image_paths, top_n: int):
    text_emb = get_text_emb(model, text, device).astype(np.float32)
    _, indices = index.search(text_emb, top_n)

    images = []

    for index in indices[0]:
        image = load_image(image_paths[index])
        images.append(image)
    
    fig = plt.figure(figsize=(20, 60))
    axes = fig.subplots(len(images), 1)

    for ax, img in zip(axes.flat, images):
        ax.imshow(img)
        ax.set_axis_off()

In [20]:
interact_manual(show_most_sim_image, text=text_input, model=fixed(model), index=fixed(index), device=fixed(device), image_paths=fixed(images), top_n=top_n_input)

interactive(children=(Text(value='', description='text'), IntSlider(value=10, description='top_n', max=20, min…

<function __main__.show_most_sim_image(text, model, index, device, image_paths, top_n: int)>