In [76]:
from langchain.embeddings import OllamaEmbeddings
from langchain.llms import TextGen

from llama_index import (
    GPTKeywordTableIndex,
    SimpleDirectoryReader,
    LLMPredictor,
    ServiceContext,
    SQLDatabase,
    SQLStructStoreIndex, VectorStoreIndex, set_global_service_context
)
from llama_index.indices.struct_store import SQLContextContainerBuilder, SQLTableRetrieverQueryEngine, \
    NLSQLTableQueryEngine
from llama_index.llms import Ollama
from llama_index.objects import SQLTableNodeMapping, SQLTableSchema, ObjectIndex

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
    inspect,
)

from sqlalchemy import insert

Do imports first, and then set up LLM and LLM Predictor

In [77]:
import os

os.environ["OPENAI_API_KEY"] = "OPENAI_API_KEY"
os.environ["REPLICATE_API_TOKEN"] = "REPLICATE_API_TOKEN"

# currently needed for notebooks
import openai

openai.api_key = os.environ["OPENAI_API_KEY"]

Load documents, build the VectorStoreIndex

In [78]:
import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

from llama_index import (
    VectorStoreIndex,
    SimpleDirectoryReader,
)

from IPython.display import Markdown, display

In [79]:

llm = Ollama(model="llama2", temperature=0.01)
llm_predictor = LLMPredictor(llm)
embed_model = OllamaEmbeddings(model="llama2")
ctx = ServiceContext.from_defaults(llm_predictor=llm_predictor, embed_model=embed_model)
set_global_service_context(ctx)

Set up service context

Read database


In [80]:
engine = create_engine("sqlite:///chinook.db")
insp = inspect(engine)
db_list = insp.get_table_names()
print(db_list)
sql_database = SQLDatabase(engine)


['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Orders', 'Playlist', 'PlaylistTrack', 'Track', 'rst']


Dump vector index from table schema information
Example here: https://gpt-index.readthedocs.io/en/v0.6.9/guides/tutorials/sql_guide.html

In [89]:
# set Logging to DEBUG for more detailed outputs
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="Album", context_str="describe albums")),  (SQLTableSchema(table_name="Customer", context_str="describe customer")),  (SQLTableSchema(table_name="Employee", context_str="contains employees information: BirthDate is birth date of the employee")),
    (SQLTableSchema(table_name="Orders", context_str="contains orders information")),
      (SQLTableSchema(table_name="rst", context_str="contains restaurant information, rst_cat is restaurant category")),                                                                    
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=1)
)

Set up query engine

In [83]:
from llama_index import PromptTemplate

qa_prompt_tmpl_str = """\
Context information is below.
---------------------
{my_context}
---------------------
Given the context information and not prior knowledge, answer the query.
Query: {my_query}
Answer: \
"""

template_var_mappings = {"context_str": "my_context", "query_str": "my_query"}

prompt_tmpl = PromptTemplate(
    qa_prompt_tmpl_str, template_var_mappings=template_var_mappings
)
partial_prompt_tmpl = prompt_tmpl.partial_format(tone_name="Shakespeare")

fmt_prompt = partial_prompt_tmpl.format(
    context_str="SELECT * FROM rst WHERE rst_cat = 'asian';",
    query_str="Find the orders made in asian restaurants?",
)


response = query_engine.query(fmt_prompt)
display(Markdown(f"<b>{response}</b>"))

INFO:llama_index.indices.struct_store.sql_retriever:> Table desc str: Table 'rst' has columns: rstID (INTEGER), rst_name (VARCHAR(255)), rst_cat (VARCHAR(255)), and foreign keys: . The table description is: contains restaurant information, rst_cat is restaurant category

Table 'Employee' has columns: EmployeeId (INTEGER), LastName (NVARCHAR(20)), FirstName (NVARCHAR(20)), Title (NVARCHAR(30)), ReportsTo (INTEGER), BirthDate (DATETIME), HireDate (DATETIME), Address (NVARCHAR(70)), City (NVARCHAR(40)), State (NVARCHAR(40)), Country (NVARCHAR(40)), PostalCode (NVARCHAR(10)), Phone (NVARCHAR(24)), Fax (NVARCHAR(24)), Email (NVARCHAR(60)), category (TEXT), rst_id (INTEGER), and foreign keys: ['ReportsTo'] -> Employee.['EmployeeId']. The table description is: contains employees information

Table 'Orders' has columns: id (INTEGER), restaurant (INTEGER), and foreign keys: ['restaurant'] -> rst.['rstID']. The table description is: contains orders information
> Table desc str: Table 'rst' has c

<b>To answer the query "Find the orders made in asian restaurants?", we need to use a subquery in the ON clause of the join operation to filter the results of the Rst table based on the rst_cat column. Here's the corrected SQL response:

SQL: SELECT Orders.* FROM Orders JOIN (SELECT id, rst_cat FROM rst WHERE rst_cat = 'asian') as rst ON Orders.restaurant = rst.id;

Explanation:

* The subquery in the ON clause filters the results of the Rst table to only those with rst_cat = 'asian'.
* The resulting table is then joined with the Orders table to get the orders made in asian restaurants.

The final result set will contain the following columns:

| id | restaurant |
| --- | --- |
| 123 | 1 |
| 456 | 2 |
| 789 | 3 |

Note that the subquery in the ON clause is enclosed in parentheses to indicate that it is a separate query that is being used to filter the results of the Rst table.</b>

Trying complex SQL functions

In [88]:
fmt_prompt = partial_prompt_tmpl.format(
    context_str="SELECT * FROM Employee",
    query_str="Which employee is the youngest based on birthday",
)


response = query_engine.query("Which employee is the youngest?")
display(Markdown(f"<b>{response}</b>"))

INFO:llama_index.indices.struct_store.sql_retriever:> Table desc str: Table 'rst' has columns: rstID (INTEGER), rst_name (VARCHAR(255)), rst_cat (VARCHAR(255)), and foreign keys: . The table description is: contains restaurant information, rst_cat is restaurant category

Table 'Employee' has columns: EmployeeId (INTEGER), LastName (NVARCHAR(20)), FirstName (NVARCHAR(20)), Title (NVARCHAR(30)), ReportsTo (INTEGER), BirthDate (DATETIME), HireDate (DATETIME), Address (NVARCHAR(70)), City (NVARCHAR(40)), State (NVARCHAR(40)), Country (NVARCHAR(40)), PostalCode (NVARCHAR(10)), Phone (NVARCHAR(24)), Fax (NVARCHAR(24)), Email (NVARCHAR(60)), category (TEXT), rst_id (INTEGER), and foreign keys: ['ReportsTo'] -> Employee.['EmployeeId']. The table description is: contains employees information

Table 'Orders' has columns: id (INTEGER), restaurant (INTEGER), and foreign keys: ['restaurant'] -> rst.['rstID']. The table description is: contains orders information
> Table desc str: Table 'rst' has c

<b>Based on the provided SQL query and response, the youngest employee is Michael Mitchell, born in 1973.</b>