<a href="https://colab.research.google.com/github/ad71/ragbot/blob/master/langgraph_chat_agent_executor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -qU langchain langchain_openai tavily-python langgraph

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.6/77.6 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m44.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.9/302.9 kB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.2/121.2 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m320.6/320.6 kB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.3/49.3 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━

In [2]:
import os
from google.colab import userdata

os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')
os.environ['TAVILY_API_KEY'] = userdata.get('TAVILY_API_KEY')
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_API_KEY'] = userdata.get('LANGCHAIN_API_KEY')

In [3]:
from langchain_community.tools.tavily_search import TavilySearchResults
tools = [TavilySearchResults(max_results=1)]

In [4]:
from langgraph.prebuilt import ToolExecutor
tool_executor = ToolExecutor(tools)

In [5]:
from langchain_openai import ChatOpenAI
model = ChatOpenAI(temperature=0, streaming=True) # stream tokens idk why

In [7]:
from langchain_core.utils.function_calling import convert_to_openai_function

functions = [convert_to_openai_function(t) for t in tools]
model = model.bind_functions(functions)

In [8]:
from typing import TypedDict, Annotated, Sequence
import operator
from langchain_core.messages import BaseMessage


class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add] # list of messages

In [9]:
# define the nodes
from langgraph.prebuilt import ToolInvocation
from langchain_core.messages import FunctionMessage
import json

def should_continue(state):
    messages = state['messages']
    last_message = messages[-1]

    if 'function_call' not in last_message.additional_kwargs:
        return 'end'
    else:
        return 'continue'

def call_model(state):
    messages = state['messages']
    response = model.invoke(messages)

    return {'messages': [response]} # we return a list because this will get added to the existing list

def call_tool(state):
    messages = state['messages']
    last_message = messages[-1]

    action = ToolInvocation(
        tool=last_message.additional_kwargs['function_call']['name'],
        tool_input=json.loads(last_message.additional_kwargs['function_call']['arguments'])
    )

    response = tool_executor.invoke(action)
    function_message = FunctionMessage(content=str(response), name=action.tool)
    return {'messages': [function_message]}

In [10]:
# define the graph

from langgraph.graph import StateGraph, END

workflow = StateGraph(AgentState)

workflow.add_node('agent', call_model)
workflow.add_node('action', call_tool)

workflow.set_entry_point('agent')

workflow.add_conditional_edges(
    'agent',
    should_continue,
    {
        'continue': 'action',
        'end': END
    }
)

workflow.add_edge('action', 'agent')
app = workflow.compile()

In [12]:
from langchain_core.messages import HumanMessage

inputs = {'messages': [HumanMessage(content='what is the weather in SF')]}
app.invoke(inputs)

{'messages': [HumanMessage(content='what is the weather in SF'),
  AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{"query":"weather in San Francisco"}', 'name': 'tavily_search_results_json'}}, response_metadata={'finish_reason': 'function_call'}, id='run-b233d7d5-f271-4aae-9a27-625438b9d1d1-0'),
  FunctionMessage(content='[{\'url\': \'https://www.weatherapi.com/\', \'content\': "{\'location\': {\'name\': \'San Francisco\', \'region\': \'California\', \'country\': \'United States of America\', \'lat\': 37.78, \'lon\': -122.42, \'tz_id\': \'America/Los_Angeles\', \'localtime_epoch\': 1715954813, \'localtime\': \'2024-05-17 7:06\'}, \'current\': {\'last_updated_epoch\': 1715954400, \'last_updated\': \'2024-05-17 07:00\', \'temp_c\': 12.2, \'temp_f\': 54.0, \'is_day\': 1, \'condition\': {\'text\': \'Overcast\', \'icon\': \'//cdn.weatherapi.com/weather/64x64/day/122.png\', \'code\': 1009}, \'wind_mph\': 11.9, \'wind_kph\': 19.1, \'wind_degree\': 250, \'wind_dir\': 