In [89]:
import datetime
from operator import itemgetter

# ruff: noqa: F401
import pandas as pd
import pymysql
from langchain_community.chat_models import ChatTongyi  #  noqa: F401
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnableParallel, RunnablePassthrough, RunnableSequence
from langchain_ollama.chat_models import ChatOllama  #  noqa: F401

from citra.mcp.tool import get_connection, sql_outage

conn = get_connection()


In [57]:
model = ChatTongyi(model='qwen-max', temperature=0.7)
# model = ChatOllama(model='qwen2.5:3b', temperature=0.7)
tools = [sql_outage]

chat_with_tools = model.bind_tools(tools)

In [None]:
from langchain_core.prompts import ChatPromptTemplate

prompt_format = ChatPromptTemplate.from_messages(
    [('system', '根据用户问题，判断数据返回的格式(markdown, excel, image)。如果未说明，默认为markdown。必须只输出格式！'), ('human', '{question}')]
)

prompt_category = ChatPromptTemplate.from_messages(
    [('system', '根据用户问题，判断该数据库查询问题的分类(查询类，操作类，指标分析类)。必须只输出分类！'), ('human', '{question}')]
)
prompt_table = ChatPromptTemplate.from_messages(
    [
        (
            'system',
            '根据用户问题，判断该数据库查询问题的表名。必须只输出表名！\
         t_fault_order_inter,故障工单表；\
         t_event_alarm_inter,停电记录表；',
        ),
        ('human', '{question}'),
    ]
)
prompt_sql = ChatPromptTemplate.from_messages(
    [
        ('system', '根据表{table}和结构：{describe}，编写sql语句。必须只输出可执行的SQL语句'),
        ('human', '{question}'),
    ]
)


In [155]:
def add_date(question: dict) -> dict:
    """Add the current date to the question."""
    today = datetime.date.today()
    question['question'] = f'今天是{today.strftime("%Y-%m-%d")},{question["question"]} '
    return question


def extract_json(message: AIMessage) -> str:
    """Extracts JSON content from a string where JSON is embedded between ```sql ***``` tags."""
    import re

    text: str = message.content.replace('\n', ' ')
    pattern = r'```(sql|json)(.*?)```'
    match = re.search(pattern, text)
    if match:
        try:
            return match.group(2)
        except Exception:
            raise ValueError(f'Failed to parse: {message}')
    return ''


def execute_sql(sql: str) -> pd.DataFrame:
    """执行sql语句,返回dataframe"""
    with conn.cursor() as cursor:
        try:
            cursor.execute(sql)
        except pymysql.err.ProgrammingError as e:
            raise ValueError('查询过程出错，请重试') from e
        results = cursor.fetchall()
        cols = [desc[0] for desc in cursor.description]
        # cols = [col_dict.get(c,c) for c in cols]
        df = pd.DataFrame(results, columns=cols)
        df.fillna('', inplace=True)
    return df


In [None]:
question = '今年5月15日后龙港供电公司的停电信息'
chain_query = (
    RunnableParallel(
        question=RunnablePassthrough(),
        format=prompt_format | model | StrOutputParser(),
        categort=prompt_category | model | StrOutputParser(),
    )
    | add_date
)


chain_analyse = RunnableSequence(
    {'question': RunnablePassthrough(), 'table': prompt_table | model | StrOutputParser()},
    RunnablePassthrough().assign(describe=lambda x: execute_sql(f'describe {x["table"]}').to_string()),
    RunnablePassthrough().assign(sql=prompt_sql | model | extract_json),
)


In [None]:
def call_tools(ai_msg: AIMessage) -> ToolMessage:
    """Simple sequential tool calling helper."""
    tool_map = {tool.name: tool for tool in tools}
    select_tool = tool_map[ai_msg.tool_calls[0]['name']]
    tool_msg = select_tool.invoke(ai_msg.tool_calls[0])
    return tool_msg


def str_to_gen(s: str, *, msg_type: str = 'msg', chunk_size: int = 10):
    """将字符串转换为生成器"""
    for i in range(0, len(s), chunk_size):
        yield {'type': msg_type, 'content': s[i : i + chunk_size]}


In [None]:
def consult_database(question: str):
    """根据问题生成sql查询语句，然后执行查询并返回结果，并修改为不同的格式"""
    detail_ques = chain_query.invoke(question)
    messages: list[BaseMessage] = [HumanMessage(content=detail_ques['question'])]
    # 查询
    detail_qurty = chain_analyse.invoke(detail_ques['question'])

    ai_msg = chat_with_tools.invoke(detail_ques['question'])
    if not ai_msg.tool_calls:  # type: ignore
        print('没有调用工具')
        yield from str_to_gen(ai_msg.content, chunk_size=5)  # type: ignore
    else:
        print('调用工具')
        import uuid

        query: str = call_tools(ai_msg).content  # type: ignore
        print(query)
        df_res = execute_sql(query)
        if df_res.size == 0:
            yield {'type': 'msg', 'content': '没有查询到数据,请补充问题'}
            return
        # 根据返回的结果类型，返回不同的格式
        messages.append(AIMessage(content=df_res.to_string(index=False)[:100]))
        res_describe = model.invoke([SystemMessage('根据用户问题，介绍以下从数据库获得结果的介绍。20字以内')] + messages)
        yield from str_to_gen(res_describe.content)
        return_type = model_with_format.invoke(question)
        if return_type is None:
            return_type = 'markdown'
        else:
            return_type = return_type['format']
        print(return_type)
        if return_type == 'markdown':
            yield {'type': 'table', 'content': df_res.to_markdown(index=False)}
        elif return_type == 'excel':
            tab_name = uuid.uuid4().hex
            df_res.to_excel(f'citra/service/cache/{tab_name}.xlsx', index=False)
            yield {'type': 'excel', 'content': f'/cache/{tab_name}.xlsx'}
        elif return_type == 'image':
            import dataframe_image as dfi

            df_name = uuid.uuid4().hex
            dfi.export(df_res, f'citra/service/cache/{df_name}.png')
            yield {'type': 'image', 'content': f'/cache/{df_name}.png'}


In [6]:
for i in consult_database('今年5月15日后龙港供电公司的停电信息'):
    print(i['content'], end='')

调用工具
SELECT equipName , lineName ,faultType , occurTime,endTime ,gdsName ,unitName FROM t_event_alarm_inter WHERE unitName = '龙港供电公司' AND occurTime >= '2025-05-15' AND occurTime < '2025-06-05' ORDER BY occurTime DESC
查询2025年5月15日后龙港供电公司的停电记录。markdown
| equipName                             | lineName   | faultType   | occurTime           | endTime             | gdsName    | unitName     |
|:--------------------------------------|:-----------|:------------|:--------------------|:--------------------|:-----------|:-------------|
| 河尾1#公变                            | 海联S907线 | 公变停电    | 2025-05-26 13:13:00 | NaT                 | 新港供电所 | 龙港供电公司 |
| 月星13#公变                           | 西岩S330线 | 公变停电    | 2025-05-26 12:33:03 | 2025-05-26 13:19:26 | 港城供电所 | 龙港供电公司 |
| 温州浙瓯工业智造有限公司14            | 世丰S227线 | 专变停电    | 2025-05-26 09:22:22 | 2025-05-26 09:41:09 | 新港供电所 | 龙港供电公司 |
| 月星4#公变                            | 西岩S330线 | 公变停电    | 2025-05-26 06:31:00 | 2025-05-26 06:49:44 | 港城供电所 | 龙港供电公司 |