In [None]:
import os
from dotenv import load_dotenv
from pyprojroot import here
from typing import List
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from pprint import pprint
import httpx

load_dotenv()

True

**Set the environment variables and load the LLM**

In [None]:
os.environ['OPENAI_API_KEY'] = os.getenv("OPEN_AI_API_KEY")

sql_agent_llm = ChatOpenAI(
    model="openai/gpt-3.5-turbo",
    temperature=0,
    base_url="https://models.github.ai/inference",
    api_key=os.getenv("OPEN_AI_API_KEY")
)
table_extractor_llm = ChatOpenAI(
    model="openai/gpt-4o-mini",
    temperature=0,
    base_url="https://models.github.ai/inference",
    api_key=os.getenv("OPEN_AI_API_KEY")
)


In [15]:
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

**Prepare the `Table` class**

In [16]:
from langchain_core.pydantic_v1 import BaseModel, Field

class Table(BaseModel):
    """
    Represents a table in the SQL database.

    Attributes:
        name (str): The name of the table in the SQL database.
    """
    name: str = Field(description="Name of table in SQL database.")

### **Strategy A:**

In [None]:
table_names = "\n".join(db.get_usable_table_names())
pprint(table_names)

In [None]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic

system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_names}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""
table_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

### **Strategy B:**

Music:

- "Album"
- "Artist"
- "Genre"
- "MediaType"
- "Playlist"
- "PlaylistTrack"
- "Track"

Business:

- "Customer"
- "Employee"
- "Invoice"
- "InvoiceLine"

In [None]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic

system = f"""You will recieve a question.

If the question is about **Music**, return **ALL** these tables:
  - "Album"
  - "Artist"
  - "Genre"
  - "MediaType"
  - "Playlist"
  - "PlaylistTrack"
  - "Track"

If the question is about **Business**, return **ALL** these tables:
  - "Customer"
  - "Employee"
  - "Invoice"
  - "InvoiceLine"

If you are unsure, return the full list of all available tables for both Music and Business categories."""
table_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

### **Strategy C:**

- **Step 1: Define the category**

In [17]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic

system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:

Music
Business"""
category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=table_extractor_llm, system_message=system)

In [19]:
category_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Music')]

- **Step 2: Execute the python function**

In [20]:
def get_tables(categories: List[Table]) -> List[str]:
    """Maps category names to corresponding SQL table names.

    Args:
        categories (List[Table]): A list of `Table` objects representing different categories.

    Returns:
        List[str]: A list of SQL table names corresponding to the provided categories.
    """
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables


table_chain = category_chain | get_tables 
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

['Album', 'Artist', 'Genre', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']

### **Final step:**

**Attach the desired strategy to your SQL agent**

In [21]:
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import create_sql_query_chain
from operator import itemgetter

query_chain = create_sql_query_chain(sql_agent_llm, db)
# Convert "question" key to the "input" key expected by current table_chain.
table_chain = {"input": itemgetter("question")} | table_chain
# Set table_names_to_use using table_chain.
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain

**Test the agent**

In [22]:
query = full_chain.invoke(
    {"question": "What are all the genres of Alanis Morisette songs"}
)
print(query)

SELECT DISTINCT "Genre"."Name"
FROM "Genre"
JOIN "Track" ON "Genre"."GenreId" = "Track"."GenreId"
JOIN "Album" ON "Track"."AlbumId" = "Album"."AlbumId"
JOIN "Artist" ON "Album"."ArtistId" = "Artist"."ArtistId"
WHERE "Artist"."Name" = 'Alanis Morissette'
LIMIT 5;


In [23]:
db.run(query)

"[('Rock',)]"

: 

**Prepare the tool (Don't run the following cell)**

In [None]:
class ChinookSQLAgent:
    """
    A specialized SQL agent that interacts with the Chinook SQL database using an LLM (Large Language Model).

    The agent handles SQL queries by mapping user questions to relevant SQL tables based on categories like "Music"
    and "Business". It uses an extraction chain to determine relevant tables based on the question and then
    executes queries against the database using the appropriate tables.

    Attributes:
        sql_agent_llm (ChatOpenAI): The language model used for interpreting and interacting with the database.
        db (SQLDatabase): The SQL database object, representing the Chinook database.
        full_chain (Runnable): A chain of operations that maps user questions to SQL tables and executes queries.

    Methods:
        __init__: Initializes the agent by setting up the LLM, connecting to the SQL database, and creating query chains.

    Args:
        sqldb_directory (str): The directory where the Chinook SQLite database file is located.
        llm (str): The name of the LLM model to use (e.g., "gpt-3.5-turbo").
        llm_temperature (float): The temperature setting for the LLM, controlling the randomness of responses.
    """

    def __init__(self, sqldb_directory: str, llm: str, llm_temerature: float) -> None:
        """Initializes the ChinookSQLAgent with the LLM and database connection.

        Args:
            sqldb_directory (str): The directory path to the SQLite database file.
            llm (str): The LLM model identifier (e.g., "gpt-3.5-turbo").
            llm_temerature (float): The temperature value for the LLM, determining the randomness of the model's output.
        """
        self.sql_agent_llm = ChatOpenAI(
            model=llm, temperature=llm_temerature)

        self.db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
        print(self.db.get_usable_table_names())
        category_chain_system = """Return the names of the SQL tables that are relevant to the user question. \
        The tables are:

        Music
        Business"""
        category_chain = create_extraction_chain_pydantic(
            Table, self.sql_agent_llm, system_message=category_chain_system)
        table_chain = category_chain | get_tables  # noqa
        query_chain = create_sql_query_chain(self.sql_agent_llm, self.db)
        # Convert "question" key to the "input" key expected by current table_chain.
        table_chain = {"input": itemgetter("question")} | table_chain
        # Set table_names_to_use using table_chain.
        self.full_chain = RunnablePassthrough.assign(
            table_names_to_use=table_chain) | query_chain


@tool
def query_chinook_sqldb(query: str) -> str:
    """Query the Chinook SQL Database. Input should be a search query."""
    # Create an instance of ChinookSQLAgent
    agent = ChinookSQLAgent(
        sqldb_directory=TOOLS_CFG.chinook_sqldb_directory,
        llm=TOOLS_CFG.chinook_sqlagent_llm,
        llm_temerature=TOOLS_CFG.chinook_sqlagent_llm_temperature
    )

    query = agent.full_chain.invoke({"question": query})

    return agent.db.run(query)