# Setup

In [2]:
!pip install google-cloud-aiplatform

Collecting google-cloud-aiplatform
  Downloading google_cloud_aiplatform-1.36.4-py2.py3-none-any.whl (3.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m30.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: google-cloud-aiplatform
Successfully installed google-cloud-aiplatform-1.36.4


In [1]:
import sys

if 'google.colab' in sys.modules:
    from google.colab import auth as google_auth
    google_auth.authenticate_user()

In [2]:
PROJECT_ID = 'solutions-2023-mar-107' # @param {type:"string"}
REGION = 'us-central1' # @param {type:"string"}
#BUCKET_URI = "gs://vector_search_regional/flipkart_multimodal_embeddings" # @param {type:"string"} # WHERE EMBEDDINGS ARE STORED
BUCKET_URI = 'gs://vector_search_regional/test_filterings'# @param {type:"string"}
ENDPOINT = "{}-aiplatform.googleapis.com".format(REGION)

In [3]:
from google.cloud import aiplatform
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)

# Prepare Data

Sample data in json format to add filters in category level:


{"id": "43_T”, "embedding": [0.6, 1.0],
"restricts": [
{"namespace": “L0”, "allow": c0_name },
{“namespace": “L1”, "allow": c1_name},
{“namespace": “L2”, "allow": c2_name},
{“namespace": “L3”, "allow": c3_name}
]
}


{"id": "43_I”, "embedding": [0.6, 1.0],
"restricts": [
{"namespace": “L0”, "allow": c0_name },
{“namespace": “L1”, "allow": c1_name},
{“namespace": “L2”, "allow": c2_name},
{“namespace": “L3”, "allow": c3_name}
]
}

In [4]:
from google.cloud import bigquery
client = bigquery.Client(PROJECT_ID)
query_job = client.query("""
   SELECT *
   FROM `flipkart.training_embeddings_with_cat_for_filtering`;""")

results = query_job.result() # Wait for the job to complete.

In [5]:
data = results.to_dataframe()

In [6]:
data

Unnamed: 0,id,embedding,L0,L1,L2,L3
0,9f56d6e1481d35f9677a69a983aa81ee_T,"[0.0305967834, -0.0117346784, -0.0113393087, -...",Furniture,Pet Furniture,,
1,efb5a934cefac9456baef772f5d97f52_T,"[0.017123403, -0.0647764653, 0.0103280917, 0.0...",Footwear,Women's Footwear,REMSON INDIA Women Flats,
2,abd4482126e006bd6f3ce9825e9449bc_T,"[0.0181311239, -0.0240292437, 0.00982811209, -...",Mobiles & Accessories,Tablet Accessories,Cases & Covers,MannMohh Cases & Covers
3,635df3139a893577a21554cd4ace1f35_T,"[0.0425712056, -0.0244968068, 0.00395903178, 0...",Mobiles & Accessories,Tablet Accessories,Cases & Covers,kasemantra Cases & Covers
4,135524f18e0ac6e3d008fa81b1576dfa_T,"[-0.0492946468, -0.0495145917, 0.0119116539, -...",Clothing,Women's Clothing,Western Wear,"Shirts, Tops & Tunics"
...,...,...,...,...,...,...
36369,dad7943c3791dfb2d669942e61dfc25d_I,"[-0.00362257822, 0.0448172726, 0.0305924937, 0...",Home Decor & Festive Needs,Showpieces,Ona'S Showpieces,
36370,906b3a5912453ffced2dbd0fc4bd495c_I,"[-0.013985265, 0.0400636345, 0.0220041461, -0....",Footwear,Women's Footwear,Casual Shoes,Boots
36371,2fa79e6a06305fa2ea23a343841b78c3_I,"[-0.0203494355, 0.0481191799, 0.00561492844, -...",Clothing,Women's Clothing,Western Wear,"Shirts, Tops & Tunics"
36372,084ae0b12e0672abfc7f9d125bd1e15b_I,"[-0.0286370851, -0.0289862268, -0.00473337155,...",Computers,Network Components,Routers,Onnet Routers


In [69]:
def adding_filters(data):

  restricts = []

  def add_restricts(col_name):
    namespace = col_name
    allow = []
    allow.append(data[col_name])
    restricts.append({'namespace': namespace, 'allow': allow})

  if data['L0']:
    add_restricts('L0')
    if data['L1']:
      add_restricts('L1')
      if data['L2']:
        add_restricts('L2')
        if data['L3']:
          add_restricts('L3')

  data['restricts']= restricts
  return data


data = data.apply(adding_filters, axis=1)

In [70]:
import json

data[['id','embedding','restricts']].to_json('sample.json', orient='records', lines=True)

In [71]:
#Upload json into GCS
!gsutil cp -r sample.json gs://vector_search_regional/test_filterings

Copying file://sample.json [Content-Type=application/json]...
/ [0 files][    0.0 B/658.0 MiB]                                                ==> NOTE: You are uploading one or more large file(s), which would run
significantly faster if you enable parallel composite uploads. This
feature can be enabled by editing the
"parallel_composite_upload_threshold" value in your .boto
configuration file. However, note that if you do this large files will
be uploaded as `composite objects
<https://cloud.google.com/storage/docs/composite-objects>`_,which
means that any user who downloads such objects will need to have a
compiled crcmod installed (see "gsutil help crcmod"). This is because
without a compiled crcmod, computing checksums on composite objects is
so slow that gsutil disables downloads of composite objects.

- [1 files][658.0 MiB/658.0 MiB]   46.1 MiB/s                                   
Operation completed over 1 objects/658.0 MiB.                                    


In [72]:
BUCKET_URI

'gs://vector_search_regional/test_filterings'

# Create Batch index with sample data

In [73]:
tree_ah_index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
    display_name='flipkart_batch_filtering',
    contents_delta_uri=BUCKET_URI,
    dimensions=1408,
    approximate_neighbors_count=150,
    distance_measure_type="COSINE_DISTANCE",
    leaf_node_embedding_count=500,
    leaf_nodes_to_search_percent=7,
    description='batch_filtering_test',
)

INDEX_RESOURCE_NAME = tree_ah_index.resource_name
print(INDEX_RESOURCE_NAME)

tree_ah_index = aiplatform.MatchingEngineIndex(index_name=INDEX_RESOURCE_NAME)

INFO:google.cloud.aiplatform.matching_engine.matching_engine_index:Creating MatchingEngineIndex
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index:Create MatchingEngineIndex backing LRO: projects/411826505131/locations/us-central1/indexes/4261605914190020608/operations/2181313902799749120
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index:MatchingEngineIndex created. Resource name: projects/411826505131/locations/us-central1/indexes/4261605914190020608
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index:To use this MatchingEngineIndex in another session:
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index:index = aiplatform.MatchingEngineIndex('projects/411826505131/locations/us-central1/indexes/4261605914190020608')


projects/411826505131/locations/us-central1/indexes/4261605914190020608


In [None]:
tree_ah_index = aiplatform.MatchingEngineIndex(index_name=INDEX_RESOURCE_NAME)

In [9]:
INDEX_RESOURCE_NAME = 'flipkart_batch_filtering'
tree_ah_index = aiplatform.MatchingEngineIndex(index_name='4261605914190020608')

##Deploy Batch index

In [10]:
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
    display_name='flipkart_batch_filtering_endpoint',
    description='Endpoint for flipkart_batch_filtering',
    public_endpoint_enabled=True,
)


DEPLOYED_INDEX_ID = 'flipkart_batch_filtering'
my_index_endpoint = my_index_endpoint.deploy_index(
    index=tree_ah_index, deployed_index_id=DEPLOYED_INDEX_ID
)
my_index_endpoint.deployed_indexes

INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:Creating MatchingEngineIndexEndpoint
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:Create MatchingEngineIndexEndpoint backing LRO: projects/411826505131/locations/us-central1/indexEndpoints/4149015923505758208/operations/5059888120875450368
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:MatchingEngineIndexEndpoint created. Resource name: projects/411826505131/locations/us-central1/indexEndpoints/4149015923505758208
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:To use this MatchingEngineIndexEndpoint in another session:
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:index_endpoint = aiplatform.MatchingEngineIndexEndpoint('projects/411826505131/locations/us-central1/indexEndpoints/4149015923505758208')
INFO:google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint:Deploying index Matc

[id: "flipkart_batch_filtering"
index: "projects/411826505131/locations/us-central1/indexes/4261605914190020608"
create_time {
  seconds: 1701274228
  nanos: 60348000
}
index_sync_time {
  seconds: 1701275740
  nanos: 50557000
}
automatic_resources {
  min_replica_count: 2
  max_replica_count: 2
}
deployment_group: "default"
]

## Query Batch index

In [11]:
import base64
from google.cloud import aiplatform
from google.protobuf import struct_pb2
from functools import cache
import time
import typing


# Inspired from https://stackoverflow.com/questions/34269772/type-hints-in-namedtuple.
class EmbeddingResponse(typing.NamedTuple):
  text_embedding: typing.Sequence[float]
  image_embedding: typing.Sequence[float]


class EmbeddingPredictionClient:
  """Wrapper around Prediction Service Client."""
  def __init__(self, project : str,
    location : str = "us-central1",
    api_regional_endpoint: str = "us-central1-aiplatform.googleapis.com"):
    client_options = {"api_endpoint": api_regional_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    self.client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)
    self.location = location
    self.project = project

  def get_embedding(self, text : str = None, image_path : str = None):
    """image_path can be a local path or a GCS URI."""
    if not text and not image_path:
      raise ValueError('At least one of text or image_bytes must be specified.')

    instance = struct_pb2.Struct()
    if text:
      instance.fields['text'].string_value = text

    if image_path:
      image_struct = instance.fields['image'].struct_value
      if image_path.lower().startswith('gs://'):
        image_struct.fields['gcsUri'].string_value = image_path
      else:
        with open(image_path, "rb") as f:
          image_bytes = f.read()
        encoded_content = base64.b64encode(image_bytes).decode("utf-8")
        image_struct.fields['bytesBase64Encoded'].string_value = encoded_content

    instances = [instance]
    endpoint = (f"projects/{self.project}/locations/{self.location}"
      "/publishers/google/models/multimodalembedding@001")
    response = self.client.predict(endpoint=endpoint, instances=instances)

    text_embedding = None
    if text:
      text_emb_value = response.predictions[0]['textEmbedding']
      text_embedding = [v for v in text_emb_value]

    image_embedding = None
    if image_path:
      image_emb_value = response.predictions[0]['imageEmbedding']
      image_embedding = [v for v in image_emb_value]

    return EmbeddingResponse(
      text_embedding=text_embedding,
      image_embedding=image_embedding)

@cache
def get_client(project):
  return EmbeddingPredictionClient(project)


def embed(project,text,image_path=None):
  client = get_client(project)
  start = time.time()
  response = client.get_embedding(text=text, image_path=image_path)
  end = time.time()
  print('Embedding Time: ', end - start)
  return response

In [12]:
res = embed(PROJECT_ID,
            "Key Features of Vishudh Printed Women's Straight Kurta BLACK, GREY Straight,Specifications of Vishudh Printed Women's Straight Kurta Kurta Details Sleeve Sleeveless Number of Contents in Sales Package Pack of 1 Fabric 100% POLYESTER Type Straight Neck ROUND NECK General Details Pattern Printed Occasion Festive Ideal For Women's In the Box Kurta Additional Details Style Code VNKU004374 BLACK::GREY Fabric Care Gentle Machine Wash in Lukewarm Water, Do Not Bleach",
            'gs://genai-product-catalog/flipkart_20k_oct26/3ecb859759e5311cbab6850e98879522_0.jpg')

Embedding Time:  0.639094352722168


In [13]:
NUM_NEIGHBORS = 5
response = my_index_endpoint.find_neighbors(
    deployed_index_id=DEPLOYED_INDEX_ID,
    queries=[res.text_embedding,res.image_embedding],
    num_neighbors=NUM_NEIGHBORS,
)

response

[[MatchNeighbor(id='3ecb859759e5311cbab6850e98879522_T', distance=1.7881393432617188e-07),
  MatchNeighbor(id='0305111c779fe663bd94122bef0f0002_T', distance=0.1887233853340149),
  MatchNeighbor(id='ba8163913f5e384d17a8202b1f8b91b3_T', distance=0.25931215286254883),
  MatchNeighbor(id='6ef0a5eb033cd610d455be7102da5685_T', distance=0.3441193103790283),
  MatchNeighbor(id='ee383a337af67ae8ad4f42714d67ddaf_T', distance=0.3441193103790283)],
 [MatchNeighbor(id='3ecb859759e5311cbab6850e98879522_I', distance=0.0),
  MatchNeighbor(id='06ad8323cf9105f1aaae6515cf08a7d6_I', distance=0.03433966636657715),
  MatchNeighbor(id='c9c27aa5dc7df49e82e55e8abb6b4020_I', distance=0.04238331317901611),
  MatchNeighbor(id='169902631b89202f0e2079e9cc09b3c7_I', distance=0.04768931865692139),
  MatchNeighbor(id='5614ccefd0ab9bee5cd28bf3d38fd12f_I', distance=0.0625385046005249)]]

# Create Streaming index with sample data
https://colab.sandbox.google.com/drive/1grzY8Idzq1kMDnFPqp5E7d8RWuuGactA?resourcekey=0-kUMWFahjwGNzSFrZkVHU8g#scrollTo=QRCnFT4B0jnm

# TODO: Update Batch Index

In [None]:
#Save this json as request.json and use in next cell
request = {
  "metadata": {
    "contentsDeltaUri": "INPUT_DIR",
    "isCompleteOverwrite": true
  }
}

In [None]:
curl -X PATCH \
    -H "Authorization: Bearer $(gcloud auth print-access-token)" \
    -H "Content-Type: application/json; charset=utf-8" \
    -d @request.json \
    "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/indexes/INDEX_ID"


# TODO: Update Streaming Index

In [None]:
#Upsert Data Point
DATAPOINT_ID_1=
DATAPOINT_ID_2=
curl -H "Content-Type: application/json" -H "Authorization: Bearer `gcloud auth print-access-token`" https://${LOCATION}-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/indexes/${INDEX_ID}:upsertDatapoints \
-d '{datapoints: [{datapoint_id: "'${DATAPOINT_ID_1}'", feature_vector: [...]},
{datapoint_id: "'${DATAPOINT_ID_2}'", feature_vector: [...]}]}'
