# Generate and store batch embeddings in AlloyDB

This notebook shows you how to generate batch vector embeddings and store them in an AlloyDB database using the [GenWealth Demo App](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/gemini/sample-apps/genwealth) as the sample dataset. 

With the steps listed here, you can dynamically build a batch of text chunks to embed based on character length of the source data in order to get more results per inference, leading to much more efficient embeddings generation. The process uses the `psycopg` library to efficiently load the embeddings into AlloyDB after they are generated. These techniques can significantly speed up the process of generating large batches of embeddings and storing them in AlloyDB vs using the native embedding() function (about 6.5x faster based on limited testing).

### Before you Begin
* Download and set up the [GenWealth Demo App](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/gemini/sample-apps/genwealth).
* Running the steps in this notebook will incur Google Cloud charges. You may also be billed for Vertex AI API usages.



### Install and Import Necessary Packages.

In [None]:
# Install required libraries
%pip install psycopg2-binary --quiet

# Import necessary modules
import psycopg2
from tabulate import tabulate

import os
import tempfile
import time
import csv

from typing import List, Optional
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel


### Define Environment Variables 

Update the variable values in this cell to match your environment. 

> NOTE: This step assumes you have a secret stored in Secret Manager called `alloydb-secret`. You can use an alternate method to define your password if desired.

In [None]:
# Set GCP and AlloyDB configuration variables

# GCP vars
region = "us-central1"
project_id = "YOUR-PROJECT-ID"

# AlloyDB vars
alloydb_ip = "X.X.X.X"
database = "ragdemos"
user = "postgres"
password = !gcloud secrets versions access latest --secret="alloydb-secret"
password = str(password[0])

# Embedding vars
text_embedding_model_name = 'textembedding-gecko@003'
model = TextEmbeddingModel.from_pretrained(text_embedding_model_name)
task = "SEMANTIC_SIMILARITY"
max_tokens = 20000

### Establish a Connection to AlloyDB.

In [None]:
# Establish a connection to AlloyDB
def getconn():
    conn = psycopg2.connect(
        host=alloydb_ip,
        database=database,
        user=user,
        password=password,
    )
    return conn

conn = getconn()

### Fetch Source Data to Embed

This step loads the text data you want to embed from AlloyDB (i.e. the `overview` and `analysis` columns in the example below). You can update the SQL query below to match your target environment.  

> IMPORTANT: Make sure to retrieve the primary key along with the columns you want to embed (i.e. `id` in the example below). This key will be used to uniquely identify the rows you are embedding during the bulk load and update process later in this notebook.   

In [None]:
# Store output in array of serializable dictionaries
source_array = []

# Define database query to get primary key plus text data to embed
# Ensure you retrieve the id key to uniquely identify the row you are embedding
sql = f"""
    SELECT id, overview, analysis FROM investments;
    """

# Run the query
print(f"Running SQL query: {sql}")
with conn.cursor() as cur:
    cur.execute(sql)
    for row in cur.fetchall():
        source_array.append(dict(zip([col.name for col in cur.description], row)))

### Define Batching Function

This helper function dynamically builds batches of text chunks to efficiently generate multiple embeddings with each call to the API.

In [None]:
# Function to build batches for embedding based on max tokens/characters
def build_batch_array(source_array, column_to_embed):
    batch_array = []
    current_chars = 0
    max_chars = max_tokens * 3

    global index_pointer
    global batch_char_count
    global total_char_count

    batch_char_count = 0

    while current_chars < max_chars:
      if index_pointer >= len(source_array):
        return batch_array
      current_chars = current_chars + len(source_array[index_pointer][column_to_embed])
      if current_chars > max_chars:
        batch_char_count = current_chars - len(source_array[index_pointer][column_to_embed])
        total_char_count = total_char_count + batch_char_count
        return batch_array
      else:
        batch_array.append(source_array[index_pointer])
        index_pointer = index_pointer + 1

### Define Embedding Functions. 

This example uses the `textembedding-gecko@003` model to generate embeddings. You can update the embedding model by setting the `model_name` variable if desired. 

If you are using `text-embedding-004` or above, uncomment the lines starting with `dimensionality`, `kwargs`, and `embeddings` in the cell below.

In [None]:
# Functions for text embedding and object embedding
def embed_text(
    texts: List[str],
    task: str = "SEMANTIC_SIMILARITY",
    model_name: str = "textembedding-gecko@003",
    #dimensionality: Optional[int] = 768,
) -> List[List[float]]:
    """Embeds texts with a pre-trained, foundational model."""
    model = TextEmbeddingModel.from_pretrained(model_name)
    inputs = [TextEmbeddingInput(text, task) for text in texts]
    #kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {}
    #embeddings = model.get_embeddings(inputs, **kwargs)
    embeddings = model.get_embeddings(inputs)
    return [embedding.values for embedding in embeddings]

