In [8]:
import asyncio
from typing_extensions import TypedDict

from langgraph.graph import StateGraph, START, END
from langchain_groq import ChatGroq
from langchain_core.messages import HumanMessage

from mcp import ClientSession
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamable_http_client

import dotenv
dotenv.load_dotenv()

True

In [9]:
class State(TypedDict):
    messages: list

In [10]:
llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0)

In [11]:
async def connect_to_servers():
    sessions = []

    # stdio servers
    stdio_servers = [
        ["python", "add_server.py"],
        ["python", "multiply_server.py"],
    ]

    for cmd in stdio_servers:
        transport = await stdio_client(cmd)
        session = await ClientSession.create(transport)
        sessions.append(session)

    # streamable-http server
    http_transport = await streamable_http_client(
        "http://127.0.0.1:8000"
    )
    http_session = await ClientSession.create(http_transport)
    sessions.append(http_session)

    return sessions

In [12]:
async def agent_node(state: State):
    sessions = await connect_to_servers()

    tools = []
    for s in sessions:
        tools.extend(await s.list_tools())

    model = llm.bind_tools(tools)

    response = await model.ainvoke(state["messages"])

    return {"messages": state["messages"] + [response]}

In [13]:
builder = StateGraph(State)
builder.add_node("agent", agent_node)
builder.add_edge(START, "agent")
builder.add_edge("agent", END)

graph = builder.compile()

In [18]:
async def main():
    result = await graph.ainvoke(
        {
            "messages": [
                HumanMessage(
                    content="What is 4 + 6? Then multiply result by 5. Also tell weather."
                )
            ]
        }
    )

    for m in result["messages"]:
        print(m)

await main()

TypeError: object _AsyncGeneratorContextManager can't be used in 'await' expression