1. **Query-Time Table Retrieval**: Dynamically retrieve relevant tables in the text-to-SQL prompt.
2. **Query-Time Sample Row retrieval**: Embed/Index each row, and dynamically retrieve example rows for each table in the text-to-SQL prompt.

In [1]:
%pwd

'c:\\Users\\Hp\\Documents\\GitHub\\rag_text-2-sql\\notebooks'

In [2]:
import os

os.chdir("../")

%pwd

'c:\\Users\\Hp\\Documents\\GitHub\\rag_text-2-sql'

In [3]:
from utils.helpers.other_imports import (
    io,
    time,
    re,
    requests,
    zipfile,
    shutil,
    gc,
    traceback,
    json,
    pyjson,
    pd,
    Path,
    List,
    Dict,
    BaseModel,
    Field,
    px,
    chromadb,
)

from utils.helpers.sql_alchemy_imports import (
    create_engine,
    text,
    inspect,
    MetaData,
    Table,
    Column,
    String,
    Integer,
)

from utils.helpers.llama_index_imports import (
    Settings, 
    SQLDatabase, 
    VectorStoreIndex, 
    ChromaVectorStore,
    load_index_from_storage,
    set_global_handler,
    LLMTextCompletionProgram,
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
    SQLRetriever,
    DEFAULT_TEXT_TO_SQL_PROMPT,
    PromptTemplate,
    FunctionTool,
    ChatResponse,
    TextNode,
    StorageContext,
    Workflow,
    step,
    StartEvent,
    StopEvent,
    draw_all_possible_flows,
    draw_most_recent_execution,
)

from utils.config import CONFIG
from utils.logger import setup_logger


# configurations
LOG_PATH = Path(CONFIG["LOG_PATH"])

CHINOOK_DBEAVER_DB_PATH = Path(CONFIG["CHINOOK_DBEAVER_DB_PATH"])
CHINOOK_TABLE_INDEX_DIR = Path(CONFIG["CHINOOK_TABLE_INDEX_DIR"])
SQLITE_DB_DIR = Path(CONFIG["SQLITE_DB_DIR"])
CHROMA_DB_DIR = Path(CONFIG["CHROMA_DB_DIR"])

WORKFLOW_VISUALIZATION_DIR = Path(CONFIG["WORKFLOW_VISUALIZATION_DIR"])

QUERY_1 = CONFIG["QUERY_1"]
QUERY_1_INITIAL = CONFIG["QUERY_1_INITIAL"]
QUERY_2 = CONFIG["QUERY_2"]
QUERY_2_INITIAL = CONFIG["QUERY_2_INITIAL"]

TOP_K = CONFIG["TOP_K"]
TOP_N = CONFIG["TOP_N"]
MAX_RETRIES = CONFIG["MAX_RETRIES"]


# setup logging
LOG_DIR = os.path.join(os.getcwd(), LOG_PATH)
os.makedirs(LOG_DIR, exist_ok=True)  # Create the logs directory if it doesn't exist

# comment out line 15 in utils/logger.py -> only for notebooks
LOG_FILE = os.path.join(LOG_DIR, "db_connect_notebook.log")
logger = setup_logger("db_connect_notebook_logger", LOG_FILE)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from utils.llm.get_prompt_temp import TABLE_INFO_PROMPT
from utils.llm.get_llm_func import get_llm_func


class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(
        ..., description="table name (must be underscores and NO spaces)"
    )
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )


program = LLMTextCompletionProgram.from_defaults(
    output_cls=TableInfo,
    prompt_template_str=TABLE_INFO_PROMPT,
    llm=get_llm_func(),
)


def extract_first_json_block(text: str):
    logger.info("Extracting the first valid JSON object from text, ignoring extra trailing text.")
    
    match = re.search(r"\{.*\}", text, re.S)
    if not match:
        logger.error(f"No JSON object found in text: {text}")
        raise ValueError("No JSON object found in output")
    
    try:
        logger.info(f"Extracted JSON: {match.group()}")
        return pyjson.loads(match.group())
    except Exception as e:
        logger.error(f"Failed to parse JSON: {e}\nRaw text: {text}")
        raise ValueError(f"Failed to parse JSON: {e}\nRaw text: {text}")


