In [1]:
from langgraph.types import Literal,Command
from langgraph.graph import StateGraph,START,END,MessagesState
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.utilities import SQLDatabase
from langchain_deepseek import ChatDeepSeek
from dotenv import load_dotenv
from langchain.chains import create_sql_query_chain
from sqlalchemy import create_engine
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent
from typing_extensions import TypedDict,List,Optional,Annotated
from pydantic import Field,BaseModel
import asyncio
import os
load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
class database(BaseModel):
    query : str = Field(description='Query to be executed')
    result : str = Field(description='Result of the query')

In [3]:
gemini_model = ChatGoogleGenerativeAI(model='gemini-2.0-flash',temperature=0.5)


In [4]:
db_user = 'root'
db_password = 'gokul'
db_host = 'localhost'
db_name = 'retail_sales_db'

engine = create_engine(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")

db = SQLDatabase(engine=engine)

In [5]:
@tool
def execute_sql_query(query: str) -> str:
    """
    Executes the SQL query and returns the result.
    """
    try:
        result = db.run(query)
        return str(result)
    except Exception as e:
        return str(e)
    
tools = [execute_sql_query]

In [None]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=gemini_model)
tools = toolkit.get_tools()


In [18]:
tools

[QuerySQLDatabaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x000002AAE0025850>),
 InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x000002AAE0025850>),
 ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x000002AAE0025850>),
 QuerySQLCheckerTool(description='Use this tool to 

In [None]:
def create_system_prompt():
    tool_names = ""
    tool_descriptions = ""
    for tool in tools:
        tool_names += f"{tool.name}, "
        tool_descriptions += f"{tool.description}, "

    sys_prompt = f""" You are a SQL Database manager, You can write sql queries and return results to the user,
    These are the available tools for you {tool_names} and their descriptions are {tool_descriptions}
    You can use these tools to get the information you need to answer the user's question.
    Always return the query along with the result if you use QueryTool"""
    
    return sys_prompt

In [20]:
tools

[QuerySQLDatabaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x000002AAE0025850>),
 InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x000002AAE0025850>),
 ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x000002AAE0025850>),
 QuerySQLCheckerTool(description='Use this tool to 

In [21]:
gemini_agent = create_react_agent(
    model=gemini_model,
    tools = tools,
    #response_format = database,
    prompt=create_system_prompt(),
)

In [22]:
response=gemini_agent.invoke({'messages':"hello can you tell me what's inside the db retail sales?"})

In [15]:
for msg in response['messages']:
    msg.pretty_print()


hello can you tell me what's inside the db retail sales?
Tool Calls:
  sql_db_list_tables (140f96d1-1a80-4a2d-b8e9-506f8afee703)
 Call ID: 140f96d1-1a80-4a2d-b8e9-506f8afee703
  Args:
Name: sql_db_list_tables

sales_tb, user_details

Okay, the database contains two tables: `sales_tb` and `user_details`. Which table would you like to know more about?


In [17]:
response['structured_response']

database(query='SELECT * FROM sales_tb LIMIT 10', result='10 rows from sales_tb')

In [23]:
response=gemini_agent.invoke({'messages':"Can you tell me which product is the costliest?"})

In [24]:
for msg in response['messages']:
    msg.pretty_print()


Can you tell me which product is the costliest?
Tool Calls:
  sql_db_query (fb7a3ea7-ff8f-4ff7-9dd0-0023a9b4e9d9)
 Call ID: fb7a3ea7-ff8f-4ff7-9dd0-0023a9b4e9d9
  Args:
    query: SELECT ProductName FROM Products ORDER BY Price DESC LIMIT 1
Name: sql_db_query

Error: (pymysql.err.ProgrammingError) (1146, "Table 'retail_sales_db.products' doesn't exist")
[SQL: SELECT ProductName FROM Products ORDER BY Price DESC LIMIT 1]
(Background on this error at: https://sqlalche.me/e/20/f405)

It seems I made a mistake in the table name. I should first check what tables are available.
Tool Calls:
  sql_db_list_tables (6f65e30c-d9be-4a97-ae33-2fafce09654d)
 Call ID: 6f65e30c-d9be-4a97-ae33-2fafce09654d
  Args:
Name: sql_db_list_tables

sales_tb, user_details
Tool Calls:
  sql_db_schema (40bad9ff-4409-491b-abea-2292dd0e1534)
 Call ID: 40bad9ff-4409-491b-abea-2292dd0e1534
  Args:
    table_names: sales_tb
Name: sql_db_schema


CREATE TABLE sales_tb (
	`TransactionID` INTEGER, 
	`Date` DATE, 
	`Custom

In [32]:
print(response['messages'][-1].content)

['The costliest product category is Clothing.', '```sql\nSELECT ProductCategory FROM sales_tb ORDER BY PriceperUnit DESC LIMIT 1\n```']


In [33]:
response

{'messages': [HumanMessage(content='Can you tell me which product is the costliest?', additional_kwargs={}, response_metadata={}, id='430195f0-a1a9-4913-a722-707ef365bbdf'),
  AIMessage(content='', additional_kwargs={'function_call': {'name': 'sql_db_query', 'arguments': '{"query": "SELECT ProductName FROM Products ORDER BY Price DESC LIMIT 1"}'}}, response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'safety_ratings': []}, id='run-ba120137-8d44-4456-bfb3-4bf4d412c9de-0', tool_calls=[{'name': 'sql_db_query', 'args': {'query': 'SELECT ProductName FROM Products ORDER BY Price DESC LIMIT 1'}, 'id': 'fb7a3ea7-ff8f-4ff7-9dd0-0023a9b4e9d9', 'type': 'tool_call'}], usage_metadata={'input_tokens': 561, 'output_tokens': 18, 'total_tokens': 579, 'input_token_details': {'cache_read': 0}}),
  ToolMessage(content='Error: (pymysql.err.ProgrammingError) (1146, "Table \'retail_sales_db.products\' doesn\'t exist")\n[SQL: SELECT ProductName FROM Product

In [34]:
print(type(response['messages'][-1].content))

<class 'list'>
