# Objective

Given a list of embeddings[with restrict fields], create and deploy a Vertex AI Vector Search (fka Matching Engine) Batch index - with filters to narrow down the search

# Install libraries

In [None]:
!pip install google-cloud-aiplatform



---

#### ⚠️ Do not forget to click the "RESTART RUNTIME" button above.

---

# Authenticate

In [None]:
import sys

if 'google.colab' in sys.modules:
    from google.colab import auth as google_auth
    google_auth.authenticate_user()

# Create Index

In [None]:
PROJECT_ID = 'solutions-2023-mar-107' # @param {type:"string"}
REGION = 'us-central1' # @param {type:"string"}
BUCKET_URI = 'gs://vector_search_regional/test_filterings' # @param {type:"string"} # WHERE EMBEDDINGS ARE STORED

In [None]:
from google.cloud import aiplatform
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)

In [None]:
tree_ah_index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
    display_name='flipkart_batch_filtering',
    contents_delta_uri=BUCKET_URI,
    dimensions=1408,
    approximate_neighbors_count=150,
    distance_measure_type="COSINE_DISTANCE",
    leaf_node_embedding_count=500,
    leaf_nodes_to_search_percent=7,
    description='Based on ~18K Flipkart product listings with both description and image',
)


# Deploy Index

In [None]:
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
    display_name='flipkart_batch_filtering_endpoint',
    description='Endpoint for flipkart_batch_filtering',
    public_endpoint_enabled=True,
)

DEPLOYED_INDEX_ID = 'flipkart_batch_filtering'
my_index_endpoint = my_index_endpoint.deploy_index(
    index=tree_ah_index, deployed_index_id=DEPLOYED_INDEX_ID
)
my_index_endpoint.deployed_indexes

# Query index

In [None]:
import base64
from google.cloud import aiplatform
from google.protobuf import struct_pb2
from functools import cache
import time
import typing
import logging

# Inspired from https://stackoverflow.com/questions/34269772/type-hints-in-namedtuple.
class EmbeddingResponse(typing.NamedTuple):
  text_embedding: typing.Sequence[float]
  image_embedding: typing.Sequence[float]


class EmbeddingPredictionClient:
  """Wrapper around Prediction Service Client."""
  def __init__(self, project : str,
    location : str = "us-central1",
    api_regional_endpoint: str = "us-central1-aiplatform.googleapis.com"):
    client_options = {"api_endpoint": api_regional_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    self.client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)
    self.location = location
    self.project = project

  def get_embedding(self, text : str = None, image_path : str = None):
    """image_path can be a local path or a GCS URI."""
    if not text and not image_path:
      raise ValueError('At least one of text or image_bytes must be specified.')

    instance = struct_pb2.Struct()
    if text:
      if len(text) > 1024:
        logging.warning('Text must be less than 1024 characters. Truncating text.')
        text = text[:1024]
      instance.fields['text'].string_value = text

    if image_path:
      image_struct = instance.fields['image'].struct_value
      if image_path.lower().startswith('gs://'):
        image_struct.fields['gcsUri'].string_value = image_path
      else:
        with open(image_path, "rb") as f:
          image_bytes = f.read()
        encoded_content = base64.b64encode(image_bytes).decode("utf-8")
        image_struct.fields['bytesBase64Encoded'].string_value = encoded_content

    instances = [instance]
    endpoint = (f"projects/{self.project}/locations/{self.location}"
      "/publishers/google/models/multimodalembedding@001")
    response = self.client.predict(endpoint=endpoint, instances=instances)

    text_embedding = None
    if text:
      text_emb_value = response.predictions[0]['textEmbedding']
      text_embedding = [v for v in text_emb_value]

    image_embedding = None
    if image_path:
      image_emb_value = response.predictions[0]['imageEmbedding']
      image_embedding = [v for v in image_emb_value]

    return EmbeddingResponse(
      text_embedding=text_embedding,
      image_embedding=image_embedding)

@cache
def get_client(project):
  return EmbeddingPredictionClient(project)


def embed(project,text,image_path=None):
  client = get_client(project)
  start = time.time()
  response = client.get_embedding(text=text, image_path=image_path)
  end = time.time()
  print('Embedding Time: ', end - start)
  return response

