In [None]:
from dotenv import load_dotenv
load_dotenv()
from google import genai
from google.genai import types
from typing import Union, List, Dict
from sqlalchemy import create_engine,inspect,Table,text,MetaData
import os
from dotenv import load_dotenv
from urllib.parse import quote_plus
from chromadb import Documents, EmbeddingFunction, Embeddings
import chromadb
from prompt import SQL_AGENT_PROMPT
from sqlalchemy.schema import CreateTable
import json
JSON_FILE_PATH = "examples.json"

In [2]:
MODEL= "gemini-2.5-flash"
CLIENT = genai.Client(api_key=os.getenv("gemini_api"))
EMBEDDING_MODEL = "text-embedding-004"

In [3]:
server = os.getenv("SQLSERVER")
database = os.getenv("DATABASE")
username = os.getenv("SERVER_USERNAME")
password = os.getenv("SERVER_PASSWORD")

params = quote_plus(
    f"DRIVER=ODBC Driver 17 for SQL Server;"
    f"SERVER={server};"
    f"DATABASE={database};"
    f"UID={username};"
    f"PWD={password};"
)
engine = create_engine(f"mssql+pyodbc:///?odbc_connect={params}")

In [None]:
# semantic search
class GeminiEmbeddingFunction(EmbeddingFunction):
    def __call__(self, input: Documents) -> Embeddings:
        task_type = "RETRIEVAL_DOCUMENT"
        if len(input) == 1: 
             task_type = "RETRIEVAL_QUERY"

        response = CLIENT.models.embed_content(
            model = EMBEDDING_MODEL,
            contents = input,
            config = types.EmbedContentConfig(
                task_type = task_type,
                title = "SQL Examples" if task_type == "RETRIEVAL_DOCUMENT" else None
            )
        )
        return [e.values for e in response.embeddings]

def setup_vector_db():
    chroma_client = chromadb.Client()
    collection = chroma_client.get_or_create_collection(
        name = "sql_examples",
        embedding_function = GeminiEmbeddingFunction()
    )
    with open(JSON_FILE_PATH,"r") as f:
        examples = json.load(f)

    ids = [str(i) for i in range(len(examples))]
    documents = [ex["question"] for ex in examples]
    meta_datas = [{"sql_query":ex["query"]} for ex in examples]

    collection.add(ids = ids,
     documents = documents,
    metadatas = meta_datas)
    return collection

def get_similar_examples(user_query,collection, k=2):
    result = collection.query(
        query_texts = [user_query],
        n_results = k
    )

    if not result["documents"]:
        return []

    retrieved_question = result['documents'][0]
    retrieved_metadata = result['metadatas'][0]
    
    formatted_results = []

    for i in range(len(retrieved_question)):
        match = {
            "similar_question": retrieved_question[i],
            "suggested_sql": retrieved_metadata[i]['sql_query']
        }
        formatted_results.append(match)
    return formatted_results


In [None]:
query_instruction = """
{query}
Double check the {dialect} query above for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

Output the final SQL query only.

SQL Query: """

# creation of tools
get_table_names_definition = {
    "name": "get_table_names",
    "description": "Returns a list of available tables in the database.",
    "parameters": {
        "type": "object",
        "properties": {
            "question": {
                "type": "string",
                "description": "The original user question to provide context."
            }
        },
        "required": ["question"],
    },
}

def get_table_names(question: str):
    """Returns the list of table names."""
    inspector = inspect(engine)
    return inspector.get_table_names()

get_schema_definition = {
    "name": "get_schema",
    "description": "Takes a list of table names and returns their schema (CREATE TABLE statements).",
    "parameters": {
        "type": "object",
        "properties": {
            "table_names": {
                "type": "array",
                "description": "List of table names to get schema for.",
                "items": {"type": "string"}
            },
        },
        "required": ["table_names"]
    },
}

def get_schema(table_names: list[str]) -> list[str]:
    """Returns schema for specific tables."""
    schema = []
    metadata = MetaData()
    metadata.reflect(engine)
    # print("reflectable tables: ",metadata.tables.keys())
    for table_name in table_names:
        try:
            table = Table(table_name, metadata, autoload_with=engine)
            ddl = str(CreateTable(table).compile(engine))
            schema.append(ddl)
        except Exception as e:
            schema.append(f"Error getting schema for {table_name}: {e}")
    return schema

query_checker_definitions = {
    "name": "query_checker",
    "description": "Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with run_sql_query!",
    "parameters": {
        "type": "object",
        "properties": {
            # FIX: Added the "query" key here
            "query": {
                "type": "string",
                "description": "The SQL query generated by the model that needs validation."
            }
        },
        "required": ["query"],
    }
}
if 'engine' in globals():
    dialect = engine.dialect.name 
else:
    "standard SQL"
    
