In [1]:
from pyprojroot import here
from langchain_community.utilities import SQLDatabase
from langchain_ollama import ChatOllama
from pprint import pprint

#### Load the LLM

In [2]:
# Separating models so they can be changed as needed
sql_agent_llm = ChatOllama(model="qwen2.5:14b", temperature=0)
table_extractor_llm = ChatOllama(model="qwen2.5:14b", temperature=0)

In [3]:
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

#### Prepare Table class

In [4]:
from typing import Optional

from pydantic import BaseModel, Field


class Table(BaseModel):
    """
    Represent a table in the SQL database.

    Attributes:
        name (str): The name of the table in the SQL database.

    # Note that:
    # 1. Field is an `optional` -- this allows the model to decline to extract it!
    """
    name: Optional[str] = Field(description="Name of table in SQL database.")

#### Strategy A:

In [5]:
table_names = "\n".join(db.get_usable_table_names())
pprint(table_names)

('Album\n'
 'Artist\n'
 'Customer\n'
 'Employee\n'
 'Genre\n'
 'Invoice\n'
 'InvoiceLine\n'
 'MediaType\n'
 'Playlist\n'
 'PlaylistTrack\n'
 'Track')


In [6]:
from langchain_core.prompts import ChatPromptTemplate


system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_names}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

prompt_template = ChatPromptTemplate.from_messages([
    ("system", system),
    ("human", "{input}")
])

structured_llm = table_extractor_llm.with_structured_output(schema=Table)
input = "What are all the genres of Alanis Morisette songs"
structured_llm.invoke(prompt_template.invoke({"input": input}))

Table(name='Track')

#### Strategy B:

In [7]:
from langchain_core.prompts import ChatPromptTemplate


system = f"""You will recieve a question.

If the question is about **Music**, return **ALL** these tables:
  - "Album"
  - "Artist"
  - "Genre"
  - "MediaType"
  - "Playlist"
  - "PlaylistTrack"
  - "Track"

If the question is about **Business**, return **ALL** these tables:
  - "Customer"
  - "Employee"
  - "Invoice"
  - "InvoiceLine"

If you are unsure, return the full list of all available tables for both Music and Business categories."""

prompt_template = ChatPromptTemplate.from_messages([
    ("system", system),
    ("human", "{question}")
])

structured_llm = table_extractor_llm.with_structured_output(schema=Table)
input = "What are all the genres of Alanis Morisette songs"
structured_llm.invoke(prompt_template.invoke({"question": input}))

Table(name='Track')

#### Final step:

In [10]:
from langchain_core.runnables import RunnablePassthrough
from langchain import hub
from typing_extensions import TypedDict
from typing_extensions import Annotated

class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str


class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

def write_query(state: State):
    prompt = query_prompt_template.invoke({
        "dialect": db.dialect,
        "top_k": 10,
        "table_info": db.get_table_info(),
        "input": state["question"],
    })
    structured_llm = table_extractor_llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return result

table_chain = prompt_template | table_extractor_llm.with_structured_output(Table)

# Create full chain
full_chain = (
    RunnablePassthrough.assign(
        table_names_to_use=lambda x: table_chain.invoke({"question": x["question"]})
    ) 
    | write_query
)

# Test the chain
question = "Give me the names of 5 artists from the database"
query = full_chain.invoke({"question": question})
print(query)

{'query': 'SELECT DISTINCT Name FROM Artist LIMIT 5;'}


In [12]:

db.run(query['query'])

"[('AC/DC',), ('Accept',), ('Aerosmith',), ('Alanis Morissette',), ('Alice In Chains',)]"

In [16]:
from langchain.tools import tool


class ChinookSQLAgent:
    """
    A specialized SQL agent that interacts with the Chinook SQL database using an LLM (Large Language Model).

    The agent handles SQL queries by mapping user questions to relevant SQL tables based on categories like "Music"
    and "Business". It uses an extraction chain to determine relevant tables based on the question and then
    executes queries against the database using the appropriate tables.

    Attributes:
        sql_agent_llm (ChatOpenAI): The language model used for interpreting and interacting with the database.
        db (SQLDatabase): The SQL database object, representing the Chinook database.
        full_chain (Runnable): A chain of operations that maps user questions to SQL tables and executes queries.

    Methods:
        __init__: Initializes the agent by setting up the LLM, connecting to the SQL database, and creating query chains.

    Args:
        sqldb_directory (str): The directory where the Chinook SQLite database file is located.
        llm (str): The name of the LLM model to use (e.g., "gpt-3.5-turbo").
        llm_temperature (float): The temperature setting for the LLM, controlling the randomness of responses.
    """

    def __init__(self, sqldb_directory: str, llm: str, llm_temerature: float) -> None:
        """Initializes the ChinookSQLAgent with the LLM and database connection.

        Args:
            sqldb_directory (str): The directory path to the SQLite database file.
            llm (str): The LLM model identifier (e.g., "gpt-3.5-turbo").
            llm_temerature (float): The temperature value for the LLM, determining the randomness of the model's output.
        """
        self.sql_agent_llm = ChatOllama(
            model=llm, temperature=llm_temerature)

        self.db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
        print(self.db.get_usable_table_names())
        category_chain_system = """Return the names of the SQL tables that are relevant to the user question. \
        The tables are:

        Music
        Business"""     

        prompt_template = ChatPromptTemplate.from_messages([
            ("system", category_chain_system),
            ("human", "{question}")
        ])

        table_chain = prompt_template | table_extractor_llm.with_structured_output(Table)

        # Create full chain
        self.full_chain = (
            RunnablePassthrough.assign(
                table_names_to_use=lambda x: table_chain.invoke({"question": x["question"]})
            ) 
            | self.__write_query
        )

    class State(TypedDict):
        question: str
        query: str
        result: str
        answer: str


    class QueryOutput(TypedDict):
        """Generated SQL query."""
        query: Annotated[str, ..., "Syntactically valid SQL query."]


    def __write_query(self, state: State):
        query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")
        prompt = query_prompt_template.invoke({
            "dialect": self.db.dialect,
            "top_k": 10,
            "table_info": self.db.get_table_info(),
            "input": state["question"],
        })
        structured_llm = self.sql_agent_llm.with_structured_output(QueryOutput)
        result = structured_llm.invoke(prompt)
        return result['query']


@tool
def query_chinook_sqldb(query: str) -> str:
    """Query the Chinook SQL Database. Input should be a search query."""
    # Create an instance of ChinookSQLAgent
    agent = ChinookSQLAgent(
        sqldb_directory=here("data/Chinook.db"),
        llm="qwen2.5:14b",
        llm_temerature=0.5
    )

    query = agent.full_chain.invoke({"question": query})

    return agent.db.run(query)

In [17]:
result = query_chinook_sqldb("What are all the genres of Alanis Morisette songs")
print(result)

['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
[(38,), (39,), (40,), (41,), (42,), (43,), (44,), (45,), (46,), (47,), (48,), (49,), (50,)]
