# How to deal with large databases when doing SQL question-answering

When using LangChain to ask a language model to write SQL queries, the model needs to know your database's structure. This means providing table names, how the tables are organized (schemas), and the types of data stored in the columns.

However, real-world databases can be huge. Directly copying all that information into every prompt is impractical. Instead, we need a smart way to give the model only the essential details it needs for the specific query.

This guide shows you how to use LangChain to:

Figure out which tables are actually relevant to the user's question. We don't want to overwhelm the model with information about tables it doesn't need.

Identify the specific values within those relevant columns that the model should use in its query. This helps the model focus on the data that matters.

In essence, we'll learn how to make LangChain intelligently select and provide only the necessary database information to the language model, so it can generate accurate and efficient SQL queries.

# 1. Install necessary dependencies

In [1]:
!apt-get update
!apt-get install sqlite3

Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Get:2 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:4 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Get:5 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:6 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Get:7 https://r2u.stat.illinois.edu/ubuntu jammy/main all Packages [8,704 kB]
Get:8 http://security.ubuntu.com/ubuntu jammy-security/main amd64 Packages [2,639 kB]
Hit:9 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:10 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Get:11 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Hit:12 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Get:13 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 P

In [6]:
%pip install --upgrade --quiet  langchain langchain-community langchain-experimental langchain-openai

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.4/2.5 MB[0m [31m11.8 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.5/2.5 MB[0m [31m38.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m29.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/209.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.2/209.2 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.3/55.3 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m414.3/414.3 kB[0m [31m23.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━

# 2. Setup Colab environment

This code snippet sets up an environment in Google Colab to create and populate an SQLite database called Chinook.db. It does this by:

Mounting your Google Drive. Creating a specific directory in your Google Drive to store the database. Downloading a SQL script from a GitHub repository. Using the sqlite3 command-line tool to create the database and execute the SQL script, thereby populating the database with data. This is a typical workflow for setting up a data science environment in Google Colab, where you need to access and create files in your Google Drive and use command-line tools to interact with data.

### Mounts Google Drive

a. from google.colab import drive: Imports the drive module from the google.colab library, which is specific to Google Colaboratory.

b. drive.mount('/content/drive'): This line mounts your Google Drive to the Colab runtime. After executing this, you'll be prompted to authorize Colab to access your Google Drive. Once authorized, your Drive files become accessible within the Colab environment under the /content/drive directory.

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### Sets up the directory

a. import os: Imports the os module for interacting with the operating system, specifically for file and directory operations.

b. directory_path = '/content/drive/MyDrive/Colab Notebooks/Chinook': Defines the path to a directory within your Google Drive. This is where the SQLite database will be created.

c. if not os.path.exists(directory_path): os.makedirs(directory_path): This checks if the specified directory exists. If it doesn't, it creates the directory and any necessary parent directories. This is important to ensure that the database file can be created in the correct location.

In [3]:
import os

# Path to the directory within your Google Drive
directory_path = '/content/drive/MyDrive/Colab Notebooks/Chinook'

# Create the directory if it doesn't exist
if not os.path.exists(directory_path):
    os.makedirs(directory_path)

### Downloads and Creates the SQLite Database

This line executes a shell command that downloads a SQL script and uses it to create an SQLite database.

In [4]:
!curl -s https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql | sqlite3 '/content/drive/MyDrive/Colab Notebooks/Chinook/Chinook.db'

### Testing the connection

This code snippet establishes a connection to an SQLite database (Chinook.db) using LangChain's SQLDatabase utility, prints the database dialect and usable table names, and then executes a simple SQL query to retrieve and display the first 20 rows from the Artist table. This demonstrates how to use LangChain to interact with SQL databases and execute SQL queries.

In [8]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:////content/drive/MyDrive/Colab Notebooks/Chinook/Chinook.db", sample_rows_in_table_info=3)
print(db.dialect)
print(db.get_usable_table_names())
print(db.run("SELECT * FROM Artist LIMIT 10;"))

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]


# 3. Handling Many tables

When using LangChain to generate SQL queries, we need to tell the language model about the structure of our database tables (schemas). If we have a lot of tables, it's impossible to fit all that information into a single prompt. Instead, we need a smarter way to provide only the relevant table schemas.

LangChain provides a powerful way to do this using "tool-calling." This allows us to ask the language model to first identify the tables that are actually needed for the user's question, and then only provide the schemas of those specific tables.

Here's how we do it:

We use LangChain's .bind_tools method to tell the language model that it has access to a tool that can provide a list of relevant table names. This tool is defined using Pydantic, which ensures the output is in a predictable format.

We use an output parser to take the language model's response (which should be a list of table names) and turn it into a usable object within LangChain.

In simpler terms, we're teaching LangChain to ask the language model: "Which tables do I need to know about?" Then, LangChain only provides the schemas of those specific tables, making the process more efficient and accurate.

### Select chat model

In [30]:
%pip install -qU "langchain[openai]"

In [31]:
import getpass
import os

if not os.environ.get("OPENAI_API_KEY"):
  os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter API key for OpenAI: ")

from langchain.chat_models import init_chat_model

llm = init_chat_model("gpt-4o-mini", model_provider="openai")

### LangChain, Pydantic, and OpenAI tools to identify relevant SQL tables for a given user question

1. Import necessary dependencies

2. Table Pydantic Model

Table(BaseModel): Defines a Pydantic model called Table.

name: str = Field(...): Specifies that the model has a single field named name, which is a string.

description: Provides a description of the field, which is used by the LLM to understand the purpose of the tool.

3. Preparing Table Names

db.get_usable_table_names(): This assumes that db is a database object (similar to previous examples) and that it has a method called get_usable_table_names(). This method is expected to return a list of table names that are available for querying.

"\n".join(...): This joins the list of table names into a single string, with each table name separated by a newline character.

4. Creating System Message

This creates a system message that will be used to instruct the LLM.

It tells the LLM to identify all potentially relevant tables for a user question.

It provides the list of available table names.

It emphasizes that the LLM should include all potentially relevant tables, even if it's unsure.

5. Creating Prompt Template

ChatPromptTemplate.from_messages(...): Creates a chat prompt template with a system message and a human message.

("system", system): Adds the system message created earlier.

("human", "{input}"): Adds a human message with a placeholder for the user's input.

6. Binding Tools and Creating Output Parser

llm_with_tools = llm.bind_tools([Table]): Binds the Table Pydantic model to the LLM as a tool. This allows the LLM to output a Table object.

output_parser = PydanticToolsParser(tools=[Table]): Creates an output parser that can parse the LLM's output into a Table object.

7. Creating the Chain

Creates a LangChain runnable chain:

prompt: The prompt template.

llm_with_tools: The LLM with the Table tool bound.

output_parser: The output parser.

8. Invoking the Chain

table_chain.invoke(...): Invokes the chain with the user's question.

{"input": "What are all the genres of Alanis Morisette songs"}: Provides the user's question as input.



In [32]:
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field


class Table(BaseModel):
    """Table in SQL database."""

    name: str = Field(description="Name of table in SQL database.")


table_names = "\n".join(db.get_usable_table_names())
system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_names}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{input}"),
    ]
)
llm_with_tools = llm.bind_tools([Table])
output_parser = PydanticToolsParser(tools=[Table])

table_chain = prompt | llm_with_tools | output_parser

table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Artist'),
 Table(name='Genre'),
 Table(name='Album'),
 Table(name='Track')]

How it Works (Simplified):

Prompt Generation: The prompt template is formatted with the system message and the user's question.
LLM Invocation: The formatted prompt is sent to the LLM.
Tool Usage: The LLM uses the Table tool (which expects a table name) to identify and return relevant table categories.
Output Parsing: The PydanticToolsParser parses the LLM's output into a list of Table objects.
Result: The chain returns a list of Table objects, where each object's name attribute represents a table category that the LLM believes is relevant to the user's question.

summary

This code snippet demonstrates how to use an LLM to identify relevant table categories for a user's question. It simplifies the previous example by using a predefined list of table categories and a simpler system message. This can be useful for routing questions to different parts of an application based on the identified categories.

In [33]:
system = """Return the names of any SQL tables that are relevant to the user question.
The tables are:

Music
Business
"""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{input}"),
    ]
)

