1. **Query-Time Table Retrieval**: Dynamically retrieve relevant tables in the text-to-SQL prompt.
2. **Query-Time Sample Row retrieval**: Embed/Index each row, and dynamically retrieve example rows for each table in the text-to-SQL prompt.

In [1]:
%pwd

'c:\\Users\\Hp\\Documents\\GitHub\\rag_text-2-sql\\notebooks'

In [2]:
import os

os.chdir("../")

%pwd

'c:\\Users\\Hp\\Documents\\GitHub\\rag_text-2-sql'

In [3]:
import io
import time
import re
import requests
import zipfile
import shutil
import gc
import traceback
import json
import json as pyjson

from pathlib import Path

from typing import List, Dict
from pydantic import BaseModel, Field

import pandas as pd

# setup Arize Phoenix for logging/observability
import phoenix as px

import chromadb


from sqlalchemy import (
    create_engine,
    text,
    inspect,
    MetaData,
    Table,
    Column,
    String,
    Integer,
)

from llama_index.core import (
    Settings, 
    SQLDatabase, 
    VectorStoreIndex, 
    load_index_from_storage,
    set_global_handler,
)
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.llms.ollama import Ollama
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from llama_index.core.prompts import PromptTemplate
from llama_index.core.tools import FunctionTool
from llama_index.core.llms import ChatResponse
# from llama_index.core.callbacks import CallbackManager
from llama_index.core.schema import TextNode
from llama_index.core.storage import StorageContext
from llama_index.core.workflow import (
    Workflow, 
    step, 
    StartEvent, 
    StopEvent,
)
from llama_index.core.workflow.events import Event
from llama_index.utils.workflow import (
    draw_all_possible_flows, 
    draw_most_recent_execution,
)

from utils.config import CONFIG
from utils.logger import setup_logger


# configurations
LOG_PATH = Path(CONFIG["LOG_PATH"])

CHINOOK_DBEAVER_DB_PATH = Path(CONFIG["CHINOOK_DBEAVER_DB_PATH"])
CHINOOK_TABLE_SUMMARIES_DB_PATH = Path(CONFIG["CHINOOK_TABLE_SUMMARIES_DB_PATH"])
SQLITE_DB_DIR = Path(CONFIG["SQLITE_DB_DIR"])
CHROMA_DB_DIR = Path(CONFIG["CHROMA_DB_DIR"])

WORKFLOW_VISUALIZATION_DIR = Path(CONFIG["WORKFLOW_VISUALIZATION_DIR"])

QUERY_TEXT = CONFIG["QUERY_TEXT"]
QUERY_TEXT_INITIAL = CONFIG["QUERY_TEXT_INITIAL"]

TOP_K = CONFIG["TOP_K"]
TOP_N = CONFIG["TOP_N"]
MAX_RETRIES = CONFIG["MAX_RETRIES"]


# setup logging
LOG_DIR = os.path.join(os.getcwd(), LOG_PATH)
os.makedirs(LOG_DIR, exist_ok=True)  # Create the logs directory if it doesn't exist

# comment out line 15 in utils/logger.py -> only for notebooks
LOG_FILE = os.path.join(LOG_DIR, "db_connect_notebook.log")
logger = setup_logger("db_connect_notebook_logger", LOG_FILE)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from utils.llm.get_prompt_temp import TABLE_INFO_PROMPT
from utils.llm.get_llm_func import get_llm_func


class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(
        ..., description="table name (must be underscores and NO spaces)"
    )
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )


def text_completion_program(table_info_prompt_temp: str):
    
    return LLMTextCompletionProgram.from_defaults(
        output_cls=TableInfo,
        prompt_template_str=table_info_prompt_temp,
        llm=get_llm_func(),
    )


def extract_first_json_block(text: str):
    logger.info("Extracting the first valid JSON object from text, ignoring extra trailing text.")
    
    match = re.search(r"\{.*\}", text, re.S)
    if not match:
        logger.error(f"No JSON object found in text: {text}")
        raise ValueError("No JSON object found in output")
    
    try:
        logger.info(f"Extracted JSON: {match.group()}")
        return pyjson.loads(match.group())
    except Exception as e:
        logger.error(f"Failed to parse JSON: {e}\nRaw text: {text}")
        raise ValueError(f"Failed to parse JSON: {e}\nRaw text: {text}")


def prep_summary_engine(sqlite_db_dir: Path, main_db_dir: Path):
    """Prepare the SQLite engine for storing table summaries."""
    os.makedirs(sqlite_db_dir, exist_ok=True)
    summary_db_path = os.path.join(sqlite_db_dir, "table_summaries.db")

    logger.info(f"Creating SQLite DB Engine for the new summaries database: {summary_db_path}")
    summary_engine = create_engine(f"sqlite:///{summary_db_path}")

    logger.info(f"Creating SQLite DB Engine for the existing Chinook database at {main_db_dir}")
    engine = create_engine(f"sqlite:///{main_db_dir}")
    inspector = inspect(engine)

    return summary_engine, summary_db_path, engine, inspector


def create_summaries_table(summary_engine):
    """
    Ensure the 'table_summaries' table exists in the SQLite DB.
    """
    try:
        logger.info(" - Ensuring the table exists (id, table_name, table_summary, created_at)")
        with summary_engine.begin() as conn:
            conn.execute(text("""
                CREATE TABLE IF NOT EXISTS table_summaries (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    table_name TEXT NOT NULL UNIQUE,
                    table_summary TEXT NOT NULL,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                )
            """))
        logger.info("Table 'table_summaries' ensured/created.")
    except Exception as e:
        logger.error(f"Failed to ensure table_summaries exists: {e}")