os.makedirs(SQLITE_DB_DIR, exist_ok=True)
SUMMARY_DB_PATH = os.path.join(SQLITE_DB_DIR, "table_summaries.db")

logger.info(f"Creating SQLite DB Engine for the new summaries database: {SUMMARY_DB_PATH}")
summary_engine = create_engine(f"sqlite:///{SUMMARY_DB_PATH}")

logger.info(f" - Ensuring the table exists (id, table_name, table_summary, created_at)")
with summary_engine.begin() as conn:
    conn.execute(text("""
    CREATE TABLE IF NOT EXISTS table_summaries (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        table_name TEXT NOT NULL,
        table_summary TEXT NOT NULL,
        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
    )
    """))

table_infos = []

logger.info(f"Creating SQLite DB Engine for the existing Chinook database at {CHINOOK_DBEAVER_DB_PATH}")
engine = create_engine(f"sqlite:///{CHINOOK_DBEAVER_DB_PATH}")
inspector = inspect(engine)

logger.info("Generating table summaries...")
with engine.connect() as conn:
    existing_tables = set()
    
    logger.info("Fetching existing summaries from the summaries database...")
    with summary_engine.connect() as summary_conn:
        rows = summary_conn.execute(text("SELECT table_name FROM table_summaries")).fetchall()
        existing_tables = {row[0] for row in rows}
        logger.info(f"Found {len(existing_tables)} existing summaries in DB")
    
    for idx, table in enumerate(inspector.get_table_names()):
        if table in existing_tables:
            logger.info(f" - Skipping table '{table}' — summary already exists.")
            continue
        
        logger.info(f" - Processing new table: {table}")
        df = pd.read_sql(f"SELECT * FROM {table} LIMIT 10;", conn)
        df_str = df.to_csv(index=False)

        table_info = None
        for attempt in range(MAX_RETRIES):
            try:
                raw_output = program(
                    table_str=df_str,
                    exclude_table_name_list=str(list(inspector.get_table_names())),
                )

                logger.info(f"Normalize LLM output")
                if isinstance(raw_output, str):
                    parsed_dict = extract_first_json_block(raw_output)
                elif isinstance(raw_output, dict):
                    parsed_dict = raw_output
                elif isinstance(raw_output, TableInfo):
                    parsed_dict = raw_output.model_dump()
                else:
                    logger.error(f"Unexpected return type: {type(raw_output)}")
                    raise TypeError(f"Unexpected return type: {type(raw_output)}")

                table_info = TableInfo(
                    table_name=table,  # use actual SQLAlchemy inspector name
                    table_summary=parsed_dict["table_summary"],
                )

                logger.info(f"Processed table: {table_info.table_name}")
                break  # success → next table

            except Exception as e:
                logger.error(f"Error with attempt {attempt+1} for {table}: {e}")
                time.sleep(2)

        if table_info:
            table_infos.append(table_info)
            
            try:
                logger.info(f"Saving table summary for {table_info.table_name} immediately to summaries DB")
                with summary_engine.begin() as conn2:
                    conn2.execute(
                        text("INSERT INTO table_summaries (table_name, table_summary) VALUES (:name, :summary)"),
                        {"name": table_info.table_name, "summary": table_info.table_summary},
                    )
            except Exception as e:
                logger.error(f"Failed to save table summary for {table_info.table_name}: {e}")
                continue

logger.info("\n FINAL TABLE SUMMARIES")
for t in table_infos:
    logger.info(f"- {t.table_name}: {t.table_summary}")

logger.debug("JSON dump for testing purposes only")
json_path = os.path.join(SQLITE_DB_DIR, "table_summaries.json")
with open(json_path, "w", encoding="utf-8") as f:
    json.dump([t.model_dump() for t in table_infos], f, indent=2, ensure_ascii=False)


logger.info(f"\nSaved {len(table_infos)} summaries to:")
logger.info(f" - SQLite DB: {SUMMARY_DB_PATH}")
logger.info(f" - JSON backup: {json_path}")

