In [1]:
import sqlite3

In [2]:
connection = sqlite3.connect("mydb.db")

In [3]:
connection

<sqlite3.Connection at 0x231dd07a440>

In [4]:
table_creation_query="""
CREATE TABLE IF NOT EXISTS employees (
    emp_id INTEGER PRIMARY KEY,
    first_name TEXT NOT NULL,
    last_name TEXT NOT NULL,
    email TEXT UNIQUE NOT NULL,
    hire_date TEXT NOT NULL,
    salary REAL NOT NULL
);
"""

In [5]:
table_creation_query2="""
CREATE TABLE IF NOT EXISTS customers (
    customer_id INTEGER PRIMARY KEY AUTOINCREMENT,
    first_name TEXT NOT NULL,
    last_name TEXT NOT NULL,
    email TEXT UNIQUE NOT NULL,
    phone TEXT
);
"""

In [6]:
table_creation_query3="""
CREATE TABLE IF NOT EXISTS orders (
    order_id INTEGER PRIMARY KEY AUTOINCREMENT,
    customer_id INTEGER NOT NULL,
    order_date TEXT NOT NULL,
    amount REAL NOT NULL,
    FOREIGN KEY (customer_id) REFERENCES customers (customer_id)
);

"""

In [7]:
cursor=connection.cursor()

In [8]:

cursor.execute(table_creation_query)
cursor.execute(table_creation_query2)
cursor.execute(table_creation_query3)

<sqlite3.Cursor at 0x231dcd3c9c0>

In [9]:
insert_query = """
INSERT INTO employees (emp_id, first_name, last_name, email, hire_date, salary)
VALUES (?, ?, ?, ?, ?, ?);
"""

insert_query_customers = """
INSERT INTO customers (customer_id, first_name, last_name, email, phone)
VALUES (?, ?, ?, ?, ?);
"""

insert_query_orders = """
INSERT INTO orders (order_id, customer_id, order_date, amount)
VALUES (?, ?, ?, ?);
"""

In [10]:
employee_data = [
    (1, "Sunny", "Savita", "sunny.sv@abc.com", "2023-06-01", 50000.00),
    (2, "Arhun", "Meheta", "arhun.m@gmail.com", "2022-04-15", 60000.00),
    (3, "Alice", "Johnson", "alice.johnson@jpg.com", "2021-09-30", 55000.00),
    (4, "Bob", "Brown", "bob.brown@uio.com", "2020-01-20", 45000.00),
    ]

customers_data = [
    (1, "John", "Doe", "john.doe@example.com", "1234567890"),
    (2, "Jane", "Smith", "jane.smith@example.com", "9876543210"),
    (3, "Emily", "Davis", "emily.davis@example.com", "4567891230"),
    (4, "Michael", "Brown", "michael.brown@example.com", "7894561230"),
]

orders_data = [
    (1, 1, "2023-12-01", 250.75),
    (2, 2, "2023-11-20", 150.50),
    (3, 3, "2023-11-25", 300.00),
    (4, 4, "2023-12-02", 450.00),
]

In [11]:
cursor.executemany(insert_query,employee_data)
cursor.executemany(insert_query_customers,customers_data)
cursor.executemany(insert_query_orders,orders_data)

<sqlite3.Cursor at 0x231dcd3c9c0>

In [12]:
connection.commit()

In [None]:
cursor.execute("select * from orders;")

<sqlite3.Cursor at 0x231dcd3c9c0>

In [14]:
for row in cursor.fetchall():
    print(row)

(1, 1, '2023-12-01', 250.75)
(2, 2, '2023-11-20', 150.5)
(3, 3, '2023-11-25', 300.0)
(4, 4, '2023-12-02', 450.0)


In [16]:
from langchain_community.utilities.sql_database import SQLDatabase

In [19]:
db = SQLDatabase.from_uri("sqlite:///mydb.db")
db

<langchain_community.utilities.sql_database.SQLDatabase at 0x231df166020>

In [20]:
db.dialect

'sqlite'

In [21]:
db.get_usable_table_names()

['customers', 'employees', 'orders']

In [31]:
import os
from dotenv import load_dotenv

load_dotenv()
os.environ["GROQ_API_KEY"] = "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"

In [32]:
from langchain_groq import ChatGroq
llm = ChatGroq(model="llama3-70b-8192")

In [33]:
llm.invoke("Hello how are you").content

"Hello! I'm just a language model, I don't have feelings or emotions like humans do, but I'm functioning properly and ready to assist you with any questions or tasks you may have! How can I help you today?"

In [34]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

