#### This Notebook demostrates LlamaIndex's capabilities to work with SQL Databases. 
We will use a MySQL table to perform a number of queries. 
- Create the SQL connection and craete an engine object to connect to the MySQL DB (DB Name: demo_db, table name: walmart). 
- An SQL Database object is also created.
- We will refer to the walmart table which contains weekly sales volumes and a host of other parameters - CPI, Fuel price, Temparature, holiday or not, etc. 
- We will use three mechanisms to perform queries on the table. 
- <b>Part 1: Use the NLSQLQueryEngine that converts a natural language query into the corresponding SQL query and fetches the result</b>
- <b>Part 2: Use the SQLTableRetrieverQueryEngine to perform the same operation. This method also uses an intermediate VectorStore</b>
- <b>Part 3: Use the NLSQLRetriever and plug in the retrieved documents (k=n mentioned) into RetreiverQueryEngine to articulate the final outcome</b>


In [1]:
# Install the following Libraries if they are not already installed in your environment.

# !pipenv install llama-index pymysql -q
# !pipenv install ipython
# !pipenv install llama-index-embeddings-openai

In [1]:
# SQLAlchemy is the Python SQL toolkit and Object Relational Mapper that gives 
# application developers the full power and flexibility of working with SQL databases and tables.

from sqlalchemy import (
    create_engine,
    text, 
)
from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI

In [2]:
# Set up the connection string 

db_user = "root"
db_password = "#####"
db_host = "localhost:3306"
db_name = "demo_db" #sampleDB

connection_string = f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}"

In [3]:
# Create an engine instance - The Engine is a factory in SQL Alchemy that can create new database connections
engine = create_engine(connection_string)

# Test the connection using raw SQL
print("Printing three rows:")
with engine.connect() as connection:
    result = connection.execute(text("select * from walmart limit 3"))
    for row in result:
        print(row)
        
print("Printing Table structure:")
with engine.connect() as connection:
    result = connection.execute(text("describe walmart"))
    for row in result:
        print(row)        

Printing three rows:
(1, datetime.date(2005, 2, 10), 1643690.0, 0, 42.31, 2.572, 211.096, 8.106)
(1, datetime.date(2012, 2, 10), 1641960.0, 1, 38.51, 2.548, 211.242, 8.106)
(1, datetime.date(2019, 2, 10), 1611970.0, 0, 39.93, 2.514, 211.289, 8.106)
Printing Table structure:
('Store', 'int', 'NO', '', None, '')
('Date', 'date', 'NO', '', None, '')
('Weekly_Sales', 'float', 'YES', '', None, '')
('Holiday_Flag', 'tinyint', 'YES', '', None, '')
('Temperature', 'float', 'YES', '', None, '')
('Fuel_Price', 'float', 'YES', '', None, '')
('CPI', 'float', 'YES', '', None, '')
('Unemployment', 'float', 'YES', '', None, '')


In [4]:
llm = OpenAI(temperature=0.1, model="gpt-3.5-turbo")

In [5]:
# here we define our SQLDatabase abstraction (a light wrapper around SQLAlchemy)
sql_database = SQLDatabase(engine, include_tables=["walmart"])

#### Part 1: Text-to-SQL Query Engine
Once we have constructed our SQL database, we can use the NLSQLTableQueryEngine to construct natural language queries that are synthesized into SQL queries.
Note that we need to specify the tables we want to use with this query engine. If we don't the query engine will pull all the schema context, 
which could overflow the context window of the LLM.

In [6]:
from llama_index.core.query_engine import NLSQLTableQueryEngine

query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database, tables=["walmart"], llm=llm
)
query_str = "How many unique stores are there?"
#query_str = "What is the average CPI of each Store? Order the results by Store number."
response = query_engine.query(query_str)

print(response)

There are 45 unique stores in the dataset.


#### Part 2: Query-Time Retrieval of Tables for Text-to-SQL
If we don't know ahead of time which table we would like to use, and the total size of the table schema overflows your context window size, 
we should store the table schema in an index so that during query time we can retrieve the right schema.
The way we can do this is using the SQLTableNodeMapping object, which takes in a SQLDatabase and produces a Node object 
for each SQLTableSchema object passed into the ObjectIndex constructor.

