In [None]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Objective

Given a list of embeddings, create and deploy a Vertex AI Vector Search (fka Matching Engine) index.

Assumes you already have embeddings in GCS in a format supported by Vertex AI Vector Search. For instructions on that see:
1. Notebook for generating multimodal embeddings and storing in BQ  
2. The product does not yet support direct import from BQ. Use below query to convert BQ format into one accepted by Vertex AI Vector Store. After running query, export as JSON

SELECT
 CONCAT(id,'_T') as id,
 text_embedding as embedding,
 c0_name as L0,
 c1_name as L1,
 c2_name as L2,
 c3_name as L3

FROM
 `<PROJECT_ID>.<DATASET_ID>.<TABLE_ID>`

UNION ALL

SELECT
 CONCAT(id,'_I') as embedding,
 image_embedding,
 c0_name as L0,
 c1_name as L1,
 c2_name as L2,
 c3_name as L3

FROM
 `<PROJECT_ID>.<DATASET_ID>.<TABLE_ID>`;

In [None]:
!pip install google-cloud-aiplatform



---

#### ⚠️ Do not forget to click the "RESTART RUNTIME" button above.

---

In [None]:
import sys

if 'google.colab' in sys.modules:
    from google.colab import auth as google_auth
    google_auth.authenticate_user()

In [None]:
PROJECT_ID = 'solutions-2023-mar-107' # @param {type:"string"}
REGION = 'us-central1' # @param {type:"string"}
BUCKET_URI = "gs://vector_search_regional/flipkart_multimodal_embeddings" # @param {type:"string"} # WHERE EMBEDDINGS ARE STORED

In [None]:
from google.cloud import aiplatform
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)

# Create Index

In [None]:
tree_ah_index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
    display_name='flipkart_batch',
    contents_delta_uri=BUCKET_URI,
    dimensions=1408,
    approximate_neighbors_count=150,
    distance_measure_type="COSINE_DISTANCE",
    leaf_node_embedding_count=500,
    leaf_nodes_to_search_percent=7,
    description='Based on ~18K Flipkart product listings with both description and image',
)

INFO:google.cloud.aiplatform.matching_engine.matching_engine_index:Creating MatchingEngineIndex
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index:Create MatchingEngineIndex backing LRO: projects/411826505131/locations/us-central1/indexes/2594851839597871104/operations/8187649057200537600
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index:MatchingEngineIndex created. Resource name: projects/411826505131/locations/us-central1/indexes/2594851839597871104
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index:To use this MatchingEngineIndex in another session:
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index:index = aiplatform.MatchingEngineIndex('projects/411826505131/locations/us-central1/indexes/2594851839597871104')


In [None]:
INDEX_RESOURCE_NAME = tree_ah_index.resource_name
print(INDEX_RESOURCE_NAME)

projects/411826505131/locations/us-central1/indexes/2594851839597871104


In [None]:
tree_ah_index = aiplatform.MatchingEngineIndex(index_name=INDEX_RESOURCE_NAME)

# Deploy Index

In [None]:
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
    display_name='flipkart_batch',
    description='Endpoint on flipkart',
    public_endpoint_enabled=True,
)

INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:Creating MatchingEngineIndexEndpoint
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:Create MatchingEngineIndexEndpoint backing LRO: projects/411826505131/locations/us-central1/indexEndpoints/6297373683249840128/operations/2224320200608579584
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:MatchingEngineIndexEndpoint created. Resource name: projects/411826505131/locations/us-central1/indexEndpoints/6297373683249840128
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:To use this MatchingEngineIndexEndpoint in another session:
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:index_endpoint = aiplatform.MatchingEngineIndexEndpoint('projects/411826505131/locations/us-central1/indexEndpoints/6297373683249840128')


In [None]:
DEPLOYED_INDEX_ID = 'flipkart_multimodal_18K'
my_index_endpoint = my_index_endpoint.deploy_index(
    index=tree_ah_index, deployed_index_id=DEPLOYED_INDEX_ID
)
my_index_endpoint.deployed_indexes

INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:Deploying index MatchingEngineIndexEndpoint index_endpoint: projects/411826505131/locations/us-central1/indexEndpoints/6297373683249840128
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:Deploy index MatchingEngineIndexEndpoint index_endpoint backing LRO: projects/411826505131/locations/us-central1/indexEndpoints/6297373683249840128/operations/2792618178587394048
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:MatchingEngineIndexEndpoint index_endpoint Deployed index. Resource name: projects/411826505131/locations/us-central1/indexEndpoints/6297373683249840128


[id: "flipkart_multimodal_18K"
index: "projects/411826505131/locations/us-central1/indexes/2594851839597871104"
create_time {
  seconds: 1702887352
  nanos: 664772000
}
index_sync_time {
  seconds: 1702887352
  nanos: 664772000
}
automatic_resources {
  min_replica_count: 2
  max_replica_count: 2
}
deployment_group: "default"
]

# Query Index

