In [1]:
import ipywidgets
import panel as pn
import pandas
import numpy
import tiledb.cloud
from tiledb.cloud.compute import Delayed
from tiledb.cloud import groups
from PIL import Image
import tiledb.vector_search as vs
from efficientnet.preprocessing import center_crop_and_resize
from typing import List, Optional, Union, Dict

In [2]:
pn.extension("ipywidgets", sizing_mode = 'stretch_both')

In [3]:
from bokeh.settings import settings

settings.resources = 'CDN'

In [4]:
def check_groups_for_vector_search(group_browser_data: tiledb.cloud.rest_api.models.GroupBrowserData) -> List[tiledb.cloud.rest_api.models.group_info.GroupInfo]:
    group_list = []
    
    for group in group_browser_data.groups:
        group_list.append(group)
            
    return group_list

In [5]:
def list_vector_search_datasets(namespace: str) -> List[tiledb.cloud.rest_api.models.group_info.GroupInfo]:
    group_api = tiledb.cloud.client.build(tiledb.cloud.rest_api.GroupsApi)

    group_list = []
    
    owned_group_data = group_api.list_owned_groups(namespace=namespace, search="vector_search", per_page=10)
    owned_pages = owned_group_data.pagination_metadata.total_pages
    
    group_list += check_groups_for_vector_search(owned_group_data)
    
    shared_group_data = group_api.list_shared_groups(namespace=namespace, search="vector_search", per_page=10)
    shared_pages = shared_group_data.pagination_metadata.total_pages
    
    group_list += check_groups_for_vector_search(shared_group_data)
    
    total_pages = 0
    if owned_pages is not None:
        total_pages += owned_pages
    
    if shared_pages is not None:
        total_pages += shared_pages
    
    if owned_pages is not None:
        for page in range(1, int(owned_pages)):
            owned_group_data = group_api.list_owned_groups(namespace=namespace, search="vector_search", per_page=10, page=page+1)
            group_list += check_groups_for_vector_search(owned_group_data)

            # If we have a progress bar update it
            # if loader is not None:
            #     loader.value += 1


    if shared_pages is not None:
        for page in range(1, int(shared_pages)):
            shared_group_data = group_api.list_owned_groups(namespace=namespace, search="vector_search", per_page=10, page=page+1)
            group_list += check_groups_for_vector_search(shared_group_data)

            # If we have a progress bar update it
            # if loader is not None:
            #     loader.value += 1
    
    return group_list

In [6]:
def group_list_to_display(group_list: List[tiledb.cloud.rest_api.models.group_info.GroupInfo]) -> List[tuple]:
    output = {}
    for group in group_list:
        if "plant" in group.name:
            output[group.name] = group.tiledb_uri
    
    return output

In [7]:
# from ipyupload import FileUpload

image_array_uri = "tiledb://seth/4d06a507-317a-4c61-9288-e2e5fcebf4c2"
file = None
image_to_search = None
image_to_search_embeddings = None
file_upload = ipywidgets.FileUpload(
    # https://developer.mozilla.org/en-US/docs/Web/HTML/Element/input#attr-accept
    # eg. '.txt', '.pdf', 'image/*', 'image/*,.pdf'
    accept='image/*', # default
    # True to accept multiple files upload else False
    multiple=False, # default
    # True to disable the button else False to enable it
    disabled=False, # default
    # CSS transparently passed to button (a button element overlays the input[type=file] element for better styling)
    # e.g. 'color: darkblue; background-color: lightsalmon; width: 180px;'
    # style_button='', # default
    # to compress data from browser to kernel
    # compress level from 1 to 9 incl. - 0 for no compression
    # compress_level=3 # default
)

In [8]:
def on_upload(change):
    global file
    global output_widget
    global loading
    global file_upload
    global file_name
    global file_size
    global image_to_search
    global thumbnail_search_widget
    global image_to_search_embeddings
    global index
    global topk
    global nprobe
    global image_array_uri

    if change.new == "":
        file = None
    elif file_upload._counter == 0:
        return
    # elif file != None and change.new == file:
    #     return
    else:
        file = change.new
        for filename, f in file.items():
            file = f
            break
        
    output_widget.clear_output()
    with output_widget:
        details = ipywidgets.HTML("<p><b>Uploading Image for search</b></p>")
        l = ipywidgets.Output()
        d = ipywidgets.VBox([details, l])
        display(d)
        with l:
            display(loading)


    file_name.value = f"<p><b>File: {file['metadata']['name']}</b></p>"
    file_size.value = f"<p><b>Size: {file['metadata']['size']}</b></p>"

    file_upload.value.clear()
    file_upload._counter = 0
    
        
    image_to_search = image_from_bytes(file['content'])

    thumbnail_search_widget.clear_output()
    with thumbnail_search_widget:
        # width, height = image_to_search.size
        # thumbnail = image_to_search.resize((100, int(height / width * 100)), Image.Resampling.NEAREST)
        display(Image.fromarray(image_to_search))

    output_widget.clear_output()
    with output_widget:
        details = ipywidgets.HTML("<p><b>Processing uploaded image for search</b></p>")
        l = ipywidgets.Output()
        d = ipywidgets.VBox([details, l])
        display(d)
        with l:
            display(loading)

    image_to_search_embeddings = search_embedding(image_to_search)

    output_widget.clear_output()
    with output_widget:
        details = ipywidgets.HTML("<p><b>Performing Image search</b></p>")
        l = ipywidgets.Output()
        d = ipywidgets.VBox([details, l])
        display(d)
        with l:
            display(loading)

    results = search(index, image_to_search_embeddings, topk, nprobe)
    
    output_widget.clear_output()
    with output_widget:
        with tiledb.open(image_array_uri, mode='r') as A:
            for result in results:
                data = A[result]["value"]
                img = Image.fromarray(data)
                width, height = img.size
                # img = img.resize((200, int(height / width * 200)), Image.Resampling.NEAREST)
                display(img)
                # img = ipywidgets.Image(value=img.data, width = 100, height = int(height / width * 100))
                # display(img)
    
    return

file_upload.observe(on_upload, names=["value"])

In [9]:
def calculate_resnet(x: numpy.ndarray) -> numpy.ndarray:
    import warnings
    warnings.filterwarnings("ignore")
    import tensorflow as tf
    from tensorflow.keras.applications.resnet_v2 import preprocess_input

    model = tf.keras.applications.ResNet50V2(include_top=False)
    maps = model.predict(preprocess_input(x))
    if numpy.prod(maps.shape) == maps.shape[-1] * len(x):
        return numpy.squeeze(maps)
    else:
        return maps.mean(axis=1).mean(axis=1)

In [10]:
def search(index, query_embedding, top_n, nprobe):
    if isinstance(index, vs.FlatIndex):
        result_d, result_i = index.query(query_embedding, k=top_n)
    elif isinstance(index, vs.IVFFlatIndex): 
        result_d, result_i = index.query(query_embedding, k=top_n, nprobe=nprobe)
    return result_i[0]
    # with tiledb.open(image_array_uri, mode='r') as A:
        
        # for result in result[0]:
        #     display(PIL.Image.fromarray(A[result]["value"]))

In [11]:
def search_embedding(image):
    return calculate_resnet(numpy.expand_dims(image, axis=0))

In [12]:
def format_image(image):
    image = numpy.array(image)
    if len(image.shape) != 3:
        raise ValueError(
          "Image dimension should be 3. tfds.show_examples does not support "
          "batched examples or video.")
    _, _, c = image.shape
    if c == 1:
        image = image.reshape(image.shape[:2])
    image = center_crop_and_resize(image, 224).astype(numpy.uint8)
    return image
    
def image_from_bytes(content):
    from PIL import Image
    import io
    image = format_image(Image.open(io.BytesIO(content)))
    return image

In [13]:
loading = pn.indicators.LoadingSpinner(value=True, width=100, height=100, color="success")

In [14]:
topk = 3

topk_selection = ipywidgets.IntSlider(
    value=topk,
    min=0,
    max=100,
    step=1,
    description='TopK:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

def on_topk_selection(change):
    global topk
    
    if change.new == "":
        topk = None
    else:
        topk = change.new


topk_selection.observe(on_topk_selection, names="value")

In [15]:
nprobe = 1

nprobe_selection = ipywidgets.IntSlider(
    value=nprobe,
    min=1,
    max=100,
    step=1,
    description='Nprobe:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

def on_nprobe_selection(change):
    global nprobe
    
    if change.new == "":
        nprobe = None
    else:
        nprobe = change.new


nprobe_selection.observe(on_nprobe_selection, names="value")

In [16]:
#organizations = ["seth", "TileDB-Inc"]
organizations = ["TileDB-Inc"]
namespace = organizations[0]
datasets = list_vector_search_datasets(namespace=namespace)
dataset_list = group_list_to_display(datasets)
uri = None
index = None

In [17]:
namespace_selection = ipywidgets.Dropdown(options=organizations, description="Namespace")

namespace = None

def on_namespace_selection_change(change):
    global namespace
    global datasets
    global dataset_list
    global dataset_selection
    global output_widget
    global loading
    global uri
    
    output_widget.clear_output()
    with output_widget:
        display(loading)
    namespace = change.new

    datasets = list_vector_search_datasets(namespace=namespace)
    dataset_list = group_list_to_display(datasets)

    # loader.visible = False
    if len(dataset_list) > 0:
        dataset_selection.options = dataset_list
        dataset_selection.value = dataset_list[list(dataset_list.keys())[0]]
        uri = dataset_selection.value
        if "ivf_flat" in uri:
            index = vs.IVFFlatIndex(uri)
        else:
            index = vs.FlatIndex(uri)

    else:
        dataset_selection.options = {}
        dataset_selection.value = None

    output_widget.clear_output()


namespace_selection.observe(on_namespace_selection_change, names="value")


In [18]:
dataset_selection = ipywidgets.Dropdown(options=dataset_list, description="Dataset")

def on_dataset_selection_change(change):
    global uri
    global output_widget
    global index

    if change.new is None or change.new == "":
        return
    
    output_widget.clear_output()
    with output_widget:
        display(loading)

    uri = change.new
    if "ivf_flat" in uri:
        index = vs.IVFFlatIndex(uri)
    else:
        index = vs.FlatIndex(uri)
    output_widget.clear_output()
    
dataset_selection.observe(on_dataset_selection_change, names="value")

In [19]:
# Test datasets
def accuracy(result, gt):
    found = 0
    total = 0
    i = 0
    for r in result:
        total += len(r)
        found += len(numpy.intersect1d(r, gt[i]))
        i += 1
    return found / total

def dataset_accuracy(index, query_embeding):
    t = numpy.array([query_embeding], dtype=numpy.float32)
    results = search(index, t, top_n=10, nprobe=10)
    ground_truth = [[42, 3440, 1160, 3406, 393, 1233, 2111, 1947, 2436, 972]]
    return accuracy([results], ground_truth)

for dataset_name in dataset_list.keys():
    uri = dataset_list[dataset_name]
    if "ivf_flat" not in dataset_name:
        index = vs.FlatIndex(uri, config=tiledb.cloud.Config().dict())
        with tiledb.open(index.db_uri) as A:
            query_embeding = A[:, 42]["values"]
            break

for dataset_name in dataset_list.keys():
    uri = dataset_list[dataset_name]
    if "ivf_flat" in dataset_name:
        index = vs.IVFFlatIndex(uri, config=tiledb.cloud.Config().dict())
        ac = dataset_accuracy(index, query_embeding)
        assert  ac >= 0.95
    else:
        index = vs.FlatIndex(uri, config=tiledb.cloud.Config().dict())
        ac = dataset_accuracy(index, query_embeding)
        assert  ac == 1.0

In [20]:
# Load initial dataset
if len(dataset_list) > 0:
    dataset_selection.options = dataset_list

    uri = dataset_list[list(dataset_list.keys())[0]]
    dataset_selection.value = uri
    if "ivf_flat" in uri:
        index = vs.IVFFlatIndex(uri)
    else:
        index = vs.FlatIndex(uri)

In [21]:
namespace_data_box = ipywidgets.VBox([namespace_selection, dataset_selection])

configuration = ipywidgets.Accordion(children=[ipywidgets.VBox([topk_selection, nprobe_selection, namespace_data_box])], selected_index=None)
configuration.set_title(0, 'Configuration')


output_widget = ipywidgets.Output()
thumbnail_search_widget = ipywidgets.Output()
file_title = ipywidgets.HTML("<p><b>Search File Details:</b></p>")
file_size = ipywidgets.HTML("<p><b></b></p>")
file_name = ipywidgets.HTML("<p><b></b></p>")
file_details = ipywidgets.VBox([file_title, file_name, file_size, thumbnail_search_widget])
details_box = ipywidgets.HBox([file_details])
output_box = ipywidgets.VBox([output_widget])

left_side_panel = ipywidgets.VBox([configuration, details_box, file_upload])

# TileDB Vector Image Search

This dashboard provides image search capabilities. Simply upload an image and find similar ones!

In [22]:
app = ipywidgets.AppLayout(
          left_sidebar=left_side_panel,
          center=output_box,
          right_sidebar=None,
          footer=None)


app

AppLayout(children=(VBox(children=(Accordion(children=(VBox(children=(IntSlider(value=3, continuous_update=Fal…