In [1]:
# Import libraries
from dhiret.common.utils import load_model, build_annoy_index
from pathlib import Path
import json
import csv

# Set model parameters to load the correct model
model_name = "clip"
model_version = "ViT-bigG-14"
clip_dataset_and_epoch = "laion2b_s39b_b160k"
embedding_size = 1280
# Load annoy index
embeddings_folder = Path("embeddings")
index_file_path = embeddings_folder / f"{model_name}_{model_version}_index.ann"
image_name_list_file_path = embeddings_folder / f"{model_name}_{model_version}_image_name_list.json"
index = build_annoy_index(embedding_size, index_file_path)
with open(image_name_list_file_path, "r") as f:
    image_name_list = json.load(f)

### CLIP ViT-bigG-14
Used laion2b_s39b_b160k on https://github.com/mlfoundations/open_clip  

    mAP (L1+L2)/2: 55.9  
    mAP L1 (Primary Instance): 54.2  
    mAP L2 (Secondary Category): 57.6  

In [2]:
##### import ipywidgets as widgets
from IPython.display import display, clear_output
import ipywidgets as widgets
from ipywidgets import IntSlider, Label, HTML
from pathlib import Path
import PIL

def open_image(image_path, width=150, resampling_method=PIL.Image.Resampling.LANCZOS):
    img = PIL.Image.open(image_path).convert("RGB")
    original_width, original_height = img.size
    aspect_ratio = original_height / original_width
    new_width = width
    new_height = int(new_width * aspect_ratio)
    img_resized = img.resize((new_width, new_height), resampling_method)
    return img_resized

