In [66]:
import os
import config.env

In [67]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///FlowerShop.db")

In [68]:
def get_schema(_):
    return db.get_table_info()

def run_query(query):
    return db.run(query)

In [69]:
get_schema(None)

'\nCREATE TABLE "Flowers" (\n\t"ID" INTEGER, \n\t"Name" TEXT NOT NULL, \n\t"Type" TEXT NOT NULL, \n\t"Source" TEXT NOT NULL, \n\t"PurchasePrice" REAL, \n\t"SalePrice" REAL, \n\t"StockQuantity" INTEGER, \n\t"SoldQuantity" INTEGER, \n\t"ExpiryDate" DATE, \n\t"Description" TEXT, \n\t"EntryDate" DATE DEFAULT CURRENT_DATE, \n\tPRIMARY KEY ("ID")\n)\n\n/*\n3 rows from Flowers table:\nID\tName\tType\tSource\tPurchasePrice\tSalePrice\tStockQuantity\tSoldQuantity\tExpiryDate\tDescription\tEntryDate\n1\tRose\tFlower\tFrance\t1.2\t2.5\t100\t10\t2023-12-31\tA beautiful red rose\t2024-08-04\n2\tTulip\tFlower\tNetherlands\t0.8\t2.0\t150\t25\t2023-12-31\tA colorful tulip\t2024-08-04\n3\tLily\tFlower\tChina\t1.5\t3.0\t80\t5\t2023-12-31\tAn elegant white lily\t2024-08-04\n*/'

查询sql

In [70]:
from langchain_core.prompts import ChatPromptTemplate

system_template ="""
你是一个数据库专家,根据给定的输入问题，将其转换成SQL查询语句，只负责生成查询语句，禁止生成新增，修改，删除等语句，这是违法的。直接返回查询语句，不需要添加任何标点符号，不需要返回markdown格式的代码
根据下表结构，编写一个SQL查询来回答用户的问题：{schema}

问题：{question}
SQL查询
"""


prompt = ChatPromptTemplate.from_messages([("system",system_template),("human","input"),])

In [71]:
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain_core.runnables import RunnablePassthrough

llm = ChatOpenAI(model="deepseek-chat",base_url=os.environ["DEEPSEEK_OPENAI_API_BASE"],api_key=os.environ["DEEPSEEK_OPENAI_API_KEY"])

sql_response = (RunnablePassthrough.assign(schema=get_schema)
| prompt
| llm
|StrOutputParser())

In [72]:
# query =  sql_response.invoke({"question":"有多少种不同的鲜花？"})
# query

In [76]:
db.run("SELECT COUNT(DISTINCT Name) FROM Flowers")

'[(5,)]'

解答

In [73]:
answer_template ="""
给定一个输入问题和SQL语句，将其转成自然语言答案。

表结构：{schema}
问题：{question}
SQL查询：{query}
SQL响应：{response}

答案
"""


prompt_res = ChatPromptTemplate.from_messages([("system",answer_template),])

In [74]:
full_chain = (RunnablePassthrough.assign(query=sql_response)
              | RunnablePassthrough.assign(schema=get_schema,response=lambda x:db.run(x["query"]))
              |prompt_res
              |llm)

In [75]:
full_chain.invoke({"question":"有多少种不同的鲜花？"})

AIMessage(content='根据SQL查询结果，有5种不同的鲜花。', response_metadata={'token_usage': {'completion_tokens': 11, 'prompt_tokens': 357, 'total_tokens': 368, 'prompt_cache_hit_tokens': 0, 'prompt_cache_miss_tokens': 357}, 'model_name': 'deepseek-chat', 'system_fingerprint': 'fp_7e0991cad4', 'finish_reason': 'stop', 'logprobs': None}, id='run-4e20371a-62d2-4d00-a15d-eaec95bc69e0-0', usage_metadata={'input_tokens': 357, 'output_tokens': 11, 'total_tokens': 368})