In [None]:
#TODO move this code to a module
import base64
from google.cloud import aiplatform
from google.protobuf import struct_pb2
from functools import cache
import time
import typing
import logging

# Inspired from https://stackoverflow.com/questions/34269772/type-hints-in-namedtuple.
class EmbeddingResponse(typing.NamedTuple):
  text_embedding: typing.Sequence[float]
  image_embedding: typing.Sequence[float]


class EmbeddingPredictionClient:
  """Wrapper around Prediction Service Client."""
  def __init__(self, project : str,
    location : str = "us-central1",
    api_regional_endpoint: str = "us-central1-aiplatform.googleapis.com"):
    client_options = {"api_endpoint": api_regional_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    self.client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)
    self.location = location
    self.project = project

  def get_embedding(self, text : str = None, image_path : str = None):
    """image_path can be a local path or a GCS URI."""
    if not text and not image_path:
      raise ValueError('At least one of text or image_bytes must be specified.')

    instance = struct_pb2.Struct()
    if text:
      if len(text) > 1024:
        logging.warning('Text must be less than 1024 characters. Truncating text.')
        text = text[:1024]
      instance.fields['text'].string_value = text

    if image_path:
      image_struct = instance.fields['image'].struct_value
      if image_path.lower().startswith('gs://'):
        image_struct.fields['gcsUri'].string_value = image_path
      else:
        with open(image_path, "rb") as f:
          image_bytes = f.read()
        encoded_content = base64.b64encode(image_bytes).decode("utf-8")
        image_struct.fields['bytesBase64Encoded'].string_value = encoded_content

    instances = [instance]
    endpoint = (f"projects/{self.project}/locations/{self.location}"
      "/publishers/google/models/multimodalembedding@001")
    response = self.client.predict(endpoint=endpoint, instances=instances)

    text_embedding = None
    if text:
      text_emb_value = response.predictions[0]['textEmbedding']
      text_embedding = [v for v in text_emb_value]

    image_embedding = None
    if image_path:
      image_emb_value = response.predictions[0]['imageEmbedding']
      image_embedding = [v for v in image_emb_value]

    return EmbeddingResponse(
      text_embedding=text_embedding,
      image_embedding=image_embedding)

@cache
def get_client(project):
  return EmbeddingPredictionClient(project)


def embed(project,text,image_path=None):
  client = get_client(project)
  start = time.time()
  response = client.get_embedding(text=text, image_path=image_path)
  end = time.time()
  print('Embedding Time: ', end - start)
  return response

In [None]:
res = embed(PROJECT_ID,
            "Key Features of Vishudh Printed Women's Straight Kurta BLACK, GREY Straight,Specifications of Vishudh Printed Women's Straight Kurta Kurta Details Sleeve Sleeveless Number of Contents in Sales Package Pack of 1 Fabric 100% POLYESTER Type Straight Neck ROUND NECK General Details Pattern Printed Occasion Festive Ideal For Women's In the Box Kurta Additional Details Style Code VNKU004374 BLACK::GREY Fabric Care Gentle Machine Wash in Lukewarm Water, Do Not Bleach",
            'gs://genai-product-catalog/flipkart_20k_oct26/3ecb859759e5311cbab6850e98879522_0.jpg')

Embedding Time:  0.983769416809082


In [None]:
NUM_NEIGHBORS = 5

response = my_index_endpoint.find_neighbors(
    deployed_index_id=DEPLOYED_INDEX_ID,
    queries=[res.text_embedding,res.image_embedding],
    num_neighbors=NUM_NEIGHBORS,
)

response

[[MatchNeighbor(id='3ecb859759e5311cbab6850e98879522_T', distance=1.1920928955078125e-07),
  MatchNeighbor(id='0305111c779fe663bd94122bef0f0002_T', distance=0.18872332572937012),
  MatchNeighbor(id='ba8163913f5e384d17a8202b1f8b91b3_T', distance=0.2593120336532593),
  MatchNeighbor(id='6ef0a5eb033cd610d455be7102da5685_T', distance=0.3441193103790283),
  MatchNeighbor(id='ee383a337af67ae8ad4f42714d67ddaf_T', distance=0.3441193103790283)],
 [MatchNeighbor(id='3ecb859759e5311cbab6850e98879522_I', distance=0.0),
  MatchNeighbor(id='06ad8323cf9105f1aaae6515cf08a7d6_I', distance=0.03433966636657715),
  MatchNeighbor(id='c9c27aa5dc7df49e82e55e8abb6b4020_I', distance=0.04238331317901611),
  MatchNeighbor(id='169902631b89202f0e2079e9cc09b3c7_I', distance=0.04768931865692139),
  MatchNeighbor(id='5614ccefd0ab9bee5cd28bf3d38fd12f_I', distance=0.0625385046005249)]]

# Future Improvements

1. Use streaming mode to enable fast updates. Doesn't seem to be exposed in python client.
2. Take advantage of filtering option. Add filters for top level category and embedding type