In [6]:
# !pip install -r requirements.txt

## 1. Database Schema and Configuration

In [251]:
import os
import google.generativeai as genai
from setup_db import execute_query
from config import GOOGLE_API_KEY, DATABASE_SCHEMA,COT_TEXT2SQL_EXAMPLE, LANGCHAIN_API_KEY


import json
import re
import time
import re
import pandas as pd
from tqdm import tqdm
import numpy as np

from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolExecutor
from langchain.tools import Tool
from langchain.schema import HumanMessage
from typing import TypedDict, Annotated, Sequence, Union
from typing import List, Tuple, Dict, Any

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.schema import SystemMessage, HumanMessage


genai.configure(api_key=GOOGLE_API_KEY)

os.environ["LANGCHAIN_API_KEY"] = LANGCHAIN_API_KEY
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "text2sql"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"


## 2. Core Text-to-SQL Functions


In [252]:
# Load LLMs
llm_sql_generator = ChatGoogleGenerativeAI(model="gemini-2.0-flash-lite-preview-02-05")
llm_sql_validator = ChatGoogleGenerativeAI(model="gemini-2.0-flash-lite-preview-02-05")

# Extract SQL from the model response
def extract_sql(text):
    match = re.search(r"```sql\s*(.*?)\s*```", text, re.DOTALL)
    return match.group(1) if match else text

# Tool 1: Generate SQL Query
def generate_sql(natural_language_query, DATABASE_SCHEMA=DATABASE_SCHEMA,COT_TEXT2SQL_EXAMPLE=COT_TEXT2SQL_EXAMPLE):
    """Agent to generate SQL from a natural language question."""
    prompt = f"""
    Properly use the Database Schema to properly use the table names and column names for respective tables.
    
    Database Schema:
    {DATABASE_SCHEMA}

    **************************
    Properly use the Database Schema to generate the SQL query.
    Answer Repeating the question and evidence, and generating the SQL with a query plan.
    
    <---(Example)--->
    {COT_TEXT2SQL_EXAMPLE}
    
    Only return the SQL query without ``` backticks, no other text. 
    Ensure the table alias is correctly assigned
    
    
    ---------------------------------
    Question: {natural_language_query}

    SQL Query: 
    """
    response = llm_sql_generator.invoke([HumanMessage(content=prompt)])
    return extract_sql(response.content)

# Tool 2: Validate and Fix SQL Query for PostgreSQL
def validate_and_fix_sql(sql_query,DATABASE_SCHEMA=DATABASE_SCHEMA):
    """Agent to validate and correct SQL syntax for PostgreSQL."""
    prompt = f"""
    The following SQL query might have syntax issues. Your task is to analyze it and correct any mistakes 
    so that it works properly in **PostgreSQL**.
    
    Properly use the Database Schema to generate the SQL query.
    Database Schema:
    {DATABASE_SCHEMA}

    Incorrect SQL:
    {sql_query}

    Return only the corrected SQL query without any explanations or ``` backticks.
    
    Make sure to use "ILIKE" instead of "=" for case insensitive matching. If not asked to be case sensitive
    Only while using "ILIKE" make sure to use "::TEXT" to avoid type mismatch errors.
    Don't use "::TEXT" while using "=" or anyother time
    

    Corrected SQL Query:
    """
    response = llm_sql_validator.invoke([HumanMessage(content=prompt)])
    return extract_sql(response.content)

# Define state type
class AgentState(TypedDict):
    input: str
    sql_query: str
    final_query: str
    query_results: str  # Add this field

# Define nodes with updated configuration
def generate_sql_node(state: AgentState) -> AgentState:
    """Generate initial SQL query"""
    try:
        sql_query = generate_sql(
            natural_language_query=state["input"],
            DATABASE_SCHEMA=DATABASE_SCHEMA,
            COT_TEXT2SQL_EXAMPLE=COT_TEXT2SQL_EXAMPLE
        )
        state["sql_query"] = sql_query
        return state
    except Exception as e:
        print(f"Error in generate_sql_node: {str(e)}")
        raise

def validate_sql_node(state: AgentState) -> AgentState:
    """Validate and fix SQL query"""
    try:
        final_query = validate_and_fix_sql(state["sql_query"])
        state["final_query"] = final_query
        return state
    except Exception as e:
        print(f"Error in validate_sql_node: {str(e)}")
        raise