2025-08-20 17:10:06,376 [INFO] Creating SQLite DB Engine for the new summaries database: db\Chinook\sqlite\table_summaries.db
2025-08-20 17:10:06,379 [INFO]  - Ensuring the table exists (id, table_name, table_summary, created_at)
2025-08-20 17:10:06,392 [INFO] Creating SQLite DB Engine for the existing Chinook database at C:\Users\Hp\AppData\Roaming\DBeaverData\workspace6\.metadata\sample-database-sqlite-1\Chinook.db
2025-08-20 17:10:06,397 [INFO] Generating table summaries...
2025-08-20 17:10:06,399 [INFO] Fetching existing summaries from the summaries database...
2025-08-20 17:10:06,402 [INFO] Found 0 existing summaries in DB
2025-08-20 17:10:06,404 [INFO]  - Processing new table: Album
2025-08-20 17:10:10,876 [ERROR] Error with attempt 1 for Album: Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download
2025-08-20 17:10:43,536 [INFO] Normalize LLM output
2025-08-20 17:10:43,539 [INFO] Processed table: Album
2025-08-20 

In [5]:
def get_table_schema(table_name: str):
    """Fetch column names and types for an existing table."""
    columns = inspector.get_columns(table_name)
    schema = {col["name"]: str(col["type"]) for col in columns}
    return schema


logger.info("Table Schemas")
for table_name in inspector.get_table_names():
    schema = get_table_schema(table_name)
    logger.info(f"\nTable: {table_name}")
    for col, dtype in schema.items():
        logger.info(f"  {col}: {dtype}")

2025-08-20 17:15:57,472 [INFO] Table Schemas
2025-08-20 17:15:57,478 [INFO] 
Table: Album
2025-08-20 17:15:57,479 [INFO]   AlbumId: INTEGER
2025-08-20 17:15:57,480 [INFO]   Title: NVARCHAR(160)
2025-08-20 17:15:57,481 [INFO]   ArtistId: INTEGER
2025-08-20 17:15:57,483 [INFO] 
Table: Artist
2025-08-20 17:15:57,484 [INFO]   ArtistId: INTEGER
2025-08-20 17:15:57,484 [INFO]   Name: NVARCHAR(120)
2025-08-20 17:15:57,486 [INFO] 
Table: Customer
2025-08-20 17:15:57,487 [INFO]   CustomerId: INTEGER
2025-08-20 17:15:57,487 [INFO]   FirstName: NVARCHAR(40)
2025-08-20 17:15:57,488 [INFO]   LastName: NVARCHAR(20)
2025-08-20 17:15:57,489 [INFO]   Company: NVARCHAR(80)
2025-08-20 17:15:57,491 [INFO]   Address: NVARCHAR(70)
2025-08-20 17:15:57,493 [INFO]   City: NVARCHAR(40)
2025-08-20 17:15:57,495 [INFO]   State: NVARCHAR(40)
2025-08-20 17:15:57,496 [INFO]   Country: NVARCHAR(40)
2025-08-20 17:15:57,497 [INFO]   PostalCode: NVARCHAR(10)
2025-08-20 17:15:57,497 [INFO]   Phone: NVARCHAR(24)
2025-08-20

In [6]:
px.launch_app()
set_global_handler("arize_phoenix")

# logger.info("🌍 To view the Phoenix app in your browser, visit http://localhost:6006/")
# logger.info("📖 For more information on how to use Phoenix, check out https://arize.com/docs/phoenix")

logger.info("Phoenix launched and global handler set.")

  next(self.gen)
  next(self.gen)


🌍 To view the Phoenix app in your browser, visit http://localhost:6006/
📖 For more information on how to use Phoenix, check out https://arize.com/docs/phoenix
2025-08-20 17:16:12,522 [INFO] Phoenix launched and global handler set.


Unknown span: U3Bhbjox

