Import necessary dependencies

In [1]:
from langchain.embeddings import OllamaEmbeddings
from langchain.llms import TextGen

from llama_index import (
    GPTKeywordTableIndex,
    SimpleDirectoryReader,
    LLMPredictor,
    ServiceContext,
    SQLDatabase,
    SQLStructStoreIndex, VectorStoreIndex, set_global_service_context
)
from llama_index.indices.struct_store import SQLContextContainerBuilder, SQLTableRetrieverQueryEngine, \
    NLSQLTableQueryEngine
from llama_index.llms import Ollama
from llama_index.objects import SQLTableNodeMapping, SQLTableSchema, ObjectIndex

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
    inspect,
)

from sqlalchemy import insert
from IPython.display import Markdown, display
import os

from llama_index.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine,
)
from llama_index.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index import VectorStoreIndex

Set up llama client

In [2]:
llm = Ollama(model="llama2", temperature=0.01)
llm_predictor = LLMPredictor(llm)
embed_model = OllamaEmbeddings(model="llama2:13b")
ctx = ServiceContext.from_defaults(llm_predictor=llm_predictor, embed_model=embed_model)
set_global_service_context(ctx)

Set up Query engine

In [3]:
# set Logging to DEBUG for more detailed outputs
engine = create_engine("sqlite:///Chinook.db")

insp = inspect(engine)
db_list = insp.get_table_names()
print(db_list)

['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Orders', 'Playlist', 'PlaylistTrack', 'Track', 'rst']


In [4]:

metadata_obj = MetaData()
metadata_obj.reflect(engine)

sql_database = SQLDatabase(engine)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = []
print(metadata_obj.tables.keys())
for table_name in metadata_obj.tables.keys():
    print(table_name)
    table_schema_objs.append(SQLTableSchema(table_name=table_name))

dict_keys(['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'Track', 'MediaType', 'Orders', 'rst', 'Playlist', 'PlaylistTrack'])
Album
Artist
Customer
Employee
Genre
Invoice
InvoiceLine
Track
MediaType
Orders
rst
Playlist
PlaylistTrack


In [5]:
obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)

query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=3)
)

here


Send query to LLM

In [10]:
response = query_engine.query("""how many restaurants have been categorized as “Asian” in the database. Please just answer with the SQL query based on the database schema.
For context, rst is short for restaurant, and cat is short for category.
""")
display(Markdown(f"<b>{response}</b>"))

<b>The SQL query you provided is incorrect. The `COUNT(*)` function counts the number of rows in a table, but it is not applied to a specific column or selection of rows. In this case, you are trying to count the number of rows where the `category` column is equal to 'Asian'.

To fix the query, you need to specify which column or selection of rows you want to count. For example:
```
SELECT COUNT(rst.cat) FROM InvoiceLine rst WHERE rst.cat = 'Asian';
```
This query will count the number of rows where the `cat` column is equal to 'Asian'.

Alternatively, you can use the `COUNT(*)` function with a subquery that filters the results by the `category` column:
```
SELECT COUNT(*) FROM InvoiceLine WHERE category = 'Asian';
```
This query will count the number of rows where the `category` column is equal to 'Asian', without having to specify the column explicitly.</b>