# Dynamically Returning Directly
In this example we will build a chat executor where the LLM can optionally decide to return the result of a tool call as the final answer. This is useful in cases where you have tools that can sometimes generate responses that are acceptable as final answers, and you want to use the LLM to determine when that is the case

This examples builds off the base chat executor. It is highly recommended you learn about that executor before going through this notebook. You can find documentation for that example here.

Any modifications of that example are called below **MODIFICATION**, so if you are  are looking for the differences you can just search for that.

## Set up

In [3]:
import os
import getpass

os.environ["OPENAI_API_KEY"] = getpass.getpass("OpenAI API Key:")
os.environ["TAVILY_API_KEY"] = getpass.getpass("Tavily API Key:")

OpenAI API Key: ········
Tavily API Key: ········


Optionally, we can set API key for LangSmith tracing, which will give us best-in-class observability.

In [4]:
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("LangSmith API Key:")

LangSmith API Key: ········


## Set up the tools
We will first define the tools we want to use. For this simple example, we will use a built-in search tool via Tavily. However, it is really easy to create your own tools - see documentation here on how to do that.

**MODIFICATION**
  
We overwrite the default schema of the input tool to have an additional parameter for returning directly.

In [36]:
from pydantic import BaseModel, Field

class SearchTool(BaseModel):
    query: str = Field(description="The query passed to the search tool")
    return_direct: bool = Field(
        description="Whether or not the result of the search tool should be returned directly to the user",
        default=False
    )

In [37]:
from langchain_community.tools.tavily_search import TavilySearchResults

search_tool = TavilySearchResults(max_results=1, args_schema=SearchTool)
tools = [search_tool]

In [38]:
from langgraph.prebuilt import ToolExecutor

tool_executor = ToolExecutor(tools)

## Set up model

In [39]:
from langchain_openai import ChatOpenAI

model = ChatOpenAI(temperature=0, streaming=True)

In [40]:
from  langchain_core.utils.function_calling import convert_to_openai_function

functions = [convert_to_openai_function(t) for t in tools]
model_with_tools = model.bind_functions(functions)

## Define the agent state

In [41]:
from typing import TypedDict, Annotated, Sequence
import operator
from langchain_core.messages import BaseMessage

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]

## Define the nodes

In [42]:
from langgraph.prebuilt import ToolInvocation
import json
from langchain_core.messages import FunctionMessage

**MODIFICATION**

We change the `should_continue` function to check whether return_direct was set to True

In [49]:
# Define the function that determines whether to continue or not
def should_continue(state):
    messages = state["messages"]
    last_message = messages[-1]
    # If there is no function call, then we finish
    if "function_call" not in last_message.additional_kwargs:
        return "end"
    # Otherwise if there is, we check if it's suppose to return direct
    else:
        arguments = json.loads(
            last_message.additional_kwargs["function_call"]["arguments"]
        )
        if arguments.get("return_direct"):
            return "final"
        else:
            return "continue"

In [50]:
# Define the function that calls the model
def call_model(state):
    messages = state["messages"]
    response = model_with_tools.invoke(messages)
    # We return a list, because this will get added to the existing list
    return {"messages": [response]}

__MODIFICATION__

We change the tool calling to get rid of the return_direct parameter (not used in the actual tool call)

In [51]:
# Define the function to execute tools
def call_tool(state):
    messages = state["messages"]
    # Based on the continue condition
    # we know the last message involves a function call
    last_message = messages[-1]
    # We construct an ToolInvocation from the function_call
    tool_name = last_message.additional_kwargs["function_call"]["name"]
    arguments = json.loads(last_message.additional_kwargs["function_call"]["arguments"])
    if tool_name == "tavily_search_results_json":
        if "return_direct" in arguments:
            del arguments["return_direct"]
    action = ToolInvocation(
        tool=tool_name,
        tool_input=arguments,
    )
    # We call the tool_executor and get back a response
    response = tool_executor.invoke(action)
    # We use the response to create a FunctionMessage
    function_message = FunctionMessage(content=str(response), name=action.tool)
    # We return a list, because this will get added to the existing list
    return {"messages": [function_message]}

## Define the graph
We can now put it all together and define the graph!

__MODIFICATION__

We add a separate node for any tool call where `return_direct=True`. The reason this is needed is that after this node we want to end, while after other tool calls we want to go back to the LLM.

