In [None]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Overview

This notebook demonstrates using multimodal [Vertex AI Embedding API](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-image-embeddings) to enhance our data with embedded representations

Input:
- A BigQuery table with a text column and GCS URIs

Output:
- Two embeddings for each product stored in the BiqQuery table. 
- One based on the product descrption, and the other based on the first product image


It is important to note that text and image data is not fused into a single embedding. The service provides text-only and image-only embeddings, but they share an embedding space. In other words the text ‘cat’ and a picture of a cat should return similar embeddings. This is useful because it allows us at inference time to operate with text-only or image-only inputs instead of requiring both.

# Setup

### Install Dependencies (If Needed)

The list `packages` contains tuples of package import names and install names. If the import name is not found then the install name is used to install quitely for the current user.

In [1]:
# tuples of (import name, install name)
packages = [
    ('google.cloud.aiplatform', 'google-cloud-aiplatform'),
]

import importlib
install = False
for package in packages:
    if not importlib.util.find_spec(package[0]):
        print(f'installing package {package[1]}')
        install = True
        !pip install {package[1]} -U -q --user

### Restart Kernel (If Installs Occured)

After a kernel restart the code submission can start with the next cell after this one.

In [2]:
if install:
    import IPython
    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)

### Authenticate

If you are using Colab, you will need to authenticate yourself first. The next cell will check if you are currently using Colab, and will start the authentication process.

In [3]:
import sys

if 'google.colab' in sys.modules:
    from google.colab import auth as google_auth
    google_auth.authenticate_user()

### Config

Update the below variables to point to the BigQuery table and GCS bucket you created in the notebook 0_EDA_flipkart_dataset.ipynb

In [8]:
PROJECT = 'solutions-2023-mar-107' # @param {type:"string"}
LOCATION = 'us-central1' # @param {type:"string"}
BQ_TABLE = 'solutions-2023-mar-107.flipkart.18K_no_duplicate' # @param {type:"string"}
TEST_DESCRIPTION = "Key Features of Vishudh Printed Women's Straight Kurta BLACK, GREY Straight,Specifications of Vishudh Printed Women's Straight Kurta Kurta Details Sleeve Sleeveless Number of Contents in Sales Package Pack of 1 Fabric 100% POLYESTER Type Straight Neck ROUND NECK General Details Pattern Printed Occasion Festive Ideal For Women's In the Box Kurta Additional Details Style Code VNKU004374 BLACK::GREY Fabric Care Gentle Machine Wash in Lukewarm Water, Do Not Bleach" # @param {type:"string"}
TEST_IMAGE = 'gs://genai-product-catalog/flipkart_20k_oct26/3ecb859759e5311cbab6850e98879522_0.jpg' # @param {type:"string"}

# Online Embedding API

