<center> <img src="https://storage.googleapis.com/arize-assets/arize-logo-white.jpg" width="300"/> </center>

</center>
<h1 align="center">Experiments: Text2SQL</h1>

---
Let's work through a Text2SQL use case where we are starting from scratch without a nice and clean dataset of questions, SQL queries, or expected responses.

ℹ️This notebook requires:
- An OpenAI API key
- An Arize Space ID & Developer Key (explained below)

In [None]:
!pip install -q 'arize[Datasets]' openai datasets pyarrow pydantic nest_asyncio 'arize-phoenix[evals]'

# Setup Config

Copy the Arize developer API Key and Space ID from the Datasets page (shown below) to the variables in the cell below.

<center><img src="https://storage.googleapis.com/arize-assets/fixtures/dataset_api_key.png" width="700"></center>


In [None]:
from uuid import uuid1

dataset_name = "docs-qa-new-" + str(uuid1())[:5]

In [None]:
import os

from arize.experimental.datasets import ArizeDatasetsClient
from arize.experimental.datasets.utils.constants import GENERATIVE

import pandas as pd

import asyncio

import json

Let's make sure we can run async code in the notebook.


In [None]:
import nest_asyncio

nest_asyncio.apply()

Lastly, let's make sure we have our openai API key set up.


In [None]:
from getpass import getpass

space_id = "U3BhY2U6NjM3MjoyMXJG"

if not os.getenv("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = getpass("🔑 Enter your OpenAI API key: ")

if not os.environ.get("ARIZE_API_KEY"):
    os.environ["ARIZE_API_KEY"] = getpass("🔑 Enter your ARIZE_API_KEY: ")

if not os.environ.get("ARIZE_DEVELOPER_KEY"):
    os.environ["ARIZE_DEVELOPER_KEY"] = getpass("🔑 Enter your developer key: ")

# Download Data

We are going to use the NBA dataset that information from 2014 - 2018. We will use DuckDB as our database.

In [None]:
import duckdb
from datasets import load_dataset

data = load_dataset("suzyanil/nba-data")["train"]

conn = duckdb.connect(database=":memory:", read_only=False)
conn.register("nba", data.to_pandas())

conn.query("SELECT * FROM nba LIMIT 5").to_df().to_dict(orient="records")[0]

## Implement Text2SQL

Let's start by implementing a simple text2sql logic.

In [None]:
import os

import openai

client = openai.AsyncClient()

columns = conn.query("DESCRIBE nba").to_df().to_dict(orient="records")

# We will use GPT4o to start
TASK_MODEL = "gpt-4o"
CONFIG = {"model": TASK_MODEL}


system_prompt = (
    "You are a SQL expert, and you are given a single table named nba with the following columns:\n"
    f'{",".join(column["column_name"] + ": " + column["column_type"] for column in columns)}\n'
    "Write a SQL query corresponding to the user's request. Return just the query text, "
    "with no formatting (backticks, markdown, etc.)."
)


async def generate_query(input):
    response = await client.chat.completions.create(
        model=TASK_MODEL,
        temperature=0,
        messages=[
            {
                "role": "system",
                "content": system_prompt,
            },
            {
                "role": "user",
                "content": input,
            },
        ],
    )
    return response.choices[0].message.content

In [None]:
query = await generate_query("Who won the most games?")
print(query)

Awesome, looks like the LLM is producing SQL! let's try running the query and see if we get the expected results.

In [None]:
def execute_query(query):
    return conn.query(query).fetchdf().to_dict(orient="records")


execute_query(query)

## Evaluation

Evaluation consists of three parts — data, task, and scores. We'll start with data.

In [None]:
questions = [
    "Which team won the most games?",
    "Which team won the most games in 2015?",
    "Who led the league in 3 point shots?",
    "Which team had the biggest difference in records across two consecutive years?",
    "What is the average number of free throws per year?",
]

Let's store the data above as a versioned dataset in Arize.


In [None]:
arize_client = ArizeDatasetsClient(
    developer_key=os.environ.get("ARIZE_DEVELOPER_KEY"),
    api_key=os.environ.get("ARIZE_API_KEY"),
)
# Create a dataset from a DataFrame add your own data here
test_df = pd.DataFrame([{"question": question} for question in questions])
dataset_id = arize_client.create_dataset(
    space_id=space_id,
    dataset_name=dataset_name,
    dataset_type=GENERATIVE,
    data=test_df,
)
dataset_name

Let's now pull down the dataset from Arize in this environment.

In [None]:
dataset = arize_client.get_dataset(space_id=space_id, dataset_id=dataset_id)
dataset.head()

Next, we'll define the task. The task is to generate SQL queries from natural language questions.

In [None]:
async def text2sql(question):
    query = await generate_query(question)
    results = None
    error = None
    try:
        results = execute_query(query)
    except duckdb.Error as e:
        error = str(e)

    r = {
        "query": query,
        "results": results,
        "error": error,
    }
    return json.dumps(r)

Finally, we'll define the scores. We'll use the following simple scoring functions to see if the generated SQL queries are correct.

In [None]:
# Test if there are no sql execution errors
def no_error(output):
    output = json.loads(output)
    return 1.0 if output.get("error") is None else 0.0


# Test if the query has results
def has_results(output):
    output = json.loads(output)
    results = output.get("results")
    has_results = results is not None and len(results) > 0
    return 1.0 if has_results else 0.0

# Run Experiment



Run experiment and log results to Arize


In [None]:
# Define the task to run text2sql on the input question
def task(dataset_row):
    input = dataset_row
    return asyncio.run(text2sql(input["question"]))


experiment = arize_client.run_experiment(
    space_id=space_id,
    dataset_id=dataset_id,
    task=task,
    evaluators=[no_error, has_results],
    experiment_name="text2sql_test-2",
)

Ok! It looks like 3/5 of our queries are valid.

#Interpreting the results

Now that we ran the initial evaluation, it looks like two of the results are valid, two produce SQL errors, and one is incorrect.

- The incorrect query didn't seem to get the date format correct. That would probably be improved by showing a sample of the data to the model (e.g. few shot example).

- There are is a binder error, which may also have to do with not understanding the data format.

Let's try to improve the prompt with few-shot examples and see if we can get better results.

In [None]:
samples = (
    conn.query("SELECT * FROM nba LIMIT 1").to_df().to_dict(orient="records")[0]
)
sample_rows = "\n".join(
    f"{column['column_name']} | {column['column_type']} | {samples[column['column_name']]}"
    for column in columns
)
system_prompt = (
    "You are a SQL expert, and you are given a single table named nba with the following columns:\n\n"
    "Column | Type | Example\n"
    "-------|------|--------\n"
    f"{sample_rows}\n"
    "\n"
    "Write a DuckDB SQL query corresponding to the user's request. "
    "Return just the query text, with no formatting (backticks, markdown, etc.)."
)


async def generate_query(input):
    response = await client.chat.completions.create(
        model=TASK_MODEL,
        temperature=0,
        messages=[
            {
                "role": "system",
                "content": system_prompt,
            },
            {
                "role": "user",
                "content": input,
            },
        ],
    )
    return response.choices[0].message.content


print(await generate_query("Which team won the most games in 2015?"))

Looking much better! Finally, let's add a scoring function that compares the results, if they exist, with the expected results.

In [None]:
from phoenix.evals.models import OpenAIModel
from phoenix.evals.classify import llm_classify
from arize.experimental.datasets.experiments.types import EvaluationResult


IS_SQL_EVAL_TEMPLATE = """You are a SQL expert, is the following a valid SQL query that executes without errors? Return the single workd "valid" if is valid, and "invalid" if it is not.

[BEGIN SQL QUERY]
{query}
[END SQL QUERY]
"""


def check_is_sql(output):
    output = json.loads(output)
    query = output.get("query")
    df_in = pd.DataFrame({"query": query}, index=[0]) if query else None
    eval_df = llm_classify(
        dataframe=df_in,
        template=IS_SQL_EVAL_TEMPLATE,
        model=OpenAIModel(model="gpt-4o"),
        rails=["valid", "invalid"],
        provide_explanation=True,
    )
    # return score, label, explanation
    return EvaluationResult(
        score=1,
        label=eval_df["label"][0],
        explanation=eval_df["explanation"][0],
    )


experiment = arize_client.run_experiment(
    space_id=space_id,
    dataset_id=dataset_id,
    task=task,
    evaluators=[no_error, has_results, check_is_sql],
    experiment_name="text2sql_test_new_prompt_and_eval-6",
)

Amazing. It looks like we removed one of the errors, and got a result for the incorrect query. Let's try out using LLM as a judge to see how well it can assess the results.