In [2]:
# code - https://python.langchain.com/docs/langgraph

import os
from dotenv import load_dotenv
load_dotenv()
azure_openai_endpoint = os.environ["AZURE_OPENAI_ENDPOINT"]
azure_openai_key = os.environ["AZURE_OPENAI_KEY"]
search_endpoint = os.environ["AZURE_SEARCH_SERVICE_ENDPOINT"]
search_key = os.environ["AZURE_SEARCH_ADMIN_KEY"]



In [4]:
from langchain_openai import ChatOpenAI, AzureChatOpenAI

from langchain_core.messages import HumanMessage
from langgraph.graph import END, MessageGraph

# model = ChatOpenAI(temperature=0)


model = AzureChatOpenAI(
    deployment_name="gpt-4",
    api_key=azure_openai_key,
    azure_endpoint=azure_openai_endpoint,
    api_version="2023-09-01-preview",    
)


graph = MessageGraph()

graph.add_node("oracle", model) #node: calls the model with the given input   #The chat model returns an AIMessage. LangGraph adds this to the state.
graph.add_edge("oracle", END)   # ((oracle)) --edge-- ((END))

graph.set_entry_point("oracle") # oracle node is entry point

runnable = graph.compile()

In [5]:
runnable.invoke(HumanMessage("What is 1 + 1?"))

[HumanMessage(content='What is 1 + 1?', id='f8ddd2a4-da25-442b-bdbb-ea5876e52615'),
 AIMessage(content='2', response_metadata={'finish_reason': 'stop', 'logprobs': None, 'content_filter_results': {'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, 'sexual': {'filtered': False, 'severity': 'safe'}, 'violence': {'filtered': False, 'severity': 'safe'}}}, id='d0903404-3a22-475c-bd63-eef88c8cd2ee')]

Conditional edges

Now, let's move onto something a little bit less trivial. Because math can be difficult for LLMs, let's allow the LLM to conditionally call a "multiply" node using tool calling.

We'll recreate our graph with an additional "multiply" that will take the result of the most recent message, if it is a tool call, and calculate the result. We'll also bind the calculator to the OpenAI model as a tool to allow the model to optionally use the tool necessary to respond to the current state:

In [8]:
import json
from langchain_core.messages import ToolMessage
from langchain_core.tools import tool
from langchain_core.utils.function_calling import convert_to_openai_tool

from langchain_core.messages.base import BaseMessage
from typing import List


@tool
def multiply(first_number: int, second_number: int):
    """Multiplies two numbers together."""
    return first_number * second_number




model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])

graph = MessageGraph()

def invoke_model(state: List[BaseMessage]):
    return model_with_tools.invoke(state)

graph.add_node("oracle", invoke_model)

def invoke_tool(state: List[BaseMessage]):
    tool_calls = state[-1].additional_kwargs.get("tool_calls", [])
    multiply_call = None

    for tool_call in tool_calls:
        if tool_call.get("function").get("name") == "multiply":
            multiply_call = tool_call

    if multiply_call is None:
        raise Exception("No adder input found.")

    res = multiply.invoke(
        json.loads(multiply_call.get("function").get("arguments"))
    )

    return ToolMessage(
        tool_call_id=multiply_call.get("id"),
        content=res
    )

graph.add_node("multiply", invoke_tool)

graph.add_edge("multiply", END)

graph.set_entry_point("oracle")

In [10]:
def router(state: List[BaseMessage]):
    tool_calls = state[-1].additional_kwargs.get("tool_calls", [])
    if len(tool_calls):
        return "multiply"
    else:
        return "end"

graph.add_conditional_edges("oracle", router, {
    "multiply": "multiply",
    "end": END,
})

In [11]:
runnable = graph.compile()

runnable.invoke(HumanMessage("What is 123 * 456?"))

[HumanMessage(content='What is 123 * 456?', id='47480786-d330-4843-aab0-a0cfb0026251'),
 AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_wzkcQtzyfmS3exSXrCKyfPAW', 'function': {'arguments': '{\n  "first_number": 123,\n  "second_number": 456\n}', 'name': 'multiply'}, 'type': 'function'}]}, response_metadata={'finish_reason': 'tool_calls', 'logprobs': None, 'content_filter_results': {}}, id='1129443e-da83-402f-8447-08d6309dbf99'),
 ToolMessage(content='56088', id='69f2dba1-9537-4f6f-8203-ecaf713fb361', tool_call_id='call_wzkcQtzyfmS3exSXrCKyfPAW')]