# A simple chain for getting answers from a database

### Load the environment variables needed for the application

In [None]:
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

### Connect to a database to have an access to the data

In [None]:
from langchain_community.utilities import SQLDatabase

# Define the database URI
db_uri = "sqlite:///chinook.db"

# Create a database object
db = SQLDatabase.from_uri(db_uri)

# Test the connection by printing dialect, listing table names and querying the database
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

### Instantiate LLM model

In [15]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini")

### Load the tools

In [16]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

# Load the SQL tools for AI agent to use
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()

### Define a State class to preserve all the required valriables accross the steps

In [17]:
from langgraph.graph import MessagesState

### Pull the prebuilt prompt template for sql query generation

In [18]:
message_template = """
    You are an agent designed to interact with a SQL database.
    Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
    Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
    You can order the results by a relevant column to return the most interesting examples in the database.
    Never query for all the columns from a specific table, only ask for the relevant columns given the question.
    You have access to tools for interacting with the database.
    Only use the below tools. Only use the information returned by the below tools to construct your final answer.
    You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

    DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

    To start you should ALWAYS look at the tables in the database to see what you can query.
    Do NOT skip this step.
    Then you should query the schema of the most relevant tables.
    """

from langchain_core.messages import HumanMessage, SystemMessage

# System message
sys_msg = SystemMessage(content=message_template)

### Bint tools to LLM

In [19]:
llm_with_tools = llm.bind_tools(tools)

In [20]:
# Node
def tool_calling_llm(state: MessagesState):
    return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}

### Build a LangGraph graph (router)

In [21]:
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode, tools_condition
# from langgraph.checkpoint.memory import MemorySaver

# # Initialize in-memory checkpointer
# memory = MemorySaver()
# config = {"configurable": {"thread_id": "1"}}

# Build graph
builder = StateGraph(MessagesState)
builder.add_node("tool_calling_llm", tool_calling_llm)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "tool_calling_llm")
builder.add_conditional_edges(
    "tool_calling_llm",
    # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
    # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
    tools_condition,
)
builder.add_edge("tools", "tool_calling_llm")

# Compile graph
graph = builder.compile()
# graph = builder.compile(checkpointer=memory)

### Visual representation of the graph

In [None]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

### Invoke the graph by providing a question

In [None]:
messages = graph.invoke({"messages": "How many users are there?"})
for m in messages['messages']:
    m.pretty_print()

# messages = graph.invoke({"messages": "Hello, my name is Julius"}, config=config)
# for m in messages['messages']:
#     m.pretty_print()