In [None]:
from llama_index.core import VectorStoreIndex, load_index_from_storage
from sqlalchemy import text
from llama_index.core.schema import TextNode
from llama_index.core import StorageContext
import os
from pathlib import Path
from typing import Dict
import weaviate
from llama_index.vector_stores.weaviate import WeaviateVectorStore


def index_one_table(sql_database: SQLDatabase, table_name, table_index_dir: str = "table_index_dir") -> Dict[str, VectorStoreIndex]:

    # client = weaviate.Client(embedded_options=weaviate.embedded.EmbeddedOptions(), additional_headers={ 'X-OpenAI-Api-Key': os.environ["OPENAI_API_KEY"]})
    client = weaviate.connect_to_local(
        headers={"X-OpenAI-Api-key": os.getenv("OPENAI_API_KEY")}
    )
    try:
        # construct vector store
        vector_store = WeaviateVectorStore(weaviate_client = client, index_name=table_name, text_key="content")
        
        # # set up the index
        # index = VectorStoreIndex(nodes, storage_context = storage_context)
        
        vector_index_dict = {}
        engine = sql_database.engine
        print(f"Indexing rows in table: {table_name}")
        if not os.path.exists(f"./content/opportunities/{table_name}"):
            # get all rows from table
            with engine.connect() as conn:
                cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
                result = cursor.fetchall()
                row_tups = []
                for row in result:
                    row_tups.append(tuple(row))
    
            # index each row, put into vector store index
            nodes = [TextNode(text=str(t)) for t in row_tups]
    
            # setting up the storage for the embeddings
            storage_context = StorageContext.from_defaults(vector_store = vector_store)
            
            # put into vector store index (use OpenAIEmbeddings by default)
            index = VectorStoreIndex(nodes, storage_context = storage_context)
    
            # save index
            index.set_index_id("vector_index")
            index.storage_context.persist(f"{table_index_dir}/{table_name}")
        else:
            # rebuild storage context
            storage_context = StorageContext.from_defaults(
                persist_dir=f"{table_index_dir}/{table_name}", vector_store = vector_store
            )
            # load index
            index = load_index_from_storage(
                storage_context, index_id="vector_index"
            )
        vector_index_dict[table_name] = index
    finally:
        client.close()
    return vector_index_dict

def index_all_tables(
    sql_database: SQLDatabase, table_index_dir: str = "table_index_dir"
) -> Dict[str, VectorStoreIndex]:
    """Index all tables."""
    if not Path(table_index_dir).exists():
        os.makedirs(table_index_dir)

    vector_index_dict = {}
    engine = sql_database.engine
    for table_name in sql_database.get_usable_table_names():
        print(f"Indexing rows in table: {table_name}")
        if not os.path.exists(f"{table_index_dir}/{table_name}"):
            # get all rows from table
            with engine.connect() as conn:
                cursor = conn.execute(text(f'SELECT * FROM "{table_name}"'))
                result = cursor.fetchall()
                row_tups = []
                for row in result:
                    row_tups.append(tuple(row))

            # index each row, put into vector store index
            nodes = [TextNode(text=str(t)) for t in row_tups]

            # put into vector store index (use OpenAIEmbeddings by default)
            index = VectorStoreIndex(nodes)

            # save index
            index.set_index_id("vector_index")
            index.storage_context.persist(f"{table_index_dir}/{table_name}")
        else:
            # rebuild storage context
            storage_context = StorageContext.from_defaults(
                persist_dir=f"{table_index_dir}/{table_name}"
            )
            # load index
            index = load_index_from_storage(
                storage_context, index_id="vector_index"
            )
        vector_index_dict[table_name] = index

    return vector_index_dict


vector_index_dict = index_one_table(sql_database, "Construction_Projects_Summary_Table")

In [None]:
from llama_index.core.retrievers import SQLRetriever
from typing import List
from llama_index.core.query_pipeline import FnComponent
import weaviate
from llama_index.vector_stores.weaviate import WeaviateVectorStore

sql_retriever = SQLRetriever(sql_database)


def get_table_context_and_rows_str(
    query_str: str, table_schema_objs: List[SQLTableSchema]
):
    client = weaviate.connect_to_local(
        headers={"X-OpenAI-Api-key": os.getenv("OPENAI_API_KEY")}
    )

    try:
        """Get table context string."""
        context_strs = []
        for table_schema_obj in table_schema_objs:
            # first append table info + additional context
            table_info = sql_database.get_single_table_info(
                table_schema_obj.table_name
            )
            if table_schema_obj.context_str:
                table_opt_context = " The table description is: "
                table_opt_context += table_schema_obj.context_str
                table_info += table_opt_context
        
            # also lookup vector index to return relevant table rows
            vector_store = WeaviateVectorStore(weaviate_client = client, index_name=table_schema_obj.table_name, text_key="content")
            loaded_index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
            vector_retriever = loaded_index.as_retriever(similarity_top_k=2)
            relevant_nodes = vector_retriever.retrieve(query_str)
            if len(relevant_nodes) > 0:
                table_row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
                for node in relevant_nodes:
                    table_row_context += str(node.get_content()) + "\n"
                table_info += table_row_context
        
            context_strs.append(table_info)
    finally: 
        client.close()
    return "\n\n".join(context_strs)


table_parser_component = FnComponent(fn=get_table_context_and_rows_str)

In [None]:
from llama_index.core.query_pipeline import (
    QueryPipeline as QP,
    Link,
    InputComponent,
    CustomQueryComponent,
)

qp = QP(
    modules={
        "input": InputComponent(),
        "table_retriever": obj_retriever,
        "table_output_parser": table_parser_component,
        "text2sql_prompt": text2sql_prompt,
        "text2sql_llm": llm,
        "sql_output_parser": sql_parser_component,
        "sql_retriever": sql_retriever,
        "response_synthesis_prompt": response_synthesis_prompt,
        "response_synthesis_llm": llm,
    },
    verbose=True,
)

In [None]:
qp.add_link("input", "table_retriever")
qp.add_link("input", "table_output_parser", dest_key="query_str")
qp.add_link(
    "table_retriever", "table_output_parser", dest_key="table_schema_objs"
)
qp.add_link("input", "text2sql_prompt", dest_key="query_str")
qp.add_link("table_output_parser", "text2sql_prompt", dest_key="schema")
qp.add_chain(
    ["text2sql_prompt", "text2sql_llm", "sql_output_parser", "sql_retriever"]
)
qp.add_link(
    "sql_output_parser", "response_synthesis_prompt", dest_key="sql_query"
)
qp.add_link(
    "sql_retriever", "response_synthesis_prompt", dest_key="context_str"
)
qp.add_link("input", "response_synthesis_prompt", dest_key="query_str")
qp.add_link("response_synthesis_prompt", "response_synthesis_llm")