In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/GoogleCloudPlatform/python-docs-samples/blob/main/alloydb/notebooks/generate_batch_embeddings.ipynb)

---
# Introduction

This notebook shows you how to batch generate vector embeddings and store them in an AlloyDB database.

With the steps listed here, you can dynamically build a batch of text chunks to embed based on character length of the source data in order to get more results per inference, leading to much more efficient embeddings generation. The process uses Asyncio to efficiently load the embeddings into AlloyDB after they are generated. These techniques can significantly speed up the process of generating large batches of embeddings and storing them in AlloyDB.

## What you'll need

* A Google Cloud Account and Google Cloud Project

## Setup and Requirements

In the following instructions you will learn to:

1. Install required dependencies for our application
2. Set up authentication for our project
3. Set up a AlloyDB for PostgreSQL Instance
4. Import the data used by our application

### Install dependencies

In [None]:
%pip install google-cloud-alloydb-connector[asyncpg]==1.4.0 --quiet

### Authenticate to Google Cloud within Colab
In order to access your Google Cloud Project from this notebook, you will need to Authenticate as an IAM user.

In [2]:
from google.colab import auth

auth.authenticate_user()

### Connect Your Google Cloud Project

In [None]:
# @markdown Please fill in the value below with your GCP project ID and then run the cell.

# Please fill in these values.
project_id = ""  # @param {type:"string"}

# Quick input validations.
assert project_id, "⚠️ Please provide a Google Cloud project ID"

# Configure gcloud.
!gcloud config set project {project_id}

### Enable APIs for AlloyDB and Vertex AI within your project

You will need to enable these APIs in order to create an AlloyDB database and utilize Vertex AI as an embeddings service!

In [None]:
# enable GCP services
!gcloud services enable alloydb.googleapis.com aiplatform.googleapis.com

## Set up AlloyDB
You will need a Postgres AlloyDB instance for the following stages of this notebook. Please set the following variables.

In [None]:
# @markdown Please fill in the both the Google Cloud region and name of your AlloyDB instance. Once filled in, run the cell.

# Please fill in these values.
region = ""  # @param {type:"string"}
cluster_name = ""  # @param {type:"string"}
instance_name = ""  # @param {type:"string"}
database_name = "testdb"  # @param {type:"string"}
table_name = "investments"
password = input("Please provide a password to be used for 'postgres' database user: ")

In [5]:
# Quick input validations.
assert region, "⚠️ Please provide a Google Cloud region"
assert instance_name, "⚠️ Please provide the name of your instance"
assert database_name, "⚠️ Please provide the name of your database_name"

### Create an AlloyDB Instance
If you have already created an AlloyDB Cluster and Instance, you can skip these steps and skip to the Create a database section.

> ⏳ - Creating an AlloyDB cluster may take a few minutes.

In [None]:
# create the AlloyDB Cluster
!gcloud beta alloydb clusters create {cluster_name} --password={password} --region={region}

Create an instance attached to our cluster with the following command.
> ⏳ - Creating an AlloyDB instance may take a few minutes.

In [None]:
!gcloud beta alloydb instances create {instance_name} --instance-type=PRIMARY --cpu-count=2 --region={region} --cluster={cluster_name}

