# ライブラリのインポート

In [None]:
import os
import json
import operator
import pkg_resources
from typing import TypedDict, Annotated, Sequence
# Lang関連
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolExecutor,ToolInvocation,ToolNode
from langchain_core.messages import BaseMessage, FunctionMessage, HumanMessage,ToolMessage
from langchain_google_genai import ChatGoogleGenerativeAI

In [None]:

langgraph_version = pkg_resources.get_distribution("langgraph").version
print(f"langgraph version: {langgraph_version}")

langchain_community_version = pkg_resources.get_distribution("langchain_community").version
print(f"langchain_community version: {langchain_community_version}")

langgraph_version = pkg_resources.get_distribution("langgraph").version
print(f"langgraph version: {langgraph_version}")

langchain_core_version = pkg_resources.get_distribution("langchain_core").version
print(f"langchain_core version: {langchain_core_version}")

langchain_google_genai_version = pkg_resources.get_distribution("langchain_google_genai").version
print(f"langchain_google_genai version: {langchain_google_genai_version}")

# モデルの定義

In [None]:
# ネット検索を使うためのツール
tools = [TavilySearchResults(max_results=1)]
tool_executor = ToolExecutor(tools)
# モデルの定義
model = ChatGoogleGenerativeAI(
    model="models/gemini-1.5-flash",
)
# モデルにツールを結びつける
model =model.bind_tools(tools)


# グラフの定義

### 保持すべき状態を定義

In [None]:
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]

### ノードの定義

In [None]:
# 続行するかどうかを決定する関数を定義
def should_continue(state):
    messages = state['messages']
    last_message = messages[-1]
    if "function_call" not in last_message.additional_kwargs:
        return "end"
    else:
        return "continue"


# モデルを呼び出す関数を定義
def call_model(state):
    messages = state['messages']
    response = model.invoke(messages)
    return {"messages": [response]}

def call_tool(state):
    messages = state['messages']
    last_message = messages[-1]
    
    # 関数呼び出しの情報を取得
    function_call = last_message.additional_kwargs["function_call"]
    tool_name = function_call["name"]
    tool_args = json.loads(function_call["arguments"])
    
    # 適切なツールを選択して実行
    for tool in tools:
        if tool.name == tool_name:
            response = tool.run(tool_args)
            break
    else:
        response = f"Error: Tool '{tool_name}' not found"
    
    # 応答を使ってFunctionMessageを作成
    function_message = FunctionMessage(content=str(response), name=tool_name)
    
    return {"messages": [function_message]}

### ノードの追加とエッジの定義

In [None]:
# 新しいグラフを定義
workflow = StateGraph(AgentState)

# 二つのノードを定義
workflow.add_node("agent", call_model)
workflow.add_node("action", call_tool)


workflow.set_entry_point("agent")

# 条件付きエッジを追加
workflow.add_conditional_edges("agent", should_continue, {"continue": "action","end": END} )

workflow.add_edge('action', 'agent')

graph = workflow.compile()

### モデルの呼び出し

In [None]:
inputs = {"messages": [HumanMessage(content="明日の東京の天気を教えて")]}
result = graph.invoke(inputs)
print(result)

### resultの中身を確認

In [None]:
for name in result['messages']:
    print("-"*100)
    print(name.__str__)