def dump_summaries_sqlite(summary_engine, table_infos):
    """
    Save table summaries into SQLite DB.
    Skips duplicates based on table_name.
    """
    try:
        with summary_engine.begin() as conn:
            for t in table_infos:
                conn.execute(
                    text("""
                        INSERT INTO table_summaries (table_name, table_summary) 
                        VALUES (:name, :summary)
                        ON CONFLICT(table_name) DO NOTHING
                    """),
                    {"name": t.table_name, "summary": t.table_summary},
                )
        logger.info(f"Saved {len(table_infos)} summaries to SQLite")
    except Exception as e:
        logger.error(f"Failed to save summaries to SQLite: {e}")


def dump_summaries_json(sqlite_db_dir: Path, table_infos):
    """
    Save table summaries into JSON file.
    Appends/merges with existing file instead of overwriting.
    """
    json_path = os.path.join(sqlite_db_dir, "table_summaries.json")

    try:
        # Load old summaries if file exists
        if os.path.exists(json_path):
            with open(json_path, "r", encoding="utf-8") as f:
                old_data = json.load(f)
        else:
            old_data = []

        # Merge with new summaries (skip duplicates by table_name)
        old_map = {d["table_name"]: d for d in old_data}
        for t in table_infos:
            old_map[t.table_name] = t.model_dump()

        merged_data = list(old_map.values())

        # Write back merged data
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(merged_data, f, indent=2, ensure_ascii=False)

        logger.info(f"JSON dump updated at {json_path} with {len(table_infos)} new/updated entries")
    except Exception as e:
        logger.error(f"Failed to dump JSON summaries: {e}")

    return json_path


def generate_table_summary(program, summary_engine, summary_db_path, engine, inspector, sqlite_db_dir: Path, max_retries: int):
    table_infos = []
    create_summaries_table(summary_engine)

    logger.info("Generating table summaries...")
    with engine.connect() as conn:
        existing_tables = set()
        
        logger.info("Fetching existing summaries from the summaries database...")
        with summary_engine.connect() as summary_conn:
            rows = summary_conn.execute(text("SELECT table_name FROM table_summaries")).fetchall()
            existing_tables = {row[0] for row in rows}
            logger.info(f"Found {len(existing_tables)} existing summaries in DB")
        
        for idx, table in enumerate(inspector.get_table_names()):
            if table in existing_tables:
                logger.info(f" - Skipping table '{table}' — summary already exists.")
                continue
            
            logger.info(f" - Processing new table: {table}")
            df = pd.read_sql(f"SELECT * FROM {table} LIMIT 10;", conn)
            df_str = df.to_csv(index=False)

            table_info = None
            for attempt in range(max_retries):
                try:
                    raw_output = program(
                        table_str=df_str,
                        exclude_table_name_list=str(list(inspector.get_table_names())),
                    )

                    logger.info(f"Normalize LLM output")
                    if isinstance(raw_output, str):
                        parsed_dict = extract_first_json_block(raw_output)
                    elif isinstance(raw_output, dict):
                        parsed_dict = raw_output
                    elif isinstance(raw_output, TableInfo):
                        parsed_dict = raw_output.model_dump()
                    else:
                        logger.error(f"Unexpected return type: {type(raw_output)}")
                        raise TypeError(f"Unexpected return type: {type(raw_output)}")

                    table_info = TableInfo(
                        table_name=table,  # use actual SQLAlchemy inspector name
                        table_summary=parsed_dict["table_summary"],
                    )

                    logger.info(f"Processed table: {table_info.table_name}")
                    break  # success → next table

                except Exception as e:
                    logger.error(f"Error with attempt {attempt+1} for {table}: {e}")
                    time.sleep(2)

            if table_info:
                table_infos.append(table_info)
                
                try:
                    logger.info(f"Saving table summary for {table_info.table_name} immediately to summaries DB")
                    dump_summaries_sqlite(summary_engine, table_infos)
                    json_path = dump_summaries_json(sqlite_db_dir, table_infos)

                except Exception as e:
                    logger.error(f"Failed to save table summary for {table_info.table_name}: {e}")
                    continue

    return table_infos, summary_db_path, json_path

In [None]:
program = text_completion_program(TABLE_INFO_PROMPT)

summary_engine, summary_db_path, engine, inspector = prep_summary_engine(SQLITE_DB_DIR, CHINOOK_DBEAVER_DB_PATH)

table_infos, summary_db_path, json_path = generate_table_summary(program, summary_engine, summary_db_path, engine, inspector, SQLITE_DB_DIR, MAX_RETRIES)

logger.info("\n FINAL TABLE SUMMARIES")
for t in table_infos:
    logger.info(f"- {t.table_name}: {t.table_summary}")

logger.info(f"\nSaved {len(table_infos)} summaries to:")
logger.info(f" - SQLite DB: {summary_db_path}")
logger.info(f" - JSON backup: {json_path}")

