# Evaluate agent for text2sql task

In this lab, you will build a text2sql agent using Strands Agents SDK. Then you can evaluate this agent with BIRD dataset against different models or prompts.

# Build agents with Strands SDK

In [None]:
!uv pip install strands-agents

## Define tools

In [None]:
import time
import boto3
from strands import tool

athena_client = boto3.client('athena'， region_name='us-west-2')
    
@tool
def get_schema(database_name="california_schools"):
    """
    Get schema information for all tables in Athena databases
    """

    sql = f"""
        SELECT
            table_name,
            column_name,
            data_type
        FROM information_schema.columns
        WHERE table_schema = '{database_name}'
        ORDER BY table_name, ordinal_position;
        """
        
    try:
        # Start query execution
        response = athena_client.start_query_execution(
            QueryString=sql,
            QueryExecutionContext={
                'Database': database_name
            }
        )
            
        query_execution_id = response['QueryExecutionId']
            
        def wait_for_query_completion(query_execution_id):
            while True:
                response = athena_client.get_query_execution(
                    QueryExecutionId=query_execution_id
                )
                state = response['QueryExecution']['Status']['State']
                
                if state in ['SUCCEEDED', 'FAILED', 'CANCELLED']:
                    print(f"Query {state}")
                    return state
                    
                print("Waiting for query to complete...")
                time.sleep(2)
            
        # Wait for query completion
        state = wait_for_query_completion(query_execution_id)

        if state == 'SUCCEEDED':
            # Get query results
            results = athena_client.get_query_results(
                QueryExecutionId=query_execution_id
            )
            print("Got query results for schema")
            # Assuming you have a database connection and cursor setup
            # cursor.execute(sql)
            # results = cursor.fetchall()
            
            database_structure = []
            table_dict = {}

            # Skip the header row
            rows = results['ResultSet']['Rows'][1:]

            for row in rows:
                # Extract values from the Data structure
                table_name = row['Data'][0]['VarCharValue']
                column_name = row['Data'][1]['VarCharValue']
                data_type = row['Data'][2]['VarCharValue']
                
                # Initialize table if not exists
                if table_name not in table_dict:
                    table_dict[table_name] = []
                
                # Append column information
                table_dict[table_name].append((column_name, data_type))

            # Convert to the desired format
            for table_name, columns in table_dict.items():
                database_structure.append({
                    "table_name": table_name,
                    "columns": columns
                })

            return database_structure

        else:
            raise Exception(f"Query failed with state: {state}")
    except Exception as e:
            print(f"Error getting schema: {e}")
            raise

@tool
def query_athena(query, database_name='california_schools'):
    """
    Execute a query on Athena
    """
    try:
        # Start query execution
        response = athena_client.start_query_execution(
            QueryString=query,
            QueryExecutionContext={
                'Database': database_name
            }
        )
        
        query_execution_id = response['QueryExecutionId']
        
        def wait_for_query_completion(query_execution_id):
            while True:
                response = athena_client.get_query_execution(
                    QueryExecutionId=query_execution_id
                )
                state = response['QueryExecution']['Status']['State']
                
                if state == 'FAILED':
                    error_message = response['QueryExecution']['Status'].get('StateChangeReason', 'Unknown error')
                    raise Exception(f"Query failed: {error_message}")
                    
                if state == 'CANCELLED':
                    raise Exception("Query was cancelled")
                    
                if state == 'SUCCEEDED':
                    return state
                    
                print("Waiting for query to complete...")
                time.sleep(2)
        
        # Wait for query completion
        state = wait_for_query_completion(query_execution_id)
        print("query complete")
        # Get query results
        print(state)

        if state == 'SUCCEEDED':
            results = athena_client.get_query_results(
                QueryExecutionId=query_execution_id
            )
            print("got query results")
            print(results)
            # Process results
            processed_results = []
            headers = []
            
            # Get headers from first row
            if results['ResultSet']['Rows']:
                headers = [field.get('VarCharValue', '') for field in results['ResultSet']['Rows'][0]['Data']]
            
            # Process data rows
            for row in results['ResultSet']['Rows'][1:]:
                values = [field.get('VarCharValue', '') for field in row['Data']]
                row_dict = dict(zip(headers, values))
                processed_results.append(row_dict)
            
            print(processed_results)
            return processed_results

        else:
            raise Exception(f"Query failed with state: {state}")
        
    except Exception as e:
        print(f"Error executing query: {e}")
        raise

