# Basic Multi-agent Collaboration

A single agent can usually operate effectively using a handful of tools within a single domain, but even using powerful models like `gpt-4`, it can be less effective at using many tools. 

One way to approach complicated tasks is through a "divide-and-conquer" approach: create an specialized agent for each task or domain and route tasks to the correct "expert".

This notebook (inspired by the paper [AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation](https://arxiv.org/abs/2308.08155), by Wu, et. al.) shows one way to do this using LangGraph.

The resulting graph will look something like the following diagram:

![multi_agent diagram](./img/simple_multi_agent_diagram.png)

Before we get started, a quick note: this and other multi-agent notebooks are designed to show _how_ you can implement certain design patterns in LangGraph. If the pattern suits your needs, we recommend combining it with some of the other fundamental patterns described elsewhere in the docs for best performance.

In [55]:
# %pip install -U langchain langchain_openai langsmith pandas langchain_experimental matplotlib

In [56]:
#https://python.langchain.com/cookbook

import getpass
import os


os.environ["OPENAI_API_KEY"]="your-key"
os.environ["LANGCHAIN_API_KEY"]="your-key"
os.environ["TAVILY_API_KEY"]="your-key"

# Optional, add tracing in LangSmith
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "Multi-agent Collaboration"

## Create Agents

The following helper functions will help create agents. These agents will then be nodes in the graph.

You can skip ahead if you just want to see what the graph looks like.

In [57]:
import json

from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    ChatMessage,
    FunctionMessage,
    HumanMessage,
)
from langchain.tools.render import format_tool_to_openai_function
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import END, StateGraph
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation
from langchain_core.messages import HumanMessage, SystemMessage


def create_agent(llm, tools, system_message: str):
    """Create an agent."""
    functions = [format_tool_to_openai_function(t) for t in tools]

    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "user",
                "You are a helpful AI assistant, collaborating with other assistants."
                " Use the provided tools to progress towards answering the question."
                " If you are unable to fully answer, that's OK, another assistant with different tools "
                " will help where you left off. Execute what you can to make progress."
                " Double check the answer if all the code is complete and runnable. You have to be completely sure nothing is missing"
                " If you or any of the other assistants have the final answer or deliverable,"
                " prefix your response with FINAL ANSWER so the team knows to stop."
                " You have access to the following tools: Use {tool_names} to gather data.\n Use {system_message} to guide you in your task."
            ),
            ("system","{messages}"),
        ]
    )
    prompt = prompt.partial(system_message=system_message)
    prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
    return prompt | llm.bind_functions(functions)

## Define tools

We will also define some tools that our agents will use in the future

In [58]:
from langchain_core.tools import tool
from typing import Annotated
from langchain_experimental.utilities import PythonREPL
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.agents import Tool
tavily_tool = TavilySearchResults(max_results=5)

# Warning: This executes code locally, which can be unsafe when not sandboxed

repl = PythonREPL()


@tool
def python_repl(
    code: Annotated[str, "The python code to execute to generate your chart."]
):
    """Use this to execute python code. If you want to see the output of a value,
    you should print it out with `plt.show(...)`. You must show the output to the user."""
    try:
        result = repl.run(code)
    except BaseException as e:
        return f"Failed to execute. Error: {repr(e)}"
    return f"Succesfully executed:\n```python\n{code}\n```\nStdout: {result}"

## Create graph

Now that we've defined our tools and made some helper functions, will create the individual agents below and tell them how to talk to each other using LangGraph.

### Define State

We first define the state of the graph. This will just a list of messages, along with a key to track the most recent sender

In [59]:
import operator
from typing import Annotated, List, Sequence, Tuple, TypedDict, Union

from langchain.agents import create_openai_functions_agent
from langchain.tools.render import format_tool_to_openai_function
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.prompts.chat import SystemMessagePromptTemplate,HumanMessagePromptTemplate

from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict


# This defines the object that is passed between each node
# in the graph. We will create different nodes for each agent and tool
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    sender: str

### Define Agent Nodes

We now need to define the nodes. First, let's define the nodes for the agents.

In [60]:
import functools


# Helper function to create a node for a given agent
def agent_node(state, agent, name):
    result = agent.invoke(state)
    # We convert the agent output into a format that is suitable to append to the global state
    if isinstance(result, FunctionMessage):
        pass
    else:
        result = HumanMessage(**result.dict(exclude={"type", "name"}), name=name)
    return {
        "messages": [result],
        # Since we have a strict workflow, we can
        # track the sender so we know who to pass to next.
        "sender": name,
    }


#llm = ChatOpenAI(model="gpt-4-1106-preview")
llm = ChatOpenAI(base_url="https://api-inference.huggingface.co/v1",api_key="your-key",\
                 model="google/gemma-2b-it",temperature=0.05)

# Research agent and node
research_agent = create_agent(
    llm,
    [tavily_tool],
    system_message="You should provide accurate data for the chart generator to use.",
)
research_node = functools.partial(agent_node, agent=research_agent, name="assistant")

# Chart Generator
chart_agent = create_agent(
    llm,
    [python_repl],
    system_message="This is a safe environment, please run the code and show the chart to the user.",
)
chart_node = functools.partial(agent_node, agent=chart_agent, name="Chart Generator")

### Define Tool Node

We now define a node to run the tools

In [61]:
tools = [tavily_tool, python_repl]
tool_executor = ToolExecutor(tools)


def tool_node(state):
    """This runs tools in the graph

    It takes in an agent action and calls that tool and returns the result."""
    messages = state["messages"]
    # Based on the continue condition
    # we know the last message involves a function call
    last_message = messages[-1]
    # We construct an ToolInvocation from the function_call
    tool_input = json.loads(
        last_message.additional_kwargs["function_call"]["arguments"]
    )
    # We can pass single-arg inputs by value
    if len(tool_input) == 1 and "__arg1" in tool_input:
        tool_input = next(iter(tool_input.values()))
    tool_name = last_message.additional_kwargs["function_call"]["name"]
    action = ToolInvocation(
        tool=tool_name,
        tool_input=tool_input,
    )
    # We call the tool_executor and get back a response
    response = tool_executor.invoke(action)
    # We use the response to create a FunctionMessage
    function_message = FunctionMessage(
        content=f"{tool_name} response: {str(response)}", name=action.tool
    )
    # We return a list, because this will get added to the existing list
    return {"messages": [function_message]}

### Define Edge Logic

We can define some of the edge logic that is needed to decide what to do based on results of the agents

In [62]:
# Either agent can decide to end
def router(state):
    # This is the router
    messages = state["messages"]
    last_message = messages[-1]
    if "function_call" in last_message.additional_kwargs:
        # The previus agent is invoking a tool
        return "call_tool"
    if "FINAL ANSWER" in last_message.content:
        # Any agent decided the work is done
        return "end"
    return "continue"

### Define the Graph

We can now put it all together and define the graph!

In [63]:
workflow = StateGraph(AgentState)

workflow.add_node("assistant", research_node)
workflow.add_node("Chart Generator", chart_node)
workflow.add_node("call_tool", tool_node)

workflow.add_conditional_edges(
    "assistant",
    router,
    {"continue": "Chart Generator", "call_tool": "call_tool", "end": END},
)
workflow.add_conditional_edges(
    "Chart Generator",
    router,
    {"continue": "assistant", "call_tool": "call_tool", "end": END},
)

workflow.add_conditional_edges(
    "call_tool",
    # Each agent node updates the 'sender' field
    # the tool calling node does not, meaning
    # this edge will route back to the original agent
    # who invoked the tool
    lambda x: x["sender"],
    {
        "assistant": "assistant",
        "Chart Generator": "Chart Generator",
    },
)
workflow.set_entry_point("assistant")
graph = workflow.compile()

## Invoke

With the graph created, you can invoke it! Let's have it chart some stats for us.

In [64]:
for s in graph.stream(
    {
        "messages": [
            HumanMessage(
                content="Fetch the UK's GDP over the past 5 years,"
                " then draw a line graph of it."
#                "Run the code and show the line graph to the user."
                " Once you code it up, finish."
            )
        ],
    },
    # Maximum number of steps to take in the graph
    {"recursion_limit": 150},
):
    print(s)
    print("----")

