In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import re
import yaml
from dotenv import load_dotenv

from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.agents import Tool, AgentExecutor
from langchain.utilities import SQLDatabase
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool, InfoSQLDatabaseTool

from src.database import create_sql_engine
from src.schema import create_schema
from src.llm import init_llm_local, init_llm_gemini
from src.retriever import init_retriever

GOOGLE_API = os.getenv("GOOGLE_API")

In [2]:
def build_list_table_tool(db):
    list_tables = ListSQLDatabaseTool(db=db)
    
    # def list_tables(x):
    #     return "\n".join(db.get_usable_table_names())
    
    return Tool(
        name="list_tables",
        func=list_tables,
        description="Use this tool to list all available table names in the database."
    )

In [3]:
def build_info_table_tool(db):
    describe_table = InfoSQLDatabaseTool(db=db)
    return Tool(
        name="describe_table",
        func=lambda table: describe_table(tool_input=table),
        description="Use this tool to describe the schema of a specific table, including column names and data types."
    )

In [4]:
def build_sql_generation_tool(llm, db, chain_template_path = "template/sql_db_chain_v1.1.yml", max_retries=3):
    with open(chain_template_path, "r") as f:
        chain_template_dict = yaml.safe_load(f)

    chain_prompt = PromptTemplate(
        input_variables=["input", "table_info", "dialect"],
        template=chain_template_dict['instruction'],
    )

    sql_generation_chain = LLMChain(llm=llm, prompt=chain_prompt, verbose=True)

    def safe_sql_query(question, max_retries=max_retries):
        table_info = db.get_table_info()
        dialect = db.dialect
        
        chain_input = {
            "input": question,
            "table_info": table_info,
            "dialect": dialect
        }
        
        for i in range(max_retries):
            print(f"--- Attempt {i + 1} of {max_retries} ---")
            sql_code = sql_generation_chain.invoke(chain_input)['text']

            match = re.search(r"SQLQuery:\s*(.*?)(?=\nSQLResult:)", sql_code, re.DOTALL)
            if match:
                sql_code = match.group(1)
            sql_code = sql_code.strip().replace("```sql", "").replace("```", "").strip()
            
            try:
                print(f"Executing SQL: {sql_code}")
                result = db.run(sql_code)
                print("Query Successful!")
                return f"Query executed successfully. Result: {result}"
            except Exception as e:
                error_message = str(e)
                print(f"Query Failed. Error: {error_message}")
                
                if i == max_retries - 1:
                    return f"Failed to execute SQL after {max_retries} attempts. Last error: {error_message}"
                
                chain_input["input"] = (
                    f"The previous attempt to answer the question '{question}' failed. "
                    f"The generated SQL was:\n{sql_code}\n"
                    f"It produced the following database error:\n{error_message}\n"
                    "Please analyze the error and the database schema to generate a corrected SQL query."
                )

        return "Failed to get a valid response from the database after multiple attempts."

    return Tool(
        name="sql_query",
        func=safe_sql_query,
        description=(
            "Use this tool to answer questions about user data, metrics, or reports from the database. "
            "Input should be a complete question in natural language. "
            "The tool will automatically generate, execute, and correct SQL to find the answer."
        )
    )

In [5]:
def build_schema_tool(retriever):
    return Tool(
        name="schema_lookup",
        func=lambda query: retriever.get_relevant_documents(query),
        description=(
            "A tool to retrieve definitions of table or column names. "
            "Use when the input is a natural language question containing a field or table name that needs clarification."
            "Input should be a short query or phrase asking about the meaning or definition of a table or column. "
            "Returns the associated schema documentation."
        )
    )

In [6]:
# Initialize SQLite DB
engine, _ = create_sql_engine()
db = SQLDatabase(engine, include_tables=["members", "items", "campaigns", "transactions", "transaction_items"])

# Initialize schema retriever and tools
llm = init_llm_gemini(api_key=GOOGLE_API)
df_schema = create_schema()
retriever = init_retriever(df_schema)

2025-07-29 01:06:04,685 | INFO | Database schema created at sqlite:///:memory:.
2025-07-29 01:06:04,686 | INFO | Generating 100 members...
2025-07-29 01:06:04,709 | INFO | Generating 30 items...
2025-07-29 01:06:04,711 | INFO | Generating 5 campaigns...
2025-07-29 01:06:04,711 | INFO | Generating 150 transactions...
2025-07-29 01:06:04,757 | INFO | All generated data added to session and committed.
2025-07-29 01:06:04,770 | INFO | Initializing Online LLM with model: models/gemini-2.0-flash...
2025-07-29 01:06:04,787 | INFO | Initializing Chroma retriever with embedding model: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
  embedding = HuggingFaceEmbeddings(model_name=embedding_model_name)
2025-07-29 01:06:06,968 | INFO | Use pytorch device_name: mps
2025-07-29 01:06:06,968 | INFO | Load pretrained SentenceTransformer: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
2025-07-29 01:06:12,164 | INFO | Anonymized telemetry enabled. See                     https://d

In [7]:
list_table_tool = build_list_table_tool(db)
print(list_table_tool.invoke(""))

campaigns, items, members, transaction_items, transactions


In [8]:
info_table_tool = build_info_table_tool(db)
print(info_table_tool.invoke("transaction_items"))


CREATE TABLE transaction_items (
	transaction_id INTEGER NOT NULL, 
	item_id INTEGER NOT NULL, 
	quantity INTEGER, 
	unit_price FLOAT, 
	PRIMARY KEY (transaction_id, item_id), 
	FOREIGN KEY(transaction_id) REFERENCES transactions (transaction_id), 
	FOREIGN KEY(item_id) REFERENCES items (item_id)
)