2025-08-22 23:56:22,262 [INFO] Creating SQLite DB Engine for the new summaries database: db\Chinook\sqlite\table_summaries.db
2025-08-22 23:56:22,264 [INFO] Creating SQLite DB Engine for the existing Chinook database at C:\Users\Hp\AppData\Roaming\DBeaverData\workspace6\.metadata\sample-database-sqlite-1\Chinook.db
2025-08-22 23:56:22,269 [INFO]  - Ensuring the table exists (id, table_name, table_summary, created_at)
2025-08-22 23:56:22,279 [INFO] Table 'table_summaries' ensured/created.
2025-08-22 23:56:22,280 [INFO] Generating table summaries...
2025-08-22 23:56:22,281 [INFO] Fetching existing summaries from the summaries database...
2025-08-22 23:56:22,283 [INFO] Found 0 existing summaries in DB
2025-08-22 23:56:22,285 [INFO]  - Processing new table: Album
2025-08-22 23:57:01,160 [INFO] Normalize LLM output
2025-08-22 23:57:01,167 [INFO] Processed table: Album
2025-08-22 23:57:01,171 [INFO] Saving table summary for Album immediately to summaries DB
2025-08-22 23:57:01,190 [INFO] Sav

In [6]:
def get_table_schema(table_name: str):
    """Fetch column names and types for an existing table."""
    columns = inspector.get_columns(table_name)
    schema = {col["name"]: str(col["type"]) for col in columns}
    return schema


logger.info("Table Schemas")
for table_name in inspector.get_table_names():
    schema = get_table_schema(table_name)
    logger.info(f"\nTable: {table_name}")
    for col, dtype in schema.items():
        logger.info(f"  {col}: {dtype}")

2025-08-23 00:01:07,558 [INFO] Table Schemas
2025-08-23 00:01:07,568 [INFO] 
Table: Album
2025-08-23 00:01:07,569 [INFO]   AlbumId: INTEGER
2025-08-23 00:01:07,569 [INFO]   Title: NVARCHAR(160)
2025-08-23 00:01:07,570 [INFO]   ArtistId: INTEGER
2025-08-23 00:01:07,572 [INFO] 
Table: Artist
2025-08-23 00:01:07,572 [INFO]   ArtistId: INTEGER
2025-08-23 00:01:07,573 [INFO]   Name: NVARCHAR(120)
2025-08-23 00:01:07,574 [INFO] 
Table: Customer
2025-08-23 00:01:07,575 [INFO]   CustomerId: INTEGER
2025-08-23 00:01:07,576 [INFO]   FirstName: NVARCHAR(40)
2025-08-23 00:01:07,577 [INFO]   LastName: NVARCHAR(20)
2025-08-23 00:01:07,577 [INFO]   Company: NVARCHAR(80)
2025-08-23 00:01:07,578 [INFO]   Address: NVARCHAR(70)
2025-08-23 00:01:07,579 [INFO]   City: NVARCHAR(40)
2025-08-23 00:01:07,580 [INFO]   State: NVARCHAR(40)
2025-08-23 00:01:07,581 [INFO]   Country: NVARCHAR(40)
2025-08-23 00:01:07,581 [INFO]   PostalCode: NVARCHAR(10)
2025-08-23 00:01:07,582 [INFO]   Phone: NVARCHAR(24)
2025-08-23

In [7]:
px.launch_app()
set_global_handler("arize_phoenix")
logger.info("Phoenix launched and global handler set.")

  next(self.gen)
  next(self.gen)


🌍 To view the Phoenix app in your browser, visit http://localhost:6006/
📖 For more information on how to use Phoenix, check out https://arize.com/docs/phoenix
2025-08-23 00:01:19,947 [INFO] Phoenix launched and global handler set.


1. Object index, retriever, SQLDatabase

In [None]:
from utils.llm.get_prompt_temp import RESPONSE_SYNTHESIS_PROMPT
from utils.llm.get_llm_func import get_embedding_func


def wrap_sql_engine(engine):
    """Wrap SQLAlchemy engine into LlamaIndex SQLDatabase + Node Mapping."""
    logger.info("Wrapping engine into LlamaIndex SQLDatabase")
    sql_database = SQLDatabase(engine)

    logger.info("Creating table node mapping, i.e. mapping from SQL tables -> nodes")
    table_node_mapping = SQLTableNodeMapping(sql_database)

    return sql_database, table_node_mapping

def load_summaries_from_sqlite(summary_engine):
    """Load table summaries stored in SQLite DB."""
    logger.info("Loading all existing summaries from SQLite DB")
    with summary_engine.connect() as conn:
        rows = conn.execute(
            text("SELECT table_name, table_summary FROM table_summaries")
        ).fetchall()
    return rows

def filter_valid_summaries(rows, engine):
    """
    Keep only summaries where the table still exists in the main DB.
    Returns a list of SQLTableSchema objects.
    """
    logger.info("Filtering out only valid tables from loaded summaries that exist in the db")
    table_schema_objs = []

    with engine.connect() as conn:
        inspector = inspect(conn)
        existing_tables = inspector.get_table_names()

    for row in rows:
        if row.table_name in existing_tables and row.table_summary:
            table_schema_objs.append(
                SQLTableSchema(table_name=row.table_name, context_str=row.table_summary)
            )
            logger.info(f"Adding table: {row.table_name} with summary: {row.table_summary}")
        else:
            logger.warning(f"Skipping missing/unextracted table: {row.table_name}")

    return table_schema_objs

