# 查询学校信息

## 准备

In [1]:
%%time

# 加载llm和embeddings
%run ../utils2.py

from llama_index.core import Settings

# Settings.llm=get_llm("gpt-3.5-turbo")
Settings.llm=get_llm()
Settings.embed_model = get_embedding()

CPU times: user 3.35 s, sys: 396 ms, total: 3.75 s
Wall time: 3.38 s


In [2]:
%%time

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
)

engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

# create city SQL table
table_name = "school_info"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("school_name", String(200), primary_key=True),
    Column("students_enrolled", Integer,nullable=False),
)
metadata_obj.create_all(engine)

CPU times: user 5.03 ms, sys: 3.65 ms, total: 8.68 ms
Wall time: 8.03 ms


In [3]:
%%time

from llama_index.core import SQLDatabase

sql_database = SQLDatabase(engine, include_tables=["school_info"])

CPU times: user 5.17 ms, sys: 142 µs, total: 5.31 ms
Wall time: 4.89 ms


In [4]:
%%time

from sqlalchemy import insert

rows = [
    {"school_name": "北京市第八十中学", "students_enrolled": 260},
    {"school_name": "北京市陈经纶中学", "students_enrolled": 279},
    {"school_name": "北京市日坛中学", "students_enrolled": 403},
    {"school_name": "中国人民大学附属中学朝阳学校", "students_enrolled": 247},
    {"school_name": "北京工业大学附属中学", "students_enrolled": 418},
    {"school_name": "北京中学", "students_enrolled": 121},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

CPU times: user 3.14 ms, sys: 0 ns, total: 3.14 ms
Wall time: 2.79 ms


In [5]:
%%time

import pandas as pd

df = pd.read_sql_query("SELECT * from school_info", engine)
df.head()

CPU times: user 0 ns, sys: 1.89 ms, total: 1.89 ms
Wall time: 1.82 ms


Unnamed: 0,school_name,students_enrolled
0,北京市第八十中学,260
1,北京市陈经纶中学,279
2,北京市日坛中学,403
3,中国人民大学附属中学朝阳学校,247
4,北京工业大学附属中学,418


In [6]:
%%time

from llama_index.core.query_engine import NLSQLTableQueryEngine

query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database, tables=["school_info"],
    # text_to_sql_prompt="school_name 使用 like 查询，比如 SELECT * FROM school_name LIKE '%SCHOOL_NAME_KEYWORD%'"
)
query_str = "招生最多的是哪个学校?"
response = query_engine.query(query_str)

response.response

CPU times: user 347 ms, sys: 19.5 ms, total: 367 ms
Wall time: 2.1 s


'招生人数最多的学校是北京工业大学附属中学，共有418名学生。'

In [7]:
%%time

query_str = "北京市陈经纶中学招多少人?"
response = query_engine.query(query_str)

response.response

CPU times: user 10.2 ms, sys: 3.77 ms, total: 13.9 ms
Wall time: 1.85 s


'北京市陈经纶中学招收的学生人数为279人。'

In [8]:
%%time

query_str = "陈经纶招多少人?"
response = query_engine.query(query_str)

response.response

CPU times: user 13.5 ms, sys: 0 ns, total: 13.5 ms
Wall time: 2.22 s


'对不起，我无法获取关于"陈经纶"学校招收人数的具体信息。可能是因为数据源中没有找到相关记录。建议直接联系该学校以获取最准确的信息。'

In [9]:
%%time

query_str = "名称带陈经纶的学校招多少人?"
response = query_engine.query(query_str)

response.response

CPU times: user 12.8 ms, sys: 0 ns, total: 12.8 ms
Wall time: 1.81 s


'名称中包含“陈经纶”的学校招生人数为279人。'

In [14]:
%%time

from llama_index.core.indices.struct_store.sql_query import (
    SQLTableRetrieverQueryEngine,
)
from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from llama_index.core import VectorStoreIndex

table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
    (SQLTableSchema(table_name="school_info"))
]  # add a SQLTableSchema for each table

obj_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex,
)
query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=1)
)

response = query_engine.query("北京市陈经纶中学招多少人?")
response.response

CPU times: user 21.4 ms, sys: 0 ns, total: 21.4 ms
Wall time: 1.77 s


'北京市陈经纶中学招收的学生人数为279人。'

In [16]:
%%time

response = query_engine.query("陈经纶招多少人?")
response.response

CPU times: user 12.2 ms, sys: 4.28 ms, total: 16.5 ms
Wall time: 1.92 s


'对不起，当前查询结果中无法提供具体的招生人数。请尝试其他搜索条件或稍后再试。'

In [28]:
prompts = query_engine.get_prompts()

prompts

{'response_synthesis_prompt': PromptTemplate(metadata={'prompt_type': <PromptType.SQL_RESPONSE_SYNTHESIS_V2: 'sql_response_synthesis_v2'>}, template_vars=['query_str', 'sql_query', 'context_str'], kwargs={}, output_parser=None, template_var_mappings=None, function_mappings=None, template='Given an input question, synthesize a response from the query results.\nQuery: {query_str}\nSQL: {sql_query}\nSQL Response: {context_str}\nResponse: '),
 'sql_retriever:text_to_sql_prompt': PromptTemplate(metadata={'prompt_type': <PromptType.TEXT_TO_SQL: 'text_to_sql'>}, template_vars=['dialect', 'schema', 'query_str'], kwargs={}, output_parser=None, template_var_mappings=None, function_mappings=None, template='Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.\n\nNever query for all the columns from a sp

In [30]:
prompts['response_synthesis_prompt'].template

'Given an input question, synthesize a response from the query results.\nQuery: {query_str}\nSQL: {sql_query}\nSQL Response: {context_str}\nResponse: '

In [31]:
prompts['sql_retriever:text_to_sql_prompt'].template

'Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.\n\nNever query for all the columns from a specific table, only ask for a few relevant columns given the question.\n\nPay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:\n\nQuestion: Question here\nSQLQuery: SQL Query to run\nSQLResult: Result of the SQLQuery\nAnswer: Final answer here\n\nOnly use tables listed below.\n{schema}\n\nQuestion: {query_str}\nSQLQuery: '

In [34]:
%%time

new_prompt = PromptTemplate(
"""\
Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. You can order the results by a relevant column to return the most interesting examples in the database.\n\nNever query for all the columns from a specific table, only ask for a few relevant columns given the question.\n\nPay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Pay attention to which column is in which table. Also, qualify column names with the table name when needed. You are required to use the following format, each taking one line:\n\nQuestion: Question here\nSQLQuery: SQL Query to run\nSQLResult: Result of the SQLQuery\nAnswer: Final answer here\n\nOnly use tables listed below.\n学校名称的查询使用like做模糊查询。\n{schema}\n\nQuestion: {query_str}\nSQLQuery: """
)

query_engine.update_prompts({"sql_retriever:text_to_sql_prompt": new_prompt})

CPU times: user 89 µs, sys: 11 µs, total: 100 µs
Wall time: 102 µs


In [35]:
%%time

response = query_engine.query("陈经纶招多少人?")
response.response

CPU times: user 13.4 ms, sys: 178 µs, total: 13.6 ms
Wall time: 1.58 s


'陈经纶学校招收了279名学生。'

In [36]:
%%time

response = query_engine.query("北京中学招多少人?")
response.response

CPU times: user 12.9 ms, sys: 333 µs, total: 13.2 ms
Wall time: 2.01 s


'北京中学招收的学生人数为121人。'

## 正式查询

In [38]:
%%time

from llama_index.core.retrievers import NLSQLRetriever

# default retrieval (return_raw=True)
nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["school_info"], return_raw=True
)

results = nl_sql_retriever.retrieve(
    "北京中学招多少人?"
)

results

CPU times: user 3.72 ms, sys: 3.48 ms, total: 7.21 ms
Wall time: 1.04 s


[NodeWithScore(node=TextNode(id_='2d1ecac1-1001-4851-97e9-99ba0c8663b9', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text='[(121,)]', start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=None)]

In [39]:
%%time

from llama_index.core.query_engine import RetrieverQueryEngine

query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever, streaming=True)

response = query_engine.query(
    "北京中学招多少人?"
)

response.print_response_stream()

The context information does not provide a specific number for how many people Beijing Middle School admits.CPU times: user 38 ms, sys: 2.81 ms, total: 40.8 ms
Wall time: 2.66 s
