# Amazon Bedrock Multimodal Workshop
## Content Search -- Indexing and search
In this notebook we are going to populate the vector database and perform search with text and images. 

We will also do comparisons between different embedding sizes. 

### Install and import needed libraries
For this notebook to run correctly, we will need to install, import and initialize the necessary libraries and clients. 

In [None]:
!pip install -q opensearch-py
!pip install -q requests_aws4auth

In [None]:
import os
import io
import json
import time
import boto3
import base64
import datetime
import pandas as pd
from PIL import Image
import concurrent.futures
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth
from aoss_utils import createEncryptionPolicy, createNetworkPolicy, createAccessPolicy, createCollection, waitForCollectionCreation
opensearch_client = boto3.client('opensearchserverless')
bedrock_runtime = boto3.client(service_name="bedrock-runtime")

In [None]:
# Define output vector size – 1,024 (default), 384, 256
outputEmbeddingLength = 1024

In [None]:
session = boto3.Session()
identity_arn = session.client('sts').get_caller_identity()['Arn']
print("Current IAM Role ARN:", identity_arn)

### Define dataset
We are going to be using the curated dataset from the previous notebook.

In [None]:
dataset_file = "curated_dataset.json"

### Create embeddings from image

In [None]:
def create_embeddings_from_image(image_path, outputEmbeddingLength):
    max_height = 2048
    max_width = 2048
    
    # Open and resize the image
    with Image.open(image_path) as img:
        if (img.size[0] * img.size[1]) > (max_height * max_width):
            img.thumbnail((max_height, max_width))
            resized_img = img.copy() 
        else:
            resized_img = img 

        # Convert the resized image to bytes
        img_byte_array = io.BytesIO()
        resized_img.save(img_byte_array, format=img.format)
        img_bytes = img_byte_array.getvalue()

    # Encode the resized image to base64
    image_encoded = base64.b64encode(img_bytes).decode('utf8')

    # Prepare the request body
    body = json.dumps(
        {
            "inputImage": image_encoded,
            "embeddingConfig": {
                "outputEmbeddingLength": outputEmbeddingLength
            }
        }
    )

    # Make the API call to the bedrock_runtime
    response = bedrock_runtime.invoke_model(
        body=body,
        modelId="amazon.titan-embed-image-v1",
        accept="application/json",
        contentType="application/json"
    )

    # Parse and return the vector
    vector = json.loads(response['body'].read().decode('utf8'))
    return vector

### Create a vector database using Amazon OpenSearch Serverless

#### Create an Amazon OpenSearch Serverless Collection

In [None]:
client = boto3.client('opensearchserverless')
service = 'aoss'
region = 'us-east-1'
credentials = boto3.Session().get_credentials()
awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, region, service, session_token=credentials.token)
collection_name = "retail-collection-1"

In [None]:
createEncryptionPolicy(client, collection_name)
createNetworkPolicy(client, collection_name)
createAccessPolicy(client, collection_name, identity_arn)
createCollection(client, collection_name)
host, collection_id = waitForCollectionCreation(client, collection_name)

#### Create a Collection Index

In [None]:
region = 'us-east-1'
service = 'aoss'
credentials = boto3.Session().get_credentials()
awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, region, service,session_token=credentials.token)

In [None]:
OSSclient = OpenSearch(
    hosts=[{'host': host, 'port': 443}],
    http_auth=awsauth,
    use_ssl=True,
    verify_certs=True,
    connection_class=RequestsHttpConnection,
    timeout=300
)

In [None]:
def create_index(index, outputEmbeddingLength) :
    if not OSSclient.indices.exists(index):
        settings = {
            "settings": {
                "index": {
                    "knn": True,
                }
            },
            "mappings": {
                "properties": {
                    "id": {"type": "text"},
                    "name": {"type": "text"},
                    "color": {"type": "text"},
                    "brand": {"type": "text"},
                    "description": {"type": "text"},
                    "createtime": {"type": "text"},
                    "image_path":{"type": "text"},
                    "vector_field": {
                        "type": "knn_vector",
                        "dimension": outputEmbeddingLength,
                    },
                }
            },
        }
        res = OSSclient.indices.create(index, body=settings)
        print(res)

