In [38]:
import os
from dotenv import load_dotenv

In [39]:
# os.environ["LANGCHAIN_TRACING_V2"] = "true"
load_dotenv()

True

In [40]:
from langchain_core.pydantic_v1 import BaseModel, Field


class SearchTool(BaseModel):
    """Look up things online, optionally returning directly"""

    query: str = Field(description="query to look up online")
    return_direct: bool = Field(
        description="Whether or the result of this should be returned directly to the user without you seeing what it is",
        default=False,
    )

In [41]:
from langchain_community.tools.tavily_search import TavilySearchResults
tools = [TavilySearchResults(max_results=1,args_schema=SearchTool)]

In [42]:
from langgraph.prebuilt import ToolExecutor
tool_executor=ToolExecutor(tools)

In [43]:
from langchain_openai import ChatOpenAI
model=ChatOpenAI(temperature=0.7,streaming=True)
model=model.bind_tools(tools)

In [44]:
class Response(BaseModel):
    """Final response to the user"""

    temperature: float = Field(description="the temperature")
    other_notes: str = Field(description="any other notes about the weather")


model = model.bind_tools(tools + [Response])

In [45]:
from typing import TypedDict,Sequence,Annotated
import operator
from langchain_core.messages import BaseMessage

In [46]:
class agentstate(TypedDict):
    messages: Annotated[Sequence[BaseMessage],operator.add]

In [47]:
from langchain_core.messages import ToolMessage
from langgraph.prebuilt import ToolInvocation
from typing import Literal

In [48]:
def should_continue(state) -> Literal["continue", "end"]:
    messages = state["messages"]
    last_message = messages[-1]
    if not last_message.tool_calls:
        return "end"
    if last_message.tool_calls[0]["name"] == "Response":
        return "end"
    return "continue"
    
def call_model(state):
    messages = state["messages"][-5:]
    response = model.invoke(messages)
    return {"messages": [response]}

def call_tool(state):
    messages = state["messages"]
    last_message = messages[-1]
    tool_invocations = []
    for tool_call in last_message.tool_calls:
        action = ToolInvocation(
            tool=tool_call["name"],
            tool_input=tool_call["args"],
        )
        tool_invocations.append(action)

    action = ToolInvocation(
        tool=tool_call["name"],
        tool_input=tool_call["args"],
    )
    responses = tool_executor.batch(tool_invocations, return_exceptions=True)
    tool_messages = [
        ToolMessage(
            content=str(response),
            name=tc["name"],
            tool_call_id=tc["id"],
        )
        for tc, response in zip(last_message.tool_calls, responses)
    ]
    return {"messages": tool_messages}

In [49]:
from langchain_core.messages import AIMessage


def first_model(state: agentstate):
    human_input = state["messages"][-1].content
    return {
        "messages": [
            AIMessage(
                content="",
                tool_calls=[
                    {
                        "name": "tavily_search_results_json",
                        "args": {
                            "query": human_input,
                        },
                        "id": "tool_abcd123",
                    }
                ],
            )
        ]
    }

In [50]:
from langgraph.checkpoint.sqlite import SqliteSaver

In [54]:
from langgraph.graph import END, StateGraph, START
workflow = StateGraph(agentstate)

workflow.add_node("agent", call_model)
workflow.add_node("action", call_tool)
workflow.add_node("final", call_tool)
workflow.add_edge(START, "first_model")

workflow.add_conditional_edges(
    "first_model",
    should_continue,
    {
        "continue": "action",
        "final": "final",
        "end": END,
    },
)

workflow.add_edge("action", "agent")
workflow.add_edge("final", END)
memory = SqliteSaver.from_conn_string(":memory:")
app = workflow.compile(checkpointer=memory)

ValueError: Found edge starting at unknown node 'first_model'

In [None]:
from langchain_core.messages import HumanMessage
import uuid
thread_id = str(uuid.uuid4())

inputs = {"messages": [HumanMessage(content="what is the weather in Jhelum,punjab,pakistan?")]}
config = {"configurable": {"thread_id": thread_id}}
while True:
    for output in app.stream(inputs, config):
        for key, value in output.items():
            print(f"Output from node '{key}':")
            print("---")
            print(value)
        print("\n---\n")
    snapshot = app.get_state(config)
    if not snapshot.next:
        break
    inputs = None
    response = input(
        "Do you approve the next step? Type y if you do, anything else to stop: "
    )
    if response != "y":
        break

Output from node 'agent':
---
{'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_U5TR77CYsr5JCh9IvVicntJ5', 'function': {'arguments': '{"query":"weather in Jhelum, Punjab, Pakistan"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}]}, response_metadata={'finish_reason': 'tool_calls', 'model_name': 'gpt-3.5-turbo-0125'}, id='run-667a4db3-f63e-4c31-88a4-19864a1cdec7-0', tool_calls=[{'name': 'tavily_search_results_json', 'args': {'query': 'weather in Jhelum, Punjab, Pakistan'}, 'id': 'call_U5TR77CYsr5JCh9IvVicntJ5', 'type': 'tool_call'}])]}

---

Output from node 'action':
---
{'messages': [ToolMessage(content='[{\'url\': \'https://www.weatherapi.com/\', \'content\': "{\'location\': {\'name\': \'Jhelum\', \'region\': \'Punjab\', \'country\': \'Pakistan\', \'lat\': 32.93, \'lon\': 73.73, \'tz_id\': \'Asia/Karachi\', \'localtime_epoch\': 1721645084, \'localtime\': \'2024-07-22 15:44\'}, \'current\': {\'last_updated_epoch\': 1721644200,