This notebook includes a simple hardcoded chain that generates a SQL-query from a natural language question, executes this SQL-query and then generates a natural language answer with the output of the SQL-query regarding the original natural language question.

Install the required libraries:

In [1]:
pip install --upgrade --quiet  langchain langchain-community langchain-openai

[0mNote: you may need to restart the kernel to use updated packages.


Import the necessary modules:

In [2]:
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

from langchain.chains import create_sql_query_chain

from langchain_openai import ChatOpenAI

from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

Fill in your OpenAI API key in the cell below to make sure your key is used as the environment variable.

In [3]:
#for gpt-3.5-turbo:
%env OPENAI_API_KEY=

env: OPENAI_API_KEY=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


In the cell below, it is shown if it is possible to interact with the wijkpaspoort.sqlite3 database, using the SQLAlchemy-driven SQLDatabase class from the langchain-community library. An open source library that contains third-party integrations that implement the base interfaces defined in LangChain Core, making them ready-to-use in any LangChain application.

In [4]:
db = SQLDatabase.from_uri("sqlite:///wijkpaspoort.sqlite3")
print(db.dialect)
print(db.get_usable_table_names())
print() #print db.get_table_info will print all the tables, their columns 
    #and three example rows. This table_info will later be used to create 
    #the prompt that is given to the llm.
db.run("SELECT * FROM addresses LIMIT 10;")

sqlite
['addresses', 'associated_fat', 'associated_liquid', 'clients', 'cluster_results', 'consumption_day', 'consumption_recipe', 'consumption_sup_nut', 'districts', 'household_income', 'measurements', 'municipalities', 'neighborhoods', 'participants', 'peoples_income', 'provinces', 'schools', 'sports_facilities', 'user_client', 'users']



"[(1, '9711LX', 51, '00140000', '001400', '0014'), (2, '9712AA', 1, '00140000', '001400', '0014'), (3, '9712AA', 3, '00140000', '001400', '0014'), (4, '9712AA', 5, '00140000', '001400', '0014'), (5, '9712AA', 7, '00140000', '001400', '0014'), (6, '9712AA', 9, '00140000', '001400', '0014'), (7, '9712AA', 11, '00140000', '001400', '0014'), (8, '9712AA', 13, '00140000', '001400', '0014'), (9, '9712AA', 15, '00140000', '001400', '0014'), (10, '9712AA', 17, '00140000', '001400', '0014')]"

In the cell below, a prompt is constructed that will be given to the LLM to come up with the right response.

In [5]:
from langchain_core.prompts import PromptTemplate

template = '''Given an input question create a syntactically correct sqlite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
When there are less than {top_k} examples available in the database, limit the results to the number of examples that are available in the database.
You can order the results by a relevant column to return the most interesting examples in the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
Make sure the SQL-query starts with 'SELECT' and does not start with something like: ' ```sql '.
Delete preceding backticks (```) and the 'sql' keyword when constructing a query.
Also make sure the SQL-query ends with a semicolon.
If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the final answer.

Only use the following tables:

{table_info}.

Question: {input}'''
prompt = PromptTemplate.from_template(template)

In the cell below, the user input (natural language question) is taken and converted to a SQL query. LangChain comes with a built-in chain for this: create_sql_query_chain and it asks for four arguments: the LLM of choice, the database of choice, the prompt that is designed above and the (k) number of results that need to be included in the output.

In [6]:
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
chain = create_sql_query_chain(llm, db, prompt=prompt, k=10)
response = chain.invoke({"question": "How many addresses are there?"})
response

'SELECT COUNT(*) FROM addresses;'

It is possible to run this SQL-query in the database:

In [7]:
db.run(response)

'[(7888477,)]'

In the cell below, the third-party integration: QuerySQLDataBaseTool is used to easily add query execution to the chain. The only argument that it requires is the database that we already set up before.

In [8]:
execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db, prompt=prompt, k=10)
chain = write_query | execute_query
chain.invoke({"question": "How many addresses are there?"})

'[(7888477,)]'

In the cell below, the natural language question and the SQL-output are passed to the LLM once more, in order to generate a natural language answer.

In [9]:
answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

#This chain fills in the above answer_prompt --> pass it to the LLM --> convert the output of the LLM into a string.
answer = answer_prompt | llm | StrOutputParser()

