### Design the full graph using 3 tools: search_tool, RAG tool and SQL-agent for travel database

In [1]:
import os
from dotenv import load_dotenv
from pyprojroot import here
load_dotenv()

True

In [2]:
os.environ['OPENAI_API_KEY'] = os.getenv("OPEN_AI_API_KEY")
os.environ['TAVILY_API_KEY'] = os.getenv("TAVILY_API_KEY")

### 1. Initialize The Tools

#### 1.1 RAG tool designe

In [3]:
from langchain_chroma import Chroma
from langchain_ollama import OllamaEmbeddings
from langchain_core.tools import tool

EMBEDDING_MODEL = "nomic-embed-text"
VECTORDB_DIR = "data/airline_policy_vectordb"
K = 2

@tool
def lookup_policy(query: str)->str:
    """Consult the company policies to check whether certain options are permitted."""
    vectordb = Chroma(
    collection_name="rag-chroma",
    persist_directory=str(here(VECTORDB_DIR)),
    embedding_function=OllamaEmbeddings(model=EMBEDDING_MODEL)
    )
    docs = vectordb.similarity_search(query, k=K)
    return "\n\n".join([doc.page_content for doc in docs])

print(lookup_policy)

name='lookup_policy' description='Consult the company policies to check whether certain options are permitted.' args_schema=<class 'langchain_core.utils.pydantic.lookup_policy'> func=<function lookup_policy at 0x76ad386d5f80>


Test the RAG tool

In [4]:
lookup_policy.invoke("can I cancel my ticket?")

'for a refund or may only be able to receive a partial refund. If you booked your flight through a third-party website or\ntravel agent, you may need to contact them directly to cancel your flight. Always check the terms and conditions of your\nticket to make sure you understand the cancellation policy and any associated fees or penalties. If you\'re cancelling your\nflight due to unforeseen circumstances such as a medical emergency or a natural disaster , Swiss Air may of fer you\nspecial exemptions or accommodations. What is Swiss Airlines 24 Hour Cancellation Policy? Swiss Airlines has a 24\n\ncircumstances such as bad weather or political unrest, Swiss Airlines may not be obligated to of fer any compensation. In\nsummary , Swiss Airlines\' cancellation policy varies depending on your fare type and the time of cancellation. T o avoid any\nunnecessary fees or charges, it\'s important to familiarise yourself with the terms and conditions of your ticket and to\ncontact Swiss Airlines a

#### 1.2 Search tool designe

In [5]:
from langchain_community.tools.tavily_search import TavilySearchResults

search_tool = TavilySearchResults(max_results=2)

Test the Search Tools

In [6]:
search_tool.invoke("What's a 'node' in LangGraph?")

[{'url': 'https://langchain-ai.github.io/langgraph/concepts/low_level/',
  'content': 'Nodes¶ In LangGraph, nodes are typically python functions (sync or async) where the first positional argument is the state, and (optionally), the second positional argument is a "config", containing optional configurable parameters (such as a thread_id). Similar to NetworkX, you add these nodes to a graph using the add_node method:'},
 {'url': 'https://medium.com/@cplog/introduction-to-langgraph-a-beginners-guide-14f9be027141',
  'content': 'Nodes: Nodes are the building blocks of your LangGraph. Each node represents a function or a computation step. You define nodes to perform specific tasks, such as processing input, making'}]

#### 1.3 SQL agent tool designe

In [7]:
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from langchain_ollama import ChatOllama
from langchain import hub
from typing_extensions import TypedDict
from typing_extensions import Annotated

In [24]:
# To działa
query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")
sqldb_directory = here("data/Chinook.db")

class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


db = SQLDatabase.from_uri(
    f"sqlite:///{sqldb_directory}")
print(db.get_table_info())
llm = ChatOllama(model="qwen2.5:14b", temperature=0)

execute_query = QuerySQLDatabaseTool(db=db)

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question in a funny way.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)    

def write_query(state: State):
    prompt = query_prompt_template.invoke({
        "dialect": db.dialect,
        "top_k": 10,
        "table_info": db.get_table_info(),
        "input": state["question"],
    })
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return result

answer = answer_prompt | llm | StrOutputParser()
# Build the chain
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

# Execute chain
# response = chain.invoke({"question": "How many employees are there"})
# print(response)

# def query_sqldb(query):
#     """Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
#     response = chain.invoke({"question": query})
#     return response

