~~~
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
~~~
<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/google-health/cxr-foundation/blob/master/notebooks/retrieve_images_by_text.ipynb">
      <img alt="Google Colab logo" src="https://www.tensorflow.org/images/colab_logo_32px.png" width="32px"><br> Run in Google Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2Fgoogle-health%2Fcxr-foundation%2Fmaster%2Fnotebooks%2Fretrieve_images_by_text.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/google-health/cxr-foundation/blob/master/notebooks/retrieve_images_by_text.ipynb">
      <img alt="GitHub logo" src="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png" width="32px"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://huggingface.co/google/cxr-foundation">
      <img alt="Hugging Face logo" src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" width="32px"><br> View on Hugging Face
    </a>
  </td>
</tr></tbody></table>


# Text-Based Image Retrieval with Chest X-Ray Embeddings

This notebook demonstrates how to use pre-computed embeddings from chest X-ray
images for text-based image retrieval. It showcases:

- Loading 2737 pre-computed embeddings and labels derived from a subset of the NIH Chest X-ray14 dataset.
- Performing a text-based search to find matching images using ELIXR-B embeddings.

The embeddings are the *elixr_img_contrastive* which are text-aligned image embedding from the Q-former output in ELIXR (https://arxiv.org/abs/2308.01317), can be used for image retrieval.

**NOTE:**  To streamline this Colab demonstration and eliminate the need for lengthy downloads, we've precomputed the embeddings, which are considerably smaller in size (similar to compressed images). You can learn how to generate embeddings using the other [notebooks](https://colab.research.google.com/github/google-health/cxr-foundation/blob/master/notebooks/).

In [None]:
# @title Download precomputed embeddings and labels
!wget -nc https://storage.googleapis.com/healthai-us/encoded-data/nih/radiology/cxr/precomputed_image_embeddings.npz https://storage.googleapis.com/healthai-us/encoded-data/nih/radiology/cxr/precomputed_text_embeddings.npz https://storage.googleapis.com/healthai-us/encoded-data/nih/radiology/cxr/thumbnails_id_to_webp.npz

In [None]:
# @title Data Preparation and Similarity Functions
import numpy as np
import pandas as pd

# Load files
embeddings_file = np.load("precomputed_image_embeddings.npz")
image_embeddings_df = pd.DataFrame(
    [(key, embeddings_file[key]) for key in embeddings_file.keys()],
    columns=['image_id', 'embeddings']
)
embeddings_file.close()

# Load text embeddings
text_embeddings = np.load("precomputed_text_embeddings.npz")
text_embeddings_queries = list(text_embeddings.keys())

# Load images
thumbnails = np.load("thumbnails_id_to_webp.npz", allow_pickle=True)

def restructure_embeddings_for_search(df):
    """Restructures the DataFrame so each image ID has 32 rows, one per sub-vector, with pre-computed norms."""
    expanded_rows = []
    for _, row in df.iterrows():
        image_id = row['image_id']
        reshaped_emb = np.reshape(row['embeddings'], (32, 128))
        norms = np.linalg.norm(reshaped_emb, axis=1)
        for i in range(32):
            expanded_rows.append({
                'image_id': image_id,
                'sub_vector': reshaped_emb[i],
                'norm': norms[i]
            })
    return pd.DataFrame(expanded_rows)


preprocessed_image_embeddings_df = restructure_embeddings_for_search(image_embeddings_df)

def find_top_5_similarities_flattened(df_embeddings, txt_emb):
    """Retrieves the top 5 most similar images to the given text embeddings.

    Calculates cosine similarity between image and text embeddings using a flattened
    DataFrame for efficient search.
    """
    def calculate_similarity(row, txt_emb):
        """Calculate similarity between an image embedding and a text embedding."""
        txt_norm = np.linalg.norm(txt_emb)
        return np.dot(row['sub_vector'], txt_emb) / (row['norm'] * txt_norm)

    # Calculate similarities for the given text embedding
    df_embeddings['similarity'] = df_embeddings.apply(
        lambda row: calculate_similarity(row, txt_emb), axis=1
    )

    # Find the max similarity for each image_id
    max_similarities = df_embeddings.groupby('image_id')['similarity'].max().reset_index()

    # Sort and get the top 5
    top_5 = max_similarities.sort_values(by='similarity', ascending=False).head(5)

    return top_5

# Image Retrieval Demo

This this demo we use take one or two text queries, fetch EXLIR-B embeddings and
use these embeddings to measure similarities between text and images embeddings.
Those align textual and visual representations, which is ideal for cross-modal retrieval tasks. We display up to
5 top matching images of the subset of the 2737 images that were precached.

Please note that the search database is small and covers several diseases.

In [None]:
# @title Perform Query
from ipywidgets import widgets, Layout
from IPython.display import display, clear_output
import numpy as np
from google.colab import output
output.no_vertical_scroll()

clear_button = widgets.Button(description="Clear")
def clear_results(button):
  clear_output()  # Clear the previous output
  text_input.value = ''  # Reset the text input widget

clear_button.on_click(clear_results)

# Create the text input widget with auto-complete
text_input = widgets.Combobox(
    placeholder='Type query...',
    description='Query',
    options=text_embeddings_queries,
    ensure_option=True  # Ensures that the typed value is in the options
)

display(text_input)


def on_text_change(change):
  if change['type'] == 'change' and change['name'] == 'value':
    clear_output(wait=True)
    selected_query = change['new']
    if selected_query:
      display(widgets.HBox([text_input, clear_button]))
    else:
      display(text_input)
    filtered_options = [option for option in text_embeddings_queries if selected_query in option]

    if len(filtered_options) == 1:
      text_input.value = filtered_options[0]  # Set the value to the single option
      selected_query = filtered_options[0]

    if selected_query in text_embeddings_queries:
      print(f"Selected query: {text_input.value}")
      # Retrieve the text embedding vector using selected_query as the key
      out = find_top_5_similarities_flattened(preprocessed_image_embeddings_df, text_embeddings[text_input.value])

      for _, row in out.iterrows():
        image_id = str(row['image_id'])

        # Get the image bytes from thumbnails
        image_bytes = thumbnails[image_id]

        # Create widgets for image and score
        image_widget = widgets.Image(value=image_bytes.tobytes(), format='webp')

        # Create a horizontal box to display image and score
        hbox = widgets.HBox([
            image_widget,
            widgets.VBox([
                widgets.Label(value=f"Similarity Score: {row['similarity']:.4f}"),
                widgets.Label(value=f"Image ID: {row['image_id']}"),
                ], layout=Layout(margin='0px 0px 0px 0px'))])

        # Display the image
        display(hbox)


text_input.observe(on_text_change, names='value')

# Next steps

Explore the other [notebooks](https://github.com/google-health/cxr-foundation/blob/master/notebooks) to learn what else you can do with the model.