Client code for [multimodal embedding API](https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-multimodal-embeddings#api-usage), we reproduce it below with minor modifications.

In [9]:
import base64
from google.cloud import aiplatform
from google.protobuf import struct_pb2
from functools import cache
import time
import typing
import logging


class EmbeddingResponse(typing.NamedTuple):
  text_embedding: typing.Sequence[float]
  image_embedding: typing.Sequence[float]


class EmbeddingPredictionClient:
  """Wrapper around Prediction Service Client."""
  def __init__(self, project : str,
    location : str = "us-central1",
    api_regional_endpoint: str = "us-central1-aiplatform.googleapis.com"):
    client_options = {"api_endpoint": api_regional_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    self.client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)
    self.location = location
    self.project = project

  def get_embedding(self, text : str = None, image_path : str = None):
    """image_path can be a local path or a GCS URI."""
    if not text and not image_path:
      raise ValueError('At least one of text or image_bytes must be specified.')

    instance = struct_pb2.Struct()

    if text:
      if len(text) > 1024:
        logging.warning('Text must be less than 1024 characters. Truncating text.')
        text = text[:1024]
      instance.fields['text'].string_value = text

    if image_path:
      image_struct = instance.fields['image'].struct_value
      if image_path.lower().startswith('gs://'):
        image_struct.fields['gcsUri'].string_value = image_path
      else:
        with open(image_path, "rb") as f:
          image_bytes = f.read()
        encoded_content = base64.b64encode(image_bytes).decode("utf-8")
        image_struct.fields['bytesBase64Encoded'].string_value = encoded_content

    instances = [instance]
    endpoint = (f"projects/{self.project}/locations/{self.location}"
      "/publishers/google/models/multimodalembedding@001")
    try:
      response = self.client.predict(endpoint=endpoint, instances=instances)

      text_embedding = None
      if text:
        text_emb_value = response.predictions[0]['textEmbedding']
        text_embedding = [v for v in text_emb_value]

      image_embedding = None
      if image_path:
        image_emb_value = response.predictions[0]['imageEmbedding']
        image_embedding = [v for v in image_emb_value]

      return EmbeddingResponse(
        text_embedding=text_embedding,
        image_embedding=image_embedding)
    except Exception as e:
      print(e)
      return None
    
@cache
def get_client(project):
  return EmbeddingPredictionClient(project)


def embed(project,text,image_path=None):
  client = get_client(project)
  start = time.time()
  response = client.get_embedding(text=text, image_path=image_path)
  end = time.time()
  print('Embedding Time: ', end - start)
  return response


### Invoke

In [11]:
res = embed(PROJECT, TEST_DESCRIPTION, TEST_IMAGE)
print(res.text_embedding[:5])
print(res.image_embedding[:5])

Embedding Time:  0.4698326587677002
[-0.0165299866, -0.0692435578, 0.0147973252, 0.0349166617, 0.00536282221]
[-0.00627771486, 0.0557949618, -0.0300531555, 0.0268286057, 0.0392316505]


# Add Embeddings to BQ

This is a naive approach which loads BQ rows sequentially and embeds using the online embedding API. We use pagination to avoid OOM errors for large datasets.

We batch BQ row updates for efficiency. The embedding API is rate limited at 120 req/min so this is as fast as we can go.

In the future, pending product updates, there may be more efficient options:
1. Embed directly from BQ. This is supported for textembedding-gecko today (https://cloud.google.com/blog/products/data-analytics/introducing-bigquery-text-embeddings) and if support roles out for multimodal this would be the ideal choice
2. Once a batch embedding API is available use that instead

In [None]:
from google.cloud import bigquery
client = bigquery.Client(PROJECT)
# Only fetch rows with no embedding. Bypass this query to update all rows
query = f"""
SELECT
  id,
  description,
  image_uri
FROM
  `{BQ_TABLE}`
WHERE
  ARRAY_LENGTH(text_embedding) = 0
"""
query_job = client.query(query)
query_job.result()
destination = query_job.destination
rows = client.list_rows(destination, max_results=1)
print(rows.total_rows)

In [None]:
BATCH_SIZE=10 #Set to 1 to clean up stragglers if rows not a multiple of batch_size
text_embeddings, image_embeddings, ids = [], [], []

for i,row in enumerate(rows):
    print(f'\n{i+1}: {row["description"]}\nimage_uri:{row["image_uri"]}')
    res = embed(PROJECT,row["description"][:900],row["image_uri"]) #API claims to supports up to 1024 chars but in practice get errors for shorter lengths
    print(res.text_embedding[:5])
    print(res.image_embedding[:5])
    text_embeddings.append(res.text_embedding)
    image_embeddings.append(res.image_embedding)
    ids.append(row["id"])
    if len(text_embeddings) == BATCH_SIZE:
      print(f'\nBATCHING {BATCH_SIZE} UPDATES TO BQ...')
      query = f"""
      UPDATE
        `{BQ_TABLE}`
      SET
        text_embedding = (
          CASE
            {''.join([f'WHEN id = "{ids[i]}" THEN {text_embeddings[i]}{chr(10)}' for i in range(len(ids))])}
          END),
        image_embedding = (
          CASE
            {''.join([f'WHEN id = "{ids[i]}" THEN {image_embeddings[i]}{chr(10)}' for i in range(len(ids))])}
          END)
      WHERE
        id IN {str(ids).replace('[','(').replace(']',')')}
      """
      start = time.time()
      query_job = client.query(query)
      query_job.result()  # Wait for the query to complete.
      end = time.time()
      print('BQ Update Time: ', end - start)
      text_embeddings, image_embeddings, ids = [], [], []