def load_summaries_from_json(sqlite_db_dir: Path):
    """Load table summaries stored in JSON file."""
    os.makedirs(sqlite_db_dir, exist_ok=True)
    summary_db_path = os.path.join(sqlite_db_dir, "table_summaries.json")
    
    if not os.path.exists(summary_db_path):
        logger.warning(f"No summary JSON found at {summary_db_path}")
        return []

    logger.info(f"Loading summaries from JSON at {summary_db_path}")
    with open(summary_db_path, "r", encoding="utf-8") as f:
        summaries = json.load(f)

    return summaries  # list[{"table_name": str, "table_summary": str}]


def filter_valid_summaries_from_json(summaries, engine):
    """
    Keep only summaries where the table still exists in the main DB.
    Returns a list of SQLTableSchema objects.
    """
    logger.info("Filtering JSON summaries to only valid tables")
    table_schema_objs = []

    with engine.connect() as conn:
        inspector = inspect(conn)
        existing_tables = inspector.get_table_names()

    for row in summaries:
        table_name = row.get("table_name")
        table_summary = row.get("table_summary")
        if table_name in existing_tables and table_summary:
            table_schema_objs.append(
                SQLTableSchema(table_name=table_name, context_str=table_summary)
            )
            logger.info(f"Adding table: {table_name} with summary: {table_summary}")
        else:
            logger.warning(f"Skipping missing/unextracted table: {table_name}")

    return table_schema_objs

def build_object_index(table_schema_objs, table_node_mapping):
    """Build ObjectIndex for retrieval from table summaries."""
    logger.info("Building object index for table retrieval")
    obj_index = ObjectIndex.from_objects(
        table_schema_objs,
        table_node_mapping,
        VectorStoreIndex,
        embed_model=get_embedding_func(),
    )
    return obj_index


def create_retrievers(sql_database, obj_index, top_k: int):
    """Create object retriever and SQL retriever."""
    logger.info("Creating retrievers for query execution")
    obj_retriever = obj_index.as_retriever(similarity_top_k=top_k)
    sql_retriever = SQLRetriever(sql_database)
    return obj_retriever, sql_retriever

In [None]:
sql_database, table_node_mapping = wrap_sql_engine(engine)

rows = load_summaries_from_sqlite(summary_engine)
table_schema_objs = filter_valid_summaries(rows, engine)

# rows = load_summaries_from_json(SQLITE_DB_DIR)
# table_schema_objs = filter_valid_summaries_from_json(rows, engine)

obj_index = build_object_index(table_schema_objs, table_node_mapping)

obj_retriever, sql_retriever = create_retrievers(sql_database, obj_index, top_k=TOP_K)

2025-08-23 00:09:17,026 [INFO] Wrapping engine into LlamaIndex SQLDatabase
2025-08-23 00:09:17,058 [INFO] Creating table node mapping, i.e. mapping from SQL tables -> nodes
2025-08-23 00:09:17,059 [INFO] Loading all existing summaries from SQLite DB
2025-08-23 00:09:17,061 [INFO] Filtering out only valid tables from loaded summaries that exist in the db
2025-08-23 00:09:17,063 [INFO] Adding table: Album with summary: Summary of album and artist data
2025-08-23 00:09:17,063 [INFO] Adding table: Artist with summary: Summary of artist information
2025-08-23 00:09:17,064 [INFO] Adding table: Customer with summary: Summary of customer data
2025-08-23 00:09:17,065 [INFO] Adding table: Employee with summary: Summary of employee information
2025-08-23 00:09:17,065 [INFO] Adding table: Genre with summary: Summary of genre data
2025-08-23 00:09:17,067 [INFO] Adding table: Invoice with summary: Invoice Information
2025-08-23 00:09:17,069 [INFO] Adding table: InvoiceLine with summary: Summary of i

In [11]:
# Table Context String
def get_table_context_str(table_schema_objs: List[SQLTableSchema]):
    """Get table context string (schema + summary)."""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        try:
            # pull schema directly from DB
            table_info = sql_database.get_single_table_info(
                table_schema_obj.table_name
            )
            if table_schema_obj.context_str:
                table_opt_context = " The table description is: "
                table_opt_context += table_schema_obj.context_str
                table_info += table_opt_context
            context_strs.append(table_info)
        except Exception as e:
            logger.error(f"Skipping table {table_schema_obj.table_name}: {e}")
    return "\n\n".join(context_strs)


table_parser_component = get_table_context_str(table_schema_objs)
logger.info(f"Table Context: {table_parser_component}")

2025-08-23 00:28:59,488 [INFO] Table Context: Table 'Album' has columns: AlbumId (INTEGER), Title (NVARCHAR(160)), ArtistId (INTEGER),  and foreign keys: ['ArtistId'] -> Artist.['ArtistId']. The table description is: Summary of album and artist data

Table 'Artist' has columns: ArtistId (INTEGER), Name (NVARCHAR(120)), . The table description is: Summary of artist information

Table 'Customer' has columns: CustomerId (INTEGER), FirstName (NVARCHAR(40)), LastName (NVARCHAR(20)), Company (NVARCHAR(80)), Address (NVARCHAR(70)), City (NVARCHAR(40)), State (NVARCHAR(40)), Country (NVARCHAR(40)), PostalCode (NVARCHAR(10)), Phone (NVARCHAR(24)), Fax (NVARCHAR(24)), Email (NVARCHAR(60)), SupportRepId (INTEGER),  and foreign keys: ['SupportRepId'] -> Employee.['EmployeeId']. The table description is: Summary of customer data

Table 'Employee' has columns: EmployeeId (INTEGER), LastName (NVARCHAR(20)), FirstName (NVARCHAR(20)), Title (NVARCHAR(30)), ReportsTo (INTEGER), BirthDate (DATETIME), Hir

