In [None]:
"""
Copyright 2024 Amazon.com, Inc. and its affiliates. All Rights Reserved.

Licensed under the Amazon Software License (the "License").
You may not use this file except in compliance with the License.
A copy of the License is located at

  https://aws.amazon.com/asl/

or in the "license" file accompanying this file. This file is distributed
on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
express or implied. See the License for the specific language governing
permissions and limitations under the License.
""";

# Text To SQL Example


# 1.0 Project Setup and Sample Data

This section sets up a sample SQLite database and populates it with employee data. We will use this data as a backend for one of the agent's tools. In the Generative AI Toolkit terminology:

- **Tools**: External functions or APIs the agent can use to retrieve information.
- **Traces**: Records of interactions. Later, we will record the steps the agent takes when handling queries.

We start by creating a `test_db.db` database and inserting sample employee records.


In [None]:
import sqlite3

# Connect to the test database (or create it if it doesn't exist)
conn = sqlite3.connect("test_db.db")
cursor = conn.cursor()

# Create a sample table
cursor.execute(
    """
    CREATE TABLE IF NOT EXISTS employees (
        id INTEGER PRIMARY KEY,
        name TEXT,
        department TEXT,
        salary INTEGER
    )
"""
)

# Insert sample data
sample_data = [
    (1, "John Doe", "Sales", 50000),
    (2, "Jane Smith", "Engineering", 75000),
    (3, "Mike Johnson", "Sales", 60000),
    (4, "Emily Brown", "Engineering", 80000),
    (5, "David Lee", "Marketing", 55000),
]
try:
    cursor.executemany("INSERT INTO employees VALUES (?, ?, ?, ?)", sample_data)
except sqlite3.IntegrityError:
    pass  # records already present in DB
else:
    conn.commit()

# 2.0 Environment and Library Configuration

Here we configure our environment and import the necessary components from the Generative AI Toolkit and supporting libraries. Key concepts:

- **Agents**: The main interface to an LLM-based application.
- **Conversation History**: The agent can maintain context across multiple turns.
- **Metrics, Traces, Cases**: We will set these up for evaluating and testing the agent’s performance later on.


In [None]:
import textwrap

from generative_ai_toolkit.evaluate.interactive import GenerativeAIToolkit, Permute
from generative_ai_toolkit.agent import BedrockConverseAgent
from generative_ai_toolkit.test import Case

# 3.0 Define a Tool for the Agent

We create a tool that allows the agent to run SQL queries against our local database. The agent can invoke this tool when it needs information about employees. By defining this tool, we give the agent the capability to answer user queries that require database lookups.


In [None]:
import sqlite3
from typing import Any, Dict


def execute_query(sql_query: str) -> Dict[str, Any]:
    """
    Executes a provided SQL query on the existing database schema.

    Parameters
    ----------
    sql_query : str
        The SQL query string that will be executed on the database.
        The query should follow standard SQL syntax and can be a SELECT, INSERT, UPDATE, or DELETE query.

    Returns
    -------
    dict
        A dictionary containing the results of the executed query.
        The 'columns' key contains a list of column names.
        The 'rows' key contains the data returned from the query, where each row is represented as a list.
    """
    conn = sqlite3.connect("test_db.db")
    cursor = conn.cursor()

    # Execute the query
    cursor.execute(sql_query)

    # Fetch the results
    rows = cursor.fetchall()

    # Get the column names
    columns = [description[0] for description in cursor.description]

    # Convert each row from a tuple to a list
    results = [list(row) for row in rows]

    # Return the results as a dictionary
    return {"columns": columns, "rows": results}


# Example usage
tools = [
    execute_query,
]

# 4.0 System Prompt for the Agent

The system prompt guides the agent. It describes the database schema and instructs the agent to generate SQL queries and then use the `execute_query` tool to retrieve results. The prompt also includes an example to help the agent understand how to behave.


In [None]:
###
# System prompt
###

system_prompt = textwrap.dedent(
    """Here is the schema for a database: \n
    TABLE EMPLOYEES (
    id INTEGER,
    name TEXT,
    department TEXT,
    salary INTEGER
    );
    
    \n\n
    
    Given this schema, you can use the provided tools to generate and execute SQL queries on the database.
    Please output the SQL query first, and then use the 'execute_query' tool to run the query. The query result 
    should be formatted appropriately based on the output.
    In natural language provide the results for the user. \n\n
    
    Example:\n
    User Query: List all employees in the Engineering department\n
    SQL Query: SELECT * FROM EMPLOYEES WHERE DEPARTMENT = 'Engineering';
    """
).strip()

# 5.0 Define Test Cases

**Cases**: Cases are repeatable tests that simulate user queries. Each case includes:

- User inputs (prompts)
- Expected SQL queries
- Expected responses

These cases allow us to verify that the agent produces correct queries and results consistently.


In [None]:
valid_queries_responses = {
    "sql_case_1": {
        "user_input": "What are the names and salaries of employees in the Marketing department?",
        "sql_query": "SELECT name, salary FROM EMPLOYEES WHERE department = 'Marketing'",
        "expected_response": {
            "columns": ["name", "salary"],
            "rows": [["David Lee", 55000]],
        },
    },
    "sql_case_2": {
        "user_input": "List all employees in the Engineering department",
        "sql_query": "SELECT * FROM EMPLOYEES WHERE department = 'Engineering'",
        "expected_response": {
            "columns": ["id", "name", "department", "salary"],
            "rows": [
                [2, "Jane Smith", "Engineering", 75000],
                [4, "Emily Brown", "Engineering", 80000],
            ],
        },
    },
}


def generate_sql_cases(valid_queries: dict) -> list[Case]:
    """
    Generates a list of Case objects dynamically from the valid SQL queries.

    Args:
        valid_queries (dict): A dictionary mapping case names to SQL queries and expected responses.

    Returns:
        list[Case]: A list of dynamically generated Case objects based on the valid queries.
    """
    cases = []

    for case_name, query_info in valid_queries.items():
        case = Case(
            name=case_name,
            user_inputs=[query_info["user_input"]],
        )
        cases.append(case)

    return cases


# Generate the SQL cases dynamically
cases = generate_sql_cases(valid_queries_responses)

# 6.0 Define and Configure Metrics

**Metrics**: We use metrics to evaluate the agent’s performance. For this example:

- **SqlMetric** will verify that the agent produces correct SQL queries and returns the expected results.
- **CostMetric** will measure the LLM invocation cost.

By measuring performance across multiple runs and configurations, we ensure that our agent meets desired quality and cost criteria.


In [None]:
from generative_ai_toolkit.metrics.modules.cost import CostMetric
from generative_ai_toolkit.metrics.modules.sql import SqlMetric

pricing_config = {
    "anthropic.claude-3-sonnet-20240229-v1:0": {
        "input_cost": 0.003,
        "output_cost": 0.015,
        "per_token": 1000,
    },
    "anthropic.claude-3-haiku-20240307-v1:0": {
        "input_cost": 0.002,
        "output_cost": 0.01,
        "per_token": 1000,
    },
}
cost_metric = CostMetric(pricing_config)

sql_metric = SqlMetric(valid_queries_responses)
metrics = [sql_metric, cost_metric]

# 7.0 Generate Traces for Evaluation

**Traces**: We run the defined cases against our agent to produce traces. Traces record everything:

- The agent’s requests to the LLM
- The tool invocations and their results
- The final agent responses

We use `GenerativeAIToolkit.generate_traces()` to run each case multiple times and with different model parameters, producing a rich dataset for evaluation.


In [None]:
traces = GenerativeAIToolkit.generate_traces(
    cases=cases,
    nr_runs_per_case=3,
    agent_factory=BedrockConverseAgent,
    agent_parameters={
        "system_prompt": Permute(
            [
                system_prompt,
                # """You are a SQL assistant that will use the tools to execute the query.""",
            ]
        ),
        "temperature": 0.9,
        "tools": tools,
        "model_id": Permute(
            [
                "anthropic.claude-3-sonnet-20240229-v1:0",
                "anthropic.claude-3-haiku-20240307-v1:0",
            ]
        ),
    },
)

# 8.0 Evaluate the Model Using the Metrics

Now that we have traces, we run the evaluation with `GenerativeAIToolkit.eval()`:

- **SqlMetric** checks the correctness of SQL queries and responses.
- **CostMetric** estimates the cost of these LLM calls.

This gives us a quantitative assessment of correctness and cost.


In [None]:
results = GenerativeAIToolkit.eval(
    traces=traces,
    metrics=metrics,
)

Calling `summary()` on the results, or otherwise consuming the `results` iterator, will start the actual work:

- Traces will be generated
- Metrics will be calculated based on these traces

`summary()` will return a nice table with averages. All measurements are available with full details in the `results` object.


In [None]:
results.summary()

# 9.0 Start the Web UI

Start the local Web UI for conversation debugging. By starting the UI, we can visually inspect traces, debug interactions, and analyze metrics more interactively.


In [None]:
# Start the User Interface on localhost port 7860
results.ui.launch()

In [None]:
from IPython.display import display, HTML

display(HTML("<b style='color: red;'>⚠️ Press enter to stop the UI</b>"))
response = input("⚠️ Press enter to stop the UI")

# Stop the UI
results.ui.close()