# Objective

Given a list of embeddings, create and deploy a Vertex AI Vector Search (fka Matching Engine) index with Streaming Updates

Assumes you already have embeddings in GCS in a format supported by Vertex AI Vector Search.

In [None]:
!pip install google-cloud-aiplatform



---

#### ⚠️ Do not forget to click the "RESTART RUNTIME" button above.

---

In [33]:
import sys

if 'google.colab' in sys.modules:
    from google.colab import auth as google_auth
    google_auth.authenticate_user()

In [None]:
PROJECT_ID = 'solutions-2023-mar-107' # @param {type:"string"}
REGION = 'us-central1' # @param {type:"string"}
BUCKET_URI = "gs://vector_search_regional/flipkart_multimodal_embeddings" # @param {type:"string"} # WHERE EMBEDDINGS ARE STORED
ENDPOINT = "{}-aiplatform.googleapis.com".format(REGION)

In [None]:
#from google.cloud import aiplatform
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)

# Create an index for Streaming Updates

In [None]:
PARENT = "projects/{}/locations/{}".format(PROJECT_ID, REGION)

In [None]:
from google.cloud import aiplatform_v1
from google.protobuf import struct_pb2

In [None]:
index_client = aiplatform_v1.IndexServiceClient(
    client_options=dict(api_endpoint=ENDPOINT)
)

In [None]:
DIMENSIONS = 1408
DISPLAY_NAME = "flipkart_streaming"

treeAhConfig = struct_pb2.Struct(
    fields={
        "leafNodeEmbeddingCount": struct_pb2.Value(number_value=500),
        "leafNodesToSearchPercent": struct_pb2.Value(number_value=7),
    }
)

algorithmConfig = struct_pb2.Struct(
    fields={"treeAhConfig": struct_pb2.Value(struct_value=treeAhConfig)}
)

config = struct_pb2.Struct(
    fields={
        "dimensions": struct_pb2.Value(number_value=DIMENSIONS),
        "approximateNeighborsCount": struct_pb2.Value(number_value=150),
        "distanceMeasureType": struct_pb2.Value(string_value="COSINE_DISTANCE"),
        "algorithmConfig": struct_pb2.Value(struct_value=algorithmConfig),
    }
)

metadata = struct_pb2.Struct(
    fields={
        "config": struct_pb2.Value(struct_value=config),
        "contentsDeltaUri": struct_pb2.Value(string_value=BUCKET_URI),
    }
)

ann_index = {
    "display_name": DISPLAY_NAME,
    "description": "Based on ~18K Flipkart product listings with both description and image",
    "metadata": struct_pb2.Value(struct_value=metadata),
    "index_update_method": aiplatform_v1.Index.IndexUpdateMethod.STREAM_UPDATE,
}
ann_index = index_client.create_index(parent=PARENT, index=ann_index)

In [None]:
ann_index

<google.api_core.operation.Operation at 0x7d135fa41810>

In [None]:
ann_index.result()

name: "projects/411826505131/locations/us-central1/indexes/3849667285773975552"

# Deploy Index

In [None]:
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
    display_name = 'flipkart_streaming',
    description ='Based on 18k flipkart product listings for which we have both a description and image',
    public_endpoint_enabled=True,
)

INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:Creating MatchingEngineIndexEndpoint
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:Create MatchingEngineIndexEndpoint backing LRO: projects/411826505131/locations/us-central1/indexEndpoints/356562824794734592/operations/1536536095016091648
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:MatchingEngineIndexEndpoint created. Resource name: projects/411826505131/locations/us-central1/indexEndpoints/356562824794734592
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:To use this MatchingEngineIndexEndpoint in another session:
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:index_endpoint = aiplatform.MatchingEngineIndexEndpoint('projects/411826505131/locations/us-central1/indexEndpoints/356562824794734592')


In [None]:
tree_ah_index = aiplatform.MatchingEngineIndex(index_name='3849667285773975552')

In [None]:
deployed_index_name = 'flipkart_streaming'
my_index_endpoint = my_index_endpoint.deploy_index(index=tree_ah_index, deployed_index_id=deployed_index_name)
my_index_endpoint.deployed_indexes

INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:Deploying index MatchingEngineIndexEndpoint index_endpoint: projects/411826505131/locations/us-central1/indexEndpoints/356562824794734592
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:Deploy index MatchingEngineIndexEndpoint index_endpoint backing LRO: projects/411826505131/locations/us-central1/indexEndpoints/356562824794734592/operations/3063256368694689792
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:MatchingEngineIndexEndpoint index_endpoint Deployed index. Resource name: projects/411826505131/locations/us-central1/indexEndpoints/356562824794734592