In [12]:
# SQL Output Parser
def parse_response_to_sql(response: ChatResponse) -> str:
    """Parse response into a clean SQL string."""
    response = response.message.content

    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]

    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]

    return response.strip().strip("```").strip()


sql_parser_component = FunctionTool.from_defaults(fn=parse_response_to_sql)


# Prompts
text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
    dialect=engine.dialect.name
)
logger.info(f"\n Text-to-SQL Prompt: {text2sql_prompt.template}")


response_synthesis_prompt = PromptTemplate(RESPONSE_SYNTHESIS_PROMPT)

2025-08-23 00:29:58,766 [INFO] 
 Text-to-SQL Prompt: Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use tables listed below.
{schema}

Question: {query_str}
SQLQuery: 


### Index Each Table

We embed/index the rows of each table, resulting in one index per table.

In [13]:
def index_all_tables_with_chroma(sql_database, chroma_db_dir: str) -> Dict[str, VectorStoreIndex]:
    """Index all tables in the SQL database using ChromaDB as the backend.
    Args:
        sql_database: SQLDatabase instance
        chroma_db_dir: Directory for ChromaDB persistence
        
    Returns:
        Dict mapping table names to VectorStoreIndex instances
    """
    os.makedirs(chroma_db_dir, exist_ok=True)

    vector_index_dict = {}
    engine = sql_database.engine

    logger.info(f" [00] Creating persistent Chroma client at: {chroma_db_dir}")
    chroma_client = chromadb.PersistentClient(path=chroma_db_dir)

    for table_name in sql_database.get_usable_table_names():
        logger.info(f"[01] Processing table: {table_name}")
        
        try:
            # Create or get collection - ChromaDB handles persistence internally
            collection = chroma_client.get_or_create_collection(name=f"table_{table_name}")
            
            # Check if collection already has data
            if collection.count() == 0:
                logger.info(f"[02] Building new index for empty collection: {table_name}")
                
                # Fetch data from database
                with engine.connect() as conn:
                    result = conn.execute(text(f'SELECT * FROM "{table_name}"'))
                    col_names = list(result.keys())
                    rows = result.fetchall()
                
                if not rows:
                    logger.warning(f"[02.1] Table {table_name} is empty, skipping...")
                    continue
                
                logger.info(f"[02.2] Converting {len(rows)} rows to structured text")
                row_texts = [
                    " | ".join([f"{col}={val}" for col, val in zip(col_names, row)])
                    for row in rows
                ]
                
                # Create TextNodes with proper IDs
                nodes = [
                    TextNode(
                        text=row_text, 
                        id_=f"{table_name}_row_{idx}"
                    ) 
                    for idx, row_text in enumerate(row_texts)
                ]
                
                logger.info(f"[02.3] Creating vector store for table: {table_name}")
                vector_store = ChromaVectorStore(chroma_collection=collection)
                
                # Create index - this will automatically add nodes to ChromaDB
                logger.info(f"[02.4] Building vector index with {len(nodes)} nodes")
                storage_context = StorageContext.from_defaults(vector_store=vector_store)
                index = VectorStoreIndex(nodes, storage_context=storage_context)
                
                logger.info(f"[02.5] Index created successfully for table: {table_name}")
                
            else:
                logger.info(f"[03] Reusing existing collection with {collection.count()} items: {table_name}")
                
                # Create vector store from existing collection
                vector_store = ChromaVectorStore(chroma_collection=collection)
                storage_context = StorageContext.from_defaults(vector_store=vector_store)
                
                # Create index from existing vector store
                index = VectorStoreIndex.from_vector_store(
                    vector_store=vector_store,
                    storage_context=storage_context
                )
            
            vector_index_dict[table_name] = index
            logger.info(f"[04] Successfully indexed table: {table_name}")
            
        except Exception as e:
            logger.error(f"[ERROR] Failed to index table {table_name}: {str(e)}")
            raise
    
    logger.info(f"[05] Successfully indexed {len(vector_index_dict)} tables")
    return vector_index_dict


# Build vector indexes for all tables using ChromaDB
vector_index_dict = index_all_tables_with_chroma(sql_database, CHROMA_DB_DIR)

2025-08-23 00:58:29,690 [INFO]  [00] Creating persistent Chroma client at: db\Chinook\chromadb
2025-08-23 00:58:30,866 [INFO] [01] Processing table: Album
2025-08-23 00:58:30,930 [INFO] [02] Building new index for empty collection: Album
2025-08-23 00:58:30,935 [INFO] [02.2] Converting 347 rows to structured text
2025-08-23 00:58:30,939 [INFO] [02.3] Creating vector store for table: Album
2025-08-23 00:58:30,941 [INFO] [02.4] Building vector index with 347 nodes
2025-08-23 00:58:36,539 [INFO] [02.5] Index created successfully for table: Album
2025-08-23 00:58:36,543 [INFO] [04] Successfully indexed table: Album
2025-08-23 00:58:36,544 [INFO] [01] Processing table: Artist
2025-08-23 00:58:36,592 [INFO] [02] Building new index for empty collection: Artist
2025-08-23 00:58:36,596 [INFO] [02.2] Converting 275 rows to structured text
2025-08-23 00:58:36,600 [INFO] [02.3] Creating vector store for table: Artist
2025-08-23 00:58:36,601 [INFO] [02.4] Building vector index with 275 nodes
2025-0

