In [None]:
import initialize_notebook # noqa

# MCP Server

In [None]:
from mcp.server import FastMCP

SERVER = FastMCP()


@SERVER.tool()
def multiply(a: float, b: float) -> float:
    """Multiplies 2 numbers together."""
    return a * b


@SERVER.tool()
def add(a: float, b: float) -> float:
    """Adds 2 numbers together."""
    return a + b


In [None]:
import threading

import uvicorn

PORT = 5000
HOST = "localhost"

RUN_ARGS = {
    "app": SERVER.streamable_http_app,
    "port": PORT,
    "host": HOST,
}

MCP_THREAD = threading.Thread(target=uvicorn.run, kwargs=RUN_ARGS)
MCP_THREAD.start()

# MCP Client

In [None]:
import json
import random

from hslu.dlm03.common import backend

model_client, model_name = backend.Gemini2p5Flash().get_client()

In [None]:
def tool_from_mcp(tool):
    return dict(
        type="function",
        function=dict(
            name=tool.name,
            description=tool.description,
            parameters=tool.inputSchema,
            strict=True,
        ),
    )

def tool_call_result_from_mcp(call_id, content):
    content_type = content.type
    match content_type:
        case "text":
            return dict(
                role="tool",
                tool_call_id=call_id,
                content=content.text,
            )
        case "resource":
            resource = content.resource
            mime_type = resource.mimeType.split(";")[0]
            match mime_type:
                case "text/plain":
                    return dict(
                        role="tool",
                        tool_call_id=call_id,
                        content=resource.text,
                    )
                case _:
                    error_message = f"Unsupported resource mime type: {mime_type}"
                    raise ValueError(error_message)
        case _:
            error_message = f"Invalid content type: {content_type}"
            raise ValueError(error_message)

In [None]:
import mcp
from mcp.client import streamable_http

async with (streamable_http.streamablehttp_client(f"http://{HOST}:{PORT}/mcp") as (read_stream, write_stream, _),
            mcp.ClientSession(read_stream, write_stream) as session):
    await session.initialize()
    mcp_tools = await session.list_tools()
    tools = [tool_from_mcp(tool) for tool in mcp_tools.tools]

In [None]:
QUERY = "What is (2 + 3) * 4?"

async with (streamable_http.streamablehttp_client(f"http://{HOST}:{PORT}/mcp") as (read_stream, write_stream, _),
            mcp.ClientSession(read_stream, write_stream) as session):
    await session.initialize()
    messages = [
        {"role": "user", "content": QUERY},
    ]
    done = False
    while not done:
        response = model_client.chat.completions.create(
            messages=messages,
            model=model_name,
            tools=tools,
        )
        done = True
        message = random.choice(response.choices).message
        messages.append(message)
        print(message)
        if message.content:
            print(message.content)
        if message.tool_calls:
            for tool_call in message.tool_calls:
                done = False
                tool_name = tool_call.function.name
                arguments = json.loads(tool_call.function.arguments)
                tool_call_result = await session.call_tool(tool_name, arguments)
                for content in tool_call_result.content:
                    content = tool_call_result_from_mcp(
                        tool_call.id,
                        content,
                    )
                    messages.append(content)