# 制作一个自己会选择工具的

##  设置下环境

In [3]:
# 使用 %cd 魔法命令更改工作目录
import os
%cd .\\
# 打印当前工作目录以确认更改
print(f"当前工作目录: {os.getcwd()}")

[WinError 2] 系统找不到指定的文件。: '.\\ # 打印当前工作目录以确认更改'
D:\code\law_llama_system
当前工作目录: D:\code\law_llama_system


In [2]:

import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

In [4]:
from llama_index.core import VectorStoreIndex, Settings
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from llama_index.core.query_engine import TransformQueryEngine
from model import DeepseekAi
from llama_index.core.tools import QueryEngineTool,ToolMetadata
from llama_index.vector_stores.chroma import ChromaVectorStore

In [6]:
from init_model import get_embedding
from config import Config

In [7]:
llm_model = DeepseekAi(**Config.get_openai_like_config())
embedding_model = get_embedding()

In [8]:
Settings.llm = llm_model
Settings.embed_model = embedding_model

In [9]:
import chromadb
db  = chromadb.PersistentClient(""".\\LawDb""")
law_db = db.get_collection("law")
law_collection = ChromaVectorStore(chroma_collection=law_db)
law_index = VectorStoreIndex.from_vector_store(vector_store=law_collection,embed_model=embedding_model)

INFO:chromadb.telemetry.product.posthog:Anonymized telemetry enabled. See                     https://docs.trychroma.com/telemetry for more information.


In [10]:
l_e = law_index.as_query_engine()

In [11]:
l_e.query("""偷税漏税的定义。""")

INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"


Response(response='偷税漏税是指纳税人故意不缴或少缴应纳税款的行为。这种行为包括但不限于编造虚假计税依据、不进行纳税申报、欠缴应纳税款并采取转移或隐匿财产的手段妨碍税务机关追缴税款等。对于偷税漏税行为，税务机关有权追缴其不缴或少缴的税款、滞纳金，并处以相应罚款；构成犯罪的，还将依法追究刑事责任。', source_nodes=[NodeWithScore(node=TextNode(id_='3b213d42-0671-4e07-83fe-8338761fc958', embedding=None, metadata={'file_path': 'E:\\law_llama_system\\Law-Book\\5-经济法\\税收征收管理法（2015-04-24）.md', 'file_name': '税收征收管理法（2015-04-24）.md', 'file_size': 33879, 'creation_date': '2024-09-12', 'last_modified_date': '2024-01-22'}, excluded_embed_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], excluded_llm_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='3d34aacc-6941-46e1-a703-66da8b661575', node_type=<ObjectType.DOCUMENT: '4'>, metadata={'file_path': 'E:\\law_llama_system\\Law-Book\\5-经济法\\税收征收管理法（2015-04-24）.md', 'file_name': '税收征收管理法（2015

In [12]:
law_engine = law_index.as_query_engine(similarity_top_k=3)

In [13]:
hyde = HyDEQueryTransform(include_original=True)

In [14]:
law_hyde_engine = TransformQueryEngine(law_engine, hyde)

In [15]:
law_hyde_engine.query("""偷税漏税的定义""")

INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"


Response(response='偷税漏税是指纳税人采取欺骗、隐瞒手段进行虚假纳税申报或者不申报，逃避缴纳税款的行为。这种行为可能导致税务机关无法正常征收应纳税款，从而影响国家税收的正常秩序。', source_nodes=[NodeWithScore(node=TextNode(id_='3b213d42-0671-4e07-83fe-8338761fc958', embedding=None, metadata={'file_path': 'E:\\law_llama_system\\Law-Book\\5-经济法\\税收征收管理法（2015-04-24）.md', 'file_name': '税收征收管理法（2015-04-24）.md', 'file_size': 33879, 'creation_date': '2024-09-12', 'last_modified_date': '2024-01-22'}, excluded_embed_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], excluded_llm_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='3d34aacc-6941-46e1-a703-66da8b661575', node_type=<ObjectType.DOCUMENT: '4'>, metadata={'file_path': 'E:\\law_llama_system\\Law-Book\\5-经济法\\税收征收管理法（2015-04-24）.md', 'file_name': '税收征收管理法（2015-04-24）.md', 'file_size': 33879, 'creation_date': '2024-09-

In [16]:
query_engine_tools=[
    QueryEngineTool(
        query_engine=law_hyde_engine,
        metadata = ToolMetadata(
            name="law_tools",
            description="""
             用户询问关于法律的任何信息时，调用此工具。
            """
        )
    ),    
]

In [17]:
from llama_index.core.agent.react.types import (
    ObservationReasoningStep
)
from llama_index.core.agent import Task, AgentChatResponse

from llama_index.core.query_pipeline import (
    AgentInputComponent, AgentFnComponent
)
from typing import Dict, Any, List

In [18]:
def agent_input_fn(task:Task,state:Dict[str,Any]) -> Dict[str,Any]:
    if "current_reasoning" not in state:
        state["current_reasoning"] = []
    reasoning_step = ObservationReasoningStep(observation=task.input)
    state["current_reasoning"].append(reasoning_step)
    return {"input":task.input}

agent_input_component = AgentInputComponent(fn=agent_input_fn)

In [19]:
from llama_index.core.agent import ReActChatFormatter
from llama_index.core.tools import BaseTool

In [20]:
def react_prompt_fn(task:Task,state:Dict[str,Any],input:str,tools:List[BaseTool]) -> Dict[str,Any]:
    chat_formatter= ReActChatFormatter()
    return chat_formatter.format(
        tools,
        chat_history=task.memory.get() + state["memory"].get_all(),
        current_reasoning=state["current_reasoning"],
    )

react_prompt_component = AgentFnComponent(fn=react_prompt_fn)

In [21]:

from llama_index.core.agent.react.output_parser import ReActOutputParser
from llama_index.core.llms import ChatResponse
from llama_index.core.agent.types import Task


In [22]:
def parse_react_output_fn(
    task: Task, state: Dict[str, Any], chat_response: ChatResponse
):
    """Parse ReAct output into a reasoning step."""
    output_parser = ReActOutputParser()
    reasoning_step = output_parser.parse(chat_response.message.content)
    return {"done": reasoning_step.is_done, "reasoning_step": reasoning_step}

parse_react_output = AgentFnComponent(fn=parse_react_output_fn)

In [23]:
parse_react_output

AgentFnComponent(partial_dict={}, fn=<function parse_react_output_fn at 0x0000026F5F554CA0>, async_fn=None)

In [24]:
from llama_index.core.query_pipeline import ToolRunnerComponent
from llama_index.core.agent.react.types import ActionReasoningStep


def run_tool_fn(
        task: Task, state: Dict[str, Any], reasoning_step: ActionReasoningStep
):
    """Run tool and process tool output."""
    tool_runner_component = ToolRunnerComponent(
        query_engine_tools, callback_manager=task.callback_manager
    )
    tool_output = tool_runner_component.run_component(
        tool_name=reasoning_step.action,
        tool_input=reasoning_step.action_input,
    )
    observation_step = ObservationReasoningStep(
        observation=str(tool_output["output"])
    )

    state["current_reasoning"].append(observation_step)
    # TODO: get output

    print(state)

    return {"response_str": observation_step.get_content(), "is_done": False}


run_tool = AgentFnComponent(fn=run_tool_fn)

In [25]:
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.agent.react.types import ResponseReasoningStep


def process_response_fn(
        task: Task, state: Dict[str, Any], response_step: ResponseReasoningStep
):
    """Process response."""
    state["current_reasoning"].append(response_step)
    response_str = response_step.response
    # Now that we're done with this step, put into memory
    state["memory"].put(ChatMessage(content=task.input, role=MessageRole.USER))
    state["memory"].put(
        ChatMessage(content=response_str, role=MessageRole.ASSISTANT)
    )

    return {"response_str": response_str, "is_done": True}


process_response = AgentFnComponent(fn=process_response_fn)

In [26]:
def process_agent_response_fn(
        task: Task, state: Dict[str, Any], response_dict: dict
):
    """Process agent response."""
    return (
        AgentChatResponse(response_dict["response_str"]),
        response_dict["is_done"],
    )


process_agent_response = AgentFnComponent(fn=process_agent_response_fn)

In [27]:
from llama_index.core.query_pipeline import QueryPipeline as QP

qp = QP(verbose=True)

qp.add_modules(
    {
        "agent_input": agent_input_component,
        "react_prompt": react_prompt_component,
        "llm": llm_model,
        "react_output_parser": parse_react_output,
        "run_tool": run_tool,
        "process_response": process_response,
        "process_agent_response": process_agent_response,
    }
)

In [28]:
qp.add_chain(["agent_input", "react_prompt", "llm", "react_output_parser"])

# add conditional link from react output to tool call (if not done)
qp.add_link(
    "react_output_parser",
    "run_tool",
    condition_fn=lambda x: not x["done"],
    input_fn=lambda x: x["reasoning_step"],
)
# add conditional link from react output to final response processing (if done)
qp.add_link(
    "react_output_parser",
    "process_response",
    condition_fn=lambda x: x["done"],
    input_fn=lambda x: x["reasoning_step"],
)

# whether response processing or tool output processing, add link to final agent response
qp.add_link("process_response", "process_agent_response")
qp.add_link("run_tool", "process_agent_response")

In [29]:
from llama_index.core.agent import QueryPipelineAgentWorker
from llama_index.core.callbacks import CallbackManager
agent_worker = QueryPipelineAgentWorker(qp)
agent = agent_worker.as_agent(
    callback_manager=CallbackManager([]), verbose=True
)

  agent_worker = QueryPipelineAgentWorker(qp)


In [30]:
task = agent.create_task(
    "杀人犯什么法律？"
)

In [31]:
step_output = agent.run_step(task.task_id)

> Running step 836a3dde-98bb-4067-bf6f-f5d40bbef71b. Step input: 杀人犯什么法律？
[1;3;38;2;155;135;227m> Running module agent_input with input: 
state: {'sources': [], 'memory': ChatMemoryBuffer(chat_store=SimpleChatStore(store={}), chat_store_key='chat_history', token_limit=3000, tokenizer_fn=functools.partial(<bound method Encoding.encode of <Encod...
task: task_id='272ad52e-a408-442d-b55b-ce28c8dfc168' input='杀人犯什么法律？' memory=ChatMemoryBuffer(chat_store=SimpleChatStore(store={}), chat_store_key='chat_history', token_limit=3000, tokenizer_fn=functools.pa...

[0m

ValueError: Module input keys must have exactly one key if dest_key is not specified. Remaining keys: in module: {'input', 'tools'}

In [76]:
qp.clean_dag

<networkx.classes.multidigraph.MultiDiGraph at 0x2abe394a820>