# Fast Embeddings

AlloyDB provides a very useful [embedding() function](https://cloud.google.com/alloydb/docs/ai/work-with-embeddings#embedding-generation) that creates embeddings directly in the database. However, this function does not always perform well when generating large batches of embeddings.

This notebook walks you through generating Vertex AI embeddings for the AlloyDB database used by the [GenWealth Demo App](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/gemini/sample-apps/genwealth). It dynamically builds a batch of text chunks to embed based on character length of the source data in order to get more results per inference, leading to much more efficient embeddings generation. It also uses psycopg to efficiently load the embeddings into AlloyDB after they are generated. These techniques can significantly speed up the process of generating large batches of embeddings and storing them in AlloyDB vs using the native embedding() function (about 6.5x faster based on limited testing).

## Setup

1. Install and import necessary packages.

In [None]:
# Install required libraries
!pip install psycopg2-binary --quiet

# Import necessary modules
import psycopg2
from tabulate import tabulate

import os, shutil
import tempfile
import json
import time
import csv

from typing import List, Optional
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel


2. Define variables to match your local environment.
  > This step assumes you have a secret stored in Secret Manager called alloydb-secret. You can use an alternate method to define your password if desired.

In [None]:
# Set GCP and AlloyDB configuration variables

# GCP vars
region = "us-central1"
project_id = "YOUR-PROJECT-ID"

# AlloyDB vars
alloydb_ip = "X.X.X.X"
database = "ragdemos"
user = "postgres"
password = !gcloud secrets versions access latest --secret="alloydb-secret"
password = str(password[0])

# Embedding vars
text_embedding_model_name = 'textembedding-gecko@003'
model = TextEmbeddingModel.from_pretrained(text_embedding_model_name)
task = "SEMANTIC_SIMILARITY"
max_tokens = 20000

3. Setup the database connection to AlloyDB.

In [None]:
# Establish a connection to AlloyDB
def getconn():
    conn = psycopg2.connect(
        host=alloydb_ip,
        database=database,
        user=user,
        password=password,
    )
    return conn

conn = getconn()

4. Retrieve the text data from AlloyDB that you want to embed.

In [None]:
# Store output in array of serializable dictionaries
source_array = []

# Define database query to get primary key plus text data to embed
# Ensure you retrieve the id key to uniquely identify the row you are embedding
sql = f"""
    SELECT id, overview, analysis FROM embedding_test;
    """

# Run the query
print(f"Running SQL query: {sql}")
with conn.cursor() as cur:
    cur.execute(sql)
    for row in cur.fetchall():
        source_array.append(dict(zip([col.name for col in cur.description], row)))

5. Define helper function to dynamically build batches of text chunks to get multiple emebeddings from the API with each inference.

In [None]:
# Function to build batches for embedding based on max tokens/characters
def build_batch_array(source_array, column_to_embed):
    batch_array = []
    current_chars = 0
    max_chars = max_tokens * 3

    global index_pointer
    global batch_char_count
    global total_char_count

    batch_char_count = 0

    while current_chars < max_chars:
      if index_pointer >= len(source_array):
        return batch_array
      current_chars = current_chars + len(source_array[index_pointer][column_to_embed])
      if current_chars > max_chars:
        batch_char_count = current_chars - len(source_array[index_pointer][column_to_embed])
        total_char_count = total_char_count + batch_char_count
        return batch_array
      else:
        batch_array.append(source_array[index_pointer])
        index_pointer = index_pointer + 1

6. Define helper functions to generate embeddings.

In [None]:
# Functions for text embedding and object embedding
def embed_text(
    texts: List[str],
    task: str = "SEMANTIC_SIMILARITY",
    model_name: str = "textembedding-gecko@003",
    #dimensionality: Optional[int] = 768,
) -> List[List[float]]:
    """Embeds texts with a pre-trained, foundational model."""
    model = TextEmbeddingModel.from_pretrained(model_name)
    inputs = [TextEmbeddingInput(text, task) for text in texts]
    #kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {}
    #embeddings = model.get_embeddings(inputs, **kwargs)
    embeddings = model.get_embeddings(inputs)
    return [embedding.values for embedding in embeddings]

def embed_objects(source_array, column_to_embed):
    source_array_length = len(source_array)

    print(f"Beginning source_array size: {source_array_length}")
    result_array = []

    # Define global variables to track progress and estimate cost
    global index_pointer
    global batch_count
    global batch_char_count
    global total_char_count

    while index_pointer < len(source_array):
        # Get a batch of up to batch_size objects
        # Objects in batch are removed from source_array by build_batch_array function
        batch_array = build_batch_array(source_array, column_to_embed)

        if batch_array:
          batch_count = batch_count + 1
          print(f"Processing batch {batch_count} with size: {len(batch_array)}. Progress: {index_pointer} / {source_array_length}. Character count (batch): {batch_char_count}. Character count (cumulative): {total_char_count}")

        texts_to_embed = [obj[column_to_embed] for obj in batch_array]
        embeddings = embed_text(texts_to_embed, model_name = text_embedding_model_name)

        for i, obj in enumerate(batch_array):
            obj['embedding'] = embeddings[i]
            result_array.append(obj)

    return result_array

7. Define helper functions to update embeddings in AlloyDB by creating a temp table, bulk loading the embeddings (along with the primary key of each embedding), updating the target table, and then dropping the temp table.

In [None]:
# Functions to manage temporary table and update the target table
def create_temp_table(column_to_embed):
    temp_table_name = f"{column_to_embed}_embeddings_temp"
    sql = f"""
    DROP TABLE IF EXISTS {temp_table_name};
    CREATE TABLE {temp_table_name} (
        id INTEGER PRIMARY KEY,
        col_name TEXT,
        embedding REAL[]
    );
    """

    with conn.cursor() as cur:
        cur.execute(sql)
    conn.commit()

    return temp_table_name


def insert_to_temp_table(temp_table_name, column_to_embed, object_array):
    with tempfile.NamedTemporaryFile(mode='w', delete=False) as temp_file:
        writer = csv.writer(temp_file, delimiter='|', quotechar="'", escapechar="'")
        for obj in object_array:
            # Ensure the embedding is represented as an array literal with curly braces
            embedding_str = "{" + ", ".join(map(str, obj['embedding'])) + "}"
            writer.writerow([obj['id'], column_to_embed, embedding_str])

    with conn.cursor() as cur:
        with open(temp_file.name, 'r') as f:
            cur.copy_expert(
                f"""COPY {temp_table_name} (id, col_name, embedding)
                FROM STDIN
                WITH (FORMAT csv, DELIMITER '|', QUOTE '''', ESCAPE '''')""",
                f
            )
    conn.commit()

    # Cleanup the temporary file
    os.remove(temp_file.name)


def update_target_table(temp_table_name, target_table_name, column_to_embed):
    sql = f"""
    UPDATE {target_table_name}
    SET {column_to_embed}_embedding = {temp_table_name}.embedding
    FROM {temp_table_name}
    WHERE {target_table_name}.id = {temp_table_name}.id;
    """

    #print(f"Running sql statement: {sql}")
    with conn.cursor() as cur:
        cur.execute(sql)
    conn.commit()


def drop_temp_table(temp_table_name):
    sql = f"""
    DROP TABLE {temp_table_name};
    """

    with conn.cursor() as cur:
        cur.execute(sql)
    conn.commit()


8. Define the target table and the source text chunk columns, then run the embedding process.

In [None]:
# Define table where embeddings will be written and columns to be embedded
target_table_name = 'embedding_test'
columns_to_embed = ['analysis','overview']

# Define global variables to track progress and estimate cost
global index_pointer
global batch_count
global batch_char_count
global total_char_count

# Define batch variables
batch_array = None
batch_size = None
batch_count = 0
total_char_count = 0

# Keep track of job timing
start_time = time.time()

for column_to_embed in columns_to_embed:
  # Initialize the index pointer for batch processing
  index_pointer = 0

  print(f"Creating embeddings for column {column_to_embed}...")
  results = embed_objects(source_array, column_to_embed)

  print(f"Creating temp table to store intermediate results...")
  temp_table_name = create_temp_table(column_to_embed)

  print(f"Inserting embeddings into temp table: {temp_table_name}...")
  insert_to_temp_table(temp_table_name, column_to_embed, results)

  print(f"Merging temp table {temp_table_name} with target table {target_table_name}...")
  update_target_table(temp_table_name, target_table_name, column_to_embed)

  print(f"Dropping temp table temp_table_name...")
  drop_temp_table(temp_table_name)

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Job started at: {time.ctime(start_time)}")
print(f"Job ended at: {time.ctime(end_time)}")
print(f"Total run time: {elapsed_time:.2f} seconds")