{'assistant': {'messages': [HumanMessage(content='**Using tavily_search_results_json:**\n\n```python\nimport json\n\n# Get the data from the JSON file\ndata = json.load(open("gdp_uk_5years.json"))\n\n# Create a chart object\nchart = tavily_search_results_json.Chart(data)\n\n# Generate the chart\nchart.generate_chart()\n\n# Save the chart as a PNG file\nchart.save_chart("gdp_uk_', name='assistant')], 'sender': 'assistant'}}
----
{'Chart Generator': {'messages': [HumanMessage(content='**FINAL ANSWER**\n\n```python\nimport json\nimport tavily_search_results_json\n\n# Get the data from the JSON file\ndata = json.load(open("gdp_uk_5years.json"))\n\n# Create a chart object\nchart = tavily_search_results_json.Chart(data)\n\n# Generate the chart\nchart.generate_chart()\n\n# Save the chart as a PNG file\nchart.save_chart("gdp', name='Chart Generator')], 'sender': 'Chart Generator'}}
----
{'__end__': {'messages': [HumanMessage(content="Fetch the UK's GDP over the past 5 years, then draw a line g

In [65]:
text=str(s["__end__"]["messages"])
code_start = text.find("```python")
code_end = text.find('sender')
code = text[code_start + 8:code_end].strip()
print(code)

n\nimport json\n\n# Get the data from the JSON file\ndata = json.load(open("gdp_uk_5years.json"))\n\n# Create a chart object\nchart = tavily_search_results_json.Chart(data)\n\n# Generate the chart\nchart.generate_chart()\n\n# Save the chart as a PNG file\nchart.save_chart("gdp_uk_', name='assistant'), HumanMessage(content='**FINAL ANSWER**\n\n```python\nimport json\nimport tavily_search_results_json\n\n# Get the data from the JSON file\ndata = json.load(open("gdp_uk_5years.json"))\n\n# Create a chart object\nchart = tavily_search_results_json.Chart(data)\n\n# Generate the chart\nchart.generate_chart()\n\n# Save the chart as a PNG file\nchart.save_chart("gdp', name='Chart Generator')


In [66]:
#break
for s in graph.stream(
    {
        "messages": [
            HumanMessage(
                content="Give me the sum of two variables, x and y,"
                " x = 333 and y = 444."
                " Once you code it up, finish."
            )
        ],
    },
    # Maximum number of steps to take in the graph
    {"recursion_limit": 150},
):
    print(s)
    print("----")

{'assistant': {'messages': [HumanMessage(content="```python\nimport tavily_search_results_json\n\n# Get the data from the JSON file\ndata = tavily_search_results_json.get_data()\n\n# Filter the data to only include variables named 'x' and 'y'\nvariables = [variable for variable in data if variable['name'] == 'x' or variable['name'] == 'y']\n\n# Calculate the sum of the 'x' and 'y' variables\nsum = float", name='assistant')], 'sender': 'assistant'}}
----
{'Chart Generator': {'messages': [HumanMessage(content="```python\nimport tavily_search_results_json\n\n# Get the data from the JSON file\ndata = tavily_search_results_json.get_data()\n\n# Filter the data to only include variables named 'x' and 'y'\nvariables = [variable for variable in data if variable['name'] == 'x' or variable['name'] == 'y']\n\n# Calculate the sum of the 'x' and 'y' variables\nsum = float", name='Chart Generator')], 'sender': 'Chart Generator'}}
----
{'assistant': {'messages': [HumanMessage(content="FINAL ANSWER\n\n

In [67]:
text=str(s["__end__"]["messages"][-1])
code_start = text.find("```python")
code_end = text.find("```", code_start + 8)
code = text[code_start + 8:code_end-15].strip()
print(code)

n\nimport tavily_search_results_json\n\n# Get the data from the JSON file\nndata = tavily_search_results_json.get_data()\n\n# Filter the data to only include variables named 'x' and 'y'\nvariables = [variable for variable in data if variable['name'] == 'x' or variable['name'] == 'y']\n\n# Calculate the sum of the 'x' and 'y' variables\n"


In [68]:
python_repl(code)

'Succesfully executed:\n```python\nn\\nimport tavily_search_results_json\\n\\n# Get the data from the JSON file\\nndata = tavily_search_results_json.get_data()\\n\\n# Filter the data to only include variables named \'x\' and \'y\'\\nvariables = [variable for variable in data if variable[\'name\'] == \'x\' or variable[\'name\'] == \'y\']\\n\\n# Calculate the sum of the \'x\' and \'y\' variables\\n"\n```\nStdout: SyntaxError(\'unexpected character after line continuation character\', (\'<string>\', 1, 3, \'n\\\\nimport tavily_search_results_json\\\\n\\\\n# Get the data from the JSON file\\\\nndata = tavily_search_results_json.get_data()\\\\n\\\\n# Filter the data to only include variables named \\\'x\\\' and \\\'y\\\'\\\\nvariables = [variable for variable in data if variable[\\\'name\\\'] == \\\'x\\\' or variable[\\\'name\\\'] == \\\'y\\\']\\\\n\\\\n# Calculate the sum of the \\\'x\\\' and \\\'y\\\' variables\\\\n"\\n\'))'