<a href="https://colab.research.google.com/github/ad71/ragbot/blob/master/langgraph_chat_agent_executor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -qU langchain langchain_openai langchain_community tavily-python langgraph

In [33]:
import os
from google.colab import userdata

os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')
os.environ['TAVILY_API_KEY'] = userdata.get('TAVILY_API_KEY')
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_API_KEY'] = userdata.get('LANGCHAIN_API_KEY')

In [3]:
# MODIFICATION FOR DYNAMIC RETURN

from langchain_core.pydantic_v1 import BaseModel, Field


class SearchTool(BaseModel):
    query: str = Field(description='query to look up online')
    return_direct: bool = Field(
        description='Whether or not the result of this should be returned directly to the user without you seeing what it is',
        default=False
    )

from langchain_community.tools.tavily_search import TavilySearchResults
tools = [TavilySearchResults(max_results=1, args_schema=SearchTool)]

In [34]:
from langchain_community.tools.tavily_search import TavilySearchResults
tools = [TavilySearchResults(max_results=1)]

In [35]:
from langgraph.prebuilt import ToolExecutor
tool_executor = ToolExecutor(tools)

In [36]:
from langchain_openai import ChatOpenAI
model = ChatOpenAI(temperature=0, streaming=True) # stream tokens idk why

In [37]:
from langchain_core.utils.function_calling import convert_to_openai_function

functions = [convert_to_openai_function(t) for t in tools]
model = model.bind_functions(functions)

In [7]:
# MODIFICATION for responding in a specific format
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.function_calling import convert_to_openai_function


class Response(BaseModel):
    '''Final response to the user'''
    temperature: float = Field(description='the temperature')
    other_notes: str = Field(description='any other notes about the weather')

functions = [convert_to_openai_function(t) for t in tools]
functions.append(convert_to_openai_function(Response))
model = model.bind_functions(functions)

In [38]:
from typing import TypedDict, Annotated, Sequence
import operator
from langchain_core.messages import BaseMessage
from langgraph.prebuilt import ToolInvocation
from langchain_core.messages import FunctionMessage
import json


class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add] # list of messages

In [39]:
# define the nodes
from langgraph.prebuilt import ToolInvocation
from langchain_core.messages import FunctionMessage
import json

def should_continue(state):
    messages = state['messages']
    last_message = messages[-1]

    if 'function_call' not in last_message.additional_kwargs:
        return 'end'
    else:
        return 'continue'

In [9]:
# MODIFICATION:
# return direct if required
def should_continue(state):
    messages = state['messages']
    last_message = messages[-1]

    # if there is no function call, then we finish
    if 'function_call' not in last_message.additional_kwargs:
        return 'end'

    # otherwise if there is, we check if it's supposed to return direct
    else:
        arguments = json.loads(last_message.additional_kwargs['function_call']['arguments'])

        if arguments.get('return_direct', False):
            return 'final'
        else:
            return 'continue'

In [9]:
# MODIFICATION:
# change the should continue function to check what function was called
# If Response was called - ie the function that is not a tool, but rather the formatted response, we should not continue in that case
def should_continue(state):
    messages = state['messages']
    last_message = messages[-1]

    if 'function_call' not in last_message.additional_kwargs:
        return 'end'

    # otherwise if there is, we need to check what type of function call it is
    elif last_message.additional_kwargs['function_call']['name'] == 'Response':
        return 'end'

    else:
        return 'continue'

In [40]:
def call_model(state):
    messages = state['messages']
    response = model.invoke(messages)

    return {'messages': [response]} # we return a list because this will get added to the existing list

In [29]:
# MODIFICATION for managing agent steps
# here we don't pass all messages to the model but rather only pass the five most recent.
# this is a pretty simplistic way to handle messages, and there may be other methods you want to look into depending on your use case.
def call_model(state):
    messages = state['messages'][-5:]
    response = model.invoke(messages)

    # we return a list because this will get added to the existing list
    return {'messages': [response]}

In [41]:
# standard call_tool
def call_tool(state):
    messages = state['messages']
    last_message = messages[-1]

    action = ToolInvocation(
        tool=last_message.additional_kwargs['function_call']['name'],
        tool_input=json.loads(last_message.additional_kwargs['function_call']['arguments'])
    )

    response = tool_executor.invoke(action)
    function_message = FunctionMessage(content=str(response), name=action.tool)
    return {'messages': [function_message]}

In [12]:
# MODIFICATION: human in the loop
def call_tool(state):
    messages = state['messages']
    last_message = messages[-1]

    action = ToolInvocation(
        tool=last_message.additional_kwargs['function_call']['name'],
        tool_input=json.loads(last_message.additional_kwargs['function_call']['arguments'])
    )

    response = input(prompt=f'[y/n] continue with {action}?')

    if response == 'n':
        raise ValueError

    response = tool_executor.invoke(action)
    function_message = FunctionMessage(content=str(response), name=action.tool)
    return {'messages': [function_message]}

In [13]:
# MODIFICATION: return direct if required
def call_tool(state):
    messages = state['messages']
    last_message = messages[-1]

    tool_name = last_message.additional_kwargs['function_call']['name']
    arguments = json.loads(last_message.additional_kwargs['function_call']['arguments'])

    # delete parameter because if we've reached the call_tool stage, the tool doesn't need to know
    if tool_name == 'tavily_search_results_json':
        if 'return_direct' in arguments:
            del arguments['return_direct']

    action = ToolInvocation(
        tool=tool_name,
        tool_input=arguments
    )

    response = tool_executor.invoke(action)
    function_message = FunctionMessage(content=str(response), name=action.tool)
    return {'messages': [function_message]}

In [42]:
# MODIFICATION: Here we create a node that returns an AIMessage with a tool call - we will use this at the start to force it call a tool
from langchain_core.messages import AIMessage
import json

def first_model(state):
    human_input = state['messages'][-1].content

    return {
        'messages': [
            AIMessage(
                content='',
                additional_kwargs={
                    'function_call': {
                        'name': 'tavily_search_results_json',
                        'arguments': json.dumps({'query': human_input})
                    }
                }
            )
        ]
    }

In [31]:
# define the graph

from langgraph.graph import StateGraph, END

workflow = StateGraph(AgentState)

workflow.add_node('agent', call_model)
workflow.add_node('action', call_tool)

workflow.set_entry_point('agent')

workflow.add_conditional_edges(
    'agent',
    should_continue,
    {
        'continue': 'action',
        'end': END
    }
)

workflow.add_edge('action', 'agent')
app = workflow.compile()

In [14]:
# MODIFICATION: return direct graph
from langgraph.graph import StateGraph, END

workflow = StateGraph(AgentState)

workflow.add_node('agent', call_model)
workflow.add_node('action', call_tool)
workflow.add_node('final', call_tool)

workflow.set_entry_point('agent')

workflow.add_conditional_edges(
    'agent',
    should_continue,
    {
        'continue': 'action', # if `tools`, then we call the tool node
        'final': 'final', # final call
        'end': END # otherwise we finish
    }
)

workflow.add_edge('action', 'agent')
workflow.add_edge('final', END)

app = workflow.compile()

In [44]:
# MODIFICATION: force calling a tool first
# ie Skip LLM call if required and directly call tool to fetch something
from langgraph.graph import StateGraph, END

workflow = StateGraph(AgentState)

# new entry point
workflow.add_node('first_agent', first_model)

workflow.add_node('agent', call_model)
workflow.add_node('action', call_tool)

workflow.set_entry_point('first_agent')

workflow.add_conditional_edges(
    'agent',
    should_continue,
    {
        'continue': 'action',
        'end': END
    }
)

workflow.add_edge('action', 'agent')

# after we call the first agent, we know we want to go to action
workflow.add_edge('first_agent', 'action')

app = workflow.compile()

In [16]:
from langchain_core.messages import HumanMessage

inputs = {'messages': [HumanMessage(content='what is the weather in SF')]}
app.invoke(inputs)

{'messages': [HumanMessage(content='what is the weather in SF'),
  AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{"query":"weather in San Francisco"}', 'name': 'tavily_search_results_json'}}, response_metadata={'finish_reason': 'function_call'}, id='run-7f10633a-6fe1-4b5c-b091-971786f342d0-0'),
  FunctionMessage(content='[{\'url\': \'https://www.weatherapi.com/\', \'content\': "{\'location\': {\'name\': \'San Francisco\', \'region\': \'California\', \'country\': \'United States of America\', \'lat\': 37.78, \'lon\': -122.42, \'tz_id\': \'America/Los_Angeles\', \'localtime_epoch\': 1716331448, \'localtime\': \'2024-05-21 15:44\'}, \'current\': {\'last_updated_epoch\': 1716330600, \'last_updated\': \'2024-05-21 15:30\', \'temp_c\': 22.2, \'temp_f\': 72.0, \'is_day\': 1, \'condition\': {\'text\': \'Partly cloudy\', \'icon\': \'//cdn.weatherapi.com/weather/64x64/day/116.png\', \'code\': 1003}, \'wind_mph\': 12.5, \'wind_kph\': 20.2, \'wind_degree\': 300, \'wind_d

In [45]:
from langchain_core.messages import HumanMessage

inputs = {'messages': [HumanMessage(content='what is the weather in sf?')]}
for output in app.stream(inputs):
    for key, value in output.items():
        print(f'Output from node {key}: ')
        print('----')
        print(value)

    print('\n----\n')

Output from node first_agent: 
----
{'messages': [AIMessage(content='', additional_kwargs={'function_call': {'name': 'tavily_search_results_json', 'arguments': '{"query": "what is the weather in sf?"}'}})]}

----

Output from node action: 
----
{'messages': [FunctionMessage(content='[{\'url\': \'https://www.weatherapi.com/\', \'content\': "{\'location\': {\'name\': \'San Francisco\', \'region\': \'California\', \'country\': \'United States of America\', \'lat\': 37.78, \'lon\': -122.42, \'tz_id\': \'America/Los_Angeles\', \'localtime_epoch\': 1716335266, \'localtime\': \'2024-05-21 16:47\'}, \'current\': {\'last_updated_epoch\': 1716335100, \'last_updated\': \'2024-05-21 16:45\', \'temp_c\': 22.2, \'temp_f\': 72.0, \'is_day\': 1, \'condition\': {\'text\': \'Partly cloudy\', \'icon\': \'//cdn.weatherapi.com/weather/64x64/day/116.png\', \'code\': 1003}, \'wind_mph\': 18.6, \'wind_kph\': 29.9, \'wind_degree\': 290, \'wind_dir\': \'WNW\', \'pressure_mb\': 1014.0, \'pressure_in\': 29.94, \'

In [17]:
# for human in the loop, use streaming
# for dynamically returning output directly
inputs = {'messages': [HumanMessage(content='what is the weather in sf? return this result directly by setting return_direct = True')]}
for output in app.stream(inputs):
    for key, value in output.items():
        print(f'Output from node {key}: ')
        print('----')
        print(value)

    print('\n----\n')

Output from node agent: 
----
{'messages': [AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{"query":"weather in San Francisco","return_direct":true}', 'name': 'tavily_search_results_json'}}, response_metadata={'finish_reason': 'function_call'}, id='run-8c5b6f79-58d0-4319-9dcb-e665448be99e-0')]}

----

Output from node final: 
----
{'messages': [FunctionMessage(content='[{\'url\': \'https://www.weatherapi.com/\', \'content\': "{\'location\': {\'name\': \'San Francisco\', \'region\': \'California\', \'country\': \'United States of America\', \'lat\': 37.78, \'lon\': -122.42, \'tz_id\': \'America/Los_Angeles\', \'localtime_epoch\': 1716331448, \'localtime\': \'2024-05-21 15:44\'}, \'current\': {\'last_updated_epoch\': 1716330600, \'last_updated\': \'2024-05-21 15:30\', \'temp_c\': 22.2, \'temp_f\': 72.0, \'is_day\': 1, \'condition\': {\'text\': \'Partly cloudy\', \'icon\': \'//cdn.weatherapi.com/weather/64x64/day/116.png\', \'code\': 1003}, \'wind_mph\': 12.5, \'