# Visual Search for Retail with Amazon Bedrock Titan Multimodal and Amazon OpenSearch Vector Databases

This notebook will walk you through the process of building visual search functionality using a Large Language Model (LLM) hosted on [Amazon Bedrock](https://aws.amazon.com/bedrock/). We will use an Embeddings Model hosted on Amazon Bedrock to convert rproduct images and descriptions to vectors and store and search them in an [Amazon OpenSearch Serverless](https://aws.amazon.com/opensearch-service/features/serverless/) collection.

<div class="alert alert-block alert-info">
<b>Note:</b>
    <ul>
        <li>This notebook is tested on <a href="https://docs.aws.amazon.com/sagemaker/latest/dg/nbi.html">Amazon SageMaker Notebook instance</a> and within <a href="https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks.html">Amazon SageMaker Studio Notebook</a> and within an AWS Region that supports <a href="https://aws.amazon.com/opensearch-service/features/serverless/">Amazon OpenSearch Serverless</a>.</li>
        <li>At the time of writing this notebook, Amazon Bedrock was only available in <a href="https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html#bedrock-regions">these supported AWS Regions</a>. If you are running this notebook from any other AWS Region, then you have to change the Amazon Bedrock client's region and/or endpoint URL parameters to one of those supported AWS Regions. Follow the guidance in the <i>Organize imports</i> section of this notebook.</li>
        <li>This notebook is recommended to be run with a minimum instance size of <i>ml.m5.xlarge</i> and
            <ul>
                <li>With <i>Amazon Linux 2, Jupyter Lab 3</i> as the platform identifier on an Amazon SageMaker Notebook instance.</li>
                <li> (or)
                <li>With <i>Data Science 3.0</i> as the image on an Amazon SageMaker Studio Notebook.</li>
            <ul>
        </li>
        <li>At the time of this writing, the most relevant latest version of the Kernel for running this notebook,
            <ul>
                <li>On an Amazon SageMaker Notebook instance was <i>conda_python3</i></li>
                <li>On an Amazon SageMaker Studio Notebook was <i>Python 3</i></li>
            </ul>
        </li>
    </ul>
</div>

**Table of Contents:**

1. [Pre-requisites](#pre-requisites)

    1.a [Check and configure access to the Internet](#1.a)

    1.b [Install required software libraries](#1.b)
    
    1.c. [Configure logging](#1.c)
    
    1.d. [Organize imports](#1.d)
    
    1.e. [Set AWS Region and boto3 config](#1.e)
    
    1.f. [Check and create an Amazon OpenSearch Serverless collection](#1.f)
    
    1.g. [Enable model access in Amazon Bedrock](#1.g)
    
    1.h. [Check and configure security permissions](#1.h)

    1.i. [Create common objects](#1.i)
    
    1.j. [Create an index in the Amazon OpenSearch Serverless collection](#1.j)
    
 2. [Create the OpenSearch Search Functionality ](#2)
 
    2.a. [Prepare to load data into the vector database](#3.a)        
        
    2.b. [Create the embeddings](#3.b)
    
    2.c. [Store the embeddings in the vector database](#3.c)
    
    2.d. [Test searching our vector database](#3.d)
    
    2.e. [Build our search functions](#3.e)
  
 3. [Frequently Asked Questions (FAQs)](#FAQs)

##  1. Pre-requisites <a id ='pre-requisites'> </a>

Check and complete the prerequisites.

###  1.a. Check and configure access to the Internet <a id ='1.a'> </a>
This notebook requires outbound access to the Internet to download the required software updates and to download the dataset.  You can either provide direct Internet access (default) or provide Internet access through an [Amazon VPC](https://aws.amazon.com/vpc/).  For more information on this, refer [here](https://docs.aws.amazon.com/sagemaker/latest/dg/appendix-notebook-and-internet-access.html).

### 1.b. Install required software libraries <a id ='1.b'> </a>
This notebook requires the following libraries:
* [SageMaker Python SDK version 2.x](https://sagemaker.readthedocs.io/en/stable/v2.html)
* [Python 3.10.x](https://www.python.org/downloads/release/python-3100/)
* [Boto3](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html)
* [LangChain](https://www.langchain.com/)
* [OpenSearch Python Client](https://pypi.org/project/opensearch-py/)
* [Tqdm](https://pypi.org/project/tqdm/)
* [Backoff](https://pypi.org/project/backoff/)

Run the following cell to install the required libraries.

<div class="alert alert-block alert-warning">  
    <b>Note:</b> At the end of the installation, the Kernel will be forcefully restarted immediately. Please wait 10 seconds for the kernel to come back before running the next cell.
</div>

In [None]:
%pip install -Uq pip
%pip install -Uq boto3
%pip install -q langchain==0.0.339
%pip install -q opensearch-py==2.4.2
%pip install -q tqdm==4.66.1
%pip install -q backoff

import IPython

IPython.Application.instance().kernel.do_shutdown(True)

### 1.c. Configure logging <a id ='1.c'> </a>

####  a. System logs <a id='Configure%20system%20logs'></a>

System logs refers to the logs generated by the notebook's interactions with the underlying notebook instance. Some examples of these are the logs generated when loading or saving the notebook.

These logs are automatically setup when the notebook instance is launched.

These logs can be accessed through the [Amazon CloudWatch Logs](https://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/WhatIsCloudWatchLogs.html) console in the same AWS Region where this notebook is running.
* When running this notebook in an Amazon SageMaker Notebook instance, navigate to the following location,
    * <i>CloudWatch > Log groups > /aws/sagemaker/NotebookInstances > {notebook-instance-name}/jupyter.log</i>
* When running this notebook in an Amazon SageMaker Studio Notebook, navigate to the following locations,
    * <i>CloudWatch > Log groups > /aws/sagemaker/studio > {sagmaker-domain-name}/{user-name}/KernelGateway/{notebook-instance-name}</i>
    * <i>CloudWatch > Log groups > /aws/sagemaker/studio > {sagmaker-domain-name}/{user-name}/JupyterServer/default</i>

Run the following cell to print the name of the underlying notebook instance.

In [None]:
import json

notebook_name = ''
resource_metadata_path = '/opt/ml/metadata/resource-metadata.json'
with open(resource_metadata_path, 'r') as metadata:
    notebook_name = (json.load(metadata))['ResourceName']
print("Notebook instance name: '{}'".format(notebook_name))

####  b. Application logs <a id='Configure%20application%20logs'></a>

Application logs refers to the logs generated by running the various code cells in this notebook. To set this up, instantiate the [Python logging service](https://docs.python.org/3/library/logging.html) by running the following cell. You can configure the default log level and format as required.

By default, this notebook will only print the logs to the corresponding cell's output console.

In [None]:
import logging
import os

# Set the logging level and format
log_level = logging.INFO
log_format = '%(asctime)s - %(levelname)s - %(message)s'
logging.basicConfig(level=log_level, format=log_format)
logging.getLogger('sagemaker.config').setLevel(logging.CRITICAL)

# Save these in the environment variables for use in the helper scripts
os.environ['LOG_LEVEL'] = str(log_level)
os.environ['LOG_FORMAT'] = log_format

###  1.d. Organize imports <a id ='1.d'> </a>

Organize all the library and module imports for later use.

In [None]:
import boto3
import botocore
import langchain
import opensearchpy
import requests
import sagemaker
from sagemaker.predictor import Predictor
import sys
from botocore.config import Config
from tqdm.notebook import tqdm
from tqdm.contrib.concurrent import process_map
from IPython.core.display import HTML
from langchain.vectorstores import OpenSearchVectorSearch
from langchain.embeddings import BedrockEmbeddings
from urllib.parse import urlparse
import base64
import backoff
import secrets
from typing import Tuple, Optional
import numpy as np
from opensearchpy.helpers import parallel_bulk
import re
import unicodedata
import concurrent.futures as cf
from PIL import Image
from io import BytesIO
from opensearchpy.helpers import parallel_bulk
import matplotlib.pyplot as plt
import base64
from os.path import splitext
from botocore.exceptions import ClientError
import io


# Import the helper functions from the 'scripts' folder
sys.path.append(os.path.join(os.getcwd(), "scripts"))
#logging.info("Updated sys.path: {}".format(sys.path))
from helper_functions import *

def backoff_hdlr(details):
    """Handler from https://pypi.org/project/backoff/"""
    '''print(
        "Backing off {wait:0.1f} seconds after {tries} tries "
        "calling function {target} with kwargs "
        "{kwargs}".format(**details)
    )'''
    
def secure_randint(a, b):
    return a + secrets.randbelow(b - a + 1)

Print the installed versions of some of the important libraries.

In [None]:
logging.info("Python version : {}".format(sys.version))
logging.info("Boto3 version : {}".format(boto3.__version__))
logging.info("SageMaker Python SDK version : {}".format(sagemaker.__version__))
logging.info("LangChain version : {}".format(langchain.__version__))
logging.info("OpenSearch Python Client version : {}".format(opensearchpy.__version__))

###  1.e. Set AWS Region and boto3 config <a id ='1.e'> </a>

Get the current AWS Region (where this notebook is running) and the SageMaker Session. This will be used to initiate some of the clients to AWS services using the boto3 APIs.

<div class="alert alert-block alert-info">
    <b>Note:</b> All the AWS services used by this notebook except Amazon Bedrock will use the current AWS Region. For Bedrock, follow the guidance in the next cell.
</div>

<div class="alert alert-block alert-warning">  
<b>Note:</b> At the time of writing this notebook, Amazon Bedrock was only available in <a href="https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html#bedrock-regions">these supported AWS Regions</a>. If you are running this notebook from any other AWS Region, then you have to change the Amazon Bedrock client's region and/or endpoint URL parameters to one of those supported AWS Regions. In order to do this, this notebook will use the value specified in the environment variable named <mark>AMAZON_BEDROCK_REGION</mark>. If this is not specified, then the notebook will default to <mark>us-west-2 (Oregon)</mark> for Amazon Bedrock.
</div>


In [None]:
# Get the AWS Region, SageMaker Session and IAM Role references
my_session = boto3.session.Session()
logging.info("SageMaker Session: {}".format(my_session))
my_iam_role = sagemaker.get_execution_role()
logging.info("Notebook IAM Role: {}".format(my_iam_role))
my_region = my_session.region_name
logging.info("Current AWS Region: {}".format(my_region))

# Explicity set the AWS Region for Amazon Bedrock clients
AMAZON_BEDROCK_DEFAULT_REGION = "us-west-2"
br_region = os.environ.get('AMAZON_BEDROCK_REGION')
if br_region is None:
    br_region = AMAZON_BEDROCK_DEFAULT_REGION
elif len(br_region) == 0:
    br_region = AMAZON_BEDROCK_DEFAULT_REGION
logging.info("AWS Region for Amazon Bedrock: {}".format(br_region))

Set the timeout and retry configurations that will be applied to all the boto3 clients used in this notebook.

In [None]:
# Increase the standard time out limits in the boto3 client from 1 minute to 3 minutes
# and set the retry limits
my_boto3_config = Config(
    connect_timeout = (60 * 3),
    read_timeout = (60 * 3),
    retries = {
        'max_attempts': 600,
        'mode': 'adaptive'
    }
)

###  1.f. Check and create an Amazon OpenSearch Serverless collection <a id ='1.f'> </a>

This notebook uses an [Amazon OpenSearch Serverless (AOSS) collection](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/serverless-collections.html) of type [Vector search](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/serverless-overview.html#serverless-usecase) as the vector database that will be used by the chat assistant.

Run the following cells to check and create an AOSS collection if it does not exist.

In [None]:
# Set the flags to identify if an AOSS collection exists and if it is created through this notebook
aoss_collection_exists = False
aoss_collection_created = False

<div class="alert alert-block alert-info">
<b>Note:</b> For the purpose of running this notebook, it is preferable to have an empty collection.
</div>

Run the following code cell to retreive the details of the first available AOSS collection.

In [None]:
# Create the AOSS client
aoss_client = boto3.client("opensearchserverless", config = my_boto3_config)

# Check and create a collection if none is found
collection_id = ''
collections = aoss_client.list_collections()['collectionSummaries']
if len(collections) == 0:
    aoss_collection_exists = False
    logging.info("No AOSS collections exist.")
else:
    aoss_collection_exists = True
    logging.info("Found an AOSS collection.")
    first_collection = collections[0]
    collection_id = first_collection["id"]
    collection_name = first_collection["name"]

<div class="alert alert-block alert-info">
<b>Note:</b> If you like to create an AOSS collection through this notebook, then, run the following cell.
</div>

In [None]:
### Note: It may take 8 to 10 minutes to create the AOSS collection.

# The helper function 'create_aoss_collection' (available through ./scripts/helper_functions.py) creates the specified
# AOSS collection with the following policies:
# Data access policy: provides full access to the IAM role associated with this notebook instance.
# Encryption policy: encrypts with AWS owned key.
# Network policy: provides public network access to the collection.
if aoss_collection_exists:
    logging.info("Skipping AOSS collection creation.")
else:
    collection_name = "vs-collection"
    data_access_policy_name = "vs-collection-dap"
    encryption_policy_name = "vs-collection-ep"
    network_policy_name = "vs-collection-np"
    response = create_aoss_collection(aoss_client, collection_name, data_access_policy_name,
                                      encryption_policy_name, network_policy_name, my_iam_role)
    collection_id = response["id"]
    collection_name = response["name"]
    aoss_collection_created = True

Run the following cell to print the details of the AOSS collection that will be used.

In [None]:
if len(collection_id) == 0:
    aoss_collection_exists = False
    logging.info("No AOSS collections exist.")
else:
    aoss_collection_exists = True
    logging.info("The following AOSS collection will be used:\nCollection id: {}; Collection name: {}"
                 .format(collection_id, collection_name))
    # Print the AWS console URL to the AOSS collection
    collection_aws_console_url = "https://{}.console.aws.amazon.com/aos/home?region={}#opensearch/collections/{}"\
    .format(my_region, my_region, collection_name)
    logging.info("If you like to take a look at this collection, visit {}".format(collection_aws_console_url))

###  1.g. Enable model access in Amazon Bedrock <a id ='1.g'> </a>

<div class="alert alert-block alert-danger">
    <b>Note:</b> Before invoking any model in Amazon Bedrock, enable access to that model by following the instructions <a href="https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html">here</a>. In addition, for Anthropic models, you need to submit the use case details. Otherwise, you will get an authorization error.
</div>

Run the following cell to print the Amazon Bedrock model access page URL for the AWS Region that was selected earlier.

In [None]:
# Print the Amazon Bedrock model access page URL
logging.info("Amazon Bedrock model access page - https://{}.console.aws.amazon.com/bedrock/home?region={}#/modelaccess"
             .format(br_region, br_region))

<div class="alert alert-block alert-warning">  
<b>Note:</b> You will have to do this manually after reading the End User License Agreement (EULA) for each of the models that you want to enable. Unless you explicitly disable it, this is a one-time setup for each model in an AWS account.
</div>

###  1.h. Check and configure security permissions <a id ='1.h'> </a>
This notebook uses the IAM role attached to the underlying notebook instance.  To view the name of this role, run the following cell.

This IAM role should have the following permissions,

1. Access to invoke the Foundation Models, you are using on Amazon Bedrock.
2. Full access to read and write to the Amazon OpenSearch Serverless collection created in the previous step.
3. Access to write to Amazon CloudWatch Logs.

In addition, [data access control](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/serverless-data-access.html) should be setup on the Amazon OpenSearch Serverless collection to provide create, read and write access to the IAM role associated with this notebook instance.

Run the following cell to print the details of the IAM role attached to the underlying notebook instance.

In [None]:
# Print the IAM role ARN and console URL
logging.info("This notebook's IAM role is '{}'".format(my_iam_role))
arn_parts = my_iam_role.split('/')
logging.info("Details of this IAM role are available at https://{}.console.aws.amazon.com/iamv2/home?region={}#/roles/details/{}?section=permissions"
             .format(my_region, my_region, arn_parts[len(arn_parts) - 1]))

###  1.i. Create common objects <a id='1.i'></a>

To begin with, list all the available models in Amazon Bedrock by running the following cell. This will help you pick a LLM and the Embeddings model within Amazon Bedrock that you will be using in this notebook. By default, both will use the On-Demand pricing model.

In [None]:
# List all the available foundation models in Amazon Bedrock
models_info = ''
bedrock_client = boto3.client("bedrock", region_name = br_region, endpoint_url = "https://bedrock.{}.amazonaws.com"
                              .format(br_region), config = my_boto3_config)
response = bedrock_client.list_foundation_models()
model_summaries = response["modelSummaries"]
models_info = models_info + "\n"
models_info = models_info + "-".ljust(125, "-") + "\n"
models_info = models_info + "{:<15} {:<30} {:<20} {:<20} {:<40}".format("Provider Name", "Model Name", "Input Modalities",
                                                          "Output Modalities", "Model Id")
models_info = models_info + "-".ljust(125, "-")
for model_summary in model_summaries:
    models_info = models_info + "\n"
    models_info = models_info + "{:<15} {:<30} {:<20} {:<20} {:<40}".format(model_summary["providerName"],
                                                                            model_summary["modelName"],
                                                                            "|".join(model_summary["inputModalities"]),
                                                                            "|".join(model_summary["outputModalities"]),
                                                                            model_summary["modelId"])
models_info = models_info + "-".ljust(125, "-") + "\n"
logging.info("Displaying available models in the '{}' Region:".format(br_region) + models_info)

From the results of running the above cell,

1. Pick the model-id that corresponds to the LLM that you want and set it as the value of the `llm_model_id` variable in the following cell.
2. (Optional) Specify the [LLM-specific inference parameters](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html) in the `model_kwargs` parameter.
2. Pick the model-id that corresponds to the Embeddings model that you want and set it as the value of the `embeddings_model_id` variable in the following cell.

Now, run the following cell to create the common objects to be used in future steps in this notebook.

<div class="alert alert-block alert-info">
<b>Note:</b> This notebook was tested with the following Amazon Bedrock models:
    <li>Embedding model(s): amazon.titan-embed-image-v1</li>
</div>

In [None]:
# Model-id of the Embeddings model to be used to generate embeddings
embeddings_model_id = "amazon.titan-embed-image-v1"

##### LLM related objects
# Function to initialise the Amazon Bedrock runtime client
def initialise_bedrock():
    return boto3.client("bedrock-runtime", region_name = br_region, config = my_boto3_config)

##### Embeddings related objects
# Function to create the Embeddings client using the LangChain BedrockEmbeddings class
bedrock = initialise_bedrock()
def create_embeddings_client():
    return BedrockEmbeddings(client = bedrock, model_id = embeddings_model_id, region_name = br_region)

br_embeddings = create_embeddings_client()


#Function to initialise the Amazon S3 Client
def initialise_s3():
    return boto3.resource('s3')

s3 = initialise_s3()
sage_sess = sagemaker.Session()
default_bucket = sage_sess.default_bucket()

##### Amazon OpenSearch Serverless (AOSS) related objects
# Create the AOSS Python client from the AOSS boto3 client using the helper function 
# available through ./scripts/helper_functions.py)
aoss_py_client = auth_opensearch(host = "{}.{}.aoss.amazonaws.com".format(collection_id, my_region),
                            service = 'aoss', region = my_region)
# Specify the name of the index in the AOSS collection; this will be created later in the notebook
index_name = "product-embeddings-index"

###  1.j. Create an index in the Amazon OpenSearch Serverless collection <a id='1.j'></a>

To create an index in the Amazon OpenSearch Serverless (AOSS) collection, we first need to define a schema for our index. AOSS allows users to specify a simple search index, which utilizes keyword matching, or the vector search feature, which utilizes [k-Nearest Neighbor (k-NN) search](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/knn.html). Vector search differs from standard search in that instead of using a typical keyword matching or fuzzy matching algorithm, vector search compares [embeddings](https://en.wikipedia.org/wiki/Word_embedding) of two pieces of text. An embedding is a numerical representation of a piece of information, like text, that we can compare against other embeddings. To learn more about embeddings, take a look at [this blog](https://huggingface.co/blog/getting-started-with-embeddings). The vector search feature allows us to search for documents that are semantically similar to the questions that our end users send to our chat assistant. This can improve the context that we then give to our LLM to answer the user's questions.

In [None]:
# Define the schema for the index with an k-NN type vector as the embedding
hnsw_index_body = {
    "settings": {
        "index": {
            "knn": True,
            "knn.algo_param.ef_search": 512
        }
    },
    "mappings": {
        "properties": {
        "product_image_and_description_embedding": {
            "type": "knn_vector",
            "dimension": 1024,
            "method": {
              "name": "hnsw",
              "engine": "nmslib",
              "space_type": "cosinesimil",
              "parameters": {
                "ef_construction": 512,
                "m": 16
              }
            }
          },
          "prodId": {
            "type": "text"
          }
        }
      }
    }

# Create the index if it does not exist
if aoss_py_client.indices.exists(index = index_name):
    logging.info("AOSS index '{}' already exists.".format(index_name))
else:
    logging.info("Creating AOSS index '{}'...".format(index_name))
    logging.info(aoss_py_client.indices.create(index = index_name, body = hnsw_index_body, ignore = 400))

In [None]:
# Print the AWS console URL to the AOSS index
index_aws_console_url = collection_aws_console_url + "/" + index_name
logging.info("If you like to take a look at this index, visit {}".format(index_aws_console_url))

In [None]:
# read product items from product_items.json
accumulated_lines = ""
with open('product_items.json', 'r') as json_file:
    for line in json_file:
        accumulated_lines += line.strip()
    product_items = json.loads(accumulated_lines)

In [None]:
#function to create a claude prompt to describe an image that uses anthropic claude's messages api
def generate_claude_search_prompt(image_data):
    system = (
        "You are a merchandizer expert in writing product descriptions. Your task is to generate a "
        "detailed description of input images for effective visual search. Your "
        "description should be engaging, informative, and detailed enough to facilitate"
        "effective multimodal embedding for "
        "similarity search. Aim for clarity and precision to ensure the generated "
        "embeddings capture the essence of the product accurately."
    )
    
    prompt = (
        "Given this input image, generate a detailed and comprehensive description "
        "that includes the following aspects:\n\n"
        "1. Physical Description: Describe the physical attributes of the product, "
        "including color, shape, size, material, and any distinctive features or "
        "design elements. Mention textures, patterns, or any aesthetic "
        "characteristics that stand out.\n\n"
    )

    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": prompt
                },
                {
                    "type": "image",
                    "source": {
                        "type": "base64",
                        "media_type": "image/jpeg",
                        "data": image_data 
                    }
                }
            ]
        }
    ]
    
    body = {
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 1000,
        "system": system,
        "messages": messages
    }
    
    return body


## 2. Create the OpenSearch Search Functionality <a id ='2'> </a>

### 2.a Prepare to load data into the vector database <a id='3.a'></a>

An Amazon OpenSearch Serverless (AOSS) collection is a logical grouping of one or more indexes that work together to support a specific workload or use case.

This notebook will use a vector index for indexing documents in the AOSS collection.

####  Initialize the text splitter <a id='Initialize%20the%20text%20splitter'></a>

When we are running inference on text descriptions, providing an entire description to a LLM as context can be overwhelming to our LLM, especially for very long descriptions. A best practice is to divide the text into easier to consume partially overlapping chunks.

Let's use the LangChain's [RecursiveCharacterTextSplitter](https://python.langchain.com/docs/modules/data_connection/document_transformers/text_splitters/recursive_text_splitter) to create a text splitting object that we will use split the content

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = 750,
    chunk_overlap  = 100,
    length_function = len,
    is_separator_regex = False,
)

In [None]:
# function to get product id
def get_id(product_data: dict) -> str:
    return product_data.get('pid', '')

# function to get urls of product images
def get_image_url(product_data: dict) -> str:
    return product_data.get('image_url', '')

# function to get augmented desccription
def get_augmented_description(product_data: dict) -> str:
    return product_data.get('product_description', '')

# function to get relevant values for embeddings from products
def process_product_item(product_data: dict) -> Tuple:
    return get_id(product_data), get_image_url(product_data), get_augmented_description(product_data)

In [None]:
def create_product_hash_map(product_items):
    """
    Create a hash map mapping the 'id' of each product item to the product item itself.

    Parameters:
    product_items (list): A list of dictionaries, each representing a product item.

    Returns:
    dict: A hash map where each key is an 'id' from the product items and each value is the corresponding product item.
    """
    hash_map = {}
    for item in product_items:
        product_id = item.get('pid')
        if product_id:
            hash_map[product_id] = item
    return hash_map

product_items_map = create_product_hash_map(product_items)

Now we have all our products with their respective images hosted on s3

###  2.b. Create the embeddings <a id='3.b'></a>

Let us set up our functions to generate embeddings by calling the bedrock API

In [None]:
def fetch_image(url:str,s3):
    bucket_name, key = url.replace('s3://', '').split('/', 1)

    obj = s3.Object(bucket_name, key)
    
    try:
        image_data = obj.get()['Body'].read()
        image = Image.open(BytesIO(image_data))
        image.thumbnail((1092,1092))

        # Convert image to base64
        buffered = BytesIO()
        image_format = image.format if image.format else 'JPEG' 
        image.save(buffered, format=image_format)
        return base64.b64encode(buffered.getvalue()).decode(), True
    except Exception as e:
        return None, False

In [None]:
# utility function to generate text embedding or image embedding in isolation, if you just have text or image as input.
'''def get_image_or_text_embedding(image_data:Optional = None, text_chunks:Optional = None, bedrock=None):
    modelId = "amazon.titan-embed-image-v1"
    contentType = "application/json"
    accept = "application/json"
    
    if not text_chunks:
        body = json.dumps({"inputImage": image_data})
        response = bedrock.invoke_model(body=body, modelId=modelId, accept=accept,contentType=contentType)
        return json.loads(response.get('body').read())['embedding']
    
    chunk_embeddings = []
    
    
    for chunk in text_chunks:
        body = json.dumps({"inputText": chunk})
        response = bedrock.invoke_model(body=body, modelId=modelId, accept=accept, contentType=contentType)
        chunk_embedding = json.loads(response.get('body').read())['embedding']
        chunk_embeddings.append(chunk_embedding)
        
    return aggregate_embeddings(chunk_embeddings)'''

# Function to generate a multimodal embedding using both text and image. 
# It is recommended to use the multi-modal embedding for better search results
def get_multi_embedding(image_data, text_chunks, bedrock):
    modelId="amazon.titan-embed-image-v1"
    contentType = "application/json"
    accept = "application/json"
    
    embeddings = []

    for chunk in text_chunks:
        body = json.dumps({"inputText": chunk,
                           "inputImage": image_data})

        response = bedrock.invoke_model(body=body, modelId=modelId, accept=accept, contentType=contentType)

        chunk_embedding = json.loads(response.get('body').read())['embedding']
        embeddings.append(chunk_embedding)

    aggregated_embedding = aggregate_embeddings(embeddings)

    return aggregated_embedding

def normalise_embedding(embedding):
    return embedding / np.linalg.norm(embedding)

def aggregate_embeddings(embeddings):
    averaged_embedding = [sum(x)/len(embeddings) for x in zip(*embeddings)]
    return averaged_embedding

Let's also create a function to sanitise our text descriptions

In [None]:
# function to sanitise text descriptions
def clean_and_prepare_text(text):
    # Normalize extra whitespace to single space
    text = re.sub(r'\s+', ' ', text).strip()
    
    # Remove special characters
    text = re.sub(r'[^\w\s]', '', text)
    
    # Unicode normalization
    text = unicodedata.normalize('NFKD', text)
    
    # Lowercasing
    text = text.lower()
    

    return text

<div class="alert alert-block alert-warning">  
<b>Note:</b> Below 3 cells are optional. <br/>
Now that we have our helper functions to compute embeddings let us test them out for a product and monitor the outputs.
</div>

In [None]:
item_id, image_url, product_description = process_product_item(product_items[0])
image_response = fetch_image(image_url,s3)
image_content = image_response[0]
image_response[1]

In [None]:
process_product_item(product_items[0])

In [None]:
multi_embedding = np.array(get_multi_embedding(image_content,text_splitter.split_text(clean_and_prepare_text(product_description)),bedrock))
multi_embedding

Now for each product image we will generate an image embedding, its corresponding text description embedding and a multimodal embedding leveraging both the image and text description. We will then use these to create actions
we can use to bulk load these embeddings into our OpenSearch vector index

In [None]:
# function to generate load actions from product items
def process_records_to_actions_partial(items_chunk):
    bedrock = initialise_bedrock()
    s3 = initialise_s3()
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size = 1400,
        chunk_overlap  = 600,
        length_function = len,
        is_separator_regex = False,
        )
    
    local_actions = []
    
    for item in items_chunk:
            product_id, image_url, description = process_product_item(item)
            image_response = fetch_image(image_url,s3)
            if not image_response[1]:
                    continue
            try:
                multi_embedding = normalise_embedding(np.array(get_multi_embedding(image_data=image_response[0],
                                                                                   text_chunks=text_splitter.split_text(description),
                                                                                   bedrock=bedrock)))
                load_action = {
                    "_index": "product-embeddings-index",
                    "_source":{
                        "prodId": product_id,
                        "product_image_and_description_embedding": multi_embedding.tolist()
                    }
                }
                local_actions.append(load_action)
            except Exception as e:
                pass

        
        

    return local_actions



now that we have our function to create load actions let's generate these actions for each product image, we will parallelise this for speed and efficiency

In [None]:
def split_data(data, num_splits):
    avg = len(data) // num_splits
    out = []
    last = 0.0

    while last < len(data):
        out.append(data[int(last):int(last + avg)])
        last += avg

    return out


In [None]:
# Function to handle the progress update
def update_progress(future, pbar):
    pbar.update(1)

<div class="alert alert-block alert-info">
    <b>Note:</b> After executing the below cell, it may take up to few minutes to complete (timing depends on the data volumes). 
</div>

In [None]:
NUM_PROCESSES = 2
items_chunks = split_data(product_items, NUM_PROCESSES)
main_pbar = tqdm(total=NUM_PROCESSES, desc='Overall Progress')

# Using ThreadPoolExecutor to process records
with cf.ThreadPoolExecutor(max_workers=NUM_PROCESSES) as executor:
    # Submitting tasks and attaching a callback to update the main progress bar
    future_to_chunk = {executor.submit(process_records_to_actions_partial, chunk): chunk for chunk in items_chunks}
    for future in cf.as_completed(future_to_chunk):
        future.add_done_callback(lambda f: update_progress(f, main_pbar))

# Ensure the main progress bar is closed
main_pbar.close();

### 2.c Store the embeddings in the vector database <a id='3.c'></a>

Run the following cells to upload the prepared documents into our created AOSS collection's index. The below function uses a parallel processing function to upload our documents into our index. The number of parallel worker threads is controlled by the `thread_count` variable.

In [None]:
# Initialize global arrays to get all documents from previous threads in one array
global_actions = []

for future in future_to_chunk:
    local_actions = future.result()

    global_actions.extend(local_actions)

In [None]:
len(global_actions)

<div class="alert alert-block alert-warning">  
    <b>Note:</b> At the time of writing this notebook, AOSS did not support ingestion with <i>id</i> for <i>Vector search</i> collection type. As a result, running the following cell more than once will result in duplicate documents being created in the AOSS index. This is ok for the purpose of running this notebook.
</div>

In [None]:
#load documents into our index using opensearh-py's parallel_bulk
for success, info in parallel_bulk(client=aoss_py_client,actions=global_actions, request_timeout=60*60, thread_count = 8):
    if not success:
        print('A document failed:', info)

### 2.d Test searching our vector database <a id='3.d'></a>


Now let's test a search on our product_image_embedding field with the following image:

In [None]:
def resize_image(img_file, size=(1092, 1092)):
    with Image.open(img_file) as img:
        img.thumbnail(size)
        buffer = BytesIO()
        img.save(buffer, format="JPEG")
        return buffer.getvalue()

def encode_image(img_file):
    resized_image = resize_image(img_file)
    img_str = base64.b64encode(resized_image)
    base64_string = img_str.decode("latin1")
    return base64_string

In [None]:
def get_base64_image(image_path:str)->str:
    if 's3' in image_path:
        s3 = initialise_s3()
        return fetch_image(image_path,s3)[0]
    else:
        return encode_image(image_path)
    return ''
        

In [None]:
@backoff.on_exception(
        backoff.expo,
        (ClientError),
        max_time=secure_randint(500, 1000),
        on_backoff=backoff_hdlr,
        giveup=lambda e: 'ThrottlingException' not in str(e) and 'ReadTimeoutError' not in str(e) and 'ModelTimeoutException' not in str(e)
    ) 
def get_search_description(base64_image:str, bedrock=None)->str:
    modelId = "anthropic.claude-3-sonnet-20240229-v1:0"
    contentType = "application/json"
    accept = "application/json"
    
    body = json.dumps(generate_claude_search_prompt(base64_image))
    response = bedrock.invoke_model(body=body, modelId=modelId, accept=accept,contentType=contentType)
    return json.loads(response.get('body').read())['content'][0]['text']

In [None]:
# replace /images/test_image.jpeg with your test image path 
query_image_path = "/images/test_image.jpeg"
get_search_description(get_base64_image(image_path=query_image_path),bedrock)

In [None]:
k=30
base64_image = get_base64_image(image_path=query_image_path)
search_query = {
        "size": k,
        "_source": ["prodId"], #to only Include only the product ID field in the response
        "query":{
            "knn": {
                "product_image_and_description_embedding": {
                    "vector": get_multi_embedding(image_data=base64_image,
                                                  text_chunks=text_splitter.split_text(clean_and_prepare_text(get_search_description(base64_image,bedrock))),
                                                  bedrock=bedrock),
                    "k": k,
                    },
                }
            }
        }
    
aoss_py_client.search(index="product-embeddings-index", body=search_query)

And to visualise our results:

We retrieve the relevant information we want from the results, the product IDs and their scores:

In [None]:
results = [(item['_source']['prodId'],item['_score']) for item in aoss_py_client.search(index="product-embeddings-index", body=search_query)["hits"]["hits"]]
results

In [None]:
def aggregate_and_sort_scores(results):
    score_dict = {}
    for prod_id, score in results:
        if prod_id in score_dict:
            score_dict[prod_id] += score
        else:
            score_dict[prod_id] = score

    sorted_prod_ids = sorted(score_dict, key=score_dict.get, reverse=True)

    return sorted_prod_ids
sorted_prod_ids = aggregate_and_sort_scores(results)
print(sorted_prod_ids)

In [None]:
def display_images_from_s3_with_fetch_image(products, s3):
    for i, result in enumerate(products, 2):
        image_data_base64, success = fetch_image(result[1], s3)
        if not success:
            print(f"Failed to fetch image for ID: {result[0]}")
            continue

        # Decode the base64 image data
        image_data = base64.b64decode(image_data_base64)
        img = Image.open(BytesIO(image_data))

        # Display the image
        plt.subplot(1, len(products) + 1, i)
        plt.imshow(img)
        plt.title(f"ID: {result[0][:8]}")
        plt.axis('off')  # Optional: to hide axes for a cleaner display

    plt.show()

In [None]:
products = []
for product_id in sorted_prod_ids[:5]:
    item = product_items_map.get(product_id)
    images=item['image_url']
    products.append((item['product_name'],images))
    
query_image = Image.open(query_image_path,)  
plt.figure(figsize=(20, 10))
plt.subplot(1, len(products) + 1, 1)
plt.imshow(query_image)
plt.title("Query")
plt.axis('off')
        
# Display Search Results
for i, result in enumerate(products, 2):
    image_data_base64, success = fetch_image(result[1], s3)
    if not success:
        print(f"Failed to fetch image for ID: {result[0]}")
        continue
        
    # Decode the base64 image data
    image_data = base64.b64decode(image_data_base64)
    img = Image.open(BytesIO(image_data))
    
    # Display the image
    plt.subplot(1, len(products) + 1, i)
    plt.imshow(img)
    plt.title(f"ID: {result[0]}")
    plt.axis('off')  

plt.show()

Now that we have our product emebeddings in our OpenSearch index, let's create some functions to allow us to search it, leveraging all the fields we have in our index

In [None]:
from collections import defaultdict

def search_opensearch_indexes(query_image_url:Optional[str]=None, k:Optional[int]=5):
    """
    Search for similar images and texts based on the embedding of a query image,text or both.
    
    Parameters:
    - query_image_url (str): The URL of the query image.
    - text_query(str): a text description of the product you want
    - k (int, optional): The number of top similar items to return. Default is 5.
    
    Returns:
    - list: A list of product IDs for the top k similar items.
    
    """
    
    def create_search_query(embedding, k):
        return {
        "size": k,
        "query": {
            "bool": {
                "should": [
                    {
                        "function_score": {
                            "query": {
                                "knn": {
                                    "product_image_and_description_embedding": {
                                        "vector": embedding,
                                        "k": k,
                                    },
                                }
                            },
                            "weight": 1,
                        }
                    }
                ]
            }
        }
    }
    base64_image = get_base64_image(query_image_url)
    query_multi_embedding_array = normalise_embedding(np.array(get_multi_embedding(image_data=base64_image,
                                                                                   text_chunks=text_splitter.split_text(clean_and_prepare_text(get_search_description(base64_image,bedrock))),
                                                                                   bedrock=bedrock)))
    query_multi_embedding = query_multi_embedding_array.tolist()
    response = aoss_py_client.search(index="product-embeddings-index", body=create_search_query(query_multi_embedding,k))
    return response['hits']['hits']
    

### 2.e Build our search functions <a id='3.e'></a>

In [None]:

def visual_search(query_image_url:Optional[str]=None, k:Optional[int]=5):
    """
    Visualize the search results along with the query image.
    
    Parameters:
    - query_image (str or bytes): The URL or byte content of the query image.
    - search_results (list): A list of product IDs representing the search results.
    - records (list): A list of records containing product information.
    """
    def open_image(query_image_url):
        if 'sagemaker' in query_image_url:
            image_data_base64, success = fetch_image(query_image_url, s3)
            if not success:
                print(f"Failed to fetch image for ID: {query_image_url}")
            image_data = base64.b64decode(image_data_base64)
            return Image.open(BytesIO(image_data))
        
        elif 'root' in query_image_url:
            return Image.open(query_image_url)
        
    search_results = search_opensearch_indexes(query_image_url=query_image_url, k=30)
    results = [(item['_source']['prodId'],item['_score']) for item in search_results]
    sortedlist = aggregate_and_sort_scores(results)
    products = []
    for product_id in sortedlist[:5]:
        item = product_items_map.get(product_id)
        image=item['image_url']
        products.append((item['product_name'],image))
        
    total_rows = 2
    
    query_image = open_image(query_image_url)
    plt.figure(figsize=(20, 10))
    plt.subplot(total_rows, 1, 1)
    plt.imshow(query_image)
    plt.axis('off')
        

    # Display Search Results
    for i,result in enumerate(products,1):
        img = open_image(result[1])
        plt.subplot(total_rows, len(products), len(products) + i)
        plt.imshow(img)
        plt.title(f"ID: {result[0]}")

    plt.show()

In [None]:
# Test using a sample query image. replace /query_image.jpeg with your actual image path
query_image_path = "/query_image.jpeg"
visual_search(query_image_url=query_image_path)

In [None]:
# Test with an array of query images. replace /query_image_1.jpeg and /query_image_2.jpeg with your test images
paths =[
    "/query_image_1.jpeg",
    "/query_image_2.jpeg",
]

In [None]:
for path in paths:
    visual_search(query_image_url=path)

## 3. Frequently Asked Questions (FAQs) <a id='FAQs'></a>

**Q: What AWS services are used in this notebook?**

Amazon Bedrock, Amazon OpenSearch Serverless, AWS Identity and Access Management (IAM), Amazon CloudWatch, and Amazon SageMaker Notebook instance (or) Amazon SageMaker Studio Notebook depending on what you use to run the notebook.

**Q: What is the difference between OpenSearch, Amazon OpenSearch Serverless, and Amazon OpenSearch Service?**

OpenSearch is a fully open-source search and analytics engine for use cases such as log analytics, real-time application monitoring, and clickstream analysis. For more information, see the [OpenSearch documentation](https://opensearch.org/docs/latest/).

Amazon OpenSearch Service provisions all the resources for your OpenSearch cluster and launches it. It also automatically detects and replaces failed OpenSearch Service nodes, reducing the overhead associated with self-managed infrastructures. You can scale your cluster with a single API call or a few clicks in the console.

Amazon OpenSearch Serverless is an on-demand serverless configuration for Amazon OpenSearch Service. Serverless removes the operational complexities of provisioning, configuring, and tuning your OpenSearch clusters. It's a good option for organizations that don't want to self-manage their OpenSearch clusters, or organizations that don't have the dedicated resources or expertise to operate large clusters. With OpenSearch Serverless, you can easily search and analyze a large volume of data without having to worry about the underlying infrastructure and data management.

**Q: How does Amazon OpenSearch Serverless manage capacity?**

With Amazon OpenSearch Serverless, you don't have to manage capacity yourself. OpenSearch Serverless automatically scales compute capacity for your account based on the current workload. Serverless compute capacity is measured in OpenSearch Compute Units (OCUs). Each OCU is a combination of 6 GiB of memory and corresponding virtual CPU (vCPU), as well as data transfer to Amazon S3. For more information about the decoupled architecture in OpenSearch Serverless, see [How it works](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/serverless-overview.html#serverless-process).

**Q: Will Amazon Bedrock capture and store my data?**

Amazon Bedrock doesn't use your prompts and continuations to train any AWS models or distribute them to third parties. Your training data isn't used to train the base Amazon Titan models or distributed to third parties. Other usage data, such as usage timestamps, logged account IDs, and other information logged by the service, is also not used to train the models.

Amazon Bedrock uses the fine tuning data you provide only for fine tuning an Amazon Titan model. Amazon Bedrock doesn't use fine tuning data for any other purpose, such as training base foundation models.

Each model provider has an escrow account that they upload their models to. The Amazon Bedrock inference account has permissions to call these models, but the escrow accounts themselves don't have outbound permissions to Amazon Bedrock accounts. Additionally, model providers don't have access to Amazon Bedrock logs or access to customer prompts and continuations.

Amazon Bedrock doesn’t store or log your data in its service logs.

**Q: What models are supported by Amazon Bedrock?**

Go [here](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html#models-supported).

**Q: What is the difference between On-demand and Provisioned Throughput in Amazon Bedrock?**

With the On-Demand mode, you only pay for what you use, with no time-based term commitments. For text generation models, you are charged for every input token processed and every output token generated. For embeddings models, you are charged for every input token processed. A token is comprised of a few characters and refers to the basic unit that a model learns to understand user input and prompt to generate results. For image generation models, you are charged for every image generated.

With the Provisioned Throughput mode, you can purchase model units for a specific base or custom model. The Provisioned Throughput mode is primarily designed for large consistent inference workloads that need guaranteed throughput. Custom models can only be accessed using Provisioned Throughput. A model unit provides a certain throughput, which is measured by the maximum number of input or output tokens processed per minute. With this Provisioned Throughput pricing, charged by the hour, you have the flexibility to choose between 1-month or 6-month commitment terms.

**Q: Where can I find customer references for Amazon Bedrock?**

Go [here](https://aws.amazon.com/bedrock/testimonials/).

**Q: Where can I find pricing information for the AWS services used in this notebook?**

- Amazon Bedrock pricing - go [here](https://aws.amazon.com/bedrock/pricing/).
- Amazon OpenSearch Serverless pricing - go [here](https://aws.amazon.com/opensearch-service/pricing/) and navigate to the <i>Serverless</i> section.
- AWS Identity and Access Management (IAM) pricing - free.
- Amazon CloudWatch pricing - go [here](https://aws.amazon.com/cloudwatch/pricing/).
- Amazon SageMaker Notebook instance (or) Amazon SageMaker Studio Notebook pricing - go [here](https://aws.amazon.com/sagemaker/pricing/).