category_chain = prompt | llm_with_tools | output_parser
category_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Music')]

### LangChain chain to identify relevant SQL tables based on the identified categories

1. get_tables Function

Purpose: This function takes a list of Table objects (from the previous explanations) as input, where each Table object represents a category. It then returns a list of actual table names based on the categories.

Logic:

It initializes an empty list called tables.

It iterates through the categories list.

For each category:

If the category.name is "Music", it extends the tables list with the names of tables related to music (e.g., "Album", "Artist", "Genre", etc.).

If the category.name is "Business", it extends the tables list with the names of tables related to business (e.g., "Customer", "Employee", "Invoice", etc.).

Finally, it returns the tables list.

2. Creating the Chain

category_chain: This is assumed to be the LangChain chain created in the previous explanation, which identifies the relevant categories based on the user's question.

|: This is the LangChain "pipe" operator, which allows you to chain together LangChain runnables.

get_tables: This is the custom function defined earlier.

This line creates a new LangChain chain table_chain by chaining category_chain and get_tables. This means that the output of category_chain (a list of Table objects) will be passed as input to the get_tables function.

3. Invoking the Chain

Summary

This code demonstrates how to use a LangChain chain to identify relevant SQL tables based on the identified categories from a previous step. It chains together a category identification chain and a custom function to generate a list of table names. This allows you to dynamically determine the relevant tables for a user's question, which can be useful for building more flexible and intelligent applications.

