# Evaluation of Single Agent with Langchain Toolkit 

In [1]:
import os
import sys

sys.path.append("/Users/I549593/Documents/Coding/MMDS/LLMsAgents-TextToSQL/src")
sys.path.append("/Users/I549593/Documents/Coding/MMDS/LLMsAgents-TextToSQL/src/agents")
sys.path.append("/Users/I549593/Documents/Coding/MMDS/LLMsAgents-TextToSQL/src/tools")

from dotenv import load_dotenv
from langchain_agent_test import SQLAgent
from database import Database

load_dotenv(override=True)


True

In [None]:
db = Database(name="california_schools")

agent = SQLAgent(db)

sql_answer = agent.generate_query("Please list the zip code of all the charter schools in Fresno County Office of Education.")
print(sql_answer)


SELECT Zip FROM schools WHERE County = 'Fresno' AND Charter = 1


In [3]:
db.execute_query(sql_answer)

[('93726-5309',),
 ('93662',),
 ('93628-9602',),
 ('93706-2611',),
 ('93726-5208',),
 ('93706-2819',),
 ('93609-9710',),
 ('93611-0581',),
 ('93612',),
 ('93705-2611',),
 ('93705-1611',),
 ('93728-3714',),
 ('93727-1510',),
 ('93740',),
 ('93721',),
 ('93704-5240',),
 ('93721-1611',),
 ('93704-4459',),
 ('93721-1104',),
 ('93706-3117',),
 ('93726-6906',),
 ('93726',),
 ('93726-5712',),
 (None,),
 ('93706-3719',),
 ('93701',),
 ('93726',),
 ('93706',),
 ('94111',),
 ('95376',),
 ('93631-1000',),
 ('93631-1701',),
 ('93631',),
 ('93631-2044',),
 ('93631-2100',),
 ('93631-1826',),
 ('93631-1000',),
 ('93631-1701',),
 (None,),
 ('93621',),
 ('93654-2017',),
 ('93706-5615',),
 ('93065-1800',),
 ('91361-6004',),
 ('93726-5309',),
 ('93727-1611',),
 ('93648-2034',),
 ('93727-4641',),
 ('93657-2780',),
 ('93611-4646',),
 ('93657-2711',),
 ('95223',),
 ('92201',),
 ('93706-9042',),
 ('93726-5318',),
 ('93721-1115',),
 ('93624-0398',),
 ('93740-8010',),
 ('93706',),
 ('93706',)]

In [4]:
original_query = "SELECT T2.Zip FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode WHERE T1.`District Name` = 'Fresno County Office of Education' AND T1.`Charter School (Y/N)` = 1"
db.execute_query(original_query)

[('93726-5309',),
 ('93628-9602',),
 ('93706-2611',),
 ('93726-5208',),
 ('93706-2819',)]

In [5]:
import pandas as pd

sample = pd.read_csv("../sample/dev.csv")

sample.head()