In [36]:
toolkit=SQLDatabaseToolkit(db=db, llm=llm)
tools=toolkit.get_tools()
tools

[QuerySQLDatabaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x00000231DF166020>),
 InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x00000231DF166020>),
 ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x00000231DF166020>),
 QuerySQLCheckerTool(description='Use this tool to 

In [37]:
for tool in tools:
    print(tool.name)

sql_db_query
sql_db_schema
sql_db_list_tables
sql_db_query_checker


In [39]:
list_tables_tool = next((tool for tool in tools if tool.name == "sql_db_list_tables"), None)
list_tables_tool

ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x00000231DF166020>)

In [41]:
get_schema_tool = next((tool for tool in tools if tool.name == "sql_db_schema"), None)
get_schema_tool

InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x00000231DF166020>)

In [42]:
llm_to_get_schema=llm.bind_tools([get_schema_tool])
list_tables_tool.invoke("")

'customers, employees, orders'

In [43]:
print(get_schema_tool.invoke("customers"))


CREATE TABLE customers (
	customer_id INTEGER, 
	first_name TEXT NOT NULL, 
	last_name TEXT NOT NULL, 
	email TEXT NOT NULL, 
	phone TEXT, 
	PRIMARY KEY (customer_id), 
	UNIQUE (email)
)

/*
3 rows from customers table:
customer_id	first_name	last_name	email	phone
1	John	Doe	john.doe@example.com	1234567890
2	Jane	Smith	jane.smith@example.com	9876543210
3	Emily	Davis	emily.davis@example.com	4567891230
*/


In [44]:
from langchain_core.tools import tool
@tool
def query_to_database(query:str)->str:
    """
    Execute a SQL query against the database and return the result.
    If the query is invalid or returns no result, an error message will be returned.
    In case of an error, the user is advised to rewrite the query and try again.
    """
    result=db.run_no_throw(query)
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    return result

In [None]:
query_to_database.invoke("SELECT * FROM Employees;")

"[(1, 'Sunny', 'Savita', 'sunny.sv@abc.com', '2023-06-01', 50000.0), (2, 'Arhun', 'Meheta', 'arhun.m@gmail.com', '2022-04-15', 60000.0), (3, 'Alice', 'Johnson', 'alice.johnson@jpg.com', '2021-09-30', 55000.0), (4, 'Bob', 'Brown', 'bob.brown@uio.com', '2020-01-20', 45000.0)]"

In [None]:
llm_with_tools=llm.bind_tools([query_to_database])

In [None]:
llm_with_tools.invoke("SELECT * FROM Employees;")

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'n61rptesd', 'function': {'arguments': '{"query":"SELECT * FROM Employees;"}', 'name': 'query_to_database'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 48, 'prompt_tokens': 924, 'total_tokens': 972, 'completion_time': 0.182048421, 'prompt_time': 0.03453552, 'queue_time': 0.28326916, 'total_time': 0.216583941}, 'model_name': 'llama3-70b-8192', 'system_fingerprint': 'fp_bf16903a67', 'service_tier': 'on_demand', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run--8a2b97a7-4d2d-4e6a-80bb-b05790d984a6-0', tool_calls=[{'name': 'query_to_database', 'args': {'query': 'SELECT * FROM Employees;'}, 'id': 'n61rptesd', 'type': 'tool_call'}], usage_metadata={'input_tokens': 924, 'output_tokens': 48, 'total_tokens': 972})

In [48]:
import warnings
warnings.filterwarnings("ignore")

In [53]:
from typing import Annotated, Literal
from langchain_core.messages import AIMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from typing_extensions import TypedDict
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import AnyMessage, add_messages
from typing import Any
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
# from langgraph.prebuilt import ToolNode

In [58]:
from langchain_core.prompts import ChatPromptTemplate

query_check_system = """You are a SQL expert. Carefully review the SQL query for common mistakes, including:

Issues with NULL handling (e.g., NOT IN with NULLs)
Improper use of UNION instead of UNION ALL
Incorrect use of BETWEEN for exclusive ranges
Data type mismatches or incorrect casting
Quoting identifiers improperly
Incorrect number of arguments in functions
Errors in JOIN conditions

If you find any mistakes, rewrite the query to fix them. If it's correct, reproduce it as is."""

query_check_prompt = ChatPromptTemplate.from_messages([("system", query_check_system), ("placeholder", "{messages}")])
check_generated_query = query_check_prompt | llm_with_tools

