# Tool Agent

In [1]:
!pip install python-dotenv langchain-community wikipedia




[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip





In [2]:
import os

from dotenv import load_dotenv

load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")

In [3]:
import enum
import logging
from typing import Optional, List

from app.core.agents.base import AgentBase
from app.core.llm.generator import Generator
from app.core.memory.memory import Memory
from app.core.persona import Persona
from app.core.states.base import StateBase, Transition
from app.core.tools.adapter import ToolAdapter
from app.states.defaults import text_handler, tools_handler
from app.states.final_answer.state import FinalAnswerState
from app.states.rewrite_question.state import RewriteQuestionState

logger = logging.getLogger(__name__)


class States(str, enum.Enum):
    REWRITE_QUESTION = RewriteQuestionState.name
    SEARCH = "search"
    SEARCH_FILTER = "search_filter"
    FINAL_ANSWER = "final_answer"


class ConversationalSearchAgent(AgentBase):
    name: str = "SearchAgent"
    description: str = "Specializes in searching for information."

    def after_generation(self, response: str, memory: Memory, tools: Optional[List[ToolAdapter]]) -> Transition:
        return tools_handler(
            response=response,
            memory=memory,
            tools=tools,
            next_state=States.SEARCH_FILTER,
            save_data_key="search_results"
        )


class SearchState(StateBase):
    name: str = States.SEARCH

    def build_prompt(self, persona: Persona, memory: Memory, tools: Optional[List[ToolAdapter]]) -> str:
        prompt = f'''{persona.prompt()}

Search for information using your tools to help solve the following problem.

Current Problem:
"""
{memory.data.get_current_message().content}
"""

Previous Notes:
"""
{memory.scratch_pad.prompt()}
"""

Here are the schemas for the tools you have access to, pick only one:
"""
{[tool.schema() for tool in tools]}
"""

Respond with the JSON input for the tool of your choice to best solve the problem.'''
        return prompt

    def after_generation(self, response: str, memory: Memory, tools: Optional[List[ToolAdapter]]) -> Transition:
        return tools_handler(
            response=response,
            memory=memory,
            tools=tools,
            next_state=States.SEARCH_FILTER,
            save_data_key="search_results"
        )


class SearchFilterState(StateBase):
    name: str = States.SEARCH_FILTER

    def build_prompt(self, persona: Persona, memory: Memory, _: Optional[List[ToolAdapter]]) -> str:
        search_results = memory.data.pop("search_results")
        prompt = f'''{persona.prompt()}

Given the search results, filter out unrelated information that doesn't directly answer the problem. Return a summarized version of the search results with the most important details for the problem.

Current Problem:
"""
{memory.data.get_current_message().content}
"""

Search Results:
"""
{search_results}
"""

Give your summarized search results with the most important details for the problem.'''
        return prompt

    def after_generation(self, response: str, memory: Memory, _: Optional[List[ToolAdapter]]) -> Transition:
        return text_handler(
            response=response,
            memory=memory,
            next_state=States.FINAL_ANSWER
        )

In [4]:
from app.llm.openai.service import OpenAIService

llm = OpenAIService(api_key=api_key)

In [5]:
from app.states.rewrite_question.state import RewriteQuestionState
import json

from app.core.messages import Query
from app.tools.langchain.wrapper import LangChainToolWrapper
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper

user_name = "user"
persona = Persona(description="You are a helpful assistant specializing in searching and filtering information.")

api_wrapper = WikipediaAPIWrapper(
    top_k_results=3,
    doc_content_chars_max=3000,
    load_all_available_meta=True
)
tools = [
    LangChainToolWrapper.create(WikipediaQueryRun(api_wrapper=api_wrapper))
]

rewrite_question_state = RewriteQuestionState(
    next_state=States.SEARCH,
    user_name=user_name,
    generator=Generator(service=llm, use_json_model=False, temperature=0.1),
)
search_state = SearchState(
    generator=Generator(service=llm, use_json_model=False, temperature=0.1),
    tools=tools
)
search_filter_state = SearchFilterState(
    generator=Generator(service=llm, use_json_model=False, temperature=0.1),
    tools=tools
)
final_answer_state = FinalAnswerState(
    generator=Generator(service=llm, temperature=0.3)
)

states = [rewrite_question_state, search_state, search_filter_state, final_answer_state]
agent = ConversationalSearchAgent.start(
    persona=persona,
    memory=Memory(),
    states=states,
    default_initial_state=States.REWRITE_QUESTION,
    clear_scratch_pad_after_answer=False,
    clear_data_after_answer=False,
    step_limit=6,
    step_limit_state_name=States.FINAL_ANSWER
)

future = agent.ask(Query(initial_state=States.REWRITE_QUESTION,
                         goal="Who is the CEO of Microsoft?", from_caller=user_name))
response = future.get(timeout=30)
print(f"Steps: {json.dumps([step.model_dump() for step in response.metadata['steps']], indent=2)}")
print(
    f"Total tokens: {sum((step.token_usage.total_tokens if step.token_usage is not None else 0) for step in response.metadata.get('steps', []))}")
print(response.final_output)

future = agent.ask(Query(initial_state=States.REWRITE_QUESTION,
                         goal="Where are they from?", from_caller=user_name))
response = future.get(timeout=30)
print(f"Steps: {json.dumps([step.model_dump() for step in response.metadata['steps']], indent=2)}")
print(
    f"Total tokens: {sum((step.token_usage.total_tokens if step.token_usage is not None else 0) for step in response.metadata.get('steps', []))}")
print(response.final_output)

agent.stop()

Steps: [
  {
    "state_name": "rewrite_question",
    "next_state": "States.SEARCH",
    "prompt": null,
    "output": null,
    "token_usage": null
  },
  {
    "state_name": "States.SEARCH",
    "next_state": "States.SEARCH_FILTER",
    "prompt": "You are a helpful assistant specializing in searching and filtering information.\n\nSearch for information using your tools to help solve the following problem.\n\nCurrent Problem:\n\"\"\"\nWho is the CEO of Microsoft?\n\"\"\"\n\nPrevious Notes:\n\"\"\"\n\nREWRITE_QUESTION: None\n\"\"\"\n\nHere are the schemas for the tools you have access to, pick only one:\n\"\"\"\n[{'name': 'wikipedia', 'description': 'A wrapper around Wikipedia. Useful for when you need to answer general questions about people, places, companies, facts, historical events, or other subjects. Input should be a search query.', 'title': 'wikipediaSchema', 'type': 'object', 'properties': {'query': {'title': 'Query', 'type': 'string'}, 'tool_name': {'title': 'Name', 'type': 

True

In [6]:
!pip install mermaid-py




[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip


In [7]:
import mermaid as md
from mermaid.graph import Graph

graph: Graph = Graph('example', """
graph TD;
    goal --> search
    search --> final_answer
    final_answer --> exit

    classDef stateNode fill:#fff,stroke:#333,stroke-width:2px,color:#000;
""")
graphe: md.Mermaid = md.Mermaid(graph)
graphe