In [None]:
'''
AI SQL Agents 
    - Langchain Tutos: https://python.langchain.com/docs/tutorials/sql_qa/
    - Chinook db: https://github.com/lerocha/chinook-database/releases

    1. Convert question to SQL query: Model converts user input to a SQL query.
    2. Execute SQL query: Execute the query.
    3. Answer the question: Model responds to user input using the query results.
'''

In [30]:
# from langchain_ollama.llms import OllamaLLM
from langchain_core.prompts import ChatPromptTemplate
from typing_extensions import Annotated, TypedDict
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool

from dotenv import load_dotenv
import os

load_dotenv()

True

In [None]:
#  =================================== Database Configuration ===================================

In [4]:
db_user = os.getenv('POSTGRESQL_USER')
db_pass = os.getenv('POSTGRESQL_PASSWORD')
db_host = os.getenv('POSTGRESQL_HOST')
db_port = os.getenv('POSTGRESQL_PORT')

db_uri = f"postgresql://{db_user}:{db_pass}@{db_host}:{db_port}/chinook"
db_dialect = "postgresql"

In [5]:
# database instance
db = SQLDatabase.from_uri(db_uri)

In [6]:
# Model config
# model = OllamaLLM(model="llama3.2")

from langchain.chat_models import init_chat_model

model = init_chat_model("llama3.2", model_provider="ollama")

In [7]:
class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

In [15]:
# Prompt template
system_message = """
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run,
then look at the results of the query and return the answer. Unless the user
specifies a specific number of examples they wish to obtain, always limit your
query to at most {top_k} results.

You can order the results by a relevant column to return the most interesting
examples in the database. Never query for all the columns from a specific table,
only ask for the relevant columns given the question.

You MUST double check your query before executing it. If you get an error while
executing a query, rewrite the query and try again. Remember the name of the tables are in singular.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the
database. This is extremely important.

To start you should ALWAYS look at the tables in the database to see what you
can query. Do NOT skip this step.

Then you should query the schema of the most relevant tables.
""".format(
    dialect=db_dialect,
    top_k=5,
)

In [None]:
# =================================== Convert question to SQL query ===================================

In [16]:
# question of the human
user_prompt = "Question: {input}"

# query template
query_prompt_template = ChatPromptTemplate(
    [("system", system_message), ("user", user_prompt)]
)

In [None]:
for message in query_prompt_template.messages:
    message.pretty_print()

In [17]:
# converting to query
class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]

In [18]:
# function to write the query that will by execute
def write_query(state: State):
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 10,
            "table_info": db.get_table_info(),
            "input": state["question"],
        }
    )
    structured_llm = model.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {"query": result["query"]}

In [19]:
write_query({"question": "How many Employee are there?"})

{'query': 'SELECT COUNT(*) FROM employee'}

In [None]:
# =================================== Execute query ===================================

In [20]:
# exectuted the query obtained
def execute_query(state: State):
    """Execute SQL query."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}

In [28]:
question = "How many Employee are there?"
query = write_query({"question": question})
print(f"query: {query}")
result = execute_query({"query": query})
print(f"result: {result}")

query: {'query': 'SELECT COUNT(*) FROM employee'}
result: {'result': '[(8,)]'}


In [None]:
# =================================== Generate answer =================================== 

In [26]:
def generate_answer(state: State):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f'Question: {state["question"]}\n'
        f'SQL Query: {state["query"]}\n'
        f'SQL Result: {state["result"]}'
    )
    response = model.invoke(prompt)
    return {"answer": response.content}

In [29]:
question = "How many Employee are there?"
query = write_query({"question": question})
#print(f"query: {query}")
result = execute_query({"query": query})
#print(f"result: {result}")
answer = generate_answer({"question": question, "query": query,"result": result})
print(f"answer: {answer}")

answer: {'answer': 'Based on the given information, the answer to the user question "How many Employees are there?" is 8.'}