GraphQL request:4:3
3 | ) {
4 |   span: node(id: $id) {
  |   ^
5 |     __typename
Traceback (most recent call last):
  File "c:\Users\Hp\Documents\GitHub\rag_text-2-sql\.venv\Lib\site-packages\graphql\execution\execute.py", line 530, in await_result
    return_type, field_nodes, info, path, await result
                                          ^^^^^^^^^^^^
  File "c:\Users\Hp\Documents\GitHub\rag_text-2-sql\.venv\Lib\site-packages\strawberry\schema\schema_converter.py", line 788, in _async_resolver
    return await await_maybe(
           ^^^^^^^^^^^^^^^^^^
    ...<5 lines>...
    )
    ^
  File "c:\Users\Hp\Documents\GitHub\rag_text-2-sql\.venv\Lib\site-packages\strawberry\utils\await_maybe.py", line 13, in await_maybe
    return await value
           ^^^^^^^^^^^
  File "c:\Users\Hp\Documents\GitHub\rag_text-2-sql\.venv\Lib\site-packages\phoenix\server\api\queries.py", line 902, in node
    raise NotFound(f"Unknown span: {id}")
phoenix.server.api.exceptions.

1. Object index, retriever, SQLDatabase

In [7]:
from utils.llm.get_llm_func import get_embedding_func
from utils.llm.get_prompt_temp import RESPONSE_SYNTHESIS_PROMPT


logger.info("Wrapping engine into LlamaIndex SQLDatabase")
sql_database = SQLDatabase(engine)

logger.info("Creating table node mapping, i.e. mapping from SQL tables -> nodes")
table_node_mapping = SQLTableNodeMapping(sql_database)

logger.info("Loading all existing summaries from SQLite DB")
with summary_engine.connect() as conn:
    rows = conn.execute(text("SELECT table_name, table_summary FROM table_summaries")).fetchall()

logger.info("Filtering out only valid tables from loaded summaries that exist in the db")
table_schema_objs = []

with engine.connect() as conn:
    inspector = inspect(conn)
    existing_tables = inspector.get_table_names()

# for t in table_infos:
#     if t.table_name in existing_tables and t.table_summary:
#         table_schema_objs.append(
#             SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
#         )
#         logger.info(f"Adding table: {t.table_name} with summary: {t.table_summary}")
#     else:
#         logger.warning(f"Skipping missing/unextracted table: {t.table_name}")

for row in rows:
    if row.table_name in existing_tables and row.table_summary:
        table_schema_objs.append(
            SQLTableSchema(table_name=row.table_name, context_str=row.table_summary)
        )
        logger.info(f"Adding table: {row.table_name} with summary: {row.table_summary}")
    else:
        logger.warning(f"Skipping missing/unextracted table: {row.table_name}")

logger.info("Building object index for table retrieval")
obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
    embed_model=get_embedding_func(),
)
obj_retriever = obj_index.as_retriever(similarity_top_k=TOP_K)

logger.info("Creating SQL retriever for query execution")
sql_retriever = SQLRetriever(sql_database)

2025-08-20 17:17:04,989 [INFO] Wrapping engine into LlamaIndex SQLDatabase
2025-08-20 17:17:05,015 [INFO] Creating table node mapping, i.e. mapping from SQL tables -> nodes
2025-08-20 17:17:05,016 [INFO] Loading all existing summaries from SQLite DB
2025-08-20 17:17:05,018 [INFO] Filtering out only valid tables from loaded summaries that exist in the db
2025-08-20 17:17:05,020 [INFO] Adding table: Album with summary: Summary of album data
2025-08-20 17:17:05,021 [INFO] Adding table: Artist with summary: Summary of artist data
2025-08-20 17:17:05,021 [INFO] Adding table: Customer with summary: Summary of customer information
2025-08-20 17:17:05,022 [INFO] Adding table: Employee with summary: Employee information table
2025-08-20 17:17:05,023 [INFO] Adding table: Genre with summary: Summary of music genres
2025-08-20 17:17:05,023 [INFO] Adding table: Invoice with summary: Invoice Information
2025-08-20 17:17:05,024 [INFO] Adding table: InvoiceLine with summary: Summary of invoice data
20

In [8]:
# Table Context String
def get_table_context_str(table_schema_objs: List[SQLTableSchema]):
    """Get table context string (schema + summary)."""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        try:
            # pull schema directly from DB
            table_info = sql_database.get_single_table_info(
                table_schema_obj.table_name
            )
            if table_schema_obj.context_str:
                table_opt_context = " The table description is: "
                table_opt_context += table_schema_obj.context_str
                table_info += table_opt_context
            context_strs.append(table_info)
        except Exception as e:
            logger.error(f"Skipping table {table_schema_obj.table_name}: {e}")
    return "\n\n".join(context_strs)


table_parser_component = get_table_context_str(table_schema_objs)
logger.info(f"Table Context: {table_parser_component}")

2025-08-20 17:18:08,379 [INFO] Table Context: Table 'Album' has columns: AlbumId (INTEGER), Title (NVARCHAR(160)), ArtistId (INTEGER),  and foreign keys: ['ArtistId'] -> Artist.['ArtistId']. The table description is: Summary of album data

Table 'Artist' has columns: ArtistId (INTEGER), Name (NVARCHAR(120)), . The table description is: Summary of artist data

Table 'Customer' has columns: CustomerId (INTEGER), FirstName (NVARCHAR(40)), LastName (NVARCHAR(20)), Company (NVARCHAR(80)), Address (NVARCHAR(70)), City (NVARCHAR(40)), State (NVARCHAR(40)), Country (NVARCHAR(40)), PostalCode (NVARCHAR(10)), Phone (NVARCHAR(24)), Fax (NVARCHAR(24)), Email (NVARCHAR(60)), SupportRepId (INTEGER),  and foreign keys: ['SupportRepId'] -> Employee.['EmployeeId']. The table description is: Summary of customer information

Table 'Employee' has columns: EmployeeId (INTEGER), LastName (NVARCHAR(20)), FirstName (NVARCHAR(20)), Title (NVARCHAR(30)), ReportsTo (INTEGER), BirthDate (DATETIME), HireDate (DATE

In [9]:
# SQL Output Parser
def parse_response_to_sql(response: ChatResponse) -> str:
    """Parse response into a clean SQL string."""
    response = response.message.content

    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]

    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]

    return response.strip().strip("```").strip()