def execute_sql_node(state: AgentState) -> AgentState:
    """Execute the SQL query and store results"""
    try:
        query_results = execute_query(state["final_query"].replace("\n", " "))
        state["query_results"] = query_results
        return state
    except Exception as e:
        print(f"Error in execute_sql_node: {str(e)}")
        raise

# Create workflow
workflow = StateGraph(AgentState)

# Add nodes
workflow.add_node("generate_sql", generate_sql_node)
workflow.add_node("validate_sql", validate_sql_node)
workflow.add_node("execute_sql", execute_sql_node)  # Add new node

# Add edges
workflow.add_edge("generate_sql", "validate_sql")
workflow.add_edge("validate_sql", "execute_sql")  # Add edge to new node 
workflow.add_edge("execute_sql", END)  # Update final edge

# Set entry point
workflow.set_entry_point("generate_sql")

# Compile
agent_executor = workflow.compile()

# Update example usage
def process_query(natural_language_query: str, max_retries=5, show_print=False):
    attempt = 0
    error_message = ""

    state = {
        "input": natural_language_query,
        "sql_query": "",
        "final_query": "",
        "query_results": ""
    }

    error_messages = []
    while attempt < max_retries:
        # print(f"Attempt {attempt + 1}:{error_messages}")
        try:
            result = agent_executor.invoke(state)

            # Extract query and results
            sql_query = result["final_query"]
            query_results = result["query_results"]

            if show_print:
                print(f"\nAttempt {attempt + 1}:")
                print("\nNatural Language Query:\n", state["input"])
                print("\nGenerated SQL:\n", sql_query)
                print("\nQuery results:\n", query_results)

            # Check if the result contains an error
            if "ERROR" in query_results.upper():
                # Extract the error message
                error_message = query_results  # Full error message
                error_messages.append(error_message)
                # Prepare the retry prompt
                state["input"] = f"""{natural_language_query}\n\nPrevious Error list: {error_messages}"""
                attempt += 1
                continue  # Retry the process

            # If no error, return the successful result
            return sql_query, query_results

        except Exception as e:
            # print(f"Unexpected error in attempt {attempt + 1}: {str(e)}")
            attempt += 1

    print(f"\nMax retries reached. Last error: {error_message}")
    return None, None  # Return None if max retries are exceeded


# Test the processing
test_query = "Show me the top 5 customers who have rented the most movies"
# test_query = "Show all staff members hired before January 1, 2020."
# test_query = "Show all staff members hired before January 1, 2020., create_date is not in staff table"
sql, results = process_query(test_query,show_print=True)


Attempt 1:

Natural Language Query:
 Show me the top 5 customers who have rented the most movies

Generated SQL:
 SELECT
    C.first_name,
    C.last_name,
    COUNT(R.rental_id) AS rental_count
FROM
    customer AS C
    JOIN rental AS R ON C.customer_id = R.customer_id
GROUP BY
    C.customer_id,
    C.first_name,
    C.last_name
ORDER BY
    rental_count DESC
LIMIT 5;

Query results:
  first_name | last_name | rental_count 
------------+-----------+--------------
 ELEANOR    | HUNT      |           46
 KARL       | SEAL      |           45
 MARCIA     | DEAN      |           42
 CLARA      | SHAW      |           42
 TAMMY      | SANDERS   |           41
(5 rows)




#### Will be used in Gradio App

In [289]:
llm_query_validator = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash-lite-preview-02-05"
)

def string_to_dict(text):
    """
    Converts a JSON-like structured string into a dictionary, handling extra formatting.

    Args:
        text (str): Input string containing key-value pairs.

    Returns:
        dict: Dictionary with extracted key-value pairs.
    """
        
    # Remove code block markers like ```python and ```
    cleaned_text = re.sub(r"```[\w]*", "", text).strip()

    # Remove trailing commas before closing braces
    cleaned_text = re.sub(r",\s*}", "}", cleaned_text)

    # Convert to dictionary
    data_dict = json.loads(cleaned_text)

    try:
        return json.loads(cleaned_text)
    except json.JSONDecodeError:
        raise ValueError("Invalid JSON format")


