### Chatbot trial langchain

In [11]:
%%capture --no-stderr
%pip install -U langgraph langchain_openai langchain_community


[notice] A new release of pip is available: 24.3.1 -> 25.0
[notice] To update, run: python.exe -m pip install --upgrade pip


In [12]:
%pip install python-dotenv
from dotenv import load_dotenv





[notice] A new release of pip is available: 24.3.1 -> 25.0
[notice] To update, run: python.exe -m pip install --upgrade pip


## Loading the Environments

In [13]:
load_dotenv()

True

## Getting Google Gemini API KEY

In [14]:
import os
load_dotenv()
google_api_key= os.getenv('GEMINI_API_KEY')
print(google_api_key)

AIzaSyAUJuGqy5yKx9stk6sCDu9hPzEVEoId6FU


## Configuring a virtual Database

In [15]:
import requests

url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"

response = requests.get(url)

if response.status_code == 200:
    # Open a local file in binary write mode
    with open("Chinook.db", "wb") as file:
        # Write the content of the response (the file) to the local file
        file.write(response.content)
    print("File downloaded and saved as Chinook.db")
else:
    print(f"Failed to download the file. Status code: {response.status_code}")

File downloaded and saved as Chinook.db


### Calling the langchain SQLDatabase function

#### Checking Database tables

In [16]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [18]:
# Ensure your VertexAI credentials are configured

%pip install --upgrade --quiet langchain-google-genai pillow

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.3.1 -> 25.0
[notice] To update, run: python.exe -m pip install --upgrade pip


## Configuring my Large Language Model

#### Using Langchain and Gemini AI

In [28]:
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(model="gemini-pro",api_key=google_api_key)

### Creating Helper functions for my AI Agent

In [29]:
from typing import Any

from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode


def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

### Working with the tools to get the AI response on the database

In [41]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit


toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")

print(list_tables_tool.invoke(""))

print(get_schema_tool.invoke("Artist"))

Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track

CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


### A tool function that will perform query from the AI Agent Against the real Database

In [31]:
from langchain_core.tools import tool


@tool
def db_query_tool(query: str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    """
    result = db.run_no_throw(query)
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    return result


print(db_query_tool.invoke("SELECT * FROM Artist LIMIT 10;"))

[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]


# Prompt LLm

#### Making sure that the sql will work

#### Function 1 that returns in a dictionary way

In [42]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from langchain_core.output_parsers import JsonOutputParser
import re


# 2. Modified system prompt
query_check_system = """You are a SQL expert that ALWAYS responds using the db_query_tool.
Format your response EXACTLY like this: {{"query": "SQL_STATEMENT"}}
Never add explanations or formatting.
ALWAYS RESPOND WITH RAW JSON FORMAT - NO MARKDOWN CODE BLOCKS.
Example of valid output:{{"query": "SELECT * FROM Table;"}}"""

# 3. Create PROPERLY STRUCTURED prompt template
query_check_prompt = ChatPromptTemplate.from_messages([
    ("system", query_check_system),
    ("human", "Validate and correct this SQL: {user}")  # Single input variable
])
class MarkdownJsonParser(JsonOutputParser):
    def parse(self, text: str):
        # Extract JSON from markdown code block
        json_match = re.search(r'```json\n(.*?)\n```', text, re.DOTALL)
        if json_match:
            return super().parse(json_match.group(1))
        return super().parse(text)  # Fallback to raw JSON parsing

# Add to your chain

# 4. Configure model with tool binding
gemini_model = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    api_key=google_api_key,
    temperature=0
).bind_tools([db_query_tool])

# 5. Create the full chain
query_check_chain = (
    query_check_prompt |
    gemini_model |
    MarkdownJsonParser()
)


# 6. Invoke with CORRECT input format
response = query_check_chain.invoke({
    "user": "SELECT * FROM Artist LIMIT 10;"  # Match the template variable
})

print(response)
print(response['query'])


{'query': 'SELECT * FROM Artist LIMIT 10;'}
SELECT * FROM Artist LIMIT 10;


#### Function 2 that returns the sql query in the content only from The AI

In [43]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage

# Updated system prompt focusing on query checking only
query_check_system = """You are an expert SQL validator. Your ONLY task is to return a corrected SQL query when errors exist. 

### Absolute Rules:
1. If the query is correct: Return original query EXACTLY as received.
2. If incorrect: Return ONLY the corrected SQL query with:
   - Proper NULL handling in NOT IN clauses
   - UNION ALL instead of UNION where appropriate
   - Inclusive BETWEEN ranges
   - Resolved data type mismatches
   - Valid quotation of identifiers
   - Correct function argument counts
   - Explicit type casting when needed
   - Valid join conditions
   - ORDER BY only when explicitly required

### Output Protocol:
- **STRICTLY FORBIDDEN:** Explanations, markdown, formatting, or text beyond the SQL
- **STRICTLY FORBIDDEN:** Comments like "Corrected query:" or "Here is:"
- **ONLY OUTPUT:** 1 raw SQL query with no encapsulation
- If input is correct: Return input verbatim

Failure to follow these instructions will result in incorrect output."""

# Create the prompt template
# query_check_prompt = ChatPromptTemplate.from_messages(
#     [("system", query_check_system), ("placeholder", "{messages}")]
# )

# Initialize the latest Google Gemini AI model
gemini_model = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    api_key=google_api_key,
    temperature=0
)

# 2. Create PROMPT TEMPLATE CHAIN
query_check = ChatPromptTemplate.from_messages([
    ("system", query_check_system),
    ("human", "{user}")  # Simplified placeholder
]) | gemini_model

# 3. Invoke PROPERLY with input formatting
response = query_check.invoke({
    "user": "SELECT * FROM Artist LIMIT 10;"  # Direct SQL input
})

print(response.content)


SELECT * FROM Artist LIMIT 10;


# Lets Define the Workflow

In [72]:
from typing import Any

from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode


def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    # Print the tools being passed to the function
    print(f"Creating ToolNode with tools: {tools}")
    
    # Create the ToolNode
    tool_node = ToolNode(tools)
    
    # Print when ToolNode is successfully created
    print("ToolNode created successfully.")
    
    # Attach fallbacks and print the action
    tool_node_with_fallbacks = tool_node.with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )
    
    # Print when fallbacks are added
    print("Fallbacks added to ToolNode.")
    
    return tool_node_with_fallbacks



def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

In [77]:
from typing import Annotated, Literal

from langchain_core.messages import AIMessage

from langchain_google_genai import ChatGoogleGenerativeAI
from pydantic import BaseModel, Field
from typing_extensions import TypedDict

from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import AnyMessage, add_messages


# Define the state for the agent
class State(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


# Define a new graph
workflow = StateGraph(State)
print(State)


# Add a node for the first tool call

def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
    debug_output= {
        "messages": [
            AIMessage(
                content="",
                tool_calls=[
                    {
                        "name": "sql_db_list_tables",
                        "args": {},
                        "id": "tool_abcd123",
                    }
                ],
            )
        ]
    }

    print("DEBUG: first_tool_call output:", debug_output)  # Print for debugging
    return debug_output



def model_check_query(state: State) -> dict[str, list[AIMessage]]:
    """
    Use this tool to double-check if your query is correct before executing it.
    """
    # Extract the last message from the state
    last_message = state["messages"][-1]
    
    # Debugging: Print the last message to verify
    print(f"Last message: {last_message}")
    
    # Ensure the input to query_check is a valid dictionary (user should be a list with the last message)
    query_input = {"user": [last_message]}  # Make sure this is a valid dictionary with a list of messages
    
    # Debugging: Check the structure of the input
    print(f"Query input (before invoking): {query_input}")
    
    try:
        # Invoke query_check with the correct dictionary format
        query_result = query_check.invoke(query_input)
        
        # Debugging: Print the result of query_check.invoke
        print(f"Query result (after invoke): {query_result}")
        
        # Ensure that the content is strictly a string and not a tuple or dict
        query_result_str = str(query_result["content"])  # Ensure content is a string
        
        # Return the result wrapped in the expected structure
        return {
            "messages": [
                AIMessage(content=query_result_str)  # Pass the result of query_check.invoke as content
            ]
        }
    except Exception as e:
        # Catch any errors and print them for debugging
        print(f"Error during query_check.invoke: {e}")
        return {
            "messages": [
                AIMessage(content=f"Error: {str(e)}")  # Return error message if invocation fails
            ]
        }





workflow.add_node("first_tool_call", first_tool_call)

# Add nodes for the first two tools
workflow.add_node(
    "list_tables_tool", create_tool_node_with_fallback([list_tables_tool])
)
workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))
gemini_model = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    api_key=google_api_key,
    temperature=0
)
#  gemini_model = ChatGoogleGenerativeAI(
#                 model="gemini-1.5-flash",
#                 api_key=google_api_key,
#                 temperature=0
#             ).bind_tools([db_query_tool])
# Add a node for a model to choose the relevant tables based on the question and available tables
model_get_schema = gemini_model.bind_tools(
    [get_schema_tool]
)
print("Modal Schema:",model_get_schema)
workflow.add_node(
    "model_get_schema",
    lambda state: {
        "messages": [model_get_schema.invoke(state["messages"])],
    },
)


# Describe a tool to represent the end state
class SubmitFinalAnswer(BaseModel):
    """Submit the final answer to the user based on the query results."""

    final_answer: str = Field(..., description="The final answer to the user")


# Add a node for a model to generate a query based on the question and schema
query_gen_system = """You are a SQL expert with a strong attention to detail.

Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.

DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.

When generating the query:

Output the SQL query that answers the input question without a tool call.

Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

If you get an error while executing a query, rewrite the query and try again.

If you get an empty result set, you should try to rewrite the query to get a non-empty result set. 
NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.

If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database."""
query_gen_prompt = ChatPromptTemplate.from_messages(
    [("system", query_gen_system), ("human", "{messages}")]
)
print("Query Gen Prompt:",query_gen_prompt)
#  gemini_model = ChatGoogleGenerativeAI(
#                 model="gemini-1.5-flash",
#                 api_key=google_api_key,
#                 temperature=0
#             ).bind_tools([db_query_tool])
query_gen = query_gen_prompt | gemini_model.bind_tools(
    [SubmitFinalAnswer]
)


def query_gen_node(state: State):
    print("State",state)
    message= query_gen.invoke(state)
   
    # message=response['query']

    # Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
    tool_messages = []
    if message.tool_calls:
        for tc in message.tool_calls:
            if tc["name"] != "SubmitFinalAnswer":
                tool_messages.append(
                    ToolMessage(
                        content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
                        tool_call_id=tc["id"],
                    )
                )
    else:
        tool_messages = []
    return {"messages": [message] + tool_messages}


workflow.add_node("query_gen", query_gen_node)

# Add a node for the model to check the query before executing it
workflow.add_node("correct_query", model_check_query)

# Add node for executing the query
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))


# Define a conditional edge to decide whether to continue or end the workflow
def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
    messages = state["messages"]
    last_message = messages[-1]
    # If there is a tool call, then we finish
    if getattr(last_message, "tool_calls", None):
        return END
    if last_message.content.startswith("Error:"):
        return "query_gen"
    else:
        return "correct_query"


# Specify the edges between the nodes
workflow.add_edge(START, "first_tool_call")
workflow.add_edge("first_tool_call", "list_tables_tool")
workflow.add_edge("list_tables_tool", "model_get_schema")
workflow.add_edge("model_get_schema", "get_schema_tool")
workflow.add_edge("get_schema_tool", "query_gen")
workflow.add_conditional_edges(
    "query_gen",
    should_continue,
)
workflow.add_edge("correct_query", "execute_query")
workflow.add_edge("execute_query", "query_gen")

# Compile the workflow into a runnable
app = workflow.compile()

<class '__main__.State'>
Creating ToolNode with tools: [ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x000001369325B050>)]
ToolNode created successfully.
Fallbacks added to ToolNode.
Creating ToolNode with tools: [InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x000001369325B050>)]
ToolNode created successfully.
Fallbacks added to ToolNode.
Modal Schema: bound=ChatGoogleGenerativeAI(model='models/gemini-1.5-flash', google_api_key=SecretStr('**********'), temperature=0.0, client=<google.ai.generativelanguage_v1beta.services.generative_service.client.GenerativeServiceClient object at 0x000001369771A5D0>, default_metadata=()) kwargs={'tools': [{'type': 'function', 'fun

### Trying the AI workflow

In [76]:
messages = app.invoke(
    {"messages": [("user", "Which are the songs in the database with the artist Aerosmith?")]},
)
json_str = messages["messages"][-1].tool_calls[0]["args"]["final_answer"]
json_str

DEBUG: first_tool_call output: {'messages': [AIMessage(content='', additional_kwargs={}, response_metadata={}, tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': 'tool_abcd123', 'type': 'tool_call'}])]}
State {'messages': [HumanMessage(content='Which are the songs in the database with the artist Aerosmith?', additional_kwargs={}, response_metadata={}, id='dab2151e-2ac4-4afd-b1bd-34799169d88c'), AIMessage(content='', additional_kwargs={}, response_metadata={}, id='982ea654-a666-4de9-9135-7984e59b45a4', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': 'tool_abcd123', 'type': 'tool_call'}]), ToolMessage(content='Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track', name='sql_db_list_tables', id='0388a23d-90ac-4582-952f-9bd410bf74b5', tool_call_id='tool_abcd123'), AIMessage(content='I need more information about the database schema to answer your question.  Specifically, I need to know the names of the tables and c

KeyboardInterrupt: 

In [None]:
from langchain_core.messages import AIMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from pydantic import BaseModel, Field
from typing import List

# Define the state for the agent
class State(BaseModel):
    messages: List[AIMessage] = []

# Initialize Gemini model
gemini_model = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    api_key=google_api_key,
    temperature=0
)

# Describe a tool to represent the end state
class SubmitFinalAnswer(BaseModel):
    """Submit the final answer to the user based on the query results."""
    final_answer: str = Field(..., description="The final answer to the user")

# Query generation system prompt
query_gen_system = """You are a SQL expert with a strong attention to detail.

Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.

DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.

When generating the query:

Output the SQL query that answers the input question without a tool call.

Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

If you get an error while executing a query, rewrite the query and try again.

If you get an empty result set, you should try to rewrite the query to get a non-empty result set. 
NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.

If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database."""

# Query generation function
def query_gen_node(state: State):
    if not state.messages:
        return {"messages": [AIMessage(content="Error: No input provided for query generation.")]}
    
    message = gemini_model.invoke(state.messages[-1].content)
    return {"messages": [AIMessage(content=message)]}

# Model to check query before execution
def model_check_query(state: State):
    if not state.messages:
        return {"messages": [AIMessage(content="Error: No query provided for validation.")]}
    
    return {"messages": [AIMessage(content="Query is valid.")]}

# Execution function
def execute_query(state: State):
    if not state.messages:
        return {"messages": [AIMessage(content="Error: No query to execute.")]}
    
    return {"messages": [AIMessage(content="Query executed successfully.")]}

# Workflow execution
state = State(messages=[AIMessage(content="Fetch the latest 5 records from users table.")])
state = query_gen_node(state)
state = model_check_query(state)
state = execute_query(state)

print(state)