sql_parser_component = FunctionTool.from_defaults(fn=parse_response_to_sql)


# Prompts
text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
    dialect=engine.dialect.name
)
logger.info(f"\n Text-to-SQL Prompt: {text2sql_prompt.template}")


response_synthesis_prompt = PromptTemplate(RESPONSE_SYNTHESIS_PROMPT)

2025-08-20 17:18:23,905 [INFO] 
 Text-to-SQL Prompt: Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use tables listed below.
{schema}

Question: {query_str}
SQLQuery: 


### Index Each Table

We embed/index the rows of each table, resulting in one index per table.

In [None]:
def index_all_tables_with_chroma(sql_database, chroma_db_dir: str = CHROMA_DB_DIR) -> Dict[str, VectorStoreIndex]:
    """Index all tables in the SQL database using ChromaDB as the backend."""
    os.makedirs(chroma_db_dir, exist_ok=True)

    vector_index_dict = {}
    engine = sql_database.engine

    logger.info(f" [00] Creating persistent Chroma client at: {chroma_db_dir}")
    chroma_client = chromadb.PersistentClient(path=chroma_db_dir)

    for table_name in sql_database.get_usable_table_names():
        logger.info(f" [01] Indexing rows in table: {table_name}")

        # Each table = separate Chroma collection
        collection = chroma_client.get_or_create_collection(name=f"table_{table_name}")
        vector_store = ChromaVectorStore(chroma_collection=collection)
        persist_dir = os.path.join(chroma_db_dir, f"table_{table_name}")

        logger.info(f" [02] Fetching all rows from table: {table_name}")
        
        if collection.count() == 0:
            logger.info(f"  [02.1.1] No existing index found → building new index for: {table_name}")
            
            with engine.connect() as conn:
                result = conn.execute(text(f'SELECT * FROM "{table_name}"'))
                
                # row_tuples = [tuple(row) for row in result.fetchall()]
                
                logger.info(f" - Converting rows to structured strings with col=value format")
                col_names = result.keys()
                row_texts = [
                    " | ".join([f"{col}={val}" for col, val in zip(col_names, row)])
                    for row in result.fetchall()
                ]

            logger.info(f"  [02.1.2] Converting rows to text nodes for table: {table_name}")
            # nodes = [TextNode(text=str(row)) for row in row_tuples]
            nodes = [TextNode(text=row_text) for row_text in row_texts]

            logger.info(f"  [02.1.3] Building vector index for table: {table_name}")
            storage_context = StorageContext.from_defaults(vector_store=vector_store)
            index = VectorStoreIndex(nodes, storage_context=storage_context)

            logger.info(f"  [02.1.4] Persisting index to: {persist_dir}")
            storage_context.persist(persist_dir=persist_dir)

        else:
            logger.info(f"  [02.2] Reloading existing Chroma index for table: {table_name}")
            storage_context = StorageContext.from_defaults(persist_dir=persist_dir)
            index = load_index_from_storage(storage_context)
            
        vector_index_dict[table_name] = index

    return vector_index_dict

