## 0. SQL Crash Course
“SQL (standing for Structured Query Language) is the standard language for
relational database management systems. When it originated back in the
1970s, the domain-specific language was intended to fulfill the need of
conducting a database query that could navigate through a network of pointers to find the desired location. Its application in handling structured data has fostered in the Digital Age. In fact, the powerful database manipulation and definition capabilities of SQL and its intuitive tabular view have become available in some form on virtually every important computer platform in the world.

Some notable features of SQL include the ability to process sets of data as
groups instead of individual units, automatic navigation to data, and the use
of statements that are complex and powerful individually. Used for a variety
of tasks, such as querying data, controlling access to the database and its
objects, guaranteeing database consistency, updating rows in a table, and
creating, replacing, altering and dropping objects, SQL lets users work with
data at the logical level.”

Read more at the ANSI Blog: The SQL Standard – ISO/IEC
9075:2016 https://blog.ansi.org/?p=158690

Basic SQL syntax information: https://www.w3schools.com/sql/sql_syntax.asp

A database is a collection of files storing some related data. A database management system is an executable program which allows for efficient management and storage of data for long periods of time. Two examples are MySQL (open-source) and SQLite (open-source library). 

## 1. Agents
The core idea of agents is to use a language model to choose a sequence of actions to take. In chains, a sequence of actions is hardcoded (in code). In agents, a language model is used as a reasoning engine to determine which actions to take and in which order.

LangChain has a SQL Agent which provides a more flexible way of interacting with SQL Databases than a chain. The main advantages of using the SQL Agent are:

    -It can answer questions based on the databases’ schema as well as on the databases’ content (like describing a specific table).

    -It can recover from errors by running a generated query, catching the traceback and regenerating it correctly.

    -It can query the database as many times as needed to answer the user question.

    -It will save tokens by only retrieving the schema from relevant tables.


To initialize the agent we’ll use the create_sql_agent constructor. This agent uses the SQLDatabaseToolkit which contains tools to:

    -Create and execute queries

    -Check query syntax

    -Retrieve table descriptions

For more information: https://python.langchain.com/docs/modules/agents/

In [1]:
%pip install --upgrade --quiet  langchain langchain-community langchain-openai

Note: you may need to restart the kernel to use updated packages.


In [2]:
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
import os
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain

In [3]:
# INSERT YOUR OPENAI API KEY: https://platform.openai.com/docs/models , make an account
os.environ["OPENAI_API_KEY"] = ''

To run the following code, you must have the file Chinook.db in the same folder as this Jupyter notebook. Run the following while in the same directory as this notebook:

curl -O https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql

sqlite3 Chinook.db

.read Chinook_Sqlite.sql

In [11]:
# SQLDatabase function loads our .db file: https://python.langchain.com/docs/integrations/toolkits/sql_database
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
# there are multiple types of sql products which include different 'dialects' of sql
print(db.dialect)
# multiple tables exist within the database, so we want to see all the tables in the db file
print(db.get_usable_table_names())
# the string passed to .run is a query (in sqlite) to select the first ten rows of the Artists table
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [12]:
# initialize llm model using ChatOpenAI function
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
# create_sql_agent documentation: https://api.python.langchain.com/en/latest/agent_toolkits/langchain_community.agent_toolkits.sql.base.create_sql_agent.html

