In [None]:
! pip install --upgrade --quiet langchain langchain-core langchain-groq langchain-community langchain-openai


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.5/2.5 MB[0m [31m147.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m64.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/55.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.3/55.3 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.9/121.9 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m53.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.9/50.9 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import re
import json
import time
import math
import ast
from google.colab import userdata
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI


In [None]:

# 1. Environment Setup

# API Key Management
api_keys = [userdata.get(f'GROQ_API_KEY_{i}') for i in range(4)]
current_api_index = 0

def switch_api_key():
    global current_api_index
    if current_api_index < len(api_keys) - 1:
        current_api_index += 1
        os.environ['GROQ_API_KEY'] = api_keys[current_api_index]
        print(f"Switched to API Key {current_api_index}")
    else:
        print("Reached the last API key. No further switching.")
        return False
    return True


# Initialize LLM
os.environ['GROQ_API_KEY'] = api_keys[current_api_index]
llm = ChatGroq(
    temperature=0,
    model_name="llama-3.3-70b-versatile"
)


In [None]:
prompt = ChatPromptTemplate.from_messages([
    ("system", """
    You are an AI assistant specializing in SQL queries for demographic and social data analysis.
    Follow these specific rules when converting natural language to SQL:

    1. Column Naming:
       - CRITICAL: Column names with hyphens MUST be wrapped in double quotes like "Poverty Rate - Marion County"
       - NEVER use underscores to replace hyphens in column names (do NOT use Poverty_Rate_Marion_County)
       - Example correct format: "Poverty Rate - Marion - Black", NOT Poverty_Rate_-_Marion_-_Black
       - When referencing counties, use only the base name in the COUNTY column (e.g., 'Adams' not 'Adams County')
       - For county-specific queries, use COUNTY = 'CountyName' AND STATE = 'StateName'

    2. Percentage Values:
       - Percentage columns store values as whole numbers (e.g., 25.5 means 25.5%)
       - When filtering with percentage thresholds, use the actual number (e.g., > 25, not > 0.25)

    3. Time-based Queries:
       - Always include YEAR column in WHERE clause when years are mentioned
       - CRITICAL: Always include YEAR in the SELECT clause when filtering by years
       - When showing data across years, YEAR should always be the first column in the results
       - For year ranges, use YEAR BETWEEN start_year AND end_year

    4. Column Selection and Result Formatting:
       - Include ONLY the columns specifically mentioned in the query or needed for the answer
       - For "highest" or "top" requests, use ORDER BY column DESC LIMIT n
       - For "lowest" requests, use ORDER BY column ASC LIMIT n
       - When asked for "last n years", order by YEAR DESC LIMIT n
       - When asked for differences between values, ensure proper column name quoting

    5. Database-Specific:
       - Always return exactly what is requested - don't add extra columns unless needed for context
    """),
    ("human", "{input}")
])

# Load Table Metadata
with open("/content/table_metadata.txt", "r") as file:
    table_metadata = json.load(file)

# Connect to SQLite Database
db = SQLDatabase.from_uri("sqlite:///savi_new.db")

# Create SQL Agent
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)

In [None]:

# 2. Query Processing Functions

def generate_response(user_input):
    """Generate response from LLM with API key switching"""
    formatted_prompt = prompt.format(input=user_input)
    while True:
        try:
            response = llm.invoke([formatted_prompt])
            return response
        except Exception as e:
            error_message = str(e)
            if "rate_limit_reached" in error_message or "Limit" in error_message:
                print("Rate limit reached. Switching API key...")
                if not switch_api_key():
                    return None, "Error: Rate limit reached on last API key. Stopping."
                time.sleep(5)
            else:
                return None, f"Error processing request: {str(e)}"


def query_data(user_input):
    """Process user input and generate SQL query"""
    schema_instruction = "You can choose from the following tables and columns:\n"
    for table_name, meta in table_metadata.items():
        schema_instruction += f"Table: {table_name}\n"
        schema_instruction += f"Description: {meta['description']}\n"
        schema_instruction += "Columns:\n"
        for column_name, column_meta in meta['columns'].items():
            schema_instruction += f"  - {column_name}: {column_meta['description']}\n"
        schema_instruction += "\n"

    full_input = f"{schema_instruction}\n{user_input}"
    try:
        response = generate_response(full_input)
        response_text = response.content if response else ""
        sql_match = re.search(r'```sql\n(.*?)\n```', response_text, re.DOTALL)
        sql_query = sql_match.group(1).strip() if sql_match else response_text.strip()
        query_response = db.run(sql_query)
        return sql_query, query_response
    except Exception as e:
        return None, f"Error processing request: {str(e)}"


def process_query(user_query):
    """Helper function to process queries and handle output"""
    sql_query, response = query_data(user_query)
    if sql_query is None:
        print("Error:", response)
    else:
        print("SQL Query:", sql_query)
        print("Query Response:", response)



In [None]:
# Load Test Cases
def load_test_cases(file_name):
    try:
        with open(file_name, 'r') as file:
            return json.load(file)
    except Exception as e:
        print(f"Error loading JSON: {e}")
        return []


def normalize_result(result, tolerance=1e-5):
    """Normalize SQL query result for comparison, with tolerance for floating-point values"""
    if isinstance(result, str):
        try:
            result = ast.literal_eval(result)
        except Exception:
            pass

    if isinstance(result, list):
        return [normalize_result(item, tolerance) for item in result]

    if isinstance(result, tuple):
        return [normalize_result(item, tolerance) for item in result]

    if isinstance(result, float):
        return round(result, 5)

    return result

def are_results_equal(result1, result2, tolerance=1e-5):
    """Compare results with proper handling of floating point values"""
    if isinstance(result1, list) and isinstance(result2, list):
        if len(result1) != len(result2):
            return False
        return all(are_results_equal(r1, r2, tolerance) for r1, r2 in zip(result1, result2))

    if isinstance(result1, float) and isinstance(result2, float):
        return abs(result1 - result2) < tolerance

    return result1 == result2


def process_test_case(case, idx, tolerance=1e-5):
    """Run a single test case with improved floating-point comparison"""
    global successful_cases

    user_query = case["user_query"]
    expected_sql = case["expected_sql"]
    expected_result = case["expected_result"]

    try:
        sql_query, result = query_data(user_query)
        normalized_generated_result = normalize_result(result, tolerance)
        normalized_expected_result = normalize_result(expected_result, tolerance)

        if are_results_equal(normalized_generated_result, normalized_expected_result, tolerance):
            successful_cases += 1
            print(f"Test Case {idx}: SUCCESS")
        else:
            print(f"Test Case {idx}: FAIL")
    except Exception as e:
        print(f"Execution Error for Test Case {idx}: {str(e)}")
        print(f"Test Case {idx}: FAIL")
        return

    print("=" * 50)
    print(f"Test Case {idx}")
    print(f"Natural Language Query: {user_query}")
    print(f"Generated SQL Query: {sql_query}")
    print(f"Expected SQL Query: {expected_sql}")
    print(f"Generated Result: {normalized_generated_result}")
    print(f"Expected Result: {normalized_expected_result}")
    print("=" * 50)

In [None]:

successful_cases = 0

# Load Test Cases
test_file = input("Enter the test case file name: ")
test_cases = load_test_cases(test_file)

# Run Test Cases
for idx, case in enumerate(test_cases, 1):
    process_test_case(case, idx)

total_cases = len(test_cases)
execution_accuracy = (successful_cases / total_cases) * 100 if total_cases > 0 else 0
execution_accuracy = min(execution_accuracy, 100)

# Display Results
print("\n" + "=" * 50)
print(f"Successful Cases: {successful_cases}")
print(f"Total Cases: {total_cases}")
print(f"Execution Accuracy: {execution_accuracy:.2f}%")
print("=" * 50)
print("Final Results:")
print(f"Overall Execution Accuracy (EX): {execution_accuracy:.2f}%")
print("=" * 50)


Enter the test case file name: savi_basic_needs.json
Test Case 1: SUCCESS
Test Case 1
Natural Language Query: What is the average poverty rate from 2010 to 2012?
Generated SQL Query: SELECT AVG("Poverty Rate - Indiana State") 
FROM savi_basic_needs_data 
WHERE YEAR BETWEEN 2010 AND 2012;
Expected SQL Query: SELECT AVG("Poverty Rate - Indiana State") FROM savi_basic_needs_data WHERE "Year" BETWEEN 2010 AND 2012;
Generated Result: [[14.14678]]
Expected Result: [[14.14678]]
Test Case 2: SUCCESS
Test Case 2
Natural Language Query: Find the year with the highest poverty rate for Black individuals in Marion County
Generated SQL Query: SELECT YEAR, "Poverty Rate - Marion - Black" 
FROM savi_basic_needs_data 
ORDER BY "Poverty Rate - Marion - Black" DESC 
LIMIT 1;
Expected SQL Query: SELECT Year, `Poverty Rate - Marion - Black` FROM savi_basic_needs_data ORDER BY `Poverty Rate - Marion - Black` DESC LIMIT 1;
Generated Result: [[2015, 29.51482]]
Expected Result: [[2015, 29.51482]]
Test Case 3: 

In [None]:

successful_cases = 0

# Load Test Cases
test_file = input("Enter the test case file name: ")
test_cases = load_test_cases(test_file)

# Run Test Cases
for idx, case in enumerate(test_cases, 1):
    process_test_case(case, idx)

total_cases = len(test_cases)
execution_accuracy = (successful_cases / total_cases) * 100 if total_cases > 0 else 0
execution_accuracy = min(execution_accuracy, 100)

# Display Results
print("\n" + "=" * 50)
print(f"Successful Cases: {successful_cases}")
print(f"Total Cases: {total_cases}")
print(f"Execution Accuracy: {execution_accuracy:.2f}%")
print("=" * 50)
print("Final Results:")
print(f"Overall Execution Accuracy (EX): {execution_accuracy:.2f}%")
print("=" * 50)


Enter the test case file name: svi_indiana_in.json
Test Case 1: SUCCESS
Test Case 1
Natural Language Query: List 5 counties that have the highest unemployment rate.
Generated SQL Query: SELECT COUNTY, EP_UNEMP
FROM svi_indiana_in
ORDER BY EP_UNEMP DESC
LIMIT 5;
Expected SQL Query: SELECT COUNTY, EP_UNEMP FROM svi_indiana_in ORDER BY EP_UNEMP DESC LIMIT 5;
Generated Result: [['Lake', 43.6], ['Marion', 35.1], ['Lake', 28.9], ['Lake', 28.6], ['Monroe', 26.2]]
Expected Result: [['Lake', 43.6], ['Marion', 35.1], ['Lake', 28.9], ['Lake', 28.6], ['Monroe', 26.2]]
Test Case 2: SUCCESS
Test Case 2
Natural Language Query: Which counties have a mobile phone penetration rate greater than 70%?
Generated SQL Query: SELECT COUNTY 
FROM svi_indiana_in 
WHERE EP_MOBILE > 70;
Expected SQL Query: SELECT COUNTY FROM svi_indiana_in WHERE EP_MOBILE > 70;
Generated Result: [['Allen'], ['Marion']]
Expected Result: [['Allen'], ['Marion']]
Test Case 3: SUCCESS
Test Case 3
Natural Language Query: Which counties 