<a href="https://colab.research.google.com/github/ad71/ragbot/blob/master/multi_agent_workflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Multi-agent collaboration
Multiple agents working on the same state of messages

- Can be shared state
- Or independent / siloed

Then share results

In [11]:
%pip install -qU langchain langchain_openai langchain_community langchain_experimental langsmith pandas matplotlib langgraph

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/199.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━[0m [32m163.8/199.5 kB[0m [31m4.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.5/199.5 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os
from google.colab import userdata

os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')
os.environ['LANGCHAIN_API_KEY'] = userdata.get('LANGCHAIN_API_KEY')
os.environ['TAVILY_API_KEY'] = userdata.get('TAVILY_API_KEY')

# optional, add tracing in LangSmith
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_PROJECT'] = 'Multi-Agent Collaboration'

In [17]:
import json
import operator
import functools

from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    ChatMessage,
    HumanMessage,
    FunctionMessage
)

from langgraph.graph import END, StateGraph
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder

from langchain_core.tools import tool
from typing import Annotated, List, Sequence, Tuple, TypedDict, Union
from langchain_experimental.utilities import PythonREPL
from langchain_community.tools.tavily_search import TavilySearchResults

from langchain.agents import create_openai_functions_agent
from langchain_openai import ChatOpenAI

In [8]:
def create_agent(llm, tools, system_message: str):
    '''Create an agent'''
    functions = [convert_to_openai_function(t) for t in tools]

    prompt = ChatPromptTemplate.from_messages(
        [
            (
                'system',
                'You are a helpful AI assistant, collaborating with other assistants. '
                'Use the provided tools to progress towards answering the question. '
                'If you are unable to fully answer, that\'s okay, another assistant with different tools '
                'will help you where you left off. Execute what you can do to make progress. '
                'If you or any of the other assistants have the final answer or deliverable, '
                'prefix your response with FINAL ANSWER so the team knows to stop. '
                'You have access to the following tools: {tool_names}.\n{system_message}'
            ),
            MessagesPlaceholder(variable_name='messages')
        ]
    )

    prompt = prompt.partial(system_message=system_message)
    prompt = prompt.partial(tool_names=', '.join([tool.name for tool in tools]))
    return prompt | llm.bind_functions(functions)

In [13]:
tavily = TavilySearchResults(max_results=5)
repl = PythonREPL() # executes code locally, which can be unsafe when not sandboxed

@tool
def python_repl(code: Annotated[str, 'The python code to execute to generate your chart']):
    '''Use this to execute python code. If you want to see the output of a value, you should print it out with print(...). This is visible to the user'''
    try:
        result = repl.run(code)
    except BaseException as e:
        return f'Failed to execute. Error: {repr(e)}'
    return f'Successfully executed:\n```python\n{code}\n```\nStdout: {result}'

In [16]:
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    sender: str

In [18]:
def agent_node(state, agent, name):
    result = agent.invoke(state)

    if isinstance(result, FunctionMessage):
        # if its a function message, pass it as it is
        pass
    else:
        # if its an AI message, make the next node think it's a human message to be acted upon
        result = HumanMessage(**result.dict(exclude={'type', 'name'}), name=name)

    return {
        'messages': [result],
        'sender': name
    }

In [19]:
llm = ChatOpenAI(model='gpt-4-1106-preview')

In [20]:
# research agent and node
research_agent = create_agent(
    llm,
    [tavily],
    system_message='You should provide accurate data for the chart generator to use'
)
research_node = functools.partial(agent_node, agent=research_agent, name='Researcher')

# chart generator
chart_agent = create_agent(
    llm,
    [python_repl],
    system_message='Any charts you display will be visible by the user'
)
chart_node = functools.partial(agent_node, agent=chart_agent, name='Chart Generator')

In [21]:
# define tool node
tools = [tavily, python_repl]
tool_executor = ToolExecutor(tools)

def tool_node(state):
    '''This runs tools in the graph
    It takes in an agent action and calls that tool and returns the result
    '''
    messages = state['messages']
    last_message = messages[-1]

    tool_input = json.loads(last_message.additional_kwargs['function_call']['arguments'])

    if len(tool_input) == 1 and '__arg1' in tool_input:
        tool_input = next(iter(tool_input.values()))
    tool_name = last_message.additional_kwargs['function_call']['name']
    action = ToolInvocation(
        tool=tool_name,
        tool_input=tool_input
    )
    response = tool_executor.invoke(action)
    function_message = FunctionMessage(content=f'{tool_name} response: {str(response)}', name=action.tool)
    return {'messages': [function_message]}