def query_checker(query: str) -> str:
    """
    Validates a SQL query using a separate LLM call.
    """
    contents_ = query_instruction.format(query=query, dialect=dialect)
    qc_response = CLIENT.models.generate_content(
        model="gemini-2.0-flash", 
        contents=contents_
    )
    
    return qc_response.text

run_query_definitions = {
    "name": "run_query",
    "description": "Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use get_schema to query the correct table fields.", # FIX: Added missing comma
    "parameters": {
        "type": "object",
        "properties": {
            "query": {
                "type": "string",
                "description": "Sql query checking.",
            }
        },
        "required": ["query"],
    }
}

def run_query(query: str) -> list[dict]:
    """
    Executes the SQL query provided by the agent.
    """
    try:
        with engine.connect() as conn:
            stmt = text(query) 
            result = conn.execute(stmt)
            rows = [dict(row) for row in result.mappings()]
        return rows
    except Exception as e:
        return f"Error: {e}"

# semantic search function 
def similarity_search(query):
    db_collection = setup_vector_db()
    question = query
    matches = get_similar_examples(question,db_collection, k=2)
    context_str = ""
    for m in matches:
        context_str += f"Example Q: {m['similar_question']}\nExample SQL: {m['suggested_sql']}\n\n"
    # print(context_str)
    return context_str

system_instruction = SQL_AGENT_PROMPT


In [None]:
# planner tool
from pydantic import BaseModel, Field
from typing import List
from google.genai import types

class PlanStep(BaseModel):
    step: int = Field(..., description="The step number")
    desc: str = Field(..., description="Short, precise description of the operation")

class QueryPlan(BaseModel):
    plan: List[PlanStep]


planner_definitions = {
    "name": "sql_planner",
    "description": "Decomposes a complex user question into a logical, ordered list of steps. Use this tool first to plan the strategy.",
    "parameters": {
        "type": "object",
        "properties": {
            "schema_context": {
                "type": "string",
                "description": "The database schema (tables, columns, relationships)."
            },
            "user_question": {
                "type": "string",
                "description": "The original question asked by the user."
            }
        },
        "required": ["schema_context", "user_question"],
    }
}

def generate_sql_plan(schema_context: str, user_question: str) -> List[dict]:
    """
    Generates a step-by-step execution plan.
    """
    prompt = f"""
    You are a SQL planning assistant. Given a database schema and a user question, 
    output a JSON ordered list of steps describing the decomposition to compute the answer.
    
    Schema:
    {schema_context}

    User question:
    "{user_question}"
    """
    config = types.GenerateContentConfig(
        response_mime_type="application/json",
        response_schema=QueryPlan,
        temperature=0.2
    )

    try:
        response = CLIENT.models.generate_content(
            model=MODEL, 
            contents=prompt,
            config=config
        )
        result = QueryPlan.model_validate_json(response.text)
        
        return [step.model_dump() for step in result.plan]
        
    except Exception as e:
        return [{"step": 0, "desc": f"Error generating plan: {str(e)}"}]

In [None]:
# final output generation tool
import json
import os
from typing import Dict, Any

# 1. The Definition (Clean Dictionary)
final_result_definition = {
    "name": "final_result",
    "description": "Must be called to output the final answer to the user.",
    "parameters": {
        "type": "object",
        "properties": {
            "question": {
                "type": "string",
                "description": "Contains the user question"
            },
            "sql_query": {
                "type": "string",
                "description": "The final SQL query generated."
            },
            "explanation": {
                "type": "string",
                "description": "A natural language explanation of the answer."
            }
        },
        "required": ["question", "sql_query", "explanation"]
    }
}

# 2. The Function (With error handling for files)
# Make sure JSON_FILE_PATH is defined globally, e.g., JSON_FILE_PATH = "examples.json"

def final_result_function(question: str, sql_query: str, explanation: str) -> Dict[str, str]:
    """ Return the final output in structured format and save to JSON."""

    return {
        "question": question,
        "sql_query": sql_query,
        "explanation": explanation
    }

In [None]:
# mentioning of tools
file_tools = {
    "get_table_names":{"definition":get_table_names_definition, "function":get_table_names},
    "get_schema":{"definition":get_schema_definition,"function":get_schema},
    "sql_planner":{"definition":planner_definitions,"function":generate_sql_plan},
    "query_checker":{"definition":query_checker_definitions,"function":query_checker},
    "run_query":{"definition":run_query_definitions,"function":run_query},
    "final_result":{"definition":final_result_definition,"function":final_result_function}
}

In [None]:
# storing in the example.json file for future references
from typing import Dict
def insert(question: str, sql_query: str):
    if question and sql_query:
        with open(JSON_FILE_PATH,"r") as f:
            examples = json.load(f)
        examples.append({
                    "question": question,
                    "query": sql_query
                })
        with open(JSON_FILE_PATH,"w") as f:
            json.dump(examples, f, indent=2)
    