[id: "flipkart_streaming"
index: "projects/411826505131/locations/us-central1/indexes/3849667285773975552"
create_time {
  seconds: 1702893052
  nanos: 116157000
}
index_sync_time {
  seconds: 1702893991
  nanos: 762074000
}
automatic_resources {
  min_replica_count: 2
  max_replica_count: 2
}
deployment_group: "default"
]

# Query Index

In [None]:
import base64
from google.cloud import aiplatform
from google.protobuf import struct_pb2
from functools import cache
import time
import typing
import logging

# Inspired from https://stackoverflow.com/questions/34269772/type-hints-in-namedtuple.
class EmbeddingResponse(typing.NamedTuple):
  text_embedding: typing.Sequence[float]
  image_embedding: typing.Sequence[float]


class EmbeddingPredictionClient:
  """Wrapper around Prediction Service Client."""
  def __init__(self, project : str,
    location : str = "us-central1",
    api_regional_endpoint: str = "us-central1-aiplatform.googleapis.com"):
    client_options = {"api_endpoint": api_regional_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    self.client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)
    self.location = location
    self.project = project

  def get_embedding(self, text : str = None, image_path : str = None):
    """image_path can be a local path or a GCS URI."""
    if not text and not image_path:
      raise ValueError('At least one of text or image_bytes must be specified.')

    instance = struct_pb2.Struct()
    if text:
      if len(text) > 1024:
        logging.warning('Text must be less than 1024 characters. Truncating text.')
        text = text[:1024]
      instance.fields['text'].string_value = text

    if image_path:
      image_struct = instance.fields['image'].struct_value
      if image_path.lower().startswith('gs://'):
        image_struct.fields['gcsUri'].string_value = image_path
      else:
        with open(image_path, "rb") as f:
          image_bytes = f.read()
        encoded_content = base64.b64encode(image_bytes).decode("utf-8")
        image_struct.fields['bytesBase64Encoded'].string_value = encoded_content

    instances = [instance]
    endpoint = (f"projects/{self.project}/locations/{self.location}"
      "/publishers/google/models/multimodalembedding@001")
    response = self.client.predict(endpoint=endpoint, instances=instances)

    text_embedding = None
    if text:
      text_emb_value = response.predictions[0]['textEmbedding']
      text_embedding = [v for v in text_emb_value]

    image_embedding = None
    if image_path:
      image_emb_value = response.predictions[0]['imageEmbedding']
      image_embedding = [v for v in image_emb_value]

    return EmbeddingResponse(
      text_embedding=text_embedding,
      image_embedding=image_embedding)

@cache
def get_client(project):
  return EmbeddingPredictionClient(project)


def embed(project,text,image_path=None):
  client = get_client(project)
  start = time.time()
  response = client.get_embedding(text=text, image_path=image_path)
  end = time.time()
  print('Embedding Time: ', end - start)
  return response

In [None]:
res = embed(PROJECT_ID,
            "Key Features of Vishudh Printed Women's Straight Kurta BLACK, GREY Straight,Specifications of Vishudh Printed Women's Straight Kurta Kurta Details Sleeve Sleeveless Number of Contents in Sales Package Pack of 1 Fabric 100% POLYESTER Type Straight Neck ROUND NECK General Details Pattern Printed Occasion Festive Ideal For Women's In the Box Kurta Additional Details Style Code VNKU004374 BLACK::GREY Fabric Care Gentle Machine Wash in Lukewarm Water, Do Not Bleach",
            'gs://genai-product-catalog/flipkart_20k_oct26/3ecb859759e5311cbab6850e98879522_0.jpg')

Embedding Time:  0.8291733264923096


In [None]:
NUM_NEIGHBORS = 5
response = my_index_endpoint.find_neighbors(
    deployed_index_id=deployed_index_name,
    queries=[res.text_embedding,res.image_embedding],
    num_neighbors=NUM_NEIGHBORS,
)

response