In [14]:
def get_table_context_and_rows_str(query_str: str, table_schema_objs: List[TableInfo]) -> str:
    """Get table context string for relevant example rows.
    Args:
        query_str: Query string for similarity search
        table_schema_objs: List of TableInfo objects
        
    Returns:
        Combined context string for all tables
    """
    context_strs = []
    
    for table_info_obj in table_schema_objs:
        table_name = table_info_obj.table_name
        
        try:
            logger.info(f"[01] Getting schema for table: {table_name}")
            table_info = sql_database.get_single_table_info(table_name)
            
            # Check if we have a vector index for this table
            if table_name not in vector_index_dict:
                logger.warning(f"[02] No vector index found for table: {table_name}")
                context_strs.append(table_info)
                continue
            
            logger.info(f"[02] Retrieving example rows for table: {table_name}")
            vector_retriever = vector_index_dict[table_name].as_retriever(
                similarity_top_k=TOP_N
            )
            
            relevant_nodes = vector_retriever.retrieve(query_str)
            logger.info(f"[03] Retrieved {len(relevant_nodes)} relevant nodes for table: {table_name}")
            
            if relevant_nodes:
                table_row_context = "\nHere are some relevant example rows (column=value):\n"
                for node in relevant_nodes:
                    table_row_context += f"- {node.get_content()}\n"
                table_info += table_row_context
            else:
                logger.info(f"[03.1] No relevant rows found for query in table: {table_name}")
            
            context_strs.append(table_info)
            
        except Exception as e:
            logger.error(f"[ERROR] Failed to get context for table {table_name}: {str(e)}")
            # Still add basic table info even if retrieval fails
            try:
                table_info = sql_database.get_single_table_info(table_name)
                context_strs.append(table_info)
            except Exception as schema_error:
                logger.error(f"[ERROR] Failed to get schema for table {table_name}: {str(schema_error)}")
    
    return "\n\n".join(context_strs)

table_parser_component = get_table_context_and_rows_str(QUERY_TEXT, table_schema_objs)
logger.info(f"Updated table context with rows:\n{table_parser_component}")

2025-08-23 01:44:27,736 [INFO] [01] Getting schema for table: Album
2025-08-23 01:44:27,738 [INFO] [02] Retrieving example rows for table: Album
2025-08-23 01:44:27,877 [INFO] [03] Retrieved 2 relevant nodes for table: Album
2025-08-23 01:44:27,879 [INFO] [01] Getting schema for table: Artist
2025-08-23 01:44:27,880 [INFO] [02] Retrieving example rows for table: Artist
2025-08-23 01:44:27,939 [INFO] [03] Retrieved 2 relevant nodes for table: Artist
2025-08-23 01:44:27,941 [INFO] [01] Getting schema for table: Customer
2025-08-23 01:44:27,942 [INFO] [02] Retrieving example rows for table: Customer
2025-08-23 01:44:27,995 [INFO] [03] Retrieved 2 relevant nodes for table: Customer
2025-08-23 01:44:27,996 [INFO] [01] Getting schema for table: Employee
2025-08-23 01:44:27,998 [INFO] [02] Retrieving example rows for table: Employee
2025-08-23 01:44:28,099 [INFO] [03] Retrieved 2 relevant nodes for table: Employee
2025-08-23 01:44:28,100 [INFO] [01] Getting schema for table: Genre
2025-08-23 

In [None]:
def get_table_context_str(table_schema_objs: List[SQLTableSchema]) -> Dict[str, str]:
    """Get table context (schema + summary) for multiple tables.
    
    Returns a dict {table_name: schema_context}
    """
    table_context = {}
    
    for table_schema_obj in table_schema_objs:
        try:            
            # pull schema directly from DB
            table_info = sql_database.get_single_table_info(
                table_schema_obj.table_name
            )
            if table_schema_obj.context_str:
                table_info += f" The table description is: {table_schema_obj.context_str}"
            
            table_context[table_schema_obj.table_name] = table_info
        
        except Exception as e:
            logger.error(f"Skipping table {table_schema_obj.table_name}: {e}")
    
    return table_context

def get_table_context_and_rows_str(query_str: str, table_schema_objs: List[TableInfo], top_n: int) -> str:
    """Get table context string (schema + relevant example rows)."""
    
    base_contexts = get_table_context_str(table_schema_objs)
    context_strs = []
    
    for table_name, schema_context in base_contexts.items():
        try:
            # Check if we have a vector index for this table
            if table_name not in vector_index_dict:
                logger.warning(f"No vector index found for table: {table_name}")
                context_strs.append(schema_context)
                continue
            
            logger.info(f"[01] Retrieving example rows for table: {table_name}")
            vector_retriever = vector_index_dict[table_name].as_retriever(
                similarity_top_k=top_n
            )
            
            relevant_nodes = vector_retriever.retrieve(query_str)
            logger.info(f"[02] Retrieved {len(relevant_nodes)} relevant nodes for table: {table_name}")
            
            if relevant_nodes:
                row_context = "\nHere are some relevant example rows (column=value):\n"
                row_context += "\n".join([f"- {node.get_content()}" for node in relevant_nodes])
                schema_context += "\n" + row_context
            else:
                logger.info(f"[02.1] No relevant rows found for query in table: {table_name}")
            
            context_strs.append(schema_context)
        
        except Exception as e:
            logger.error(f"Failed to enrich context for table {table_name}: {str(e)}")
            context_strs.append(schema_context)  # fallback to just schema
    
    return "\n\n".join(context_strs)