In [None]:
# agent
class Agent:
    def __init__(self,tools:list[dict], system_instruction:str):
        self.model = MODEL
        self.contents = []
        self.client = CLIENT
        self.tools = tools
        self.system_instruction = system_instruction
        self.tool_max_use = 5
    
    def run(self, contents: Union[str, List[Dict]],call_tracker: Dict[str, int] = None): 
        if call_tracker is None:
            call_tracker = {}
         
        if isinstance(contents, list):
            self.contents.append({"role": "user", "parts": contents})
        else:
            self.contents.append({"role": "user", "parts": [{"text": contents}]})
 
        tool_declarations = [t["definition"] for t in self.tools.values()]
        tools_obj = [types.Tool(function_declarations=tool_declarations)]
        tool_names = list(self.tools.keys())
        # validated mode: it checks multiple times whether the tool is needed or not
        tool_config = types.ToolConfig(
        function_calling_config=types.FunctionCallingConfig(mode=types.FunctionCallingConfigMode.VALIDATED,allowed_function_names=tool_names ))

        question = self.contents[0]['parts'][0]['text']
        if isinstance(question, str):
            examples = similarity_search(question)
        else:
            examples = ""
        
        formatted_system_instructions = self.system_instruction.format(dialect = dialect, examples= examples)
        # print(formatted_system_instructions)
        config = types.GenerateContentConfig(
            automatic_function_calling=types.AutomaticFunctionCallingConfig(disable=True),
            temperature=0,
            system_instruction=formatted_system_instructions,
            tools=tools_obj,
            tool_config=tool_config
        )
 
        response = self.client.models.generate_content(model=self.model, contents=self.contents, config=config)
        self.contents.append(response.candidates[0].content)

        if response.function_calls:
            functions_response_parts = []
            for tool_call in response.function_calls:
                fn_name = tool_call.name
                print(f"[Function Call] {tool_call}")

                if fn_name == "final_result":
                    return tool_call.args
                
                current_count = call_tracker.get(fn_name, 0)
                if current_count >= self.tool_max_use:
                    print(f"!!! LIMIT REACHED for {fn_name}. Blocking execution. !!!")
                    result = {
                        "error": f"SYSTEM: You have called '{fn_name}' {current_count} times, which is the limit. Do not use this tool again. Analyze what you have or ask the user for clarification."
                    }
                
                elif fn_name in self.tools:
                    try:
                        call_tracker[fn_name] = current_count + 1
                        result = {"result": self.tools[tool_call.name]["function"](**tool_call.args)}
 
                    except Exception as e:
                        result = {"error": str(e)}
                else:
                    result = {"error": "Tool not found"}
 
                print(f"[Function Response] {result}")
                functions_response_parts.append({"functionResponse": {"name": tool_call.name, "response": result}})

            return self.run(contents=functions_response_parts, call_tracker=call_tracker)
        
        return response
    
agent = Agent(tools = file_tools,
              system_instruction=system_instruction)

question =  "List all campaigns along with their total budget. For each campaign, calculate the 'Acquired Customer Revenue'—which is the total spend from customers who actually registered during that campaign's active dates. Order the result by the highest revenue generated."

response = agent.run(contents=question)

if isinstance(response, dict):
    print("--- Final Answer ---")
    question = response['question']
    query = response['sql_query']
    if question and query and query!="N/A":
        sql_query = query.replace("\n", " ")
        insert(question,sql_query)
    print(f"Question: {response['question']}")
    print(f"Explanation: {response['explanation']}")
    print(f"SQL: {response['sql_query']}")
else:
    print(response.text)

  embedding_function = GeminiEmbeddingFunction()


[Function Call] id=None args={'question': "List all campaigns along with their total budget. For each campaign, calculate the 'Acquired Customer Revenue'—which is the total spend from customers who actually registered during that campaign's active dates. Order the result by the highest revenue generated."} name='get_table_names'
[Function Response] {'result': ['campaign', 'customer', 'interaction', 'transaction']}
[Function Call] id=None args={'table_names': ['campaign', 'customer', 'interaction', 'transaction']} name='get_schema'
[Function Response] {'result': ['\nCREATE TABLE campaign (\n\t[index] BIGINT NULL, \n\tcampaign_id VARCHAR(max) COLLATE SQL_Latin1_General_CP1_CI_AS NULL, \n\tcampaign_name VARCHAR(max) COLLATE SQL_Latin1_General_CP1_CI_AS NULL, \n\tcampaign_type VARCHAR(max) COLLATE SQL_Latin1_General_CP1_CI_AS NULL, \n\tstart_date VARCHAR(max) COLLATE SQL_Latin1_General_CP1_CI_AS NULL, \n\tend_date VARCHAR(max) COLLATE SQL_Latin1_General_CP1_CI_AS NULL, \n\ttarget_segment V