**Reference: https://python.langchain.com/v0.1/docs/use_cases/sql/large_db/**

What happens in this notebook:

### **Table Model Definition**
   - **`Table` Class**: This is a simple Pydantic model representing a SQL table. It has one attribute, `name`, which is a string and is described as "Name of table in SQL database."
     - This model is used in the extraction process to match relevant SQL tables based on the user's query.

### **Helper Function - `get_tables`**
   - **`get_tables`**: This function takes a list of `Table` objects (i.e., categories such as "Music" or "Business") and returns a list of corresponding SQL table names based on the category.
     - For example, if the category is `"Music"`, the tables `"Album"`, `"Artist"`, `"Genre"`, etc., are added to the result.
     - Similarly, for `"Business"`, the corresponding tables like `"Customer"`, `"Employee"`, etc., are included.

### **Designing the agent for the large DB**

- **Step 1: Initialize LLM (`sql_agent_llm`)**: The LLM is instantiated with a given model (e.g., `"gpt-3.5-turbo"`) and temperature. The temperature controls how creative/random the model's responses are.
- **Step 2: Connect to the SQL Database (`db`)**: The connection to the Chinook SQLite database is established. The database URI is constructed using the `sqldb_directory` provided.
- **Step 3: Define Category Chain (`category_chain`)**: The `category_chain_system` is defined, which is a string explaining the categories available (like "Music" and "Business"). This chain determines which SQL tables are relevant to the user query based on the category.
- **Step 4: Chain Creation**:
- **`category_chain`**: This uses the `create_extraction_chain_pydantic` function, which creates an extraction chain that identifies relevant SQL tables from the user's question using the `Table` Pydantic model and the LLM.
- **`table_chain`**: A chain is formed by combining the output from `category_chain` with the `get_tables` function, so it maps categories to the actual SQL tables.
- **Step 5: Query Chain (`query_chain`)**: This creates a SQL query chain using the LLM and the database (`self.db`). It takes the SQL tables and constructs a query.
- **Step 6: Table Chain Input Handling**: The `"question"` key from the user input is mapped to the `"input"` key expected by the `table_chain`. This enables the chain to process user queries correctly.
- **Step 7: Full Chain Construction**: Finally, the full chain (`full_chain`) is created by combining:
1. **`RunnablePassthrough.assign`**: This sets up a step that assigns the `table_names_to_use` using the result of the `table_chain`.
2. **`query_chain`**: Executes the SQL query once the relevant tables are identified.

In [1]:
import os
from dotenv import load_dotenv
from pyprojroot import here
from typing import List
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough
from langchain_community.agent_toolkits import create_sql_agent
from operator import itemgetter
from langchain_openai import ChatOpenAI
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain_core.pydantic_v1 import BaseModel, Field

load_dotenv()

True

**Set the environment variables and load the LLM**

In [2]:
os.environ['OPENAI_API_KEY'] = os.getenv("OPEN_AI_API_KEY")

llm = ChatOpenAI(model="gpt-3.5-turbo-1106", temperature=0)
# llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
# llm = ChatOpenAI(model="gpt-4o")

In [3]:
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [4]:
# Getting the number of rows in the table
num_rows_query = "SELECT COUNT(*) FROM InvoiceLine;"
num_rows = db.run(num_rows_query)

# Getting the number of columns in the table
num_columns_query = "PRAGMA table_info(InvoiceLine);"
columns_info = db.run(num_columns_query)
num_columns = len(columns_info)

# Printing the shape of the table
print(f"The shape of aircrafts_data is: ({num_rows}, {num_columns})")

The shape of aircrafts_data is: ([(2240,)], 212)


In [5]:
class Table(BaseModel):
    """
    Represents a table in the SQL database.

    Attributes:
        name (str): The name of the table in the SQL database.
    """
    name: str = Field(description="Name of table in SQL database.")

In [11]:
system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:

Music
Business"""
category_chain = create_extraction_chain_pydantic(pydantic_schemas=Table, llm=llm, system_message=system)

In [12]:
category_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Music'), Table(name='Business')]

In [8]:
def get_tables(categories: List[Table]) -> List[str]:
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables


table_chain = category_chain | get_tables  # noqa
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

['Album',
 'Artist',
 'Genre',
 'MediaType',
 'Playlist',
 'PlaylistTrack',
 'Track',
 'Customer',
 'Employee',
 'Invoice',
 'InvoiceLine']

In [9]:
query_chain = create_sql_query_chain(llm, db)
# Convert "question" key to the "input" key expected by current table_chain.
table_chain = {"input": itemgetter("question")} | table_chain
# Set table_names_to_use using table_chain.
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain

In [10]:
query = full_chain.invoke(
    {"question": "What are all the genres of Alanis Morisette songs"}
)
print(query)

SELECT "Name"
FROM "Genre"
WHERE "GenreId" IN (
    SELECT "GenreId"
    FROM "Track"
    WHERE "AlbumId" IN (
        SELECT "AlbumId"
        FROM "Album"
        WHERE "ArtistId" = (
            SELECT "ArtistId"
            FROM "Artist"
            WHERE "Name" = 'Alanis Morissette'
        )
    )
)


In [11]:
db.run(query)

"[('Rock',)]"

In [12]:
query = full_chain.invoke(
    {"question": "list the total sales per country. Which country's customers spent the most?"}
)
print(query)

SELECT c.Country, SUM(i.Total) AS TotalSales
FROM Customer c
JOIN Invoice i ON c.CustomerId = i.CustomerId
GROUP BY c.Country
ORDER BY TotalSales DESC
LIMIT 1;


In [13]:
db.run(query)

"[('USA', 523.06)]"

In [28]:
query = full_chain.invoke(
    {"question": "What is the set of all unique genres of Alanis Morisette songs"}
)
print(query)

SELECT DISTINCT g.Name
FROM Genre g
JOIN Track t ON g.GenreId = t.GenreId
JOIN Album a ON t.AlbumId = a.AlbumId
JOIN Artist ar ON a.ArtistId = ar.ArtistId
WHERE ar.Name = 'Alanis Morissette'


In [29]:
db.run(query)

"[('Rock',)]"