# 实战案例：使用LlamaIndex工作流实现自然语言查询数据库

引言
在大模型应用开发中，检索增强生成（RAG）已成为提升AI回答准确性的关键技术。本文介绍如何利用LlamaIndex的工作流功能，实现一套完整的自然语言查询数据库系统，让用户可以使用自然语言直接查询结构化数据，无需编写复杂的SQL语句。

## 1. LlamaIndex工作流简介
LlamaIndex工作流是一种事件驱动的机制，它通过以下方式组织工作步骤：
工作流由一系列步骤(step)组成
每个步骤处理特定的事件
步骤会产生新的事件，交由后继步骤处理
直到产生StopEvent，整个工作流结束
工作流的优势在于将复杂任务分解为简单步骤，使代码结构清晰，便于维护和扩展。

## 2. 案例需求：自然语言查询数据库
本文将实现一个自然语言查询数据库的系统，主要步骤包括：
用户输入自然语言查询
系统检索相关的数据库表
根据表的结构生成SQL查询
执行SQL查询获取结果
生成自然语言回复

## 3. 环境准备

首先，我们需要安装必要的依赖包

In [None]:
# 必要的基础包
%pip install llama-index
%pip install llama-index-vector-stores-qdrant
%pip install llama-index-llms-openai
%pip install llama-index-embeddings-openai

# 用于可视化工作流的包
%pip install llama-index-utils-workflow

## 4. 数据准备

本案例使用WikiTableQuestions数据集，包含了多个CSV格式的表格数据：

"https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip

In [None]:
import pandas as pd
from pathlib import Path

# 加载数据集中的CSV文件
data_dir = Path("./data/WikiTableQuestions-1.0.2-compact/WikiTableQuestions/csv/200-csv")
csv_files = sorted([f for f in data_dir.glob("*.csv")])
dfs = []
for csv_file in csv_files:
    print(f"processing file: {csv_file}")
    try:
        df = pd.read_csv(csv_file)
        dfs.append(df)
    except Exception as e:
        print(f"Error parsing {csv_file}: {str(e)}")

为了便于检索，我们需要为每个表生成描述文本：

In [4]:
from llama_index.core.prompts import ChatPromptTemplate 
from llama_index.core.bridge.pydantic import BaseModel, Field 
from llama_index.llms.openai import OpenAI
from llama_index.core.llms import ChatMessage


class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(
        ..., description="table name (must be underscores and NO spaces)"
    )
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )


prompt_str = """\
Give me a summary of the table with the following JSON format.

- The table name must be unique to the table and describe it while being concise. 
- Do NOT output a generic table name (e.g. table, my_table).

Do NOT make the table name one of the following: {exclude_table_name_list}

Table:
{table_str}

Summary: """
prompt_tmpl = ChatPromptTemplate(
    message_templates=[ChatMessage.from_str(prompt_str, role="user")]
)
llm = OpenAI(model="gpt-4o")

In [5]:
import os

tableinfo_dir = "./data/WikiTableQuestions-1.0.2-compact/WikiTableQuestions_TableInfo"
os.makedirs(tableinfo_dir, exist_ok=True)
print(f"✓ 目录创建成功：{tableinfo_dir}")

✓ 目录创建成功：./data/WikiTableQuestions-1.0.2-compact/WikiTableQuestions_TableInfo


In [6]:
import json


def _get_tableinfo_with_index(idx: int) -> str:
    results_gen = Path(tableinfo_dir).glob(f"{idx}_*")
    results_list = list(results_gen)
    if len(results_list) == 0:
        return None
    elif len(results_list) == 1:
        path = results_list[0]
        with open(path, 'r') as file:
            data = json.load(file) 
            return TableInfo.model_validate(data)
    else:
        raise ValueError(
            f"More than one file matching index: {list(results_gen)}"
        )


table_names = set()
table_infos = []
max_retries = 5 # 最大重试次数
retry_count = 0 # 当前重试次数  
for idx, df in enumerate(dfs):
    table_info = _get_tableinfo_with_index(idx)
    if table_info:
        table_infos.append(table_info)
    else:
        while True:
            df_str = df.head(10).to_csv()
            table_info = llm.structured_predict(
                TableInfo,
                prompt_tmpl,
                table_str=df_str,
                exclude_table_name_list=str(list(table_names)),
            )
            table_name = table_info.table_name
            # print(f"Processed table: {table_name}")
            print(f"处理表格: {table_name} (第 {retry_count + 1} 次尝试)")
            if table_name not in table_names:
                table_names.add(table_name)
                break
            else:
                # try again
                retry_count += 1 # 重试次数加1
                if retry_count >= max_retries:
                    # 添加随机后缀确保唯一性
                    import random
                    random_suffix = f"_{random.randint(1, 999):03d}"
                    table_name = f"{table_name}{random_suffix}"
                    table_names.add(table_name)
                    print(f"达到最大重试次数，使用带随机后缀的表名: {table_name}")
                    break
                print(f"表名 {table_name} 已存在，正在重试 ({retry_count}/{max_retries})")
                # print(f"Table name {table_name} already exists, trying again.")
                pass
        out_file = f"{tableinfo_dir}/{idx}_{table_name}.json"
        print(f"准备保存表格信息到文件: {out_file}")
        try:
            with open(out_file, "w", encoding='utf-8') as f:
                json.dump(table_info.model_dump(), f, ensure_ascii=False, indent=2)
                # json.dump(table_info.dict(), f, ensure_ascii=False, indent=2)
                print(f"✓ 成功保存表格信息到: {out_file}")
        except Exception as e:
            print(f"❌ 保存表格信息失败 {out_file}: {str(e)}")
            raise
    table_infos.append(table_info)

