Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 69 additions & 7 deletions flocks/mcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
self._connected = False
self._transport_type: Optional[str] = None
self._command_queue: asyncio.Queue[_ClientCommand] | None = None
self._owner_loop: asyncio.AbstractEventLoop | None = None
self._owner_task: asyncio.Task[None] | None = None
self._owner_error: BaseException | None = None

Expand All @@ -140,6 +141,7 @@ async def connect(self) -> None:
loop = asyncio.get_running_loop()
startup_future: asyncio.Future[None] = loop.create_future()
self._owner_error = None
self._owner_loop = loop
self._command_queue = asyncio.Queue()

owner_task = asyncio.create_task(
Expand Down Expand Up @@ -225,6 +227,7 @@ def _reset_runtime_state(self, clear_owner_error: bool = False) -> None:
self._connected = False
self._transport_type = None
self._command_queue = None
self._owner_loop = None
if self._owner_task is not None and self._owner_task.done():
self._owner_task = None
if clear_owner_error:
Expand Down Expand Up @@ -628,16 +631,13 @@ async def disconnect(self) -> None:
return

try:
if owner_task is not None and not owner_task.done() and self._command_queue is not None:
response = asyncio.get_running_loop().create_future()
await self._command_queue.put(_ClientCommand(action="disconnect", response=response))
await response
if owner_task is not None and not owner_task.done() and self._connected and self._command_queue is not None:
await self._submit_command("disconnect")
elif owner_task is not None and not owner_task.done():
owner_task.cancel()
self._cancel_task_threadsafe(owner_task)

if owner_task is not None:
with contextlib.suppress(asyncio.CancelledError):
await owner_task
await self._await_task(owner_task)
except Exception as exc:
log.error("mcp.client.disconnect_error", {
"server": self.name,
Expand Down Expand Up @@ -753,6 +753,38 @@ async def _submit_command(self, action: str, **payload: Any) -> Any:
) from self._owner_error
raise RuntimeError(f"Client not connected: {self.name}")

owner_loop = self._owner_loop
if owner_loop is None:
raise RuntimeError(f"Client owner loop not initialized: {self.name}")

current_loop = asyncio.get_running_loop()
if current_loop is owner_loop:
return await self._submit_command_on_owner_loop(action, payload)

if not owner_loop.is_running():
owner_error = self._owner_error or RuntimeError(f"Client not connected: {self.name}")
raise owner_error

future = asyncio.run_coroutine_threadsafe(
self._submit_command_on_owner_loop(action, payload),
owner_loop,
)
return await asyncio.wrap_future(future)

async def _submit_command_on_owner_loop(
self,
action: str,
payload: Dict[str, Any],
) -> Any:
"""Submit a command from the owner loop and await the response."""
owner_task = self._owner_task
if self._command_queue is None or owner_task is None:
if self._owner_error is not None:
raise RuntimeError(
f"Client not connected: {self.name}: {_extract_root_cause(self._owner_error)}"
) from self._owner_error
raise RuntimeError(f"Client not connected: {self.name}")

response = asyncio.get_running_loop().create_future()
command = _ClientCommand(action=action, payload=payload, response=response)
await self._command_queue.put(command)
Expand All @@ -762,6 +794,36 @@ async def _submit_command(self, action: str, **payload: Any) -> Any:
response.set_exception(owner_error)

return await response

async def _await_task(self, task: asyncio.Task[Any]) -> None:
"""Await a task safely from the owner loop or another loop."""
owner_loop = self._owner_loop
current_loop = asyncio.get_running_loop()
if owner_loop is None or current_loop is owner_loop:
with contextlib.suppress(asyncio.CancelledError):
await task
return

if not owner_loop.is_running():
return

future = asyncio.run_coroutine_threadsafe(
self._await_task_on_owner_loop(task),
owner_loop,
)
await asyncio.wrap_future(future)

async def _await_task_on_owner_loop(self, task: asyncio.Task[Any]) -> None:
"""Await a task that belongs to the owner loop."""
with contextlib.suppress(asyncio.CancelledError):
await task

def _cancel_task_threadsafe(self, task: asyncio.Task[Any]) -> None:
"""Cancel a task from the owner loop or another thread/loop."""
owner_loop = self._owner_loop
if owner_loop is None or not owner_loop.is_running():
return
owner_loop.call_soon_threadsafe(task.cancel)

@property
def is_connected(self) -> bool:
Expand Down
45 changes: 43 additions & 2 deletions flocks/server/routes/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,13 +476,54 @@ async def update_mcp_server(name: str, request: McpUpdateRequest):
_persist_mcp_server_config(name, clean_config)

status = await MCP.status()
if name in status:
previous_status = status.get(name)
was_connected = (
previous_status is not None
and previous_status.status == McpStatus.CONNECTED
)
should_reconnect = was_connected and clean_config.get("enabled", True) is not False

if previous_status is not None:
await MCP.remove(name)

reconnected = False
reconnect_error: Optional[str] = None
if should_reconnect:
reconnect_timeout_seconds = max(float(clean_config.get("timeout", 30.0) or 30.0), 1.0) + 2.0
try:
reconnected = await asyncio.wait_for(
MCP.connect(name, clean_config),
timeout=reconnect_timeout_seconds,
)
except asyncio.TimeoutError:
reconnect_error = (
f"Connection timed out while reconnecting MCP server: {name}"
)
except Exception as exc:
reconnect_error = str(exc)
if not reconnected and reconnect_error is None:
reconnect_status = (await MCP.status()).get(name)
reconnect_error = (
getattr(reconnect_status, "error", None)
if reconnect_status is not None
else None
) or f"Failed to reconnect MCP server: {name}"

message = f"MCP server '{name}' updated successfully."
if should_reconnect and reconnected:
message = f"MCP server '{name}' updated and reconnected successfully."
elif should_reconnect and reconnect_error:
message = (
f"MCP server '{name}' updated successfully, but reconnect failed: "
f"{reconnect_error}"
)

return {
"success": True,
"message": f"MCP server '{name}' updated successfully.",
"message": message,
"config": _to_frontend_mcp_config(clean_config),
"reconnected": reconnected,
"reconnect_error": reconnect_error,
}
except HTTPException:
raise
Expand Down
49 changes: 46 additions & 3 deletions flocks/workflow/service_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
import asyncio
import json
import time
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional

from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field

from flocks.mcp import MCP, get_manager
from flocks.utils.log import Log
from flocks.workflow.runner import RunWorkflowResult, run_workflow

log = Log.create(service="workflow.service_runtime")


class InvokeRequest(BaseModel):
"""Request payload for workflow invoke."""
Expand All @@ -31,22 +37,59 @@ def create_service_app(
release_id: str,
) -> FastAPI:
"""Build service app bound to one workflow snapshot."""
app = FastAPI(title="Flocks Workflow Service", version="0.2.0")
@asynccontextmanager
async def lifespan(_app: FastAPI):
_app.state.mcp_ready = False
_app.state.mcp_error = None
try:
await MCP.init()
except Exception as exc:
_app.state.mcp_error = str(exc)
log.warning("workflow_service.mcp.init_failed", {"error": str(exc)})
else:
_app.state.mcp_ready = True
try:
yield
finally:
try:
await get_manager().shutdown()
except Exception as exc:
log.warning("workflow_service.mcp.shutdown_failed", {"error": str(exc)})

app = FastAPI(title="Flocks Workflow Service", version="0.2.0", lifespan=lifespan)
app.state.workflow_json = workflow_json
app.state.workflow_id = workflow_id
app.state.release_id = release_id

@app.get("/health")
async def health() -> Dict[str, Any]:
return {
"ok": True,
payload = {
"ok": bool(app.state.mcp_ready),
"mcp_ready": bool(app.state.mcp_ready),
"mcp_error": app.state.mcp_error,
"workflow_id": app.state.workflow_id,
"release_id": app.state.release_id,
}
if app.state.mcp_ready:
return payload
return JSONResponse(status_code=503, content=payload)

@app.post("/invoke")
async def invoke(req: InvokeRequest) -> Dict[str, Any]:
started = time.time()
if not app.state.mcp_ready:
raise HTTPException(
status_code=503,
detail={
"request_id": req.request_id,
"workflow_id": app.state.workflow_id,
"release_id": app.state.release_id,
"status": "FAILED",
"error": app.state.mcp_error or "MCP subsystem is not ready",
"duration_ms": int((time.time() - started) * 1000),
},
)

try:
result: RunWorkflowResult = await asyncio.to_thread(
run_workflow,
Expand Down
119 changes: 119 additions & 0 deletions tests/mcp/test_mcp_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import threading
from contextlib import asynccontextmanager
from types import MethodType
from unittest.mock import AsyncMock
Expand Down Expand Up @@ -176,3 +177,121 @@ async def broken_stdio(self, _server_params, stderr_file):
await client._connect_local(startup_future)

assert fake_stderr.closed is True


class TestMcpClientCrossLoopSubmission:
@pytest.mark.asyncio
async def test_call_tool_from_another_loop_reuses_owner_loop(self):
client = McpClient(
name="demo",
server_type="remote",
url="https://example.com/mcp",
)
owner_loop = asyncio.get_running_loop()
client._connected = True
client._owner_loop = owner_loop
client._command_queue = asyncio.Queue()

observed: dict[str, object] = {}

async def owner() -> None:
while True:
command = await client._command_queue.get()
observed["owner_loop"] = asyncio.get_running_loop()
observed["action"] = command.action
observed["payload"] = dict(command.payload)
if command.action == "disconnect":
if command.response is not None and not command.response.done():
command.response.set_result(None)
return
if command.response is not None and not command.response.done():
command.response.set_result(
{
"ok": True,
"loop_matches": asyncio.get_running_loop() is owner_loop,
"payload": dict(command.payload),
}
)

client._owner_task = asyncio.create_task(owner())

result = await asyncio.to_thread(
lambda: asyncio.run(client.call_tool("demo_tool", {"value": "x"}))
)

assert result == {
"ok": True,
"loop_matches": True,
"payload": {"name": "demo_tool", "arguments": {"value": "x"}},
}
assert observed["owner_loop"] is owner_loop
assert observed["action"] == "call_tool"

await client.disconnect()

@pytest.mark.asyncio
async def test_disconnect_from_another_loop_finishes_owner_task(self):
client = McpClient(
name="demo",
server_type="remote",
url="https://example.com/mcp",
)
client._connected = True
client._owner_loop = asyncio.get_running_loop()
client._command_queue = asyncio.Queue()
disconnect_seen = threading.Event()

async def owner() -> None:
while True:
command = await client._command_queue.get()
if command.action == "disconnect":
disconnect_seen.set()
if command.response is not None and not command.response.done():
command.response.set_result(None)
return

owner_task = asyncio.create_task(owner())
client._owner_task = owner_task

await asyncio.to_thread(lambda: asyncio.run(client.disconnect()))

assert disconnect_seen.is_set() is True
assert owner_task.done() is True
assert client._owner_loop is None
assert client._command_queue is None

@pytest.mark.asyncio
async def test_disconnect_while_connecting_cancels_owner_task(
self,
monkeypatch: pytest.MonkeyPatch,
):
client = McpClient(
name="demo",
server_type="remote",
url="https://example.com/mcp",
)
started = asyncio.Event()
cancelled = asyncio.Event()

async def fake_remote(startup_future):
del startup_future
started.set()
try:
await asyncio.Future()
except asyncio.CancelledError:
cancelled.set()
raise

monkeypatch.setattr(client, "_connect_remote", fake_remote)

connect_task = asyncio.create_task(client.connect())
await started.wait()

await client.disconnect()

with pytest.raises(RuntimeError, match="Connection closed before initialization: demo"):
await connect_task

assert cancelled.is_set() is True
assert client._owner_task is None
assert client._command_queue is None
Loading