In [20]:
from langchain.embeddings import OllamaEmbeddings
from langchain.llms import TextGen

from llama_index import (
    GPTKeywordTableIndex,
    SimpleDirectoryReader,
    LLMPredictor,
    ServiceContext,
    SQLDatabase,
    SQLStructStoreIndex, VectorStoreIndex, set_global_service_context
)
from llama_index.indices.struct_store import SQLContextContainerBuilder, SQLTableRetrieverQueryEngine, \
    NLSQLTableQueryEngine
from llama_index.llms import Ollama
from llama_index.objects import SQLTableNodeMapping, SQLTableSchema, ObjectIndex

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
    inspect,
)

from sqlalchemy import insert


import os
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
os.environ["OPENAI_API_KEY"] = "OPENAI_API_KEY"
os.environ["REPLICATE_API_TOKEN"] = "REPLICATE_API_TOKEN"

# currently needed for notebooks
import openai

openai.api_key = os.environ["OPENAI_API_KEY"]

import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

from llama_index import (
    VectorStoreIndex,
    SimpleDirectoryReader,
)

from IPython.display import Markdown, display

model_id = "NumbersStation/nsql-llama-2-7B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
llm = AutoModelForCausalLM.from_pretrained(model_id)
llm_predictor = LLMPredictor(llm)
embed_model = OllamaEmbeddings(model="llama2")
ctx = ServiceContext.from_defaults(llm_predictor=llm_predictor, embed_model=embed_model)
set_global_service_context(ctx)

engine = create_engine("sqlite:///../Chinook.db")
insp = inspect(engine)
db_list = insp.get_table_names()
print(db_list)
sql_database = SQLDatabase(engine)




Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Orders', 'Playlist', 'PlaylistTrack', 'Track', 'rst']


In [21]:
# set Logging to DEBUG for more detailed outputs
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="Album", context_str="describe albums")),  (SQLTableSchema(table_name="Customer", context_str="describe customer")),  (SQLTableSchema(table_name="Employee", context_str="describe employees, category is employee category")),
      (SQLTableSchema(table_name="rst", context_str="describe restaurant, rst_cat is restaurant category")),                                                                    
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=3)
)

In [22]:
from llama_index import PromptTemplate

qa_prompt_tmpl_str = """\
Given an input question, first create a syntactically correct SQL query to run, then look at the results of the query and return the answer. The question is {my_query}

Use the following format:

Question: "Question here"

SQLQuery: "SQL Query to run"

SQLResult: "Result of the SQLQuery"

Answer: "Final answer here"

Some examples of SQL queries that correspond to questions are:

Question: How many different categories in restaurant?
Answer: SELECT DISTINCT rst_cat FROM rst;
"""

sample_context = """\
Question: How many different categories in restaurant?
Answer: SELECT DISTINCT rst_cat FROM rst;
"""

template_var_mappings = {"query_str": "my_query"}

prompt_tmpl = PromptTemplate(
    qa_prompt_tmpl_str, template_var_mappings=template_var_mappings
)
partial_prompt_tmpl = prompt_tmpl.partial_format(tone_name="Shakespeare")

fmt_prompt = partial_prompt_tmpl.format(
    query_str="How many asian category restaurant do we have?",
)


response = query_engine.query(fmt_prompt)
display(Markdown(f"<b>{response}</b>"))

INFO:llama_index.indices.struct_store.sql_retriever:> Table desc str: Table 'rst' has columns: rstID (INTEGER), rst_name (VARCHAR(255)), rst_cat (VARCHAR(255)), and foreign keys: . The table description is: describe restaurant, rst_cat is restaurant category

Table 'Employee' has columns: EmployeeId (INTEGER), LastName (NVARCHAR(20)), FirstName (NVARCHAR(20)), Title (NVARCHAR(30)), ReportsTo (INTEGER), BirthDate (DATETIME), HireDate (DATETIME), Address (NVARCHAR(70)), City (NVARCHAR(40)), State (NVARCHAR(40)), Country (NVARCHAR(40)), PostalCode (NVARCHAR(10)), Phone (NVARCHAR(24)), Fax (NVARCHAR(24)), Email (NVARCHAR(60)), category (TEXT), rst_id (INTEGER), and foreign keys: ['ReportsTo'] -> Employee.['EmployeeId']. The table description is: describe employees, category is employee category

Table 'Album' has columns: AlbumId (INTEGER), Title (NVARCHAR(160)), ArtistId (INTEGER), and foreign keys: ['ArtistId'] -> Artist.['ArtistId']. The table description is: describe albums
> Table des

AttributeError: 'LlamaForCausalLM' object has no attribute 'metadata'