[[MatchNeighbor(id='3ecb859759e5311cbab6850e98879522_T', distance=1.1920928955078125e-07),
  MatchNeighbor(id='0305111c779fe663bd94122bef0f0002_T', distance=0.18872332572937012),
  MatchNeighbor(id='ba8163913f5e384d17a8202b1f8b91b3_T', distance=0.25931215286254883),
  MatchNeighbor(id='6ef0a5eb033cd610d455be7102da5685_T', distance=0.3441193103790283),
  MatchNeighbor(id='ee383a337af67ae8ad4f42714d67ddaf_T', distance=0.3441193103790283)],
 [MatchNeighbor(id='3ecb859759e5311cbab6850e98879522_I', distance=0.0),
  MatchNeighbor(id='06ad8323cf9105f1aaae6515cf08a7d6_I', distance=0.03433966636657715),
  MatchNeighbor(id='c9c27aa5dc7df49e82e55e8abb6b4020_I', distance=0.04238319396972656),
  MatchNeighbor(id='169902631b89202f0e2079e9cc09b3c7_I', distance=0.047689199447631836),
  MatchNeighbor(id='5614ccefd0ab9bee5cd28bf3d38fd12f_I', distance=0.06253844499588013)]]

# Update Index

In [28]:
#Lets read few entries from test dataset & convert into desired format
from google.cloud import bigquery
client = bigquery.Client(PROJECT_ID)
query_job = client.query("""
  SELECT CONCAT(id,'_T') as id, text_embedding as embedding, L0, L1, L2, L3

  FROM `solutions-2023-mar-107.flipkart.test_data_for_index_update`

  UNION ALL

  SELECT CONCAT(id,'_I') as embedding, image_embedding, L0, L1, L2, L3

  FROM `solutions-2023-mar-107.flipkart.test_data_for_index_update`;
   """)

res = query_job.result() # Wait for the job to complete.

In [29]:
df = res.to_dataframe() # Convert to pandas dataframe
df

Unnamed: 0,id,embedding,L0,L1,L2,L3
0,dbdac18a8ee5a8a48238b9685c96e90a_T,"[0.0173654296, -0.0533204265, -0.0123991454, 0...",Watches,Wrist Watches,Timewel Wrist Watches,
1,8a771d8dfa97d06278038945dfe6b936_T,"[0.0395723879, -0.046936553, -0.0225308761, 0....",Watches,Wrist Watches,Chappin & Nellson Wrist Watches,
2,894904e26516d491bf1c7711fe800e78_T,"[0.0234976951, -0.0287721325, -0.00390096451, ...",Watches,Wrist Watches,Only Kidz Wrist Watches,
3,138f8455457c6cf87a0b94e132c485a8_T,"[0.0199023429, -0.0493029393, -0.0240550581, 0...",Watches,Wrist Watches,Gift Island Wrist Watches,
4,7c973b8fb2069b2142aea3473b70c213_T,"[0.00237481017, -0.0487776175, 0.0100880247, 0...",Watches,Watch Accessories,Wrist Bands,Sakhi Styles Wrist Bands
5,81d73f4a7add96d46146ac4e192aad92_T,"[0.00450034346, -0.0472053625, 0.0043839491, -...",Clothing,Kids' Clothing,Girls Wear,Innerwear & Sleepwear
6,140225e6d36138c0c79f4b97d42456bd_T,"[-0.0411048084, -0.0358454958, 0.0511916205, 0...",Clothing,Men's Clothing,T-Shirts,Ocean Race T-Shirts
7,9ac56e95bf79b7a4268387b4c8efdd52_T,"[0.00935253594, -0.0527491495, 0.0230338033, 0...",Clothing,Men's Clothing,T-Shirts,Nimya T-Shirts
8,35c289ac8c50c49fae6d06e37ce34d42_T,"[-0.0280462336, -0.0645570904, 0.0218884815, 0...",Clothing,Men's Clothing,Shirts,Casual & Party Wear Shirts
9,fef6a5aa8c590c8029bbc11903cbd554_T,"[0.0218668692, -0.0588314161, 0.00145428523, 0...",Clothing,Men's Clothing,T-Shirts,Nimya T-Shirts


Insert DataPoint

In [40]:
dp1_id = df['id'][0]
dp1_embedding = df['embedding'][0]
dp1_cat = df['L0'][0]

insert_datapoints_payload = aiplatform_v1.IndexDatapoint(
    datapoint_id=dp1_id,
    feature_vector=dp1_embedding,
    #restricts=[{"namespace": "L0", "allow_list": [dp1_cat]}],
)

upsert_request = aiplatform_v1.UpsertDatapointsRequest(
    index='projects/411826505131/locations/us-central1/indexes/3849667285773975552', datapoints=[insert_datapoints_payload]
)

index_client.upsert_datapoints(request=upsert_request)




Delete DataPoints

In [39]:
remove_request = aiplatform_v1.RemoveDatapointsRequest(
    index='projects/411826505131/locations/us-central1/indexes/3849667285773975552', datapoint_ids=[dp1_id]
)

index_client.remove_datapoints(request=remove_request)

