#### QA over structured data, i.e tables, or databases
In this notebook, we will be exploring the use of llama index in answering questions from structured data such as databases, tables, dataframes etc.

In [1]:
from llama_index.indices.vector_store import VectorStoreIndex
from llama_index.service_context import ServiceContext
from llama_index.storage import StorageContext


In [2]:
from sqlalchemy import (
    create_engine, 
    MetaData, 
    Column, 
    Table, 
    String, 
    Integer,
    Select,
    column
)

In [3]:
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()


In [4]:
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj, 
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)

In [5]:
metadata_obj.create_all(engine)

In [6]:
from sqlalchemy import insert

rows = [
    {"city_name": "Toronto", "population": 2731571, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13929286, "country": "Japan"},
    {"city_name": "Berlin", "population": 600000, "country": "Germany"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

In [7]:
from llama_index.utilities.sql_wrapper import SQLDatabase

sql_database = SQLDatabase(engine=engine, include_tables=["city_stats"])

### Create the Natural language SQL converter
Once done with defining the table and creating an SQL wrapper using our engine and specifying the tables, we can then use this to build an index and a query engine

In [8]:
from llama_index.indices.struct_store import NLSQLTableQueryEngine

In [9]:
query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database,
    tables=["city_stats"],
)

response = query_engine.query("What city has the hightest population")

In [10]:
from pprint import pprint

In [12]:
pprint(response)

Response(response='Based on the query results, the city with the highest '
                  'population is Tokyo, with a population of 13,929,286.',
         source_nodes=[NodeWithScore(node=TextNode(id_='afea2294-f067-497d-8c10-c76916148ec5', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text="[('Tokyo', 13929286)]", start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=None)],
         metadata={'afea2294-f067-497d-8c10-c76916148ec5': {},
                   'col_keys': ['city_name', 'population'],
                   'result': [('Tokyo', 13929286)],
                   'sql_query': 'SELECT city_name, population FROM city_stats '
                                'ORDER BY population DESC LIMIT 1;'})


In [11]:
pprint(response.response)

('Based on the query results, the city with the highest population is Tokyo, '
 'with a population of 13,929,286.')


In [13]:
from llama_index.objects import SQLTableNodeMapping, ObjectIndex, SQLTableSchema

In [14]:
table_node_mapping = SQLTableNodeMapping(sql_database)
table_node_mapping.__dict__

{'_sql_database': <llama_index.utilities.sql_wrapper.SQLDatabase at 0x28ecab450>}

In [15]:
table_schema_objs = [
    (SQLTableSchema(table_name="city_stats")),
    # one schema for each table
]

In [16]:
obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)

In [17]:
# manually set extra context text
city_stats_text = (
    "This table gives information regarding the population and country of a given city.\n"
    "The user will query with codewords, where 'foo' corresponds to population and 'bar'"
    "corresponds to city."
)

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="city_stats", context_str=city_stats_text))
]

In [18]:
obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)

In [19]:
from llama_index.indices.struct_store import SQLTableRetrieverQueryEngine


In [20]:
query_engine = SQLTableRetrieverQueryEngine(
    sql_database,
    table_retriever=obj_index.as_retriever(similarity_tok_k=1)
)

In [22]:
response = query_engine.query("What is the population of Tokyo?")
pprint(response.response)

'The population of Tokyo is approximately 13,929,286 people.'
