# 学校信息查询

初步结论，nl2sql能够满足一些场景的使用。

为了能让本地模型回答正确，需要调整提示词，使之能稳定的输出正确答案。

## 准备

In [1]:
%%time

# 加载llm和embeddings
%run ../utils2.py

from llama_index.core import Settings

Settings.llm=get_llm()
Settings.embed_model = get_embedding()

CPU times: user 3.4 s, sys: 400 ms, total: 3.8 s
Wall time: 3.57 s


In [2]:
%%time

# 建立连接和表
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
)

engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

# create city SQL table
table_name = "school_info"
school_info_table = Table(
    table_name,
    metadata_obj,
    Column("school_name", String(200), primary_key=True),
    Column("students_enrolled", Integer,nullable=False),
)
metadata_obj.create_all(engine)

CPU times: user 8.84 ms, sys: 101 µs, total: 8.94 ms
Wall time: 11.2 ms


In [3]:
%%time

from llama_index.core import SQLDatabase

sql_database = SQLDatabase(engine, include_tables=["school_info"])

CPU times: user 5.05 ms, sys: 130 µs, total: 5.18 ms
Wall time: 4.49 ms


In [4]:
%%time

from sqlalchemy import insert

rows = [
    {"school_name": "北京市第八十中学", "students_enrolled": 260},
    {"school_name": "北京市陈经纶中学", "students_enrolled": 279},
    {"school_name": "北京市日坛中学", "students_enrolled": 403},
    {"school_name": "中国人民大学附属中学朝阳学校", "students_enrolled": 247},
    {"school_name": "北京工业大学附属中学", "students_enrolled": 418},
    {"school_name": "北京中学", "students_enrolled": 121},
]
for row in rows:
    stmt = insert(school_info_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

CPU times: user 3.12 ms, sys: 0 ns, total: 3.12 ms
Wall time: 2.88 ms


In [5]:
%%time

import pandas as pd

df = pd.read_sql_query("SELECT * from school_info", engine)
df.head()

CPU times: user 1.54 ms, sys: 0 ns, total: 1.54 ms
Wall time: 1.37 ms


Unnamed: 0,school_name,students_enrolled
0,北京市第八十中学,260
1,北京市陈经纶中学,279
2,北京市日坛中学,403
3,中国人民大学附属中学朝阳学校,247
4,北京工业大学附属中学,418


In [6]:
%%time

import logging
import sys

# logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

CPU times: user 229 µs, sys: 0 ns, total: 229 µs
Wall time: 232 µs


## 基本查询

### 有效的查询用例

In [7]:
%%time

from llama_index.core.query_engine import NLSQLTableQueryEngine

query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database, tables=["school_info"],
)
query_str = "招生最多的是哪个学校?"
response = query_engine.query(query_str)

response.response

INFO:llama_index.core.indices.struct_store.sql_retriever:> Table desc str: Table 'school_info' has columns: school_name (VARCHAR(200)), students_enrolled (INTEGER), and foreign keys: .
> Table desc str: Table 'school_info' has columns: school_name (VARCHAR(200)), students_enrolled (INTEGER), and foreign keys: .
INFO:httpx:HTTP Request: POST http://192.168.0.72:3000/v1/chat/completions "HTTP/1.1 200 OK"
HTTP Request: POST http://192.168.0.72:3000/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://192.168.0.72:3000/v1/chat/completions "HTTP/1.1 200 OK"
HTTP Request: POST http://192.168.0.72:3000/v1/chat/completions "HTTP/1.1 200 OK"
CPU times: user 353 ms, sys: 23 ms, total: 376 ms
Wall time: 4.44 s


'招生人数最多的学校是北京工业大学附属中学，共有418名学生。'

In [8]:
query_str = "北京中学招多少?"
response = query_engine.query(query_str)

response.response

'北京市第八十中学的招生人数为260人。'

### 不能模糊查询的解决办法 - 自定义提示词

#### 本地模型qwen1.5:7b有效 -- qwen2-7b无效

In [10]:
query_str = "陈经纶招多少?"
response = query_engine.query(query_str)

response.response

'对不起，当前查询结果中没有提供关于"陈经纶"招收学生的具体数字。可能需要进一步的信息或者确认学校名称的准确性。'

In [11]:
%%time

from llama_index.core import PromptTemplate

prompts = query_engine.get_prompts()
prompts['sql_retriever:text_to_sql_prompt'].template

CPU times: user 143 µs, sys: 41 µs, total: 184 µs
Wall time: 189 µs


'Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.\n\nNever query for all the columns from a specific table, only ask for a few relevant columns given the question.\n\nPay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:\n\nQuestion: Question here\nSQLQuery: SQL Query to run\nSQLResult: Result of the SQLQuery\nAnswer: Final answer here\n\nOnly use tables listed below.\n{schema}\n\nQuestion: {query_str}\nSQLQuery: '

In [12]:
%%time

old_prompt_str=prompts['sql_retriever:text_to_sql_prompt'].template
new_prompt = PromptTemplate(
    f"{old_prompt_str}"
    "关键字的查询使用like做全模糊查询，比如 school_name like '%keyword%'\n"
    "查询学校名称时，使用模糊查询"
)
query_engine.update_prompts({"sql_retriever:text_to_sql_prompt": new_prompt})

CPU times: user 96 µs, sys: 0 ns, total: 96 µs
Wall time: 99.2 µs


In [13]:
query_str = "陈经纶招多少?"
response = query_engine.query(query_str)

response.response

'对不起，当前查询结果中无法提供具体的招生人数。请尝试其他关键词或详细信息进行搜索。'

In [14]:
%%time

query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database,
    tables=["school_info"],
    llm=get_llm('qwen:14b')
)

old_prompt_str=prompts['sql_retriever:text_to_sql_prompt'].template
new_prompt = PromptTemplate(
    f"{old_prompt_str}"
    # "关键字的查询使用like做全模糊查询，比如 school_name like '%keyword%'\n"
    "查询学校名称时，使用模糊查询"
)
query_engine.update_prompts({"sql_retriever:text_to_sql_prompt": new_prompt})

query_str = "陈经纶招多少?"
response = query_engine.query(query_str)

response.response

CPU times: user 60.6 ms, sys: 0 ns, total: 60.6 ms
Wall time: 12.9 s


'陈经纶学校招收了279名学生。'

#### 部分云端模型是可以的 - 文心一言不行

In [15]:
%%time

query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database,
    tables=["school_info"],
    llm=get_llm('qwen-turbo')
    # llm=get_llm('glm-4')
    # llm=get_llm('moonshot-v1-8k')
)

old_prompt_str=prompts['sql_retriever:text_to_sql_prompt'].template
new_prompt = PromptTemplate(
    f"{old_prompt_str}"
    # "关键字的查询使用like做全模糊查询，比如 school_name like '%keyword%'\n"
    "查询学校名称时，使用模糊查询"
)
query_engine.update_prompts({"sql_retriever:text_to_sql_prompt": new_prompt})

query_str = "陈经纶招多少?"
response = query_engine.query(query_str)

response.response

CPU times: user 61.2 ms, sys: 0 ns, total: 61.2 ms
Wall time: 3.8 s


'陈经纶学校的学生人数为279人。'

### 基础查询的问题

- 无法流式输出
- 只能用于管道或者组合使用

## 基于表索引的查询

### 直接使用 NLSQLTableQueryEngine 的问题

直接使用 NLSQLTableQueryEngine，当表模式过多时，会造成生成的提示词超过 LLM 限制。

### 使用表索引的查询

In [16]:
%%time

from llama_index.core.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine,
)
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import VectorStoreIndex

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="school_info"))
]

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
query_engine = SQLTableRetrieverQueryEngine(
    sql_database, 
    obj_index.as_retriever(similarity_top_k=1)
)

response = query_engine.query("招生最多的是哪个学校, 招多少?")
response.response

CPU times: user 27.2 ms, sys: 0 ns, total: 27.2 ms
Wall time: 6.58 s


'招生最多的是北京工业大学附属中学，招收了418名学生。'

### SQLTableRetrieverQueryEngine 不支持流式输出

In [19]:
%%time

from llama_index.core import get_response_synthesizer

synth = get_response_synthesizer(streaming=True)

query_engine = SQLTableRetrieverQueryEngine(
    sql_database, 
    obj_index.as_retriever(similarity_top_k=1),
    synthesize_response=synth,
)

streaming_response = query_engine.query("招生最多的是哪个学校, 招多少?")
# streaming_response.print_response_stream()
# for text in streaming_response.response_gen:
#     print(text,end="")
streaming_response

CPU times: user 17.8 ms, sys: 0 ns, total: 17.8 ms
Wall time: 2.92 s


Response(response='招生人数最多的学校是北京工业大学附属中学，招收了418名学生。', source_nodes=[NodeWithScore(node=TextNode(id_='d689eed9-afe7-4d42-aea6-2338f814c86f', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text="[('北京工业大学附属中学', 418)]", start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=None)], metadata={'d689eed9-afe7-4d42-aea6-2338f814c86f': {}, 'sql_query': 'SELECT school_name, students_enrolled FROM school_info ORDER BY students_enrolled DESC LIMIT 1;', 'result': [('北京工业大学附属中学', 418)], 'col_keys': ['school_name', 'students_enrolled']})

## 单独使用sql检索以及集成到查询引擎

### 单独使用sql检索

#### 基本使用

In [30]:
%%time

from llama_index.core.retrievers import NLSQLRetriever

# default retrieval (return_raw=True)
nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["school_info"], return_raw=True
)

results = nl_sql_retriever.retrieve(
    "招生最多的前三个学校?"
)

results[0].text

CPU times: user 6.72 ms, sys: 343 µs, total: 7.06 ms
Wall time: 1.97 s


"[('北京工业大学附属中学', 418), ('北京市日坛中学', 403), ('北京市陈经纶中学', 279)]"

#### 支持模糊查询

In [21]:
%%time

from llama_index.core.retrievers import NLSQLRetriever

# default retrieval (return_raw=True)
nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["school_info"], 
    # return_raw=True,
    return_raw=False,
    # llm=get_llm("qwen-turbo")
)

old_prompt_str=nl_sql_retriever.get_prompts()['text_to_sql_prompt'].template
new_prompt = PromptTemplate(
    f"{old_prompt_str}"
    "关键字的查询使用like做全模糊查询，比如 school_name like '%keyword%'\n"
    "查询学校名称时，使用模糊查询，比如 school_name like '%keyword%'，并且查询结果附带学校名称"
)
nl_sql_retriever.update_prompts({"text_to_sql_prompt": new_prompt})

results = nl_sql_retriever.retrieve(
    "陈经纶招多少?"
)

results[0]

CPU times: user 7.32 ms, sys: 0 ns, total: 7.32 ms
Wall time: 3.07 s


NodeWithScore(node=TextNode(id_='88df7ba5-4b58-4e2a-b438-394ffb6d0128', embedding=None, metadata={'school_name': '北京市陈经纶中学', 'students_enrolled': 279}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text='', start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=None)

### 传入sql检索到查询引擎

#### 基本使用

In [22]:
%%time

from llama_index.core.query_engine import RetrieverQueryEngine

query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever)

response = query_engine.query(
    "招生最多的前三个学校?"
)

response.response

CPU times: user 14.5 ms, sys: 0 ns, total: 14.5 ms
Wall time: 3.95 s


'招生最多的前三个学校是北京工业大学附属中学，它招收了418名学生；其次是北京市日坛中学，招收了403名学生；排在第三位的是北京市陈经纶中学，招收了279名学生。'

#### 使用流式输出

In [23]:
%%time

query_engine = RetrieverQueryEngine.from_args(
    nl_sql_retriever,
    streaming=True
)

response = query_engine.query(
    "招生最多的前三个学校?"
)
response.print_response_stream()

招生最多的前三个学校是北京工业大学附属中学，它招收了418名学生；其次是北京市日坛中学，招收了403名学生；排在第三位的是北京市陈经纶中学，招收了279名学生。CPU times: user 86.6 ms, sys: 0 ns, total: 86.6 ms
Wall time: 3.97 s


#### 自定义提示词实现模糊查询

In [24]:
%%time

nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["school_info"], 
    return_raw=True,
)

old_prompt_str=nl_sql_retriever.get_prompts()['text_to_sql_prompt'].template
new_prompt = PromptTemplate(
    f"{old_prompt_str}"
    "关键字的查询使用like做全模糊查询，比如 school_name like '%keyword%'\n"
    "查询学校名称时，使用模糊查询，比如 school_name like '%keyword%'，并且查询结果附带学校名称"
)
nl_sql_retriever.update_prompts({"text_to_sql_prompt": new_prompt})

results = nl_sql_retriever.retrieve(
    "陈经纶招多少?"
)

results[0].text

CPU times: user 7.87 ms, sys: 0 ns, total: 7.87 ms
Wall time: 2.9 s


"[('北京市陈经纶中学', 279)]"

In [25]:
%%time

# logging.disable(logging.CRITICAL + 1)

query_engine = RetrieverQueryEngine.from_args(
    nl_sql_retriever,
    streaming=True,
)

response = query_engine.query(
    "陈经纶招多少?"
)
response.print_response_stream()

根据提供的上下文信息，无法直接确定“陈经纶招多少”。需要更多具体的信息来解答这个问题。CPU times: user 50.5 ms, sys: 0 ns, total: 50.5 ms
Wall time: 3.83 s


In [26]:
%%time

from llama_index.core.prompts import PromptType

my_qa_prompt_template = (
    "回答的示例，数据：[('北京市陈经纶中学', 279)]，回答：北京市陈经纶中学招收学生数量为279人。"
    "Context information is below.\n"
    "---------------------\n"
    "{context_str}\n"
    "---------------------\n"
    "Given the context information and not prior knowledge, "
    "answer the query.\n"
    "Query: {query_str}\n"
    "Answer: "
)
my_qa_prompt = PromptTemplate(
    my_qa_prompt_template, prompt_type=PromptType.QUESTION_ANSWER
)

query_engine = RetrieverQueryEngine.from_args(
    nl_sql_retriever,
    streaming=True,
    text_qa_template=my_qa_prompt,
)

response = query_engine.query(
    "陈经纶招多少?"
)
response.print_response_stream()

北京市陈经纶中学招收学生数量为279人。CPU times: user 31 ms, sys: 4.68 ms, total: 35.7 ms
Wall time: 3.61 s


#### 基于sql返回raw的处理 - 可以少写一层提示词

In [29]:
%%time

nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["school_info"], 
    return_raw=False,
)

old_prompt_str=nl_sql_retriever.get_prompts()['text_to_sql_prompt'].template
new_prompt = PromptTemplate(
    f"{old_prompt_str}"
    "关键字的查询使用like做全模糊查询，比如 school_name like '%keyword%'\n"
    "查询学校名称时，使用模糊查询，比如 school_name like '%keyword%'，并且查询结果附带学校名称"
)
nl_sql_retriever.update_prompts({"text_to_sql_prompt": new_prompt})

results = nl_sql_retriever.retrieve(
    "陈经纶招多少?"
)

# results

query_engine = RetrieverQueryEngine.from_args(
    nl_sql_retriever,
    streaming=True,
)

response = query_engine.query(
    "陈经纶招多少?"
)
response.print_response_stream()

陈经纶招收279名学生。CPU times: user 35.6 ms, sys: 0 ns, total: 35.6 ms
Wall time: 4.87 s
