In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Evaluating prompts at scale with Gemini Batch Prediction API

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/evaluation/evaluating_prompts_at_scale_with_gemini_batch_prediction_api.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Fevaluation%2Fevaluating_prompts_at_scale_with_gemini_batch_prediction_api.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>    
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/evaluation/evaluating_prompts_at_scale_with_gemini_batch_prediction_api.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo"><br> Open in Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/evaluation/evaluating_prompts_at_scale_with_gemini_batch_prediction_api.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

| | |
|-|-|
|Author(s) | [Ariel Jassan](https://github.com/arieljassan) |

## Introduction

This tutorial guides you through the process of evaluating the effectiveness of your prompts at scale using the Gemini Batch Prediction API. Even though in this tutorial we will do image classification, it can be extended to other cases as well. One of the benefits of using the Gemini Batch Prediction API is that you can evaluate your prompts and setup in Gemini using hundreds of examples with one single request.

For the purpose of this tutorial, we will execute a prompt to classify images into classes of sports. The data is based on an excerpt of the datase that can be found in https://www.kaggle.com/datasets/gpiosenka/sports-classification.


## Steps

1. **Prepare the data in BigQuery and GCS**
    * Upload sample images to Google Cloud Storage and create ground truth table in BigQuery.
    
2. **Run Gemini Batch Prediciton API**
    * Send prompts to Gemini for batch prediction and get results in BigQuery.

3. **Analyze results in BigQuery and Looker Studio**
    * Present findings, focusing on prompt/dataset strengths and weaknesses.

## Getting started

### Install dependencies

In [1]:
! pip install google-cloud-aiplatform --upgrade -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[?25h

### Restart Colab

In [2]:
# You will see a notification of Colab crashing. It is the expected behavior.
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

{'status': 'ok', 'restart': True}

### Authenticate your notebook environment (Colab only)

In [1]:
import sys

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user()

### Define constants

In [80]:
PROJECT_ID = "your-project-id"
LOCATION = "us-central1"

# Generative model.
MODEL_ID = "gemini-1.5-flash-001"

# BigQuery tables.
BQ_DATASET_ID = "text_extraction_3"
BQ_DATASET = f"{PROJECT_ID}.{BQ_DATASET_ID}"
FILES_TABLE = f"{BQ_DATASET_ID}.sports_files"
PROMPTS_TABLE = f"{BQ_DATASET}.temp_prompts"
TEXT_GENERATION_TABLE_PREFIX = f"{BQ_DATASET}.results"

# BigQuery views.
RESULTS_VIEW = f"{BQ_DATASET}.extraction_results"
EVALUATION_VIEW = f"{BQ_DATASET}.evaluation"

# GCS Bucket.
BUCKET_NAME = "your-bucket-name"
SPORTS_FILE = f"gs://{BUCKET_NAME}/sports_files.csv"

# Stop states from .
STOP_STATES = ("JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED")

### Import libraries and initialize clients

In [81]:
import base64
import datetime
from io import BytesIO
import json
import time
import requests

import google.auth
import google.auth.transport.requests
from google.cloud import bigquery
from google.cloud import storage
import vertexai
from vertexai.generative_models import GenerativeModel, Part


# BigQuery client.
bq_client = bigquery.Client(project=PROJECT_ID)

# Google Cloud Storage client.
storage_client = storage.Client()

# Initialize Vertex AI SDK.
vertexai.init(project=PROJECT_ID, location=LOCATION)

## Data preparation

In this section we will create the bucket with images in Google Cloud Storage, create the dataset in BigQuery, load the table with ground truth, and create the views that will serve for analysis of the results from Gemini and reporting in Looker Studio.

### Get sample images and upload them to a GCS bucket
Images are a subset of the sports classification dataset in https://www.kaggle.com/datasets/gpiosenka/sports-classification

In [None]:
# TODO: update with url
# Download sample data to notebook.
!wget https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/evaluation/data/sports_files.zip
!unzip /content/sports_files.zip

# Create bucket.
!gcloud storage buckets create gs://{BUCKET_NAME} --location={LOCATION}

# Copy images to bucket.
!gcloud storage cp -r -n /content/sports_files/ gs://{BUCKET_NAME}/

### Create BigQuery dataset and load table with ground truth

In [66]:
def create_dataset(dataset_id, project, location):
    """Creates a BigQuery dataset."""
    dataset = bigquery.Dataset(dataset_id)
    dataset.location = location

    dataset = bq_client.create_dataset(dataset, timeout=30)
    print("Created dataset {}.{}".format(bq_client.project, dataset.dataset_id))


def load_files_table_from_uri(files_table, uri):
    """Load ground truth table from a URI."""
    job_config = bigquery.LoadJobConfig(
        schema=[
            bigquery.SchemaField("path", "STRING"),
            bigquery.SchemaField("label", "STRING"),
        ],
        skip_leading_rows=1,
        source_format=bigquery.SourceFormat.CSV,
        write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
    )
    load_job = bq_client.load_table_from_uri(uri, files_table, job_config=job_config)
    load_job.result()

    destination_table = bq_client.get_table(files_table)
    print("Loaded {} rows.".format(destination_table.num_rows))


create_dataset(dataset_id=BQ_DATASET, project=PROJECT_ID, location=LOCATION)
load_files_table_from_uri(files_table=FILES_TABLE, uri=SPORTS_FILE)

Created dataset arielj-argolis-1.text_extraction_3
Loaded 274 rows.


### Test image URIs from BigQuery

In [82]:
def get_filepaths(files_table):
    """Get filepaths from the ground truth table in BigQuery."""
    job = bq_client.query(
        f"""
      SELECT path
      FROM {files_table}"""
    )
    return [row[0] for row in job.result()]


images_uri = get_filepaths(files_table=FILES_TABLE)
print(images_uri[:2])

['train/tennis/064.jpg', 'train/tennis/045.jpg']


### Create view of text generation results

Run this only once to create the view

In [101]:
def create_text_generation_view(text_generation_table_prefix, results_view):
    """Creates a view of text extraction results."""

    view = bigquery.Table(results_view)

    view.view_query = rf"""
      WITH t1 AS
      (
        SELECT
          evaluation_id,
          evaluation_ts,
          prompt_text,
          gcs_uri,
          JSON_EXTRACT(response, '$[0].content.parts[0].text') AS json_data
        FROM `{text_generation_table_prefix}_*`
      ),
      t2 AS (
        SELECT
          evaluation_id,
          evaluation_ts,
          prompt_text,
          gcs_uri,
          REGEXP_EXTRACT(json_data, r'```json(.*)```') AS f
        FROM t1
      ),
      t3 AS(
        SELECT
          evaluation_id,
          evaluation_ts,
          prompt_text,
          gcs_uri,
          REPLACE(f, '\\n', '') AS f
        FROM t2
      ),
      t4 AS (
        SELECT
          evaluation_id,
          evaluation_ts,
          prompt_text,
          gcs_uri,
          REPLACE(f, '\\"', '"') AS f
        FROM t3
      ),
      t5 AS (
        SELECT
          evaluation_id,
          evaluation_ts,
          prompt_text,
          gcs_uri,
          JSON_QUERY(f, '$.sport') AS f
        FROM t4
      ),
      t6 AS (
        SELECT
          evaluation_id,
          evaluation_ts,
          prompt_text,
          gcs_uri,
          REPLACE(f, '"', '') AS f
        FROM t5
      )

      SELECT
        evaluation_id,
        evaluation_ts,
        prompt_text,
        gcs_uri,
        f AS label
      FROM t6"""

    # Make an API request to create the view.
    view = bq_client.create_table(view, exists_ok=False)
    print(f"Created {view.table_type}: {str(view.reference)}")


create_text_generation_view(
    text_generation_table_prefix=TEXT_GENERATION_TABLE_PREFIX, results_view=RESULTS_VIEW
)

Created VIEW: arielj-argolis-1.text_extraction_3.extraction_results


### Create view of experiment evaluation

Run this only once to create the view.

In [79]:
def create_evaluation_view(evaluation_view, files_table, results_view):
    """Creates a view of experiment evaluation."""

    view = bigquery.Table(evaluation_view)

    view.view_query = f"""
      WITH t1 AS (
        SELECT
          e.evaluation_id,
          e.evaluation_ts,
          e.prompt_text,
          f.path,
          f.label,
          e.gcs_uri,
          f.label = e.label AS correct
        FROM `{files_table}` f
        JOIN `{results_view}` e
          ON f.path = e.gcs_uri
      )

      SELECT
        evaluation_id,
        evaluation_ts,
        prompt_text,
        path,
        label,
        correct
      FROM t1"""

    # Make an API request to create the view.
    view = bq_client.create_table(view, exists_ok=False)
    print(f"Created {view.table_type}: {str(view.reference)}")


create_evaluation_view(
    evaluation_view=EVALUATION_VIEW, files_table=FILES_TABLE, results_view=RESULTS_VIEW
)

Created VIEW: arielj-argolis-1.text_extraction_3.evaluation


## Define prompt and execute it via Gemini Batch Prediction API

### Define the prompt

In [83]:
prompt = """\
- Classify the sport from the image below in one of the following categories:
* baseball
* basketball
* tennis
* volleyball

- Provide an answer in JSON format. 3. Example response:
'{"sport": "baseball"}'

- Image:
"""

### Classify one image using the Python SDK

In [84]:
def download_blob_into_memory(bucket_name, blob_name):
    """Downloads a blob from GCS into memory."""
    bucket = storage_client.bucket(bucket_name)

    blob = bucket.blob(blob_name)
    contents = blob.download_as_bytes()
    return contents


def classify_image(model_id, prompt, bucket_name, blob_name):
    """Classifies an image."""
    model = GenerativeModel(model_id)
    contents = download_blob_into_memory(bucket_name, blob_name)
    encoded_image = base64.b64encode(contents).decode("utf-8")
    image_content = Part.from_data(
        data=base64.b64decode(encoded_image), mime_type="image/jpeg"
    )
    contents = [prompt, image_content]
    response = model.generate_content(contents)
    return response


blob_name = get_filepaths(files_table=FILES_TABLE)[1]
response = classify_image(
    model_id=MODEL_ID, prompt=prompt, bucket_name=BUCKET_NAME, blob_name=blob_name
)
print(f"blob_name: {blob_name}")
print(f"response: {response.text}")

blob_name: train/tennis/045.jpg
response: ```json
{"sport": "tennis"}
```


### Create a New Line JSON file applying the prompt to each of the images
In this section, also an `evaluation_id` variable is created to identify the execution run.

In [85]:
def create_newline_json_file(
    prompt, evaluation_ts, evaluation_id, file_name, bucket_name, images_uri
):
    """Creates a newline delimited JSON file."""
    with open(file_name, "w") as outfile:
        for image_uri in images_uri:
            contents = download_blob_into_memory(bucket_name, image_uri)
            encoded_image = base64.b64encode(contents).decode("utf-8")
            request = {
                "contents": [
                    {
                        "role": "user",
                        "parts": [
                            {"text": prompt},
                            {
                                "inlineData": {
                                    "mimeType": "image/jpeg",
                                    "data": encoded_image,
                                }
                            },
                        ],
                    }
                ]
            }
            line = {
                "evaluation_ts": evaluation_ts,
                "evaluation_id": evaluation_id,
                "prompt_text": prompt,
                "gcs_uri": image_uri,
                "request": request,
            }

            outfile.write(json.dumps(line))
            outfile.write("\n")


now = datetime.datetime.now()
evaluation_ts = str(now)
evaluation_id = f"{now.year}_{now.month}_{now.day}_{now.hour}_{now.minute}"
json_file_name = f"/tmp/{evaluation_id}.json"

create_newline_json_file(
    prompt=prompt,
    evaluation_ts=evaluation_ts,
    evaluation_id=evaluation_id,
    file_name=json_file_name,
    bucket_name=BUCKET_NAME,
    images_uri=images_uri,
)

### Upload the newline delimited JSON file to BigQuery

In [86]:
def upload_newline_json_file(json_file_name, project_id, prompts_table):
    """Uploads a newline delimited JSON file to BigQuery."""
    job_config = bigquery.LoadJobConfig(
        schema=[
            bigquery.SchemaField("evaluation_ts", "STRING"),
            bigquery.SchemaField("evaluation_id", "STRING"),
            bigquery.SchemaField("prompt_text", "STRING"),
            bigquery.SchemaField("gcs_uri", "STRING"),
            bigquery.SchemaField("request", "JSON"),
        ],
        source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON,
        write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
    )

    with open(json_file_name, "rb") as source_file:
        job = bq_client.load_table_from_file(
            source_file, PROMPTS_TABLE, job_config=job_config
        )

    job.result()
    table = bq_client.get_table(prompts_table)
    print(
        "Loaded {} rows and {} columns to {}".format(
            table.num_rows, len(table.schema), prompts_table
        )
    )


upload_newline_json_file(
    json_file_name=json_file_name, project_id=PROJECT_ID, prompts_table=PROMPTS_TABLE
)

Loaded 274 rows and 5 columns to arielj-argolis-1.text_extraction_3.temp_prompts


### Launch a Gemini Batch Prediction request

In [87]:
# Define table to store results from Gemini Batch Prediction.
text_generation_table = f"{TEXT_GENERATION_TABLE_PREFIX}_{evaluation_id}"


def create_batch_prediction_job(
    project_id, model_id, prompts_table, text_generation_table
):
    """Creates a Gemini batch prediction job."""

    gemini_batch_url = (
        f"https://us-central1-aiplatform.googleapis.com/v1/projects/{project_id}/"
        f"locations/us-central1/batchPredictionJobs"
    )

    # Get Authentication token.
    creds, _ = google.auth.default()
    auth_req = google.auth.transport.requests.Request()
    creds.refresh(auth_req)
    token = creds.token

    # Build request.
    request_data = {
        "displayName": evaluation_id,
        "model": f"publishers/google/models/{model_id}",
        "inputConfig": {
            "instancesFormat": "bigquery",
            "bigquerySource": {"inputUri": f"bq://{prompts_table}"},
        },
        "outputConfig": {
            "predictionsFormat": "bigquery",
            "bigqueryDestination": {"outputUri": f"bq://{text_generation_table}"},
        },
    }
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {token}",
    }

    batch_response = requests.post(gemini_batch_url, json=request_data, headers=headers)
    print(batch_response.text)

    # Get job id of the request.
    batch_response_json = batch_response.json()
    job_name = batch_response_json["name"]
    job_id = job_name.split("/")[-1]

    # Get state of the request.
    job_state = batch_response_json["state"]
    return job_id, job_state


job_id, job_state = create_batch_prediction_job(
    project_id=PROJECT_ID,
    model_id=MODEL_ID,
    prompts_table=PROMPTS_TABLE,
    text_generation_table=text_generation_table,
)

{
  "name": "projects/743100398377/locations/us-central1/batchPredictionJobs/5373658193635311616",
  "displayName": "2024_7_25_11_19",
  "model": "publishers/google/models/gemini-1.5-flash-001",
  "inputConfig": {
    "instancesFormat": "bigquery",
    "bigquerySource": {
      "inputUri": "bq://arielj-argolis-1.text_extraction_3.temp_prompts"
    }
  },
  "outputConfig": {
    "predictionsFormat": "bigquery",
    "bigqueryDestination": {
      "outputUri": "bq://arielj-argolis-1.text_extraction_3.results_2024_7_25_11_19"
    }
  },
  "state": "JOB_STATE_PENDING",
  "createTime": "2024-07-25T11:21:32.461969Z",
  "updateTime": "2024-07-25T11:21:32.461969Z",
  "modelVersionId": "1"
}



### Check state of the Gemini Batch Prediction request until completed

In [88]:
def check_job_state(project_id, job_id, last_job_state):
    """Checks status of the Gemini batch text generation request."""

    get_status_url = (
        f"https://us-central1-aiplatform.googleapis.com/v1/projects/{project_id}/"
        f"locations/us-central1/batchPredictionJobs/{job_id}"
    )

    # Get Authentication token.
    creds, _ = google.auth.default()
    auth_req = google.auth.transport.requests.Request()
    creds.refresh(auth_req)
    token = creds.token

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {token}",
    }

    job_state = last_job_state
    now = datetime.datetime.now()
    print(f"Status {now}: {job_state}")

    while job_state not in STOP_STATES:
        state_response = requests.get(get_status_url, headers=headers)
        state_response_json = state_response.json()
        job_state = state_response_json["state"]
        now = datetime.datetime.now()
        print(f"Status {now}: {job_state}")
        if job_state not in STOP_STATES:
            time.sleep(30)


check_job_state(project_id=PROJECT_ID, job_id=job_id, last_job_state=job_state)

Status 2024-07-25 11:21:42.653686: JOB_STATE_PENDING
Status 2024-07-25 11:21:43.049581: JOB_STATE_QUEUED
Status 2024-07-25 11:22:13.409234: JOB_STATE_QUEUED
Status 2024-07-25 11:22:43.835615: JOB_STATE_QUEUED
Status 2024-07-25 11:23:14.256927: JOB_STATE_QUEUED
Status 2024-07-25 11:23:44.634046: JOB_STATE_QUEUED
Status 2024-07-25 11:24:15.031125: JOB_STATE_QUEUED
Status 2024-07-25 11:24:45.450595: JOB_STATE_QUEUED
Status 2024-07-25 11:25:15.830552: JOB_STATE_QUEUED
Status 2024-07-25 11:25:46.277329: JOB_STATE_SUCCEEDED


### List sample of text generation results from BigQuery

In [89]:
def print_text_generation_results(text_generation_table):
    """Lists text generation results from BigQuery."""

    job = bq_client.query(
        f"""
      SELECT gcs_uri, response, status, processed_time
      FROM {text_generation_table}
      LIMIT 5"""
    )

    for row in job.result():
        json_row = json.loads(row[1])
        json_response = json_row[0]["content"]["parts"][0]["text"]
        print(json_response)


print_text_generation_results(text_generation_table=text_generation_table)

```json
{"sport": "volleyball"}
```
```json
{"sport": "volleyball"}
```
```json
{"sport": "tennis"}
```
```json
{"sport": "tennis"}
```
```json
{"sport": "tennis"}
```


## Analyze results in BigQuery and Looker Studio

### Copy a Looker Studio dashboard to analyze results

1. Make a copy of this [Looker Studio dashboard](https://lookerstudio.google.com/reporting/6dd5a7e8-b353-4dde-9bd8-72eb7b501559)
1. Connect dashboard to your view