In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Monitor batch prediction with Gemini API

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/batch-prediction/monitor_batch_prediction_gemini_api.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Fbatch-prediction%2Fmonitor_batch_prediction_gemini_api.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/batch-prediction/monitor_batch_prediction_gemini_api.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/batch-prediction/monitor_batch_prediction_gemini_api.ipynb">
      <img width="32px" src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

<div style="clear: both;"></div>

<b>Share to:</b>

<a href="https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/batch-prediction/monitor_batch_prediction_gemini_api.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg" alt="LinkedIn logo">
</a>

<a href="https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/batch-prediction/monitor_batch_prediction_gemini_api.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg" alt="Bluesky logo">
</a>

<a href="https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/batch-prediction/monitor_batch_prediction_gemini_api.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/53/X_logo_2023_original.svg" alt="X logo">
</a>

<a href="https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/batch-prediction/monitor_batch_prediction_gemini_api.ipynb" target="_blank">
  <img width="20px" src="https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" alt="Reddit logo">
</a>

<a href="https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/batch-prediction/monitor_batch_prediction_gemini_api.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg" alt="Facebook logo">
</a>

| | |
|-|-|
| Author(s) |  [Ivan Nardini](https://github.com/your-github-username/) |

## Overview

While the Gemini API allows asynchronous batch predictions to Cloud Storage or BigQuery, it currently lacks built-in completion notifications. This notebook addresses this gap by leveraging Vertex AI Pipelines to manage the workflow and track job status.


### Objectives

This tutorial demonstrates how to orchestrate and monitor Gemini batch prediction jobs using Vertex AI Pipelines.

Specifically, you will learn how to:

1. **Prepare Batch Inputs and Output Location:** Set up your data in Cloud Storage and designate a Cloud Storage bucket for the model's output.
2. **Build a Vertex AI Pipeline for Batch Prediction:** Define a pipeline that encapsulates the batch prediction job.
3. **Submit a Vertex AI Pipeline Job:** Execute the defined pipeline, triggering the batch prediction process on the Gemini model.  
4. **Retrieve Batch Prediction Results:**  Access and process the predictions generated by the Gemini model once the pipeline completes.

## Get started

### Install Vertex AI SDK and other required packages


In [None]:
%pip install --upgrade --user --quiet google-cloud-aiplatform google-cloud-bigquery kfp google-cloud-pipeline-components

### Restart runtime

To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.

The restart might take a minute or longer. After it's restarted, continue to the next step.

In [None]:
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

<div class="alert alert-block alert-warning">
<b>⚠️ The kernel is going to restart. In Colab or Colab Enterprise, you might see an error message that says "Your session crashed for an unknown reason." This is expected. Wait until it's finished before continuing to the next step. ⚠️</b>
</div>


### Authenticate your notebook environment (Colab only)

If you're running this notebook on Google Colab, run the cell below to authenticate your environment.

In [None]:
import sys

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user()

### Requirements

#### Set Project ID and Location

To get started using Vertex AI, you must have an existing Google Cloud project and [enable these APIs](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com,artifactregistry.googleapis.com).

Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).

In [None]:
# Use the environment variable if the user doesn't provide Project ID.
import os

PROJECT_ID = "[your-project-id]"  # @param {type: "string", placeholder: "[your-project-id]", isTemplate: true}

if not PROJECT_ID or PROJECT_ID == "[your-project-id]":
    PROJECT_ID = str(os.environ.get("GOOGLE_CLOUD_PROJECT"))

PROJECT_NUMBER = !gcloud projects describe {PROJECT_ID} --format="get(projectNumber)"[0]
PROJECT_NUMBER = PROJECT_NUMBER[0]

LOCATION = os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")

#### Set and create a Cloud Storage bucket

Create a storage bucket to store intermediate artifacts such as models.

In [None]:
BUCKET_NAME = "your-bucket-name-{PROJECT_ID}-unique"  # @param {type:"string"}

BUCKET_URI = f"gs://{BUCKET_NAME}"  # @param {type:"string"}