# Build vector indexes for all tables using ChromaDB
vector_index_dict = index_all_tables_with_chroma(sql_database)

2025-08-20 17:18:33,190 [INFO]  [00] Creating persistent Chroma client at: db\Chinook\chromadb
2025-08-20 17:18:33,958 [INFO]  [01] Indexing rows in table: Album
2025-08-20 17:18:34,007 [INFO]  [02] Fetching all rows from table: Album
2025-08-20 17:18:34,019 [INFO]   [02.1.1] No existing index found → building new index for: Album
2025-08-20 17:18:34,020 [INFO]  - Converting rows to structured strings with col=value format
2025-08-20 17:18:34,023 [INFO]   [02.1.2] Converting rows to text nodes for table: Album
2025-08-20 17:18:34,026 [INFO]   [02.1.3] Building vector index for table: Album
2025-08-20 17:18:36,481 [INFO]   [02.1.4] Persisting index to: db\Chinook\chromadb\table_Album
2025-08-20 17:18:36,489 [INFO]  [01] Indexing rows in table: Artist
2025-08-20 17:18:36,522 [INFO]  [02] Fetching all rows from table: Artist
2025-08-20 17:18:36,525 [INFO]   [02.1.1] No existing index found → building new index for: Artist
2025-08-20 17:18:36,527 [INFO]  - Converting rows to structured str

In [11]:
def get_table_context_and_rows_str(query_str: str, table_schema_objs: List[TableInfo]):
    """Get table context string for your TableInfo objects."""
    context_strs = []

    for table_info_obj in table_schema_objs:
        logger.info("[01] Getting schema for table (use .table_name instead of .name)")
        table_info = sql_database.get_single_table_info(table_info_obj.table_name)

        logger.info("[02] Retrieving example rows for table")
        vector_retriever = vector_index_dict[table_info_obj.table_name].as_retriever(similarity_top_k=TOP_N)
        relevant_nodes = vector_retriever.retrieve(query_str) # this will return the TextNodes we stored as vector indexes 
        logger.info(f"Retrieved {len(relevant_nodes)} relevant nodes for table: {table_info_obj.table_name}")

        if len(relevant_nodes) > 0:
            table_row_context = ("\nHere are some relevant example rows (column=value):\n")
            for node in relevant_nodes:
                table_row_context += str(node.get_content()) + "\n"
            table_info += table_row_context

        context_strs.append(table_info)
        
        logger.error(f"No vector index found for {table_info_obj.table_name}")
        continue

    return "\n\n".join(context_strs)

table_parser_component = get_table_context_and_rows_str(QUERY_1, table_schema_objs)
logger.info(f"Updated table context with rows:\n{table_parser_component}")

2025-08-20 17:21:03,237 [INFO] [01] Getting schema for table (use .table_name instead of .name)
2025-08-20 17:21:03,240 [INFO] [02] Retrieving example rows for table
2025-08-20 17:21:03,344 [INFO] Retrieved 2 relevant nodes for table: Album
2025-08-20 17:21:03,345 [ERROR] No vector index found for Album
2025-08-20 17:21:03,348 [INFO] [01] Getting schema for table (use .table_name instead of .name)
2025-08-20 17:21:03,351 [INFO] [02] Retrieving example rows for table
2025-08-20 17:21:03,490 [INFO] Retrieved 2 relevant nodes for table: Artist
2025-08-20 17:21:03,491 [ERROR] No vector index found for Artist
2025-08-20 17:21:03,493 [INFO] [01] Getting schema for table (use .table_name instead of .name)
2025-08-20 17:21:03,494 [INFO] [02] Retrieving example rows for table


InternalError: Error executing plan: Internal error: Error creating hnsw segment reader: Nothing found on disk

### Define Workflow

In [None]:
from utils.t2SQL_workflow.custom_events import (
    TableRetrievedEvent,
    SchemaProcessedEvent,
    SQLPromptReadyEvent,
    SQLGeneratedEvent,
    SQLParsedEvent,
    SQLResultsEvent,
    ResponsePromptReadyEvent,
)
from utils.t2SQL_workflow.custom_fallbacks import (
    extract_sql_from_response,
    analyze_sql_error,
    create_t2s_prompt,
)