#The chain has the format of: (create_SQL_query and save the SQL-query part --> pass the SQL-query to the QuerySQLDataBaseTool, which will execute the SQL-query) --> pass the output of the QuerySQLDataBaseTool to the chain above.
chain = (RunnablePassthrough.assign(query=write_query).assign(result=itemgetter("query") | execute_query) | answer)

chain.invoke({"question": "List all measurements for the municipality of 'Amsterdam'"})


"The measurements for the municipality of 'Amsterdam' are as follows:\n1. Measurement ID: 922, values: [81.5, 59.7, 60.4, 2.8, 58.4, 38.8, 11.4, 24.3, 39.2, 76.5, 11.6, 6.7, 55.0, 9.8, 26.2]\n2. Measurement ID: 18263, values: [80.9, 58.4, 59.4, 2.6, 57.7, 37.7, 10.8, 23.4, 38.4, 75.8, 10.7, 6.1, 54.2, 9.1, 25.6]\n3. Measurement ID: 35604, values: [82.0, 60.6, 60.8, 3.1, 59.4, 39.6, 11.9, 25.0, 40.2, 77.1, 12.0, 7.3, 55.7, 10.3, 27.1]\n4. Measurement ID: 52945, values: [60.0, 34.8, 33.6, 1.7, 42.9, 55.3, 18.3, 14.2, 55.8, 66.8, 7.2, 9.3, 43.2, 6.4, 9.1]\n5. Measurement ID: 70286, values: [59.3, 34.4, 33.0, 1.6, 42.0, 54.6, 17.6, 13.6, 54.6, 65.8, 6.7, 8.7, 41.9, 5.8, 8.4]\n6. Measurement ID: 87627, values: [61.2, 35.9, 34.7, 1.9, 43.6, 56.3, 18.9, 15.0, 56.5, 67.6, 7.7, 9.9, 44.0, 7.0, 9.9]"

In [10]:
questions = [
    {"question": "How many addresses are in the database?"},
    {"question": "How many school are there?"},
    {"question": "List all measurements for the municipality of ‘’s-Hertogenbosch’"},
    {"question": "How many indor services are there for korfball?"},
    {"question": "How many measurements are there in the municipality of Staphorst where the percentage of people with excellent to very good health is above 60%?"},
    {"question": "How many measurements are there for the Vondelburt district where more than 1 out of 10 people is a heavy drinker%?"},
    {"question": "Can you show me the top 10 municipalities by their average percentage of weekly sporters?"},
    {"question": "List the top 10 districts by overwait and also list their respective values?"},
    {"question": "Can you find the top 10 municipalities with the highest median standardised income for me? Also show me how many people in that municipality have overweight and how many have a high risk of anxiety or depression."},
    {"question": "Give me the top 20 neighbourhoods in the Netherlands based on the amount of people who had very high stress in the past 4 weeks and have more than 20 percent of heavy drinkers. Also show me in which municipalities these neighborhoods are located."},
    {"question": "Can you show me the municipalities with an average disposable income between 100 and 120 and also show me the respective percentage of weekly sporters in that municipality?"},
    {"question": "Make a list with the top 10 neighborhoods and their municipalities with, on average, the highest percentage of weekly sporters but where the average standardized income of the corresponding municipality is below 15."},
]
    

In [11]:
for i, question in enumerate(questions):
    try:
        ans = chain.invoke(question)
        query = write_query.invoke(question)
        print("Question {} : {}".format(i+1, question))
        print("Query : ", query)
        print("Answer : ", ans)
        print()
    except Exception as e:
        print("Question {} : {}".format(i+1, question))
        print("Error : ", e)
        print()

Question 1 : {'question': 'How many addresses are in the database?'}
Query :  SELECT COUNT(*) FROM addresses;
Answer :  There are 7,888,477 addresses in the database.

Question 2 : {'question': 'How many school are there?'}
Query :  SELECT COUNT(*) FROM schools;
Answer :  There are 6166 schools.

Question 3 : {'question': 'List all measurements for the municipality of ‘’s-Hertogenbosch’'}
Query :  SELECT * FROM measurements
JOIN municipalities ON measurements.municipality_code = municipalities.municipality_code
WHERE municipalities.municipality_name = 's-Hertogenbosch';
Answer :  The SQL query provided will list all measurements for the municipality of 's-Hertogenbosch'.

Question 4 : {'question': 'How many indor services are there for korfball?'}
Query :  SELECT total_indoor_services
FROM sports_facilities
WHERE sport = 'korfball';
Answer :  There are a total of 5 indoor services available for korfball.

Question 5 : {'question': 'How many measurements are there in the municipality of