In [13]:
# execute sql agent chain
agent_executor.invoke(
    "List the total sales per country. Which country's customers spent the most?"
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with ``


[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'Customer, Invoice, InvoiceLine'}`


[0m[33;1m[1;3m
CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
1	Luís	Gonçalves	Em

{'input': "List the total sales per country. Which country's customers spent the most?",
 'output': 'The total sales per country are as follows:\n1. USA: $523.06\n2. Canada: $303.96\n3. France: $195.10\n4. Brazil: $190.10\n5. Germany: $156.48\n\nThe country whose customers spent the most is the USA with a total sales amount of $523.06.'}

The above shows the internal chain that is executed when we execute Langchain's pre-existing SQL agent, as well as the final response to the user question.

## 2. Query validation

Perhaps the most error-prone part of any SQL chain or agent is writing valid and safe SQL queries. The example below shows one method to validate the queries that are composed from the user input.

In [22]:
# now we are creating a sql query chain to check the results from the SQL agent
# this is the same sql query chain from the RAG notebook
chain = create_sql_query_chain(llm, db)

In [27]:
# these are instructions to the model on how to validate the sql queries from the sql_query_chain
system = """Double check the user's {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

Output the final SQL query only."""

# chatprompttemplate: https://api.python.langchain.com/en/latest/prompts/langchain_core.prompts.chat.ChatPromptTemplate.html
prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{query}")]
).partial(dialect=db.dialect)

# construct a new chain
validation_chain = prompt | llm | StrOutputParser()

# we are extending our chain so that the output of the sql query chain is validated using the validation chain above
# note that the llm is called in both the 'chain' and 'validation chain'
full_chain = {"query": chain} | validation_chain

In [25]:
query = full_chain.invoke(
    {
        "question": "What's the average Invoice from an American customer whose Fax is missing since 2003 but before 2010"
    }
)
query

"The original query looks correct and does not contain any common mistakes. Here is the reproduced SQL query:\n\n```sql\nSELECT AVG(i.Total) AS Average_Invoice\nFROM Invoice i\nJOIN Customer c ON i.CustomerId = c.CustomerId\nWHERE c.Country = 'USA' \nAND c.Fax IS NULL\nAND i.InvoiceDate >= '2003-01-01' \nAND i.InvoiceDate < '2010-01-01'\n```"

The obvious downside of this approach is that we need to make two model calls instead of one to generate our query. To get around this we can try to perform the query generation and query check in a single model invocation:

In [28]:
system = """You are a {dialect} expert. Given an input question, creat a syntactically correct {dialect} query to run.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per {dialect}. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Only use the following tables:
{table_info}

Write an initial draft of the query. Then double check the {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

Use format:

First draft: <<FIRST_DRAFT_QUERY>>
Final answer: <<FINAL_ANSWER_QUERY>>
"""

prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{input}")]
).partial(dialect=db.dialect)


def parse_final_answer(output: str) -> str:
    return output.split("Final answer: ")[1]


chain = create_sql_query_chain(llm, db, prompt=prompt) | parse_final_answer
prompt.pretty_print()


You are a [33;1m[1;3m{dialect}[0m expert. Given an input question, creat a syntactically correct [33;1m[1;3m{dialect}[0m query to run.
Unless the user specifies in the question a specific number of examples to obtain, query for at most [33;1m[1;3m{top_k}[0m results using the LIMIT clause as per [33;1m[1;3m{dialect}[0m. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Only use the following tables:
[33;1m[1;3m{table_info}[0m

Write an initial draft of the query. Th

## 3. Large databases

In order to write valid queries against a database, we need to feed the model the table names, table schemas, and feature values for it to query over. When there are many tables, columns, and/or high-cardinality columns, it becomes impossible for us to dump the full information about our database in every prompt. Instead, we must find ways to dynamically insert into the prompt only the most relevant information. Let’s take a look at some techniques for doing this.



#### 3A. Many tables
One of the main pieces of information we need to include in our prompt is the schemas of the relevant tables. When we have very many tables, we can’t fit all of the schemas in a single prompt. What we can do in such cases is first extract the names of the tables related to the user input, and then include only their schemas.

One easy and reliable way to do this is using OpenAI function-calling and Pydantic models. LangChain comes with a built-in create_extraction_chain_pydantic chain that lets us do just this:

In [33]:
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List

In [31]:
# load model
llm = ChatOpenAI(model="gpt-3.5-turbo-1106", temperature=0)

class Table(BaseModel):
    """Table in SQL database."""

    name: str = Field(description="Name of table in SQL database.")


table_names = "\n".join(db.get_usable_table_names())

system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_names}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

# create_extraction_chain_pydantic - https://api.python.langchain.com/en/latest/chains/langchain.chains.openai_tools.extraction.create_extraction_chain_pydantic.html
table_chain = create_extraction_chain_pydantic(Table, llm, system_message=system)
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Genre'), Table(name='Artist'), Table(name='Track')]

This works pretty well! Except, as we’ll see below, we actually need a few other tables as well. This would be pretty difficult for the model to know based just on the user question. In this case, we might think to simplify our model’s job by grouping the tables together. We’ll just ask the model to choose between categories “Music” and “Business”, and then take care of selecting all the relevant tables from there:

In [32]:
# we are helping our model by including the expected results
system = """Return the names of the SQL tables that are relevant to the user question. \
The tables are:

Music
Business"""
category_chain = create_extraction_chain_pydantic(Table, llm, system_message=system)
category_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Music'), Table(name='Business')]

In [34]:
# here we are defining a function which keeps track of the tables relevant to the input
def get_tables(categories: List[Table]) -> List[str]:
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables

table_chain = category_chain | get_tables  
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

['Album',
 'Artist',
 'Genre',
 'MediaType',
 'Playlist',
 'PlaylistTrack',
 'Track',
 'Customer',
 'Employee',
 'Invoice',
 'InvoiceLine']

Now that we’ve got a chain that can output the relevant tables for any query we can combine this with our create_sql_query_chain, which can accept a list of table_names_to_use to determine which table schemas are included in the prompt:

In [35]:
from operator import itemgetter
from langchain_core.runnables import RunnablePassthrough

query_chain = create_sql_query_chain(llm, db)
# Convert "question" key to the "input" key expected by current table_chain.
table_chain = {"input": itemgetter("question")} | table_chain
# Set table_names_to_use using table_chain.
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain

In [42]:
query = full_chain.invoke(
    {"question": "What are all the genres of Alanis Morisette songs"}
)
print(query)

SELECT g.Name
FROM Genre g
JOIN Track t ON g.GenreId = t.GenreId
JOIN Album a ON t.AlbumId = a.AlbumId
JOIN Artist ar ON a.ArtistId = ar.ArtistId
WHERE ar.Name = 'Alanis Morissette'


In [43]:
db.run(query)

"[('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',), ('Rock',)]"

In [44]:
# adding unique in our prompt correspodns to the addition of DISTINCT in the SQL query!
query = full_chain.invoke(
    {"question": "What is the set of all unique genres of Alanis Morisette songs"}
)
print(query)
print(db.run(query))

SELECT DISTINCT g.Name
FROM Genre g
JOIN Track t ON g.GenreId = t.GenreId
JOIN Album a ON t.AlbumId = a.AlbumId
JOIN Artist ar ON a.ArtistId = ar.ArtistId
WHERE ar.Name = 'Alanis Morissette'
[('Rock',)]


### Related Topics

-Optimize SQL agent performance with few-shot prompting: https://python.langchain.com/docs/use_cases/sql/agents

-High cardinality columns using chains: https://python.langchain.com/docs/use_cases/sql/large_db

-High cardinality columns using agents: https://python.langchain.com/docs/use_cases/sql/agents