In [1]:
from IPython.display import HTML

HTML('''<script>
code_show=true; 
function code_toggle() {
 if (code_show){
 $('div.input').hide();
 } else {
 $('div.input').show();
 }
 code_show = !code_show
} 
$( document ).ready(code_toggle);
</script>
<form action="javascript:code_toggle()"><input type="submit" value="Click here to toggle on/off the raw code."></form>''')

In [2]:
import os
import random
from io import BytesIO
import urllib
import logging

########
import extract_image_vectors

import tensorflow as tf

from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input
import numpy as np
import faiss
from IPython.display import Image
from ipywidgets import widgets
from IPython.display import clear_output
# Clear tensorflow warnings.
# Note, no solutions from the following work here:
# https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information
clear_output()

In [3]:
model = extract_image_vectors.init_model()
# Clear tensorflow warnings.
clear_output()

In [4]:
def is_filename_url(filename):
    return filename.startswith('http')

In [5]:
def predict_single_image(model, filename):
    target_size = (224, 224)
    #Simple hack for urls
    if is_filename_url(filename):
        with urllib.request.urlopen(filename) as url:
            img = image.load_img(BytesIO(url.read()), target_size=target_size)
    else:
        img = image.load_img(filename, target_size=target_size)
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)

    features = model.predict(x, verbose=0)
    return features

In [6]:
def read_filename_list(filename):
    res = []
    with open(filename, 'r') as f:
        for line in f:
            res.append(line.strip('\n'))
    return res

In [7]:
index_filename = 'ms_coco_file_list.faiss.index'
index = faiss.read_index(index_filename)

In [8]:
fl = 'ms_coco_file_list.txt'
filename_list = read_filename_list(fl)

In [9]:
pred_file = fl+'.saved_predictions.npy'
predictions = np.load(pred_file)

In [10]:
def search_and_display_similar_images(w_output, k = 8, filename = None, pred_idx = 0):
    images_dir = 'ms_coco/val2014'
    if filename:
        query_filename = filename
    else:
        query_filename = os.path.join(images_dir, filename_list[pred_idx])

    query = predict_single_image(model, query_filename)
    sims, ids = index.search(query, k)

    output_image_width = 200
    with w_output:
        display("Input:")
        if is_filename_url(query_filename):
            display(Image(url=query_filename, width=output_image_width))
        else:
            display('Image id: %d' % pred_idx)
            display(Image(filename=query_filename, width=output_image_width))

    i_list = []
    res_list = []
    images_in_row = 4
    for i, id_i in enumerate(ids[0]):
        i_list.append(
            widgets.VBox([
                widgets.Label(value=filename_list[id_i]),
                widgets.Image(value=open(os.path.join(images_dir, filename_list[id_i]), 'rb').read(),
                              width=output_image_width),
                widgets.Label(value='score: %.2f' % sims[0][i]),
            ])
        )
        if len(i_list) % images_in_row == 0:
            res_list.append(widgets.HBox(i_list))
            i_list = []
    if i_list:
        res_list.append(widgets.HBox(i_list))
    with w_output:
        display("Output:")
        display(widgets.VBox(res_list))

In [11]:
output_images = widgets.Output()
btn = widgets.Button(description='Query random image')
url = 'https://www.mlprague.com/image/imgproxy.php?w=360&h=480&img=https://cms.mlprague.com/upload/7abfd6f3012d8012cb94ffee6f647dd09ebee4ee.png&crop=true'
url_box = widgets.Text(value=url, placeholder='Insert image url, press enter')
url_btn = widgets.Button(description='Search')
#######
def btn_eh(obj):
    output_images.clear_output()
    search_and_display_similar_images(output_images, pred_idx = random.randrange(len(filename_list)))

def url_box_eh(obj):
    output_images.clear_output()
    search_and_display_similar_images(output_images,filename=obj.value)

def url_btn_eh(eh):
    output_images.clear_output()
    search_and_display_similar_images(output_images,filename=url_box.value)
    
btn.on_click(btn_eh)
url_box.on_submit(url_box_eh)
url_btn.on_click(url_btn_eh)
display(widgets.HBox([btn, widgets.Label(value='or enter an url:'), url_box, url_btn]))

HBox(children=(Button(description='Query random image', style=ButtonStyle()), Label(value='or enter an url:'),…

In [12]:
display(output_images)

Output()