# 如何处理SQL问答中的大型数据库
为了编写针对数据库的有效查询，我们需要向模型提供表名、表结构以及待查询的特征值。当存在大量数据表、列或高基数列时，我们不可能在每次提示中都转储完整的数据库信息。因此，必须找到动态方法，仅将最相关的信息插入提示中。
在本指南中，我们将演示识别此类相关信息并将其输入查询生成步骤的方法。内容包括：
1. 识别相关表格子集；2. 识别列值的相关子集。

## 安装设置
首先，获取所需的包并设置环境变量：

In [None]:
%pip install --upgrade --quiet  langchain langchain-community langchain-openai

In [None]:
# Uncomment the below to use LangSmith. Not required.
# import os
# os.environ["LANGSMITH_API_KEY"] = getpass.getpass()
# os.environ["LANGSMITH_TRACING"] = "true"

以下示例将使用带有 Chinook 数据库的 SQLite 连接。按照[这些安装步骤](https://database.guide/2-sample-databases-sqlite/)在本笔记本同级目录下创建 `Chinook.db` 文件：
* 将[此文件](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql)另存为 `Chinook_Sqlite.sql`* 运行 `sqlite3 Chinook.db`* 运行 `.read Chinook_Sqlite.sql`* 测试 `SELECT * FROM Artist LIMIT 10;`
现在，`Chinook.db` 已存在于我们的目录中，我们可以使用基于 SQLAlchemy 的 [SQLDatabase](https://python.langchain.com/api_reference/community/utilities/langchain_community.utilities.sql_database.SQLDatabase.html) 类与之交互：

In [1]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
print(db.run("SELECT * FROM Artist LIMIT 10;"))

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]


## 众多表格
我们需要在提示中包含的主要信息之一就是相关表的模式。当我们有非常多的表时，无法将所有表的模式都放入一个提示中。在这种情况下，我们可以先提取与用户输入相关的表名，然后仅包含这些表的模式。
一种简单可靠的方法是使用[工具调用](/docs/how_to/tool_calling)。下面我们将展示如何利用这一功能来获取符合指定格式的输出（本例中为表名列表）。我们通过聊天模型的`.bind_tools`方法绑定Pydantic格式的工具，并将其输入输出解析器，从而根据模型响应重建对象。
import ChatModelTabs from "@theme/ChatModelTabs";
<ChatModelTabs customVarName="llm" />


In [3]:
# | output: false
# | echo: false

from langchain_openai import ChatOpenAI

llm = ChatOpenAI()

In [4]:
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field


class Table(BaseModel):
    """Table in SQL database."""

    name: str = Field(description="Name of table in SQL database.")


table_names = "\n".join(db.get_usable_table_names())
system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_names}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{input}"),
    ]
)
llm_with_tools = llm.bind_tools([Table])
output_parser = PydanticToolsParser(tools=[Table])

table_chain = prompt | llm_with_tools | output_parser

table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Genre')]

效果相当不错！不过，正如我们稍后将看到的，实际上还需要其他几张表。仅凭用户的问题，模型很难准确判断这一点。在这种情况下，我们或许可以通过将表格分组来简化模型的工作。我们只需让模型在"音乐"和"商业"两个类别之间做出选择，然后由我们负责从中选取所有相关的表格：

In [5]:
system = """Return the names of any SQL tables that are relevant to the user question.
The tables are:

Music
Business
"""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{input}"),
    ]
)

category_chain = prompt | llm_with_tools | output_parser
category_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

[Table(name='Music'), Table(name='Business')]

In [6]:
from typing import List


def get_tables(categories: List[Table]) -> List[str]:
    tables = []
    for category in categories:
        if category.name == "Music":
            tables.extend(
                [
                    "Album",
                    "Artist",
                    "Genre",
                    "MediaType",
                    "Playlist",
                    "PlaylistTrack",
                    "Track",
                ]
            )
        elif category.name == "Business":
            tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])
    return tables


table_chain = category_chain | get_tables
table_chain.invoke({"input": "What are all the genres of Alanis Morisette songs"})

['Album',
 'Artist',
 'Genre',
 'MediaType',
 'Playlist',
 'PlaylistTrack',
 'Track',
 'Customer',
 'Employee',
 'Invoice',
 'InvoiceLine']

