In [1]:

import os
import openai
from dotenv import load_dotenv
load_dotenv()  # take environment variables from .env.


True

In [2]:
openai_api_key = os.getenv("OPENAI_API_KEY")


## Create Database Schema

In [3]:
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
)

In [4]:
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

In [5]:
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)

## Define SQL Database

We first define our SQLDatabase abstraction (a light wrapper around SQLAlchemy).

In [6]:
from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI

In [16]:
llm = OpenAI(temperature=0.1, model="gpt-4o-2024-08-06")

In [8]:
sql_database = SQLDatabase(engine, include_tables=["city_stats"])

We add some testing data to our SQL database.

In [9]:
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
from sqlalchemy import insert

rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {
        "city_name": "Chicago",
        "population": 2679000,
        "country": "United States",
    },
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

In [10]:
# view current table
stmt = select(
    city_stats_table.c.city_name,
    city_stats_table.c.population,
    city_stats_table.c.country,
).select_from(city_stats_table)

with engine.connect() as connection:
    results = connection.execute(stmt).fetchall()
    print(results)

[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]


## Query Index

We first show how we can execute a raw SQL query, which directly executes over the table.

In [11]:
from sqlalchemy import text

with engine.connect() as con:
    rows = con.execute(text("SELECT city_name from city_stats"))
    for row in rows:
        print(row)

('Chicago',)
('Seoul',)
('Tokyo',)
('Toronto',)



## Part 1: Text-to-SQL Query Engine

In [17]:
from llama_index.core.query_engine import NLSQLTableQueryEngine

query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database, tables=["city_stats"], llm=llm
)
query_str = "Which city has the highest population?"
response = query_engine.query(query_str)

In [19]:
display((f"{response}"))

'The city with the highest population is Tokyo, with a population of 13,960,000.'

## Part 2: Query-Time Retrieval of Tables for Text-to-SQL

In [20]:
from llama_index.core.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine,
)
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import VectorStoreIndex

# set Logging to DEBUG for more detailed outputs
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="city_stats"))
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=1)
)

In [22]:
response = query_engine.query("Which city has the highest population?")
display((f"{response}"))

'Tokyo has the highest population among all cities, with a population of 13,960,000.'

In [23]:
# you can also fetch the raw result from SQLAlchemy!
response.metadata["result"]

[('Tokyo', 13960000)]

## Part 3: Text-to-SQL Retriever

In [25]:
from llama_index.core.retrievers import NLSQLRetriever

# default retrieval (return_raw=True)
nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["city_stats"], return_raw=True
)

In [26]:
results = nl_sql_retriever.retrieve(
    "Return the top 5 cities (along with their populations) with the highest population."
)

In [28]:
from llama_index.core.response.notebook_utils import display_source_node

for n in results:
    display_source_node(n)

**Node ID:** ec561faf-1c23-4604-8256-2d4164b10bbf<br>**Similarity:** None<br>**Text:** [('Tokyo', 13960000), ('Seoul', 9776000), ('Toronto', 2930000), ('Chicago', 2679000)]<br>

In [29]:
# default retrieval (return_raw=False)
nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["city_stats"], return_raw=False
)

In [30]:
results = nl_sql_retriever.retrieve(
    "Return the top 5 cities (along with their populations) with the highest population."
)

In [31]:
# NOTE: all the content is in the metadata
for n in results:
    display_source_node(n, show_source_metadata=True)

**Node ID:** 56aa2d70-c48b-498c-8eb5-6d5eb0b4a999<br>**Similarity:** None<br>**Text:** <br>**Metadata:** {'city_name': 'Tokyo', 'population': 13960000}<br>

**Node ID:** a1db0c0c-3d72-477e-ba37-b3fad54ac303<br>**Similarity:** None<br>**Text:** <br>**Metadata:** {'city_name': 'Seoul', 'population': 9776000}<br>

**Node ID:** 7519ac63-6691-48ab-9c85-25755aeaa9c3<br>**Similarity:** None<br>**Text:** <br>**Metadata:** {'city_name': 'Toronto', 'population': 2930000}<br>

**Node ID:** aff32b8e-3784-4ae1-a5b9-f9e5d86895c3<br>**Similarity:** None<br>**Text:** <br>**Metadata:** {'city_name': 'Chicago', 'population': 2679000}<br>

In [32]:
from llama_index.core.query_engine import RetrieverQueryEngine

query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever)

In [33]:
response = query_engine.query(
    "Return the top 5 cities (along with their populations) with the highest population."
)

In [34]:
print(str(response))

Tokyo: 13,960,000
Seoul: 9,776,000
Toronto: 2,930,000
Chicago: 2,679,000