table_parser_component = get_table_context_and_rows_str(QUERY_TEXT, table_schema_objs)
logger.info(f"Updated table context with rows:\n{table_parser_component}")

### Define Workflow

In [16]:
from utils.workflow.custom_events import (
    TableRetrievedEvent,
    SchemaProcessedEvent,
    SQLPromptReadyEvent,
    SQLGeneratedEvent,
    SQLParsedEvent,
    SQLResultsEvent,
    ResponsePromptReadyEvent,
)
from utils.workflow.custom_fallbacks import (
    extract_sql_from_response,
    analyze_sql_error,
    create_t2s_prompt,
)


class Text2SQLWorkflowRowRetrieval(Workflow):
    @step
    async def input_step(self, ev: StartEvent) -> TableRetrievedEvent:
        logger.info(f"[Step 01] Process initial query and retrieve relevant tables")
        query = ev.query

        logger.info(f" - Use object retriever built from your table summaries")
        tables = obj_retriever.retrieve(query)  # candidate schemas
        logger.info(f" - Retrieved {len(tables)} candidate tables for query: {query}")
        
        return TableRetrievedEvent(
            tables=tables, 
            query_str=query
        )

    @step
    async def table_output_parser_step(self, ev: TableRetrievedEvent) -> SchemaProcessedEvent:
        logger.info(f"[Step 02] Parsing schemas and retrieving relevant rows for query: {ev.query_str}")

        logger.info(f" - Enriching context function with vector row retrieval for tables: {ev.tables}")
        schema_str = get_table_context_and_rows_str(ev.query_str, ev.tables)
        
        return SchemaProcessedEvent(
            table_schema=schema_str, 
            query_str=ev.query_str
        )

    @step
    async def text2sql_prompt_step(self, ev: SchemaProcessedEvent | SQLResultsEvent) -> SQLPromptReadyEvent:
        logger.info(f"[Step 03] Creating SQL prompt for query: {ev.query_str}")
        if isinstance(ev, SchemaProcessedEvent):
            table_schema = ev.table_schema
            query_str = ev.query_str
            retry_count = 0
            error_message = ""
        else:
            table_schema = getattr(ev, 'table_schema', '')
            query_str = ev.query_str
            retry_count = getattr(ev, 'retry_count', 0) + 1
            error_message = getattr(ev, 'error_message', '')

        prompt = create_t2s_prompt(table_schema, query_str, retry_count, error_message)
        
        return SQLPromptReadyEvent(
            t2s_prompt=prompt,
            query_str=query_str,
            table_schema=table_schema,
            retry_count=retry_count,
            error_message=error_message
        )

    @step
    async def text2sql_llm_step(self, ev: SQLPromptReadyEvent) -> SQLGeneratedEvent:
        logger.info(f"[Step 04] Running LLM to generate SQL for query: {ev.query_str}")
        sql_response = await Settings.llm.acomplete(ev.t2s_prompt)
        
        return SQLGeneratedEvent(
            sql_query=str(sql_response).strip(),
            query_str=ev.query_str,
            table_schema=ev.table_schema,
            retry_count=ev.retry_count,
            error_message=ev.error_message
        )

    @step
    async def sql_output_parser_step(self, ev: SQLGeneratedEvent) -> SQLParsedEvent:
        logger.info(f"[Step 05] Parsing LLM response to extract clean SQL for query: {ev.query_str}")
        try:
            clean_sql = parse_response_to_sql(ev.sql_query)  # primary parser
        except Exception:
            clean_sql = extract_sql_from_response(ev.sql_query, logger)  # fallback
        
        if not clean_sql:
            clean_sql = extract_sql_from_response(ev.sql_query, logger)

        logger.info(f"Attempt #{ev.retry_count + 1}")
        logger.info(f"LLM Response: {ev.sql_query}")
        logger.info(f"Cleaned SQL: {clean_sql}")

        return SQLParsedEvent(
            sql_query=clean_sql,
            query_str=ev.query_str,
            table_schema=ev.table_schema,
            retry_count=ev.retry_count,
            error_message=ev.error_message
        )

    @step
    async def sql_retriever_step(self, ev: SQLParsedEvent) -> SQLResultsEvent:
        logger.info(f"[Step 06] Executing SQL for query: {ev.query_str}")
        try:
            results = sql_retriever.retrieve(ev.sql_query)
            logger.info(f"[SUCCESS] Executed on Attempt #{ev.retry_count + 1}")

            return SQLResultsEvent(
                context_str=str(results),
                sql_query=ev.sql_query,
                query_str=ev.query_str,
                success=True
            )
        except Exception as e:
            error_msg = str(e)
            logger.error(f"Execution failed (Attempt #{ev.retry_count + 1}): {error_msg}")

            if ev.retry_count < MAX_RETRIES:
                retry_event = SQLResultsEvent(
                    context_str="",
                    sql_query=ev.sql_query,
                    query_str=ev.query_str,
                    success=False,
                    retry_count=ev.retry_count + 1,
                )
                retry_event.retry_count = ev.retry_count + 1
                retry_event.error_message = analyze_sql_error(error_msg, ev.sql_query, ev.table_schema, logger)
                retry_event.table_schema = ev.table_schema
                
                return retry_event
            else:
                return SQLResultsEvent(
                    context_str=(f"Failed after {MAX_RETRIES+1} attempts. Final error: {error_msg}"),
                    sql_query=ev.sql_query,
                    query_str=ev.query_str,
                    success=False,
                    retry_count=ev.retry_count + 1,
                )

    @step
    async def retry_handler_step(self, ev: SQLResultsEvent) -> SQLPromptReadyEvent:
        logger.info(f"[Step 07] Handling retry for query: {ev.query_str}")
        if ev.success:
            return None
        
        return SQLPromptReadyEvent(
            t2s_prompt="",  # regenerated later
            query_str=ev.query_str,
            table_schema=getattr(ev, 'table_schema', ''),
            retry_count=ev.retry_count,
            error_message=getattr(ev, 'error_message', 'Unknown error')
        )

    @step
    async def response_synthesis_prompt_step(self, ev: SQLResultsEvent) -> ResponsePromptReadyEvent:
        logger.info(f"[Step 08] Preparing synthesis prompt for query: {ev.query_str}")
        if not ev.success:
            return None
        prompt = response_synthesis_prompt.format(
            query_str=ev.query_str,
            context_str=ev.context_str,
            sql_query=ev.sql_query
        )
        
        return ResponsePromptReadyEvent(
            query_str=ev.query_str,
            rs_prompt=prompt
        )

    @step
    async def response_synthesis_llm_step(self, ev: ResponsePromptReadyEvent) -> StopEvent:
        logger.info(f"[Step 09] Generating final answer for query: {ev.query_str}")
        answer = await Settings.llm.acomplete(ev.rs_prompt)
        
        return StopEvent(result=str(answer))