In [59]:
check_generated_query.invoke({"messages": [("user", "SELECT * FROM Emploees LIMIT 5;")]})

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'sr1jk7hwz', 'function': {'arguments': '{"query":"SELECT * FROM Employees LIMIT 5;"}', 'name': 'query_to_database'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 73, 'prompt_tokens': 1027, 'total_tokens': 1100, 'completion_time': 0.255508558, 'prompt_time': 0.041625691, 'queue_time': 0.271002229, 'total_time': 0.297134249}, 'model_name': 'llama3-70b-8192', 'system_fingerprint': 'fp_bf16903a67', 'service_tier': 'on_demand', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run--3189488f-48a9-49a5-b8f3-59ce3ae99719-0', tool_calls=[{'name': 'query_to_database', 'args': {'query': 'SELECT * FROM Employees LIMIT 5;'}, 'id': 'sr1jk7hwz', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1027, 'output_tokens': 73, 'total_tokens': 1100})

In [60]:
check_generated_query.invoke({"messages": [("user", "SELECT +++ FROM Employees LIMITs 5;")]})

AIMessage(content="It looks like there's a mistake in the SQL query. \n\nThe correct query should be:\n\n```\n", additional_kwargs={'tool_calls': [{'id': 'mtsy5dwfk', 'function': {'arguments': '{"query":"SELECT * FROM Employees LIMIT 5;"}', 'name': 'query_to_database'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 95, 'prompt_tokens': 1026, 'total_tokens': 1121, 'completion_time': 0.361963992, 'prompt_time': 0.042036105, 'queue_time': 0.31590454, 'total_time': 0.404000097}, 'model_name': 'llama3-70b-8192', 'system_fingerprint': 'fp_bf16903a67', 'service_tier': 'on_demand', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run--98b631c1-aacb-43f7-b332-31c4a69c61d0-0', tool_calls=[{'name': 'query_to_database', 'args': {'query': 'SELECT * FROM Employees LIMIT 5;'}, 'id': 'mtsy5dwfk', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1026, 'output_tokens': 95, 'total_tokens': 1121})

In [61]:
check_generated_query.invoke({"messages": [("user", "SELECT everything FROM Employees LIMITs 5;")]})

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'cykz4rpmy', 'function': {'arguments': '{"query":"SELECT * FROM Employees LIMIT 5;"}', 'name': 'query_to_database'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 73, 'prompt_tokens': 1026, 'total_tokens': 1099, 'completion_time': 0.242856558, 'prompt_time': 0.041590402, 'queue_time': 0.270680598, 'total_time': 0.28444696}, 'model_name': 'llama3-70b-8192', 'system_fingerprint': 'fp_bf16903a67', 'service_tier': 'on_demand', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run--f126ea1d-8f83-49d1-af1f-f72983257fb6-0', tool_calls=[{'name': 'query_to_database', 'args': {'query': 'SELECT * FROM Employees LIMIT 5;'}, 'id': 'cykz4rpmy', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1026, 'output_tokens': 73, 'total_tokens': 1099})

In [62]:
class SubmitFinalAnswer(BaseModel):
    """Submit the final answer to the user based on the query results."""
    final_answer: str = Field(..., description="The final answer to the user")
    
llm_with_final_answer=llm.bind_tools([SubmitFinalAnswer])

In [63]:
query_gen_system_prompt = """You are a SQL expert with a strong attention to detail.Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.

1. DO NOT call any tool besides SubmitFinalAnswer to submit the final answer. When generating the query:

2. Output the SQL query that answers the input question without a tool call.

3. Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.

4. You can order the results by a relevant column to return the most interesting examples in the database.

5. Never query for all the columns from a specific table, only ask for the relevant columns given the question.

6. If you get an error while executing a query, rewrite the query and try again.

7. If you get an empty result set, you should try to rewrite the query to get a non-empty result set.

8. NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.

9. If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.

10. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Do not return any sql query except answer.

11. rewrite this promote with very simple english and little shorter manner also """

query_gen_prompt = ChatPromptTemplate.from_messages([("system", query_gen_system_prompt), ("placeholder", "{messages}")])

query_generator = query_gen_prompt | llm_with_final_answer

In [64]:
query_gen_prompt

ChatPromptTemplate(input_variables=[], optional_variables=['messages'], input_types={'messages': list[typing.Annotated[typing.Union[typing.Annotated[langchain_core.messages.ai.AIMessage, Tag(tag='ai')], typing.Annotated[langchain_core.messages.human.HumanMessage, Tag(tag='human')], typing.Annotated[langchain_core.messages.chat.ChatMessage, Tag(tag='chat')], typing.Annotated[langchain_core.messages.system.SystemMessage, Tag(tag='system')], typing.Annotated[langchain_core.messages.function.FunctionMessage, Tag(tag='function')], typing.Annotated[langchain_core.messages.tool.ToolMessage, Tag(tag='tool')], typing.Annotated[langchain_core.messages.ai.AIMessageChunk, Tag(tag='AIMessageChunk')], typing.Annotated[langchain_core.messages.human.HumanMessageChunk, Tag(tag='HumanMessageChunk')], typing.Annotated[langchain_core.messages.chat.ChatMessageChunk, Tag(tag='ChatMessageChunk')], typing.Annotated[langchain_core.messages.system.SystemMessageChunk, Tag(tag='SystemMessageChunk')], typing.Annot

In [65]:
query_generator.invoke({"messages": [("can you fetch the data from employee table?")]})

AIMessage(content="Here is the SQL query to fetch data from the employee table:\n\n```\nSELECT * FROM employee LIMIT 5;\n```\n\nPlease provide the result of this query, and I'll be happy to assist you further.", additional_kwargs={}, response_metadata={'token_usage': {'completion_tokens': 44, 'prompt_tokens': 1201, 'total_tokens': 1245, 'completion_time': 0.201639696, 'prompt_time': 0.046957782, 'queue_time': 0.274164628, 'total_time': 0.248597478}, 'model_name': 'llama3-70b-8192', 'system_fingerprint': 'fp_bf16903a67', 'service_tier': 'on_demand', 'finish_reason': 'stop', 'logprobs': None}, id='run--e555a370-bc71-4007-8ee1-d7a3d6cb0ea9-0', usage_metadata={'input_tokens': 1201, 'output_tokens': 44, 'total_tokens': 1245})

In [66]:
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

In [67]:
def first_tool_call(state:State)->dict[str,list[AIMessage]]:
    return{"messages": [AIMessage(content="",tool_calls=[{"name":"sql_db_list_tables","args":{},"id":"tool_abcd123"}])]}

In [68]:
def handle_tool_error(state:State):
    error = state.get("error") 
    tool_calls = state["messages"][-1].tool_calls
    return { "messages": [ ToolMessage(content=f"Error: {repr(error)}\n please fix your mistakes.",tool_call_id=tc["id"],) for tc in tool_calls ] }

In [69]:
def create_node_from_tool_with_fallback(tools:list)-> RunnableWithFallbacks[Any, dict]:
    return ToolNode(tools).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")

In [70]:
def check_the_given_query(state:State):
    return {"messages": [check_generated_query.invoke({"messages": [state["messages"][-1]]})]}

In [71]:
def generation_query(state:State):
    message = query_generator.invoke(state)

    # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
    tool_messages = []
    if message.tool_calls:
        for tc in message.tool_calls:
            if tc["name"] != "SubmitFinalAnswer":
                tool_messages.append(
                    ToolMessage(
                        content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
                        tool_call_id=tc["id"],
                    )
                )
    else:
        tool_messages = []
    return {"messages": [message] + tool_messages}

In [72]:
def should_continue(state:State):
    messages = state["messages"]
    last_message = messages[-1]
    if getattr(last_message, "tool_calls", None):
        return END
    if last_message.content.startswith("Error:"):
        return "query_gen"
    else:
        return "correct_query"

In [None]:
# list_tables=create_node_from_tool_with_fallback([list_tables_tool])

In [None]:
# get_schema=create_node_from_tool_with_fallback([get_schema_tool])

In [None]:
# get_schema_tool.invoke(AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_s40m', 'function': {'arguments': '{"table_names":"orders"}', 'name': 'sql_db_schema'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 35, 'prompt_tokens': 1059, 'total_tokens': 1094, 'completion_time': 0.104650804, 'prompt_time': 0.056007903, 'queue_time': 0.018336399000000003, 'total_time': 0.160658707}, 'model_name': 'llama3-70b-8192', 'system_fingerprint': 'fp_753a4aecf6', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-6b65ca0f-8c99-4918-9742-f77db5402121-0', tool_calls=[{'name': 'sql_db_schema', 'args': {'table_names': 'orders'}, 'id': 'call_s40m', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1059, 'output_tokens': 35, 'total_tokens': 1094}))

In [None]:
# get_schema.invoke(AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_s40m', 'function': {'arguments': '{"table_names":"orders"}', 'name': 'sql_db_schema'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 35, 'prompt_tokens': 1059, 'total_tokens': 1094, 'completion_time': 0.104650804, 'prompt_time': 0.056007903, 'queue_time': 0.018336399000000003, 'total_time': 0.160658707}, 'model_name': 'llama3-70b-8192', 'system_fingerprint': 'fp_753a4aecf6', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-6b65ca0f-8c99-4918-9742-f77db5402121-0', tool_calls=[{'name': 'sql_db_schema', 'args': {'table_names': 'orders'}, 'id': 'call_s40m', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1059, 'output_tokens': 35, 'total_tokens': 1094}))

In [None]:

# AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_s40m', 'function': {'arguments': '{"table_names":"orders"}', 'name': 'sql_db_schema'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 35, 'prompt_tokens': 1059, 'total_tokens': 1094, 'completion_time': 0.104650804, 'prompt_time': 0.056007903, 'queue_time': 0.018336399000000003, 'total_time': 0.160658707}, 'model_name': 'llama3-70b-8192', 'system_fingerprint': 'fp_753a4aecf6', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-6b65ca0f-8c99-4918-9742-f77db5402121-0', tool_calls=[{'name': 'sql_db_schema', 'args': {'table_names': 'orders'}, 'id': 'call_s40m', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1059, 'output_tokens': 35, 'total_tokens': 1094})

In [None]:
# query_database=create_node_from_tool_with_fallback([query_to_database])

In [81]:
def llm_get_schema(state:State):
    print("this is my state", state)
    response = llm_to_get_schema.invoke(state["messages"])
    {"messages": [response]}

In [None]:
# [HumanMessage(content='how many order are there which is more than 300 rupees?', additional_kwargs={}, response_metadata={}, id='0571873a-df8f-4673-942c-ea73a052f7d6'), AIMessage(content='', additional_kwargs={}, response_metadata={}, id='bcaf622f-b11e-4502-b5ef-0b784bc5bdf1', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': 'tool_abcd123', 'type': 'tool_call'}]), ToolMessage(content='customers, employees, orders', name='sql_db_list_tables', id='a48f6154-23de-47f0-887c-2888405ced08', tool_call_id='tool_abcd123')]

In [83]:
from langchain_core.messages import BaseMessage, HumanMessage

In [84]:
llm_to_get_schema.invoke([HumanMessage(content='how many order are there which is more than 300 rupees?', additional_kwargs={}, response_metadata={}, id='0571873a-df8f-4673-942c-ea73a052f7d6'), AIMessage(content='', additional_kwargs={}, response_metadata={}, id='bcaf622f-b11e-4502-b5ef-0b784bc5bdf1', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': 'tool_abcd123', 'type': 'tool_call'}]), ToolMessage(content='customers, employees, orders', name='sql_db_list_tables', id='a48f6154-23de-47f0-887c-2888405ced08', tool_call_id='tool_abcd123')])

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': '5cdnwbrf4', 'function': {'arguments': '{"table_names":"orders"}', 'name': 'sql_db_schema'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 35, 'prompt_tokens': 1036, 'total_tokens': 1071, 'completion_time': 0.194720686, 'prompt_time': 0.041899651, 'queue_time': 0.271709809, 'total_time': 0.236620337}, 'model_name': 'llama3-70b-8192', 'system_fingerprint': 'fp_bf16903a67', 'service_tier': 'on_demand', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run--37a3a4c8-b2f3-46a5-a89b-935bff29f8d3-0', tool_calls=[{'name': 'sql_db_schema', 'args': {'table_names': 'orders'}, 'id': '5cdnwbrf4', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1036, 'output_tokens': 35, 'total_tokens': 1071})

In [None]:
workflow = StateGraph(State)
workflow.add_node("first_tool_call",first_tool_call)
workflow.add_node("list_tables_tool", list_tables)
workflow.add_node("get_schema_tool", get_schema)
workflow.add_node("model_get_schema", llm_get_schema)
workflow.add_node("query_gen", generation_query)
workflow.add_node("correct_query", check_the_given_query)
workflow.add_node("execute_query", query_database)

In [None]:
workflow.add_edge(START, "first_tool_call")
workflow.add_edge("first_tool_call", "list_tables_tool")
workflow.add_edge("list_tables_tool", "model_get_schema")
workflow.add_edge("model_get_schema", "get_schema_tool")
workflow.add_edge("get_schema_tool", "query_gen")
workflow.add_conditional_edges("query_gen",should_continue,
                            {END:END,
                            "correct_query":"correct_query"})
workflow.add_edge("correct_query", "execute_query")
workflow.add_edge("execute_query", "query_gen")

In [None]:
app=workflow.compile()

In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod

display(
    Image(
        app.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

In [None]:
query={"messages": [("user", "how many order are there which is more than 300 rupees?")]}

In [None]:
response=app.invoke(query)

In [None]:
response["messages"][-1].tool_calls[0]["args"]["final_answer"]