## 5. 将表存入SQLite数据库：

In [7]:
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
)
import re


# Function to create a sanitized column name
def sanitize_column_name(col_name):
    # Remove special characters and replace spaces with underscores
    return re.sub(r"\W+", "_", col_name)


# Function to create a table from a DataFrame using SQLAlchemy
def create_table_from_dataframe(
    df: pd.DataFrame, table_name: str, engine, metadata_obj
):
    # Sanitize column names
    sanitized_columns = {col: sanitize_column_name(col) for col in df.columns}
    df = df.rename(columns=sanitized_columns)

    # Dynamically create columns based on DataFrame columns and data types
    columns = [
        Column(col, String if dtype == "object" else Integer)
        for col, dtype in zip(df.columns, df.dtypes)
    ]

    # Create a table with the defined columns
    table = Table(table_name, metadata_obj, *columns)

    # Create the table in the database
    metadata_obj.create_all(engine)

    # Insert data from DataFrame into the table
    with engine.connect() as conn:
        for _, row in df.iterrows():
            insert_stmt = table.insert().values(**row.to_dict())
            conn.execute(insert_stmt)
        conn.commit()


# engine = create_engine("sqlite:///:memory:")
engine = create_engine("sqlite:///wiki_table_questions.db")
metadata_obj = MetaData()
for idx, df in enumerate(dfs):
    tableinfo = _get_tableinfo_with_index(idx)
    print(f"Creating table: {tableinfo.table_name}")
    create_table_from_dataframe(df, tableinfo.table_name, engine, metadata_obj)

Creating table: progressive_rock_album_chart_positions
Creating table: filmography_of_diane
Creating table: annual_fatalities_and_accidents
Creating table: academy_awards_and_nominations_1972
Creating table: theatrical_awards_nominations
Creating table: bad_boy_artists_and_albums
Creating table: south_dakota_radio_stations
Creating table: missing_persons_cases_1982
Creating table: chart_performance_of_singles
Creating table: kodachrome_film_types_and_dates
Creating table: bbc_radio_service_costs_2012_2013
Creating table: french_airports_and_aerodromes
Creating table: voter_registration_statistics
Creating table: norwegian_club_performance_statistics
Creating table: triple_crown_winners
Creating table: grammy_awards_nominations_and_wins
Creating table: boxing_fight_records
Creating table: historical_college_football_records
Creating table: yamato_district_population_density
Creating table: voter_registration_statistics_by_party
Creating table: french_film_awards_performance
Creating tab

## 6. 构建基础工具

创建基于表的描述的向量索引

ObjectIndex 是一个 LlamaIndex 内置的模块，通过索引 (Index）检索任意 Python 对象
这里我们使用 VectorStoreIndex 也就是向量检索，并通过 SQLTableNodeMapping 将文本描述的 node 和数据库的表形成映射

In [8]:
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import SQLDatabase, VectorStoreIndex

sql_database = SQLDatabase(engine)

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    SQLTableSchema(table_name=t.table_name, context_str=t.table_summary)
    for t in table_infos
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
obj_retriever = obj_index.as_retriever(similarity_top_k=3)

In [9]:
# 创建 SQLRetriever
from llama_index.core.retrievers import SQLRetriever
from typing import List

sql_retriever = SQLRetriever(sql_database)


def get_table_context_str(table_schema_objs: List[SQLTableSchema]):
    """Get table context string."""
    context_strs = []
    for table_schema_obj in table_schema_objs:
        table_info = sql_database.get_single_table_info(
            table_schema_obj.table_name
        )
        if table_schema_obj.context_str:
            table_opt_context = " The table description is: "
            table_opt_context += table_schema_obj.context_str
            table_info += table_opt_context

        context_strs.append(table_info)
    return "\n\n".join(context_strs)

In [10]:
# 创建 Text2SQL 的提示词（系统默认模板），和输出结果解析器（从生成的文本中抽取SQL）
from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from llama_index.core import PromptTemplate
from llama_index.core.llms import ChatResponse

def parse_response_to_sql(chat_response: ChatResponse) -> str:
    """Parse response to SQL."""
    response = chat_response.message.content
    sql_query_start = response.find("SQLQuery:")
    if sql_query_start != -1:
        response = response[sql_query_start:]
        # TODO: move to removeprefix after Python 3.9+
        if response.startswith("SQLQuery:"):
            response = response[len("SQLQuery:") :]
    sql_result_start = response.find("SQLResult:")
    if sql_result_start != -1:
        response = response[:sql_result_start]
    return response.strip().strip("```").strip()


text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
    dialect=engine.dialect.name
)
print(text2sql_prompt.template)

Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use tables listed below.
{schema}

