In [None]:
from llama_index import SimpleDirectoryReader, WikipediaReader
from IPython.display import Markdown, display
from langchain import HuggingFaceHub, OpenAI

In [None]:
# Import tokens from .env

In [None]:
wiki_docs = WikipediaReader().load_data(pages=['Toronto', 'Berlin', 'Tokyo'])

wiki_docs

# Create db schema

In [None]:
from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, select, column

In [None]:
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

In [None]:
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)

In [None]:
# the_llm=HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature":0.2, "max_length":512})

the_llm = HuggingFaceHub(repo_id="EleutherAI/gpt-neox-20b", model_kwargs={"tempearture":0.5, "max_length": 512})

# the_llm = OpenAI(temperature=0.9, model_name="text-davinci-003")

In [None]:
from llama_index import SQLDatabase, ServiceContext, GPTSQLStructStoreIndex
from llama_index import LLMPredictor

In [None]:
llm_predictor = LLMPredictor(llm=the_llm)
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, chunk_size=900, chunk_size_limit=900)

In [None]:
sql_database = SQLDatabase(engine, include_tables=["city_stats"])

In [None]:
sql_database.table_info

In [None]:
# NOTE: the table_name specified here is the table that you
# want to extract into from unstructured documents.
index = GPTSQLStructStoreIndex.from_documents(
    wiki_docs, 
    sql_database=sql_database, 
    table_name="city_stats",
    service_context=service_context
)

# index = GPTSQLStructStoreIndex(
#     [],
#     sql_database=sql_database, 
#     table_name="city_stats",
# )

In [None]:
# view current table
stmt = select(
    city_stats_table.c["city_name", "population", "country"]
).select_from(city_stats_table)

with engine.connect() as connection:
    results = connection.execute(stmt).fetchall()
    print(results)

In [None]:
query_engine = index.as_query_engine(
    query_mode="sql"
)
response = query_engine.query("SELECT city_name from city_stats")

In [None]:
display(Markdown(f"<b>{response}</b>"))

In [None]:
# set Logging to DEBUG for more detailed outputs
query_engine = index.as_query_engine(
    query_mode="nl"
)
response = query_engine.query("Which city has the highest population?")

In [None]:
display(Markdown(f"<b>{response}</b>"))

In [None]:
# you can also fetch the raw result from SQLAlchemy! 
response.extra_info["result"]