In [2]:
get_schema()

Waiting for query to complete...
Query SUCCEEDED
Got query results for schema


[{'table_name': 'frpm',
  'columns': [('cdscode', 'varchar'),
   ('academic year', 'varchar'),
   ('county code', 'varchar'),
   ('district code', 'integer'),
   ('school code', 'varchar'),
   ('county name', 'varchar'),
   ('district name', 'varchar'),
   ('school name', 'varchar'),
   ('district type', 'varchar'),
   ('school type', 'varchar'),
   ('educational option type', 'varchar'),
   ('nslp provision status', 'varchar'),
   ('charter school (y/n)', 'double'),
   ('charter school number', 'varchar'),
   ('charter funding type', 'varchar'),
   ('irc', 'double'),
   ('low grade', 'varchar'),
   ('high grade', 'varchar'),
   ('enrollment (k-12)', 'double'),
   ('free meal count (k-12)', 'double'),
   ('percent (%) eligible free (k-12)', 'double'),
   ('frpm count (k-12)', 'double'),
   ('percent (%) eligible frpm (k-12)', 'double'),
   ('enrollment (ages 5-17)', 'double'),
   ('free meal count (ages 5-17)', 'double'),
   ('percent (%) eligible free (ages 5-17)', 'double'),
   ('frp

In [3]:
query_athena("select * from satscores limit 10")

Waiting for query to complete...
query complete
SUCCEEDED
got query results
{'UpdateCount': 0, 'ResultSet': {'Rows': [{'Data': [{'VarCharValue': 'cds'}, {'VarCharValue': 'rtype'}, {'VarCharValue': 'sname'}, {'VarCharValue': 'dname'}, {'VarCharValue': 'cname'}, {'VarCharValue': 'enroll12'}, {'VarCharValue': 'numtsttakr'}, {'VarCharValue': 'avgscrread'}, {'VarCharValue': 'avgscrmath'}, {'VarCharValue': 'avgscrwrite'}, {'VarCharValue': 'numge1500'}]}, {'Data': [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]}, {'Data': [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]}, {'Data': [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]}, {'Data': [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]}, {'Data': [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]}, {'Data': [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]}, {'Data': [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]}, {'Data': [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]}, {'Data': [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}]}, {'Data': [{}, {}, {}, {}, {}, {}, {}, {}, {

[{'cds': '',
  'rtype': '',
  'sname': '',
  'dname': '',
  'cname': '',
  'enroll12': '',
  'numtsttakr': '',
  'avgscrread': '',
  'avgscrmath': '',
  'avgscrwrite': '',
  'numge1500': ''},
 {'cds': '',
  'rtype': '',
  'sname': '',
  'dname': '',
  'cname': '',
  'enroll12': '',
  'numtsttakr': '',
  'avgscrread': '',
  'avgscrmath': '',
  'avgscrwrite': '',
  'numge1500': ''},
 {'cds': '',
  'rtype': '',
  'sname': '',
  'dname': '',
  'cname': '',
  'enroll12': '',
  'numtsttakr': '',
  'avgscrread': '',
  'avgscrmath': '',
  'avgscrwrite': '',
  'numge1500': ''},
 {'cds': '',
  'rtype': '',
  'sname': '',
  'dname': '',
  'cname': '',
  'enroll12': '',
  'numtsttakr': '',
  'avgscrread': '',
  'avgscrmath': '',
  'avgscrwrite': '',
  'numge1500': ''},
 {'cds': '',
  'rtype': '',
  'sname': '',
  'dname': '',
  'cname': '',
  'enroll12': '',
  'numtsttakr': '',
  'avgscrread': '',
  'avgscrmath': '',
  'avgscrwrite': '',
  'numge1500': ''},
 {'cds': '',
  'rtype': '',
  'sname': '

## Define and test agent

In [None]:
import os
from strands import Agent
from strands.models.litellm import LiteLLMModel

os.environ['OPENAI_BASE_URL'] = "http://localhost:4000"
os.environ["OPENAI_API_KEY"] = "sk-12341234"

# model_id = "bedrock/us.amazon.nova-lite-v1:0"
# model_id = "openai/deepseek-r1-ds"
model_id = "bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0"
litellm_model = LiteLLMModel(
    model_id=model_id, 
    # params={"max_tokens": 1000, "temperature": 0.7}
)

system_prompt = """
You are an AI Agent specialized in generating SQL queries for Amazon Athena against Amazon S3 .parquet files. 
Your primary task is to interpret user queries, generate appropriate SQL queries, and provide the executed sql 
query as well as relevant answers based on the data. Follow these instructions carefully: 
1. Before generating any SQL query, use the /getschema tool to familiarize yourself with the data structure. 
a. Use the default database name 'california_schools' unless specified otherwise.
b. The schema will be provided in a structured format with table names and their respective columns.
c. Always refer to the schema for accurate column names and data types.

2. When generating an SQL query: 
a. Write the query as a single line, removing all newline characters. 
b. Column names must be exactly as they appear in the schema, including spaces. Do not replace spaces with underscores. 
c. Always enclose column names that contain spaces in double quotes ("). 
d. Be extra careful with column names containing special characters or spaces. 

3. Column name handling: 
a. Never modify column names. Use them exactly as they appear in the schema. 
b. If a column name contains spaces or special characters, always enclose it in double quotes ("). 
c. Do not use underscores in place of spaces in column names. 

4. Query output format: 
a. Always include the exact query that was run in your response. Start your response with 
"Executed SQL Query:" followed by the exact query that was run. 
b. Format the SQL query in a code block using three backticks (```). 
c. After the query, provide your explanation and analysis. 

5. When providing your response: 
a. Start with the executed SQL query as specified in step 4. 
b. Double-check that all column names in your generated query match the schema exactly. 
c. Ask for clarifications from the user if required. 

6. Error handling: 
a. If a query fails due to column name issues: - Review the schema and correct any mismatched column names. - 
Ensure all column names with spaces are enclosed in double quotes. - Regenerate the query with corrected column names. - 
Display both the failed query and the corrected query. 
b. Implement retry logic with up to 3 attempts for failed queries. 

Here are a few examples of generating SQL queries based on a question: 
Question: What is the highest eligible free rate for K-12 students in the schools in Alameda County? 
Executed SQL Query: "SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' 
ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1" 

Question: Please list the zip code of all the charter schools in Fresno County Office of Education. 
Executed SQL Query: "SELECT T2.Zip FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode WHERE T1.`District Name` = 'Fresno County Office of Education' 
AND T1.`Charter School (Y/N)` = 1" 

Question: Consider the average difference between K-12 enrollment and 15-17 enrollment 
of schools that are locally funded, list the names and DOC type of schools which has a difference above this average. 
Executed SQL Query: "SELECT T2.School, T2.DOC FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode 
WHERE T2.FundingType = 'Locally funded' AND (T1.`Enrollment (K-12)` - T1.`Enrollment (Ages 5-17)`) > 
(SELECT AVG(T3.`Enrollment (K-12)` - T3.`Enrollment (Ages 5-17)`) FROM frpm AS T3 INNER JOIN schools AS T4 ON T3.CDSCode = 
T4.CDSCode WHERE T4.FundingType = 'Locally funded')"
"""

agent = Agent(
    model=litellm_model,
    system_prompt=system_prompt,
    tools=[
        get_schema,
        query_athena
    ],
    trace_attributes={
        "session.id": "abc-1234",
        "user.id": "user-email-example@domain.com",
        "langfuse.tags": [
            "Agent-SDK",
            "text2sql",
        ]
    }
)

In [3]:
res = agent("Please list the zip code of all the charter schools in Fresno County Office of Education.")

I'll help you find the zip codes of all charter schools in Fresno County Office of Education. First, let me check the schema to understand the data structure.
Tool #1: get_schema
Waiting for query to complete...
Query SUCCEEDED
Got query results for schema
Now that I have the schema, I'll create a query to find the zip codes of all charter schools in Fresno County Office of Education. Based on the schema:
- The `frpm` table contains the "Charter School (Y/N)" column and "District Name" information
- The `schools` table contains the "Zip" information
- I need to join these tables using the CDSCode column

Let me execute the query:
Tool #2: query_athena
Waiting for query to complete...
query complete
SUCCEEDED
got query results
{'UpdateCount': 0, 'ResultSet': {'Rows': [{'Data': [{'VarCharValue': 'Zip'}]}, {'Data': [{'VarCharValue': '93726-5309'}]}, {'Data': [{}]}, {'Data': [{'VarCharValue': '93628-9602'}]}, {'Data': [{}]}, {'Data': [{'VarCharValue': '93706-2611'}]}, {'Data': [{}]}, {'Dat

In [6]:
res = agent("What is the unabbreviated mailing street address of the school with the highest FRPM count for K-12 students?")

I'll help you find the unabbreviated mailing street address of the school with the highest FRPM count for K-12 students. First, let me examine the relevant columns in the schema.
Tool #3: get_schema
Based on the schema, I need to:
1. Find the school with the highest "frpm count (k-12)" from the frpm table
2. Join with the schools table to get the unabbreviated mailing street address, which is in the "mailstreet" column

Let me construct and execute the query:
Tool #4: query_athena
Waiting for query to complete...
query complete
SUCCEEDED
got query results
{'UpdateCount': 0, 'ResultSet': {'Rows': [{'Data': [{'VarCharValue': 'mailstreet'}, {'VarCharValue': 'school name'}, {'VarCharValue': 'district name'}, {'VarCharValue': 'frpm count (k-12)'}]}, {'Data': [{'VarCharValue': '14429 South Downey Avenue'}, {'VarCharValue': 'Paramount High'}, {'VarCharValue': 'Paramount Unified'}, {'VarCharValue': '4419.0'}]}], 'ResultSetMetadata': {'ColumnInfo': [{'CatalogName': 'hive', 'SchemaName': '', 'Ta

In [7]:
res = agent("What is the number of SAT test takers of the schools with the highest FRPM count for K-12 students?")

I need to find the number of SAT test takers for the school with the highest FRPM count for K-12 students. Let me complete this in steps:

1. First, I'll identify the school with the highest FRPM count
2. Then, I'll look up that school's SAT test taker count from the satscores table
Tool #5: query_athena
Now that I've identified Paramount High as the school with the highest FRPM count (4,419), I'll look up its SAT test taker count:
Tool #6: query_athena
Waiting for query to complete...
query complete
SUCCEEDED
got query results
{'UpdateCount': 0, 'ResultSet': {'Rows': [{'Data': [{'VarCharValue': 'numtsttakr'}]}, {'Data': [{'VarCharValue': '459'}]}], 'ResultSetMetadata': {'ColumnInfo': [{'CatalogName': 'hive', 'SchemaName': '', 'TableName': '', 'Name': 'numtsttakr', 'Label': 'numtsttakr', 'Type': 'integer', 'Precision': 10, 'Scale': 0, 'Nullable': 'UNKNOWN', 'CaseSensitive': False}]}}, 'ResponseMetadata': {'RequestId': '984a1921-ad60-42a6-8be6-44c55882bd98', 'HTTPStatusCode': 200, 'HTTP

In [8]:
res = agent("How many schools with an average score in Math greater than 400 in the SAT test are exclusively virtual?")

I need to find how many schools with an average SAT Math score greater than 400 are exclusively virtual. Let me break down this query:

1. I need to identify schools with average SAT Math scores greater than 400 from the `satscores` table
2. I need to check which of those schools are virtual (exclusively) from the `schools` table
3. Then I'll count the number of matches

Let's first look at the specific column names in the schema:
Tool #7: get_schema
Now I'll create a query to find virtual schools with average SAT Math scores over 400. I can see in the schema:
- `avgscrmath` in the `satscores` table is the average Math score
- `virtual` in the `schools` table indicates if a school is virtual
- I need to join these tables based on the CDS code

Let me create and run the query:
Tool #8: query_athena
Waiting for query to complete...
query complete
SUCCEEDED
got query results
{'UpdateCount': 0, 'ResultSet': {'Rows': [{'Data': [{'VarCharValue': 'virtual_schools_count'}]}, {'Data': [{'VarCha

In [9]:
res.message['content'][0]['text']

"Executed SQL Query:\n```sql\nSELECT COUNT(*) as virtual_schools_count FROM satscores s JOIN schools sc ON s.cds = sc.cdscode WHERE s.avgscrmath > 400 AND sc.virtual = 'F'\n```\n\nBased on the results, there are **4 schools** that meet both criteria:\n1. They have an average SAT Math score greater than 400\n2. They are exclusively virtual (virtual = 'F', where 'F' indicates fully virtual)\n\nThese 4 schools represent the intersection of high-performing SAT math schools that operate in a completely virtual environment."

# Evaluate the agent

## Prepare the dataset

In [10]:
import json
with open("text2sql_data_file_auto.json", 'r') as f:
        data_dict = json.load(f)

In [11]:
data_dict

{'Trajectory1': [{'question_id': 0,
   'question': 'What is the highest eligible free rate for K-12 students in the schools in Alameda County?',
   'question_type': 'TEXT2SQL',
   'ground_truth': {'ground_truth_sql_query': "SELECT `Free Meal Count (K-12)` / `Enrollment (K-12)` FROM frpm WHERE `County Name` = 'Alameda' ORDER BY (CAST(`Free Meal Count (K-12)` AS REAL) / `Enrollment (K-12)`) DESC LIMIT 1",
    'ground_truth_sql_context': "[{'table_name': 'frpm', 'columns': [('cdscode', 'varchar'), ('academic year', 'varchar'), ('county code', 'varchar'), ('district code', 'integer'), ('school code', 'varchar'), ('county name', 'varchar'), ('district name', 'varchar'), ('school name', 'varchar'), ('district type', 'varchar'), ('school type', 'varchar'), ('educational option type', 'varchar'), ('nslp provision status', 'varchar'), ('charter school (y/n)', 'double'), ('charter school number', 'varchar'), ('charter funding type', 'varchar'), ('irc', 'double'), ('low grade', 'varchar'), ('high

In [12]:
def load_my_data(file_path):
    """
    Load data from a JSON file and return it as a dictionary.
    """
    result = []
    with open(file_path, 'r') as f:
        data_dict = json.load(f)
    
    for trajectoryID, questions in data_dict.items():
        # print(f"Trajectory ID: {trajectoryID}")
        for question in questions:
            # print(f"Question: {question['question']}")
            if 'ground_truth' in question:
                # print(f"Answer: {question['ground_truth']['ground_truth_answer']}")
                item = {
                    "trajectory_id": trajectoryID,
                    "question": question['question'],
                    "ground_truth": question['ground_truth']['ground_truth_answer']
                }
                result.append(item)
            else:
                print("No answer provided.")
            # print("-" * 40)  # Separator for readability
    
    return result

## Define the LLM judger

In [None]:
import json
import os
from litellm import completion

def model_call(model_id, prompt) -> str:
    response = completion(
        model=model_id,
        max_tokens=1024*8,
        messages=[{ "content": prompt,"role": "user"}]
    )
    return response["choices"][0]["message"]["content"]

class LLM_judger:
    def __init__(self):
        """Initialize the LLM_judger class."""
        pass
    
    def __call__(self, ground_truth: str, agent_output: str) -> dict:
        """Score the agent's output by comparing it with the ground truth."""
        query = f"""
        You are a LLM judger tasked with evaluating the correctness of an AI agent's answer to a question based on a ground truth answer.
        The AI agent's answer is provided below, along with the ground truth answer. Your task is to determine if the AI agent's answer is correct or not.
        AI agent's answer: {agent_output}
        Ground truth answer: {ground_truth}        
        
        Your task:
        1. State the model's predicted answer (answer only).
        2. State the ground truth (answer only).
        3. Determine if the model's final answer is correct (ignore formatting differences, etc.). RESPOND with the predicted and ground truth answer, followed with a JSON object containing the correctness encapsulated within the following delimiters:
           ```json
           {{ "correctness": true/false }}
           ```
        """
        
        # claude 3.7 sonnet as the default judge model
        # response = model_call("openai/claude-3.5-sonnet", query)
        response = model_call("bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0", query)
        
        if response is None:
            return {"correctness": False, "reasoning": "Inference failed."}
        
        try:
            # Extract correctness JSON object from the response
            json_start = response.index("```json") + 7
            json_end = response.index("```", json_start)
            correctness = json.loads(response[json_start:json_end].strip()).get("correctness", False)
        except (ValueError, IndexError):
            correctness = False

        print("correctness: ", correctness)
        return {"correctness": correctness, "reasoning": response}


## Evaluate the agent

In [14]:
def evaluator(agent: Agent, dataset: list[dict], scorers: list[callable]) -> dict:
    """Evaluate the dataset using the provided scorers."""
    results = []
    for item in dataset:
        question = item["question"]
        grount_truth = item["ground_truth"]
        agent_output = agent(question)
        answer= agent_output.message['content'][0]['text']
        
        for scorer in scorers:
            score = scorer(grount_truth, answer)
            results.append({
                "trajectory_id": item["trajectory_id"],
                "question": question,
                "ground_truth": grount_truth,
                "answer": answer,
                "score": score
            })
        time.sleep(30)  # Sleep to avoid rate limiting issues with the LLM API
    
    return results

In [None]:
ds = load_my_data("text2sql_data_file_auto.json")

eval_results = evaluator(
    agent,
    ds,
    scorers=[LLM_judger()]
)

In [None]:
eval_results

In [None]:
# Calculate the percentage of 'correctness' being True in the evaluation results.
correct_count = sum(1 for item in eval_results if item['score'].get('correctness', False))
accuracy = correct_count / len(eval_results) if eval_results else 0
print(f"Accuracy: {accuracy:.2%} ({correct_count}/{len(eval_results)})")

# Summary

Agent's brain is LLM, we evaluate Strands agents with different LLMs. Here is the summary:

| Agent             | Text2SQL Accuracy    | Memo  |
| :----------------:| :------: | :------: |
| Claude 3.5 Haiku        |     |   |
| Claude 3.5 Sonnet       |   80%   |   |
| Claude 3.7 Sonnet       |     |   |
| Claude 3.7 Sonnet thinking |     |    |
| DeepSeek R1-0528  | 40% |   |
| DeepSeek R1       |  |  |
| Nova Lite         | 0% |   |
| Nova Pro          |   |   |
| Gemini Flash 2.5  | 40% | |