In [None]:
res = embed(PROJECT_ID,
            "Key Features of Vishudh Printed Women's Straight Kurta BLACK, GREY Straight,Specifications of Vishudh Printed Women's Straight Kurta Kurta Details Sleeve Sleeveless Number of Contents in Sales Package Pack of 1 Fabric 100% POLYESTER Type Straight Neck ROUND NECK General Details Pattern Printed Occasion Festive Ideal For Women's In the Box Kurta Additional Details Style Code VNKU004374 BLACK::GREY Fabric Care Gentle Machine Wash in Lukewarm Water, Do Not Bleach",
            'gs://genai-product-catalog/flipkart_20k_oct26/3ecb859759e5311cbab6850e98879522_0.jpg')

print(res.text_embedding[:5])
print(res.image_embedding[:5])

Embedding Time:  0.9179873466491699
[-0.0165299587, -0.0692435279, 0.0147972982, 0.0349166, 0.00536285108]
[-0.00627784571, 0.0557938404, -0.0300552715, 0.0268275943, 0.0392336547]


In [None]:
!gcloud ai index-endpoints list \
  --project=solutions-2023-mar-107 \
  --region=us-central1

Using endpoint [https://us-central1-aiplatform.googleapis.com/]
---
createTime: '2023-12-18T08:15:01.086330Z'
description: Endpoint on flipkart
displayName: flipkart_batch
encryptionSpec: {}
etag: AMEw9yMncLfvNLAtX8rSW4pi_mUSXH_R9oNvIGj-SKLerlSS34Prq25jE19M_qsiEp0n
name: projects/411826505131/locations/us-central1/indexEndpoints/6297373683249840128
publicEndpointDomainName: 1154454212.us-central1-411826505131.vdb.vertexai.goog
updateTime: '2023-12-18T08:15:01.784596Z'
---
createTime: '2023-12-08T10:20:13.996148Z'
deployedIndexes:
- createTime: '2023-12-08T10:24:03.327365Z'
  dedicatedResources:
    machineSpec:
      machineType: e2-standard-16
    maxReplicaCount: 2
    minReplicaCount: 2
  deploymentGroup: default
  displayName: flipkart_streaming_filtering_endpoint
  id: flipkart_streaming_filteri_1702030773989
  index: projects/411826505131/locations/us-central1/indexes/6070927064486117376
  indexSyncTime: '2023-12-18T08:23:20.322743Z'
displayName: flipkart_streaming_filtering_endp

In [None]:
DEPLOYED_INDEX_ID = 'flipkart_batch_filtering'
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint('projects/411826505131/locations/us-central1/indexEndpoints/4149015923505758208')



  id: flipkart_batch_filtering
  index: projects/411826505131/locations/us-central1/indexes/4261605914190020608
  
description: Endpoint for flipkart_batch_filtering
displayName: flipkart_batch_filtering_endpoint
name: projects/411826505131/locations/us-central1/indexEndpoints/4149015923505758208


## Without Filters

In [None]:
NUM_NEIGHBORS = 5
response = my_index_endpoint.find_neighbors(
    deployed_index_id=DEPLOYED_INDEX_ID,
    queries=[res.text_embedding,res.image_embedding],
    num_neighbors=NUM_NEIGHBORS,
)

In [None]:
response

[[MatchNeighbor(id='3ecb859759e5311cbab6850e98879522_T', distance=1.7881393432617188e-07),
  MatchNeighbor(id='0305111c779fe663bd94122bef0f0002_T', distance=0.1887233853340149),
  MatchNeighbor(id='ba8163913f5e384d17a8202b1f8b91b3_T', distance=0.25931215286254883),
  MatchNeighbor(id='6ef0a5eb033cd610d455be7102da5685_T', distance=0.3441193103790283),
  MatchNeighbor(id='ee383a337af67ae8ad4f42714d67ddaf_T', distance=0.3441193103790283)],
 [MatchNeighbor(id='3ecb859759e5311cbab6850e98879522_I', distance=0.0),
  MatchNeighbor(id='06ad8323cf9105f1aaae6515cf08a7d6_I', distance=0.03433966636657715),
  MatchNeighbor(id='c9c27aa5dc7df49e82e55e8abb6b4020_I', distance=0.04238319396972656),
  MatchNeighbor(id='169902631b89202f0e2079e9cc09b3c7_I', distance=0.047689199447631836),
  MatchNeighbor(id='5614ccefd0ab9bee5cd28bf3d38fd12f_I', distance=0.06253844499588013)]]

## Filter at root category level

In [None]:
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import Namespace
response = my_index_endpoint.find_neighbors(
    deployed_index_id=DEPLOYED_INDEX_ID,
    queries=[res.text_embedding,res.image_embedding],
    num_neighbors=NUM_NEIGHBORS,
    filter=[Namespace("L0", ["Clothing"])]
)

response

[[MatchNeighbor(id='ba8163913f5e384d17a8202b1f8b91b3_T', distance=0.25931215286254883),
  MatchNeighbor(id='37e703b75b9a2c465aa78f8264544afd_T', distance=0.3462928533554077),
  MatchNeighbor(id='0ea5dbf6a8e3885960ea92c56f8b2a03_T', distance=0.3552004098892212),
  MatchNeighbor(id='4925cd62bb5f1501264c40973f8fa3d5_T', distance=0.3700721263885498),
  MatchNeighbor(id='e5739673dd302b6f4386197b3ae0ea06_T', distance=0.38109153509140015)],
 [MatchNeighbor(id='08ea3424aab3ec97281f82a0510c6775_I', distance=0.1566087007522583),
  MatchNeighbor(id='c0739a9ea229dc557a02ccd124534a0a_I', distance=0.18003523349761963),
  MatchNeighbor(id='67e0c9fd0f00b831697ba6fb4215ea7b_I', distance=0.20431339740753174),
  MatchNeighbor(id='dd01bbc58bcbf6fbfe6e1455209b6ffe_I', distance=0.20633137226104736),
  MatchNeighbor(id='d9ae13cdeba5ebf2205046a23c7caffd_I', distance=0.21051812171936035)]]

## Filter at sub category level

In [None]:
response = my_index_endpoint.find_neighbors(
    deployed_index_id=DEPLOYED_INDEX_ID,
    queries=[res.text_embedding,res.image_embedding],
    num_neighbors=NUM_NEIGHBORS,
    filter=[Namespace("L1", ["Women's Clothing"])]
)

response

[[MatchNeighbor(id='ba8163913f5e384d17a8202b1f8b91b3_T', distance=0.25931215286254883),
  MatchNeighbor(id='37e703b75b9a2c465aa78f8264544afd_T', distance=0.3462928533554077),
  MatchNeighbor(id='4925cd62bb5f1501264c40973f8fa3d5_T', distance=0.3700721263885498),
  MatchNeighbor(id='8cc202041fb2ff897fb5ca47203edd0d_T', distance=0.38541924953460693),
  MatchNeighbor(id='08ea3424aab3ec97281f82a0510c6775_T', distance=0.39698123931884766)],
 [MatchNeighbor(id='08ea3424aab3ec97281f82a0510c6775_I', distance=0.1566087007522583),
  MatchNeighbor(id='c0739a9ea229dc557a02ccd124534a0a_I', distance=0.18003523349761963),
  MatchNeighbor(id='67e0c9fd0f00b831697ba6fb4215ea7b_I', distance=0.20431339740753174),
  MatchNeighbor(id='d9ae13cdeba5ebf2205046a23c7caffd_I', distance=0.21051812171936035),
  MatchNeighbor(id='3c2ff53673543e77a049a2765c430a82_I', distance=0.230016827583313)]]

## Filter at multiple catagories

In [None]:
response = my_index_endpoint.find_neighbors(
    deployed_index_id=DEPLOYED_INDEX_ID,
    queries=[res.text_embedding,res.image_embedding],
    num_neighbors=NUM_NEIGHBORS,
    filter=[Namespace("L0", ["Clothing"]), Namespace("L1", ["Women's Clothing"])]
)
response

[[MatchNeighbor(id='ba8163913f5e384d17a8202b1f8b91b3_T', distance=0.25931215286254883),
  MatchNeighbor(id='37e703b75b9a2c465aa78f8264544afd_T', distance=0.3462928533554077),
  MatchNeighbor(id='4925cd62bb5f1501264c40973f8fa3d5_T', distance=0.3700721263885498),
  MatchNeighbor(id='8cc202041fb2ff897fb5ca47203edd0d_T', distance=0.38541924953460693),
  MatchNeighbor(id='08ea3424aab3ec97281f82a0510c6775_T', distance=0.39698123931884766)],
 [MatchNeighbor(id='08ea3424aab3ec97281f82a0510c6775_I', distance=0.1566087007522583),
  MatchNeighbor(id='c0739a9ea229dc557a02ccd124534a0a_I', distance=0.18003523349761963),
  MatchNeighbor(id='67e0c9fd0f00b831697ba6fb4215ea7b_I', distance=0.20431339740753174),
  MatchNeighbor(id='d9ae13cdeba5ebf2205046a23c7caffd_I', distance=0.21051812171936035),
  MatchNeighbor(id='3c2ff53673543e77a049a2765c430a82_I', distance=0.230016827583313)]]