# NL2SQL

## 准备

In [1]:
%%time

# 加载llm和embeddings
%run ../utils2.py

from llama_index.core import Settings

# Settings.llm=get_llm("gpt-3.5-turbo")
Settings.llm=get_llm()
Settings.embed_model = get_embedding()

CPU times: user 3.36 s, sys: 374 ms, total: 3.74 s
Wall time: 3.36 s


In [2]:
%%time
%%capture

!pip install SQLAlchemy

CPU times: user 13.3 ms, sys: 4.4 ms, total: 17.7 ms
Wall time: 3.35 s


In [7]:
%%time

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
)

engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)

CPU times: user 2.31 ms, sys: 0 ns, total: 2.31 ms
Wall time: 2.03 ms


In [8]:
%%time

from llama_index.core import SQLDatabase

sql_database = SQLDatabase(engine, include_tables=["city_stats"])

CPU times: user 3.68 ms, sys: 0 ns, total: 3.68 ms
Wall time: 3.14 ms


In [9]:
%%time

from sqlalchemy import insert

rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {
        "city_name": "Chicago",
        "population": 2679000,
        "country": "United States",
    },
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

CPU times: user 1.97 ms, sys: 0 ns, total: 1.97 ms
Wall time: 1.64 ms


In [28]:
%%time

import pandas as pd

df = pd.read_sql_query("SELECT * from city_stats", engine)
df.head()

CPU times: user 0 ns, sys: 2.31 ms, total: 2.31 ms
Wall time: 2 ms


Unnamed: 0,city_name,population,country
0,Toronto,2930000,Canada
1,Tokyo,13960000,Japan
2,Chicago,2679000,United States
3,Seoul,9776000,South Korea


In [10]:
# view current table
stmt = select(
    city_stats_table.c.city_name,
    city_stats_table.c.population,
    city_stats_table.c.country,
).select_from(city_stats_table)

with engine.connect() as connection:
    results = connection.execute(stmt).fetchall()
    print(results)

[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]


In [11]:

from sqlalchemy import text

with engine.connect() as con:
    rows = con.execute(text("SELECT city_name from city_stats"))
    for row in rows:
        print(row)

('Chicago',)
('Seoul',)
('Tokyo',)
('Toronto',)


## 文本到sql查询引擎

In [17]:
%%time

from llama_index.core.query_engine import NLSQLTableQueryEngine

query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database, tables=["city_stats"]
)
query_str = "人口最多的城市是哪个, 有多少人口?"
response = query_engine.query(query_str)

response.response

CPU times: user 15.4 ms, sys: 0 ns, total: 15.4 ms
Wall time: 2.35 s


'人口最多的城市是东京，其人口数量为1396万人。'

In [18]:
response.metadata["result"]

[('Tokyo', 13960000)]

## 查询时表检索

In [20]:
%%time

from llama_index.core.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine,
)
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import VectorStoreIndex

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="city_stats"))
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=1)
)

response = query_engine.query("人口最多的城市是哪个, 有多少人口?")
response.response

CPU times: user 23.2 ms, sys: 656 µs, total: 23.8 ms
Wall time: 3.6 s


'人口最多的城市是东京，其人口数量为1396万人。'

In [21]:
response.metadata["result"]

[('Tokyo', 13960000)]

## 文本到sql检索

In [22]:
%%time

from llama_index.core.retrievers import NLSQLRetriever

# default retrieval (return_raw=True)
nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["city_stats"], return_raw=True
)

results = nl_sql_retriever.retrieve(
    "Return the top 5 cities (along with their populations) with the highest population."
)

results

CPU times: user 8.84 ms, sys: 386 µs, total: 9.23 ms
Wall time: 4.12 s


[NodeWithScore(node=TextNode(id_='6b36a203-51de-462b-81da-688cfc92ccc4', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text="[('Tokyo', 13960000), ('Seoul', 9776000), ('Toronto', 2930000), ('Chicago', 2679000)]", start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=None)]

In [26]:
%%time

from llama_index.core.query_engine import RetrieverQueryEngine

query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever, streaming=True)

response = query_engine.query(
    "Return the top 5 cities (along with their populations) with the highest population."
)

response.print_response_stream()

1. Tokyo - 13,960,000
2. Seoul - 9,776,000
3. Toronto - 2,930,000
4. Chicago - 2,679,000CPU times: user 94.1 ms, sys: 2.57 ms, total: 96.7 ms
Wall time: 3.25 s
