# 查询学校信息

## 准备

In [13]:
%%time

# 加载llm和embeddings
%run ../utils2.py

from llama_index.core import Settings

# Settings.llm=get_llm('ERNIE-4.0-8K') # SQL报错，无法连接数据库，总之不可用
# Settings.llm=get_llm('glm-4-flash') # 多处错误，不如qwen
Settings.llm=get_llm('qwen-max') 
# Settings.llm=get_llm('qwen-turbo') # 不能很好的遵循提示，sql显示指定的数据列
# Settings.llm=get_llm('qwen:14b') # 不如qwen2-7b，而且慢
# Settings.llm=get_llm('yi:6b') # 大部分情况不如qwen2-7b，但是最后的sql显示指定的数据列是可以的
# Settings.llm=get_llm("gpt-3.5-turbo") # 通过测试，但有一些是英文
# Settings.llm=get_llm() # qwen2-7b 不能很好的遵循提示，sql显示指定的数据列
Settings.embed_model = get_embedding()

CPU times: user 1.36 ms, sys: 0 ns, total: 1.36 ms
Wall time: 1.38 ms


In [40]:
%%time

import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

CPU times: user 40 µs, sys: 5 µs, total: 45 µs
Wall time: 47 µs


In [14]:
%%time

# 建立连接和表
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
)

engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

# create city SQL table
table_name = "school_info"
school_info_table = Table(
    table_name,
    metadata_obj,
    Column("school_name", String(200), primary_key=True),
    Column("students_enrolled", Integer,nullable=False),
)
metadata_obj.create_all(engine)

CPU times: user 2.62 ms, sys: 0 ns, total: 2.62 ms
Wall time: 2.23 ms


In [15]:
%%time

from llama_index.core import SQLDatabase

sql_database = SQLDatabase(engine, include_tables=["school_info"])

CPU times: user 0 ns, sys: 3.37 ms, total: 3.37 ms
Wall time: 2.98 ms


In [16]:
%%time

from sqlalchemy import insert

rows = [
    {"school_name": "北京市第八十中学", "students_enrolled": 260},
    {"school_name": "北京市陈经纶中学", "students_enrolled": 279},
    {"school_name": "北京市日坛中学", "students_enrolled": 403},
    {"school_name": "中国人民大学附属中学朝阳学校", "students_enrolled": 247},
    {"school_name": "北京工业大学附属中学", "students_enrolled": 418},
    {"school_name": "北京中学", "students_enrolled": 121},
]
for row in rows:
    stmt = insert(school_info_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

CPU times: user 2.03 ms, sys: 970 µs, total: 3 ms
Wall time: 2.59 ms


In [17]:
%%time

import pandas as pd

df = pd.read_sql_query("SELECT * from school_info", engine)
df.head()

CPU times: user 0 ns, sys: 2.09 ms, total: 2.09 ms
Wall time: 1.95 ms


Unnamed: 0,school_name,students_enrolled
0,北京市第八十中学,260
1,北京市陈经纶中学,279
2,北京市日坛中学,403
3,中国人民大学附属中学朝阳学校,247
4,北京工业大学附属中学,418


## 基本查询

In [18]:
%%time

from llama_index.core.query_engine import NLSQLTableQueryEngine

query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database, tables=["school_info"],
)
query_str = "招生最多的是哪个学校?"
response = query_engine.query(query_str)

response.response

CPU times: user 56.9 ms, sys: 4.41 ms, total: 61.4 ms
Wall time: 5.91 s


'招生最多的学校是北京工业大学附属中学。'

In [19]:
%%time

response = query_engine.query("陈经纶招多少人?")
response.response

CPU times: user 13.7 ms, sys: 307 µs, total: 14 ms
Wall time: 6.39 s


'对不起，我没有找到关于"陈经纶"学校招收学生人数的相关信息。可能是学校名称有误或者信息暂未录入数据库，请您核实后再次询问。'

## 调整 text_to_sql_prompt

In [20]:
%%time

from llama_index.core import PromptTemplate

prompts = query_engine.get_prompts()

new_prompt = PromptTemplate(
"关键字的查询使用like做全模糊查询，比如 school_name like '%keyword%'\n"
"""\
Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.\n\nNever query for all the columns from a specific table, only ask for a few relevant columns given the question.\n\nPay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:\n\nQuestion: Question here\nSQLQuery: SQL Query to run\nSQLResult: Result of the SQLQuery\nAnswer: Final answer here\n\nOnly use tables listed below.\n{schema}\n\nQuestion: {query_str}\nSQLQuery: """
)

query_engine.update_prompts({"sql_retriever:text_to_sql_prompt": new_prompt})

CPU times: user 210 µs, sys: 0 ns, total: 210 µs
Wall time: 213 µs


In [21]:
%%time

response = query_engine.query("陈经纶招多少人?")
response.response

CPU times: user 13.1 ms, sys: 421 µs, total: 13.5 ms
Wall time: 9.26 s


