In [1]:
from langchain.embeddings import OllamaEmbeddings
from langchain.llms import TextGen

from llama_index import (
    GPTKeywordTableIndex,
    SimpleDirectoryReader,
    LLMPredictor,
    ServiceContext,
    SQLDatabase,
    SQLStructStoreIndex, VectorStoreIndex, set_global_service_context
)
from llama_index.indices.struct_store import SQLContextContainerBuilder, SQLTableRetrieverQueryEngine, \
    NLSQLTableQueryEngine
from llama_index.llms import Ollama
from llama_index.objects import SQLTableNodeMapping, SQLTableSchema, ObjectIndex

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
    inspect,
)

from sqlalchemy import insert

import os

os.environ["OPENAI_API_KEY"] = "OPENAI_API_KEY"
os.environ["REPLICATE_API_TOKEN"] = "REPLICATE_API_TOKEN"

# currently needed for notebooks
import openai

openai.api_key = os.environ["OPENAI_API_KEY"]

import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

from llama_index import (
    VectorStoreIndex,
    SimpleDirectoryReader,
)

from IPython.display import Markdown, display

llm = Ollama(model="llama2", temperature=1)
llm_predictor = LLMPredictor(llm)
embed_model = OllamaEmbeddings(model="llama2:13b")
ctx = ServiceContext.from_defaults(llm_predictor=llm_predictor, embed_model=embed_model)
set_global_service_context(ctx)

engine = create_engine("sqlite:///Chinook.db")
insp = inspect(engine)
db_list = insp.get_table_names()
print(db_list)

['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Orders', 'Playlist', 'PlaylistTrack', 'Track', 'rst']


In [2]:

sql_database = SQLDatabase(engine)

# set Logging to DEBUG for more detailed outputs
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="Album", context_str="describe albums")),  (SQLTableSchema(table_name="Customer", context_str="describe customer")),  (SQLTableSchema(table_name="Employee", context_str="contains employees information: BirthDate is birth date of the employee")),
    (SQLTableSchema(table_name="Orders", context_str="contains orders information")),
      (SQLTableSchema(table_name="rst", context_str="contains restaurant information, rst_cat is restaurant category")),                                                                    
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=3)
)



DEBUG:urllib3.connectionpool:Starting new HTTP connection (1): localhost:11434
Starting new HTTP connection (1): localhost:11434
DEBUG:urllib3.connectionpool:http://localhost:11434 "POST /api/embeddings HTTP/1.1" 200 None
http://localhost:11434 "POST /api/embeddings HTTP/1.1" 200 None
DEBUG:urllib3.connectionpool:Starting new HTTP connection (1): localhost:11434
Starting new HTTP connection (1): localhost:11434
DEBUG:urllib3.connectionpool:http://localhost:11434 "POST /api/embeddings HTTP/1.1" 200 None
http://localhost:11434 "POST /api/embeddings HTTP/1.1" 200 None
DEBUG:urllib3.connectionpool:Starting new HTTP connection (1): localhost:11434
Starting new HTTP connection (1): localhost:11434
DEBUG:urllib3.connectionpool:http://localhost:11434 "POST /api/embeddings HTTP/1.1" 200 None
http://localhost:11434 "POST /api/embeddings HTTP/1.1" 200 None
DEBUG:urllib3.connectionpool:Starting new HTTP connection (1): localhost:11434
Starting new HTTP connection (1): localhost:11434
DEBUG:urllib3

In [5]:
from llama_index import PromptTemplate

qa_prompt_tmpl_str = """\
Context information is below.
---------------------
{my_context}
---------------------
Given the context information and not prior knowledge, answer the query.
Query: {my_query}
Answer: \
"""

template_var_mappings = {"context_str": "my_context", "query_str": "my_query"}

prompt_tmpl = PromptTemplate(
    qa_prompt_tmpl_str, template_var_mappings=template_var_mappings
)
partial_prompt_tmpl = prompt_tmpl.partial_format(tone_name="Shakespeare")


fmt_prompt = partial_prompt_tmpl.format(
    context_str="SELECT * FROM Employee",
    query_str="Find the lastname the youngest employee?",
)


response = query_engine.query(fmt_prompt)
display(Markdown(f"<b>{response}</b>"))

DEBUG:urllib3.connectionpool:Starting new HTTP connection (1): localhost:11434
Starting new HTTP connection (1): localhost:11434
DEBUG:urllib3.connectionpool:http://localhost:11434 "POST /api/embeddings HTTP/1.1" 200 None
http://localhost:11434 "POST /api/embeddings HTTP/1.1" 200 None
DEBUG:llama_index.indices.utils:> Top 3 nodes:
> [Node 86435f61-4829-4a58-98d7-b238596a4998] [Similarity score:             0.203719] Schema of table rst:
Table 'rst' has columns: rstID (INTEGER), rst_name (VARCHAR(255)), rst_cat (...
> [Node d71dbff9-4a1b-4d5f-98df-4eba1ce7b3d4] [Similarity score:             0.18362] Schema of table Employee:
Table 'Employee' has columns: EmployeeId (INTEGER), LastName (NVARCHAR(...
> [Node 51292d4d-b72d-4890-bff4-1105dc11a490] [Similarity score:             0.154699] Schema of table Album:
Table 'Album' has columns: AlbumId (INTEGER), Title (NVARCHAR(160)), Artis...
> Top 3 nodes:
> [Node 86435f61-4829-4a58-98d7-b238596a4998] [Similarity score:             0.203719] Sc

<b>The youngest employee in the company is Callahan.</b>