Import nessccaary Library

In [1]:
import json
import pandas as pd
import os
import google.generativeai as genai
import sqlparse
from sqlparse.sql import Identifier, IdentifierList
from sqlparse.tokens import Keyword, DML

  from .autonotebook import tqdm as notebook_tqdm


LOADING SCHEMA AND AI MODEL CONFIGURATION

In [6]:
with open("simple_enterprise_schema.json", "r") as f:
    data = json.load(f)

rows = []
for table in data.get("tables", []):
    table_name = table.get("table_name")
    table_desc = table.get("description")
    columns = table.get("columns", {})
    for column_name, column_desc in columns.items():
        rows.append({
            "table_name": table_name,
            "table_description": table_desc,
            "column_name": column_name,
            "column_description": column_desc
        })
schema_df = pd.DataFrame(rows)
print("Schema loaded successfully.")


api_key = # I use Gemini API key

if not api_key == "AIzaSyCXkL5SCunfH0DNd7qXpjm6zDr08-H3mdM":
    print("ERROR: API key not set. Please replace 'YOUR_API_KEY_HERE' with your actual key.")
else:
    genai.configure(api_key=api_key)
    llm_model = genai.GenerativeModel('gemini-1.5-flash')
    print("AI Model configured successfully.")


Schema loaded successfully.
AI Model configured successfully.


CORE LOGIC (GENERATOR AND VALIDATOR)

In [7]:
def create_sql_prompt(schema_df, user_query):
     # This class is from Assignment 1
    prompt = """
    You are an expert SQL developer. Your task is to write a SQL query based on the provided database schema and a user's question.
    Follow these rules strictly:
    1. Only use the tables and columns listed in the schema below.
    2. If a user's question involves a name (like "John Doe"), you must perform a lookup on the 'users' table.
    3. The generated SQL must be syntactically correct for a standard SQL database.
    4. If the question cannot be answered with the given schema, respond with "Error: Cannot answer query with the provided schema."
    ---
    DATABASE SCHEMA:
    {schema}
    ---
    USER QUESTION:
    "{query}"
    ---
    SQL QUERY:
    """
    schema_string = schema_df.to_string()
    return prompt.format(schema=schema_string, query=user_query)

class SQLValidator:
    # This class is from Assignment 2
    def __init__(self, schema_df):
        self.schema_df = schema_df
        self.known_tables = set(schema_df['table_name'].unique())
        self.table_columns = {
            table: set(group['column_name']) 
            for table, group in schema_df.groupby('table_name')
        }
        self.unsafe_keywords = {'DELETE', 'UPDATE', 'DROP', 'INSERT', 'TRUNCATE', 'GRANT', 'REVOKE'}

    def validate(self, sql_query):
        is_safe, message = self._is_safe(sql_query)
        if not is_safe:
            return False, message
        try:
            is_valid, message = self._are_tables_and_columns_valid(sql_query)
            if not is_valid:
                return False, message
        except Exception as e:
            return False, f"Syntactic Error: Failed to parse query. Details: {e}"
        return True, "Validation successful: Query is safe and aligned with the schema."

    def _is_safe(self, sql_query):
        parsed = sqlparse.parse(sql_query)
        for token in parsed[0].tokens:
            if token.ttype in DML and token.value.upper() in self.unsafe_keywords:
                return False, f"Validation failed: Unsafe keyword '{token.value.upper()}' found."
        return True, "Query is safe."

    def _are_tables_and_columns_valid(self, sql_query):
        parsed = sqlparse.parse(sql_query)[0]
        tables = {t.get_real_name() for t in parsed.flatten() if isinstance(t, Identifier) and t.is_groupable}
        
        for table in tables:
            if table not in self.known_tables:
                if not any(table in s for s in self.table_columns.values()):
                     return False, f"Validation failed: Table '{table}' not found in schema."

        columns = {c.get_real_name() for c in parsed.flatten() if isinstance(c, Identifier) and not c.is_groupable}
        for col in columns:
            found = False
            for table in tables:
                if table in self.table_columns and col in self.table_columns[table]:
                    found = True
                    break
            if not found:
                if '.' not in col and col != '*':
                    is_in_schema = any(col in v for v in self.table_columns.values())
                    if not is_in_schema:
                       return False, f"Validation failed: Column '{col}' could not be associated with any table in the query."

        return True, "Tables and columns are valid."

print("Generator and Validator function defined successfully.")

Generator and Validator function defined successfully.


INTEGRATED Text2SQL PIPELINE (The main function that takes a user query and returns a validated SQL query.)

In [None]:
def text_to_sql_pipeline(user_query, schema_df, model):
    print(f"\n1. Received User Query: '{user_query}'")
    
    # Generate SQL 
    prompt = create_sql_prompt(schema_df, user_query)
    response = model.generate_content(prompt)
    generated_sql = response.text.strip().replace("```sql", "").replace("```", "")
    print(f"2. Generated SQL from LLM:\n{generated_sql}")
    
    # Validate SQL 
    print("3. Initializing SQL Validator...")
    validator = SQLValidator(schema_df)
    is_valid, message = validator.validate(generated_sql)
    print(f"4. Validation Result: {'PASS' if is_valid else 'FAIL'}")
    print(f"   Message: {message}")
    
    # Final Output 
    if is_valid:
        print("5. FINAL DECISION: Query is APPROVED for execution.")
        return generated_sql
    else:
        print("5. FINAL DECISION: Query is REJECTED.")
        return None

#run the function with a valid user query 
approved_sql = text_to_sql_pipeline( user_query="List all open incidents reported by John Doe.", schema_df=schema_df, model=llm_model)

print("-" * 30)

# test a query that will generate unsafe SQL 
rejected_sql = text_to_sql_pipeline(
    user_query="Delete all logs related to the user 'admin'.",
    schema_df=schema_df,
    model=llm_model
)


1. Received User Query: 'List all open incidents reported by John Doe.'
2. Generated SQL from LLM:

SELECT i.*
FROM incidents i
JOIN users u ON i.user_id = u.user_id
WHERE u.name = 'John Doe' AND i.status = 'open';

3. Initializing SQL Validator...
4. Validation Result: PASS
   Message: Validation successful: Query is safe and aligned with the schema.
5. FINAL DECISION: Query is APPROVED for execution.
------------------------------

1. Received User Query: 'Delete all logs related to the user 'admin'.'
2. Generated SQL from LLM:

DELETE FROM logs
WHERE user_id = (SELECT user_id FROM users WHERE name = 'admin');

3. Initializing SQL Validator...
4. Validation Result: FAIL
   Message: Validation failed: Unsafe keyword 'DELETE' found.
5. FINAL DECISION: Query is REJECTED.