In [None]:
! gsutil mb -l {LOCATION} -p {PROJECT_ID} {BUCKET_URI}

#### Set Service Account and permissions

You will need to have the following IAM roles set:

- Vertex AI User (roles/aiplatform.user)
- BigQuery Data Editor (roles/bigquery.dataEditor)
- Storage Object Admin (roles/storage.objectAdmin)

For more information about granting roles, see [Manage access](https://cloud.google.com/iam/docs/granting-changing-revoking-access).


> If you run following commands using Vertex AI Workbench, run directly in the terminal.


In [None]:
SERVICE_ACCOUNT = f"{PROJECT_NUMBER}-compute@developer.gserviceaccount.com"

In [None]:
for role in ['aiplatform.user', 'storage.objectAdmin', 'bigquery.dataEditor']:

    ! gcloud projects add-iam-policy-binding {PROJECT_ID} \
      --member=serviceAccount:{SERVICE_ACCOUNT} \
      --role=roles/{role} --condition=None

### Set and create a BigQuery table

Create a BigQuery table to store predictions.

In [None]:
from datetime import datetime

from google.cloud import bigquery


def create_bq_table(
    dataset_id: str,
    project_id: str = PROJECT_ID,
    location: str = LOCATION,
) -> tuple[str, str]:
    """
    Creates a BigQuery dataset and generates a table URI for batch predictions.
    """

    # Initialize BigQuery client
    bq_client = bigquery.Client(project=project_id, location=location)

    # Create dataset reference
    dataset_path = f"{project_id}.{dataset_id}"
    dataset = bigquery.Dataset(dataset_path)
    dataset.location = location

    # Create or get existing dataset
    dataset = bq_client.create_dataset(dataset, exists_ok=True, timeout=30)

    # Generate table URI with timestamp
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    table_id = f"prediction_result_{timestamp}"
    table_uri = f"bq://{project_id}.{dataset_id}.{table_id}"

    return table_uri

In [None]:
BQ_DATASET = "gen_ai_batch_prediction"  # @param {type:"string"}
OUTPUT_TABLE_URI = create_bq_table(dataset_id=BQ_DATASET)

### Initiate Vertex AI SDK

In [None]:
import vertexai

vertexai.init(project=PROJECT_ID, location=LOCATION, staging_bucket=BUCKET_URI)

### Import libraries

In [None]:
from typing import NamedTuple

from google.cloud import aiplatform
from google_cloud_pipeline_components.types.artifact_types import VertexDataset
from google_cloud_pipeline_components.v1.dataset import TabularDatasetCreateOp
from google_cloud_pipeline_components.v1.vertex_notification_email import (
    VertexNotificationEmailOp,
)
from kfp import compiler, dsl
from kfp.dsl import Markdown, Output, component

### Set constants

In [None]:
INPUT_TABLE_URI = (
    "bq://storage-samples.generative_ai.batch_requests_for_multimodal_input_2"
)
MODEL_ID = "gemini-1.5-pro-002"  # @param {type:"string", isTemplate: true}
RECIPIENTS = [
    "your-email@provider.com"
]  # @param {type: "string", placeholder: "[your-email@provider.com]", isTemplate: true}
PIPELINE_ROOT = f"{BUCKET_URI}/genai-prediction-pipeline"

### Build the Batch prediction component

Define a lightweight custom Kubeflow Pipelines component for running batch prediction jobs using Vertex AI's Generative Models.

It takes an input BigQuery table, submits it to a specified Generative Model for batch prediction, and outputs the resulting predictions to a specified output BigQuery table location.

The component monitors the job's progress and logs relevant information. Upon successful completion, it returns the URI of the output dataset.


In [None]:
@component(
    base_image="python:3.10",
    packages_to_install=["google-cloud-aiplatform", "google_cloud_pipeline_components"],
)
def GenAIModelBatchPredictOp(
    input_bq_table: str,
    output_bq_table: str,
    model_id: str,
    project: str,
    location: str,
    output_dataset_artifact: Output[VertexDataset],
) -> NamedTuple("outputs", dataset_uri=str):

    import logging
    import sys
    import time

    import vertexai
    from vertexai.batch_prediction import BatchPredictionJob
    from vertexai.generative_models import GenerativeModel

    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    logger = logging.getLogger(__name__)

    # Initiate Vertex AI session
    logger.info(
        f"Initializing Vertex AI session with project: {project}, location: {location}"
    )
    vertexai.init(project=project, location=location)

    # Initiate the model
    logger.info(f"Initializing GenerativeModel with model_id: {model_id}")
    model = GenerativeModel(model_id)

    # Send the batch prediction request
    logger.info(f"Submitting batch prediction job - Input table: {input_bq_table}")
    logger.info(f"Output will be stored at: {output_bq_table}")
    job = BatchPredictionJob.submit(
        source_model=model_id,
        input_dataset=input_bq_table,
        output_uri_prefix=output_bq_table,
    )

    # Monitor the job
    start_time = time.time()
    logger.info("Starting job monitoring...")
    while not job.has_ended:
        elapsed_time = time.time() - start_time
        logger.info(f"Job running... Elapsed time: {elapsed_time:.2f} seconds")
        time.sleep(60)
        job.refresh()

    # Check if the job succeeds
    if job.has_succeeded:
        total_time = time.time() - start_time
        logger.info(f"Job completed successfully in {total_time:.2f} seconds!")
        logger.info(f"Output dataset available at: {output_bq_table}")
    else:
        logger.error(f"Job failed with error: {job.error}")
        sys.exit(1)

    output_bq_table = job.output_location
    component_outputs = NamedTuple("outputs", dataset_uri=str)
    logger.info(f"Returning component output with dataset_uri: {output_bq_table}")
    return component_outputs(output_bq_table)

### Build a component to visualize the prediction table

Define a lightweight custom Kubeflow Pipelines component for visualizing a prediction sample.

The component takes the BigQuery table name, sample size, project, and location as inputs and outputs a markdown file. It uses the BigQuery Python client to query the table, pandas to process the data, and incorporates logging for visualization.


In [None]:
@component(
    base_image="python:3.10",
    packages_to_install=[
        "google-cloud-bigquery[pandas]",
        "google_cloud_pipeline_components",
    ],
)
def VisualizeBatchPredictionTable(
    output_bq_table: str,
    sample_size: int,
    project: str,
    location: str,
    output_markdown_table: Output[Markdown],
):
    import logging
    import sys

    from google.cloud import bigquery
    import pandas as pd

    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    logger = logging.getLogger(__name__)

    # Helper to extract only request and response text from the records
    def extract_text(record):
        try:
            request_text = record["request"]["contents"][0]["parts"][0]["text"]
            response_text = record["response"]["candidates"][0]["content"]["parts"][0][
                "text"
            ]
            return {"Request": request_text, "Response": response_text}
        except (KeyError, IndexError) as e:
            logger.warning(f"Could not extract text from record: {e}")
            return {"Request": "", "Response": ""}

    # Helper function to escape pipe characters and handle multiline content
    def escape_cell(val):
        if val is None:
            return ""
        val_str = str(val)
        # Escape pipe characters
        val_str = val_str.replace("|", "\\|")
        # Replace newlines with <br>
        val_str = val_str.replace("\n", "<br>")
        return val_str

    # Initialize BigQuery client
    logger.info(f"Initializing BigQuery client for project: {project} in {location}")
    client = bigquery.Client(project=project, location=location)

    # Construct and execute query
    output_bq_table = output_bq_table.replace("bq://", "")
    query = f"""
        SELECT *
        FROM `{output_bq_table}`
        LIMIT {sample_size}
    """
    logger.info(f"Executing query on dataset: {output_bq_table}")
    logger.info(f"Sampling {sample_size} rows")
    df = client.query(query).to_dataframe()
    logger.info(f"Query returned {len(df)} rows and {len(df.columns)} columns")
    if df.empty:
        logger.error("No data found in table")
        sys.exit(1)

    # Process DataFrame to extract texts
    logger.info("Extracting request and response texts")
    processed_records = [extract_text(record) for record in df.to_dict("records")]
    processed_df = pd.DataFrame(processed_records)

    # Format markdown table with proper escaping
    logger.info("Converting DataFrame to markdown format")
    headers = "|" + "|".join(str(col) for col in processed_df.columns) + "|"
    separator = "|" + "|".join("---" for _ in processed_df.columns) + "|"

    rows = []
    for idx, row in processed_df.iterrows():
        row_str = "|" + "|".join(escape_cell(val) for val in row) + "|"
        rows.append(row_str)

    # Combine all parts and write to file
    markdown_table = "\n".join([headers, separator] + rows)
    logger.info(f"Writing markdown table to: {output_markdown_table.path}")
    with open(output_markdown_table.path, "w") as f:
        f.write(markdown_table)

### Define your workflow using Kubeflow Pipelines DSL package

The kfp.dsl package contains the domain-specific language (DSL) that you can use to build the pipeline for running Gen AI batch prediction workflow.

In [None]:
@dsl.pipeline(name="genai-batch-prediction-pipeline")
def pipeline(
    input_dataset_name: str,
    input_bq_table: str,
    output_bq_table: str,
    model_id: str,
    sample_size: int,
    project: str = PROJECT_ID,
    location: str = LOCATION,
    recipients: list = RECIPIENTS,
):

    notify_email_task = VertexNotificationEmailOp(recipients=recipients)

    create_input_dataset_task = TabularDatasetCreateOp(
        display_name=input_dataset_name,
        bq_source=input_bq_table,
        project=project,
        location=location,
    ).set_display_name("Create input dataset")

    with dsl.ExitHandler(notify_email_task, name="Notification handler"):

        run_batch_prediction_task = (
            GenAIModelBatchPredictOp(
                input_bq_table=input_bq_table,
                output_bq_table=output_bq_table,
                model_id=model_id,
                project=project,
                location=location,
            )
            .after(create_input_dataset_task)
            .set_display_name("Run Gen AI Batch Prediction job")
        )

        visualize_prediction_task = (
            VisualizeBatchPredictionTable(
                output_bq_table=output_bq_table,
                sample_size=sample_size,
                project=project,
                location=location,
            )
            .after(run_batch_prediction_task)
            .set_display_name("Visualize Gen AI Predictions")
        )

### Compile your pipeline into a YAML file

After the workflow of your pipeline is defined, compile the pipeline into YAML format for executing your pipeline on Vertex AI Pipelines.

In [None]:
compiler.Compiler().compile(pipeline_func=pipeline, package_path="pipeline.yaml")

#### Submit your pipeline run

After compiling your pipeline, use the Vertex AI Python client to submit and run your pipeline.

In [None]:
parameter_values = {
    "input_dataset_name": "genai_input_prediction_dataset",
    "input_bq_table": INPUT_TABLE_URI,
    "output_bq_table": OUTPUT_TABLE_URI,
    "model_id": MODEL_ID,
    "sample_size": 10,
    "project": PROJECT_ID,
    "location": LOCATION,
    "recipients": RECIPIENTS,
}

job = aiplatform.PipelineJob(
    display_name="census-demo-pipeline",
    parameter_values=parameter_values,
    template_path="pipeline.yaml",
    pipeline_root=PIPELINE_ROOT,
)

job.run()

## Cleaning up


In [None]:
delete_pipeline_job = True
delete_bigquery_dataset = True
delete_bucket = True

if delete_pipeline_job:
    job.delete()

# Delete the Cloud Storage bucket
if delete_bucket:
    ! gsutil -m rm -r {BUCKET_URI}

# delete dataset
if delete_bigquery_dataset:
    ! bq rm -r -f -d {PROJECT_ID}:{BQ_DATASET}