<a href="https://colab.research.google.com/github/ad71/ragbot/blob/master/multi_agent_supervisor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -qU langchain langchain_openai langchain_community langchain_experimental langsmith pandas matplotlib langgraph

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m973.7/973.7 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m29.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.5/199.5 kB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.4/121.4 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.0/13.0 MB[0m [31m21.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.3/8.3 MB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m83.9/83.9 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m308.5/308.5 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━

In [2]:
import os
from google.colab import userdata

os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')
os.environ['LANGCHAIN_API_KEY'] = userdata.get('LANGCHAIN_API_KEY')
os.environ['TAVILY_API_KEY'] = userdata.get('TAVILY_API_KEY')

# optional, add tracing in LangSmith
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_PROJECT'] = 'Multi-Agent Collaboration'

In [14]:
import json
import operator
import functools

from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    ChatMessage,
    HumanMessage,
    FunctionMessage
)

from langgraph.graph import END, StateGraph
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder

from langchain_core.tools import tool
from typing import Annotated, List, Sequence, Tuple, TypedDict, Union
from langchain_experimental.tools import PythonREPLTool
from langchain_community.tools.tavily_search import TavilySearchResults

from langchain.agents import create_openai_functions_agent, AgentExecutor
from langchain_openai import ChatOpenAI

from langgraph.graph import StateGraph, END

In [6]:
tavily_tool = TavilySearchResults(max_results=5)
python_tool = PythonREPLTool()

In [22]:
def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str) -> AgentExecutor:
    prompt = ChatPromptTemplate.from_messages(
        [
            ('system', system_prompt),
            MessagesPlaceholder(variable_name='messages'),
            MessagesPlaceholder(variable_name='agent_scratchpad')
        ]
    )
    agent = create_openai_functions_agent(llm, tools, prompt)
    executor = AgentExecutor(agent=agent, tools=tools)
    return executor

In [11]:
def agent_node(state, agent, name):
    # convert agent result which is an AI message into a Human message with a name
    result = agent.invoke(state)
    return {'messages': [HumanMessage(content=result['output'], name=name)]}

## Create Agent Supervisor
Use function calling to choose the next worker node or finish processing

In [13]:
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

members = ['Researcher', 'Coder']
system_prompt = (
    'You are a supervisor tasked with managing a conversation between the '
    'following workers: {members}. Given the following user request, '
    'respond with the worker to act next. Each worker will perform a '
    'task and respond with their results and status. When finished, '
    'respond with FINISH.'
)

# our team supervisor is an LLM node. It just picks the next agent to process and decides when the work in completed
options = ['FINISH'] + members

# using openai function calling can make output parsing easier for us
function_def = {
    'name': 'route',
    'description': 'Select the next role.',
    'parameters': {
        'title': 'routeSchema',
        'type': 'object',
        'properties': {
            'next': {
                'title': 'Next',
                'anyOf': [
                    {'enum': options},
                ]
            }
        },
        'required': ['next']
    }
}

prompt = ChatPromptTemplate.from_messages(
    [
        ('system', system_prompt),
        MessagesPlaceholder(variable_name='messages'),
        (
            'system',
            'Given the conversation above, who should act next? '
            'Or should we FINISH? Select one of: {options}'
        )
    ]
).partial(options=str(options), members=', '.join(members))

llm = ChatOpenAI(model='gpt-4-1106-preview')
supervisor_chain = prompt | llm.bind_functions(functions=[function_def], function_call='route') | JsonOutputFunctionsParser()

In [25]:
# create graph
class AgentState(TypedDict):
    # the annotation tells the graph that new messages will always be added to the current states
    messages: Annotated[Sequence[BaseMessage], operator.add]
    # the next field indicates where to route to next
    next: str

research_agent = create_agent(llm, [tavily_tool], 'You are a web researcher.')
research_node = functools.partial(agent_node, agent=research_agent, name='Researcher')

# WARNING: This performs arbitrary code execution. Proceed with CAUTION
code_agent = create_agent(llm, [python_tool], 'You may generate safe python code to analyze data and generate charts using matplotlib')
code_node = functools.partial(agent_node, agent=code_agent, name='Coder')

workflow = StateGraph(AgentState)
workflow.add_node('Researcher', research_node)
workflow.add_node('Coder', code_node)
workflow.add_node('supervisor', supervisor_chain)

In [26]:
# we want our workers to always report back to the supervisor when done
for member in members:
    workflow.add_edge(member, 'supervisor')

# the supervisor populates the 'next' field in the graph state which routes to a node or finishes
conditional_map = {k: k for k in members}
conditional_map['FINISH'] = END

workflow.add_conditional_edges('supervisor', lambda x: x['next'], conditional_map)
workflow.set_entry_point('supervisor')

graph = workflow.compile()

In [28]:
# invoke the team

for s in graph.stream({'messages': [HumanMessage(content='Code hello world and print it to the terminal')]}):
    if '__end__' not in s:
        print(s)
        print('--------')

{'supervisor': {'next': 'Coder'}}
--------
{'Coder': {'messages': [HumanMessage(content="The code `print('Hello, World!')` was executed and it printed `Hello, World!` to the terminal.", name='Coder')]}}
--------
{'supervisor': {'next': 'FINISH'}}
--------


In [29]:
for s in graph.stream(
    {'messages': [HumanMessage(content='Write a brief research report on pikas.')]},
    {'recursion_limit': 100}):
    if '__end__' not in s:
        print(s)
        print('--------')

{'supervisor': {'next': 'Researcher'}}
--------
--------
{'supervisor': {'next': 'FINISH'}}
--------