既然我们已经拥有一个能够为任何查询输出相关表格的链，我们可以将其与我们的[create_sql_query_chain](https://python.langchain.com/api_reference/langchain/chains/langchain.chains.sql_database.query.create_sql_query_chain.html)结合使用。该链可以接受一个`table_names_to_use`列表，用于确定提示中包含哪些表结构：

In [7]:
from operator import itemgetter

from langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough

query_chain = create_sql_query_chain(llm, db)
# Convert "question" key to the "input" key expected by current table_chain.
table_chain = {"input": itemgetter("question")} | table_chain
# Set table_names_to_use using table_chain.
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain

In [8]:
query = full_chain.invoke(
    {"question": "What are all the genres of Alanis Morisette songs"}
)
print(query)

SELECT DISTINCT "g"."Name"
FROM "Genre" g
JOIN "Track" t ON "g"."GenreId" = "t"."GenreId"
JOIN "Album" a ON "t"."AlbumId" = "a"."AlbumId"
JOIN "Artist" ar ON "a"."ArtistId" = "ar"."ArtistId"
WHERE "ar"."Name" = 'Alanis Morissette'
LIMIT 5;


In [9]:
db.run(query)

"[('Rock',)]"

我们可以在此处查看此次运行的LangSmith追踪记录：[点击这里](https://smith.langchain.com/public/4fbad408-3554-4f33-ab47-1e510a1b52a3/r)。
我们已了解如何在链中动态地将表结构子集包含到提示信息里。针对该问题的另一种可行方法是赋予代理（Agent）一个查询工具，让它自行决定何时查找表。您可以在[SQL：代理](/docs/tutorials/agents)指南中查看相关示例。

## 高基数列
为了筛选包含专有名词（如地址、歌曲名或艺术家名）的列，我们首先需要仔细核对拼写以确保数据筛选的准确性。
一种简单的策略是创建一个包含数据库中所有不同专有名词的向量存储。然后，我们可以针对每个用户输入查询该向量存储，并将最相关的专有名词注入到提示中。
首先我们需要获取每个所需实体的唯一值，为此我们定义一个将结果解析为元素列表的函数：

In [10]:
import ast
import re


def query_as_list(db, query):
    res = db.run(query)
    res = [el for sub in ast.literal_eval(res) for el in sub if el]
    res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
    return res


proper_nouns = query_as_list(db, "SELECT Name FROM Artist")
proper_nouns += query_as_list(db, "SELECT Title FROM Album")
proper_nouns += query_as_list(db, "SELECT Name FROM Genre")
len(proper_nouns)
proper_nouns[:5]

['AC/DC', 'Accept', 'Aerosmith', 'Alanis Morissette', 'Alice In Chains']

现在我们可以将所有值嵌入并存储到向量数据库中：

In [11]:
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings

vector_db = FAISS.from_texts(proper_nouns, OpenAIEmbeddings())
retriever = vector_db.as_retriever(search_kwargs={"k": 15})

并构建一个查询构造链，首先从数据库中检索值并将其插入到提示中：

In [12]:
from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough

system = """You are a SQLite expert. Given an input question, create a syntactically
correct SQLite query to run. Unless otherwise specificed, do not return more than
{top_k} rows.

Only return the SQL query with no markup or explanation.

Here is the relevant table info: {table_info}

Here is a non-exhaustive list of possible feature values. If filtering on a feature
value make sure to check its spelling against this list first:

{proper_nouns}
"""

prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{input}")])

query_chain = create_sql_query_chain(llm, db, prompt=prompt)
retriever_chain = (
    itemgetter("question")
    | retriever
    | (lambda docs: "\n".join(doc.page_content for doc in docs))
)
chain = RunnablePassthrough.assign(proper_nouns=retriever_chain) | query_chain

为了测试我们的链式系统，让我们看看在尝试过滤"elenis moriset"（Alanis Morissette的错误拼写）时会发生什么，分别在不使用检索和使用检索的情况下：

In [13]:
# Without retrieval
query = query_chain.invoke(
    {"question": "What are all the genres of elenis moriset songs", "proper_nouns": ""}
)
print(query)
db.run(query)

SELECT DISTINCT g.Name 
FROM Track t
JOIN Album a ON t.AlbumId = a.AlbumId
JOIN Artist ar ON a.ArtistId = ar.ArtistId
JOIN Genre g ON t.GenreId = g.GenreId
WHERE ar.Name = 'Elenis Moriset';


''

In [14]:
# With retrieval
query = chain.invoke({"question": "What are all the genres of elenis moriset songs"})
print(query)
db.run(query)

SELECT DISTINCT g.Name
FROM Genre g
JOIN Track t ON g.GenreId = t.GenreId
JOIN Album a ON t.AlbumId = a.AlbumId
JOIN Artist ar ON a.ArtistId = ar.ArtistId
WHERE ar.Name = 'Alanis Morissette';


"[('Rock',)]"

我们可以看到，通过检索功能，我们成功将拼写从“Elenis Moriset”纠正为“Alanis Morissette”，并获得了有效的结果。
解决此问题的另一种可能方法是让代理自行决定何时查找专有名词。您可以在[SQL：代理](/docs/tutorials/agents)指南中查看相关示例。