To connect to your AlloyDB instance from this notebook, you will need to enable public IP on your instance. Alternatively, you can follow [these instructions](https://cloud.google.com/alloydb/docs/connect-external) to connect to an AlloyDB for PostgreSQL instance with Private IP from outside your VPC.

In [None]:
!gcloud beta alloydb instances update {instance_name} --region={region} --cluster={cluster_name} --assign-inbound-public-ip=ASSIGN_IPV4 --database-flags="password.enforce_complexity=on"

### Connect to AlloyDB

This function will create a connection pool to your AlloyDB instance using the [AlloyDB Python connector](https://github.com/GoogleCloudPlatform/alloydb-python-connector). The AlloyDB Python connector will automatically create secure connections to your AlloyDB instance using mTLS.

In [6]:
import asyncpg

import sqlalchemy
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

from google.cloud.alloydb.connector import AsyncConnector, IPTypes

async def init_connection_pool(connector: AsyncConnector, db_name: str, pool_size: int = 5) -> AsyncEngine:
    # initialize Connector object for connections to AlloyDB
    connection_string = f"projects/{project_id}/locations/{region}/clusters/{cluster_name}/instances/{instance_name}"

    async def getconn() -> asyncpg.Connection:
        conn: asyncpg.Connection = await connector.connect(
            connection_string,
            "asyncpg",
            user="postgres",
            password=password,
            db=db_name,
            ip_type=IPTypes.PUBLIC,
        )
        return conn

    pool = create_async_engine(
        "postgresql+asyncpg://",
        async_creator=getconn,
        pool_size=pool_size,
        max_overflow=0,
    )
    return pool

### Create a Database

Nex, you will create database to store the data using the connection pool. Enabling public IP takes a few minutes, you may get an error that there is no public IP address. Please wait and retry this step if you hit an error!

In [None]:
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlalchemy import text, exc

from google.cloud.alloydb.connector import AsyncConnector, IPTypes

async def create_db(database_name):
    # Get a raw connection directly from the connector
    connector = AsyncConnector()
    connection_string = f"projects/{project_id}/locations/{region}/clusters/{cluster_name}/instances/{instance_name}"
    pool = await init_connection_pool(connector, "postgres")
    async with pool.connect() as conn:
        try:
          await conn.execute(text("COMMIT")) # end transaction
          await conn.execute(text(f"CREATE DATABASE {database_name}"))
          print(f"Database '{database_name}' created successfully")
        except exc.ProgrammingError:
          print(f"Database '{database_name}' already exists")

await create_db(database_name=database_name)

### Download data

The following code has been prepared code to help insert the CSV data into your AlloyDB for PostgreSQL database.

Download the CSV file:

In [None]:
# TODO: Change cloud bucket
!gsutil cp gs://cloud-samples-data/alloydb/investments_data /content/investments.csv

The download can be verified by the following command or using the "Files" tab.

In [None]:
!ls

In this next step you will:

1. Create the table into store data
2. And insert the data from the CSV into the database table

### Import data to your database


In [16]:
# Prepare data
import pandas as pd

data = "/content/investments.csv"

df = pd.read_csv(data)
df['etf'] = df['etf'].map({'t': True, 'f': False})
df['rating'] = df['rating'].astype(str).fillna('')

In [None]:
df.head()

The data consists of the following columns:

* **id**
* **ticker**: A string representing the stock symbol or ticker (e.g., "AAPL" for Apple, "GOOG" for Google).
* **etf**: A boolean value indicating whether the asset is an ETF (True) or not (False).
* **market**:  A string representing the stock exchange where the asset is traded.
* **rating**: Whether to hold, buy or sell a stock.
* **overview**: A text field for a general overview or description of the asset.
* **analysis**: A text field, for a more detailed analysis of the asset.
* **overview_embedding** (empty)
* **analysis_embedding** (empty)

In [None]:
create_table_cmd = sqlalchemy.text(
    f'CREATE TABLE {table_name} ( \
    id SERIAL PRIMARY KEY, \
    ticker VARCHAR(255) NOT NULL UNIQUE, \
    etf BOOLEAN, \
    market VARCHAR(255), \
    rating TEXT,  \
    overview TEXT, \
    overview_embedding VECTOR (768), \
    analysis TEXT,  \
    analysis_embedding VECTOR (768) \
    )'
)


insert_data_cmd = sqlalchemy.text(
    f"""
    INSERT INTO {table_name} (id, ticker, etf, market,
      rating, overview, analysis) VALUES (:id, :ticker, :etf, :market,
      :rating, :overview, :analysis)
    """
)

parameter_map = [
    {
        "id": row["id"],
        "ticker": row["ticker"],
        "etf": row["etf"],
        "market": row["market"],
        "rating": row["rating"],
        "overview": row["overview"],
        "analysis": row["analysis"],
    }
    for index, row in df.iterrows()
]

In [None]:
from google.cloud.alloydb.connector import AsyncConnector

connector = AsyncConnector()

# Create table and insert data
async def insert_data(pool):
  async with pool.connect() as db_conn:
    await db_conn.execute(sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector;"))
    await db_conn.execute(create_table_cmd)
    await db_conn.execute(
        insert_data_cmd,
        parameter_map,
    )
    await db_conn.commit()

pool = await init_connection_pool(connector, database_name)
await insert_data(pool)
await pool.dispose()

## Create the embeddings workflow

The embeddings workflow contains four major parts:
1. Read the data
2. Batch the data
3. Generate embeddings
4. Update original table


#### Step 0:  Configure Logging

In [7]:
import logging
import sys

# Configure the root logger to output messages with INFO level or above
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s[%(levelname)5s][%(name)14s] - %(message)s',  datefmt='%H:%M:%S', force=True)

#### Step 1: Read the data

This code reads data from a database and yields it for further processing.

In [8]:
from typing import AsyncIterator, List
from sqlalchemy import RowMapping
from sqlalchemy.ext.asyncio import AsyncEngine

async def get_source_data(pool: AsyncEngine, embed_cols: List[str]) -> AsyncIterator[RowMapping]:
  """
  Yields data in the form of:
      {'id' : 'id1', 'col1': 'val1', 'col2': 'val2'}
  where col1 and col2 are columns containing data to be embedded.
  """
  logger = logging.getLogger('get_source_data')

  sql = f"SELECT id, {', '.join(embed_cols)} FROM {table_name}"
  logger.info(f"Running SQL query: {sql}")
  async with pool.connect() as conn:
    async for row in await conn.stream(text(sql)):
      logger.debug(f"yielded row: {row._mapping['id']}")
      # Yield the row as a dictionary (RowMapping)
      yield row._mapping

#### Step 2: Batch the data

This code defines a function called `batch_source_data` that takes database rows and groups them into batches based on a character count limit (max_char_count). This batching process is crucial for efficient embedding generation for these reasons:

* **Resource Optimization:**  Instead of sending numerous small requests, batching allows us to send fewer, larger requests. This significantly optimizes resource usage and potentially reduces API costs.

* **Working Within API Limits:**  The max_char_count limit ensures each batch stays within the API's acceptable input size, preventing issues with exceeding the maximum character limit.


In [9]:
from typing import Any, List

async def batch_source_data(read_generator: AsyncIterator[RowMapping]) ->  AsyncIterator[List[dict[str, Any]]]:
  """
  Yields data in the form of:
  [
    {'id' : 'id1', 'col1': 'val1', 'col2': 'val2'},
    ...
  ]
  where col1 and col2 are columns containing data to be embedded.
  """
  logger = logging.getLogger('batch_data')

  batch = []
  char_count = 0
  batch_num = 0

  async for row in read_generator:
    # Char count in current row
    row_char_count = sum(len(row[col]) for col in cols_to_embed)

    if char_count + row_char_count > max_char_count:
        batch_num += 1
        logger.info(f"yielded batch number: {batch_num} with length: {len(batch)}")
        yield batch
        batch, char_count = [], 0

    # Add the current row to the batch
    batch.append(row)
    char_count += row_char_count

  if batch:
      batch_num += 1
      logger.info(f"Yielded batch number: {batch_num} with length: {len(batch)}")
      yield batch

#### Step 3: Generate embeddings

This step converts your text data into numerical representations called "embeddings." These embeddings capture the meaning and relationships between words, making them useful for various tasks like search, recommendations, and clustering.

The code uses two functions to efficiently generate embeddings:

**embed_text**

This function your text data and sends it to vertex AI, transforming the text in specific columns into embeddings.

**embed_objects_concurrently**

This function is the orchestrator. It manages the embedding generation process for multiple batches of text concurrently. This function ensures that all batches are processed efficiently without overwhelming the system.

In [27]:
import vertexai
from google.api_core.exceptions import ResourceExhausted
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel

async def embed_text(
    logger: logging.Logger,
    batch_data: List[dict[str, Any]],
    model: TextEmbeddingModel,
    cols_to_embed: List[str],
    task_type: str = "SEMANTIC_SIMILARITY",
    retries: int = 100,
    delay: int = 1,
) -> List[dict[str, List[float] | str]]:
  """
  Returns data in the form of:
  [
    {
      'id': 'id1',
      'col1_embedding': [1.0, 1.1, ...],
      'col2_embedding': [2.0, 2.1, ...],
      ...
    },
    ...
  ]
  where col1 and col2 are columns containing data to be embedded.
  """
  global total_char_count

  # Place all of the embeddings into a single list
  inputs = []
  for data in batch_data:
    inputs.extend(
      TextEmbeddingInput(data[col], task_type) for col in cols_to_embed
    )

  for attempt in range(retries):  # Retry loop
    try:
      # Get embeddings for the text data
      embeddings = await model.get_embeddings_async(inputs)

      # Increase total char count
      total_char_count += sum([len(input.text) for input in inputs])

      # group the results together by id
      embedding_iter = iter(embeddings)
      results = []
      for row in batch_data:
        r = { 'id': row['id'] }
        for col in cols_to_embed:
          r[f'{col}_embedding'] = str(next(embedding_iter).values)
        results.append(r)
      return results

    except ResourceExhausted as e:
      if attempt < retries - 1:  # Retry only if attempts are left
        logger.warning(f"Error: {e}. Retrying in {delay} seconds...")
        await asyncio.sleep(delay)  # Wait before retrying
      else:
        logger.error(f"Failed to get embeddings after {retries} attempts.")
        raise  # Raise the error if all retries fail

  return []

async def embed_objects_concurrently(
    cols_to_embed: List[str],
    batch_data: AsyncIterator[List[dict[str, Any]]],
    model: TextEmbeddingModel,
    task_type: str,
    max_concurrency: int = 5,
) -> AsyncIterator[List[dict[str, str | List[float]]]]:
    """
    Embeds text from objects concurrently with a maximum concurrency limit. This
    function processes batches of data concurrently, limiting the number of
    simultaneous embedding tasks to improve efficiency and resource utilization.
    """
    logger = logging.getLogger('embed_objects')
    # Keep track of pending tasks
    pending: set[asyncio.Task] = set()
    has_next = True
    while pending or has_next:
      while len(pending) < max_concurrency and has_next:
        try:
          data = await anext(batch_data)
          coro = embed_text(logger, data, model, cols_to_embed, task_type)
          pending.add(asyncio.ensure_future(coro))
        except StopAsyncIteration:
          has_next = False

      done, pending = await asyncio.wait(
          pending, return_when=asyncio.FIRST_COMPLETED
      )

      for task in done:
        result = task.result()
        logger.info(f"Embedding task completed: Processed {len(result)} rows.")
        yield result


#### Step 4: Update original table

After generating embeddings for your text data, you need to store them in your database. This step efficiently updates your original table with the newly created embeddings.

This process uses two functions to manage database updates:

**batch_update_rows**
1. This function takes a batch of data (including the embeddings) and updates the corresponding rows in your database table.
2. It constructs an SQL UPDATE query to modify specific columns with the embedding values.
3. It ensures that the updates are done efficiently and correctly within a database transaction.


**batch_update_rows_concurrently**

1. This function handles the concurrent updating of multiple batches of data.
2. It creates multiple "tasks" that each execute the batch_update_rows function on a separate batch.
3. It limits the number of concurrent tasks to avoid overloading your database and system resources.
4. It manages the execution of these tasks, ensuring that all batches are processed efficiently.

In [28]:
from sqlalchemy import text

async def batch_update_rows(pool: AsyncEngine, logger: logging.Logger, data: List[dict[str, Any]]) -> None:
  update_query = f"""
    UPDATE {table_name}
    SET {', '.join([f'{col}_embedding = :{col}_embedding' for col in cols_to_embed])}
    WHERE id = :id;
  """

  insert_params = []
  for row in data:
    insert_params.append({})

  async with pool.connect() as conn:
    await conn.execute(
        text(update_query),
        # Create parameters for all rows in the data
        parameters = data,
    )
    await conn.commit()
  logger.info(f"Updated {len(data)} rows in database.")


async def batch_update_rows_concurrently(
    pool: AsyncEngine,
    embed_data: AsyncIterator[List[dict[str, Any]]],
    max_concurrency: int = 5
):
  logger = logging.getLogger('update_rows')
  # Keep track of pending tasks
  pending: set[asyncio.Task] = set()
  has_next = True
  while pending or has_next:
    while len(pending) < max_concurrency and has_next:
      try:
        data = await anext(embed_data)
        coro = batch_update_rows(pool, logger, data)
        pending.add(asyncio.ensure_future(coro))
      except StopAsyncIteration:
        has_next = False

    done, pending = await asyncio.wait(
        pending, return_when=asyncio.FIRST_COMPLETED
    )

  logger.info("All database update tasks completed.")

## Run the embeddings workflow



In [29]:
# Max token count for the embeddings API
max_tokens = 20000

# For some tokenizers and text, there's a rough approximation that 1 token corresponds to about 3-4 characters. This is a very general guideline and can vary significantly.
max_char_count = max_tokens * 3

cols_to_embed = ['analysis','overview']

# Model to use for generating embeddings
model_name = 'text-embedding-004'

# Generate optimised embeddings for a given task
# Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types#supported_task_types
task = "SEMANTIC_SIMILARITY"

This runs the complete embeddings workflow:

1. Gettting source data
2. Batching source data
3. Generating embeddings for batches
4. Updating data batches in the original table

In [None]:
import vertexai
import time
import asyncio
from vertexai.language_models import TextEmbeddingModel

pool_size = 10
embed_data_concurrency = 20
batch_update_concurrency = 10
total_char_count = 0

# Set up connections to the database
connector = AsyncConnector()
pool = await init_connection_pool(connector, database_name, pool_size=pool_size)

# Initialise VertexAI and the model to be used to generate embeddings
vertexai.init(project=project_id, location=region)
model = TextEmbeddingModel.from_pretrained(model_name)

start_time = time.monotonic()

# Fetch source data from the database
source_data = get_source_data(pool, cols_to_embed)

# Divide the source data into batches for efficient processing
batch_data = batch_source_data(source_data)

# Generate embeddings for the batched data concurrently
embeddings_data = embed_objects_concurrently(cols_to_embed, batch_data, model, task, max_concurrency=embed_data_concurrency)

# Update the database with the generated embeddings concurrently
await batch_update_rows_concurrently(pool, embeddings_data, max_concurrency=batch_update_concurrency)

end_time = time.monotonic()
elapsed_time = end_time - start_time

# Release database connections and close the connector
await pool.dispose()
await connector.close()

print(f"Job started at: {time.ctime(start_time)}")
print(f"Job ended at: {time.ctime(end_time)}")
print(f"Total run time: {elapsed_time:.2f} seconds")
print(f"Total characters embedded: {total_char_count}")