In [1]:
from pyprojroot import here
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from langchain_openai import ChatOpenAI
from langchain_openai import AzureChatOpenAI
import os
from dotenv import load_dotenv
load_dotenv()

True

**Set the environment variables and load the LLM**

In [1]:
#os.environ['AZURE_OPENAI_API_KEY'] = os.getenv("AZURE_OPENAI_API_KEY")

llm = AzureChatOpenAI(
    api_key="EnvrvVgQgVNjzHrXY57TUyESTNeIqovna3WsB7mctCSmyrICEJ2lJQQJ99AJACHYHv6XJ3w3AAABACOG2fTJ",
    azure_deployment="gpt35",
    model= "gpt-3.5-turbo",
    azure_endpoint ="https://aressgenaisvc.openai.azure.com/",
    api_version="2024-06-01"
)
# llm = ChatOpenAI(model="gpt-4o-mini")
# llm = ChatOpenAI(model="gpt-4o")

NameError: name 'AzureChatOpenAI' is not defined

**Load and test the sqlite db**

In [None]:
sqldb_directory = here("data/travel.sqlite")
db = SQLDatabase.from_uri(
    f"sqlite:///{sqldb_directory}")

print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM aircrafts_data LIMIT 10;")

# from sqlalchemy import create_engine, inspect
# from sqlalchemy.orm import sessionmaker
# engine = create_engine(db_path)

# # Create a session
# Session = sessionmaker(bind=engine)
# session = Session()

# # Use SQLAlchemy's Inspector to get database information
# inspector = inspect(engine)

# # Get table names
# tables = inspector.get_table_names()
# print("Tables in the database:", tables)
# print(len(tables))

**Create the SQL agent chain and run a test query**

In [17]:
system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
    Question: {question}\n
    SQL Query: {query}\n
    SQL Result: {result}\n
    Answer:
    """

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(
    llm, db)
answer_prompt = PromptTemplate.from_template(
    system_role)
answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

In [None]:
message = "How many tables do I have in the database? and what are their names?"
response = chain.invoke({"question": message})
response

**Travel SQL-agent Tool Design**

In [25]:
from langchain_core.tools import tool
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from langchain_openai import ChatOpenAI


class TravelSQLAgentTool:
    """
    A tool for interacting with a travel-related SQL database using an LLM (Language Model) to generate and execute SQL queries.

    This tool enables users to ask travel-related questions, which are transformed into SQL queries by a language model.
    The SQL queries are executed on the provided SQLite database, and the results are processed by the language model to
    generate a final answer for the user.

    Attributes:
        sql_agent_llm (ChatOpenAI): An instance of a ChatOpenAI language model used to generate and process SQL queries.
        system_role (str): A system prompt template that guides the language model in answering user questions based on SQL query results.
        db (SQLDatabase): An instance of the SQL database used to execute queries.
        chain (RunnablePassthrough): A chain of operations that creates SQL queries, executes them, and generates a response.

    Methods:
        __init__: Initializes the TravelSQLAgentTool by setting up the language model, SQL database, and query-answering pipeline.
    """

    def __init__(self, llm: str, sqldb_directory: str, llm_temerature: float) -> None:
        """
        Initializes the TravelSQLAgentTool with the necessary configurations.

        Args:
            llm (str): The name of the language model to be used for generating and interpreting SQL queries.
            sqldb_directory (str): The directory path where the SQLite database is stored.
            llm_temerature (float): The temperature setting for the language model, controlling response randomness.
        """
        self.sql_agent_llm = ChatOpenAI(
            model=llm, temperature=llm_temerature)
        self.system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
            Question: {question}\n
            SQL Query: {query}\n
            SQL Result: {result}\n
            Answer:
            """
        self.db = SQLDatabase.from_uri(
            f"sqlite:///{sqldb_directory}")
        print(self.db.get_usable_table_names())

        execute_query = QuerySQLDataBaseTool(db=self.db)
        write_query = create_sql_query_chain(
            self.sql_agent_llm, self.db)
        answer_prompt = PromptTemplate.from_template(
            self.system_role)

        answer = answer_prompt | self.sql_agent_llm | StrOutputParser()
        self.chain = (
            RunnablePassthrough.assign(query=write_query).assign(
                result=itemgetter("query") | execute_query
            )
            | answer
        )

In [24]:
from agent_graph.load_tools_config import LoadToolsConfig

TOOLS_CFG = LoadToolsConfig()

@tool
def query_travel_sqldb(query: str) -> str:
    """Query the Swiss Airline SQL Database and access all the company's information. Input should be a search query."""
    agent = TravelSQLAgentTool(
        llm=TOOLS_CFG.travel_sqlagent_llm,
        sqldb_directory=TOOLS_CFG.travel_sqldb_directory,
        llm_temerature=TOOLS_CFG.travel_sqlagent_llm_temperature
    )
    response = agent.chain.invoke({"question": query})
    return response