Question: {query_str}
SQLQuery: 


In [11]:
# 创建自然语言回复生成模板
response_synthesis_prompt_str = (
    "Given an input question, synthesize a response from the query results.\n"
    "Query: {query_str}\n"
    "SQL: {sql_query}\n"
    "SQL Response: {context_str}\n"
    "Response: "
)
response_synthesis_prompt = PromptTemplate(
    response_synthesis_prompt_str,
)

## 7. 定义工作流

In [12]:
from llama_index.core.workflow import (
    Workflow,
    StartEvent,
    StopEvent,
    step,
    Context,
    Event,
)

# 事件：找到了数据库中相关的表
class TableRetrieveEvent(Event):
    """Result of running table retrieval."""

    table_context_str: str
    query: str

# 事件：文本转为了 SQL
class TextToSQLEvent(Event):
    """Text-to-SQL event."""

    sql: str
    query: str


class TextToSQLWorkflow1(Workflow):
    """Text-to-SQL Workflow that does query-time table retrieval."""

    def __init__(
        self,
        obj_retriever,
        text2sql_prompt,
        sql_retriever,
        response_synthesis_prompt,
        llm,
        *args,
        **kwargs
    ) -> None:
        """Init params."""
        super().__init__(*args, **kwargs)
        self.obj_retriever = obj_retriever
        self.text2sql_prompt = text2sql_prompt
        self.sql_retriever = sql_retriever
        self.response_synthesis_prompt = response_synthesis_prompt
        self.llm = llm

    @step
    def retrieve_tables(
        self, ctx: Context, ev: StartEvent
    ) -> TableRetrieveEvent:
        """Retrieve tables."""
        table_schema_objs = self.obj_retriever.retrieve(ev.query)
        table_context_str = get_table_context_str(table_schema_objs)
        print("====\n"+table_context_str+"\n====")
        return TableRetrieveEvent(
            table_context_str=table_context_str, query=ev.query
        )

    @step
    def generate_sql(
        self, ctx: Context, ev: TableRetrieveEvent
    ) -> TextToSQLEvent:
        """Generate SQL statement."""
        fmt_messages = self.text2sql_prompt.format_messages(
            query_str=ev.query, schema=ev.table_context_str
        )
        chat_response = self.llm.chat(fmt_messages)
        sql = parse_response_to_sql(chat_response)
        print("====\n"+sql+"\n====")
        return TextToSQLEvent(sql=sql, query=ev.query)

    @step
    def generate_response(self, ctx: Context, ev: TextToSQLEvent) -> StopEvent:
        """Run SQL retrieval and generate response."""
        retrieved_rows = self.sql_retriever.retrieve(ev.sql)
        print("====\n"+str(retrieved_rows)+"\n====")
        fmt_messages = self.response_synthesis_prompt.format_messages(
            sql_query=ev.sql,
            context_str=str(retrieved_rows),
            query_str=ev.query,
        )
        chat_response = llm.chat(fmt_messages)
        return StopEvent(result=chat_response)

In [13]:
workflow = TextToSQLWorkflow1(
    obj_retriever,
    text2sql_prompt,
    sql_retriever,
    response_synthesis_prompt,
    llm,
    verbose=True,
)

In [14]:
response = await workflow.run(
    query="What was the year that The Notorious B.I.G was signed to Bad Boy?"
)
print(str(response))

Running step retrieve_tables
====
Table 'bad_boy_artists_and_albums' has columns: Act (VARCHAR), Year_signed (INTEGER), _Albums_released_under_Bad_Boy (VARCHAR), . The table description is: List of artists signed to Bad Boy Records along with the year signed and number of albums released.

Table 'grammy_awards_nominations_and_wins' has columns: Year (INTEGER), Award (VARCHAR), Work_Artist (VARCHAR), Result (VARCHAR), . The table description is: Summary of Grammy Award nominations and wins for the album 'A Ghost Is Born' and other works by Wilco.

Table 'progressive_rock_album_chart_positions' has columns: Year (INTEGER), Title (VARCHAR), Chart_Positions_UK (VARCHAR), Chart_Positions_US (VARCHAR), Chart_Positions_NL (VARCHAR), Comments (VARCHAR), . The table description is: Chart positions of progressive rock albums in the UK, US, and NL from 1969 to 1981.
====
Step retrieve_tables produced event TableRetrieveEvent
Running step generate_sql
====
SELECT Year_signed FROM bad_boy_artists_a

## 8. 可视化工作流

In [15]:
from llama_index.utils.workflow import draw_all_possible_flows

draw_all_possible_flows(
    TextToSQLWorkflow1, filename="text_to_sql_table_retrieval.html"
)

<class 'NoneType'>
<class 'llama_index.core.workflow.events.StopEvent'>
<class '__main__.TextToSQLEvent'>
<class '__main__.TableRetrieveEvent'>
text_to_sql_table_retrieval.html
