In [None]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SQL Code Generation on Vertex AI

**NOTE:** This notebook uses the PaLM generative model, which will reach its [discontinuation date in October 2024](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text#model_versions). Please refer to [this updated notebook](https://github.com/GoogleCloudPlatform/generative-ai/blob/c86e9da59015a269894bc5ccf91ff08f33cdee44/language/code/code_generation.ipynb) for a version which uses the latest Gemini model.

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/language/use-cases/sql-code-generation/sql_code_generation.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Flanguage%2Fuse-cases%2Fsql-code-generation%2Fsql_code_generation.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/language/use-cases/sql-code-generation/sql_code_generation.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/bigquery/import?url=https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/use-cases/sql-code-generation/sql_code_generation.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/bigquery/v1/32px.svg" alt="BigQuery Studio logo"><br> Open in BigQuery Studio
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/generative-ai/blob/main/language/use-cases/sql-code-generation/sql_code_generation.ipynb">
      <img width="32px" src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

<div style="clear: both;"></div>

<b>Share to:</b>

<a href="https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/language/use-cases/sql-code-generation/sql_code_generation.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg" alt="LinkedIn logo">
</a>

<a href="https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/language/use-cases/sql-code-generation/sql_code_generation.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg" alt="Bluesky logo">
</a>

<a href="https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/language/use-cases/sql-code-generation/sql_code_generation.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/53/X_logo_2023_original.svg" alt="X logo">
</a>

<a href="https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/language/use-cases/sql-code-generation/sql_code_generation.ipynb" target="_blank">
  <img width="20px" src="https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" alt="Reddit logo">
</a>

<a href="https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/language/use-cases/sql-code-generation/sql_code_generation.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg" alt="Facebook logo">
</a>            

| | |
|-|-|
| Author(s): | [Roy Arsan](https://www.linkedin.com/in/arsan) |

## Overview
Large language models can be used for generating code, including SQL. In particular, models can convert natural language text into SQL queries. One common purpose is to enable users to query data without requiring knowledge of tables' names, data schema nor the specific SQL dialect or query engine of the underlying data warehouse like BigQuery.

This notebook covers prompt engineering best practices for SQL code generation, and puts in practice learnings from [SQL-PaLM: Improve Large Language Model Adaptation for text-to-SQL](https://arxiv.org/pdf/2306.00739.pdf). For example, the BigQuery dataset schema is retrieved and provided dynamically as context to the prompt, for grounding the LLM and personalizing its output. The notebook also demonstrates simple model evaluation whereby the generated SQL queries are evaluated by executing them against the BigQuery dataset, and by comparing them with ground truth queries and corresponding results.

For this notebook, you generate SQL queries to analyze Cloud Audit Logs and answer critical security questions around activity in your own Google Cloud project. While this notebook uses BigQuery logs dataset, the concepts and approach presented here can be applied to other databases and datasets.

![NL2SQL flow](https://services.google.com/fh/files/misc/nl2sql_for_log_analytics2.png)

### Objective

By the end of the notebook, you should be able to:

* Use model to generate SQL queries based on Natural Language questions:
  * Using few-shot prompting
  * Providing custom dataset schemas as context
  * Formatting model output

* Evaluate model-generated queries by:
  * Executing sanitized queries against live dataset
  * Comparing queries (and their results) to ground truth queries using simple fuzzy string matching
  * Calculating model accuracy score

In addition, you can use this notebook to answer your own security questions from your own audit logs, such as:

- Any unusually high cloud API usage by any user identity over the last month?
- Any destructive actions by an unapproved identity over the last 7 days?
- Any unusual day-to-day spike in data volume accessed by any user this week?


## Getting Started

### Prerequisite
 If you haven't already done so, the only requirement is to [upgrade your existing log bucket](https://cloud.google.com/logging/docs/buckets#upgrade-bucket) to use Log Analytics which provides you with a linked BigQuery dataset with your own queryable logs data. This is a **one-click step without incurring additional costs**. By default, Cloud Audit Admin Activity logs are enabled, ingested and stored in every project's `_Required` bucket without any charges.

![one click prerequisite](https://services.google.com/fh/files/misc/upgrade_log_bucket.png)

### Install SDKs

In [None]:
# Install Vertex AI SDK to use for model predictions
%pip install google-cloud-aiplatform google-cloud-bigquery --upgrade --user

# Install fuzzy string comparison modules for model output evaluation
%pip install -q python-Levenshtein --upgrade --user
%pip install -q fuzzywuzzy --upgrade --user

Install Python SDK for Google Sheets only if you wish to later save the model-generated SQL queries and their results into a Google Sheet for subsequent troubleshooting or to expand your few-shot examples dataset. This is **not applicable for Vertex AI Workbench or Colab Enterprise** because notebooks in those environments cannot access Google Drive or Google Sheets for security purposes.

In [None]:
# [Optional] Install Python SDK for Google Sheets
%pip install gspread --upgrade --user

**Colab only:** Uncomment the following cell to restart the kernel or use the button to restart the kernel. For Vertex AI Workbench you can restart the terminal using the button on top. 

In [None]:
# # Automatically restart kernel after installs so that your environment can access the new packages
# import IPython

# app = IPython.Application.instance()
# app.kernel.do_shutdown(True)

### Authenticating your notebook environment
* If you are using **Colab** to run this notebook, run the cell below and continue.
* If you are using **Vertex AI Workbench**, check out the setup instructions [here](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/setup-env).

In [None]:
import sys

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user()

In [None]:
# For debug only
!gcloud config list --format 'value(core.account)'

### Import libraries

**Colab only:** Uncomment the following cell to initialize the Vertex AI SDK. For Vertex AI Workbench, you don't need to run this.  

In [None]:
# PROJECT_ID = "[your-project-id]" # @param {type:"string"}
# LOCATION = "us-central1" # @param {type:"string"}

# from google.cloud import aiplatform
# aiplatform.init(project=PROJECT_ID, location=LOCATION)

In [None]:
import numpy as np
import pandas as pd
from vertexai.language_models import TextGenerationModel

### Set project and datasets for BigQuery

This is the project containing:
 - The linked BigQuery dataset `BQ_LINKED_DATASET` with your raw logs, and,
 - A new BigQuery dataset `BQ_PROCESSED_DATASET` you'll create to store the processed logs.

This project could be the same or a separate project than the one you're using for Vertex AI.

Make sure you have **BigQuery Data Viewer** role over `BQ_LINKED_DATASET` dataset.

In [None]:
BQ_PROJECT_ID = "[bq-project-id]"  # @param {type:"string"}
BQ_LINKED_DATASET = "[linked-bq-dataset]"  # @param {type:"string"}
BQ_PROCESSED_DATASET = "[new-bq-dataset]"  # @param {type:"string"}

from google.cloud import bigquery

client = bigquery.Client(project=BQ_PROJECT_ID)

### Import models

We will interact with Vertex AI LLM model `text-bison`:

In [None]:
MODEL_ID = "text-bison"  # @param {type:"string"}

model = TextGenerationModel.from_pretrained(MODEL_ID)

## Prepare the data

> You can skip this section if your raw logs are already processed and normalized in curated tables using [Dataform as part of Community Security Analytics](https://github.com/GoogleCloudPlatform/security-analytics/tree/main/dataform) (CSA). For more information on CSA and how to automatically and continuously build post-processed tables out of your raw logs, see this [Google Cloud blog post](https://cloud.google.com/blog/products/data-analytics/deploy-community-security-analytics-with-dataform).

Like any other AI/ML project, first thing is to prepare your data including datasets for few-shot prompting and subsequent evaluation. You'll preprocess the raw logs that reside in your BigQuery linked dataset into a summary table into your new BigQuery dataset. This table will contain the logs in aggregated form and also normalized into a simple schema. This allows you to unlock and scale ML analysis:
- From a computation point of view because the dataset is smaller and simple.
- From a talent point of view because researchers and analysts are not required to be familiar with the complex schema of raw logs ([LogEntry definition](https://cloud.google.com/logging/docs/reference/v2/rest/v2/LogEntry)).


### Create new dataset

In [None]:
!bq --location=US mk --dataset {BQ_PROJECT_ID}:{BQ_PROCESSED_DATASET}

### Build table of user actions

Let's search and process the audit logs to create a table of user actions aggregated by day. This summary table reduces the set to relevant records and simplifies the structure which in turn simplifies exploratory and advanced analytics. For those interested, we use the same SQL query as defined in CSA repo, specifically [`csa_4_01_summary_daily`](https://github.com/GoogleCloudPlatform/security-analytics/blob/main/dataform/definitions/summary/csa_4_01_summary_daily.sqlx)  Dataform definition file.

In [None]:
TABLE_NAME = "csa_4_01_summary_daily"
TABLE_ID = f"{BQ_PROJECT_ID}.{BQ_PROCESSED_DATASET}.{TABLE_NAME}"
SUMMARY_LOOKBACK_DAYS = 90

job_config = bigquery.QueryJobConfig(
    destination=TABLE_ID, write_disposition="WRITE_TRUNCATE"
)

sql = f"""
SELECT
  EXTRACT(DATE FROM timestamp) AS day,
  proto_payload.audit_log.authentication_info.principal_email,
  ARRAY_AGG(DISTINCT proto_payload.audit_log.method_name IGNORE NULLS) AS actions,
  COUNT(*) AS counter
FROM `{BQ_PROJECT_ID}.{BQ_LINKED_DATASET}._AllLogs`
WHERE
  timestamp >=  TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL {SUMMARY_LOOKBACK_DAYS} DAY)
  AND proto_payload.audit_log.authentication_info.principal_email IS NOT NULL
  AND proto_payload.audit_log.method_name NOT LIKE "storage.%.get"
  AND proto_payload.audit_log.method_name NOT LIKE "v1.compute.%.list"
  AND proto_payload.audit_log.method_name NOT LIKE "beta.compute.%.list"
GROUP BY
  day,
  proto_payload.audit_log.authentication_info.principal_email
"""

# Start the query and save results in new table
query_job = client.query(sql, job_config=job_config)
result = query_job.result()  # Wait for the job to complete.

print(f"{result.total_rows} user action records loaded to table {TABLE_ID}")

### Import sample queries

You will now retrieve a list of 15 sample security questions and corresponding SQL queries from a CSV file. These security questions are variations from the open-source [Community Security Analytics](https://github.com/GoogleCloudPlatform/security-analytics). CSA provides a set of security questions and corresponding queries for BigQuery, Log Analytics and Chronicle.

We will use a subset of these queries as few-shot examples as part of the model prompt, and the remaining set for model evaluation.

Run the following to read the CSV file from a GCS bucket and load all records into an in-memory pandas DataFrame:

In [None]:
BUCKET_ID = "csa-datasets-public"  # @param {type:"string"}
FILENAME = "SQL_Generator_Example_Queries.csv"  # @param {type:"string"}

df = pd.read_csv(f"gs://{BUCKET_ID}/{FILENAME}", header=0)
df.head(2)

Retrieve table name referenced by the sample questions. This should be the same as the table of post-processed logs we created above.

In [None]:
# Retrieve unique table names excluding null values and empty string
BQ_TABLES = df["Qualified table name"].replace("", np.nan).dropna().unique()
print(BQ_TABLES)

### Extract train & eval datasets

Extract train & eval datasets and store in respective dataframes:

In [None]:
train_df = df.loc[df["Dataset"] == "Train", ["Question", "SQL Query"]]
eval_df = df.loc[df["Dataset"] == "Eval", ["Question", "SQL Query"]]

Take a peek at a few records from each set:

In [None]:
train_df.head(5)

In [None]:
eval_df.head(5)

## Prepare few-shot prompt


### Define prompt template

The model prompt will include the following components:
1. Concise statement to specify the task
1. Schema definition to describe existing dataset
1. A few shot examples of questions in natural language and corresponding SQL statements

This is the template we will later use to generate the prompt using these 3 components.

In [None]:
# This string template takes three arguments:
# - schema definition
# - few shot examples
# - question for which query needs to be generated

prompt_template = """\
This is a task converting text into GoogleSQL statement.
We will first give you the dataset schema and then ask a question in text.
You are asked to generate SQL statement which is valid for BigQuery.
Remove any delimiters around answer such as "```"

BigQuery tables schema definition:
{schema_definition}
Here are a few shot examples:
{few_examples}
Write GoogleSQL query for following question: {question}
Answer: "Query here"
"""

### Build schema definition (compact version)

First, we need to build a concise schema definition of your dataset. As mentioned earlier, we'll use that as part of our prompt's context for grounding the results.

Retrieve table and column definitions from the `INFORMATION_SCHEMA` of your BigQuery dataset.

In [None]:
# Following SQL query will generate schema definition of your dataset
QUERY = f"""\
SELECT
    '[Schema (values)]: ' || '| log_summary | ' || STRING_AGG(table_values, ' | ') || ';' AS tables_definition,
    '[Column names (type)]: ' || STRING_AGG(column_names_types) || ';' AS columns_definition
FROM (
    SELECT
      table_name,
      table_name || ' : ' || STRING_AGG(column_name, ' , ') as table_values,
      STRING_AGG(table_name || ' : ' || column_name || ' (' || data_type || ')', ' | ') as column_names_types
    FROM {BQ_PROJECT_ID}.{BQ_PROCESSED_DATASET}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS
    WHERE table_name IN {'(' + ",".join(map(lambda x: f"'{x}'", BQ_TABLES)) + ')'}
    GROUP BY table_name
    ORDER BY table_name
)
"""

# Create query job
query_job = client.query(QUERY)
# Get first row
schema = next(query_job.result())

# Build schema definition
schema_definition = f"""\
{schema.tables_definition}

{schema.columns_definition}
"""

print(schema_definition)

### Add queries as few-shot examples

This is a helper function to format example input and output within the prompt. We will use this helper function as we add shots to the prompt.

In [None]:
one_shot_template = """
Question: {question}

Answer: {query}
"""

Let's add queries from our dataset as examples to our prompt:

In [None]:
few_examples = ""
for index, row in train_df.iterrows():
    few_examples += one_shot_template.format(
        question=row["Question"], query=row["SQL Query"]
    )

print(f"Added {str(train_df.shape[0])} pairs as few-shot examples")

### Review full prompt

Using the schema definition, the few shot examples, and a sample question to be answered, let's generate a full prompt as an example.

In [None]:
question = "This is a sample question"

prompt = prompt_template.format(
    schema_definition=schema_definition, few_examples=few_examples, question=question
)
print("Prompt:")
print(prompt)

print("Number of input tokens: " + str(len(prompt)))

## Generate SQL queries

### Define helper function to generate SQL

The following helper function `generate_sql()` is used to retrieve a SQL query from the Vertex AI LLM model using the prompt template we have built thus far.

Notice how `generate_sql()` uses `sanitize_output()` function to strip the response down to the SQL query itself before returning the results. Even though the model prompt includes instructions to tune the model output, there may still be enclosing quotes or code block backticks which need to be stripped out to avoid a subsequent SQL syntax error.

In [None]:
import re


# Strip text to include only the SQL code block with
def sanitize_output(text: str) -> str:
    # Strip whitespace and any potential backticks enclosing the code block
    text = text.strip()
    regex = re.compile(r"^\s*```(\w+)?|```\s*$")
    text = regex.sub("", text).strip()

    # Find and remove any trailing quote without corresponding opening quote
    if re.search(r'^[^"]*"$', text):
        text = text[:-1]
    # Find and remove any leading quote without corresponding closing quote
    if re.search(r'^"[^"]*$', text):
        text = text[1:]

    return text


# Call model using prompt and pre-defined parameters
def generate_sql(
    model,
    prompt: str,
    temperature: float = 0.2,
    max_output_tokens: int = 1024,
    top_k: int = 40,
    top_p: float = 0.8,
) -> str:
    print("Generating SQL...")
    print("Number of input tokens: " + str(len(prompt)))

    response = model.predict(
        prompt,
        temperature=temperature,
        max_output_tokens=max_output_tokens,
        top_k=top_k,
        top_p=top_p,
    )

    text = response.text
    print("Number of output tokens: " + str(len(text)))
    print("Response:")
    print(text)

    # Strip text to include only the SQL code block
    text = sanitize_output(text)
    print("Response stripped:")
    print(text)

    return text

### Define helper function to execute SQL

The following helper function `execute_sql()` is used to execute a SQL query against the live BigQuery dataset, and returning results as a dataframe.

Notice how `execute_sql()` ensures to qualify table names with the project and BigQuery dataset you specified above, before executing the SQL query. 

In [None]:
# Limit number of bytes processed as a guardrail for cost control
BQ_MAX_BYTES_BILLED = pow(2, 30)  # 1GB


def execute_sql(query: str):
    print("Executing SQL...")

    # Qualify table names with your project and dataset ID
    for table_name in BQ_TABLES:
        query = query.replace(
            table_name, f"{BQ_PROJECT_ID}.{BQ_PROCESSED_DATASET}.{table_name}"
        )

    print("Query:")
    print(query)

    # Validate the query by performing a dry run without incurring a charge
    job_config = bigquery.QueryJobConfig(use_query_cache=False, dry_run=True)
    try:
        response = client.query(query, job_config=job_config)
    except Exception as e:
        print("Error validating query:")
        print(e)
        return e

    print(f"Query will process {response.total_bytes_processed / 1024:.2f} KB.")

    # Execute the query
    job_config = bigquery.QueryJobConfig(
        use_query_cache=False, maximum_bytes_billed=BQ_MAX_BYTES_BILLED
    )
    try:
        response = client.query(query)
        df = response.to_dataframe()
    except Exception as e:
        print("Error executing query:")
        print(e)
        return e

    return df

### Example 1

Let's generate the SQL to answer this sample question:

*List all user actions that contains the word 'delete' or 'remove' over the last month. Include the user and the day in the results.*


In [None]:
question = "List all user actions that contains the word 'delete' or 'remove' over the last month. Include the user and the day in the results."

query = generate_sql(
    model,
    prompt_template.format(
        schema_definition=schema_definition,
        few_examples=few_examples,
        question=question,
    ),
)

Let's test the generated query with the live dataset in BigQuery.


In [None]:
# Execute the query
query_result = execute_sql(query)
print(query_result)

### Example 2

Let's generate the SQL to answer this sample question:

*List any action containing IAM case-insensitive by any unapproved user over the last 7 days, where approved user include 'admin@example.com'.*


In [None]:
question = "List any action containing IAM case-insensitive by any unapproved user over the last 7 days, where approved user include 'admin@example.com'"

query = generate_sql(
    model,
    prompt_template.format(
        schema_definition=schema_definition,
        few_examples=few_examples,
        question=question,
    ),
)

Let's test the generated query against your BigQuery dataset:

In [None]:
# Execute the query
query_result = execute_sql(query)
print(query_result)

## Evaluate model

### Run model on evaluation dataset

Let's generate SQL queries for all questions in our evaluation dataset. That dataset includes both `Question` and the ground truth `SQL Query`. Run the following code to automatically call the model for each question in the dataset and record the response in a new column `Generated SQL Query`. This may take few minutes as model calls are done serially.


In [None]:
eval_df["Generated SQL Query"] = eval_df["Question"].apply(
    lambda x: generate_sql(
        model,
        prompt_template.format(
            schema_definition=schema_definition, few_examples=few_examples, question=x
        ),
    )
)

eval_df

In [None]:
len(eval_df)

### Execute ground truth queries

Before we evaluate our generated queries, let's run the "ground truth" queries as part of our dataset.

In [None]:
eval_df["SQL Query Result"] = eval_df["SQL Query"].apply(execute_sql)

Let's peak into the results

In [None]:
eval_df.loc[:, ["SQL Query", "SQL Query Result"]]

### Execute generated queries

In [None]:
eval_df["Generated SQL Query Result"] = eval_df["Generated SQL Query"].apply(
    execute_sql
)

Let's peek into some of the results

In [None]:
eval_df.loc[:, ["Generated SQL Query", "Generated SQL Query Result"]].head()

### Serialize results

 Since the results of each successful query are represented as a nested DataFrame within the evaluation DataFrame, we need to first serialize them so it's easier to compare the results, and to optionally save them in a spreadsheet or a CSV file.

In [None]:
def format_query_result(query_result):
    if isinstance(query_result, pd.DataFrame):
        if query_result.shape[0] == 0:
            return "No results found"
        return query_result.to_csv(index=False)
    elif isinstance(query_result, Exception):
        return query_result.message
    else:
        return query_result


eval_df["Generated SQL Query Result Formatted"] = eval_df[
    "Generated SQL Query Result"
].apply(format_query_result)
eval_df["SQL Query Result Formatted"] = eval_df["SQL Query Result"].apply(
    format_query_result
)

Inspect the results

In [None]:
eval_df.head()

### Calculate match score

Let's evaluate the model accuracy by calculating a match score for each pair of queries. In our case, we'll calculate score based on fuzzy string matching of each generated query with the corresponding ground truth query.

Let's import and use `fuzzywuzzy` library which we already installed along with `Levenshtein` module for fast computation of Levenshtein distance between two strings:


In [None]:
from fuzzywuzzy import fuzz


def get_match_of_queries(df):
    return fuzz.partial_ratio(df["SQL Query"], df["Generated SQL Query"])


eval_df["match_score_queries"] = eval_df.apply(get_match_of_queries, axis=1)
eval_df.loc[:, ["SQL Query", "Generated SQL Query", "match_score_queries"]]

Now calculate the mean score across all questions

In [None]:
print(
    "The average match score based on raw generated queries is: ",
    round(eval_df["match_score_queries"].mean(), 2),
    "%",
)

The generated queries might have a different SQL implementation or variables naming than the 'ground truth' query, yet still yield the correct answer to the security question. So calculating the match score of each query string with its corresponding 'ground truth' query string is not a sufficient evaluation method. Therefore, we'll also calculate the match score between the actual results returned from your dataset.

Let's run the fuzzy match logic on the formatted version of the results. That formatted column is already stringified and ready for string comparison. This may take several minutes to complete depending on size of actual results being compared.

In [None]:
def get_match_of_results(df):
    return fuzz.partial_ratio(
        df["SQL Query Result Formatted"], df["Generated SQL Query Result Formatted"]
    )


# This may take several minutes to complete
eval_df["match_score"] = eval_df.apply(get_match_of_results, axis=1)
eval_df.loc[
    :,
    [
        "SQL Query",
        "Generated SQL Query",
        "SQL Query Result Formatted",
        "Generated SQL Query Result Formatted",
        "match_score_queries",
        "match_score",
    ],
]

Now calculate the mean score across all questions

In [None]:
print(
    "The average match score of all generated queries is: ",
    round(eval_df["match_score"].mean(), 2),
    "%",
)

## [Optional] Save results in Google Sheets

You may want to save all generated queries, results and scores into a Google Sheet for visual inspection and for future reference like model accuracy tracking.

Skip this section if using **Vertex AI Workbench** or **Colab Enterprise** because notebooks in those environments cannot access Google Drive or Google Sheets for security purposes.

Create a new spreadsheet in Google Sheets (https://sheets.new) and copy over your unique spreadsheet ID into `QUERIES_SHEET_ID` parameter.  You can find your spreadsheet ID in the Google Sheets URL: docs.google.com/spreadsheets/d/*spreadsheetId*/edit#gid=0


In [None]:
QUERIES_SHEET_ID = ""  # @param {type:"string"}
QUERIES_WORKSHEET_NAME = "Evaluation Dataset"  # @param {type:"string"}

In [None]:
from google.auth import default
import gspread

# Authenticate with Google Sheets
creds, _ = default()
gc = gspread.authorize(creds)

wks_results = gc.open_by_key(QUERIES_SHEET_ID).worksheet(QUERIES_WORKSHEET_NAME)

# Drop Query Result column which may contain non-serializable objects
eval_df_copy = eval_df.drop(columns=["SQL Query Result", "Generated SQL Query Result"])


def limit_cell_length(cell) -> str:
    if len(cell) >= 50000:
        return cell[:49990] + "..."
    return cell


eval_df_copy["Generated SQL Query Result Formatted"] = eval_df_copy[
    "Generated SQL Query Result Formatted"
].apply(limit_cell_length)
eval_df_copy["SQL Query Result Formatted"] = eval_df_copy[
    "SQL Query Result Formatted"
].apply(limit_cell_length)

wks_results.update(
    [eval_df_copy.columns.values.tolist()] + eval_df_copy.values.tolist()
)

## Summary

In this notebook, you were able to:
- Prepare datasets including writing logs into a summary table for easier and faster log analysis.
- Build a prompt template using an existing dataset of Text:SQL pairs as few-shot examples combined with a dynamically retrieved context, that is the database schema, for grounding the model.
- Convert NL question into SQL query using the language model and a few-shot prompt.
- Sanitize and validate model output.
- Evaluate model output by executing generated queries on BigQuery
- Run model on an entire evaluation dataset
- Calculate match score based on model output, that is the generated queries compared to the ground truth queries
- Calculate match score based on actual results of generated queries, compared to the results of the ground truth queries.
- Save results for future reference or for tracking model accuracy.

## Cleanup

To clean up all Google Cloud resources used in this notebook, you can delete the [Google Cloud project](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects) you used for the tutorial.

Otherwise, you can delete the individual resources you created in this tutorial, namely the BigQuery dataset `BQ_PROCESSED_DATASET` with the processed data: 

In [None]:
# Delete the created BigQuery dataset
!bq rm -r -f {BQ_PROJECT_ID}:{BQ_PROCESSED_DATASET}