# Objective

Given a list of embeddings, create and deploy a Vertex AI Vector Search (fka Matching Engine) index.

Assumes you already have embeddings in GCS in a format supported by Vertex AI Vector Search. For instructions on that see:
1. [Notebook](https://source.corp.google.com/piper///depot/google3/experimental/genaisa/product_catalog/notebooks/0_generate_multimodal_embeddings.ipynb) for generating multimodal embeddings and storing in BQ  
2. [Saved query](https://pantheon.corp.google.com/bigquery?ws=!1m7!1m6!12m5!1m3!1ssolutions-2023-mar-107!2sus-central1!3s1b67da64-ecbc-42d6-945b-5fe1df4e559f!2e1) used to convert embeddings stored in BQ into vector search ingestion format. The product does not yet support direct import from BQ

In [None]:
!pip install google-cloud-aiplatform

---

#### ⚠️ Do not forget to click the "RESTART RUNTIME" button above.

---

In [None]:
import sys

if 'google.colab' in sys.modules:
    from google.colab import auth as google_auth
    google_auth.authenticate_user()

In [None]:
PROJECT_ID = 'solutions-2023-mar-107'
REGION = 'us-central1'
BUCKET_URI = 'gs://vector_search_regional/mercari_multimodal_embeddings/' # WHERE EMBEDDINGS ARE STORED

In [None]:
from google.cloud import aiplatform
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)

# Create Index

In [None]:
tree_ah_index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
    display_name='mercari_multimodal_batch_tree_cosine',
    contents_delta_uri=BUCKET_URI,
    dimensions=1408,
    approximate_neighbors_count=150,
    distance_measure_type="COSINE_DISTANCE",
    leaf_node_embedding_count=500,
    leaf_nodes_to_search_percent=7,
    description='Based on ~13K mercari product listings for which we have both a description and image',
)

In [None]:
INDEX_RESOURCE_NAME = tree_ah_index.resource_name
print(INDEX_RESOURCE_NAME)

In [None]:
tree_ah_index = aiplatform.MatchingEngineIndex(index_name=INDEX_RESOURCE_NAME)

# Deploy Index

In [None]:
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
    display_name='mercari_Q3',
    description='Endpoint for Q3 development on mercari',
    public_endpoint_enabled=True,
)

In [None]:
DEPLOYED_INDEX_ID = 'muiltimodal_13K'
my_index_endpoint = my_index_endpoint.deploy_index(
    index=tree_ah_index, deployed_index_id=DEPLOYED_INDEX_ID
)
my_index_endpoint.deployed_indexes

# Query Index

In [None]:
#TODO move this code to a module
import base64
from google.cloud import aiplatform
from google.protobuf import struct_pb2
from functools import cache
import time
import typing


# Inspired from https://stackoverflow.com/questions/34269772/type-hints-in-namedtuple.
class EmbeddingResponse(typing.NamedTuple):
  text_embedding: typing.Sequence[float]
  image_embedding: typing.Sequence[float]


class EmbeddingPredictionClient:
  """Wrapper around Prediction Service Client."""
  def __init__(self, project : str,
    location : str = "us-central1",
    api_regional_endpoint: str = "us-central1-aiplatform.googleapis.com"):
    client_options = {"api_endpoint": api_regional_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    self.client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)
    self.location = location
    self.project = project

  def get_embedding(self, text : str = None, image_path : str = None):
    """image_path can be a local path or a GCS URI."""
    if not text and not image_path:
      raise ValueError('At least one of text or image_bytes must be specified.')

    instance = struct_pb2.Struct()
    if text:
      instance.fields['text'].string_value = text

    if image_path:
      image_struct = instance.fields['image'].struct_value
      if image_path.lower().startswith('gs://'):
        image_struct.fields['gcsUri'].string_value = image_path
      else:
        with open(image_path, "rb") as f:
          image_bytes = f.read()
        encoded_content = base64.b64encode(image_bytes).decode("utf-8")
        image_struct.fields['bytesBase64Encoded'].string_value = encoded_content

    instances = [instance]
    endpoint = (f"projects/{self.project}/locations/{self.location}"
      "/publishers/google/models/multimodalembedding@001")
    response = self.client.predict(endpoint=endpoint, instances=instances)

    text_embedding = None
    if text:
      text_emb_value = response.predictions[0]['textEmbedding']
      text_embedding = [v for v in text_emb_value]

    image_embedding = None
    if image_path:
      image_emb_value = response.predictions[0]['imageEmbedding']
      image_embedding = [v for v in image_emb_value]

    return EmbeddingResponse(
      text_embedding=text_embedding,
      image_embedding=image_embedding)

@cache
def get_client(project):
  return EmbeddingPredictionClient(project)


def embed(project,text,image_path=None):
  client = get_client(project)
  start = time.time()
  response = client.get_embedding(text=text, image_path=image_path)
  end = time.time()
  print('Embedding Time: ', end - start)
  return response

In [None]:
res = embed(PROJECT_ID,
            'IZOD Women\'s Light Gray Thigh Length Pull On Golf Shorts',
            'gs://genai-product-catalog/toy_images/shorts.jpg')

In [None]:
NUM_NEIGHBORS = 5

response = my_index_endpoint.find_neighbors(
    deployed_index_id=DEPLOYED_INDEX_ID,
    queries=[res.text_embedding,res.image_embedding],
    num_neighbors=NUM_NEIGHBORS,
)

response

# Future Improvements

1. Use streaming mode to enable fast updates. Doesn't seem to be exposed in python client.
2. Take advantage of filtering option. Add filters for top level category and embedding type