# Runner
async def run_text2sql_workflow_row(query: str):
    workflow = Text2SQLWorkflowRowRetrieval(timeout=480)
    result = await workflow.run(query=query)
    return result

In [17]:
result = await run_text2sql_workflow_row(QUERY_TEXT)
print(result)

2025-08-23 01:46:17,928 [INFO] [Step 01] Process initial query and retrieve relevant tables
2025-08-23 01:46:17,929 [INFO]  - Use object retriever built from your table summaries
2025-08-23 01:46:17,991 [INFO]  - Retrieved 5 candidate tables for query: What is the billing city of Leonie Köhler?
2025-08-23 01:46:18,000 [INFO] [Step 02] Parsing schemas and retrieving relevant rows for query: What is the billing city of Leonie Köhler?
2025-08-23 01:46:18,002 [INFO]  - Enriching context function with vector row retrieval for tables: [SQLTableSchema(table_name='Invoice', context_str='Invoice Information'), SQLTableSchema(table_name='Customer', context_str='Summary of customer data'), SQLTableSchema(table_name='Employee', context_str='Summary of employee information'), SQLTableSchema(table_name='Artist', context_str='Summary of artist information'), SQLTableSchema(table_name='InvoiceLine', context_str='Summary of invoice data')]
2025-08-23 01:46:18,002 [INFO] [01] Getting schema for table: I

Some tasks did not clean up within timeout


<think>
Okay, let's see. The user is asking about the billing city of Leonie Köhler, and they provided an SQL query that selects BillingCity from Invoice where CustomerId is 2. The SQL response shows that there are multiple entries with Stuttgart as the billing city. 

First, I need to parse the SQL response. The metadata part lists the result as [('Stuttgart',), ...], which means there are multiple entries, each with Stuttgart. The score is None, which might indicate that the query didn't find any results, but the user is asking about Leonie, so maybe there's a mistake in the query or the data. However, the key point here is that the billing city is consistently Stuttgart.

So, the answer should directly state that Leonie Köhler's billing city is Stuttgart. Even though there are multiple entries, the city is the same across all. Therefore, the response is straightforward.
</think>

The billing city of Leonie Köhler is **Stuttgart**.


Visualize

In [None]:
async def visualize_text2sql_workflow_row(sample_query: str, execution_name: str, output_dir: str = WORKFLOW_VISUALIZATION_DIR):
    """
    Function to visualize the Text2SQL workflow in your version:
    - Draws all possible flows
    - Runs your row-retrieval Text2SQL workflow
    - Draws execution path of the actual run
    """
    os.makedirs(output_dir, exist_ok=True)
    
    logger.info("[01] Drawing all possible flows...")
    all_flows_path = os.path.join(output_dir, f"{execution_name}_text2sql_workflow_flow.html")
    draw_all_possible_flows(
        Text2SQLWorkflowRowRetrieval,
        filename=all_flows_path
    )
    logger.info(f"[SUCCESS] All possible flows saved to: {all_flows_path}")

    logger.info("[02] Running workflow and drawing execution path...")
    try:
        logger.info(" - wrapper function instead of manual instantiation")
        result = await run_text2sql_workflow_row(sample_query)

        logger.info(" - Recreating workflow object for execution path drawing")
        workflow = Text2SQLWorkflowRowRetrieval(timeout=240)

        execution_path = os.path.join(output_dir, f"{execution_name}_text2sql_workflow_execution.html")
        draw_most_recent_execution(
            workflow,
            filename=execution_path
        )
        logger.info(f"[SUCCESS] Recent execution path saved to: {execution_path}")
        logger.info(f"Workflow result: {result.result}")  
        logger.debug("this `.result` holds final answer")

    except Exception as e:
        logger.error(f"Error during workflow execution: {e}")
        logger.info("Note: Ensure retrievers + LLM configs are initialized correctly")

In [None]:
await visualize_text2sql_workflow_row(QUERY_TEXT, QUERY_TEXT_INITIAL)