# Custom Button Class
class PathedButton(widgets.Button):
    def __init__(self, image_path, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.image_path = image_path

def get_all_images(image_folder):
    image_extensions = ('.jpg', '.jpeg', '.png', '.tif')
    image_folder_path = Path(image_folder)
    all_images = [str(file_path) for file_path in image_folder_path.rglob('*') if file_path.suffix.lower() in image_extensions]
    return all_images

def on_preview_button_click(button):
    image_path = image_selector.value
    with selected_image_output:
        clear_output()
        display(open_image(image_path))

def on_preview_image_click(button):
    image_path = button.image_path
    on_preview_image_click.selected_image_path = image_path  # Store the selected image path as an attribute of the function
    
    with selected_image_output:
        clear_output()
        display(open_image(image_path))


def on_search_button_click(button):
    search_query = search_box.value
    if not search_query.strip():
        with previews_output:
            clear_output()
            display(widgets.Label("Please enter a search query."))
        num_found_images.value = ""
        return
    matching_images = [image_path for image_path in all_image_paths if search_query.lower() in image_path.lower()]
    image_selector.options = matching_images
    
    # Display the number of found images
    num_found_images.value = f"Number of found images: {len(matching_images)}"
    
    # Clear old previews
    with previews_output:
        clear_output()
        
        # Display new previews
        preview_images = []
        preview_buttons = []

        for image_path in matching_images[:5]:
            img = open_image(image_path)

            img_widget = widgets.Image(value=img._repr_png_(), width=img.width, height=img.height)
            preview_images.append(img_widget)

            button = PathedButton(image_path=image_path, description="Select")
            button.on_click(on_preview_image_click)
            preview_buttons.append(button)


        with previews_output:
            clear_output()
            display(widgets.GridBox(preview_images + preview_buttons, layout=widgets.Layout(grid_template_columns="repeat(5, 1fr)")))

image_folder = "data"

# Create the list of all image file paths
all_image_paths = get_all_images(image_folder)

# Create the widgets
search_box = widgets.Text(description='Search:')
search_button = widgets.Button(description="Search Images")
search_button.on_click(on_search_button_click)

image_selector = widgets.Dropdown(description='Select image:')

preview_button = widgets.Button(description="Preview Image")
preview_button.on_click(on_preview_button_click)

num_found_images = widgets.Label(value="")

previews_output = widgets.Output()

selected_image_output = widgets.Output()


# Display the widgets
display(search_box)
display(search_button)
display(num_found_images)
display(previews_output)
display(image_selector)
display(preview_button)
display(selected_image_output)

# Initialize the found images variable
found_images = []

Text(value='', description='Search:')

Button(description='Search Images', style=ButtonStyle())

Label(value='')

Output()

Dropdown(description='Select image:', options=(), value=None)

Button(description='Preview Image', style=ButtonStyle())

Output()

In [3]:
def on_query_selected_image_click(button):
    if not hasattr(on_preview_image_click, 'selected_image_path'):
        print("No image selected.")
        return

    image_path = on_preview_image_click.selected_image_path

    # Find the index of the image_path in the image_name_list
    try:
        image_index = image_name_list.index(image_path)
    except ValueError:
        print(f"Image path {image_path} not found in the image_name_list")
        return

    # Get the embedding from the annoy index
    embedding = index.get_item_vector(image_index)
    
    # Query the annoy index using the embedding
    num_nearest_neighbors = 10
    nearest_neighbors_indices = index.get_nns_by_vector(embedding, num_nearest_neighbors, include_distances=True)
    nearest_neighbors_paths_distances = [(image_name_list[i], d) for i, d in zip(*nearest_neighbors_indices)]
    nearest_neighbors_paths = [image_name_list[i] for i in nearest_neighbors_indices[0]]
    
    # Store the last retrieval results and the query image path as attributes of the function
    on_query_selected_image_click.last_retrieval_results = nearest_neighbors_paths_distances
    on_query_selected_image_click.query_image_path = image_path

    with retrieved_images_output:
        clear_output()
        
        retrieved_images = []

        for neighbor_path, distance in nearest_neighbors_paths_distances:
            img = open_image(neighbor_path)

            img_widget = widgets.Image(value=img._repr_png_(), width=img.width, height=img.height)
            img_label = widgets.Textarea(value=neighbor_path, layout=widgets.Layout(width='150px', height='50px', overflow_y='scroll'), disabled=True)
            distance_label = widgets.Label(value=f"Distance: {distance:.4f}", layout=widgets.Layout(width='150px'))
            retrieved_images.append(widgets.VBox([img_widget, img_label, distance_label]))

        display(widgets.GridBox(retrieved_images, layout=widgets.Layout(grid_template_columns="repeat(5, 1fr)")))
        
def save_retrieval_results_to_csv(button):
    if not hasattr(on_query_selected_image_click, "last_retrieval_results"):
        print("No retrieval results to save.")
        return

    retrieval_results = on_query_selected_image_click.last_retrieval_results
    query_image_path = Path(on_query_selected_image_click.query_image_path)
    file_name = query_image_path.as_posix().replace("/", "_") + ".csv"

    results_dir = Path("results")
    results_dir.mkdir(exist_ok=True)

    output_file = results_dir / file_name

    with output_file.open("w", newline="") as csvfile:
        csv_writer = csv.writer(csvfile)
        csv_writer.writerow(["file_path", "distance"])

        for path, distance in retrieval_results:
            csv_writer.writerow([path, distance])

    print(f"Saved retrieval results to {output_file}")

# Create the new button
query_selected_image_button = widgets.Button(description="Query Selected Image")
query_selected_image_button.on_click(on_query_selected_image_click)

# Create button for saving csv results
save_results_button = widgets.Button(description="Save Retrieval Results")
save_results_button.on_click(save_retrieval_results_to_csv)

# Create the retrieved images output widget
retrieved_images_output = widgets.Output()

# Display the new button and the retrieved images output widget
display(query_selected_image_button)
display(retrieved_images_output)
display(save_results_button)

Button(description='Query Selected Image', style=ButtonStyle())

Output()

Button(description='Save Retrieval Results', style=ButtonStyle())