# A simple chain for getting answers from a database

### Load the environment variables needed for the application

In [None]:
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

### Connect to a database to have an access to the data

In [None]:
from langchain_community.utilities import SQLDatabase

# Define the database URI
db_uri = "sqlite:///chinook.db"

# Create a database object
db = SQLDatabase.from_uri(db_uri)

# Test the connection by printing dialect, listing table names and querying the database
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

### Instantiate LLM model

In [29]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini")

### Define a State class to preserve all the required valriables accross the steps

In [30]:
from typing_extensions import TypedDict


class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

### Pull the prebuilt prompt template for sql query generation

In [None]:
from langchain import hub

query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

assert len(query_prompt_template.messages) == 1
query_prompt_template.messages[0].pretty_print()

### Create a function for SQL queries generation

In [32]:
from typing_extensions import Annotated


class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


def write_query(state: State):
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 10,
            "table_info": db.get_table_info(),
            "input": state["question"],
        }
    )
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    
    return {"query": result["query"]}

In [None]:
# Test the SQL generation function
write_query({"question": "How many Employees are there?"})

### Create a function for query executing

In [34]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool


def execute_query(state: State):
    """Execute SQL query."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}

In [None]:
# Test the SQL execution function
execute_query({"query": "SELECT COUNT(*) AS EmployeeCount FROM Employee;"})

### Create a function for answer generation

In [36]:
def generate_answer(state: State):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f'Question: {state["question"]}\n'
        f'SQL Query: {state["query"]}\n'
        f'SQL Result: {state["result"]}'
    )
    response = llm.invoke(prompt)
    return {"answer": response.content}

### Build a LangGraph graph (chain) by adding all the steps into the sequence

In [37]:
from langgraph.graph import START, StateGraph

graph_builder = StateGraph(State).add_sequence(
    [write_query, execute_query, generate_answer]
)
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()

### Visual representation of the graph

In [None]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

### Invoke the graph by providing a question

In [None]:
for step in graph.stream(
    {"question": "List the tables in the database"}, stream_mode="updates"
):
    print(step)

### Add an Human-In-The-Loop functionallity to review the SQL query and allow it to be run

In [16]:
from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory, interrupt_before=["execute_query"])

# Now that we're using persistence, we need to specify a thread ID
# so that we can continue the run after review.
config = {"configurable": {"thread_id": "1"}}

In [None]:
display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
for step in graph.stream(
    {"question": "What is my name?"},
    config,
    stream_mode="updates",
):
    print(step)

In [None]:
state = graph.get_state(config)
state.next

In [None]:
# If approved, continue the graph execution
for step in graph.stream(None, config, stream_mode="updates"):
    print(step)

In [None]:
for step in graph.stream(
    {"question": "What is my name?"},
    config,
    stream_mode="updates",
):
    print(step)

try:
    user_approval = input("Do you want to go to execute query? (yes/no): ")
except Exception:
    user_approval = "no"

if user_approval.lower() == "yes":
    # If approved, continue the graph execution
    for step in graph.stream(None, config, stream_mode="updates"):
        print(step)
else:
    print("Operation cancelled by user.")