# Objective

Given a list of embeddings, create and deploy a Vertex AI Vector Search (fka Matching Engine) index with Streaming Updates

Assumes you already have embeddings in GCS in a format supported by Vertex AI Vector Search. For instructions on that see:
1. [Notebook](https://source.corp.google.com/piper///depot/google3/experimental/genaisa/product_catalog/notebooks/0_generate_multimodal_embeddings.ipynb) for generating multimodal embeddings and storing in BQ  
2. [Saved query](https://pantheon.corp.google.com/bigquery?ws=!1m7!1m6!12m5!1m3!1ssolutions-2023-mar-107!2sus-central1!3s1b67da64-ecbc-42d6-945b-5fe1df4e559f!2e1) used to convert embeddings stored in BQ into vector search ingestion format. The product does not yet support direct import from BQ

In [3]:
!pip install google-cloud-aiplatform

Collecting google-cloud-aiplatform
  Downloading google_cloud_aiplatform-1.36.4-py2.py3-none-any.whl (3.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m15.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: google-cloud-aiplatform
Successfully installed google-cloud-aiplatform-1.36.4


---

#### ⚠️ Do not forget to click the "RESTART RUNTIME" button above.

---

In [1]:
import sys

if 'google.colab' in sys.modules:
    from google.colab import auth as google_auth
    google_auth.authenticate_user()

In [2]:
PROJECT_ID = 'solutions-2023-mar-107' # @param {type:"string"}
REGION = 'us-central1' # @param {type:"string"}
#BUCKET_URI = "gs://vector_search_regional/flipkart_multimodal_embeddings" # @param {type:"string"} # WHERE EMBEDDINGS ARE STORED
BUCKET_URI = 'gs://vector_search_regional/test_filterings' # @param {type:"string"} # WHERE EMBEDDINGS ARE STORED
ENDPOINT = "{}-aiplatform.googleapis.com".format(REGION)

In [3]:
from google.cloud import aiplatform
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)

# Create an index for Streaming Updates with REST

In [9]:
import requests
import json
import subprocess


url = "https://{0}/v1/projects/{1}/locations/{2}/indexes".format(ENDPOINT, PROJECT_ID, REGION)

command = subprocess.run('gcloud auth application-default print-access-token', shell=True, capture_output=True, text=True).stdout.strip()

headers = {
    'Authorization': 'Bearer ' + command,
    'Content-Type': 'application/json'
}

data={"displayName": "flipkart_streaming_index_with_filters",
        "description": 'flipkart_streaming_index_with_filters',
        "metadata": {
        "contentsDeltaUri": 'gs://vector_search_regional/test_filterings',
          "config": {
            "dimensions": 1408,
            "approximateNeighborsCount": 150,
            "distanceMeasureType": "COSINE_DISTANCE",
            "algorithmConfig": {"treeAhConfig": {"leafNodeEmbeddingCount": 500, "leafNodesToSearchPercent": 7}}
          },
        },
        "indexUpdateMethod": "STREAM_UPDATE"
      }
data = json.dumps(data)
print(data)

response = requests.post(url, data=data,headers={"Content-Type": "application/json", "Authorization": "Bearer "+command} )
print(response)

{"displayName": "flipkart_streaming_index_with_filters", "description": "flipkart_streaming_index_with_filters", "metadata": {"contentsDeltaUri": "gs://vector_search_regional/test_filterings", "config": {"dimensions": 1408, "approximateNeighborsCount": 150, "distanceMeasureType": "COSINE_DISTANCE", "algorithmConfig": {"treeAhConfig": {"leafNodeEmbeddingCount": 500, "leafNodesToSearchPercent": 7}}}}, "indexUpdateMethod": "STREAM_UPDATE"}
<Response [200]>


In [10]:
print(response.text)
index_detail=response.json()
print(index_detail)

{
  "name": "projects/411826505131/locations/us-central1/indexes/6070927064486117376/operations/4537822407821361152",
  "metadata": {
    "@type": "type.googleapis.com/google.cloud.aiplatform.v1.CreateIndexOperationMetadata",
    "genericMetadata": {
      "createTime": "2023-11-29T11:40:40.741945Z",
      "updateTime": "2023-11-29T11:40:40.741945Z"
    }
  }
}

{'name': 'projects/411826505131/locations/us-central1/indexes/6070927064486117376/operations/4537822407821361152', 'metadata': {'@type': 'type.googleapis.com/google.cloud.aiplatform.v1.CreateIndexOperationMetadata', 'genericMetadata': {'createTime': '2023-11-29T11:40:40.741945Z', 'updateTime': '2023-11-29T11:40:40.741945Z'}}}


In [11]:
INDEX_RESOURCE_NAME = index_detail['name']
print(INDEX_RESOURCE_NAME)

projects/411826505131/locations/us-central1/indexes/6070927064486117376/operations/4537822407821361152


# Create an index for Streaming Updates with v1

In [48]:
AUTH_TOKEN = !gcloud auth print-access-token
PROJECT_NUMBER = !gcloud projects list --filter="PROJECT_ID:'{PROJECT_ID}'" --format='value(PROJECT_NUMBER)'
PROJECT_NUMBER = PROJECT_NUMBER[0]

PARENT = "projects/{}/locations/{}".format(PROJECT_ID, REGION)

print("ENDPOINT: {}".format(ENDPOINT))
print("PROJECT_ID: {}".format(PROJECT_ID))
print("REGION: {}".format(REGION))

!gcloud config set project {PROJECT_ID} --quiet
!gcloud config set ai_platform/region {REGION} --quiet


ENDPOINT: us-central1-aiplatform.googleapis.com
PROJECT_ID: solutions-2023-mar-107
REGION: us-central1
Updated property [core/project].
Updated property [ai_platform/region].


In [49]:
from google.cloud import aiplatform_v1
from google.protobuf import struct_pb2

In [50]:
index_client = aiplatform_v1.IndexServiceClient(
    client_options=dict(api_endpoint=ENDPOINT)
)

In [None]:
DIMENSIONS = 1408
DISPLAY_NAME = "stream_test"

treeAhConfig = struct_pb2.Struct(
    fields={
        "leafNodeEmbeddingCount": struct_pb2.Value(number_value=500),
        "leafNodesToSearchPercent": struct_pb2.Value(number_value=7),
    }
)

algorithmConfig = struct_pb2.Struct(
    fields={"treeAhConfig": struct_pb2.Value(struct_value=treeAhConfig)}
)

config = struct_pb2.Struct(
    fields={
        "dimensions": struct_pb2.Value(number_value=DIMENSIONS),
        "approximateNeighborsCount": struct_pb2.Value(number_value=150),
        "distanceMeasureType": struct_pb2.Value(string_value="COSINE_DISTANCE"),
        "algorithmConfig": struct_pb2.Value(struct_value=algorithmConfig),
    }
)

metadata = struct_pb2.Struct(
    fields={
        "config": struct_pb2.Value(struct_value=config),
        "contentsDeltaUri": struct_pb2.Value(string_value=BUCKET_URI),
    }
)

ann_index = {
    "display_name": DISPLAY_NAME,
    "description": "Glove 100 ANN index",
    "metadata": struct_pb2.Value(struct_value=metadata),
    "index_update_method": aiplatform_v1.Index.IndexUpdateMethod.STREAM_UPDATE,
}
ann_index = index_client.create_index(parent=PARENT, index=ann_index)
#ann_index.result()

In [57]:
ann_index

<google.api_core.operation.Operation at 0x7891b35320b0>

# Deploy Index

In [None]:
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
    display_name='flipkart_multimodal_streaming_tree_cosine',
    description='Based on 18k flipkart product listings for which we have both a description and image',
    public_endpoint_enabled=True,
)

In [62]:
tree_ah_index = aiplatform.MatchingEngineIndex(index_name='6070927064486117376')

In [63]:
deployed_index_name = 'flipkart_multimodal_streaming'
my_index_endpoint = my_index_endpoint.deploy_index(index=tree_ah_index, deployed_index_id=deployed_index_name)
my_index_endpoint.deployed_indexes

INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:Deploying index MatchingEngineIndexEndpoint index_endpoint: projects/411826505131/locations/us-central1/indexEndpoints/4967334049547812864
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:Deploy index MatchingEngineIndexEndpoint index_endpoint backing LRO: projects/411826505131/locations/us-central1/indexEndpoints/4967334049547812864/operations/8471804643259711488
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:MatchingEngineIndexEndpoint index_endpoint Deployed index. Resource name: projects/411826505131/locations/us-central1/indexEndpoints/4967334049547812864


[id: "flipkart_multimodal_streaming"
index: "projects/411826505131/locations/us-central1/indexes/9179325601046069248"
create_time {
  seconds: 1701227838
  nanos: 150044000
}
index_sync_time {
  seconds: 1701229381
  nanos: 491033000
}
automatic_resources {
  min_replica_count: 2
  max_replica_count: 2
}
deployment_group: "default"
]

# Query Index

In [None]:
import base64
from google.cloud import aiplatform
from google.protobuf import struct_pb2
from functools import cache
import time
import typing


# Inspired from https://stackoverflow.com/questions/34269772/type-hints-in-namedtuple.
class EmbeddingResponse(typing.NamedTuple):
  text_embedding: typing.Sequence[float]
  image_embedding: typing.Sequence[float]


class EmbeddingPredictionClient:
  """Wrapper around Prediction Service Client."""
  def __init__(self, project : str,
    location : str = "us-central1",
    api_regional_endpoint: str = "us-central1-aiplatform.googleapis.com"):
    client_options = {"api_endpoint": api_regional_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    self.client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)
    self.location = location
    self.project = project

  def get_embedding(self, text : str = None, image_path : str = None):
    """image_path can be a local path or a GCS URI."""
    if not text and not image_path:
      raise ValueError('At least one of text or image_bytes must be specified.')

    instance = struct_pb2.Struct()
    if text:
      instance.fields['text'].string_value = text

    if image_path:
      image_struct = instance.fields['image'].struct_value
      if image_path.lower().startswith('gs://'):
        image_struct.fields['gcsUri'].string_value = image_path
      else:
        with open(image_path, "rb") as f:
          image_bytes = f.read()
        encoded_content = base64.b64encode(image_bytes).decode("utf-8")
        image_struct.fields['bytesBase64Encoded'].string_value = encoded_content

    instances = [instance]
    endpoint = (f"projects/{self.project}/locations/{self.location}"
      "/publishers/google/models/multimodalembedding@001")
    response = self.client.predict(endpoint=endpoint, instances=instances)

    text_embedding = None
    if text:
      text_emb_value = response.predictions[0]['textEmbedding']
      text_embedding = [v for v in text_emb_value]

    image_embedding = None
    if image_path:
      image_emb_value = response.predictions[0]['imageEmbedding']
      image_embedding = [v for v in image_emb_value]

    return EmbeddingResponse(
      text_embedding=text_embedding,
      image_embedding=image_embedding)

@cache
def get_client(project):
  return EmbeddingPredictionClient(project)


def embed(project,text,image_path=None):
  client = get_client(project)
  start = time.time()
  response = client.get_embedding(text=text, image_path=image_path)
  end = time.time()
  print('Embedding Time: ', end - start)
  return response

In [None]:
res = embed(PROJECT_ID,
            "Key Features of Vishudh Printed Women's Straight Kurta BLACK, GREY Straight,Specifications of Vishudh Printed Women's Straight Kurta Kurta Details Sleeve Sleeveless Number of Contents in Sales Package Pack of 1 Fabric 100% POLYESTER Type Straight Neck ROUND NECK General Details Pattern Printed Occasion Festive Ideal For Women's In the Box Kurta Additional Details Style Code VNKU004374 BLACK::GREY Fabric Care Gentle Machine Wash in Lukewarm Water, Do Not Bleach",
            'gs://genai-product-catalog/flipkart_20k_oct26/3ecb859759e5311cbab6850e98879522_0.jpg')

Embedding Time:  0.8610877990722656


In [None]:
NUM_NEIGHBORS = 5
response = my_index_endpoint.find_neighbors(
    deployed_index_id=DEPLOYED_INDEX_ID,
    queries=[res.text_embedding,res.image_embedding],
    num_neighbors=NUM_NEIGHBORS,
)

response

[[MatchNeighbor(id='3ecb859759e5311cbab6850e98879522_T', distance=0.0),
  MatchNeighbor(id='0305111c779fe663bd94122bef0f0002_T', distance=0.18872332572937012),
  MatchNeighbor(id='ba8163913f5e384d17a8202b1f8b91b3_T', distance=0.2593117952346802),
  MatchNeighbor(id='6ef0a5eb033cd610d455be7102da5685_T', distance=0.34411925077438354),
  MatchNeighbor(id='ee383a337af67ae8ad4f42714d67ddaf_T', distance=0.34411925077438354)],
 [MatchNeighbor(id='3ecb859759e5311cbab6850e98879522_I', distance=0.0),
  MatchNeighbor(id='06ad8323cf9105f1aaae6515cf08a7d6_I', distance=0.03433966636657715),
  MatchNeighbor(id='c9c27aa5dc7df49e82e55e8abb6b4020_I', distance=0.04238319396972656),
  MatchNeighbor(id='169902631b89202f0e2079e9cc09b3c7_I', distance=0.047689199447631836),
  MatchNeighbor(id='5614ccefd0ab9bee5cd28bf3d38fd12f_I', distance=0.06253844499588013)]]