Unnamed: 0,question_id,db_id,question,evidence,SQL,difficulty
0,479,card_games,Among the cards with converted mana cost highe...,card set Coldsnap refers to name = 'Coldsnap';...,SELECT SUM(CASE WHEN T1.power LIKE '*' OR T1.p...,moderate
1,1057,european_football_2,Calculate the average home team goal in the 20...,average home team goal = AVG(home_team_goal)= ...,SELECT CAST(SUM(t2.home_team_goal) AS REAL) / ...,moderate
2,1367,student_club,Which college do most of the members go to?,college most members go refers to MAX(COUNT(ma...,SELECT T2.college FROM member AS T1 INNER JOIN...,simple
3,298,toxicology,Calculate the percentage of molecules containi...,hydrogen refers to element = 'h'; label = '+' ...,SELECT CAST(COUNT(CASE WHEN T1.element = 'h' A...,moderate
4,651,codebase_community,"Provide the related post title of ""How to tell...",,SELECT T3.Title FROM postLinks AS T1 INNER JOI...,simple


In [None]:
import re

def validate_query(query: str) -> bool:
    """
    Validate the SQL query for potential issues.
    Returns True if valid, False otherwise.
    """
    # Example regex for basic validation (add more rules if needed)
    invalid_tokens = [r"\\", r";--", r"--", r"\/\*", r"\*\/"]
    for token in invalid_tokens:
        if re.search(token, query):
            return False
    return True

In [None]:
import pandas as pd
from tqdm.notebook import tqdm
import os

model_name = "gpt-4o"
model_type = "single_agent"

# Create directory for saving evaluation results
os.makedirs(f"./runs/{model_name}", exist_ok=True)

# Initialize evaluation results
try:
    # If there is an ongoing run, load it to resume
    ongoing_run = pd.read_csv(f"./runs/{model_name}/{model_type}.csv").to_dict(orient="records")
    evaluation = ongoing_run
except FileNotFoundError:
    evaluation = []

# Initialize counters
total = len(sample)
correct = 0

for row in tqdm(
    sample.itertuples(index=False),
    total=total,
    desc=f"Evaluating {model_name}:",
):
    question = row.question
    ground_truth_query = row.SQL
    db_id = row.db_id
    difficulty = row.difficulty
    question_id = row.question_id

    print(f"Evaluating Question ID {question_id}...")

    # Initialize database and agent
    db = Database(name=db_id)
    agent = SQLAgent(db)

    # Generate query using the agent
    generated_query = agent.generate_query(question)

    if generated_query is None:
        print(f"Failed to generate query for Question: {question}")
        evaluation.append(
            {
                "question_id": question_id,
                "is_correct": False,
                "difficulty": difficulty,
                "ground_truth_results": None,
                "generated_results": None,
                "feedback": "Query generation failed",
            }
        )
        continue

    if not validate_query(generated_query):
        print(f"Invalid query detected: {generated_query}")
        evaluation.append(
            {
                "question_id": question_id,
                "is_correct": False,
                "difficulty": difficulty,
                "ground_truth_results": None,
                "generated_results": None,
                "feedback": "Invalid query generated",
            }
        )
        continue

    try:
        # Execute ground truth and generated queries
        ground_truth_results, execution_time = db.execute_query(ground_truth_query)
        generated_results, execution_time = db.execute_query(generated_query)

        # Compare the results
        is_correct = ground_truth_results == generated_results
        feedback = "Correct" if is_correct else "Results mismatch"

        if is_correct:
            correct += 1
            print(f"CORRECT for Question ID {question_id}")
        else:
            print(f"Mismatch for Question ID {question_id}")
            print(f"Ground Truth Results: {ground_truth_results}")
            print(f"Generated Results: {generated_results}")

    except Exception as e:
        print(f"Error executing queries for Question ID {question_id}: {e}")
        ground_truth_results = None
        generated_results = None
        is_correct = False
        feedback = f"Execution error: {e}"

    # Log the results for this question
    evaluation.append(
        {
            "question_id": question_id,
            "is_correct": is_correct,
            "difficulty": difficulty,
            "ground_truth_results": ground_truth_results,
            "generated_results": generated_results,
            "feedback": feedback,
        }
    )

    # Save results incrementally
    df = pd.DataFrame(evaluation)
    df.to_csv(f"./runs/{model_name}/{model_type}.csv", index=False)

# Calculate and log execution accuracy
execution_accuracy = (correct / total) * 100
print(f"Execution Accuracy (EX): {execution_accuracy:.2f}%")

# Final save of evaluation results
df = pd.DataFrame(evaluation)
df.to_csv(f"./runs/{model_name}/{model_type}.csv", index=False)


Evaluating Question ID 479...
Mismatch for Question: Among the cards with converted mana cost higher than 5 in the set Coldsnap, how many of them have unknown power?
Ground Truth Results: [(6,)]
Generated Results: [(0,)]
Evaluating Question ID 1057...
Mismatch for Question: Calculate the average home team goal in the 2010/2011 season in the country of Poland.
Ground Truth Results: [(1.5041666666666667,)]
Generated Results: [(1.5916666666666666,)]
Evaluating Question ID 1367...
Mismatch for Question: Which college do most of the members go to?
Ground Truth Results: [('College of Education & Human Services',)]
Generated Results: [('College of Agriculture and Applied Sciences', 36)]
Evaluating Question ID 298...
Mismatch for Question: Calculate the percentage of molecules containing carcinogenic compounds that element is hydrogen.
Ground Truth Results: [(17.22094171880145,)]
Generated Results: [(144,)]
Evaluating Question ID 651...
Error generating response: Recursion limit of 25 reached 

OperationalError: unrecognized token: "\"

In [8]:
execution_accuracy = (correct / total) * 100
print(f"Execution Accuracy (EX): {execution_accuracy:.2f}%")

Execution Accuracy (EX): 17.00%