In [None]:
index_name = "retail-dataset-{}".format(outputEmbeddingLength)

In [None]:
create_index(index_name, outputEmbeddingLength)

### Populate the index

In [None]:
def create_dataset_list(records_file):
    dataset_list = []

    with open(records_file, 'r') as json_file:
            dataset_list = json.load(json_file)
    
    return dataset_list
    
def process_batch(batch, index, outputEmbeddingLength):
    start_time = datetime.datetime.now()
    bulk_data = ""
    for entry in batch:        
        image_location = "images/{}".format(entry["image_path"])
        vector = create_embeddings_from_image(image_location, outputEmbeddingLength)
        dt = datetime.datetime.now().isoformat()
        doc = {
            "vector_field" : vector["embedding"],
            "createtime": dt,
            "id": entry["item_id"],
            "name": entry["item_name"],
            "color": entry["color"],
            "brand": entry["brand"],
            "description": entry["description"],
            "image_path": entry["image_path"]
        }
        
        bulk_entry = "{{\"index\": {{\"_index\": \"{}\"}}}}\n{}\n".format(index, json.dumps(doc))
        bulk_data += bulk_entry
    end_time = datetime.datetime.now()
    processing_time = (end_time - start_time).total_seconds() * 1000  # Convert to milliseconds
    print("Processed {} records in {} ms".format(len(batch), processing_time))
    response = OSSclient.bulk(bulk_data)
    if (response["errors"] is False):
        print("Sent {} records in {} ms".format(len(response["items"]), response["took"]))
    else:
        print("Error found")

def populate_vector_database(records_file, index, outputEmbeddingLength, batch_size=100):
    dataset_list = create_dataset_list(records_file)
    with concurrent.futures.ThreadPoolExecutor() as executor:
        # Split the dataset into batches
        batches = [dataset_list[i:i+batch_size] for i in range(0, len(dataset_list), batch_size)]

        # Map the process_batch function to each batch in the dataset using multiple threads
        futures = [executor.submit(process_batch, batch, index, outputEmbeddingLength) for batch in batches]

        # Wait for all threads to complete
        concurrent.futures.wait(futures)

In [None]:
populate_vector_database(dataset_file, index_name, outputEmbeddingLength)


### Query the vector database

#### Search by text

In [None]:
def get_embedding_for_text(text, outputEmbeddingLength):
    body = json.dumps(
        {"inputText": text, 
         "embeddingConfig": { 
                "outputEmbeddingLength": outputEmbeddingLength
            }
        }
    )

    response = bedrock_runtime.invoke_model(
        body=body, 
        modelId="amazon.titan-embed-image-v1", 
        accept="application/json", 
        contentType="application/json"       
    )

    vector_json = json.loads(response['body'].read().decode('utf8'))

    return vector_json, text

def query_the_database_with_text(text, index, outputEmbeddingLength, k):
    o_vector_json, o_text = get_embedding_for_text(text, outputEmbeddingLength)
    query = {
      'query': {
        'bool': {
            "must": [
                {
                    "knn":{
                       'vector_field':{
                           "vector":o_vector_json["embedding"],
                           "k": k
                       } 
                    }
                }
            ]
        }
      }
    }
    
    response = OSSclient.search(
        body = query,
        index = index
    )
    
    return response
    
def display_images(image_data):
    # Create a subplot with 1 row and the number of images as columns
    num_images = len(image_data)
    fig, axes = plt.subplots(1, num_images, figsize=(15, 5))

    # Iterate over each image data entry and display the image and description
    for i, entry in enumerate(image_data):
        image_path = "images/{}".format(entry['_source']['image_path'])
        #description = entry['metadata']['description']
        
        # Load and display the image
        img = mpimg.imread(image_path)
        axes[i].imshow(img)
        axes[i].axis('off')
        axes[i].set_title("{}".format(entry['_source']['image_path']))
        axes[i].text(0.5, -0.1, f"Score: {entry['_score']:.4f}", ha='center', transform=axes[i].transAxes)
    # Adjust layout to prevent clipping of titles
    plt.tight_layout()
    plt.show()