def embed_objects(source_array, column_to_embed):
    source_array_length = len(source_array)

    print(f"Beginning source_array size: {source_array_length}")
    result_array = []

    # Define global variables to track progress and estimate cost
    global index_pointer
    global batch_count
    global batch_char_count
    global total_char_count

    while index_pointer < len(source_array):
        # Get a batch of up to batch_size objects
        # Objects in batch are removed from source_array by build_batch_array function
        batch_array = build_batch_array(source_array, column_to_embed)

        if batch_array:
          batch_count = batch_count + 1
          print(f"Processing batch {batch_count} with size: {len(batch_array)}. Progress: {index_pointer} / {source_array_length}. Character count (batch): {batch_char_count}. Character count (cumulative): {total_char_count}")

        texts_to_embed = [obj[column_to_embed] for obj in batch_array]
        embeddings = embed_text(texts_to_embed, model_name = text_embedding_model_name)

        for i, obj in enumerate(batch_array):
            obj['embedding'] = embeddings[i]
            result_array.append(obj)

    return result_array

### Define Bulk Load and Update Functions

This step defines functions that update embeddings in the target database using the following process:

1. Create a temp table
2. Bulk load the generated embeddings (along with the primary key of each embedding) into the temp table
3. Update the target table by joining to the temp table based on the primary key. 
4. Drop the temp table.

This method is faster and more efficient than updating rows one by one with multiple round trips to the database.

You can modify these functions to generate embeddings in your own environment by changing the SQL queries to match your table structure. Be sure to also update the primary key name in function `insert_to_temp_table` to match your primary key (the example uses `id` as the PK).

In [None]:
# Functions to manage temporary table and update the target table
def create_temp_table(column_to_embed):
    temp_table_name = f"{column_to_embed}_embeddings_temp"
    # Update the SQL query below to match your environment
    sql = f"""
    DROP TABLE IF EXISTS {temp_table_name};
    CREATE TABLE {temp_table_name} (
        id INTEGER PRIMARY KEY,
        col_name TEXT,
        embedding REAL[]
    );
    """

    with conn.cursor() as cur:
        cur.execute(sql)
    conn.commit()

    return temp_table_name


def insert_to_temp_table(temp_table_name, column_to_embed, object_array):
    with tempfile.NamedTemporaryFile(mode='w', delete=False) as temp_file:
        writer = csv.writer(temp_file, delimiter='|', quotechar="'", escapechar="'")
        for obj in object_array:
            # Ensure the embedding is represented as an array literal with curly braces
            embedding_str = "{" + ", ".join(map(str, obj['embedding'])) + "}"
            # Update primary key name here (this example uses 'id' as the PK)
            writer.writerow([obj['id'], column_to_embed, embedding_str])

    with conn.cursor() as cur:
        with open(temp_file.name, 'r') as f:
            cur.copy_expert(
                # Use the COPY command for efficient loading of the temp table
                # Update the query below to match your environment
                f"""COPY {temp_table_name} (id, col_name, embedding)
                FROM STDIN
                WITH (FORMAT csv, DELIMITER '|', QUOTE '''', ESCAPE '''')""",
                f
            )
    conn.commit()

    # Cleanup the temporary file
    os.remove(temp_file.name)


def update_target_table(temp_table_name, target_table_name, column_to_embed):
    # Update the SQL query below to match your environment
    sql = f"""
    UPDATE {target_table_name}
    SET {column_to_embed}_embedding = {temp_table_name}.embedding
    FROM {temp_table_name}
    WHERE {target_table_name}.id = {temp_table_name}.id;
    """

    #print(f"Running sql statement: {sql}")
    with conn.cursor() as cur:
        cur.execute(sql)
    conn.commit()


def drop_temp_table(temp_table_name):
    sql = f"""
    DROP TABLE {temp_table_name};
    """

    with conn.cursor() as cur:
        cur.execute(sql)
    conn.commit()


### Run the Embedding Process

This step runs the embedding process using the table structure and functions defined in the cells above. 

To modify this code to run in your environment, update the target table (`investments` in this example) and the source text chunk columns (`analysis` and `overview` in this example) to match your data structure.

In [None]:
# Define table where embeddings will be written and columns to be embedded
target_table_name = 'investments'
columns_to_embed = ['analysis','overview']

# Define global variables to track progress and estimate cost
global index_pointer
global batch_count
global batch_char_count
global total_char_count

# Define batch variables
batch_array = None
batch_size = None
batch_count = 0
total_char_count = 0

# Keep track of job timing
start_time = time.time()

for column_to_embed in columns_to_embed:
  # Initialize the index pointer for batch processing
  index_pointer = 0

  print(f"Creating embeddings for column {column_to_embed}...")
  results = embed_objects(source_array, column_to_embed)

  print(f"Creating temp table to store intermediate results...")
  temp_table_name = create_temp_table(column_to_embed)

  print(f"Inserting embeddings into temp table: {temp_table_name}...")
  insert_to_temp_table(temp_table_name, column_to_embed, results)

  print(f"Merging temp table {temp_table_name} with target table {target_table_name}...")
  update_target_table(temp_table_name, target_table_name, column_to_embed)

  print(f"Dropping temp table temp_table_name...")
  drop_temp_table(temp_table_name)

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Job started at: {time.ctime(start_time)}")
print(f"Job ended at: {time.ctime(end_time)}")
print(f"Total run time: {elapsed_time:.2f} seconds")