In [51]:
from typing import List


def get_tables(categories: List[Table]) -> List[str]:
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables


table_chain = category_chain | get_tables
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

['Album', 'Artist', 'Genre', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']

### Define the LangChain chain

1. Import necessary dependencies

2. query_chain Creation

This line creates a query_chain using the create_sql_query_chain function.

llm: Assumed to be an initialized language model.

db: Assumed to be an initialized database connection object.

This chain will take a natural language question and generate an SQL query, execute it against the database, and return the results.

3. table_chain Modification

table_chain: This is assumed to be the chain created in the previous explanation, which identifies relevant tables based on the user's question.

{"input": itemgetter("question")}: This creates a dictionary that maps the "question" key in the input to the "input" key expected by the table_chain.

itemgetter("question"): This creates a callable object that retrieves the value associated with the "question" key from the input dictionary.

|: This is the LangChain "pipe" operator, which chains together runnables.

This line modifies the table_chain so that it expects the input question to be under the "question" key instead of the "input" key.

4. full_chain Creation

RunnablePassthrough.assign(...): This creates a RunnablePassthrough object that assigns the output of a runnable to a key in the input dictionary.

table_names_to_use=table_chain: This specifies that the output of the table_chain (a list of table names) should be assigned to the "table_names_to_use" key in the input dictionary.

| query_chain: This chains the RunnablePassthrough with the query_chain.

Summary

This code creates a LangChain chain that dynamically selects the relevant tables to query based on the user's question, and then generates and executes an SQL query against those tables. This allows the system to handle questions that involve different tables or categories of data.

In [52]:
from operator import itemgetter

from langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough

query_chain = create_sql_query_chain(llm, db)
# Convert "question" key to the "input" key expected by current table_chain.
table_chain = {"input": itemgetter("question")} | table_chain
# Set table_names_to_use using table_chain.
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain

In [53]:
query = full_chain.invoke(
    {"question": "What are all the genres of Alanis Morisette songs"}
)
print(query)

SQLQuery: 
```sql
SELECT DISTINCT "Genre"."Name"
FROM "Track"
JOIN "Album" ON "Track"."AlbumId" = "Album"."AlbumId"
JOIN "Artist" ON "Album"."ArtistId" = "Artist"."ArtistId"
JOIN "Genre" ON "Track"."GenreId" = "Genre"."GenreId"
WHERE "Artist"."Name" = 'Alanis Morissette'
LIMIT 5;
```


# 4. Handling High-cardinality columns

In order to filter columns that contain proper nouns such as addresses, song names or artists, we first need to double-check the spelling in order to filter the data correctly.

One naive strategy it to create a vector store with all the distinct proper nouns that exist in the database. We can then query that vector store each user input and inject the most relevant proper nouns into the prompt.

### Create the list of proper nouns from Artist , Album and Genre

1. Import necessary dependencies

2. query_as_list Function

Purpose: This function takes a database object (db) and an SQL query (query) as input and returns a cleaned list of strings from the query results.

Logic:

res = db.run(query): Executes the SQL query against the database and stores the result in res. It is assumed the results are returned as a string representing a list of lists.

res = [el for sub in ast.literal_eval(res) for el in sub if el]:

ast.literal_eval(res): Safely evaluates the string res as a Python literal.
This is used to convert the string representation of a list of lists (from the database) into an actual Python list of lists.

[el for sub in ... for el in sub if el]: Flattens the list of lists into a single list. It iterates through each sublist and then through each element (el) in the sublist, adding it to the result if el is not empty.

res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]:

re.sub(r"\b\d+\b", "", string): Uses a regular expression to remove standalone numbers (whole numbers) from each string in the list.

r"\b\d+\b": The regular expression pattern.

\b: Matches a word boundary.

\d+: Matches one or more digits.

\b: Matches another word boundary.

"": Replaces the matched numbers with an empty string.

.strip(): Removes leading and trailing whitespace from each string.

return res: Returns the cleaned list of strings.

3. Retrieving Proper Nouns

This code retrieves proper nouns from three different tables in the database: Artist, Album, and Genre.

For each table:

It calls the query_as_list function to execute a SELECT query and clean the results.

It appends the cleaned results to the proper_nouns list.

4. Length and First 5 Elements

len(proper_nouns): Calculates and displays the number of proper nouns retrieved.

proper_nouns[:5]: Displays the first 5 elements of the proper_nouns list.

Summary:

This code snippet retrieves data from the database, specifically names of artists, titles of albums, and names of genres. It cleans the data by removing standalone numbers and whitespace, and then combines the results into a single list of proper nouns.

In [63]:
import ast
import re


def query_as_list(db, query):
    res = db.run(query)
    res = [el for sub in ast.literal_eval(res) for el in sub if el]
    res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
    return res


proper_nouns = query_as_list(db, "SELECT Name FROM Artist")
proper_nouns += query_as_list(db, "SELECT Title FROM Album")
proper_nouns += query_as_list(db, "SELECT Name FROM Genre")
len(proper_nouns)
proper_nouns[:5]

['AC/DC', 'Accept', 'Aerosmith', 'Alanis Morissette', 'Alice In Chains']

In [69]:
len(proper_nouns)

647

### Creating embeddings from proper noun list and storing in Vector store

1. Import necessary dependencies

2. Creating the Vector Database

FAISS.from_texts(...): This is a static method of the FAISS class that creates a FAISS vector database from a list of texts.

proper_nouns: This is the list of proper nouns (artist names, album titles, genre names) that we retrieved and cleaned in the previous explanation.

OpenAIEmbeddings(): This creates an instance of the OpenAIEmbeddings class, which will be used to generate embeddings for the proper nouns.

3. Creating the Retriever

vector_db.as_retriever(...): This method creates a retriever object from the FAISS vector database.

search_kwargs={"k": 15}: This specifies the search parameters for the retriever.

k=15: This means that when the retriever is used to search for similar texts, it will return the top 15 most similar results.



In [64]:
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings

vector_db = FAISS.from_texts(proper_nouns, OpenAIEmbeddings())
retriever = vector_db.as_retriever(search_kwargs={"k": 15})

### Define the final chain

1. Import necessary dependencies

2. System Message

This defines the system message for the LLM.

It instructs the LLM to generate syntactically correct SQL queries.

It specifies that the LLM should only return the SQL query, without any additional explanations.

It includes placeholders for top_k, table_info, and proper_nouns.

Key Addition: It now includes a list of proper_nouns that the LLM should use to check the spelling of feature values when filtering. This helps to improve the accuracy of the generated queries by reducing spelling errors.

3. Prompt Template

This creates a ChatPromptTemplate with the system message and a placeholder for the user's input.

4. query_chain Creation

This creates a query_chain using the create_sql_query_chain function.

llm: Assumed to be an initialized language model.

db: Assumed to be an initialized database connection object.

prompt: The prompt template created earlier.

This chain will take a natural language question and generate an SQL query, execute it, and return the results.

5. retriever_chain Creation

This creates a retriever_chain that retrieves relevant proper nouns from the vector database.

itemgetter("question"): Retrieves the "question" from the input dictionary.

| retriever: Passes the question to the retriever (created in the previous explanation).

| (lambda docs: "\n".join(doc.page_content for doc in docs)): Takes the retrieved documents and concatenates their page_content into a single string, separated by newlines.

6. chain Creation

This creates the main chain by combining the retriever_chain and the query_chain.

RunnablePassthrough.assign(proper_nouns=retriever_chain): This assigns the output of the retriever_chain (the concatenated proper nouns) to the "proper_nouns" key in the input dictionary.

| query_chain: This passes the updated input dictionary to the query_chain.

How it Works (Full Chain)?

Input: The chain receives a dictionary with a "question" key containing the natural language question.

Proper Noun Retrieval: The retriever_chain retrieves relevant proper nouns from the vector database and concatenates them into a single string.

Proper Noun Assignment: The RunnablePassthrough assigns the concatenated proper nouns to the "proper_nouns" key in the input dictionary.

Query Generation and Execution: The query_chain is executed. It uses the "proper_nouns" key to know which proper nouns to include in the prompt. It generates an SQL query based on the question and executes it against the database.

Result: The chain returns the results of the SQL query.

In [65]:
from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough

system = """You are a SQLite expert. Given an input question, create a syntactically
correct SQLite query to run. Unless otherwise specificed, do not return more than
{top_k} rows.

Only return the SQL query with no markup or explanation.

Here is the relevant table info: {table_info}

Here is a non-exhaustive list of possible feature values. If filtering on a feature
value make sure to check its spelling against this list first:

{proper_nouns}
"""

prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{input}")])

query_chain = create_sql_query_chain(llm, db, prompt=prompt)
retriever_chain = (
    itemgetter("question")
    | retriever
    | (lambda docs: "\n".join(doc.page_content for doc in docs))
)
chain = RunnablePassthrough.assign(proper_nouns=retriever_chain) | query_chain

### Invoke the chain

In [67]:
# With retrieval
query = query_chain.invoke(
    {"question": "What are all the genres of elenis moriset songs", "proper_nouns": ""}
)
print(query)
db.run(query)

SELECT DISTINCT g.Name 
FROM Genre g 
JOIN Track t ON g.GenreId = t.GenreId 
JOIN Album a ON t.AlbumId = a.AlbumId 
JOIN Artist ar ON a.ArtistId = ar.ArtistId 
WHERE ar.Name = 'Alanis Morissette';


"[('Rock',)]"

In [68]:
# Without retrieval
query = chain.invoke({"question": "What are all the genres of elenis moriset songs"})
print(query)
db.run(query)

SELECT DISTINCT Genre.Name 
FROM Genre 
JOIN Track ON Genre.GenreId = Track.GenreId 
WHERE Track.Name LIKE '%Alanis Morissette%';


''