/*
3 rows from transaction_items table:
transaction_id	item_id	quantity	unit_price
1	22	5	1553.03
2	8	1	247.49
3	5	1	489.17
*/


  func=lambda table: describe_table(tool_input=table),


In [9]:
sql_generation_tool = build_sql_generation_tool(llm, db)
print(sql_generation_tool.invoke("How to know all items sold during December 2024?"))

--- Attempt 1 of 3 ---


[1m> Entering new LLMChain chain...[0m
Prompt after formatting:
[32;1m[1;3mGiven an input question, first create a syntactically correct sqlite query to run, then look at the results of the query and return the answer. 
Unless the user explicitly requests a specific number of examples, limit the query to a maximum of 100 results. Prefer using aggregation functions to reduce the number of output rows. 
You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Use the following commonly supported SQL functions:
- Math: ABS, ROUND, CEIL, FLOOR, MOD, POWER, SQRT
- Date/Time: CURRENT_DATE, NOW(), DATE_ADD, EXTRA

  sql_generation_chain = LLMChain(llm=llm, prompt=chain_prompt, verbose=True)



[1m> Finished chain.[0m
Executing SQL: SELECT DISTINCT
  T1.item_name
FROM items AS T1
INNER JOIN transaction_items AS T2
  ON T1.item_id = T2.item_id
INNER JOIN transactions AS T3
  ON T2.transaction_id = T3.transaction_id
WHERE
  STRFTIME('%Y', T3.transaction_time) = '2024' AND STRFTIME('%m', T3.transaction_time) = '12';
Query Successful!
Query executed successfully. Result: [('Attack Avoid',), ('Seek Suddenly',), ('Product Quite',), ('Speak Order',), ('Be Green',), ('Effect Box',), ('Challenge Commercial',), ('Election Assume',), ('Finally Current',), ('Activity Despite',), ('Early Fund',), ('Play Rock',), ('Boy Condition',), ('Some Difficult',), ('Remain Ok',), ('Local Church',)]


In [10]:
schema_tool = build_schema_tool(retriever)
print(schema_tool.invoke("How to know all items in a transaction?"))

  func=lambda query: retriever.get_relevant_documents(query),
2025-07-29 01:06:16,829 | ERROR | Failed to send telemetry event CollectionQueryEvent: capture() takes 1 positional argument but 3 were given


[Document(metadata={'column_name': '__table__', 'data_type': 'TABLE', 'definition': 'Table listing the specific items purchased in each transaction.', 'table_name': 'transaction_items'}, page_content='transaction_items.__table__: [TABLE] Table listing the specific items purchased in each transaction.'), Document(metadata={'column_name': 'item_id', 'data_type': 'Integer', 'definition': 'Unique ID of the purchased item', 'table_name': 'transaction_items'}, page_content='transaction_items.item_id: [Integer] Unique ID of the purchased item'), Document(metadata={'column_name': 'price', 'data_type': 'Float', 'definition': 'Original price of the item', 'table_name': 'items'}, page_content='items.price: [Float] Original price of the item')]


---

In [11]:
def build_agent(llm, db, retriever, agent_template_path: str = "template/sql_agent_v1.1.yml"):
    with open(agent_template_path, "r", encoding="utf-8") as f:
        agent_template_dict = yaml.safe_load(f)
    
    table_list = df_schema.loc[df_schema['column_name']=='__table__'].apply(lambda r: f"{r['table_name']}: {r['definition']}", axis=1).tolist()


    # Create agent
    list_table_tool = build_list_table_tool(db)
    info_table_tool = build_info_table_tool(db)
    sql_generation_tool = build_sql_generation_tool(llm, db)
    schema_tool = build_schema_tool(retriever)
    tools = [list_table_tool, info_table_tool, sql_generation_tool, schema_tool]

    ## custom agent
    agent_prompt = ZeroShotAgent.create_prompt(
        tools=tools,
        prefix=agent_template_dict['prefix'],
        suffix=agent_template_dict['suffix'],
        format_instructions=agent_template_dict['instruction'].replace("{table_list}", str(list_table_tool.invoke(""))),
        input_variables=["input", "agent_scratchpad"],
    )

    agent = ZeroShotAgent(
        llm_chain=LLMChain(llm=llm, prompt=agent_prompt),
        allowed_tools=[t.name for t in tools],
        stop=["\nFinal Answer:"],
    )

    executor = AgentExecutor.from_agent_and_tools(
        agent=agent,
        tools=tools,
        # max_iterations=3,
        early_stopping_method="generate",
        handle_parsing_errors=True,
        verbose=True,
    )

    return executor

In [12]:
agent = build_agent(llm, db, retriever, agent_template_path = "template/sql_agent_v1.1.yml")

  agent = ZeroShotAgent(


In [13]:
print(agent.invoke("List all items purchased in transaction 3."))



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mI need to find out which items were purchased in transaction 3. I need to query the database. I'll start by looking at the tables available to see which ones seem relevant.
Action: list_tables
Action Input: [0m
Observation: [36;1m[1;3mcampaigns, items, members, transaction_items, transactions[0m
Thought:[32;1m[1;3mI see the `transaction_items` table and the `items` table. I think these two tables are relevant to answering the question. I should examine the schema of these tables to understand their columns and relationships.
Action: describe_table
Action Input: transaction_items[0m
Observation: [33;1m[1;3m
CREATE TABLE transaction_items (
	transaction_id INTEGER NOT NULL, 
	item_id INTEGER NOT NULL, 
	quantity INTEGER, 
	unit_price FLOAT, 
	PRIMARY KEY (transaction_id, item_id), 
	FOREIGN KEY(transaction_id) REFERENCES transactions (transaction_id), 
	FOREIGN KEY(item_id) REFERENCES items (item_id)
)

/*
3 rows from 