In [7]:
from llama_index.core.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine,
)
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import VectorStoreIndex

In [8]:
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="walmart"))
]  # add a SQLTableSchema for our table, you may add more tables here

# The ObjectIndex class allows for the indexing of arbitrary Python objects including SQL database schema objects. 
obj_index = ObjectIndex.from_objects(
    table_schema_objs, 
    table_node_mapping,
   index_cls=VectorStoreIndex,
)
query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=3)
)

In [9]:
query_str = "What is the average CPI of each Store? Order the results by Store number."
response = query_engine.query(query_str)

print(response)

The average CPI for each store, ordered by store number, is as follows:
Store 1: 215.9969
Store 2: 215.6463
Store 3: 219.3915
Store 4: 128.6797
Store 5: 216.5656
Store 6: 217.5532
Store 7: 193.6642
Store 8: 219.4390
Store 9: 219.6267
Store 10: 128.6797
Store 11: 219.3915
Store 12: 128.6797
Store 13: 128.6797
Store 14: 186.2857
Store 15: 135.0926
Store 16: 193.6642
Store 17: 128.6797
Store 18: 135.0926
Store 19: 135.0926
Store 20: 209.0381
Store 21: 215.6463
Store 22: 139.0113
Store 23: 135.0926
Store 24: 135.0926
Store 25: 209.0381
Store 26: 135.0926
Store 27: 139.0113
Store 28: 128.6797
Store 29: 135.0926
Store 30: 215.6463
Store 31: 215.6463
Store 32: 193.6642
Store 33: 128.6797
Store 34: 128.6797
Store 35: 139.0113
Store 36: 214.7291
Store 37: 214.7291
Store 38: 128.6797
Store 39: 214.7291
Store 40: 135.0926
Store 41: 193.6642
Store 42: 128.6797
Store 43: 207.7352
Store 44: 128.6797
Store 45: 186.2857


#### Part 3: Text-to-SQL Retriever
So far our text-to-SQL capability is packaged in a query engine and consists of both retrieval and synthesis.
You can use the SQL retriever on its own. 

In [10]:
from llama_index.core.retrievers import NLSQLRetriever

In [11]:
# default retrieval (return_raw=True)
nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["walmart"], return_raw=True
)


### Plug into our RetrieverQueryEngine
We compose our SQL Retriever with our standard RetrieverQueryEngine to synthesize a response. The result is roughly similar to our packaged Text-to-SQL query engines.


In [12]:
from llama_index.core.query_engine import RetrieverQueryEngine

query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever)

response = query_engine.query(
    "What is the average CPI of each Store? Order the results by Store number."
)

In [14]:
print(str(response))

The average CPI of each Store, ordered by Store number, is as follows:
- Store 1: 215.99689190204327
- Store 2: 215.64631087296493
- Store 3: 219.39153167084382
- Store 4: 128.67966946688566
- Store 5: 216.5655815284569
- Store 6: 217.55319704709353
- Store 7: 193.66424336466756
- Store 8: 219.43902641243034
- Store 9: 219.62668945739318
- Store 10: 128.67966946688566
- Store 11: 219.39153167084382
- Store 12: 128.67966946688566
- Store 13: 128.67966946688566
- Store 14: 186.28567824997268
- Store 15: 135.09260761821187
- Store 16: 193.66424336466756
- Store 17: 128.67966946688566
- Store 18: 135.09260761821187
- Store 19: 135.09260761821187
- Store 20: 209.03813075352383
- Store 21: 215.64631087296493
- Store 22: 139.0112838211593
- Store 23: 135.09260761821187
- Store 24: 135.09260761821187
- Store 25: 209.03813075352383
- Store 26: 135.09260761821187
- Store 27: 139.0112838211593
- Store 28: 128.67966946688566
- Store 29: 135.09260761821187
- Store 30: 215.64631087296493
- Store 31: