# 通过 LlamaIndex 使用自然语言查询数据库

## 准备数据

In [1]:
%%time
%%capture

# 使用 sqlalchemy 创建 sql 表结构和示例数据
!pip install sqlalchemy

CPU times: user 14.7 ms, sys: 7.64 ms, total: 22.3 ms
Wall time: 2.04 s


In [2]:
%%time

# 建立连接和表
from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    select,
)

engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

# 创建学校信息表结构
table_name = "school_info"
school_info_table = Table(
    table_name,
    metadata_obj,
    Column("school_name", String(200), primary_key=True),
    Column("students_enrolled", Integer,nullable=False),
)
metadata_obj.create_all(engine)

CPU times: user 93.4 ms, sys: 16.4 ms, total: 110 ms
Wall time: 108 ms


In [3]:
%%time

from sqlalchemy import insert

# 插入学校信息记录
rows = [
    {"school_name": "北京市第八十中学", "students_enrolled": 260},
    {"school_name": "北京市陈经纶中学", "students_enrolled": 279},
    {"school_name": "北京市日坛中学", "students_enrolled": 403},
    {"school_name": "中国人民大学附属中学朝阳学校", "students_enrolled": 247},
    {"school_name": "北京工业大学附属中学", "students_enrolled": 418},
    {"school_name": "北京中学", "students_enrolled": 121},
]
for row in rows:
    stmt = insert(school_info_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

CPU times: user 1.35 ms, sys: 3.08 ms, total: 4.44 ms
Wall time: 3.63 ms


In [4]:
%%time

# 通过 pandas 连接数据库展示数据

!pip install pandas

import pandas as pd

df = pd.read_sql_query("SELECT * from school_info", engine)
df.head(10)

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
[0mCPU times: user 476 ms, sys: 299 ms, total: 775 ms
Wall time: 2.33 s


Unnamed: 0,school_name,students_enrolled
0,北京市第八十中学,260
1,北京市陈经纶中学,279
2,北京市日坛中学,403
3,中国人民大学附属中学朝阳学校,247
4,北京工业大学附属中学,418
5,北京中学,121


## 最简示例

In [5]:
%%time
%%capture

# 安装所需的库

!pip install llama-index-core
!pip install llama-index-llms-openai-like
!pip install llama-index-embeddings-ollama

CPU times: user 16.7 ms, sys: 28.5 ms, total: 45.2 ms
Wall time: 6.67 s


In [6]:
%%time

from llama_index.core import SQLDatabase

sql_database = SQLDatabase(engine, include_tables=["school_info"])

CPU times: user 1.67 s, sys: 149 ms, total: 1.82 s
Wall time: 1.35 s


In [7]:
%%time

from llama_index.core import Settings
from llama_index.llms.openai_like import OpenAILike
from llama_index.embeddings.ollama import OllamaEmbedding

Settings.llm=OpenAILike(
    # nl2sql失败的模型: qwen:7b, qwen2:1.5b, yi:6b
    # 成功的模型：qwen2:7b, qwen:14b
    model="qwen2",
    api_base="http://monkey:11434/v1", 
    api_key="ollama",
    is_chat_model=True,
    temperature=0.1,
    request_timeout=60.0
)

Settings.embed_model =OllamaEmbedding(
    model_name="quentinz/bge-large-zh-v1.5",
    base_url="http://monkey:11434",
    ollama_additional_kwargs={"mirostat": 0}, # -mirostat N 使用 Mirostat 采样。
)

CPU times: user 1.34 s, sys: 195 ms, total: 1.54 s
Wall time: 1.58 s


### 成功的示例

In [9]:
%%time

from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.llms.openai_like import OpenAILike
from llama_index.embeddings.ollama import OllamaEmbedding

query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database, 
    tables=["school_info"],
)
query_str = "招生最多的是哪个学校?"
response = query_engine.query(query_str)

response.response

CPU times: user 14.4 ms, sys: 3.7 ms, total: 18.1 ms
Wall time: 1.38 s


'招生最多的是北京工业大学附属中学，共有418名学生。'

### 不成功的示例

In [10]:
%%time

from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.llms.openai_like import OpenAILike
from llama_index.embeddings.ollama import OllamaEmbedding

query_engine = NLSQLTableQueryEngine(
    sql_database=sql_database, 
    tables=["school_info"],
    llm=OpenAILike(
        model="yi",
        api_base="http://monkey:11434/v1", 
        api_key="ollama",
        is_chat_model=True,
        temperature=0.1,
        request_timeout=60.0
    )
)
query_str = "招生最多的是哪个学校?"
response = query_engine.query(query_str)

response.response

CPU times: user 81.3 ms, sys: 0 ns, total: 81.3 ms
Wall time: 5.16 s


'Based on the given SQL query "SELECT school_name FROM schools ORDER BY students_num DESC LIMIT 1", it seems that there was an error in retrieving the results. However, based on the information available, the school with the most students is likely to be the one with the highest number of enrolled students. To find this out, you can execute the following corrected SQL query:\n\n"SELECT school_name FROM schools ORDER BY students_num DESC LIMIT 1"\n\nThis will return the name of the school that currently has the largest student enrollment.'

## 实践总结

### 回答的流式输出

In [82]:
%%time

from llama_index.core.retrievers import NLSQLRetriever
from llama_index.core.query_engine import RetrieverQueryEngine

nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["school_info"], return_raw=True
)

query_engine = RetrieverQueryEngine.from_args(
    nl_sql_retriever,
    streaming=True
)

response = query_engine.query(
    "招生最多的前三个学校?"
)
response.print_response_stream()
print()

招生最多的前三个学校是北京工业大学附属中学、北京市日坛中学和北京市陈经纶中学。
CPU times: user 44.3 ms, sys: 19.4 ms, total: 63.7 ms
Wall time: 2.14 s


### 支持模糊查询

#### 默认情况不支持

In [12]:
%%time

response = query_engine.query(
    "陈经纶招多少?"
)
response.print_response_stream()
print()

无法回答这个问题，因为提供的上下文信息是一个空列表，没有包含任何与“陈经纶招多少”相关的内容。需要具体的数字、单位或者上下文来提供一个准确的答案。CPU times: user 95 ms, sys: 2.31 ms, total: 97.3 ms
Wall time: 1.81 s


#### 自定义提示词实现模糊查询

In [32]:
%%time

from llama_index.core import PromptTemplate

nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["school_info"], 
    return_raw=True,
    llm=OpenAILike(
        model='qwen-turbo', 
        api_base="http://ape:3000/v1", 
        api_key="sk-bJP6QSnUfjAYeYeE505d3eBf63A643BeB0B8E350Df9b7750",
        is_chat_model=True,
        temperature=0.1,
        request_timeout=60.0
    )
)

old_prompt_str=nl_sql_retriever.get_prompts()['text_to_sql_prompt'].template
new_prompt = PromptTemplate(
    f"{old_prompt_str}"
    "查询关键字使用模糊查询, 并且查询结果应包含关键字所属的列"
)
nl_sql_retriever.update_prompts({"text_to_sql_prompt": new_prompt})

results = nl_sql_retriever.retrieve(
    "陈经纶招多少?"
)

results

CPU times: user 72.2 ms, sys: 0 ns, total: 72.2 ms
Wall time: 1.7 s


[NodeWithScore(node=TextNode(id_='5baee8fa-73df-4133-952a-9a0e3b1b31ae', embedding=None, metadata={'sql_query': "SELECT school_name, students_enrolled FROM school_info WHERE school_name LIKE '%陈经纶%' ORDER BY students_enrolled DESC LIMIT 1;", 'result': [('北京市陈经纶中学', 279)], 'col_keys': ['school_name', 'students_enrolled']}, excluded_embed_metadata_keys=['sql_query', 'result', 'col_keys'], excluded_llm_metadata_keys=['sql_query', 'result', 'col_keys'], relationships={}, text="[('北京市陈经纶中学', 279)]", mimetype='text/plain', start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=None)]

In [83]:
%%time

nl_sql_retriever = NLSQLRetriever(
    sql_database, tables=["school_info"], 
    return_raw=False,

    
    llm=OpenAILike(
        model='qwen-turbo', 
        api_base="http://ape:3000/v1", 
        api_key="sk-bJP6QSnUfjAYeYeE505d3eBf63A643BeB0B8E350Df9b7750",
        is_chat_model=True,
        temperature=0.1,
        request_timeout=60.0
    )
)

old_prompt_str=nl_sql_retriever.get_prompts()['text_to_sql_prompt'].template
new_prompt = PromptTemplate(
    f"{old_prompt_str}"
    "查询关键字使用模糊查询, 并且查询结果应包含关键字所属的列"
)
nl_sql_retriever.update_prompts({"text_to_sql_prompt": new_prompt})


query_engine = RetrieverQueryEngine.from_args(
    nl_sql_retriever,
    streaming=True,
)

response = query_engine.query(
    "陈经纶招多少?"
)
response.print_response_stream()
print()

陈经纶招收279名学生。
CPU times: user 96.2 ms, sys: 5.11 ms, total: 101 ms
Wall time: 3.06 s


In [85]:
%%time

from llama_index.core.prompts import PromptType

my_qa_prompt_template = (
    "回答中要求使用学校的完整名称(school_name)"
    "不用再计算，给出的就是答案"
    "Context information is below.\n"
    "---------------------\n"
    "{context_str}\n"
    "---------------------\n"
    "Given the context information and not prior knowledge, "
    "answer the query.\n"
    "Query: {query_str}\n"
    "Answer: "
)
my_qa_prompt = PromptTemplate(
    my_qa_prompt_template, prompt_type=PromptType.QUESTION_ANSWER
)

query_engine = RetrieverQueryEngine.from_args(
    nl_sql_retriever,
    streaming=True,
    text_qa_template=my_qa_prompt,
)

response = query_engine.query(
    "陈经纶招多少?"
)
response.print_response_stream()
print()

北京市陈经纶中学招收的学生人数为279名。
CPU times: user 28.4 ms, sys: 15.1 ms, total: 43.5 ms
Wall time: 1.97 s


In [77]:
%%time

response = query_engine.query(
    "日坛招多少?"
)
response.print_response_stream()
print()

北京市日坛中学招收的学生人数为403名。
CPU times: user 40.6 ms, sys: 1.93 ms, total: 42.6 ms
Wall time: 1.77 s


In [79]:
%%time

response = query_engine.query(
    "招生最少的学校是哪个? 招多少？"
)
response.print_response_stream()
print()

招生最少的学校是北京中学，招收了121名学生。
CPU times: user 31.2 ms, sys: 16.8 ms, total: 48 ms
Wall time: 2.57 s
