In [None]:
import setup

setup.init_django()

In [None]:
from rag import (
    db as rag_db, 
    engines as rag_engines,
    settings as rag_settings, 
    updaters as rag_updaters,
)

In [None]:
from typing import Optional, Union
from sqlalchemy import create_engine, text

In [None]:
rag_settings.init()
rag_db.init_vector_db()
rag_updaters.update_llama_index_documents(use_saved_embeddings=True)

In [None]:
vector_index = rag_engines.get_semantic_query_index()
semantic_query_retriever = rag_engines.get_semantic_query_retriever_engine()
sql_query_engine = rag_engines.get_sql_query_engine()

In [None]:
print(rag_settings.VECTOR_DB_NAME, rag_settings.VECTOR_DB_TABLE_NAME)

In [None]:
from llama_index.core.tools import QueryEngineTool

vector_tool = QueryEngineTool.from_defaults(
    query_engine=semantic_query_retriever,
    description=(
        f"Useful for answering semantic questions about different blog posts"
    ),
)

In [None]:
sql_tool = QueryEngineTool.from_defaults(
    query_engine=sql_query_engine,
    description=(
        "Useful for translating a natural language query into a SQL query over"
        " a table containing: blog posts and page views each blog post"
    ),
)

In [None]:
from typing import Any, Optional, Union


from llama_index.core.query_engine import SQLAutoVectorQueryEngine
from llama_index.core.query_engine.sql_vector_query_engine import *

class MySQLAutoVectorQueryEngine(SQLAutoVectorQueryEngine):
    def __init__(
        self,
        sql_query_tool: QueryEngineTool,
        vector_query_tool: QueryEngineTool,
        selector: Optional[Union[LLMSingleSelector, PydanticSingleSelector]] = None,
        llm: Optional[LLM] = None,
        service_context: Optional[ServiceContext] = None,
        sql_vector_synthesis_prompt: Optional[BasePromptTemplate] = None,
        sql_augment_query_transform: Optional[SQLAugmentQueryTransform] = None,
        use_sql_vector_synthesis: bool = True,
        callback_manager: Optional[CallbackManager] = None,
        verbose: bool = True,
    ) -> None:
        """Initialize params."""
        # validate that the query engines are of the right type
        if not isinstance(
            sql_query_tool.query_engine,
            (BaseSQLTableQueryEngine, NLSQLTableQueryEngine),
        ):
            raise ValueError(
                "sql_query_tool.query_engine must be an instance of "
                "BaseSQLTableQueryEngine or NLSQLTableQueryEngine"
            )
        if not isinstance(vector_query_tool.query_engine, RetrieverQueryEngine):
            raise ValueError(
                "vector_query_tool.query_engine must be an instance of "
                "RetrieverQueryEngine"
            )
        # if not isinstance(
        #     vector_query_tool.query_engine.retriever, VectorIndexAutoRetriever
        # ):
        #     raise ValueError(
        #         "vector_query_tool.query_engine.retriever must be an instance "
        #         "of VectorIndexAutoRetriever"
        #     )

        sql_vector_synthesis_prompt = (
            sql_vector_synthesis_prompt or DEFAULT_SQL_VECTOR_SYNTHESIS_PROMPT
        )
        SQLJoinQueryEngine.__init__(
            self,
            sql_query_tool,
            vector_query_tool,
            selector=selector,
            llm=llm,
            service_context=service_context,
            sql_join_synthesis_prompt=sql_vector_synthesis_prompt,
            sql_augment_query_transform=sql_augment_query_transform,
            use_sql_join_synthesis=use_sql_vector_synthesis,
            callback_manager=callback_manager,
            verbose=verbose,
        )

In [None]:
# from llama_index.core.query_engine import SQLAutoVectorQueryEngine

query_engine = MySQLAutoVectorQueryEngine(
    sql_tool, 
    vector_tool,
)

In [None]:
response = query_engine.query(
    "What kind of org is discussed?"
)

In [None]:
response.response

In [None]:
response = query_engine.query(
    "Are are the top 5 most viewed blog posts? What keywords do their content have?"
)

In [None]:
from IPython.display import Markdown, display

display(Markdown(response.response))

In [None]:
response = query_engine.query(
    "What are the top 5 least viewed blog posts from today?"
)
print(response.response)

In [None]:
display(Markdown(response.response))