class Text2SQLWorkflowRowRetrieval(Workflow):
    @step
    async def input_step(self, ev: StartEvent) -> TableRetrievedEvent:
        logger.info(f"[Step 01] Process initial query and retrieve relevant tables")
        query = ev.query

        logger.info(f" - Use object retriever built from your table summaries")
        tables = obj_retriever.retrieve(query)  # candidate schemas
        logger.info(f" - Retrieved {len(tables)} candidate tables for query: {query}")
        
        return TableRetrievedEvent(
            tables=tables, 
            query_str=query
        )

    @step
    async def table_output_parser_step(self, ev: TableRetrievedEvent) -> SchemaProcessedEvent:
        logger.info(f"[Step 02] Parsing schemas and retrieving relevant rows for query: {ev.query_str}")

        logger.info(f" - Enriching context function with vector row retrieval for tables: {ev.tables}")
        schema_str = get_table_context_and_rows_str(ev.query_str, ev.tables)
        
        return SchemaProcessedEvent(
            table_schema=schema_str, 
            query_str=ev.query_str
        )

    @step
    async def text2sql_prompt_step(self, ev: SchemaProcessedEvent | SQLResultsEvent) -> SQLPromptReadyEvent:
        logger.info(f"[Step 03] Creating SQL prompt for query: {ev.query_str}")
        if isinstance(ev, SchemaProcessedEvent):
            table_schema = ev.table_schema
            query_str = ev.query_str
            retry_count = 0
            error_message = ""
        else:
            table_schema = getattr(ev, 'table_schema', '')
            query_str = ev.query_str
            retry_count = getattr(ev, 'retry_count', 0) + 1
            error_message = getattr(ev, 'error_message', '')

        prompt = create_t2s_prompt(table_schema, query_str, retry_count, error_message)
        
        return SQLPromptReadyEvent(
            t2s_prompt=prompt,
            query_str=query_str,
            table_schema=table_schema,
            retry_count=retry_count,
            error_message=error_message
        )

    @step
    async def text2sql_llm_step(self, ev: SQLPromptReadyEvent) -> SQLGeneratedEvent:
        logger.info(f"[Step 04] Running LLM to generate SQL for query: {ev.query_str}")
        sql_response = await Settings.llm.acomplete(ev.t2s_prompt)
        
        return SQLGeneratedEvent(
            sql_query=str(sql_response).strip(),
            query_str=ev.query_str,
            table_schema=ev.table_schema,
            retry_count=ev.retry_count,
            error_message=ev.error_message
        )

    @step
    async def sql_output_parser_step(self, ev: SQLGeneratedEvent) -> SQLParsedEvent:
        logger.info(f"[Step 05] Parsing LLM response to extract clean SQL for query: {ev.query_str}")
        try:
            clean_sql = parse_response_to_sql(ev.sql_query)  # primary parser
        except Exception:
            clean_sql = extract_sql_from_response(ev.sql_query, logger)  # fallback
        
        if not clean_sql:
            clean_sql = extract_sql_from_response(ev.sql_query, logger)

        logger.info(f"Attempt #{ev.retry_count + 1}")
        logger.info(f"LLM Response: {ev.sql_query}")
        logger.info(f"Cleaned SQL: {clean_sql}")

        return SQLParsedEvent(
            sql_query=clean_sql,
            query_str=ev.query_str,
            table_schema=ev.table_schema,
            retry_count=ev.retry_count,
            error_message=ev.error_message
        )

    @step
    async def sql_retriever_step(self, ev: SQLParsedEvent) -> SQLResultsEvent:
        logger.info(f"[Step 06] Executing SQL for query: {ev.query_str}")
        try:
            results = sql_retriever.retrieve(ev.sql_query)
            logger.info(f"[SUCCESS] Executed on attempt #{ev.retry_count + 1}")

            return SQLResultsEvent(
                context_str=str(results),
                sql_query=ev.sql_query,
                query_str=ev.query_str,
                success=True
            )
        except Exception as e:
            error_msg = str(e)
            logger.error(f"Execution failed (Attempt #{ev.retry_count + 1}): {error_msg}")

            if ev.retry_count < MAX_RETRIES:
                retry_event = SQLResultsEvent(
                    context_str="",
                    sql_query=ev.sql_query,
                    query_str=ev.query_str,
                    success=False
                )
                retry_event.retry_count = ev.retry_count + 1
                retry_event.error_message = analyze_sql_error(error_msg, ev.sql_query, ev.table_schema, logger)
                retry_event.table_schema = ev.table_schema
                
                return retry_event
            else:
                return SQLResultsEvent(
                    context_str=(f"Failed after {MAX_RETRIES+1} attempts. Final error: {error_msg}"),
                    sql_query=ev.sql_query,
                    query_str=ev.query_str,
                    success=False
                )

    @step
    async def retry_handler_step(self, ev: SQLResultsEvent) -> SQLPromptReadyEvent:
        logger.info(f"[Step 07] Handling retry for query: {ev.query_str}")
        if ev.success:
            return None
        
        return SQLPromptReadyEvent(
            t2s_prompt="",  # regenerated later
            query_str=ev.query_str,
            table_schema=getattr(ev, 'table_schema', ''),
            retry_count=ev.retry_count,
            error_message=getattr(ev, 'error_message', 'Unknown error')
        )

    @step
    async def response_synthesis_prompt_step(self, ev: SQLResultsEvent) -> ResponsePromptReadyEvent:
        logger.info(f"[Step 08] Preparing synthesis prompt for query: {ev.query_str}")
        if not ev.success:
            return None
        prompt = response_synthesis_prompt.format(
            query_str=ev.query_str,
            context_str=ev.context_str,
            sql_query=ev.sql_query
        )
        
        return ResponsePromptReadyEvent(rs_prompt=prompt)

    @step
    async def response_synthesis_llm_step(self, ev: ResponsePromptReadyEvent) -> StopEvent:
        logger.info(f"[Step 09] Generating final answer for query: {ev.query_str}")
        answer = await Settings.llm.acomplete(ev.rs_prompt)
        
        return StopEvent(result=str(answer))