In [None]:
results_text = query_the_database_with_text("A bed", index_name, outputEmbeddingLength, k=10)
# Display the results
display_images(results_text["hits"]["hits"])

#### Search by image

In [None]:
def query_the_database_with_image(image, index, outputEmbeddingLength, k):
    o_vector_json = create_embeddings_from_image(image, outputEmbeddingLength)
    query = {
      'query': {
        'bool': {
            "must": [
                {
                    "knn":{
                       'vector_field':{
                           "vector":o_vector_json["embedding"],
                           "k": k
                       } 
                    }
                }
            ]
        }
      }
    }
    
    response = OSSclient.search(
        body = query,
        index = index
    )
    
    return response

In [None]:
test_image = "test-images/bed.png" # Locate test image

In [None]:
results = query_the_database_with_image(test_image, index_name, outputEmbeddingLength, k=10)

In [None]:
# Display the results
display_images(results["hits"]["hits"])

## Compare different vector sizes results

This section will showcase you the difference between using different vector sizes.

In [None]:
vector_sizes = [1024, 384, 256]
index_name_compare = "retail-dataset"
k = 10

In [None]:
def vector_comparison_populate(vector_sizes, index_name, dataset_file):
    for vector_size in vector_sizes:
        test_index_name = "test-{}-{}".format(index_name, vector_size)
        create_index(test_index_name, vector_size)
        populate_vector_database(dataset_file, test_index_name, vector_size)

def text_query_comparison(vector_sizes, index_name, text_query, k):
    query_results = []
    for vector_size in vector_sizes:
        test_index_name = "test-{}-{}".format(index_name, vector_size)
        query_result = query_the_database_with_text(text_query, test_index_name, vector_size, k)
        query_results.append({
            "index": test_index_name,
            "results": query_result
        })      
    return query_results

def image_query_comparison(vector_sizes, index_name, image, k):
    query_results = []
    for vector_size in vector_sizes:
        test_index_name = "test-{}-{}".format(index_name, vector_size)
        query_result = query_the_database_with_image(image, test_index_name, vector_size, k)
        query_results.append({
            "index": test_index_name,
            "results": query_result
        })      
    return query_results

def print_results(results):
    dfs = []
    for item in results:
        index_name = item['index']
        hits = item['results']['hits']['hits']
    
        df_data = {'title': [], 'score': []}
    
        for hit in hits:
            df_data['title'].append(hit['_source']['name'])
            df_data['score'].append(hit['_score'])
    
        df = pd.DataFrame(df_data)
        dfs.append({index_name: df})
    
    # Concatenate and print dataframes side by side
    dfs_concatenated = pd.concat([list(df_dict.values())[0] for df_dict in dfs],
                                 axis=1,
                                 keys=[list(df_dict.keys())[0] for df_dict in dfs])
    
    return dfs_concatenated


In [None]:
vector_comparison_populate(vector_sizes, index_name_compare,  dataset_file)

#### Compare Text Search against the different indexes
You might need to wait a couple of seconds before indexing has finished

In [None]:
text_query = "A bed"

In [None]:
text_query_results = text_query_comparison(vector_sizes, index_name_compare, text_query, k)

In [None]:
print_results(text_query_results)

#### Compare Image Search against the different indexes

In [None]:
test_image = "test-images/bed.png" # Locate test image

In [None]:
image_query_results = image_query_comparison(vector_sizes, index_name_compare, test_image, k)

In [None]:
print_results(image_query_results)

In [None]:
for result in text_query_results:
    print("Results for {}".format(result["index"]))
    print("Query: {}".format(text_query))
    display_images(result['results']["hits"]["hits"])

In [None]:
for result in image_query_results:
    print("Results for {}".format(result["index"]))
    print("Input: {}".format(test_image))
    display_images(result['results']["hits"]["hits"])

### Clean up 
In this section we will delete any resource which may incur in unnecessary costs.

In [None]:
response = client.delete_security_policy(
    name='{}-policy'.format(collection_name),
    type='encryption'
)

response = client.delete_security_policy(
     name='{}-policy'.format(collection_name),
    type='network'
)

response = client.delete_access_policy(
    name='{}-policy'.format(collection_name),
    type='data'
)

response = client.delete_collection(
    id=collection_id
)