# Query Pipeline for Advanced Text-to-SQL

In this guide we show you how to setup a text-to-SQL pipeline over your data with our [query pipeline](https://docs.llamaindex.ai/en/stable/module_guides/querying/pipeline/root.html) syntax.

This gives you flexibility to enhance text-to-SQL with additional techniques. We show these in the below sections:
1. **Query-Time Table Retrieval**: Dynamically retrieve relevant tables in the text-to-SQL prompt.
2. **Query-Time Sample Row retrieval**: Embed/Index each row, and dynamically retrieve example rows for each table in the text-to-SQL prompt.

Our out-of-the box pipelines include our `NLSQLTableQueryEngine` and `SQLTableRetrieverQueryEngine`. (if you want to check out our text-to-SQL guide using these modules, take a look [here](https://docs.llamaindex.ai/en/stable/examples/index_structs/struct_indices/SQLIndexDemo.html)). This guide implements an advanced version of those modules, giving you the utmost flexibility to apply this to your own setting.

## Load and Ingest Data


### Load Data
We use the [WikiTableQuestions dataset](https://ppasupat.github.io/WikiTableQuestions/) (Pasupat and Liang 2015) as our test dataset.

We go through all the csv's in one folder, store each in a sqlite database (we will then build an object index over each table schema).

In [1]:
%pwd

'c:\\Users\\Hp\\Documents\\GitHub\\rag_text-2-sql\\notebooks'

In [2]:
import os

os.chdir("../")

%pwd

'c:\\Users\\Hp\\Documents\\GitHub\\rag_text-2-sql'

In [None]:
from utils.helpers.other_imports import (
    io,
    time,
    re,
    requests,
    zipfile,
    shutil,
    gc,
    traceback,
    json,
    pyjson,
    pd,
    Path,
    List,
    Dict,
    BaseModel,
    Field,
    px,
)

from utils.helpers.sql_alchemy_imports import (
    create_engine,
    text,
    MetaData,
    Table,
    Column,
    String,
    Integer,
)

from utils.helpers.llama_index_imports import (
    Settings, 
    SQLDatabase, 
    VectorStoreIndex, 
    load_index_from_storage,
    set_global_handler,
    LLMTextCompletionProgram,
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
    SQLRetriever,
    DEFAULT_TEXT_TO_SQL_PROMPT,
    PromptTemplate,
    FunctionTool,
    ChatResponse,
    TextNode,
    StorageContext,
    Workflow,
    step,
    StartEvent,
    StopEvent,
    draw_all_possible_flows,
    draw_most_recent_execution,
)

from utils.config import CONFIG
from utils.logger import setup_logger


# configurations
LOG_PATH = Path(CONFIG["LOG_PATH"])

LOCAL_EX1_DATA_PATH = CONFIG["LOCAL_EX1_DATA_PATH"]
LOCAL_EX1_DATA_FOLDER_DIR = Path(CONFIG["LOCAL_EX1_DATA_FOLDER_DIR"])
LOCAL_EX1_DATA_TABLE_INFO_DIR = Path(CONFIG["LOCAL_EX1_DATA_TABLE_INFO_DIR"])
LOCAL_EX1_DATA_TABLE_INDEX_DIR = Path(CONFIG["LOCAL_EX1_DATA_TABLE_INDEX_DIR"])

SQLITE_DB_DIR = Path(CONFIG["SQLITE_DB_DIR"])
SQLITE_DB_FILE = CONFIG["SQLITE_DB_FILE"]

WORKFLOW_VISUALIZATION_DIR = Path(CONFIG["WORKFLOW_VISUALIZATION_DIR"])

QUERY_1 = CONFIG["QUERY_1"]
QUERY_1_INITIAL = CONFIG["QUERY_1_INITIAL"]
QUERY_2 = CONFIG["QUERY_2"]
QUERY_2_INITIAL = CONFIG["QUERY_2_INITIAL"]

TOP_K = CONFIG["TOP_K"]
MAX_RETRIES = CONFIG["MAX_RETRIES"]

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# setup logging
LOG_DIR = os.path.join(os.getcwd(), LOG_PATH)
os.makedirs(LOG_DIR, exist_ok=True)  # Create the logs directory if it doesn't exist

# comment out line 15 in utils/logger.py -> only for notebooks
LOG_FILE = os.path.join(LOG_DIR, "test_notebook.log")
logger = setup_logger("test_notebook_logger", LOG_FILE)

In [6]:
DATA_DIR = LOCAL_EX1_DATA_FOLDER_DIR
CSV_FILES = sorted([f for f in DATA_DIR.glob("*.csv")])

TABLEINFO_DIR = LOCAL_EX1_DATA_TABLE_INFO_DIR
os.makedirs(TABLEINFO_DIR, exist_ok=True)

dfs = []

for csv_file in CSV_FILES:
    print(f"processing file: {csv_file}")
    try:
        df = pd.read_csv(csv_file)
        dfs.append(df)
    except Exception as e:
        print(f"Error parsing {csv_file}: {str(e)}")

processing file: data\ex1\WikiTableQuestions\csv\200-csv\0.csv
processing file: data\ex1\WikiTableQuestions\csv\200-csv\1.csv
processing file: data\ex1\WikiTableQuestions\csv\200-csv\10.csv
processing file: data\ex1\WikiTableQuestions\csv\200-csv\11.csv
processing file: data\ex1\WikiTableQuestions\csv\200-csv\12.csv
processing file: data\ex1\WikiTableQuestions\csv\200-csv\14.csv
processing file: data\ex1\WikiTableQuestions\csv\200-csv\15.csv
Error parsing data\ex1\WikiTableQuestions\csv\200-csv\15.csv: Error tokenizing data. C error: Expected 4 fields in line 16, saw 5

processing file: data\ex1\WikiTableQuestions\csv\200-csv\17.csv
Error parsing data\ex1\WikiTableQuestions\csv\200-csv\17.csv: Error tokenizing data. C error: Expected 6 fields in line 5, saw 7

processing file: data\ex1\WikiTableQuestions\csv\200-csv\18.csv
processing file: data\ex1\WikiTableQuestions\csv\200-csv\20.csv
processing file: data\ex1\WikiTableQuestions\csv\200-csv\22.csv
processing file: data\ex1\WikiTableQu

### Extract Table Name and Summary from each Table

Here we use gpt-3.5 to extract a table name (with underscores) and summary from each table with our Pydantic program.

In [7]:
from utils.llm.get_prompt_temp import TABLE_INFO_PROMPT
from utils.llm.get_llm_func import get_llm_func


class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(
        ..., description="table name (must be underscores and NO spaces)"
    )
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )

program = LLMTextCompletionProgram.from_defaults(
    output_cls=TableInfo,
    prompt_template_str=TABLE_INFO_PROMPT,
    llm=get_llm_func(),
)

def extract_first_json_block(text: str):
    match = re.search(r"\{.*\}", text, re.S)  # grab first {...} block
    if not match:
        raise ValueError("No JSON object found in output")
    return pyjson.loads(match.group())


def _get_tableinfo_with_index(idx: int) -> str:
    results_gen = Path(TABLEINFO_DIR).glob(f"{idx}_*")
    results_list = list(results_gen)
    
    if len(results_list) == 0:
        return None
    elif len(results_list) == 1:
        path = results_list[0]
        json_str = path.read_text(encoding="utf-8")
        return TableInfo.model_validate_json(json_str)
    else:
        raise ValueError(f"More than one file matching index: {list(results_gen)}")


table_names = set()
table_infos = []

for idx, df in enumerate(dfs):
    table_info = _get_tableinfo_with_index(idx)
    if table_info:
        table_infos.append(table_info)
        continue

    df_str = df.head(10).to_csv()

    for attempt in range(MAX_RETRIES):
        try:
            raw_output = program(
                table_str=df_str,
                exclude_table_name_list=str(list(table_names)),
            )

            if isinstance(raw_output, TableInfo):
                table_info = raw_output
            elif isinstance(raw_output, dict):
                table_info = TableInfo(**raw_output)
            elif isinstance(raw_output, str):
                parsed_dict = extract_first_json_block(raw_output)
                table_info = TableInfo(**parsed_dict)
            else:
                raise TypeError(f"Unexpected return type from program(): {type(raw_output)}")

            table_name = table_info.table_name
            print(f"Processed table: {table_name}")

            if table_name in table_names:
                print(f"Table name '{table_name}' already exists, skipping this table.")
                table_info = None  # don’t append duplicate
                break  # skip

            # save table info
            table_names.add(table_name)
            out_file = f"{TABLEINFO_DIR}/{idx}_{table_name}.json"
            json.dump(table_info.model_dump(), open(out_file, "w"))
            break  # move to next table

        except Exception as e:
            print(f"Error with attempt {attempt+1}: {e}")
            time.sleep(2)

    if table_info:
        table_infos.append(table_info)

### Put Data in SQL Database

We use `sqlalchemy`, a popular SQL database toolkit, to load all the tables.

In [8]:
# Function to create a sanitized column name
def sanitize_column_name(col_name):
    # Remove special characters and replace spaces with underscores
    return re.sub(r"\W+", "_", col_name)


# Function to create a table from a DataFrame using SQLAlchemy
def create_table_from_dataframe(
    df: pd.DataFrame, table_name: str, engine, metadata_obj
):
    # Sanitize column names
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    # Dynamically create columns based on DataFrame columns and data types
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]

    # Create a table with the defined columns
    table = Table(table_name, metadata_obj, *columns)

    # Create the table in the database
    metadata_obj.create_all(engine)

    # Insert data from DataFrame into the table
    with engine.connect() as conn:
        for _, row in df.iterrows():
            insert_stmt = table.insert().values(**row.to_dict())
            conn.execute(insert_stmt)
        conn.commit()


# engine = create_engine("sqlite:///:memory:")
engine = create_engine(f"sqlite:///{SQLITE_DB_FILE}")
metadata_obj = MetaData()
for idx, df in enumerate(dfs):
    tableinfo = _get_tableinfo_with_index(idx)
    if tableinfo is None:
        print(f"[ERROR] No TableInfo for index {idx}")
        continue  # skip this one or handle it differently
    print(f"Creating table: {tableinfo.table_name}")
    create_table_from_dataframe(df, tableinfo.table_name, engine, metadata_obj)

Creating table: movie_chart_positions
Creating table: movie_data
Creating table: death_accident_statistics
Creating table: award_data_1972
Creating table: award_data
Creating table: people_info
Creating table: broadcasting_info
Creating table: person_info
Creating table: chart_positions
Creating table: kodachrome_film_info
Creating table: bbc_radio_costs
Creating table: airport_locations
Creating table: party_voters
Creating table: club_performance
Creating table: horse_race_data
Creating table: grammy_awards
Creating table: boxing_matches
Creating table: sports_performance_data
Creating table: district_info
Creating table: party_data
Creating table: award_nominations
Creating table: government_ministers
Creating table: new_municipality_old_municipality_seat
Creating table: team_performance
Creating table: encoding_info
Creating table: temperature_data
Creating table: people_terms
Creating table: new_mexico_governorships
Creating table: weather_statistics
Creating table: drop_event_dat

In [9]:
px.launch_app()
set_global_handler("arize_phoenix")

  next(self.gen)
  next(self.gen)


🌍 To view the Phoenix app in your browser, visit http://localhost:6006/
📖 For more information on how to use Phoenix, check out https://arize.com/docs/phoenix


### Define Modules

Here we define the core modules.
1. Object index + retriever to store table schemas
2. SQLDatabase object to connect to the above tables + SQLRetriever.
3. Text-to-SQL Prompt
4. Response synthesis Prompt
5. LLM

In [10]:
from utils.llm.get_llm_func import get_embedding_func
from utils.llm.get_prompt_temp import RESPONSE_SYNTHESIS_PROMPT


# Object index, retriever, SQLDatabase
sql_database = SQLDatabase(engine)
table_node_mapping = SQLTableNodeMapping(sql_database)

table_schema_objs = [
    SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
    for t in table_infos
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
    embed_model=get_embedding_func(),
)
obj_retriever = obj_index.as_retriever(similarity_top_k=5)



# SQLRetriever + Table Parser
sql_retriever = SQLRetriever(sql_database)


def get_table_context_str(table_schema_objs: List[SQLTableSchema]):
    """Get table context string."""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        table_info = sql_database.get_single_table_info(
            table_schema_obj.table_name
        )
        if table_schema_obj.context_str:
            table_opt_context = " The table description is: "
            table_opt_context += table_schema_obj.context_str
            table_info += table_opt_context

        context_strs.append(table_info)
    return "\n\n".join(context_strs)


table_parser_component = get_table_context_str(table_schema_objs)



# Text-to-SQL Prompt + Output Parser
def parse_response_to_sql(response: ChatResponse) -> str:
    """Parse response to SQL."""
    response = response.message.content
    sql_query_start = response.find("SQLQuery:")
    
    if sql_query_start != -1:
        response = response[sql_query_start:]
        
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    
    sql_result_start = response.find("SQLResult:")
    
    if sql_result_start != -1:
        response = response[:sql_result_start]
    
    return response.strip().strip("```").strip()


sql_parser_component = FunctionTool.from_defaults(fn=parse_response_to_sql)

text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
    dialect=engine.dialect.name
)
print(text2sql_prompt.template)



# Response Synthesis Prompt
response_synthesis_prompt = PromptTemplate(RESPONSE_SYNTHESIS_PROMPT)

Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use tables listed below.
{schema}

Question: {query_str}
SQLQuery: 


## Advanced Capability 2: Text-to-SQL with Query-Time Row Retrieval (along with Table Retrieval)

One problem in the previous example is that if the user asks a query that asks for "The Notorious BIG" but the artist is stored as "The Notorious B.I.G", then the generated SELECT statement will likely not return any matches.

We can alleviate this problem by fetching a small number of example rows per table. A naive option would be to just take the first k rows. Instead, we embed, index, and retrieve k relevant rows given the user query to give the text-to-SQL LLM the most contextually relevant information for SQL generation.

We now extend our query pipeline.

### Index Each Table

We embed/index the rows of each table, resulting in one index per table.

In [11]:
def index_all_tables(sql_database, table_index_dir: str = LOCAL_EX1_DATA_TABLE_INDEX_DIR) -> Dict[str, VectorStoreIndex]:
    """Index all tables in the SQL database."""
    Path(table_index_dir).mkdir(parents=True, exist_ok=True)

    vector_index_dict = {}
    engine = sql_database.engine

    for table_name in sql_database.get_usable_table_names():
        print(f"Indexing rows in table: {table_name}")
        table_path = Path(table_index_dir) / table_name

        if not table_path.exists():
            # Fetch all rows from the table
            with engine.connect() as conn:
                result = conn.execute(text(f'SELECT * FROM "{table_name}"'))
                row_tuples = [tuple(row) for row in result.fetchall()]

            # Create TextNode objects from rows
            nodes = [TextNode(text=str(row)) for row in row_tuples]

            # Build the index using current global Settings
            index = VectorStoreIndex(nodes)

            # Save index
            index.set_index_id("vector_index")
            index.storage_context.persist(persist_dir=str(table_path))

        else:
            # Rebuild storage context from saved directory
            storage_context = StorageContext.from_defaults(
                persist_dir=str(table_path)
            )

            # Load existing index
            index = load_index_from_storage(
                storage_context, index_id="vector_index"
            )

        vector_index_dict[table_name] = index

    return vector_index_dict

vector_index_dict = index_all_tables(sql_database)

Indexing rows in table:  ohio_districts
Loading llama_index.core.storage.kvstore.simple_kvstore from data\ex1\table_index_dir\ ohio_districts\docstore.json.
Loading llama_index.core.storage.kvstore.simple_kvstore from data\ex1\table_index_dir\ ohio_districts\index_store.json.
Indexing rows in table: afrikaans_language_usage
Loading llama_index.core.storage.kvstore.simple_kvstore from data\ex1\table_index_dir\afrikaans_language_usage\docstore.json.
Loading llama_index.core.storage.kvstore.simple_kvstore from data\ex1\table_index_dir\afrikaans_language_usage\index_store.json.
Indexing rows in table: airport_locations
Loading llama_index.core.storage.kvstore.simple_kvstore from data\ex1\table_index_dir\airport_locations\docstore.json.
Loading llama_index.core.storage.kvstore.simple_kvstore from data\ex1\table_index_dir\airport_locations\index_store.json.
Indexing rows in table: award_data
Loading llama_index.core.storage.kvstore.simple_kvstore from data\ex1\table_index_dir\award_data\docs

### Define Expanded Table Parser Component

We expand the capability of our `table_parser_component` to not only return the relevant table schemas, but also return relevant rows per table schema.

It now takes in both `table_schema_objs` (output of table retriever), but also the original `query_str` which will then be used for vector retrieval of relevant rows.

In [12]:
def get_table_context_and_rows_str(query_str: str, table_schema_objs: List[SQLTableSchema]):
    """Get table context string."""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        # first append table info + additional context
        table_info = sql_database.get_single_table_info(
            table_schema_obj.table_name
        )
        if table_schema_obj.context_str:
            table_opt_context = " The table description is: "
            table_opt_context += table_schema_obj.context_str
            table_info += table_opt_context

        # also lookup vector index to return relevant table rows
        vector_retriever = vector_index_dict[
            table_schema_obj.table_name
        ].as_retriever(similarity_top_k=2)
        relevant_nodes = vector_retriever.retrieve(query_str)
        if len(relevant_nodes) > 0:
            table_row_context = "\nHere are some relevant example rows (values in the same order as columns above)\n"
            for node in relevant_nodes:
                table_row_context += str(node.get_content()) + "\n"
            table_info += table_row_context

        context_strs.append(table_info)
    return "\n\n".join(context_strs)

In [12]:
query_str = "What was the year that The Notorious B.I.G was signed to Bad Boy?"
table_parser_component = get_table_context_and_rows_str(query_str, table_schema_objs)

### Define Expanded Query Pipeline

This looks similar to the query pipeline in section 1, but with an upgraded table_parser_component.

In [13]:
from utils.t2SQL_workflow.custom_events import (
    TableRetrievedEvent,
    SchemaProcessedEvent,
    SQLPromptReadyEvent,
    SQLGeneratedEvent,
    SQLParsedEvent,
    SQLResultsEvent,
    ResponsePromptReadyEvent,
)
from utils.t2SQL_workflow.custom_fallbacks import (
    extract_sql_from_response,
    analyze_sql_error,
    create_t2s_prompt,
)


class Text2SQLWorkflowRowRetrieval(Workflow):
    @step
    async def input_step(self, ev: StartEvent) -> TableRetrievedEvent:
        """Step 1: Process initial query and retrieve relevant tables"""
        query = ev.query
        tables = obj_retriever.retrieve(query)  # retrieve candidate schemas
        
        return TableRetrievedEvent(
            tables=tables, 
            query_str=query
        )

    @step
    async def table_output_parser_step(self, ev: TableRetrievedEvent) -> SchemaProcessedEvent:
        """Step 2: Parse schemas + retrieve relevant rows"""
        schema_str = get_table_context_and_rows_str(ev.query_str, ev.tables)
        
        return SchemaProcessedEvent(
            table_schema=schema_str, 
            query_str=ev.query_str
        )

    @step
    async def text2sql_prompt_step(self, ev: SchemaProcessedEvent | SQLResultsEvent) -> SQLPromptReadyEvent:
        """Step 3: Create prompt (initial or retry)"""
        if isinstance(ev, SchemaProcessedEvent):
            table_schema = ev.table_schema
            query_str = ev.query_str
            retry_count = 0
            error_message = ""
        else:
            table_schema = getattr(ev, 'table_schema', '')
            query_str = ev.query_str
            retry_count = getattr(ev, 'retry_count', 0) + 1
            error_message = getattr(ev, 'error_message', '')

        prompt = create_t2s_prompt(table_schema, query_str, retry_count, error_message)
        
        return SQLPromptReadyEvent(
            t2s_prompt=prompt,
            query_str=query_str,
            table_schema=table_schema,
            retry_count=retry_count,
            error_message=error_message
        )

    @step
    async def text2sql_llm_step(self, ev: SQLPromptReadyEvent) -> SQLGeneratedEvent:
        """Step 4: Run LLM to generate SQL"""
        sql_response = await get_llm_func().acomplete(ev.t2s_prompt)
        
        return SQLGeneratedEvent(
            sql_query=str(sql_response).strip(),
            query_str=ev.query_str,
            table_schema=ev.table_schema,
            retry_count=ev.retry_count,
            error_message=ev.error_message
        )

    @step
    async def sql_output_parser_step(self, ev: SQLGeneratedEvent) -> SQLParsedEvent:
        """Step 5: Parse/clean SQL"""
        try:
            clean_sql = parse_response_to_sql(ev.sql_query)  # primary parser
        except Exception:
            clean_sql = extract_sql_from_response(ev.sql_query)  # fallback
        
        if not clean_sql:
            clean_sql = extract_sql_from_response(ev.sql_query)

        print(f"Attempt #{ev.retry_count + 1}")
        print(f"LLM Response: {ev.sql_query}")
        print(f"Cleaned SQL: {clean_sql}")

        return SQLParsedEvent(
            sql_query=clean_sql,
            query_str=ev.query_str,
            table_schema=ev.table_schema,
            retry_count=ev.retry_count,
            error_message=ev.error_message
        )

    @step
    async def sql_retriever_step(self, ev: SQLParsedEvent) -> SQLResultsEvent:
        """Step 6: Execute SQL with retries"""
        try:
            results = sql_retriever.retrieve(ev.sql_query)
            print(f"[SUCCESS] Executed on attempt #{ev.retry_count + 1}")
            
            return SQLResultsEvent(
                context_str=str(results),
                sql_query=ev.sql_query,
                query_str=ev.query_str,
                success=True
            )
        except Exception as e:
            error_msg = str(e)
            print(f"[ERROR] Execution failed (Attempt #{ev.retry_count + 1}): {error_msg}")

            if ev.retry_count < MAX_RETRIES:
                retry_event = SQLResultsEvent(
                    context_str="",
                    sql_query=ev.sql_query,
                    query_str=ev.query_str,
                    success=False
                )
                retry_event.retry_count = ev.retry_count + 1
                retry_event.error_message = analyze_sql_error(error_msg, ev.sql_query, ev.table_schema)
                retry_event.table_schema = ev.table_schema
                
                return retry_event
            else:
                return SQLResultsEvent(
                    context_str=f"Failed after {MAX_RETRIES+1} attempts. Final error: {error_msg}",
                    sql_query=ev.sql_query,
                    query_str=ev.query_str,
                    success=False
                )

    @step
    async def retry_handler_step(self, ev: SQLResultsEvent) -> SQLPromptReadyEvent:
        """Step 7: Retry failed SQL by regenerating prompt"""
        if ev.success:
            return None
        
        return SQLPromptReadyEvent(
            t2s_prompt="",  # regenerated later
            query_str=ev.query_str,
            table_schema=getattr(ev, 'table_schema', ''),
            retry_count=ev.retry_count,
            error_message=getattr(ev, 'error_message', 'Unknown error')
        )

    @step
    async def response_synthesis_prompt_step(self, ev: SQLResultsEvent) -> ResponsePromptReadyEvent:
        """Step 8: Prepare final synthesis prompt"""
        if not ev.success:
            return None
        prompt = response_synthesis_prompt.format(
            query_str=ev.query_str,
            context_str=ev.context_str,
            sql_query=ev.sql_query
        )
        
        return ResponsePromptReadyEvent(rs_prompt=prompt)

    @step
    async def response_synthesis_llm_step(self, ev: ResponsePromptReadyEvent) -> StopEvent:
        """Step 9: Generate final human-readable answer"""
        answer = await get_llm_func().acomplete(ev.rs_prompt)

        return StopEvent(result=str(answer))


# Runner
async def run_text2sql_workflow_row(query: str):
    workflow = Text2SQLWorkflowRowRetrieval(timeout=480)
    result = await workflow.run(query=query)
    return result

Visualize

In [14]:
async def visualize_text2sql_workflow(sample_query: str, execution_name: str):
    """
    Function to visualize the Text2SQL workflow both as all possible flows
    and a specific execution example
    """
    output_dir = WORKFLOW_VISUALIZATION_DIR
    os.makedirs(output_dir, exist_ok=True)
    
    sub_dir = os.path.join(output_dir, f"{execution_name}")
    os.makedirs(sub_dir, exist_ok=True)

    # 1. Draw ALL possible flows through your workflow
    print("Drawing all possible flows...")
    all_flows_path = os.path.join(sub_dir, "text2sql_workflow_flow.html")
    draw_all_possible_flows(
        Text2SQLWorkflowRowRetrieval, 
        filename=all_flows_path
    )
    print(f"[SUCCESS] All possible flows saved to: {all_flows_path}")

    # 2. Run workflow + visualize the execution path
    print("Running workflow and drawing execution path...")
    try:
        workflow = Text2SQLWorkflowRowRetrieval(timeout=240)
        result = await workflow.run(query=sample_query)

        # Draw the execution path
        execution_path = os.path.join(sub_dir, "text2sql_workflow_execution.html")
        draw_most_recent_execution(
            workflow,
            filename=execution_path
        )
        print(f"[SUCCESS] Recent execution path saved to: {execution_path}")
        print(f"Workflow result: {result}")
        
    except Exception as e:
        print(f"[ERROR] Error during workflow execution: {e}")
        print("Note: You may need to set up your retriever and LLM settings first")

### Run Some Queries

We can now ask about relevant entries even if it doesn't exactly match the entry in the database.

In [14]:
await visualize_text2sql_workflow(QUERY_1, QUERY_1_INITIAL)

Drawing all possible flows...
outputs\workflow_visualization\BIG_text2sql_workflow_flow.html
[SUCCESS] All possible flows saved to: outputs\workflow_visualization\BIG_text2sql_workflow_flow.html
Running workflow and drawing execution path...
Attempt #1
LLM Response: <think>
Okay, let's see. The user is asking for the year that The Notorious B.I.G was signed to Bad Boy. I need to figure out which table to use here.

First, looking at the tables provided, there's a table called people_info. The columns there are Act, Year_signed, and _Albums_released_under_Bad_Boy. The example rows show that for The Notorious B.I.G, the Year_signed is 1993 and the _Albums_released_under_Bad_Boy is '5'. So the question is about the Year_signed column, which should give the year of signing.

The user's question is straightforward, so I just need to select the Year_signed from the people_info table. There's no need to check other tables like Grammy or drop events because those are unrelated. The rules menti

In [15]:
await visualize_text2sql_workflow(QUERY_2, QUERY_2_INITIAL)

Drawing all possible flows...
outputs\workflow_visualization\Best_Dir_1972_text2sql_workflow_flow.html
[SUCCESS] All possible flows saved to: outputs\workflow_visualization\Best_Dir_1972_text2sql_workflow_flow.html
Running workflow and drawing execution path...
Attempt #1
LLM Response: <think>
Okay, let's see. The user is asking, "Who won best director in the 1972 academy awards?" So I need to find the answer based on the provided tables.

First, I need to figure out which tables contain the relevant data. The user mentioned the tables 'award_nominations', 'award_data_1972', 'grammy_awards', 'award_data', and 'movie_data'. The question is about the 1972 Academy Awards and the Best Director.

Looking at the 'award_data_1972' table, it has columns like Award, Category, Nominee, Result. The example rows show that in 1972, there was an Academy Award for Best Director. So maybe the answer is looking for that specific entry.

But wait, the user is asking for the result of the 1972 Academy Aw