'很抱歉，我没有找到关于"陈经纶"学校招收学生人数的相关信息。可能是学校名称有误或者该数据暂未提供。请确认学校名称或稍后再次尝试。'

## 查询时表检索

In [22]:
%%time

from llama_index.core.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine,
)
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import VectorStoreIndex

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="school_info"))
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=1)
)

response = query_engine.query("招生最多的是哪个学校, 招多少?")
response.response

CPU times: user 22.1 ms, sys: 0 ns, total: 22.1 ms
Wall time: 7.24 s


'招生最多的学校是北京工业大学附属中学，共招收了418名学生。'

In [23]:
response = query_engine.query("招生最多的是前三个学校?")

response.response

'招生最多的前三个学校分别是：北京工业大学附属中学（418人），北京市日坛中学（403人），以及北京市陈经纶中学（279人）。'

## 文本到sql检索

In [24]:
%%time

from llama_index.core.retrievers import NLSQLRetriever

# default retrieval (return_raw=True)
nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["school_info"], return_raw=True
)

results = nl_sql_retriever.retrieve(
    "招生最多的前三个学校?"
)

results

CPU times: user 9.94 ms, sys: 282 µs, total: 10.2 ms
Wall time: 13.5 s


[NodeWithScore(node=TextNode(id_='b6adbfc4-b5bb-4066-9f4f-ebe477a698f2', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text="[('北京工业大学附属中学', 418), ('北京市日坛中学', 403), ('北京市陈经纶中学', 279)]", start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=None)]

In [25]:
results = nl_sql_retriever.retrieve(
    "北京市陈经纶中学招生人数是多少? 结果要包含学校名称"
)

results

[NodeWithScore(node=TextNode(id_='9c14136e-c88e-49fa-a4ba-7613ee6e2cbb', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text="[('北京市陈经纶中学', 279)]", start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=None)]

In [26]:
%%time

from llama_index.core.query_engine import RetrieverQueryEngine

query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever, streaming=True)

response = query_engine.query(
    "招生最多的三个学校? 要包括招生人数"
)

response.print_response_stream()

招生最多的三个学校分别是：北京工业大学附属中学，招生人数为418人；北京市日坛中学，招生人数为403人；以及北京市陈经纶中学，招生人数为279人。CPU times: user 23.4 ms, sys: 3.79 ms, total: 27.1 ms
Wall time: 17.4 s


In [27]:
response = query_engine.query(
    "北京市日坛中学招生人数是多少"
)

response.print_response_stream()

对不起，根据我所了解的信息，无法提供北京市日坛中学具体的招生人数。请您直接联系学校获取最新和准确的招生信息。

In [32]:
%%time

nl_sql_retriever.get_prompts()
new_prompt = PromptTemplate(
    "关键字的查询使用like做全模糊查询，比如 school_name like '%keyword%'\n"
    "查询结果中始终包括学校名称（school_name）"
    "查询结果要包括表的所有列"
"""\
Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.\n\nQuery for all the columns from a specific table.\n\nPay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:\n\nQuestion: Question here\nSQLQuery: SQL Query to run\nSQLResult: Result of the SQLQuery\nAnswer: Final answer here\n\nOnly use tables listed below.\n{schema}\n\nQuestion: {query_str}\nSQLQuery: """    
)
nl_sql_retriever.update_prompts({"text_to_sql_prompt": new_prompt})

results = nl_sql_retriever.retrieve(
    "陈经纶招生人数是多少? "
)

results

CPU times: user 7.51 ms, sys: 0 ns, total: 7.51 ms
Wall time: 4.14 s


[NodeWithScore(node=TextNode(id_='30210c4f-f35b-401a-803d-0c871f67b757', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text='[]', start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=None)]

In [29]:
%%time

query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever, streaming=True)

response = query_engine.query(
    "陈经纶招生人数是多少"
)

response.print_response_stream()

对不起，我没有找到关于陈经纶招生人数的具体信息。建议您直接咨询陈经纶学校或访问其官方网站获取最新和准确的招生信息。CPU times: user 19.1 ms, sys: 7.05 ms, total: 26.1 ms
Wall time: 7.16 s


In [30]:
%%time


results = nl_sql_retriever.retrieve(
    "陈经纶招生人数是多少"
)

results

CPU times: user 7.34 ms, sys: 0 ns, total: 7.34 ms
Wall time: 4.58 s


[NodeWithScore(node=TextNode(id_='e129aaec-827f-4fe9-89f5-3bd72660e205', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text='[]', start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=None)]

In [31]:
%%time

query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever, streaming=True)

response = query_engine.query(
    "陈经纶招生人数是多少"
)

response.print_response_stream()

对不起，我没有找到关于陈经纶招生人数的具体信息。建议您直接咨询陈经纶学校或访问其官方网站获取最新和准确的招生信息。CPU times: user 25 ms, sys: 278 µs, total: 25.3 ms
Wall time: 7.59 s