# Runner
async def run_text2sql_workflow_row(query: str):
    workflow = Text2SQLWorkflowRowRetrieval(timeout=480)
    result = await workflow.run(query=query)
    return result

Visualize

In [None]:
async def visualize_text2sql_workflow_row(sample_query: str, execution_name: str, output_dir: str = WORKFLOW_VISUALIZATION_DIR):
    """
    Function to visualize the Text2SQL workflow in your version:
    - Draws all possible flows
    - Runs your row-retrieval Text2SQL workflow
    - Draws execution path of the actual run
    """
    os.makedirs(output_dir, exist_ok=True)
    
    logger.info("[01] Drawing all possible flows...")
    all_flows_path = os.path.join(output_dir, f"{execution_name}_text2sql_workflow_flow.html")
    draw_all_possible_flows(
        Text2SQLWorkflowRowRetrieval,
        filename=all_flows_path
    )
    logger.info(f"[SUCCESS] All possible flows saved to: {all_flows_path}")

    logger.info("[02] Running workflow and drawing execution path...")
    try:
        logger.info(" - wrapper function instead of manual instantiation")
        result = await run_text2sql_workflow_row(sample_query)

        logger.info(" - Recreating workflow object for execution path drawing")
        workflow = Text2SQLWorkflowRowRetrieval(timeout=240)

        execution_path = os.path.join(output_dir, f"{execution_name}_text2sql_workflow_execution.html")
        draw_most_recent_execution(
            workflow,
            filename=execution_path
        )
        logger.info(f"[SUCCESS] Recent execution path saved to: {execution_path}")
        logger.info(f"Workflow result: {result.result}")  
        logger.debug("this `.result` holds final answer")

    except Exception as e:
        logger.error(f"Error during workflow execution: {e}")
        logger.info("Note: Ensure retrievers + LLM configs are initialized correctly")

### Run Some Queries

We can now ask about relevant entries even if it doesn't exactly match the entry in the database.

In [None]:
result = await run_text2sql_workflow_row(QUERY_1)
print(result)

In [None]:
await visualize_text2sql_workflow_row(QUERY_1, QUERY_1_INITIAL)