# Validate Natural Language Query
def validate_nl_query(natural_language_query, user_instructions=None):
    """Agent to validate and improve natural language query."""
    prompt = f"""
    
    "JUST OUTPUT THE Python DICTIONARY of texts of natural language query"
    You are a helpful assistant that validates natural language queries for a Database.
    Your task is to analyze the query for ambiguity, incompleteness, or incorrectness and improve it if needed.
    You are also allowed to use the database schema to improve the query.

    User Instructions:
    {user_instructions}

    Database Schema:
    {DATABASE_SCHEMA}
    
    Natural Language Query:
    {natural_language_query}
    
    If the query is clear and complete, return it unchanged.
    If the query needs improvement, provide the improved version and explain why.
    Make sure to return the corrected input, the improved query and the feedback.
    Make sure to not alter the original query too much, if there is no typo or discrepancy with database schema. 
    
    
    In the following Format - For example:
    original_query: show moveis with actr smith
    corrected_input: show movies with actor smith
    feedback: Fixed typos in 'movies' and 'actor', added specificity about searching by last name
    
    original_query: show all movies with rating R
    corrected_input: show all movies with rating R
    feedback: Query is clear and well-formed, minor rewording for consistency
    
    original_query: list customer payments
    corrected_input: list customer payments
    feedback: Query is clear and grammatically correct   
    
    Output Format (make sure say why we changed what, if changed): 
    ```python
    {{
        original_query: show moveis with actr smith
        corrected_input: show movies with actor smith
        feedback: Fixed typos in 'movies' and 'actor'
    }}
    ```
    

    """
    response = llm_query_validator.invoke([HumanMessage(content=prompt)])

    print(response.content)
    return string_to_dict(response.content)


In [291]:
validate_nl_query("Find the films that have been rented more times than the average")

```python
{
    "original_query": "Find the films that have been rented more times than the average",
    "corrected_input": "Find the films that have been rented more times than the average",
    "feedback": "The query is clear and grammatically correct"
}
```


{'original_query': 'Find the films that have been rented more times than the average',
 'corrected_input': 'Find the films that have been rented more times than the average',
 'feedback': 'The query is clear and grammatically correct'}

In [293]:
process_query("Find the films that have been rented more times than the average")

('SELECT\n  f.title\nFROM film AS f\nJOIN inventory AS i\n  ON f.film_id = i.film_id\nJOIN rental AS r\n  ON i.inventory_id = r.inventory_id\nGROUP BY\n  f.film_id,\n  f.title\nHAVING\n  COUNT(r.rental_id) > (\n    SELECT\n      AVG(rental_count)\n    FROM (\n      SELECT\n        COUNT(r2.rental_id) AS rental_count\n      FROM inventory AS i2\n      JOIN rental AS r2\n        ON i2.inventory_id = r2.inventory_id\n      GROUP BY\n        i2.film_id\n    ) AS film_rental_counts\n  );',
 '            title            \n-----------------------------\n EFFECT GLADIATOR\n BALLOON HOMEWARD\n VOYAGE LEGALLY\n BIKINI BORROWERS\n GARDEN ISLAND\n CONGENIALITY QUEST\n EXCITEMENT EVE\n PATIENT SISTER\n AMISTAD MIDSUMMER\n EASY GLADIATOR\n HILLS NEIGHBORS\n MALTESE HOPE\n SATURDAY LAMBS\n OPERATION OPERATION\n SEATTLE EXPECATIONS\n SHOW LORD\n MILLION ACE\n BASIC EASY\n AMADEUS HOLY\n CLUB GRAFFITI\n OSCAR GOLD\n DYNAMITE TARZAN\n FELLOWSHIP AUTUMN\n GOLDFINGER SENSIBILITY\n STAR OPERATION\n BOUND 

In [296]:
res = execute_query("""SELECT
  f.title
FROM film AS f
JOIN inventory AS i
  ON f.film_id = i.film_id
JOIN rental AS r
  ON i.inventory_id = r.inventory_id
GROUP BY
  f.film_id
HAVING
  COUNT(r.rental_id) > (
  SELECT
    AVG(rental_count)
  FROM (
    SELECT
      COUNT(r2.rental_id) AS rental_count
    FROM film AS f2
    JOIN inventory AS i2
      ON f2.film_id = i2.film_id
    JOIN rental AS r2
      ON i2.inventory_id = r2.inventory_id
    GROUP BY
      f2.film_id
  ) AS film_rental_counts
);""".replace("\n"," "))

print(res)

            title            
-----------------------------
 EFFECT GLADIATOR
 BALLOON HOMEWARD
 VOYAGE LEGALLY
 BIKINI BORROWERS
 GARDEN ISLAND
 CONGENIALITY QUEST
 EXCITEMENT EVE
 PATIENT SISTER
 AMISTAD MIDSUMMER
 EASY GLADIATOR
 HILLS NEIGHBORS
 MALTESE HOPE
 SATURDAY LAMBS
 OPERATION OPERATION
 SEATTLE EXPECATIONS
 SHOW LORD
 MILLION ACE
 BASIC EASY
 AMADEUS HOLY
 CLUB GRAFFITI
 OSCAR GOLD
 DYNAMITE TARZAN
 FELLOWSHIP AUTUMN
 GOLDFINGER SENSIBILITY
 STAR OPERATION
 BOUND CHEAPER
 SLACKER LIAISONS
 FATAL HAUNTED
 BANGER PINOCCHIO
 CONTACT ANONYMOUS
 ROOM ROMAN
 ALADDIN CALENDAR
 ARACHNOPHOBIA ROLLERCOASTER
 DEER VIRGINIAN
 SOUTH WAIT
 CONFUSED CANDLES
 MOONSHINE CABIN
 ENGLISH BULWORTH
 HALF OUTFIELD
 FURY MURDER
 TRAMP OTHERS
 ROSES TREASURE
 LOUISIANA HARRY
 HOMICIDE PEACH
 OPEN AFRICAN
 ROXANNE REBEL
 DESERT POSEIDON
 SECRETARY ROUGE
 COAST RAINBOW
 UNDEFEATED DALMATIONS
 KISSING DOLLS
 MADNESS ATTACKS
 ALTER VICTORY
 CHILL LUCK
 MASSACRE USUAL
 KICK SAVANNAH
 BEACH HEARTBREAKER

## Loading the Evaluation Dataset

In [152]:
import pandas as pd

file_path = r".\Pagila Evals Dataset(Sheet1).csv"

# Read CSV with an alternate encoding
test_queries = pd.read_csv(file_path, encoding="ISO-8859-1")  # or encoding="latin1"
print(test_queries.head())


   Query Number                             Natural Language Query Difficulty
0             1             List all actors' first and last names.       Easy
1             2      Show the titles of all films in the database.       Easy
2             3                       Get the names of all cities.       Easy
3             4           List all categories available for films.       Easy
4             5  Show the first name and last name of all custo...       Easy


In [153]:
test_queries.head()

Unnamed: 0,Query Number,Natural Language Query,Difficulty
0,1,List all actors' first and last names.,Easy
1,2,Show the titles of all films in the database.,Easy
2,3,Get the names of all cities.,Easy
3,4,List all categories available for films.,Easy
4,5,Show the first name and last name of all custo...,Easy


## Inferencing the agent


In [155]:
# Ensure columns exist before iteration

if "sql_gen_query" not in test_queries.columns:
    test_queries["sql_gen_query"] = np.nan

if "results" not in test_queries.columns:
    test_queries["results"] = np.nan

# Iterate over DataFrame with progress bar
for index, row in tqdm(test_queries.iterrows(), total=len(test_queries), desc="Processing Queries"):
    if pd.isna(row["sql_gen_query"]) or (isinstance(row["sql_gen_query"], str) and "(0 rows)" in row["sql_gen_query"]) or (pd.notna(row["results"]) and isinstance(row["results"], str) and "ERROR" in row["results"]):
        sql, output = process_query(row["Natural Language Query"])
        # print(output, "\n\n\n")

        # Update the DataFrame in place
        test_queries.at[index, "sql_gen_query"] = sql
        test_queries.at[index, "results"] = output
        
        if index%5 == 0:
            print(f"Sleeping for 5 seconds at index {index}") # Scleeping to 
            time.sleep(5)
            

# Save to CSV after updating all rows
test_queries.to_csv("inference_results.csv", index=False)


Processing Queries:  70%|███████   | 28/40 [00:02<00:01, 11.29it/s]

Error running query: ERROR:  function date(unknown, unknown) does not exist
LINE 1: SELECT * FROM customer WHERE create_date >= date('now', '-6 ...
                                                    ^
HINT:  No function matches the given name and argument types. You might need to add explicit type casts.

Error running query: ERROR:  function date(unknown, unknown) does not exist
LINE 1: SELECT * FROM customer WHERE create_date >= DATE('now', '-6 ...
                                                    ^
HINT:  No function matches the given name and argument types. You might need to add explicit type casts.

Error running query: ERROR:  column "create_date" does not exist
LINE 1: SELECT * FROM staff WHERE create_date < '2020-01-01'
                                  ^



Processing Queries:  75%|███████▌  | 30/40 [00:13<00:05,  1.76it/s]

Sleeping for 5 seconds at index 30


Processing Queries:  78%|███████▊  | 31/40 [00:21<00:09,  1.08s/it]

Error running query: ERROR:  column rental.id does not exist
LINE 1: ...film_id JOIN rental   ON inventory.inventory_id = rental.id ...
                                                             ^



Processing Queries:  88%|████████▊ | 35/40 [00:35<00:09,  1.81s/it]

Sleeping for 5 seconds at index 35


Processing Queries: 100%|██████████| 40/40 [00:56<00:00,  1.42s/it]


In [163]:
error_rows = test_queries[test_queries["results"].astype(str).str.contains(r"Error", case=False, na=False)]

In [164]:
error_rows

Unnamed: 0,Query Number,Natural Language Query,Difficulty,sql_gen_query,results


### Since we can see that there are no error_rows, we can consider that the errors in sql syntax is resolved by reattempts by the process_query()

In [165]:
for index, row in error_rows.iterrows():
    print(f"Natural Language Query:\n{row['Natural Language Query']}\n")
    print(f"sql_gen_query:\n{row['sql_gen_query']}\n")
    print(f"results: {row['results']}\n")
    

## Checking the inferenced results


In [300]:

import pandas as pd

file_path = r".\inferenced_results.csv"

# Read CSV with an alternate encoding
inferenced_df = pd.read_csv(file_path, encoding="ISO-8859-1")  # or encoding="latin1"
print(inferenced_df.info())


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 40 entries, 0 to 39
Data columns (total 6 columns):
 #   Column                  Non-Null Count  Dtype 
---  ------                  --------------  ----- 
 0   Query Number            40 non-null     int64 
 1   Natural Language Query  40 non-null     object
 2   Difficulty              40 non-null     object
 3   Query                   40 non-null     object
 4   sql_gen_query           40 non-null     object
 5   results                 40 non-null     object
dtypes: int64(1), object(5)
memory usage: 2.0+ KB
None


In [301]:
inferenced_df.head()

Unnamed: 0,Query Number,Natural Language Query,Difficulty,Query,sql_gen_query,results
0,1,List all actors' first and last names.,Easy,List all actors' first and last names.,"SELECT first_name, last_name FROM actor;",first_name | last_name \n-------------+--...
1,2,Show the titles of all films in the database.,Easy,Show the titles of all films in the database.,SELECT title FROM film;,title \n---------------...
2,3,Get the names of all cities.,Easy,Get the names of all cities.,SELECT city FROM city;,city \n----------------...
3,4,List all categories available for films.,Easy,List all categories available for films.,SELECT name FROM category;,name \n-------------\n Action\n Animat...
4,5,Show the first name and last name of all custo...,Easy,Show the first name and last name of all custo...,"SELECT first_name, last_name FROM customer;",first_name | last_name \n-------------+--...


## Checking the inferenced results with `0 rows` output

In [168]:
check = inferenced_df.loc[28,"results"]

print(check)

 customer_id | store_id | first_name | last_name | email | address_id | activebool | create_date | last_update | active 
-------------+----------+------------+-----------+-------+------------+------------+-------------+-------------+--------
(0 rows)




In [169]:
filtered_df = inferenced_df[inferenced_df["results"].str.contains(r"\(0 rows\)", regex=True, na=False)]
filtered_df


Unnamed: 0,Query Number,Natural Language Query,Difficulty,Query,sql_gen_query,results
6,7,"Find all actors with the last name ""Smith.""",Easy,"Find all actors with the last name ""Smith.""","SELECT actor_id, first_name, last_name\nFROM a...",actor_id | first_name | last_name \n---------...
7,8,List all customers who are from the city of Â...,Easy,List all customers who are from the city of Â...,"SELECT\n C.first_name,\n C.last_name\nFROM C...",first_name | last_name \n------------+-------...
16,17,"Show all actors who appeared in the film ""Ince...",Medium,"Show all actors who appeared in the film ""Ince...","SELECT A.first_name, A.last_name\nFROM actor A...",first_name | last_name \n------------+-------...
19,20,"Find all films rented by customer ""John Doe.""",Medium,"Find all films rented by customer ""John Doe.""",SELECT F.title\nFROM film AS F\nJOIN inventory...,title \n-------\n(0 rows)\n\n
25,26,Show all rentals made in the last 7 days.,Medium,Show all rentals made in the last 7 days.,SELECT *\nFROM rental\nWHERE rental_date >= NO...,rental_id | rental_date | inventory_id | cust...
28,29,Find all customers who registered in the last ...,Medium,Find all customers who registered in the last ...,SELECT * FROM customer WHERE create_date >= DA...,customer_id | store_id | first_name | last_na...
31,32,Find the films that have been rented more time...,Hard,Find the films that have been rented more time...,SELECT film_id\nFROM inventory\nGROUP BY film_...,film_id \n---------\n(0 rows)\n\n
35,36,Find customers who have rented more films this...,Hard,Find customers who have rented more films this...,"SELECT c.first_name, c.last_name\nFROM custome...",first_name | last_name \n------------+-------...
37,38,"For each customer, show the number of films re...",Hard,"For each customer, show the number of films re...",WITH MonthlyRentals AS (\n SELECT\n ...,customer_id | rental_month | rental_count | p...
38,39,Show the names of customers who have rented ev...,Hard,Show the names of customers who have rented ev...,"SELECT c.first_name, c.last_name\nFROM custome...",first_name | last_name \n------------+-------...


### Manually checking the 0 rows results - Looks like the SQL query is correct


In [173]:
idx=7

# query = filtered_df.iloc[idx]["Query"]
query = "Find customers who have rented less films this year than last yea"
sql_query = filtered_df.iloc[idx]["sql_gen_query"]
print("Query:", query)
print("SQL:", sql_query, "\n\n\n")


result = execute_query(sql_query.replace("\n", " "))
print(result)

Query: Find customers who have rented less films this year than last yea
SQL: SELECT c.first_name, c.last_name
FROM customer c
JOIN rental r ON c.customer_id = r.customer_id
WHERE EXTRACT(YEAR FROM r.rental_date) = EXTRACT(YEAR FROM CURRENT_DATE)
GROUP BY c.customer_id, c.first_name, c.last_name
HAVING COUNT(DISTINCT r.rental_id) > (
    SELECT COUNT(DISTINCT r2.rental_id)
    FROM rental r2
    WHERE r2.customer_id = c.customer_id
      AND EXTRACT(YEAR FROM r2.rental_date) = EXTRACT(YEAR FROM CURRENT_DATE) - 1
); 



 first_name | last_name 
------------+-----------
(0 rows)




#### Query to sort by rental_date in descending order and limit the result to the top 5 using ORDER BY and LIMIT.

In [178]:
result = execute_query("""
    SELECT c.first_name, c.last_name, r.rental_date
    FROM customer c
    JOIN rental r ON c.customer_id = r.customer_id
    ORDER BY r.rental_date DESC
    LIMIT 5;
""".replace("\n", " "))

print(result)


 first_name | last_name |      rental_date       
------------+-----------+------------------------
 PHILIP     | CAUSEY    | 2022-08-23 21:50:12+00
 GLADYS     | HAMILTON  | 2022-08-23 21:43:07+00
 GRACE      | ELLIS     | 2022-08-23 21:42:48+00
 DENISE     | KELLY     | 2022-08-23 21:26:47+00
 BETTY      | WHITE     | 2022-08-23 21:25:26+00
(5 rows)




#### Last date is `2022-08-23 21:50:12+00`, so the prospect of choosing records from this year - 2025 and previous year - 2024 will always give us 0 rows, with this dataset


## LLM Judge - SQL Evaluation


In [179]:
def judge_sql_logic(nl_query: str, sql_query: str,DATABASE_SCHEMA=DATABASE_SCHEMA) -> str:
    """
    Uses Gemini to verify if the given SQL query correctly implements the natural language query logic.
    
    Args:
        nl_query (str): The natural language query.
        sql_query (str): The SQL query to check.
    
    Returns:
        str: A response indicating if the SQL query is logically correct, with reasoning.
    """
    agent = ChatGoogleGenerativeAI(model="gemini-pro")
    prompt = f""" 
        Check if the following SQL query correctly implements the given natural language request:

        NL Query: {nl_query}
        SQL Query: {sql_query}
        
        Provided Database Schema: 
        {DATABASE_SCHEMA}

        Provide a response indicating if it is logically correct, with reasoning.
        The scoring breakdown could be as follows:
        100 for fully correct queries.
        50 for queries that are logically correct but have minor errors.
        0 for queries that are incorrect or produce the wrong results
        
        The response should be in the following format:
        Score: 100
        Reasoning: The query is fully correct.
        Score: 50
        Reasoning: The query is logically correct but has minor errors. (with proper reasoning and improvements)
        Score: 0
        Reasoning: The query is incorrect or produces the wrong results. (with proper reasoning and improvements)

    """
            
    response = agent([HumanMessage(content=prompt)])
    
    response_lines = response.content.strip().split("\n")
    # print(response.content)
    score = (
        response_lines[0]
        .replace("Score: ", "")
        .strip("\"'")
    )
    reasoning = (
        response_lines[1].replace("Reasoning: ", "").strip("\"'")
        if len(response_lines) > 1
        else ""
    )
    return {
        "score": int(score),
        "reasoning": reasoning,
    }


In [180]:

idx=7

query = inferenced_df.iloc[idx]["Query"]
sql_query = inferenced_df.iloc[idx]["sql_gen_query"]
print("Query:", query)
print("SQL:", sql_query, "\n\n\n")


response = judge_sql_logic([query],[sql_query])

score = response["score"]
reasoning = response["reasoning"]

print("Score:", score)
print("Reasoning:", reasoning)


Query: List all customers who are from the city of ÂNew York.Â
SQL: SELECT
  C.first_name,
  C.last_name
FROM Customer AS C
JOIN Address AS A
  ON C.address_id = A.address_id
JOIN City AS CI
  ON A.city_id = CI.city_id
WHERE
  CI.city = 'New York'; 



Score: 100
Reasoning: The SQL query is fully correct. It performs a join between the Customer, Address, and City tables to retrieve the first name and last name of customers who are from the city of New York. The query uses the city_id column to link the Address table to the City table and the address_id column to link the Customer table to the Address table. The WHERE clause filters the results to include only customers who are from the city of New York.


In [182]:
# Ensure columns exist before iteration
if "reasoning" not in inferenced_df.columns:
    inferenced_df["reasoning"] = np.nan

if "score" not in inferenced_df.columns:
    inferenced_df["score"] = np.nan

# Iterate over the DataFrame with progress bar
for index, row in tqdm(inferenced_df.iterrows(), total=len(inferenced_df), desc="Evaluating SQL Queries"):
    if pd.isna(row["reasoning"]) or pd.isna(row["score"]):
        query = row["Query"]
        sql_query = row["sql_gen_query"]

        response = judge_sql_logic(query, sql_query)

        score = response["score"]
        reasoning = response["reasoning"]

        # Update the DataFrame in place
        inferenced_df.at[index, "score"] = score
        inferenced_df.at[index, "reasoning"] = reasoning
        
        if index%5 == 0:
            print(f"Sleeping for 5 seconds at index {index}") # Scleeping to 
            time.sleep(5)

# Save to CSV after updating all rows
inferenced_df.to_csv("evaluation_results.csv", index=False)


Evaluating SQL Queries:  12%|█▎        | 5/40 [00:12<01:33,  2.68s/it]

Sleeping for 5 seconds at index 5


Evaluating SQL Queries:  25%|██▌       | 10/40 [00:30<01:29,  2.98s/it]

Sleeping for 5 seconds at index 10


Evaluating SQL Queries:  38%|███▊      | 15/40 [00:51<01:27,  3.50s/it]

Sleeping for 5 seconds at index 15


Evaluating SQL Queries:  50%|█████     | 20/40 [01:10<01:02,  3.12s/it]

Sleeping for 5 seconds at index 20


Evaluating SQL Queries:  62%|██████▎   | 25/40 [01:31<00:55,  3.71s/it]

Sleeping for 5 seconds at index 25


Evaluating SQL Queries:  75%|███████▌  | 30/40 [01:51<00:36,  3.61s/it]

Sleeping for 5 seconds at index 30


Evaluating SQL Queries:  88%|████████▊ | 35/40 [02:12<00:17,  3.57s/it]

Sleeping for 5 seconds at index 35


Evaluating SQL Queries: 100%|██████████| 40/40 [02:33<00:00,  3.83s/it]


In [183]:
inferenced_df.head()

Unnamed: 0,Query Number,Natural Language Query,Difficulty,Query,sql_gen_query,results,reasoning,score
0,1,List all actors' first and last names.,Easy,List all actors' first and last names.,"SELECT first_name, last_name FROM actor;",first_name | last_name \n-------------+--...,The query is fully correct. It retrieves the f...,100.0
1,2,Show the titles of all films in the database.,Easy,Show the titles of all films in the database.,SELECT title FROM film;,title \n---------------...,The query is fully correct. It selects the tit...,100.0
2,3,Get the names of all cities.,Easy,Get the names of all cities.,SELECT city FROM city;,city \n----------------...,The query is fully correct. It retrieves the c...,100.0
3,4,List all categories available for films.,Easy,List all categories available for films.,SELECT name FROM category;,name \n-------------\n Action\n Animat...,The query is fully correct. It selects the nam...,100.0
4,5,Show the first name and last name of all custo...,Easy,Show the first name and last name of all custo...,"SELECT first_name, last_name FROM customer;",first_name | last_name \n-------------+--...,The query is fully correct. It selects the fir...,100.0


## Check evaluated results

In [184]:

import pandas as pd

file_path = r".\evaluation_results.csv"

# Read CSV with an alternate encoding
evaluated_df = pd.read_csv(file_path, encoding="ISO-8859-1")  # or encoding="latin1"
print(evaluated_df.info())


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 40 entries, 0 to 39
Data columns (total 8 columns):
 #   Column                  Non-Null Count  Dtype  
---  ------                  --------------  -----  
 0   Query Number            40 non-null     int64  
 1   Natural Language Query  40 non-null     object 
 2   Difficulty              40 non-null     object 
 3   Query                   40 non-null     object 
 4   sql_gen_query           40 non-null     object 
 5   results                 40 non-null     object 
 6   reasoning               40 non-null     object 
 7   score                   40 non-null     float64
dtypes: float64(1), int64(1), object(6)
memory usage: 2.6+ KB
None


In [297]:
evaluated_df

Unnamed: 0,Query Number,Natural Language Query,Difficulty,Query,sql_gen_query,results,reasoning,score
0,1,List all actors' first and last names.,Easy,List all actors' first and last names.,"SELECT first_name, last_name FROM actor;",first_name | last_name \n-------------+--...,The query is fully correct. It retrieves the f...,100.0
1,2,Show the titles of all films in the database.,Easy,Show the titles of all films in the database.,SELECT title FROM film;,title \n---------------...,The query is fully correct. It selects the tit...,100.0
2,3,Get the names of all cities.,Easy,Get the names of all cities.,SELECT city FROM city;,city \n----------------...,The query is fully correct. It retrieves the c...,100.0
3,4,List all categories available for films.,Easy,List all categories available for films.,SELECT name FROM category;,name \n-------------\n Action\n Animat...,The query is fully correct. It selects the nam...,100.0
4,5,Show the first name and last name of all custo...,Easy,Show the first name and last name of all custo...,"SELECT first_name, last_name FROM customer;",first_name | last_name \n-------------+--...,The query is fully correct. It selects the fir...,100.0
5,6,Show all films released in 2006.,Easy,Show all films released in 2006.,SELECT * FROM film WHERE release_year = 2006;,film_id | title | ...,The query is fully correct. It selects all row...,100.0
6,7,"Find all actors with the last name ""Smith.""",Easy,"Find all actors with the last name ""Smith.""","SELECT actor_id, first_name, last_name\nFROM a...",actor_id | first_name | last_name \n---------...,The SQL query is fully correct. It retrieves a...,100.0
7,8,List all customers who are from the city of Ã...,Easy,List all customers who are from the city of Ã...,"SELECT\n C.first_name,\n C.last_name\nFROM C...",first_name | last_name \n------------+-------...,The query is fully correct. It identifies the ...,100.0
8,9,Get all stores located in the country ÃÂIndi...,Easy,Get all stores located in the country ÃÂIndi...,SELECT S.store_id\nFROM store AS S\nJOIN addre...,store_id \n----------\n 13\n 18\n...,The query is fully correct. It retrieves the s...,100.0
9,10,Show all films with a rental rate greater than...,Easy,Show all films with a rental rate greater than...,"SELECT film_id, title, rental_rate FROM film W...",film_id | title | rental_...,The query is fully correct. It selects all fil...,100.0


In [298]:
mean_score = evaluated_df["score"].mean()

print(f"Mean Score of SQL Evaluation:\n\n{mean_score}")

Mean Score of SQL Evaluation:

100.0


In [187]:

# Filter rows where the score is not 100
filtered_df = evaluated_df[evaluated_df["score"] != 100]

# Display the filtered DataFrame
filtered_df

Unnamed: 0,Query Number,Natural Language Query,Difficulty,Query,sql_gen_query,results,reasoning,score


In [188]:
for index, row in filtered_df.iterrows():
    print(f"Query: {row['Query']}\n")
    print(f"SQL Generated:\n{row['sql_gen_query']}\n")
    print(f"Reasoning:\n{row['reasoning']}\n")
    print("=" * 80)  # Separator for readability
