In [1]:
from IPython.display import HTML
HTML('''<script>
code_show=true; 
function code_toggle() {
 if (code_show){
     $('div.input').hide();
     $('div.cell').css({'padding': "0px"})
 } else {
     $('div.input').show();
     $('div.cell').css({'padding': "5px"})
 }
 code_show = !code_show
} 
$( document ).ready(code_toggle);
</script>
<form action="javascript:code_toggle()"><input type="submit" value="Click here to toggle on/off the raw code."></form>''')

In [2]:
import os
import random
from io import BytesIO
import urllib
import logging

import functools

########
import extract_image_vectors

import tensorflow as tf

from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input
import numpy as np
import faiss
from IPython.display import Image
from ipywidgets import widgets
from IPython.display import clear_output
# Clear tensorflow warnings.
# Note, no solutions from the following work here:
# https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information
clear_output()

In [3]:
model = extract_image_vectors.init_model()
# Clear tensorflow warnings.
clear_output()

In [4]:
def is_filename_url(filename):
    """
    Test if the string is url
    
    Args:
        filename: input string
    Returns:
        True if the input is url, False otherwise
    """
    # Hack for urls
    return filename.startswith('http')

In [5]:
def predict_single_image(model, filename):
    """
    Compute feature vector for a single image
    
    Args:
        model: Keras model for prediction
        filename: filename of the input image
    Returns:
        numpy array with features, 4096 for VGG model
    """
    target_size = (224, 224)
    #Simple hack for urls
    if is_filename_url(filename):
        with urllib.request.urlopen(filename) as url:
            img = image.load_img(BytesIO(url.read()), target_size=target_size)
    else:
        img = image.load_img(filename, target_size=target_size)
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)

    features = model.predict(x, verbose=0)
    return features

In [6]:
def read_filename_list(filename):
    """
    Read filenames from the text file into a list
    
    Args:
        filename: file containing filenames
    Returns:
        list with filenames from the input file
    """
    res = []
    with open(filename, 'r') as f:
        for line in f:
            res.append(line.strip('\n'))
    return res

In [7]:
fl = 'ms_coco_file_list.txt'
filename_list = read_filename_list(fl)

In [8]:
pred_file = fl+'.saved_predictions.npy'
predictions = np.load(pred_file)

In [9]:
index_filename = 'ms_coco_file_list.faiss.index'
index = faiss.read_index(index_filename)

# Flat index for testing
#dim = 4096
#index = faiss.IndexFlatIP(dim)
#index.add(predictions)

In [10]:
def search_and_display_similar_images(w_output, output_button_eh, k = 8, filename = None, pred_idx = 0):
    """
    Search input image from a file or an url in the datbase,
    displays the results
    
    Args:
        w_output:         widgets.Output() for displaing the results
        output_button_eh: event handler (function) for search buttons under each image result
        k:                number of similar images returned from the database
        filename:         input filename, local file or url. If empty, pred_idx is used
        pred_idx:         index of file from the database used as an input.
    Returns:
        Nothing
    """
    images_dir = 'ms_coco/val2014'
    if filename:
        query_filename = filename
    else:
        query_filename = os.path.join(images_dir, filename_list[pred_idx])

    # Search similar images
    query = predict_single_image(model, query_filename)
    faiss.normalize_L2(query)
    sims, ids = index.search(query, k)

    # Display input image
    output_image_width = 200
    with w_output:
        display("Input:")
        if is_filename_url(query_filename):
            display(Image(url=query_filename, width=output_image_width))
        else:
            display('Image id: %d' % pred_idx)
            display(Image(filename=query_filename, width=output_image_width))

    # Display output images
    i_list = []
    res_list = []
    images_in_row = 4
    layout = widgets.Layout(width='auto')
    for i, id_i in enumerate(ids[0]):
        im = widgets.Image(value=open(os.path.join(images_dir, filename_list[id_i]), 'rb').read(),
                              width=output_image_width)
        im_button = widgets.Button(description='Search', layout=layout)
        im_button.on_click(functools.partial(output_btn_eh, id_i))
        i_list.append(
            widgets.VBox([
                widgets.Label(value=filename_list[id_i]),
                im,
                widgets.HBox([
                    widgets.Label(value='distance: %.2f' % sims[0][i]),
                    im_button
                ])
            ])
        )
        if len(i_list) % images_in_row == 0:
            res_list.append(widgets.HBox(i_list))
            i_list = []
    if i_list:
        res_list.append(widgets.HBox(i_list))
    with w_output:
        display("Output:")
        display(widgets.VBox(res_list))

In [11]:
output_images = widgets.Output()
btn = widgets.Button(description='Query random image')
url = 'https://www.visitczechrepublic.com/cms/getmedia/1be239a4-2f42-4b57-88ff-c6ccc1c938a5/shutterstock_1161049588_Brno.jpg?width=768'
url_box = widgets.Text(value=url, placeholder='Insert image url, press enter')
url_btn = widgets.Button(description='Search')

# Callback function for "Search" button under the result
def output_btn_eh(index, obj):
    output_images.clear_output()
    search_and_display_similar_images(output_images, output_btn_eh, pred_idx = index)

# Callback function for "Query random image" button
def btn_eh(obj):
    output_images.clear_output()
    search_and_display_similar_images(output_images, output_btn_eh, pred_idx = random.randrange(len(filename_list)))

# Callback function for input box with url -- activated on enter key
def url_box_eh(obj):
    output_images.clear_output()
    search_and_display_similar_images(output_images, output_btn_eh, filename=obj.value)

# Callback function for "Search" url button
def url_btn_eh(eh):
    output_images.clear_output()
    search_and_display_similar_images(output_images, output_btn_eh, filename=url_box.value)


btn.on_click(btn_eh)
url_box.on_submit(url_box_eh)
url_btn.on_click(url_btn_eh)
display(widgets.HBox([btn, widgets.Label(value='or enter an url:'), url_box, url_btn]))

HBox(children=(Button(description='Query random image', style=ButtonStyle()), Label(value='or enter an url:'),…

In [12]:
display(output_images)

Output()