@tool
def query_sqldb(query):
    """Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
    response = chain.invoke({"question": query})
    return response


CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Empl

In [None]:
# message = "How many employees are there"
# response = query_sqldb(message)
# print(response)

In [None]:
message = "How many employees are there"
# message = "How many tables do I have in the database? and what are their names?"
response = query_sqldb.invoke(message)
print(response)

## Testy

In [26]:
sqldb_directory = here("data/travel.sqlite")

db = SQLDatabase.from_uri(
    f"sqlite:///{sqldb_directory}")

sql_llm = ChatOllama(model="qwen2.5:14b", temperature=0)

query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

execute_query = QuerySQLDatabaseTool(db=db)

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question in a funny way.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

answer = answer_prompt | sql_llm | StrOutputParser()

class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str


class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


def write_query(state: State):
    prompt = query_prompt_template.invoke({
        "dialect": db.dialect,
        "top_k": 10,
        "table_info": db.get_table_info(),
        "input": state["question"],
    })
    structured_llm = sql_llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return result

# Build the chain
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

@tool
def query_sqldb(query):
    """Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
    response = chain.invoke({"question": query})
    return response

In [None]:
message = "How many tables do I have in the database? and what are their names?"
response = query_sqldb.invoke(message)
print(response)

In [20]:
sqldb_directory = here("data/travel.sqlite")

sql_llm = ChatOllama(model="qwen2.5:14b", temperature=0)
# llm = ChatOpenAI(model="gpt-4o-mini")
# llm = ChatOpenAI(model="gpt-4o")
system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
    Question: {question}\n
    SQL Query: {query}\n
    SQL Result: {result}\n
    Answer:
    """


class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str


class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

db = SQLDatabase.from_uri(
    f"sqlite:///{sqldb_directory}")

def write_query(state: State):
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 10,
            "table_info": db.get_table_info(),
            "input": state["question"],
        }
    )
    structured_llm = sql_llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return result

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(
    sql_llm, db)
answer_prompt = PromptTemplate.from_template(
    system_role)


answer = answer_prompt | sql_llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)
# Test the chain
# message = "How many tables do I have in the database? and what are their names?"
# response = chain.invoke({"question": message})

# @tool
# def query_sqldb(query):
#     """Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
#     response = chain.invoke({"question": query})
#     return response


def write_query(state: State):
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 10,
            "table_info": db.get_table_info(),
            "input": state["question"],
        }
    )
    structured_llm = sql_llm.with_structured_output(QueryOutput)
    response = structured_llm.invoke(prompt)
    return response

# gen_query = write_query({"question": "How many Employees are there?"})

@tool
def query_sqldb(query):
    """Execute SQL query."""
    gen_query = write_query({"question": query})

    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(gen_query["query"])}


In [None]:
message = "How many tables do I have in the database? and what are their names?"
response = query_sqldb.invoke(message)
print(response)

Test z dokumentacji

In [None]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

In [8]:
from typing_extensions import TypedDict


class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

In [9]:
os.environ['LANGCHAIN_API_KEY'] = os.getenv("LANGCHAIN_API_KEY")

In [10]:
from langchain_ollama import  ChatOllama


llm = ChatOllama(model="qwen2.5:14b")

In [None]:
from langchain import hub

query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

assert len(query_prompt_template.messages) == 1
query_prompt_template.messages[0].pretty_print()

In [12]:
from typing_extensions import Annotated


class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


def write_query(state: State):
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 10,
            "table_info": db.get_table_info(),
            "input": state["question"],
        }
    )
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)

    if result is None:
        print(f"Error: Unable to invoke structured LLM: {result}")
        return
    else:
        return {"query": result["query"]}

In [None]:
gen_query = write_query({"question": "How many Employees are there?"})

In [17]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool


def execute_query(state: State):
    """Execute SQL query."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}

In [None]:
execute_query({"query": "SELECT COUNT(*) AS EmployeeCount FROM Employee;"})

In [21]:
def generate_answer(state: State):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f'Question: {state["question"]}\n'
        f'SQL Query: {state["query"]}\n'
        f'SQL Result: {state["result"]}'
    )
    response = llm.invoke(prompt)
    return {"answer": response.content}

In [22]:
from langgraph.graph import START, StateGraph

graph_builder = StateGraph(State).add_sequence(
    [write_query, execute_query, generate_answer]
)
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()

In [None]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
for step in graph.stream(
    {"question": "How many employees are there?"}, stream_mode="updates"
):
    print(step)