In [52]:
from langgraph.graph import StateGraph, END

# Define a new graph
workflow = StateGraph(AgentState)

# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("action", call_tool)
workflow.add_node("final", call_tool)

# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")

# We now add a conditional edge
workflow.add_conditional_edges(
    # First, we define the start node. We use `agent`.
    # This means these are the edges taken after the `agent` node is called.
    "agent",
    # Next, we pass in the function that will determine which node is called next.
    should_continue,
    # Finally we pass in a mapping.
    # The keys are strings, and the values are other nodes.
    # END is a special node marking that the graph should finish.
    # What will happen is we will call `should_continue`, and then the output of that
    # will be matched against the keys in this mapping.
    # Based on which one it matches, that node will then be called.
    {
        # If `tools`, then we call the tool node.
        "continue": "action",
        # Final call
        "final": "final",
        # Otherwise we finish.
        "end": END,
    },
)

# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("action", "agent")
workflow.add_edge("final", END)

# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()

## Use it!
We can now use it! This now exposes the same interface as all other LangChain runnables.

In [53]:
from langchain_core.messages import HumanMessage

inputs = {"messages": [HumanMessage(content="what is the weather in Lagos, Nigeria? Please set return_direct = False")]}
for output in app.stream(inputs):
    # stream() yields dictionaries with output keyed by node name
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("---")
        print(value)
    print("\n---\n")

Output from node 'agent':
---
{'messages': [AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{"query":"weather in Lagos, Nigeria","return_direct":false}', 'name': 'tavily_search_results_json'}}, response_metadata={'finish_reason': 'function_call'})]}

---

Output from node 'action':
---
{'messages': [FunctionMessage(content='[{\'url\': \'https://www.weatherapi.com/\', \'content\': "{\'location\': {\'name\': \'Lagos\', \'region\': \'Lagos\', \'country\': \'Nigeria\', \'lat\': 6.45, \'lon\': 3.4, \'tz_id\': \'Africa/Lagos\', \'localtime_epoch\': 1712818511, \'localtime\': \'2024-04-11 7:55\'}, \'current\': {\'last_updated_epoch\': 1712817900, \'last_updated\': \'2024-04-11 07:45\', \'temp_c\': 26.0, \'temp_f\': 78.8, \'is_day\': 1, \'condition\': {\'text\': \'Partly cloudy\', \'icon\': \'//cdn.weatherapi.com/weather/64x64/day/116.png\', \'code\': 1003}, \'wind_mph\': 8.1, \'wind_kph\': 13.0, \'wind_degree\': 320, \'wind_dir\': \'NW\', \'pressure_mb\': 1010.0, \'pr

**Setting `return_redirect` to True**

In [54]:
from langchain_core.messages import HumanMessage

inputs = {
    "messages": [
        HumanMessage(
            content="what is the weather in Lagos, Nigeria? Please set return_direct = True"
        )
    ]
}
for output in app.stream(inputs):
    # stream() yields dictionaries with output keyed by node name
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("---")
        print(value)
    print("\n---\n")

Output from node 'agent':
---
{'messages': [AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{"query":"weather in Lagos, Nigeria","return_direct":true}', 'name': 'tavily_search_results_json'}}, response_metadata={'finish_reason': 'function_call'})]}

---

Output from node 'final':
---
{'messages': [FunctionMessage(content='[{\'url\': \'https://www.weatherapi.com/\', \'content\': "{\'location\': {\'name\': \'Lagos\', \'region\': \'Lagos\', \'country\': \'Nigeria\', \'lat\': 6.45, \'lon\': 3.4, \'tz_id\': \'Africa/Lagos\', \'localtime_epoch\': 1712818511, \'localtime\': \'2024-04-11 7:55\'}, \'current\': {\'last_updated_epoch\': 1712817900, \'last_updated\': \'2024-04-11 07:45\', \'temp_c\': 26.0, \'temp_f\': 78.8, \'is_day\': 1, \'condition\': {\'text\': \'Partly cloudy\', \'icon\': \'//cdn.weatherapi.com/weather/64x64/day/116.png\', \'code\': 1003}, \'wind_mph\': 8.1, \'wind_kph\': 13.0, \'wind_degree\': 320, \'wind_dir\': \'NW\', \'pressure_mb\': 1010.0, \'pres