In [12]:
from pyprojroot import here
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from langchain_ollama import ChatOllama

from dotenv import load_dotenv
load_dotenv()

True

#### Load and test the sqlite db

In [10]:
llm = ChatOllama(model="qwen2.5:14b")

# sqldb_directory = here("data/database.sqlite")
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(
    f"sqlite:///{sqldb_directory}")

print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

#### Create the SQL agent chain and run a test query

In [None]:

from langchain import hub
from typing_extensions import TypedDict
from typing_extensions import Annotated


query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
    Question: {question}\n
    SQL Query: {query}\n
    SQL Result: {result}\n
    Answer:
    """


class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str


class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


def write_query(state: State):
    prompt = query_prompt_template.invoke({
        "dialect": db.dialect,
        "top_k": 10,
        "table_info": db.get_table_info(),
        "input": state["question"],
    })
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return result

execute_query = QuerySQLDatabaseTool(db=db)
answer_prompt = PromptTemplate.from_template(system_role)
answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

In [11]:
message = "Give me the names of 5 artists from the database"
response = chain.invoke({"question": message})
response

'The names of 5 artists from the database are:\n\n1. AC/DC\n2. Aaron Copland & London Symphony Orchestra\n3. Aaron Goldberg\n4. Academy of St. Martin in the Fields & Sir Neville Marriner\n5. Academy of St. Martin in the Fields Chamber Ensemble & Sir Neville Marriner'

SQL-agent Tool Design

In [33]:
from langchain_core.tools import tool
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from langchain_ollama import ChatOllama
from langchain import hub
from typing_extensions import TypedDict
from typing_extensions import Annotated


class SQLAgentTool:
    """
    A tool for interacting with a digital media store related SQL database using an LLM (Language Model) to generate and execute SQL queries.

    This tool enables users to ask digital media store related questions, which are transformed into SQL queries by a language model.
    The SQL queries are executed on the provided SQLite database, and the results are processed by the language model to
    generate a final answer for the user.

    Attributes:
        sql_agent_llm (ChatOllama): An instance of a ChatOllama language model used to generate and process SQL queries.
        system_role (str): A system prompt template that guides the language model in answering user questions based on SQL query results.
        db (SQLDatabase): An instance of the SQL database used to execute queries.
        chain (RunnablePassthrough): A chain of operations that creates SQL queries, executes them, and generates a response.

    Methods:
        __init__: Initializes the SQLAgentTool by setting up the language model, SQL database, and query-answering pipeline.
    """

    def __init__(self, llm: str, sqldb_directory: str, llm_temerature: float) -> None:
        """
        Initializes the SQLAgentTool with the necessary configurations.

        Args:
            llm (str): The name of the language model to be used for generating and interpreting SQL queries.
            sqldb_directory (str): The directory path where the SQLite database is stored.
            llm_temerature (float): The temperature setting for the language model, controlling response randomness.
        """
        self.sql_agent_llm = ChatOllama(
            model=llm, temperature=llm_temerature)
        self.system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
            Question: {question}\n
            SQL Query: {query}\n
            SQL Result: {result}\n
            Answer:
            """
        self.db = SQLDatabase.from_uri(
            f"sqlite:///{sqldb_directory}")
        print(self.db.get_usable_table_names())

        execute_query = QuerySQLDatabaseTool(db=self.db)
        answer_prompt = PromptTemplate.from_template(
            self.system_role)

        answer = answer_prompt | self.sql_agent_llm | StrOutputParser()
        self.chain = (
            RunnablePassthrough.assign(query=self.__write_query).assign(
                result=itemgetter("query") | execute_query
            )
            | answer
        )

    class State(TypedDict):
        question: str
        query: str
        result: str
        answer: str


    class QueryOutput(TypedDict):
        """Generated SQL query."""
        query: Annotated[str, ..., "Syntactically valid SQL query."]


    def __write_query(self, state: State):
        query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")
        prompt = query_prompt_template.invoke({
            "dialect": self.db.dialect,
            "top_k": 10,
            "table_info": self.db.get_table_info(),
            "input": state["question"],
        })
        structured_llm = self.sql_agent_llm.with_structured_output(QueryOutput)
        result = structured_llm.invoke(prompt)
        return result

In [34]:
sqlagent_llm = "qwen2.5:14b"
sqldb_directory = here("data/Chinook.db")
sqlagent_llm_temperature = 0.0

@tool
def query_digital_media_sqldb(query: str) -> str:
    """Query the Digital Media Store SQL Database and access all the company's information. Input should be a search query."""
    agent = SQLAgentTool(
        llm=sqlagent_llm,
        sqldb_directory=sqldb_directory,
        llm_temerature=sqlagent_llm_temperature
    )
    response = agent.chain.invoke({"question": query})
    return response

In [35]:
message = "Give me the names of 5 artists from the database"
response = query_digital_media_sqldb.invoke(message)
print(response)

['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
The names of 5 artists from the database are:

1. AC/DC
2. Accept
3. Aerosmith
4. Alanis Morissette
5. Alice In Chains
