# LangGraph の基礎


In [None]:
from dotenv import load_dotenv

load_dotenv(dotenv_path="../.env", override=True)

## 単純なチャットボットの実装


In [None]:
from typing import Annotated
from typing_extensions import TypedDict

from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages


class State(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages]

In [None]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini")


def llm_node(state: State):
    ai_message = llm.invoke(state["messages"])
    return {"messages": [ai_message]}

In [None]:
from langgraph.graph import StateGraph, START, END

graph_builder = StateGraph(State)
graph_builder.add_node("llm_node", llm_node)

graph_builder.add_edge(START, "llm_node")
graph_builder.add_edge("llm_node", END)

graph = graph_builder.compile()

In [None]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
from langchain_core.messages import HumanMessage

initial_state = {"messages": HumanMessage("こんにちは！")}

In [None]:
graph.invoke(initial_state)

## 単純なエージェントの実装


In [None]:
from langchain_community.tools.tavily_search import TavilySearchResults

tool = TavilySearchResults()
tools = [tool]

In [None]:
from typing import Annotated
from typing_extensions import TypedDict

from langgraph.graph.message import add_messages


class State(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages]

In [None]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini")
llm_with_tools = llm.bind_tools(tools)


def llm_node(state: State):
    return {"messages": [llm_with_tools.invoke(state["messages"])]}

In [None]:
import json

from langchain_core.messages import ToolMessage
from langchain_core.tools import BaseTool


class BasicToolNode:
    def __init__(self, tools: list[BaseTool]) -> None:
        # {"ツール名": "ツール"} というdictを作成
        tools_by_name = {}
        for tool in tools:
            tools_by_name[tool.name] = tool
        self.tools_by_name = tools_by_name

    def __call__(self, state: State):
        latest_message = state["messages"][-1]

        tool_messages = []
        for tool_call in latest_message.tool_calls:
            tool = self.tools_by_name[tool_call["name"]]
            tool_result = tool.invoke(tool_call["args"])
            tool_messages.append(
                ToolMessage(
                    content=json.dumps(tool_result),
                    name=tool_call["name"],
                    tool_call_id=tool_call["id"],
                )
            )
        return {"messages": tool_messages}


tool_node = BasicToolNode(tools=[tool])

In [None]:
from langgraph.graph import StateGraph, START, END


graph_builder = StateGraph(State)
graph_builder.add_node("llm_node", llm_node)
graph_builder.add_node("tool_node", tool_node)


def route_tools(state: State):
    last_message = state["messages"][-1]
    if hasattr(last_message, "tool_calls") and len(last_message.tool_calls) > 0:
        return "tool_node"
    return END


graph_builder.add_conditional_edges(
    "llm_node",
    route_tools,
    {
        "tool_node": "tool_node",
        END: END,
    },
)
graph_builder.add_edge("tool_node", "llm_node")
graph_builder.add_edge(START, "llm_node")
graph = graph_builder.compile()

In [None]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
from langchain_core.messages import HumanMessage

initial_state = {"messages": HumanMessage("こんにちは！")}
graph.invoke(initial_state)

In [None]:
from langchain_core.messages import HumanMessage

initial_state = {"messages": HumanMessage("東京の今日の天気は？")}
graph.invoke(initial_state)

In [None]:
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage

initial_state = {"messages": HumanMessage("東京の今日の天気は？")}

for event in graph.stream(initial_state, stream_mode="updates"):
    for value in event.values():
        latest_message = value["messages"][-1]
        if isinstance(latest_message, AIMessage):
            if (
                hasattr(latest_message, "tool_calls")
                and len(latest_message.tool_calls) > 0
            ):
                for tool_call in latest_message.tool_calls:
                    print(
                        f"Tool call: name = {tool_call['name']}, args = {tool_call['args']}"
                    )
            else:
                print(f"AI: {latest_message.content}")
        elif isinstance(latest_message, ToolMessage):
            print(f"Tool result: {latest_message.content}")
        else:
            print(latest_message)
