From bf137d841eb2ee90197e1dce66b1e4bde1924c58 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 11:08:55 -0700 Subject: [PATCH 1/8] feat(invoke): selector-based invoke and invoke_many with sync fan-out Add two selector-driven invocation tools that replace the legacy invoke_device(device_id, function, params) shape: - invoke(selector, params, llm_reasoning) resolves a function-scoped selector to exactly one (device, function) tuple and calls it. Returns {success, device_id, function, result|error}. Returns no_match, ambiguous_match, invalid_invoke_scope, or invalid_selector errors as structured envelopes when the selector does not resolve cleanly. - invoke_many(selector, params, timeout, max_concurrency, llm_reasoning) resolves to N (device, function) tuples and fans out the calls in parallel via a thread pool. Partial-failure semantics: a single target's failure does not abort siblings. Returns {candidates, matched, succeeded, failed, results, errors} with per-target structured errors. Per-target timeout defaults to 30s. invoke_device gains a DeprecationWarning pointing to invoke(); the function still works for one release while callers migrate. Adapters (Claude Agent SDK, Strands, LangChain, the in-tree StrandsOpenAIDeviceConnectAgent, and the operator-facing AGENT_SCRIPT template) drop invoke_device and expose invoke / invoke_many instead. invoke_device_with_fallback stays unchanged -- it covers a different ergonomic case (try a list of device ids in order) with no selector equivalent. 22 unit tests cover scope rejection, ambiguous and zero matches, JSON-RPC error mapping, partial failure, per-target timeout propagation, and llm_reasoning stripping. 9 integration tests cover single-target invoke, robot dispatch through to event emission, fan-out across multiple cameras, partial failure, and zero-candidate empty envelopes. --- .../device_connect_agent_tools/__init__.py | 32 +- .../adapters/claude.py | 62 +++- .../adapters/langchain.py | 23 +- .../adapters/strands.py | 23 +- .../adapters/strands_agent.py | 15 +- .../device_connect_agent_tools/tools.py | 285 ++++++++++++++- .../tests/test_claude_adapter.py | 3 +- .../tests/test_invoke.py | 336 ++++++++++++++++++ .../tests/test_langchain_adapter.py | 3 +- .../tests/test_strands_adapter.py | 3 +- .../portal/views/devices.py | 16 +- tests/tests/test_tools_invoke.py | 256 ++++++++++--- 12 files changed, 941 insertions(+), 116 deletions(-) create mode 100644 packages/device-connect-agent-tools/tests/test_invoke.py diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py b/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py index de79913..c809baa 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py @@ -4,19 +4,21 @@ """Device Connect Tools — framework-agnostic SDK for Device Connect IoT. -Selector-driven discovery keeps LLM context small: +Selector-driven discovery and invocation keep LLM context small: - from device_connect_agent_tools import connect, discover, discover_labels + from device_connect_agent_tools import connect, discover, discover_labels, invoke connect() vocab = discover_labels() # fleet vocabulary cams = discover("device(category:camera, location:zone-A/*)") # device roster writes = discover("device(*).function(direction:write)") # function tuples - result = invoke_device("camera-001", "capture_image", {"resolution": "1080p"}) + result = invoke("device(camera-001).function(capture_image)", + {"resolution": "1080p"}) -The older ``describe_fleet`` / ``list_devices`` / ``get_device_functions`` -trio remains available for one release as advisory-deprecated wrappers -- -prefer ``discover`` / ``discover_labels`` for new code. +The older ``describe_fleet`` / ``list_devices`` / ``get_device_functions`` / +``invoke_device`` family remains available for one release as +advisory-deprecated wrappers -- prefer ``discover`` / ``discover_labels`` / +``invoke`` / ``invoke_many`` for new code. """ from device_connect_agent_tools.agent import DeviceConnectAgent @@ -25,14 +27,17 @@ # Selector-driven discovery (preferred) discover, discover_labels, - # Invocation - invoke_device, + # Selector-driven invocation (preferred) + invoke, + invoke_many, + # Other invocation helpers invoke_device_with_fallback, get_device_status, - # Advisory-deprecated discovery wrappers (one-release transition) + # Advisory-deprecated wrappers (one-release transition) describe_fleet, list_devices, get_device_functions, + invoke_device, discover_devices, ) @@ -46,13 +51,16 @@ # Selector-driven discovery (preferred) "discover", "discover_labels", - # Invocation - "invoke_device", + # Selector-driven invocation (preferred) + "invoke", + "invoke_many", + # Other invocation helpers "invoke_device_with_fallback", "get_device_status", - # Advisory-deprecated -- use discover() / discover_labels() instead + # Advisory-deprecated -- use discover / discover_labels / invoke instead "describe_fleet", "list_devices", "get_device_functions", + "invoke_device", "discover_devices", ] diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py index 807abcb..9dd08d8 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py @@ -4,7 +4,7 @@ """Claude Agent SDK adapter — exposes Device Connect tools to claude-agent-sdk. -Selector-driven discovery keeps LLM context small:: +Selector-driven discovery and invocation keep LLM context small:: import anyio from claude_agent_sdk import ClaudeSDKClient, ClaudeAgentOptions, AssistantMessage, TextBlock @@ -45,7 +45,8 @@ async def main(): discover as _discover, discover_labels as _discover_labels, discover_devices as _discover_devices, - invoke_device as _invoke_device, + invoke as _invoke, + invoke_many as _invoke_many, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) @@ -101,27 +102,54 @@ async def discover(args: dict[str, Any]) -> dict[str, Any]: ) -# Invocation tools +# Selector-driven invocation tools (recommended) @tool( - "invoke_device", - "Call a function on a Device Connect device. Use discover() with a " - "function-scoped selector first to learn available functions and " - "parameters.", - {"device_id": str, "function": str, "params": dict, "llm_reasoning": str}, + "invoke", + "Call exactly one function on one device. The selector must resolve " + "to a single (device, function) tuple -- use device().function() " + "or function() scope. Returns {success, device_id, function, " + "result|error}. Use invoke_many for fan-out across multiple targets.", + {"selector": str, "params": dict, "llm_reasoning": str}, ) -async def invoke_device(args: dict[str, Any]) -> dict[str, Any]: +async def invoke(args: dict[str, Any]) -> dict[str, Any]: return _text( - _invoke_device( - device_id=args["device_id"], - function=args["function"], + _invoke( + selector=args["selector"], + params=args.get("params"), + llm_reasoning=args.get("llm_reasoning"), + ) + ) + + +@tool( + "invoke_many", + "Fan out a function call over a selector-resolved set of (device, " + "function) tuples in parallel. Partial-failure semantics: per-target " + "results and errors are returned even if some targets fail. Returns " + "{candidates, matched, succeeded, failed, results, errors}. Each " + "target gets a per-call timeout (default 30s).", + { + "selector": str, "params": dict, "timeout": float, + "max_concurrency": int, "llm_reasoning": str, + }, +) +async def invoke_many(args: dict[str, Any]) -> dict[str, Any]: + return _text( + _invoke_many( + selector=args["selector"], params=args.get("params"), + timeout=float(args.get("timeout", 30.0)), + max_concurrency=int(args.get("max_concurrency", 32)), llm_reasoning=args.get("llm_reasoning"), ) ) +# Other invocation helpers + + @tool( "invoke_device_with_fallback", "Call a function with automatic fallback across a list of device IDs. " @@ -148,12 +176,12 @@ async def get_device_status(args: dict[str, Any]) -> dict[str, Any]: return _text(_get_device_status(device_id=args["device_id"])) -# Backward-compatible (long-deprecated — prefer discover() / discover_labels()) +# Backward-compatible (long-deprecated -- prefer discover() / invoke()) @tool( "discover_devices", - "Deprecated — use discover() and discover_labels() instead. Discovers " + "Deprecated -- use discover() and discover_labels() instead. Discovers " "all devices with full function schemas.", {"device_type": str, "refresh": bool}, ) @@ -176,7 +204,8 @@ def create_device_connect_server(name: str = "device-connect"): tools=[ discover_labels, discover, - invoke_device, + invoke, + invoke_many, invoke_device_with_fallback, get_device_status, discover_devices, @@ -187,7 +216,8 @@ def create_device_connect_server(name: str = "device-connect"): __all__ = [ "discover_labels", "discover", - "invoke_device", + "invoke", + "invoke_many", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py index f934024..35d5e51 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py @@ -4,16 +4,16 @@ """LangChain adapter — wraps Device Connect tools as LangChain StructuredTools. -Selector-driven discovery keeps LLM context small: +Selector-driven discovery and invocation keep LLM context small: from device_connect_agent_tools import connect from device_connect_agent_tools.adapters.langchain import ( - discover_labels, discover, invoke_device, + discover_labels, discover, invoke, invoke_many, ) from langgraph.prebuilt import create_react_agent connect() - agent = create_react_agent(model, [discover_labels, discover, invoke_device]) + agent = create_react_agent(model, [discover_labels, discover, invoke, invoke_many]) Requires: pip install device-connect-agent-tools[langchain] """ @@ -24,27 +24,32 @@ discover as _discover, discover_labels as _discover_labels, discover_devices as _discover_devices, - invoke_device as _invoke_device, + invoke as _invoke, + invoke_many as _invoke_many, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) -# Selector-driven discovery tools (recommended) +# Selector-driven discovery (recommended) discover_labels = StructuredTool.from_function(_discover_labels) discover = StructuredTool.from_function(_discover) -# Invocation tools -invoke_device = StructuredTool.from_function(_invoke_device) +# Selector-driven invocation (recommended) +invoke = StructuredTool.from_function(_invoke) +invoke_many = StructuredTool.from_function(_invoke_many) + +# Other invocation helpers invoke_device_with_fallback = StructuredTool.from_function(_invoke_device_with_fallback) get_device_status = StructuredTool.from_function(_get_device_status) -# Backward-compatible (long-deprecated — prefer discover() / discover_labels()) +# Backward-compatible (long-deprecated -- prefer discover() / invoke()) discover_devices = StructuredTool.from_function(_discover_devices) __all__ = [ "discover_labels", "discover", - "invoke_device", + "invoke", + "invoke_many", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py index 848f362..d22fcf7 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py @@ -4,16 +4,16 @@ """Strands adapter — wraps Device Connect tools with @strands.tool. -Selector-driven discovery keeps LLM context small: +Selector-driven discovery and invocation keep LLM context small: from device_connect_agent_tools import connect from device_connect_agent_tools.adapters.strands import ( - discover_labels, discover, invoke_device, + discover_labels, discover, invoke, invoke_many, ) from strands import Agent connect() - agent = Agent(tools=[discover_labels, discover, invoke_device]) + agent = Agent(tools=[discover_labels, discover, invoke, invoke_many]) agent("What devices are online?") Requires: pip install device-connect-agent-tools[strands] @@ -25,27 +25,32 @@ discover as _discover, discover_labels as _discover_labels, discover_devices as _discover_devices, - invoke_device as _invoke_device, + invoke as _invoke, + invoke_many as _invoke_many, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) -# Selector-driven discovery tools (recommended) +# Selector-driven discovery (recommended) discover_labels = strands_tool(_discover_labels) discover = strands_tool(_discover) -# Invocation tools -invoke_device = strands_tool(_invoke_device) +# Selector-driven invocation (recommended) +invoke = strands_tool(_invoke) +invoke_many = strands_tool(_invoke_many) + +# Other invocation helpers invoke_device_with_fallback = strands_tool(_invoke_device_with_fallback) get_device_status = strands_tool(_get_device_status) -# Backward-compatible (long-deprecated — prefer discover() / discover_labels()) +# Backward-compatible (long-deprecated -- prefer discover() / invoke()) discover_devices = strands_tool(_discover_devices) __all__ = [ "discover_labels", "discover", - "invoke_device", + "invoke", + "invoke_many", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py index a3f0cf5..c5f5e67 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py @@ -62,7 +62,8 @@ async def prepare(self) -> Dict[str, Any]: from device_connect_agent_tools.adapters.strands import ( discover_labels, discover, - invoke_device, + invoke, + invoke_many, invoke_device_with_fallback, get_device_status, ) @@ -74,7 +75,7 @@ async def prepare(self) -> Dict[str, Any]: model=AnthropicModel(model_id=self._model_id, max_tokens=self._max_tokens), tools=[ discover_labels, discover, - invoke_device, invoke_device_with_fallback, get_device_status, + invoke, invoke_many, invoke_device_with_fallback, get_device_status, ], system_prompt=system_prompt, ) @@ -120,14 +121,18 @@ def _build_system_prompt(self) -> str: f"functions, or events. Examples:\n" f" device(category:camera, location:zone-A/*)\n" f" device(robot-001).function(direction:write)\n" - f" function(safety:critical)\n" - f" - invoke_device(device_id, function, params) -- call a device function\n\n" + f" function(safety:critical)\n\n" + f"INVOCATION TOOLS:\n" + f" - invoke(selector, params) -- call exactly one function. " + f"Selector must resolve to one (device, function) tuple.\n" + f" - invoke_many(selector, params) -- fan out a function call " + f"over a selector-resolved set in parallel.\n\n" f"INSTRUCTIONS:\n" f"When you receive device events, you MUST:\n" f"1. Analyze the events\n" f"2. Use discover() with a function-scoped selector to check " f"available functions if needed\n" - f"3. Use invoke_device() to interact with devices\n" + f"3. Use invoke() or invoke_many() to interact with devices\n" f"4. Report what you found and what actions you took\n\n" f"Always provide llm_reasoning when invoking devices to explain your decision.\n" f"Always call at least one tool per batch of events." diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py index db71bc2..528e554 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py @@ -463,6 +463,269 @@ def discover_labels( return out +# ── Selector-driven operations ─────────────────────────────────── + + +# Default per-target timeout for invoke_many fan-out. Configurable per call. +DEFAULT_INVOKE_TIMEOUT = 30.0 + +# Cap on parallel worker threads for invoke_many fan-out. Larger fleets can +# raise this via the ``max_concurrency`` argument; the default keeps thread +# overhead bounded while still parallelising typical 10-100 device fan-outs. +DEFAULT_INVOKE_CONCURRENCY = 32 + + +def _resolve_function_tuples( + selector: str, +) -> tuple[list[dict] | None, dict[str, Any] | None]: + """Resolve a selector to (device_id, function_name) tuples for invocation. + + Walks pagination so callers do not have to. Returns ``(rows, None)`` on + success or ``(None, error_envelope)`` if the selector failed to parse, + used a non-function scope, or the registry was unreachable. + """ + rows: list[dict] = [] + offset = 0 + while True: + page = discover(selector, offset=offset, limit=DISCOVER_HARD_LIMIT) + if "error" in page: + return None, page + if page["scope"] not in ( + Scope.DEVICE_FUNCTION.value, Scope.FUNCTION_ONLY.value, + ): + return None, _empty_envelope( + scope=page["scope"], + error=_error( + "invalid_invoke_scope", + "invoke/invoke_many require a function-scoped selector " + "(device(...).function(...) or function(...)); got " + f"scope={page['scope']!r}", + ), + ) + rows.extend(page["results"]) + if page["next_offset"] is None: + break + offset = page["next_offset"] + return rows, None + + +def _shape_invoke_response( + response: dict[str, Any], + device_id: str, + function_name: str, +) -> dict[str, Any]: + """Normalize a JSON-RPC response into a {success, result|error} envelope. + + JSON-RPC error objects arrive as ``{"code": int, "message": str}`` from + the wire; this maps them to the structured ``{code: str, message: str}`` + error shape that the rest of the agent surface uses. + """ + if "error" in response: + err = response["error"] + if isinstance(err, dict): + code = str(err.get("code", "invoke_failed")) + message = str(err.get("message", err)) + else: + code, message = "invoke_failed", str(err) + return { + "success": False, + "device_id": device_id, + "function": function_name, + "error": {"code": code, "message": message}, + } + return { + "success": True, + "device_id": device_id, + "function": function_name, + "result": response.get("result", {}), + } + + +def invoke( + selector: str, + params: dict[str, Any] | None = None, + llm_reasoning: str | None = None, +) -> dict[str, Any]: + """Resolve a selector to one (device, function) tuple and invoke it. + + Use this when the call is unambiguous -- one device, one function. + The selector must use ``device().function()`` or + ``function()`` scope. + + Args: + selector: Selector expression resolving to exactly one function. + params: Function parameters dict. Do NOT put ``llm_reasoning`` + inside ``params``. + llm_reasoning: Decision rationale for observability. + + Returns: + On success: ``{"success": True, "device_id": ..., "function": ..., + "result": ...}``. + On failure: ``{"success": False, "error": {"code": ..., + "message": ...}}``. Codes include the discover() codes plus + ``no_match`` (zero matches), ``ambiguous_match`` (multiple + matches), ``invalid_invoke_scope`` (selector did not target + functions), and ``invoke_failed`` (the device returned an error). + """ + rows, error_envelope = _resolve_function_tuples(selector) + if error_envelope is not None: + return {"success": False, "error": error_envelope["error"]} + + if not rows: + return { + "success": False, + "error": _error( + "no_match", + f"selector matched 0 functions: {selector!r}", + ), + } + if len(rows) > 1: + return { + "success": False, + "error": _error( + "ambiguous_match", + f"selector matched {len(rows)} functions, expected exactly 1: " + f"{selector!r}", + ), + "candidates": [ + {"device_id": r.get("device_id"), "function": r.get("name")} + for r in rows[:10] + ], + } + + row = rows[0] + device_id = row.get("device_id") or "" + function_name = row.get("name") or "" + + trace_id = f"trace-{uuid.uuid4().hex[:12]}" + if llm_reasoning: + truncated = ( + llm_reasoning[:200] + "..." + if len(llm_reasoning) > 200 else llm_reasoning + ) + logger.info( + "[%s] [%s::%s] Reason: %s", + trace_id, device_id, function_name, truncated, + ) + + try: + conn = get_connection() + clean = {k: v for k, v in (params or {}).items() if k != "llm_reasoning"} + response = conn.invoke(device_id, function_name, params=clean) + except Exception as e: + logger.error( + "[%s] %s::%s -> ERROR: %s", + trace_id, device_id, function_name, e, + ) + return { + "success": False, + "device_id": device_id, + "function": function_name, + "error": _error("invoke_failed", str(e)), + } + return _shape_invoke_response(response, device_id, function_name) + + +def invoke_many( + selector: str, + params: dict[str, Any] | None = None, + timeout: float = DEFAULT_INVOKE_TIMEOUT, + max_concurrency: int = DEFAULT_INVOKE_CONCURRENCY, + llm_reasoning: str | None = None, +) -> dict[str, Any]: + """Resolve a selector to (device, function) tuples and invoke each in parallel. + + Returns aggregated results with partial-failure semantics: a single + target's failure does not abort the rest. Each target gets ``timeout`` + seconds; the overall call returns once every target has finished or + timed out. + + Args: + selector: Function-scoped selector + (``device(...).function(...)`` or ``function(...)``). + params: Function parameters dict applied to every target. + timeout: Per-target timeout in seconds. + max_concurrency: Cap on parallel worker threads. + llm_reasoning: Decision rationale for observability. + + Returns: + ``{"candidates": N, "matched": N, "succeeded": S, "failed": F, + "results": [{device_id, function, result}, ...], + "errors": [{device_id, function, error}, ...]}``. + + ``candidates`` is the count returned by the selector resolver. + ``matched`` is the same value in this release; once edge-side + ``where`` predicates land, ``matched`` will narrow below + ``candidates`` to reflect post-predicate self-election. + + On selector parse / connection failure the envelope is returned + with all counts at zero plus a top-level ``error`` field. + """ + import concurrent.futures + + rows, error_envelope = _resolve_function_tuples(selector) + if error_envelope is not None: + return { + "candidates": 0, "matched": 0, "succeeded": 0, "failed": 0, + "results": [], "errors": [], "error": error_envelope["error"], + } + + out: dict[str, Any] = { + "candidates": len(rows), + "matched": len(rows), + "succeeded": 0, + "failed": 0, + "results": [], + "errors": [], + } + if not rows: + return out + + workers = max(1, min(max_concurrency, len(rows))) + clean = {k: v for k, v in (params or {}).items() if k != "llm_reasoning"} + + def call_one(row: dict) -> dict[str, Any]: + device_id = row.get("device_id") or "" + function_name = row.get("name") or "" + try: + conn = get_connection() + response = conn.invoke( + device_id, function_name, params=clean, timeout=timeout, + ) + except Exception as e: + response = {"error": {"code": "invoke_failed", "message": str(e)}} + return _shape_invoke_response(response, device_id, function_name) + + if llm_reasoning: + truncated = ( + llm_reasoning[:200] + "..." + if len(llm_reasoning) > 200 else llm_reasoning + ) + logger.info( + "[invoke_many::%d targets] Reason: %s", len(rows), truncated, + ) + + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as exe: + futures = [exe.submit(call_one, row) for row in rows] + for future in concurrent.futures.as_completed(futures): + shaped = future.result() + if shaped["success"]: + out["results"].append({ + "device_id": shaped["device_id"], + "function": shaped["function"], + "result": shaped["result"], + }) + out["succeeded"] += 1 + else: + out["errors"].append({ + "device_id": shaped["device_id"], + "function": shaped["function"], + "error": shaped["error"], + }) + out["failed"] += 1 + return out + + # ── Hierarchical discovery tools ───────────────────────────────── @@ -650,22 +913,20 @@ def invoke_device( params: dict[str, Any] | None = None, llm_reasoning: str | None = None, ) -> dict[str, Any]: - """Call a function on a Device Connect device. + """Call a function on a Device Connect device (deprecated; use invoke()). Args: device_id: Target device ID (e.g., "robot-001", "camera-001"). - function: Function name to call (e.g., "start_cleaning", "capture_image"). - params: Function parameters as a dictionary. Check get_device_functions() for schemas. - Do NOT put llm_reasoning inside params. - llm_reasoning: Why you're calling this function — for observability. - - Example: - result = invoke_device( - device_id="robot-001", function="start_cleaning", - params={"zone": "zone-A"}, - llm_reasoning="Camera detected spill in zone-A" - ) + function: Function name to call. + params: Function parameters as a dictionary. + llm_reasoning: Why you're calling this function -- for observability. """ + warnings.warn( + "invoke_device(device_id, function, ...) is deprecated; use " + "invoke('device().function()', params) instead.", + DeprecationWarning, + stacklevel=2, + ) trace_id = f"trace-{uuid.uuid4().hex[:12]}" if llm_reasoning: truncated = llm_reasoning[:200] + "..." if len(llm_reasoning) > 200 else llm_reasoning diff --git a/packages/device-connect-agent-tools/tests/test_claude_adapter.py b/packages/device-connect-agent-tools/tests/test_claude_adapter.py index b0e2ac6..311aab5 100644 --- a/packages/device-connect-agent-tools/tests/test_claude_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_claude_adapter.py @@ -68,7 +68,8 @@ def _mock_sdk_and_connection(): "discover_labels", "discover", "discover_devices", - "invoke_device", + "invoke", + "invoke_many", "invoke_device_with_fallback", "get_device_status", ) diff --git a/packages/device-connect-agent-tools/tests/test_invoke.py b/packages/device-connect-agent-tools/tests/test_invoke.py new file mode 100644 index 0000000..aae1a83 --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_invoke.py @@ -0,0 +1,336 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector-driven ``invoke`` and ``invoke_many`` tools. + +Uses a small labeled fleet (cam-001, cam-002, robot-001, sensor-001) drawn +from the existing DC test driver vocabulary so every selector exercises +real device, function, and event names. +""" +import time +from unittest.mock import MagicMock, patch + +import pytest + +from device_connect_agent_tools import tools as tools_mod + + +# -- Fixtures ------------------------------------------------------- + + +SAMPLE_DEVICES = [ + { + "device_id": "cam-001", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "cam-002", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "robot-001", + "device_type": "cleaner_robot", + "location": "lab-A", + "status": {"state": "idle"}, + "identity": {"device_type": "cleaner_robot"}, + "labels": {"category": "robot", "location": "lab-A"}, + "functions": [ + { + "name": "dispatch_robot", + "parameters": {}, + "labels": {"direction": "write", "safety": "critical"}, + }, + ], + "events": [], + }, + { + "device_id": "sensor-001", + "device_type": "temperature_sensor", + "location": "lab-B", + "status": {"state": "online"}, + "identity": {"device_type": "temperature_sensor"}, + "labels": {"category": "sensor", "location": "lab-B"}, + "functions": [ + { + "name": "get_reading", + "parameters": {}, + "labels": {"direction": "read"}, + }, + ], + "events": [], + }, +] + + +def _conn_with_invoke(invoke_side_effect): + """Return a mock Connection whose .invoke() applies ``invoke_side_effect``. + + ``invoke_side_effect`` is called with ``(device_id, function_name, + params, timeout)`` and must return a JSON-RPC response dict. + """ + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + + def _invoke(device_id, function_name, params=None, timeout=None): + return invoke_side_effect(device_id, function_name, params, timeout) + + conn.invoke.side_effect = _invoke + return conn + + +@pytest.fixture +def all_succeed_conn(): + def _ok(device_id, function_name, params, timeout): + return {"jsonrpc": "2.0", "id": "1", "result": { + "device_id": device_id, "function": function_name, "params": params, + }} + conn = _conn_with_invoke(_ok) + with patch.object(tools_mod, "get_connection", return_value=conn): + yield conn + + +# -- invoke --------------------------------------------------------- + + +class TestInvoke: + def test_single_match_returns_success(self, all_succeed_conn): + r = tools_mod.invoke( + "device(cam-001).function(capture_image)", + params={"resolution": "1080p"}, + ) + assert r["success"] is True + assert r["device_id"] == "cam-001" + assert r["function"] == "capture_image" + assert r["result"]["params"] == {"resolution": "1080p"} + + def test_function_only_selector_with_unique_name(self, all_succeed_conn): + r = tools_mod.invoke("function(get_reading)") + assert r["success"] is True + assert r["device_id"] == "sensor-001" + assert r["function"] == "get_reading" + + def test_no_match_returns_no_match_error(self, all_succeed_conn): + r = tools_mod.invoke("device(*).function(does_not_exist)") + assert r["success"] is False + assert r["error"]["code"] == "no_match" + assert "does_not_exist" in r["error"]["message"] + + def test_ambiguous_match_returns_error_with_candidates(self, all_succeed_conn): + # capture_image exists on both cam-001 and cam-002. + r = tools_mod.invoke("function(capture_image)") + assert r["success"] is False + assert r["error"]["code"] == "ambiguous_match" + assert "expected exactly 1" in r["error"]["message"] + ids = {c["device_id"] for c in r["candidates"]} + assert ids == {"cam-001", "cam-002"} + + def test_device_only_scope_rejected(self, all_succeed_conn): + # Device-only scope cannot resolve to a function. + r = tools_mod.invoke("device(robot-001)") + assert r["success"] is False + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_event_scope_rejected(self, all_succeed_conn): + r = tools_mod.invoke("event(reading)") + assert r["success"] is False + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_selector_parse_error_propagated(self, all_succeed_conn): + r = tools_mod.invoke("not a selector") + assert r["success"] is False + assert r["error"]["code"] == "selector_parse_error" + + def test_non_string_selector_rejected(self, all_succeed_conn): + r = tools_mod.invoke(None) # type: ignore[arg-type] + assert r["success"] is False + assert r["error"]["code"] == "invalid_selector" + + def test_jsonrpc_error_maps_to_invoke_failed(self): + def _err(device_id, function_name, params, timeout): + return { + "jsonrpc": "2.0", "id": "1", + "error": {"code": -32000, "message": "device busy"}, + } + conn = _conn_with_invoke(_err) + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke("device(robot-001).function(dispatch_robot)") + assert r["success"] is False + assert r["error"]["code"] == "-32000" + assert r["error"]["message"] == "device busy" + assert r["device_id"] == "robot-001" + assert r["function"] == "dispatch_robot" + + def test_connection_exception_returns_invoke_failed(self): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + conn.invoke.side_effect = RuntimeError("messaging down") + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke("device(cam-001).function(capture_image)") + assert r["success"] is False + assert r["error"]["code"] == "invoke_failed" + assert "messaging down" in r["error"]["message"] + + def test_llm_reasoning_stripped_from_params(self, all_succeed_conn): + tools_mod.invoke( + "device(cam-001).function(capture_image)", + params={"resolution": "1080p", "llm_reasoning": "should not appear"}, + llm_reasoning="caller reasoning", + ) + # Inspect the params actually delivered to the wire: + sent = all_succeed_conn.invoke.call_args.kwargs["params"] + assert "llm_reasoning" not in sent + assert sent["resolution"] == "1080p" + + +# -- invoke_many ---------------------------------------------------- + + +class TestInvokeMany: + def test_zero_matches_returns_empty_envelope(self, all_succeed_conn): + r = tools_mod.invoke_many("device(*).function(does_not_exist)") + assert r["candidates"] == 0 + assert r["matched"] == 0 + assert r["succeeded"] == 0 + assert r["failed"] == 0 + assert r["results"] == [] + assert r["errors"] == [] + assert "error" not in r + + def test_all_succeed(self, all_succeed_conn): + r = tools_mod.invoke_many("device(*).function(capture_image)") + assert r["candidates"] == 2 + assert r["matched"] == 2 + assert r["succeeded"] == 2 + assert r["failed"] == 0 + ids = {row["device_id"] for row in r["results"]} + assert ids == {"cam-001", "cam-002"} + # Each result row is shaped {device_id, function, result}. + for row in r["results"]: + assert row["function"] == "capture_image" + assert "result" in row + + def test_partial_failure_shape(self): + def _half_fail(device_id, function_name, params, timeout): + if device_id == "cam-001": + return {"jsonrpc": "2.0", "id": "1", "result": {"ok": True}} + return { + "jsonrpc": "2.0", "id": "1", + "error": {"code": -32000, "message": "down"}, + } + conn = _conn_with_invoke(_half_fail) + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke_many("device(*).function(capture_image)") + assert r["candidates"] == 2 + assert r["matched"] == 2 + assert r["succeeded"] == 1 + assert r["failed"] == 1 + assert {row["device_id"] for row in r["results"]} == {"cam-001"} + assert {row["device_id"] for row in r["errors"]} == {"cam-002"} + for row in r["errors"]: + assert row["error"]["code"] == "-32000" + assert row["error"]["message"] == "down" + + def test_invalid_scope_returns_error_envelope(self, all_succeed_conn): + r = tools_mod.invoke_many("device(robot-001)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_selector_parse_error_propagated(self, all_succeed_conn): + r = tools_mod.invoke_many("widgets(*)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "selector_parse_error" + + def test_per_target_timeout_passed_to_connection(self, all_succeed_conn): + tools_mod.invoke_many( + "device(*).function(capture_image)", timeout=7.5, + ) + # Every conn.invoke call should carry the same timeout. + for call in all_succeed_conn.invoke.call_args_list: + assert call.kwargs["timeout"] == 7.5 + + def test_max_concurrency_caps_thread_pool(self, all_succeed_conn): + # The fan-out group has 3 targets (capture_image x2 + dispatch_robot + # don't share name; pick a selector that resolves to multiple). Use + # function(direction:write) which selects 4 distinct rows. + r = tools_mod.invoke_many( + "function(direction:write)", max_concurrency=1, + ) + assert r["candidates"] >= 2 + assert r["succeeded"] == r["candidates"] + + def test_connection_exception_recorded_per_target(self): + # Mix: cam-001 succeeds, cam-002's call raises locally. + def _mixed(device_id, function_name, params, timeout): + if device_id == "cam-002": + raise RuntimeError("messaging blip") + return {"jsonrpc": "2.0", "id": "1", "result": {"ok": True}} + conn = _conn_with_invoke(_mixed) + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke_many("device(*).function(capture_image)") + assert r["succeeded"] == 1 + assert r["failed"] == 1 + cam002_err = next(e for e in r["errors"] if e["device_id"] == "cam-002") + assert cam002_err["error"]["code"] == "invoke_failed" + assert "messaging blip" in cam002_err["error"]["message"] + + def test_llm_reasoning_stripped_from_params(self, all_succeed_conn): + tools_mod.invoke_many( + "device(*).function(capture_image)", + params={"resolution": "4k", "llm_reasoning": "should not appear"}, + ) + for call in all_succeed_conn.invoke.call_args_list: + sent = call.kwargs["params"] + assert "llm_reasoning" not in sent + assert sent["resolution"] == "4k" + + +# -- _resolve_function_tuples --------------------------------------- + + +class TestResolveFunctionTuples: + def test_walks_all_pages(self, all_succeed_conn): + # Use a small DISCOVER_HARD_LIMIT temporarily. + with patch.object(tools_mod, "DISCOVER_HARD_LIMIT", 1): + rows, err = tools_mod._resolve_function_tuples( + "device(*).function(direction:write)" + ) + assert err is None + # 4 distinct (device, function) tuples for direction:write across the + # mock fleet (cam-001, cam-002, robot-001, sensor-001 set_threshold + # and set_location). With limit=1 per page, the resolver had to + # paginate through all of them. + assert len(rows) >= 2 + for row in rows: + assert "device_id" in row + assert "name" in row + + def test_propagates_discover_error(self, all_succeed_conn): + rows, err = tools_mod._resolve_function_tuples("not a selector") + assert rows is None + assert err is not None + assert err["error"]["code"] == "selector_parse_error" diff --git a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py index d647ee3..c4a487e 100644 --- a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py @@ -72,7 +72,8 @@ def _mock_langchain_and_connection(): EXPECTED_TOOLS = { "discover_labels", "discover", - "invoke_device", + "invoke", + "invoke_many", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/tests/test_strands_adapter.py b/packages/device-connect-agent-tools/tests/test_strands_adapter.py index a40b5ad..30d1ae0 100644 --- a/packages/device-connect-agent-tools/tests/test_strands_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_strands_adapter.py @@ -55,7 +55,8 @@ def _mock_strands_and_connection(): EXPECTED_TOOLS = { "discover_labels", "discover", - "invoke_device", + "invoke", + "invoke_many", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-server/device_connect_server/portal/views/devices.py b/packages/device-connect-server/device_connect_server/portal/views/devices.py index 3f82309..7f5bf1e 100644 --- a/packages/device-connect-server/device_connect_server/portal/views/devices.py +++ b/packages/device-connect-server/device_connect_server/portal/views/devices.py @@ -320,7 +320,7 @@ async def download_starter_script(request: web.Request): """Device Connect — starter AI agent (Strands + OpenAI). Connects to Device Connect, discovers your fleet, and reacts to device -events by calling tools (discover_labels, discover, invoke_device). +events by calling tools (discover_labels, discover, invoke, invoke_many). LLM inference runs through the Arm internal OpenAI proxy. Usage: @@ -404,7 +404,7 @@ async def prepare(self) -> Dict[str, Any]: from strands.models.openai import OpenAIModel from device_connect_agent_tools.adapters.strands import ( discover_labels, discover, - invoke_device, invoke_device_with_fallback, get_device_status, + invoke, invoke_many, invoke_device_with_fallback, get_device_status, ) result = await super().prepare() @@ -417,7 +417,7 @@ async def prepare(self) -> Dict[str, Any]: ), tools=[ discover_labels, discover, - invoke_device, invoke_device_with_fallback, get_device_status, + invoke, invoke_many, invoke_device_with_fallback, get_device_status, ], system_prompt=self._build_system_prompt(), ) @@ -454,14 +454,18 @@ def _build_system_prompt(self) -> str: f"functions, or events. Examples:\\n" f" device(category:camera, location:zone-A/*)\\n" f" device(robot-001).function(direction:write)\\n" - f" function(safety:critical)\\n" - f" - invoke_device(device_id, function, params) -- call a device function\\n\\n" + f" function(safety:critical)\\n\\n" + f"INVOCATION TOOLS:\\n" + f" - invoke(selector, params) -- call exactly one function. " + f"Selector must resolve to one (device, function) tuple.\\n" + f" - invoke_many(selector, params) -- fan out a function call " + f"over a selector-resolved set in parallel.\\n\\n" f"INSTRUCTIONS:\\n" f"When you receive device events, you MUST:\\n" f"1. Analyze the events\\n" f"2. Use discover() with a function-scoped selector to check " f"available functions if needed\\n" - f"3. Use invoke_device() to interact with devices\\n" + f"3. Use invoke() or invoke_many() to interact with devices\\n" f"4. Report what you found and what actions you took\\n\\n" f"Always provide llm_reasoning when invoking devices.\\n" f"Always call at least one tool per batch of events." diff --git a/tests/tests/test_tools_invoke.py b/tests/tests/test_tools_invoke.py index 447f301..df9878b 100644 --- a/tests/tests/test_tools_invoke.py +++ b/tests/tests/test_tools_invoke.py @@ -2,40 +2,65 @@ # # SPDX-License-Identifier: Apache-2.0 -"""Integration tests for device-connect-agent-tools invoke_device(). +"""Integration tests for selector-driven invocation tools. -Tests that the agent SDK can invoke device RPCs via the messaging backend. +Covers ``invoke()`` and ``invoke_many()`` against real devices registered +via the messaging backend. Exercises single-match, ambiguous-match, +selector-scope rejection, parallel fan-out, and partial-failure semantics +end-to-end. """ import asyncio -import pytest +import time +import pytest SETTLE_TIME = 0.3 +DISCOVERY_TIMEOUT = 5.0 + + +async def _wait_for_devices(messaging_url, expected_ids): + """Connect and poll until all expected ``device_ids`` are visible.""" + from device_connect_agent_tools import connect + from device_connect_agent_tools.connection import get_connection + + await asyncio.to_thread(connect, nats_url=messaging_url) + deadline = time.monotonic() + DISCOVERY_TIMEOUT + while True: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices) + ids = {d.get("device_id") for d in devices} + if expected_ids.issubset(ids) or time.monotonic() > deadline: + return devices + await asyncio.sleep(0.25) + + +# -- invoke --------------------------------------------------------- @pytest.mark.asyncio @pytest.mark.integration async def test_invoke_sensor_reading(device_spawner, messaging_url): - """invoke_device() should call sensor's get_reading and return result.""" + """invoke() calls sensor.get_reading and returns the reading payload.""" await device_spawner.spawn_sensor( - "itest-tools-invoke-sensor", initial_temp=23.5, initial_humidity=50.0, + "itest-inv-read-sensor", initial_temp=23.5, initial_humidity=50.0, ) await asyncio.sleep(SETTLE_TIME) - from device_connect_agent_tools import connect, disconnect, invoke_device + from device_connect_agent_tools import disconnect, invoke - await asyncio.to_thread(connect, nats_url=messaging_url) + await _wait_for_devices(messaging_url, {"itest-inv-read-sensor"}) try: result = await asyncio.to_thread( - invoke_device, - device_id="itest-tools-invoke-sensor", - function="get_reading", - params={"unit": "celsius"}, - llm_reasoning="Testing sensor read", + invoke, + "device(itest-inv-read-sensor).function(get_reading)", + {"unit": "celsius"}, + "Testing sensor read", ) - assert isinstance(result, dict) - assert result.get("success") is True or "temperature" in result.get("result", {}) + assert result["success"] is True + assert result["device_id"] == "itest-inv-read-sensor" + assert result["function"] == "get_reading" + assert "temperature" in result["result"] finally: await asyncio.to_thread(disconnect) @@ -43,25 +68,26 @@ async def test_invoke_sensor_reading(device_spawner, messaging_url): @pytest.mark.asyncio @pytest.mark.integration async def test_invoke_robot_dispatch(device_spawner, event_capture, messaging_url): - """invoke_device() should dispatch robot and trigger cleaning.""" + """invoke() dispatches the robot and the cleaning_finished event arrives.""" await device_spawner.spawn_robot( - "itest-tools-invoke-robot", clean_duration=0.3, + "itest-inv-robot", clean_duration=0.3, ) await asyncio.sleep(SETTLE_TIME) - async with event_capture.subscribe("device-connect.*.itest-tools-invoke-robot.event.*") as events: - from device_connect_agent_tools import connect, disconnect, invoke_device + async with event_capture.subscribe( + "device-connect.*.itest-inv-robot.event.*" + ) as events: + from device_connect_agent_tools import disconnect, invoke - await asyncio.to_thread(connect, nats_url=messaging_url) + await _wait_for_devices(messaging_url, {"itest-inv-robot"}) try: result = await asyncio.to_thread( - invoke_device, - device_id="itest-tools-invoke-robot", - function="dispatch_robot", - params={"zone_id": "zone-tools"}, - llm_reasoning="Testing robot dispatch via tools", + invoke, + "device(itest-inv-robot).function(dispatch_robot)", + {"zone_id": "zone-tools"}, + "Testing robot dispatch", ) - assert isinstance(result, dict) + assert result["success"] is True finally: await asyncio.to_thread(disconnect) @@ -71,42 +97,184 @@ async def test_invoke_robot_dispatch(device_spawner, event_capture, messaging_ur @pytest.mark.asyncio @pytest.mark.integration -async def test_invoke_unknown_device(messaging_url): - """invoke_device() on non-existent device should return error.""" - from device_connect_agent_tools import connect, disconnect, invoke_device +async def test_invoke_no_match_returns_no_match(device_spawner, messaging_url): + """A selector that resolves to zero functions returns ``no_match``.""" + await device_spawner.spawn_camera("itest-inv-nomatch-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import connect, disconnect, invoke await asyncio.to_thread(connect, nats_url=messaging_url) try: result = await asyncio.to_thread( - invoke_device, - device_id="nonexistent-device-xyz", - function="ping", - llm_reasoning="Testing error handling", + invoke, + "device(itest-inv-nomatch-cam).function(does_not_exist)", + ) + assert result["success"] is False + assert result["error"]["code"] == "no_match" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_ambiguous_match_returns_error(device_spawner, messaging_url): + """A selector matching multiple (device, function) tuples returns an error.""" + await device_spawner.spawn_camera("itest-inv-amb-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-inv-amb-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, invoke + + await _wait_for_devices( + messaging_url, {"itest-inv-amb-cam-1", "itest-inv-amb-cam-2"} + ) + try: + result = await asyncio.to_thread( + invoke, "device(itest-inv-amb-cam-*).function(capture_image)", + ) + assert result["success"] is False + assert result["error"]["code"] == "ambiguous_match" + cand_ids = {c["device_id"] for c in result["candidates"]} + assert {"itest-inv-amb-cam-1", "itest-inv-amb-cam-2"} <= cand_ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_device_only_scope_rejected(device_spawner, messaging_url): + """A device-only selector cannot resolve to a function.""" + await device_spawner.spawn_camera("itest-inv-scope-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import connect, disconnect, invoke + + await asyncio.to_thread(connect, nats_url=messaging_url) + try: + result = await asyncio.to_thread(invoke, "device(itest-inv-scope-cam)") + assert result["success"] is False + assert result["error"]["code"] == "invalid_invoke_scope" + finally: + await asyncio.to_thread(disconnect) + + +# -- invoke_many ---------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_many_succeeds_across_devices(device_spawner, messaging_url): + """invoke_many() fans out a single function across multiple matching devices.""" + await device_spawner.spawn_camera("itest-inv-many-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-inv-many-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, invoke_many + + await _wait_for_devices( + messaging_url, {"itest-inv-many-cam-1", "itest-inv-many-cam-2"} + ) + try: + result = await asyncio.to_thread( + invoke_many, + "device(itest-inv-many-cam-*).function(capture_image)", + {"resolution": "720p"}, ) - assert isinstance(result, dict) - assert result.get("success") is False + assert result["candidates"] == 2 + assert result["matched"] == 2 + assert result["succeeded"] == 2 + assert result["failed"] == 0 + ids = {row["device_id"] for row in result["results"]} + assert ids == {"itest-inv-many-cam-1", "itest-inv-many-cam-2"} finally: await asyncio.to_thread(disconnect) @pytest.mark.asyncio @pytest.mark.integration -async def test_invoke_camera_capture(device_spawner, messaging_url): - """invoke_device() should capture image from camera.""" - await device_spawner.spawn_camera("itest-tools-invoke-cam") +async def test_invoke_many_partial_failure(device_spawner, messaging_url): + """A failing target is recorded in errors while siblings succeed.""" + await device_spawner.spawn_camera( + "itest-inv-many-pf-cam-1", location="lab-A", failure_rate=1.0, + ) + await device_spawner.spawn_camera( + "itest-inv-many-pf-cam-2", location="lab-A", + ) await asyncio.sleep(SETTLE_TIME) - from device_connect_agent_tools import connect, disconnect, invoke_device + from device_connect_agent_tools import disconnect, invoke_many + + await _wait_for_devices( + messaging_url, + {"itest-inv-many-pf-cam-1", "itest-inv-many-pf-cam-2"}, + ) + try: + result = await asyncio.to_thread( + invoke_many, + "device(itest-inv-many-pf-cam-*).function(capture_image)", + ) + assert result["candidates"] == 2 + assert result["matched"] == 2 + assert result["succeeded"] == 1 + assert result["failed"] == 1 + success_ids = {row["device_id"] for row in result["results"]} + error_ids = {row["device_id"] for row in result["errors"]} + assert success_ids == {"itest-inv-many-pf-cam-2"} + assert error_ids == {"itest-inv-many-pf-cam-1"} + for row in result["errors"]: + assert "code" in row["error"] + assert "message" in row["error"] + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_many_zero_candidates(device_spawner, messaging_url): + """No matches yields an empty envelope, not an error.""" + await device_spawner.spawn_camera("itest-inv-many-zero-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import connect, disconnect, invoke_many await asyncio.to_thread(connect, nats_url=messaging_url) try: result = await asyncio.to_thread( - invoke_device, - device_id="itest-tools-invoke-cam", - function="capture_image", - params={"resolution": "720p"}, - llm_reasoning="Testing camera capture via tools", + invoke_many, + "device(itest-no-such-device).function(capture_image)", ) - assert isinstance(result, dict) + assert result["candidates"] == 0 + assert result["matched"] == 0 + assert result["succeeded"] == 0 + assert result["failed"] == 0 + assert result["results"] == [] + assert result["errors"] == [] + assert "error" not in result + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_many_function_only_selector(device_spawner, messaging_url): + """function() selects the function across the whole fleet.""" + await device_spawner.spawn_sensor( + "itest-inv-many-fo-sensor", initial_temp=20.0, + ) + await device_spawner.spawn_camera("itest-inv-many-fo-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, invoke_many + + await _wait_for_devices( + messaging_url, {"itest-inv-many-fo-cam", "itest-inv-many-fo-sensor"} + ) + try: + result = await asyncio.to_thread(invoke_many, "function(get_reading)") + ids = {row["device_id"] for row in result["results"]} + assert "itest-inv-many-fo-sensor" in ids + # Camera does not have get_reading; should not be in results. + assert "itest-inv-many-fo-cam" not in ids finally: await asyncio.to_thread(disconnect) From b64edac4d6426d6443ee7f071bcbff63e60b9429 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 11:15:16 -0700 Subject: [PATCH 2/8] feat(predicate): add CEL where evaluator with optional [predicate] extra Add device_connect_edge.predicate, a thin wrapper around cel-python that compiles where expressions into reusable WherePredicate objects and evaluates them against device-local context (identity, labels, status, shared bindings). CEL was chosen over JSONLogic because the v4 design's mask-indexing pattern (mask[seat_row][seat_col] == 1) needs computed array indices, which JSONLogic's literal-path var operator cannot express without flattening the mask to 1D and indexing arithmetically. CEL handles it natively. cel-python is an optional dependency. Importing the module without it installed succeeds; compiling or evaluating a predicate raises a clear PredicateCompileError pointing at the [predicate] extra: pip install device-connect-edge[predicate] pip install device-connect-agent-tools[predicate] The evaluator is shared by the dispatcher (validates expressions before sending them out) and the device runtime (evaluates per-call to decide whether to execute a fan-out). 16 unit tests cover compilation, evaluation, the mask-indexing regression case, missing-variable and type-mismatch error surfaces, and evaluator reusability. --- .../device-connect-agent-tools/pyproject.toml | 1 + .../device_connect_edge/predicate.py | 163 ++++++++++++++++++ packages/device-connect-edge/pyproject.toml | 3 + .../tests/test_predicate.py | 131 ++++++++++++++ 4 files changed, 298 insertions(+) create mode 100644 packages/device-connect-edge/device_connect_edge/predicate.py create mode 100644 packages/device-connect-edge/tests/test_predicate.py diff --git a/packages/device-connect-agent-tools/pyproject.toml b/packages/device-connect-agent-tools/pyproject.toml index ec0f198..606073c 100644 --- a/packages/device-connect-agent-tools/pyproject.toml +++ b/packages/device-connect-agent-tools/pyproject.toml @@ -37,6 +37,7 @@ strands = ["strands-agents>=1.0"] langchain = ["langchain-core>=0.2"] claude = ["claude-agent-sdk>=0.1"] mcp = ["fastmcp>=1.0"] +predicate = ["device-connect-edge[predicate]"] dev = [ "pytest>=8.0", "pytest-asyncio>=0.23", diff --git a/packages/device-connect-edge/device_connect_edge/predicate.py b/packages/device-connect-edge/device_connect_edge/predicate.py new file mode 100644 index 0000000..6ddc7c0 --- /dev/null +++ b/packages/device-connect-edge/device_connect_edge/predicate.py @@ -0,0 +1,163 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""CEL ``where`` predicate evaluator for self-election at the edge. + +A ``where`` predicate is a CEL (Common Expression Language) expression that +each candidate device evaluates against its own context to decide whether +to execute a fan-out call. The predicate sees four top-level variables: + + identity device-local identity dict (device_id, device_type, ...) + labels device labels (the same labels selectors filter on) + status device status (heartbeat-updated: location, availability, + battery, online, ...) + bindings shared payload supplied by the caller (selection masks, + thresholds, lookup tables) + +Examples (every example here ships with v4 spec):: + + battery > 50 + labels.category == "camera" && status.battery > 50 + mask[seat_row][seat_col] == 1 + bindings.threshold < status.temperature + +CEL is sandboxed by construction: no I/O, no filesystem, no exec. This +module wraps `cel-python` with lazy import so device-connect-edge does +not require it as a hard dependency. Install with the optional +``[predicate]`` extra:: + + pip install device-connect-edge[predicate] + +The evaluator is shared by the dispatcher (validates the expression +before broadcast) and the device runtime (evaluates per-call to decide +whether to execute the fan-out). +""" + +from __future__ import annotations + +from typing import Any, Mapping + + +class PredicateCompileError(ValueError): + """Raised when a ``where`` expression fails to compile. + + Carries the original cel-python error chained so callers can drill in + if they need the exact parse position. + """ + + +class PredicateEvalError(RuntimeError): + """Raised when an otherwise-valid predicate fails at evaluation time. + + Typical causes: missing context key, type mismatch (e.g. comparing a + string to an int), or arithmetic overflow. + """ + + +# Lazy import: ``cel-python`` is an optional extra. Importers of this module +# pay no cost unless they actually compile a predicate. +def _require_celpy(): + try: + import celpy # type: ignore[import-not-found] + return celpy + except ImportError as e: + raise PredicateCompileError( + "where predicates require the 'cel-python' package; " + "install with the [predicate] extra: " + "pip install 'device-connect-edge[predicate]'" + ) from e + + +def _to_cel(value: Any) -> Any: + """Recursively wrap a Python value as the matching CEL type. + + Native Python ints, strings, dicts, and lists arrive at the boundary + untyped; cel-python's evaluator expects its own typed wrappers + (``IntType``, ``MapType``, ``ListType``, ...). We wrap once at the + top of evaluation rather than asking callers to import celtypes. + """ + celpy = _require_celpy() + ct = celpy.celtypes + if value is None: + return None + if isinstance(value, bool): + return ct.BoolType(value) + if isinstance(value, int): + return ct.IntType(value) + if isinstance(value, float): + return ct.DoubleType(value) + if isinstance(value, str): + return ct.StringType(value) + if isinstance(value, (bytes, bytearray)): + return ct.BytesType(bytes(value)) + if isinstance(value, Mapping): + return ct.MapType({ + ct.StringType(str(k)): _to_cel(v) for k, v in value.items() + }) + if isinstance(value, (list, tuple)): + return ct.ListType([_to_cel(v) for v in value]) + # Fallback: stringify. Rare; happens for custom objects in the context. + return ct.StringType(str(value)) + + +class WherePredicate: + """A compiled ``where`` predicate, ready to evaluate against device context. + + Compile once (typically at the dispatcher when the call comes in or at + the edge when the broadcast envelope is received), then evaluate once + per candidate. Predicates are stateless and safe to reuse across calls. + """ + + __slots__ = ("expression", "_program") + + def __init__(self, expression: str, _program: Any): + self.expression = expression + self._program = _program + + def evaluate(self, context: Mapping[str, Any]) -> bool: + """Return ``True`` if the predicate holds for ``context``. + + ``context`` should be a flat mapping of variable name to Python + value. Common keys: ``identity``, ``labels``, ``status``, + ``bindings``. Missing keys are not auto-defaulted; if the + predicate references one, the call raises PredicateEvalError so + the caller can decide between fail-open and fail-closed. + """ + celpy = _require_celpy() + cel_context = {k: _to_cel(v) for k, v in context.items()} + try: + result = self._program.evaluate(cel_context) + except celpy.CELEvalError as e: + raise PredicateEvalError( + f"failed to evaluate where {self.expression!r}: {e}" + ) from e + return bool(result) + + +def compile_where(expression: str) -> WherePredicate: + """Compile a ``where`` expression into a reusable :class:`WherePredicate`. + + Raises :class:`PredicateCompileError` if cel-python is not installed + or the expression is malformed. + """ + celpy = _require_celpy() + if not isinstance(expression, str): + raise PredicateCompileError( + f"where expression must be a string, got {type(expression).__name__}" + ) + if not expression.strip(): + raise PredicateCompileError("where expression must be non-empty") + env = celpy.Environment() + try: + ast = env.compile(expression) + except Exception as e: + # cel-python surfaces parse errors via several exception classes + # depending on the failure mode (lark.UnexpectedToken, ValueError, + # CELParseError). Catch broadly and rewrap so callers only see + # PredicateCompileError. + raise PredicateCompileError( + f"failed to compile where {expression!r}: {e}" + ) from e + program = env.program(ast) + return WherePredicate(expression=expression, _program=program) diff --git a/packages/device-connect-edge/pyproject.toml b/packages/device-connect-edge/pyproject.toml index 27b5e88..58de4d1 100644 --- a/packages/device-connect-edge/pyproject.toml +++ b/packages/device-connect-edge/pyproject.toml @@ -38,6 +38,9 @@ dependencies = [ [project.optional-dependencies] zenoh = [] # Zenoh is now a core dependency; kept for backward compat +predicate = [ + "cel-python>=0.5.0", +] telemetry = [ "opentelemetry-api>=1.30.0", "opentelemetry-sdk>=1.30.0", diff --git a/packages/device-connect-edge/tests/test_predicate.py b/packages/device-connect-edge/tests/test_predicate.py new file mode 100644 index 0000000..dfaff81 --- /dev/null +++ b/packages/device-connect-edge/tests/test_predicate.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the CEL ``where`` predicate evaluator. + +These tests require the ``[predicate]`` extra (cel-python). They are +skipped automatically when cel-python is not installed so the rest of +the edge test suite stays runnable on minimal installs. +""" +from __future__ import annotations + +import pytest + +celpy = pytest.importorskip("celpy") + +from device_connect_edge.predicate import ( + PredicateCompileError, + PredicateEvalError, + WherePredicate, + compile_where, +) + + +# -- compile_where -------------------------------------------------- + + +class TestCompile: + def test_simple_comparison_compiles(self): + p = compile_where("battery > 50") + assert isinstance(p, WherePredicate) + assert p.expression == "battery > 50" + + def test_boolean_combination_compiles(self): + p = compile_where("a > 1 && b < 10 || c == 'x'") + assert isinstance(p, WherePredicate) + + def test_array_indexing_compiles(self): + p = compile_where("mask[row][col] == 1") + assert isinstance(p, WherePredicate) + + def test_label_dot_access_compiles(self): + p = compile_where("labels.category == 'camera'") + assert isinstance(p, WherePredicate) + + def test_empty_expression_rejected(self): + with pytest.raises(PredicateCompileError): + compile_where("") + with pytest.raises(PredicateCompileError): + compile_where(" ") + + def test_non_string_rejected(self): + with pytest.raises(PredicateCompileError): + compile_where(123) # type: ignore[arg-type] + + def test_malformed_expression_rejected(self): + with pytest.raises(PredicateCompileError) as exc: + compile_where("a > > b") + assert "failed to compile" in str(exc.value) + + +# -- evaluate ------------------------------------------------------- + + +class TestEvaluate: + def test_truthy_comparison(self): + p = compile_where("battery > 50") + assert p.evaluate({"battery": 80}) is True + assert p.evaluate({"battery": 30}) is False + + def test_label_match(self): + p = compile_where("labels.category == 'camera'") + assert p.evaluate({"labels": {"category": "camera"}}) is True + assert p.evaluate({"labels": {"category": "robot"}}) is False + + def test_2d_mask_indexing(self): + # The mask-indexing case is the deciding example for picking CEL + # over JSONLogic; keep it as a regression guard. + p = compile_where("mask[row][col] == 1") + ctx = { + "mask": [[0, 1, 0], [1, 0, 0]], + "row": 0, + "col": 1, + } + assert p.evaluate(ctx) is True + ctx["col"] = 0 + assert p.evaluate(ctx) is False + + def test_combined_label_and_status(self): + p = compile_where("labels.category == 'camera' && status.battery > 50") + ctx = { + "labels": {"category": "camera"}, + "status": {"battery": 80}, + } + assert p.evaluate(ctx) is True + ctx["status"]["battery"] = 30 + assert p.evaluate(ctx) is False + ctx["labels"]["category"] = "robot" + ctx["status"]["battery"] = 80 + assert p.evaluate(ctx) is False + + def test_bindings_and_status_compose(self): + p = compile_where("status.temperature > bindings.threshold") + ctx = { + "status": {"temperature": 75.5}, + "bindings": {"threshold": 70.0}, + } + assert p.evaluate(ctx) is True + + def test_string_in_list(self): + p = compile_where("labels.category in ['camera', 'inference']") + assert p.evaluate({"labels": {"category": "camera"}}) is True + assert p.evaluate({"labels": {"category": "robot"}}) is False + + def test_missing_variable_raises_eval_error(self): + p = compile_where("status.battery > 50") + with pytest.raises(PredicateEvalError): + p.evaluate({}) + + def test_type_mismatch_raises_eval_error(self): + p = compile_where("battery > 50") + with pytest.raises(PredicateEvalError): + p.evaluate({"battery": "not a number"}) + + def test_evaluator_is_reusable(self): + # Compile once, evaluate against many contexts. Reusability is the + # property that lets callers compile broadcast envelopes once at + # the dispatcher and ship them to N targets. + p = compile_where("battery > 50") + results = [p.evaluate({"battery": v}) for v in (10, 50, 51, 100)] + assert results == [False, False, True, True] From 4e2208a3712bb2de58de8583a38223c9c28d032c Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 12:05:01 -0700 Subject: [PATCH 3/8] feat(broadcast): async fan-out with correlation, fire_at, and subscribe Add the async selector-driven fan-out path so callers do not have to block on the slowest device: - broadcast(selector, params, where=, bindings=, fire_at=, on_late=) publishes a single envelope to a fanout subject keyed by tenant. Returns immediately with a correlation_id and the candidate count. Compile-validates the optional CEL where predicate at the dispatcher so syntax errors short-circuit before reaching the wire. - DeviceRuntime._broadcast_subscription receives envelopes on ``device-connect..broadcast``. Each candidate self-elects via the target_device_ids gate (pre-resolved by the dispatcher from the selector), then evaluates the optional where predicate against its own context (identity, labels, status, shared bindings). On match the device executes the function and emits a reply on ``device-connect...event.async_reply.`` carrying {success, result|error, actually_fired_at}. - fire_at + on_late synchronized fan-out: the edge holds the message until the wall-clock deadline and fires from its own clock. on_late=skip drops late arrivals (preserves coherence for card-stunt / light-show style workloads); on_late=fire executes immediately. The achieved spread depends on NTP residual (~5-10 ms typical) rather than network jitter (~50-150 ms). - subscribe(selector) returns a Subscription handle. Two selector forms: ``correlation:`` for broadcast replies, and event-scoped selectors (``event()`` or ``device(...).event()``) for live event streams. The handle exposes sync read() and a yielding iter() with idle-timeout reset. - await_replies(correlation_id, timeout, until) sync helper for the common broadcast-then-collect pattern; subscribes, drains, returns the list of reply payloads. The edge predicate context mirrors DeviceStatus.location into labels["location"] when the driver did not declare a labels.location itself, matching the dispatcher-side flatten_device contract so the same selector and predicate strings work on both sides. Test coverage: 38 unit tests across broadcast (12), subscribe (12), and existing modules; 5 NATS integration tests cover end-to-end broadcast + reply, where filter at the edge, fire_at synchronization spread, on_late=skip late-arrival drop, and subscribe(correlation:) streaming. --- .../device_connect_agent_tools/__init__.py | 10 + .../device_connect_agent_tools/connection.py | 22 +- .../device_connect_agent_tools/tools.py | 405 ++++++++++++++++++ .../tests/test_broadcast.py | 201 +++++++++ .../tests/test_subscribe.py | 203 +++++++++ .../device_connect_edge/device.py | 152 +++++++ tests/tests/test_tools_broadcast.py | 213 +++++++++ 7 files changed, 1205 insertions(+), 1 deletion(-) create mode 100644 packages/device-connect-agent-tools/tests/test_broadcast.py create mode 100644 packages/device-connect-agent-tools/tests/test_subscribe.py create mode 100644 tests/tests/test_tools_broadcast.py diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py b/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py index c809baa..1a7c1e0 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py @@ -30,6 +30,11 @@ # Selector-driven invocation (preferred) invoke, invoke_many, + broadcast, + # Selector-driven subscription + Subscription, + subscribe, + await_replies, # Other invocation helpers invoke_device_with_fallback, get_device_status, @@ -54,6 +59,11 @@ # Selector-driven invocation (preferred) "invoke", "invoke_many", + "broadcast", + # Selector-driven subscription + "Subscription", + "subscribe", + "await_replies", # Other invocation helpers "invoke_device_with_fallback", "get_device_status", diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py b/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py index dae997c..b399f70 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py @@ -409,13 +409,33 @@ async def _async_invoke( # ── Broadcast ──────────────────────────────────────────────────── + def publish_broadcast(self, envelope: Dict[str, Any]) -> None: + """Publish a selector-driven broadcast envelope to the fanout subject. + + The envelope shape is documented in + ``device_connect_edge.device.DeviceRuntime._broadcast_subscription``; + every device subscribed to ``device-connect..broadcast`` + receives the message and self-elects via ``target_device_ids`` and + the optional ``where`` predicate. + """ + return self._run(self._async_publish_broadcast(envelope)) + + async def _async_publish_broadcast(self, envelope: Dict[str, Any]) -> None: + subject = f"device-connect.{self.zone}.broadcast" + await self._client.publish(subject, json.dumps(envelope).encode()) + def broadcast( self, function: str, params: Optional[Dict[str, Any]] = None, timeout: float = 5.0, ) -> List[Dict[str, Any]]: - """Invoke a function on all discovered devices and collect results.""" + """Invoke a function on all discovered devices and collect results. + + Sequential sync fan-out (one invoke per device). Predates the + selector-driven broadcast tool; left in place for callers that want + a simple "call this on everyone" without setting up subscriptions. + """ devices = self.list_devices() results = [] for d in devices: diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py index 528e554..c81faf2 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py @@ -20,6 +20,7 @@ import logging import os +import time import uuid import warnings from typing import Any @@ -726,6 +727,410 @@ def call_one(row: dict) -> dict[str, Any]: return out +def broadcast( + selector: str, + params: dict[str, Any] | None = None, + where: str | None = None, + bindings: dict[str, Any] | None = None, + fire_at: float | None = None, + on_late: str = "skip", + llm_reasoning: str | None = None, +) -> dict[str, Any]: + """Async selector-driven fan-out. Returns immediately with a correlation id. + + Use ``broadcast`` when the caller does not want to block on the slowest + device. Each candidate self-elects via the optional ``where`` predicate + (CEL, evaluated at the edge against the device's identity, labels, live + status, and the shared ``bindings``) and emits its reply as an event on + a per-device subject keyed by ``correlation_id``:: + + device-connect...event.async_reply. + + Subscribe to those replies via ``subscribe('correlation:')`` or wait + for them with ``await_replies(correlation_id, timeout=...)``. + + Args: + selector: Function-scoped selector. The selector must resolve to a + single function name across the matched devices; if multiple + functions match, an ``ambiguous_function`` error is returned. + params: Function parameters dict applied to every target. + where: Optional CEL predicate evaluated at the edge per candidate + (e.g. ``"status.battery > 50"``, ``"mask[row][col] == 1"``). + Validated at the dispatcher before publication so syntax + errors return immediately rather than reaching the wire. + bindings: Shared payload merged into the predicate context as + ``bindings.``. Keep small (selection masks, thresholds, + top-K rankings); the same bytes ship to every device. + fire_at: Optional wall-clock epoch seconds. Each device holds the + message and fires its function from its own clock at + ``fire_at`` for synchronized fan-out. + on_late: Policy when a device receives a ``fire_at`` message after + the deadline. ``"skip"`` (default) drops the call; ``"fire"`` + executes immediately. + llm_reasoning: Decision rationale for observability. + + Returns: + On success: ``{"correlation_id": "br-...", "candidates": N, + "selector": ..., "function": ...}``. + On failure: ``{"candidates": 0, "error": {"code", "message"}}`` + with codes including the discover() codes, + ``invalid_invoke_scope``, ``ambiguous_function``, + ``invalid_predicate``, and ``invalid_on_late``. + """ + if on_late not in ("skip", "fire"): + return { + "candidates": 0, + "error": _error( + "invalid_on_late", + f"on_late must be 'skip' or 'fire', got {on_late!r}", + ), + } + + rows, error_envelope = _resolve_function_tuples(selector) + if error_envelope is not None: + return {"candidates": 0, "error": error_envelope["error"]} + + if not rows: + # Empty fan-out: still mint a correlation id so callers waiting on + # replies see a clean "no candidates" rather than a hang. + return { + "correlation_id": f"br-{uuid.uuid4().hex[:12]}", + "candidates": 0, + "selector": selector, + } + + # Broadcast assumes one function per call. If the selector resolves to + # multiple distinct functions, surface that as a structured error so + # the caller can either narrow the selector or split into multiple + # broadcasts. + function_names = {row.get("name") for row in rows if row.get("name")} + if len(function_names) != 1: + return { + "candidates": len(rows), + "error": _error( + "ambiguous_function", + f"selector resolved to {len(function_names)} distinct " + "functions; broadcast requires exactly one function per call: " + f"{sorted(function_names)!r}", + ), + } + function_name = next(iter(function_names)) + + # Compile-validate the where predicate before going to the wire so a + # syntax error short-circuits without bothering devices. + if where is not None: + try: + from device_connect_edge.predicate import compile_where + compile_where(where) + except Exception as e: + return { + "candidates": len(rows), + "error": _error("invalid_predicate", str(e)), + } + + correlation_id = f"br-{uuid.uuid4().hex[:12]}" + target_device_ids = sorted({ + row.get("device_id") for row in rows if row.get("device_id") + }) + clean_params = { + k: v for k, v in (params or {}).items() if k != "llm_reasoning" + } + + envelope: dict[str, Any] = { + "correlation_id": correlation_id, + "function": function_name, + "params": clean_params, + "target_device_ids": target_device_ids, + } + if where: + envelope["where"] = where + if bindings: + envelope["bindings"] = bindings + if fire_at is not None: + envelope["fire_at"] = float(fire_at) + envelope["on_late"] = on_late + + if llm_reasoning: + truncated = ( + llm_reasoning[:200] + "..." + if len(llm_reasoning) > 200 else llm_reasoning + ) + logger.info( + "[broadcast::%s::%d targets] Reason: %s", + correlation_id, len(target_device_ids), truncated, + ) + + try: + conn = get_connection() + conn.publish_broadcast(envelope) + except Exception as e: + logger.error("broadcast publish failed: %s", e) + return { + "candidates": len(target_device_ids), + "error": _error("connection_error", str(e)), + } + + return { + "correlation_id": correlation_id, + "candidates": len(target_device_ids), + "selector": selector, + "function": function_name, + } + + +# ── Selector-driven subscription ───────────────────────────────── + + +# Sentinel used to recognise the broadcast-reply form of a subscribe +# selector (``correlation:``). Kept short so the selector reads +# naturally; the parser matches an exact prefix. +_CORRELATION_PREFIX = "correlation:" + + +class Subscription: + """A live subscription handle returned by :func:`subscribe`. + + Two selector forms produce a subscription: + + * ``"correlation:"`` -- replies from a prior :func:`broadcast` call, + keyed by ``correlation_id`` and routed across all devices that fired. + * Event-scoped selectors (``event()`` or + ``device(...).event()``) -- a multiplex of matching events + across the resolved candidate set. + + The handle exposes a sync ``read`` API that drains buffered messages. + Use as a context manager (or call :meth:`close`) to tear the + underlying messaging subscription down deterministically:: + + with subscribe("correlation:" + cid) as sub: + for reply in sub.iter(timeout=5.0): + process(reply) + """ + + def __init__(self, conn: Any, inbox_names: list[str]): + self._conn = conn + self._inbox_names = list(inbox_names) + self._closed = False + self._cursor = 0 # index into the concatenated message stream + + def read(self, max_messages: int | None = None) -> list[dict[str, Any]]: + """Drain currently buffered messages without blocking. + + Returns parsed payload dicts (already JSON-decoded by the + connection's buffered subscription path). Subsequent calls return + only messages that arrived after the previous call. + """ + if self._closed: + return [] + out: list[dict[str, Any]] = [] + for name in self._inbox_names: + inboxes = self._conn.get_inbox(name) + buffered = inboxes.get(name, []) or [] + # Each buffered entry is (subject, payload). We expose the + # parsed payload but stamp the subject onto it so callers can + # distinguish per-source messages without parsing it themselves. + for subject, payload in buffered: + if not isinstance(payload, dict): + payload = {"raw": payload} + payload = {**payload, "_subject": subject} + out.append(payload) + # Fast cursor: trim per-inbox buffers we have already returned by + # truncating from the front. The connection layer already caps each + # inbox at 1000 entries, so bounded growth is its concern. + for name in self._inbox_names: + self._conn._inbox[name] = [] + if max_messages is not None: + out = out[:max_messages] + return out + + def iter(self, timeout: float = 5.0, poll_interval: float = 0.05): + """Yield messages until ``timeout`` elapses with no new arrivals. + + ``timeout`` resets each time at least one message is yielded, so + callers can drain a steady stream without re-parameterising the + wait. Use ``read`` instead for one-shot draining. + """ + deadline = time.monotonic() + timeout + while not self._closed: + new = self.read() + if new: + for msg in new: + yield msg + deadline = time.monotonic() + timeout + continue + if time.monotonic() >= deadline: + return + time.sleep(poll_interval) + + def close(self) -> None: + """Tear down the underlying messaging subscriptions.""" + if self._closed: + return + self._closed = True + for name in self._inbox_names: + try: + self._conn.unsubscribe_buffered(name) + except Exception: # pragma: no cover - cleanup best effort + logger.debug("close: unsubscribe %s failed", name, exc_info=True) + + def __enter__(self) -> "Subscription": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self.close() + + +def _correlation_subjects(conn: Any, correlation_id: str) -> list[str]: + """Build the per-device wildcard reply subjects for a correlation id. + + The reply template is ``device-connect...event + .async_reply.``; ```` is single-token wildcarded + so a subscription receives replies from any device that fires the + broadcast without having to enumerate them up-front. + """ + return [ + f"device-connect.{conn.zone}.*.event.async_reply.{correlation_id}", + ] + + +def _event_subjects_for_selector(selector: str) -> tuple[list[str] | None, dict[str, Any] | None]: + """Resolve an event-scoped selector to per-device subjects. + + Returns ``(subjects, None)`` on success or ``(None, error_envelope)`` + if the selector failed to parse or used a non-event scope. + """ + rows: list[dict] = [] + offset = 0 + while True: + page = discover(selector, offset=offset, limit=DISCOVER_HARD_LIMIT) + if "error" in page: + return None, page + if page["scope"] not in (Scope.DEVICE_EVENT.value, Scope.EVENT_ONLY.value): + return None, _empty_envelope( + scope=page["scope"], + error=_error( + "invalid_subscribe_scope", + "subscribe requires an event-scoped selector " + "(device(...).event(...) or event(...)) or " + "'correlation:'; got " + f"scope={page['scope']!r}", + ), + ) + rows.extend(page["results"]) + if page["next_offset"] is None: + break + offset = page["next_offset"] + + conn = get_connection() + subjects: list[str] = [] + seen: set[str] = set() + for row in rows: + device_id = row.get("device_id") or "" + event_name = row.get("name") or "" + if not device_id or not event_name: + continue + subj = f"device-connect.{conn.zone}.{device_id}.event.{event_name}" + if subj not in seen: + seen.add(subj) + subjects.append(subj) + return subjects, None + + +def subscribe(selector: str) -> Subscription: + """Subscribe to events or broadcast replies matching a selector. + + Args: + selector: One of: + - ``"correlation:"`` for broadcast replies of a prior call. + - An event-scoped selector (``event()`` or + ``device(...).event()``) for live event streams. + + Returns: + A :class:`Subscription` handle. Iterate with ``sub.iter(timeout)`` + or drain currently-buffered messages with ``sub.read()``. Always + close (or use ``with``) to tear the underlying subscription down. + + Raises: + ValueError on selector errors. The selector string is checked at + the boundary; downstream subscribe calls are not retried, so a + parse error fails fast. + """ + if not isinstance(selector, str) or not selector.strip(): + raise ValueError("subscribe selector must be a non-empty string") + + conn = get_connection() + if selector.startswith(_CORRELATION_PREFIX): + correlation_id = selector[len(_CORRELATION_PREFIX):].strip() + if not correlation_id: + raise ValueError( + "correlation form must be 'correlation:' with non-empty id" + ) + subjects = _correlation_subjects(conn, correlation_id) + inbox_prefix = f"sub-corr-{correlation_id}-{uuid.uuid4().hex[:8]}" + else: + subjects, error_envelope = _event_subjects_for_selector(selector) + if error_envelope is not None: + err = error_envelope.get("error") + msg = err.get("message", str(err)) if isinstance(err, dict) else str(err) + raise ValueError(msg) + if not subjects: + # Nothing to subscribe to. Return an idle Subscription so the + # caller's ``with subscribe(...) as sub: ...`` pattern still + # works without raising; ``read``/``iter`` will yield nothing. + return Subscription(conn, inbox_names=[]) + inbox_prefix = f"sub-evt-{uuid.uuid4().hex[:8]}" + + inbox_names: list[str] = [] + for i, subj in enumerate(subjects): + name = f"{inbox_prefix}-{i}" + conn.subscribe_buffered(subj, name=name) + inbox_names.append(name) + return Subscription(conn, inbox_names=inbox_names) + + +def await_replies( + correlation_id: str, + timeout: float = 10.0, + until: int | None = None, + poll_interval: float = 0.05, +) -> list[dict[str, Any]]: + """Block until ``timeout`` elapses or ``until`` replies have arrived. + + A sync helper for the common broadcast pattern: caller fires a + :func:`broadcast`, then waits for some replies. Builds a one-shot + subscription on the correlation reply subject, drains it, and tears + down before returning. + + Args: + correlation_id: The id returned by :func:`broadcast`. + timeout: Overall wall-clock limit in seconds. + until: Stop early once this many replies have been collected. + poll_interval: How often the helper polls the subscription buffer. + + Returns: + A list of reply payload dicts, each with at least + ``{correlation_id, device_id, success, result|error, + actually_fired_at}``. + """ + if not correlation_id: + return [] + sub = subscribe(f"{_CORRELATION_PREFIX}{correlation_id}") + try: + replies: list[dict[str, Any]] = [] + deadline = time.monotonic() + timeout + while True: + new = sub.read() + replies.extend(new) + if until is not None and len(replies) >= until: + break + if time.monotonic() >= deadline: + break + time.sleep(poll_interval) + return replies + finally: + sub.close() + + # ── Hierarchical discovery tools ───────────────────────────────── diff --git a/packages/device-connect-agent-tools/tests/test_broadcast.py b/packages/device-connect-agent-tools/tests/test_broadcast.py new file mode 100644 index 0000000..e25a35e --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_broadcast.py @@ -0,0 +1,201 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector-driven ``broadcast`` tool. + +Uses the same labeled mock fleet (cam-001, cam-002, robot-001, sensor-001) +as the discover/invoke tests so selectors exercise real device, function, +and event names. +""" +import json +from unittest.mock import MagicMock, patch + +import pytest + +from device_connect_agent_tools import tools as tools_mod + + +SAMPLE_DEVICES = [ + { + "device_id": "cam-001", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "cam-002", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "sensor-001", + "device_type": "temperature_sensor", + "location": "lab-B", + "status": {"state": "online"}, + "identity": {"device_type": "temperature_sensor"}, + "labels": {"category": "sensor"}, + "functions": [ + { + "name": "get_reading", + "parameters": {}, + "labels": {"direction": "read"}, + }, + ], + "events": [], + }, +] + + +@pytest.fixture +def mock_conn(): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + conn.zone = "default" + # Capture the published envelope for assertions. + published: list[dict] = [] + conn.publish_broadcast.side_effect = lambda env: published.append(env) + conn._published = published + with patch.object(tools_mod, "get_connection", return_value=conn): + yield conn + + +# -- broadcast ------------------------------------------------------ + + +class TestBroadcast: + def test_returns_correlation_id_and_candidates(self, mock_conn): + r = tools_mod.broadcast("device(*).function(capture_image)") + assert r["correlation_id"].startswith("br-") + assert r["candidates"] == 2 + assert r["function"] == "capture_image" + assert "error" not in r + + def test_envelope_carries_function_and_targets(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + params={"resolution": "4k"}, + ) + env = mock_conn._published[0] + assert env["function"] == "capture_image" + assert env["params"] == {"resolution": "4k"} + assert sorted(env["target_device_ids"]) == ["cam-001", "cam-002"] + # No optional fields when caller did not set them. + assert "where" not in env + assert "bindings" not in env + assert "fire_at" not in env + assert "on_late" not in env + + def test_where_and_bindings_propagate_to_envelope(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + where="status.battery > 50", + bindings={"threshold": 80}, + ) + env = mock_conn._published[0] + assert env["where"] == "status.battery > 50" + assert env["bindings"] == {"threshold": 80} + + def test_fire_at_propagates_with_default_on_late(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + fire_at=123456789.0, + ) + env = mock_conn._published[0] + assert env["fire_at"] == 123456789.0 + assert env["on_late"] == "skip" + + def test_fire_at_with_explicit_on_late_fire(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + fire_at=123.0, on_late="fire", + ) + env = mock_conn._published[0] + assert env["on_late"] == "fire" + + def test_invalid_on_late_rejected(self, mock_conn): + r = tools_mod.broadcast( + "device(*).function(capture_image)", on_late="bogus", + ) + assert r["candidates"] == 0 + assert r["error"]["code"] == "invalid_on_late" + assert mock_conn.publish_broadcast.call_count == 0 + + def test_ambiguous_function_rejected(self, mock_conn): + # function(direction:read) resolves to multiple distinct functions + # (get_reading + dispatch_robot's get_status if it had read; here + # it just hits sensor's get_reading and possibly more). With our + # SAMPLE_DEVICES this matches just get_reading, so artificially + # broaden by picking a selector that crosses functions: + r = tools_mod.broadcast("device(*).function(*)") + assert r["candidates"] == 3 + assert r["error"]["code"] == "ambiguous_function" + + def test_zero_matches_returns_correlation_with_zero(self, mock_conn): + r = tools_mod.broadcast("device(*).function(does_not_exist)") + assert r["candidates"] == 0 + assert r["correlation_id"].startswith("br-") + # No envelope was published (no targets). + assert mock_conn.publish_broadcast.call_count == 0 + + def test_invalid_scope_rejected(self, mock_conn): + r = tools_mod.broadcast("device(cam-001)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_selector_parse_error_propagated(self, mock_conn): + r = tools_mod.broadcast("widgets(*)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "selector_parse_error" + + def test_invalid_predicate_rejected_before_publish(self, mock_conn): + # The predicate is compile-validated at the dispatcher; a syntax + # error short-circuits without publishing. + try: + import celpy # noqa: F401 + except ImportError: + pytest.skip("cel-python not installed") + r = tools_mod.broadcast( + "device(*).function(capture_image)", where="a > > b", + ) + assert r["error"]["code"] == "invalid_predicate" + assert mock_conn.publish_broadcast.call_count == 0 + + def test_publish_failure_returns_connection_error(self): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + conn.zone = "default" + conn.publish_broadcast.side_effect = RuntimeError("messaging down") + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.broadcast("device(*).function(capture_image)") + assert r["error"]["code"] == "connection_error" + assert "messaging down" in r["error"]["message"] + + def test_llm_reasoning_stripped_from_params(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + params={"resolution": "4k", "llm_reasoning": "should not appear"}, + ) + env = mock_conn._published[0] + assert "llm_reasoning" not in env["params"] diff --git a/packages/device-connect-agent-tools/tests/test_subscribe.py b/packages/device-connect-agent-tools/tests/test_subscribe.py new file mode 100644 index 0000000..a6f032e --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_subscribe.py @@ -0,0 +1,203 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector-driven subscribe + await_replies tools. + +The tests stand up a fake Connection that mirrors the buffered-inbox API +the production class exposes (``subscribe_buffered`` / +``unsubscribe_buffered`` / ``get_inbox`` / ``_inbox`` dict). Real +messaging is not exercised here; integration tests cover the wire. +""" +from unittest.mock import patch + +import pytest + +from device_connect_agent_tools import tools as tools_mod + + +SAMPLE_DEVICES = [ + { + "device_id": "cam-001", + "device_type": "camera", + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [], + "events": [ + {"name": "object_detected", "labels": {"modality": "rgb"}}, + ], + }, + { + "device_id": "cam-002", + "device_type": "camera", + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [], + "events": [ + {"name": "object_detected", "labels": {"modality": "rgb"}}, + ], + }, +] + + +class FakeConnection: + """Minimal fake of the agent-tools Connection used by Subscription.""" + + def __init__(self, devices=None, zone="default"): + self.zone = zone + self.devices = devices or [] + self._inbox: dict[str, list[tuple]] = {} + self.subscribed_subjects: list[str] = [] + self.unsubscribed_names: list[str] = [] + + def list_devices(self): + return list(self.devices) + + def subscribe_buffered(self, subject: str, name: str | None = None) -> str: + name = name or subject + self._inbox[name] = [] + self.subscribed_subjects.append(subject) + return name + + def unsubscribe_buffered(self, name: str) -> None: + self.unsubscribed_names.append(name) + self._inbox.pop(name, None) + + def get_inbox(self, name: str | None = None): + if name is not None: + return {name: list(self._inbox.get(name, []))} + return {k: list(v) for k, v in self._inbox.items()} + + # Test helper: simulate a message landing on a given subject. + def deliver(self, subject: str, payload: dict): + for name, _ in list(self._inbox.items()): + self._inbox[name].append((subject, payload)) + + +@pytest.fixture +def fake_conn(): + conn = FakeConnection(devices=SAMPLE_DEVICES) + with patch.object(tools_mod, "get_connection", return_value=conn): + yield conn + + +# -- subscribe ------------------------------------------------------ + + +class TestSubscribe: + def test_correlation_form_subscribes_to_reply_subject(self, fake_conn): + sub = tools_mod.subscribe("correlation:abc-123") + assert len(fake_conn.subscribed_subjects) == 1 + subj = fake_conn.subscribed_subjects[0] + assert subj == "device-connect.default.*.event.async_reply.abc-123" + sub.close() + assert fake_conn.unsubscribed_names + + def test_correlation_form_with_empty_id_rejected(self, fake_conn): + with pytest.raises(ValueError): + tools_mod.subscribe("correlation:") + + def test_event_selector_subscribes_per_device(self, fake_conn): + sub = tools_mod.subscribe("device(*).event(object_detected)") + # Two cameras emit object_detected -> two subjects subscribed. + assert len(fake_conn.subscribed_subjects) == 2 + for subj in fake_conn.subscribed_subjects: + assert subj.startswith("device-connect.default.") + assert subj.endswith(".event.object_detected") + sub.close() + + def test_event_selector_zero_matches_returns_idle(self, fake_conn): + sub = tools_mod.subscribe("event(no_such_event)") + assert fake_conn.subscribed_subjects == [] + # Idle subscription: read returns empty, close is a no-op. + assert sub.read() == [] + sub.close() + + def test_non_event_scope_rejected(self, fake_conn): + with pytest.raises(ValueError) as exc: + tools_mod.subscribe("device(cam-001)") + assert "subscribe requires" in str(exc.value) + + def test_empty_or_non_string_rejected(self, fake_conn): + with pytest.raises(ValueError): + tools_mod.subscribe("") + with pytest.raises(ValueError): + tools_mod.subscribe(None) # type: ignore[arg-type] + + +# -- Subscription --------------------------------------------------- + + +class TestSubscriptionHandle: + def test_read_drains_buffered_messages(self, fake_conn): + sub = tools_mod.subscribe("correlation:r1") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r1", + {"correlation_id": "r1", "device_id": "cam-001", "success": True}, + ) + msgs = sub.read() + assert len(msgs) == 1 + assert msgs[0]["device_id"] == "cam-001" + # Subject is stamped onto the payload for source attribution. + assert "_subject" in msgs[0] + # A second read returns nothing -- the buffer is drained. + assert sub.read() == [] + sub.close() + + def test_context_manager_closes(self, fake_conn): + with tools_mod.subscribe("correlation:r2") as sub: + assert sub.read() == [] + assert fake_conn.unsubscribed_names # close() ran + + def test_iter_yields_until_idle_timeout(self, fake_conn): + sub = tools_mod.subscribe("correlation:r3") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r3", + {"correlation_id": "r3", "device_id": "cam-001"}, + ) + # Short timeout; iter() should yield the buffered reply then exit + # once no new messages arrive within the idle window. + msgs = list(sub.iter(timeout=0.1, poll_interval=0.01)) + assert len(msgs) == 1 + sub.close() + + +# -- await_replies -------------------------------------------------- + + +class TestAwaitReplies: + def test_empty_correlation_id_returns_empty_list(self, fake_conn): + assert tools_mod.await_replies("") == [] + + def test_collects_replies_until_count(self, fake_conn): + # Pre-stage two replies on the to-be-subscribed subject. await_replies + # subscribes (drains nothing yet), then deliver more during the loop. + # We deliver up-front via the fake's deliver hook so the first poll + # picks them up. + def deliver_when_subscribed(subject, name=None): + n = FakeConnection.subscribe_buffered(fake_conn, subject, name) + # Pre-load a couple of replies so the first poll returns them. + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r4", + {"correlation_id": "r4", "device_id": "cam-001"}, + ) + fake_conn.deliver( + "device-connect.default.cam-002.event.async_reply.r4", + {"correlation_id": "r4", "device_id": "cam-002"}, + ) + return n + + with patch.object( + fake_conn, "subscribe_buffered", side_effect=deliver_when_subscribed, + ): + replies = tools_mod.await_replies( + "r4", timeout=2.0, until=2, poll_interval=0.01, + ) + assert len(replies) == 2 + ids = {r["device_id"] for r in replies} + assert ids == {"cam-001", "cam-002"} + + def test_returns_after_timeout_with_partial(self, fake_conn): + # No replies delivered -> after timeout, returns empty list. + replies = tools_mod.await_replies( + "r5", timeout=0.1, poll_interval=0.01, + ) + assert replies == [] diff --git a/packages/device-connect-edge/device_connect_edge/device.py b/packages/device-connect-edge/device_connect_edge/device.py index 40d5c63..b64d443 100644 --- a/packages/device-connect-edge/device_connect_edge/device.py +++ b/packages/device-connect-edge/device_connect_edge/device.py @@ -1135,6 +1135,151 @@ async def on_msg(data: bytes, reply_subject: Optional[str]): self._logger.info("Subscribed to commands on %s", subj) + async def _broadcast_subscription(self) -> None: + """Subscribe to selector-driven broadcasts and self-elect to handle. + + Broadcast envelope shape (JSON over a fanout subject):: + + { + "correlation_id": "br-abc123", + "function": "capture_image", + "params": {"resolution": "4k"}, + "target_device_ids": ["cam-001", "cam-002"], // pre-resolved + "where": "status.battery > 50", // optional CEL + "bindings": {"mask": [[0,1],[1,0]]}, // optional + "fire_at": 1234567890.5, // optional, epoch s + "on_late": "skip" // skip|fire + } + + On match, the device executes the function and emits a reply on + ``device-connect...event.async_reply.`` + with ``{correlation_id, device_id, success, result|error, + actually_fired_at}``. + """ + subj = f"device-connect.{self.tenant}.broadcast" + + async def on_msg(data: bytes, reply_subject: Optional[str]): + try: + envelope = json.loads(data) + except Exception as e: + self._logger.debug("Broadcast: malformed envelope: %s", e) + return + + correlation_id = envelope.get("correlation_id") + if not correlation_id: + return + + # Self-election step 1: target_device_ids gate (pre-resolved by + # the dispatcher from the selector). When absent or empty, treat + # the broadcast as fleet-wide. + targets = envelope.get("target_device_ids") or [] + if targets and self.device_id not in targets: + return + + function_name = envelope.get("function") + if not function_name: + return + params_dict = envelope.get("params", {}) or {} + + # Self-election step 2: where predicate against {identity, labels, + # status, bindings}. A failed compile or eval is treated as + # fail-closed (do not execute). + where_expr = envelope.get("where") + if where_expr: + try: + from device_connect_edge.predicate import compile_where + predicate = compile_where(where_expr) + caps = self._driver.capabilities if self._driver else self.capabilities + status = self._driver.status if self._driver else None + labels = (caps.labels if caps and caps.labels else {}) or {} + status_dict = ( + status.model_dump() if status and hasattr(status, "model_dump") else {} + ) + # Mirror the legacy DeviceStatus.location into labels so + # ``labels.location`` works in predicates without the driver + # having to declare it explicitly. Matches the dispatcher-side + # flatten_device contract. + if "location" not in labels and status_dict.get("location"): + labels = {**labels, "location": status_dict["location"]} + context = { + "identity": ( + caps.identity.model_dump() + if caps and getattr(caps, "identity", None) else {} + ), + "labels": labels, + "status": status_dict, + "bindings": envelope.get("bindings", {}) or {}, + } + if not predicate.evaluate(context): + return + except Exception as e: + self._logger.warning( + "Broadcast %s: where predicate failed (skipping): %s", + correlation_id, e, + ) + return + + # fire_at: hold the message until the wall-clock deadline. The + # on_late policy decides what to do if the message arrives past + # the deadline (skip preserves coherence; fire runs anyway). + fire_at = envelope.get("fire_at") + on_late = envelope.get("on_late", "skip") + if fire_at is not None: + delay = float(fire_at) - time.time() + if delay < 0 and on_late == "skip": + self._logger.info( + "Broadcast %s arrived %.3fs late, on_late=skip", + correlation_id, -delay, + ) + return + if delay > 0: + await asyncio.sleep(delay) + + # Execute the driver function and emit the reply. + actually_fired_at = time.time() + reply_subj = ( + f"device-connect.{self.tenant}.{self.device_id}" + f".event.async_reply.{correlation_id}" + ) + try: + if self._driver is None: + raise RuntimeError("no driver configured") + driver_functions = self._driver._get_functions() + if function_name not in driver_functions: + raise RuntimeError(f"unknown function: {function_name}") + result = await self._driver.invoke(function_name, **params_dict) + reply_payload = { + "correlation_id": correlation_id, + "device_id": self.device_id, + "success": True, + "result": result, + "actually_fired_at": actually_fired_at, + } + except Exception as e: + self._logger.warning( + "Broadcast %s: function %s failed: %s", + correlation_id, function_name, e, + ) + reply_payload = { + "correlation_id": correlation_id, + "device_id": self.device_id, + "success": False, + "error": {"code": "invoke_failed", "message": str(e)}, + "actually_fired_at": actually_fired_at, + } + try: + await self.messaging.publish( + reply_subj, json.dumps(reply_payload).encode(), + ) + except Exception as e: # pragma: no cover + self._logger.warning( + "Broadcast %s: reply publish failed: %s", correlation_id, e, + ) + + await self.messaging.subscribe(subj, callback=on_msg) + self._logger.info("Subscribed to broadcasts on %s", subj) + + async def _event_dispatch_loop(self) -> None: """Send queued events, retrying on failure.""" @@ -1372,6 +1517,13 @@ async def run(self) -> None: # Subscribe to commands BEFORE capability routines so log order makes sense await self._cmd_subscription() + # Subscribe to fleet broadcasts (best-effort; broadcast is opt-in for + # callers, so failure here should not block command handling). + try: + await self._broadcast_subscription() + except Exception as e: # pragma: no cover - best effort logging + self._logger.warning("Broadcast subscription failed: %s", e) + # Start capability routines if driver supports them (CapabilityDriverMixin) # This must happen after registration so events don't fire before device is registered if hasattr(self._driver, 'start_capability_routines'): diff --git a/tests/tests/test_tools_broadcast.py b/tests/tests/test_tools_broadcast.py new file mode 100644 index 0000000..0e7413f --- /dev/null +++ b/tests/tests/test_tools_broadcast.py @@ -0,0 +1,213 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Integration tests for selector-driven broadcast + correlation replies. + +End-to-end coverage for the async fan-out path: +- Dispatcher publishes a broadcast envelope on the fanout subject. +- Each device runtime self-elects via target_device_ids and the optional + CEL ``where`` predicate. +- Devices execute the function and emit a reply on the per-device async + reply subject keyed by correlation_id. +- ``await_replies`` collects replies for a bounded window. +""" + +import asyncio +import time + +import pytest + +SETTLE_TIME = 0.4 +DISCOVERY_TIMEOUT = 5.0 + + +async def _wait_for_devices(messaging_url, expected_ids): + from device_connect_agent_tools import connect + from device_connect_agent_tools.connection import get_connection + + await asyncio.to_thread(connect, nats_url=messaging_url) + deadline = time.monotonic() + DISCOVERY_TIMEOUT + while True: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices) + ids = {d.get("device_id") for d in devices} + if expected_ids.issubset(ids) or time.monotonic() > deadline: + return devices + await asyncio.sleep(0.25) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_returns_correlation_and_replies_arrive( + device_spawner, messaging_url, +): + """broadcast() returns a correlation_id and matching devices reply on the + per-device async reply subject.""" + await device_spawner.spawn_camera("itest-bc-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bc-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bc-cam-1", "itest-bc-cam-2"}) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-bc-cam-*).function(capture_image)", + {"resolution": "720p"}, + ) + assert result["correlation_id"].startswith("br-") + assert result["candidates"] == 2 + assert result["function"] == "capture_image" + + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=5.0, until=2, + ) + assert len(replies) == 2 + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bc-cam-1", "itest-bc-cam-2"} + for r in replies: + assert r["success"] is True + assert r["correlation_id"] == result["correlation_id"] + assert "actually_fired_at" in r + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_where_filters_at_edge(device_spawner, messaging_url): + """A CEL where predicate runs at each candidate; only matches reply.""" + pytest.importorskip("celpy") + await device_spawner.spawn_camera("itest-bcw-cam-a", location="lab-A") + await device_spawner.spawn_camera("itest-bcw-cam-b", location="lab-B") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bcw-cam-a", "itest-bcw-cam-b"}) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-bcw-cam-*).function(capture_image)", + {"resolution": "1080p"}, + "labels.location == 'lab-A'", # where predicate + ) + assert result["candidates"] == 2 + + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=3.0, + ) + # Only cam-a is in lab-A; cam-b silently self-deselects. + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bcw-cam-a"} + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_fire_at_synchronizes_fan_out( + device_spawner, messaging_url, +): + """fire_at causes each device to fire from its own clock at the deadline.""" + await device_spawner.spawn_camera("itest-bcf-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bcf-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bcf-cam-1", "itest-bcf-cam-2"}) + try: + # Schedule 0.5s in the future; on_late=skip so any tardy device drops + # the call rather than firing late and breaking the coherence. + scheduled = time.time() + 0.5 + result = await asyncio.to_thread( + broadcast, + "device(itest-bcf-cam-*).function(capture_image)", + None, None, None, + scheduled, # fire_at + "skip", # on_late + ) + assert result["candidates"] == 2 + + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=3.0, until=2, + ) + assert len(replies) == 2 + # actually_fired_at should be at-or-after the scheduled time on each. + for r in replies: + assert r["actually_fired_at"] >= scheduled - 0.05 # small slack + # Achieved spread should be tight (well under network jitter). + spread = max(r["actually_fired_at"] for r in replies) - min( + r["actually_fired_at"] for r in replies + ) + assert spread < 0.5, f"fire_at spread too wide: {spread:.3f}s" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_fire_at_late_with_skip_drops( + device_spawner, messaging_url, +): + """A fire_at in the past with on_late=skip yields no replies.""" + await device_spawner.spawn_camera("itest-bcl-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bcl-cam"}) + try: + past = time.time() - 5.0 # already 5s late + result = await asyncio.to_thread( + broadcast, + "device(itest-bcl-cam).function(capture_image)", + None, None, None, past, "skip", + ) + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=1.5, + ) + assert replies == [] + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_subscribe_correlation_form(device_spawner, messaging_url): + """subscribe('correlation:') captures replies as they arrive.""" + await device_spawner.spawn_camera("itest-bcs-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bcs-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import broadcast, disconnect, subscribe + + await _wait_for_devices(messaging_url, {"itest-bcs-cam-1", "itest-bcs-cam-2"}) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-bcs-cam-*).function(capture_image)", + ) + cid = result["correlation_id"] + + def collect(): + with subscribe(f"correlation:{cid}") as sub: + # Drain over a short window. + return list(sub.iter(timeout=2.0, poll_interval=0.05)) + + replies = await asyncio.to_thread(collect) + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bcs-cam-1", "itest-bcs-cam-2"} + finally: + await asyncio.to_thread(disconnect) From 072ef0133944a20c0dbfed9c6248109cabcb8938 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 12:10:20 -0700 Subject: [PATCH 4/8] feat(cli): selector-driven verbs in devctl and statectl Add the operator-facing shell surface for selector-driven discovery and operations: devctl verbs (read-side): - devctl discover "" [--offset N] [--limit M] - devctl discover-labels [--key K] [--offset N] [--limit M] statectl verbs (write-side): - statectl invoke "" [--param k=v ...] - statectl invoke-many "" [--param k=v ...] [--timeout T] [--max-concurrency N] - statectl broadcast "" [--param k=v ...] [--where E] [--bindings JSON] [--fire-at T] [--on-late skip|fire] - statectl subscribe "" [--timeout T] [--until N] - statectl await [--timeout T] [--until N] Each verb is a thin wrapper over the Python tool of the same name and exits non-zero on tool-side errors so they compose into shell pipelines naturally. Parameter values are decoded as JSON when they look like JSON (numbers, booleans, arrays, objects, quoted strings) and pass through as strings otherwise, so common shapes (--param resolution=4k, --param zones='[1,2,3]') work without quoting heroics. The historical ``devctl discover`` verb (mDNS scan for uncommissioned devices) is renamed to ``mdns-scan`` with ``scan`` as an alias, so ``discover`` is free for the selector-driven sense. Existing scripts should switch from ``devctl discover`` to ``devctl scan`` if they were exercising the mDNS path. 22 parser-shape unit tests guard against argument drift; the underlying tools already have full unit and integration coverage from earlier phases. --- .../device_connect_server/devctl/cli.py | 27 +- .../devctl/selector_cli.py | 103 +++++++ .../device_connect_server/statectl/cli.py | 23 ++ .../statectl/operations_cli.py | 282 ++++++++++++++++++ .../test_selector_cli.py | 199 ++++++++++++ 5 files changed, 630 insertions(+), 4 deletions(-) create mode 100644 packages/device-connect-server/device_connect_server/devctl/selector_cli.py create mode 100644 packages/device-connect-server/device_connect_server/statectl/operations_cli.py create mode 100644 packages/device-connect-server/tests/device_connect_server/test_selector_cli.py diff --git a/packages/device-connect-server/device_connect_server/devctl/cli.py b/packages/device-connect-server/device_connect_server/devctl/cli.py index 071b423..f73ec6a 100644 --- a/packages/device-connect-server/device_connect_server/devctl/cli.py +++ b/packages/device-connect-server/device_connect_server/devctl/cli.py @@ -574,9 +574,20 @@ def create_parser() -> argparse.ArgumentParser: p_reg.add_argument("--broker", default=None, help="Broker URL") p_reg.add_argument("--keepalive", action="store_true", help="Start heartbeat loop") - # discover command - p_discover = sub.add_parser("discover", help="Discover uncommissioned devices") - p_discover.add_argument("--timeout", type=int, default=5, help="Timeout in seconds") + # mdns-scan: discover uncommissioned devices on the local network. + # Renamed from the historical ``discover`` verb so the selector-driven + # ``discover`` below (which queries the fleet, not the local network) + # can take the natural name. + p_scan = sub.add_parser( + "mdns-scan", help="Discover uncommissioned devices via mDNS", + aliases=["scan"], + ) + p_scan.add_argument("--timeout", type=int, default=5, help="Timeout in seconds") + + # Selector-driven fleet discovery (new). Registers ``discover`` and + # ``discover-labels`` as parser entries. + from device_connect_server.devctl import selector_cli + selector_cli.register_subparsers(sub) # commission command p_commission = sub.add_parser("commission", help="Commission a device with PIN") @@ -617,9 +628,17 @@ def main(argv: Optional[List[str]] = None) -> None: loop.stop() print("\nbye!") - elif args.cmd == "discover": + elif args.cmd in ("mdns-scan", "scan"): asyncio.run(discover_devices(timeout=args.timeout)) + elif args.cmd == "discover": + from device_connect_server.devctl import selector_cli + sys.exit(selector_cli.run_discover(args)) + + elif args.cmd == "discover-labels": + from device_connect_server.devctl import selector_cli + sys.exit(selector_cli.run_discover_labels(args)) + elif args.cmd == "commission": asyncio.run( commission_device( diff --git a/packages/device-connect-server/device_connect_server/devctl/selector_cli.py b/packages/device-connect-server/device_connect_server/devctl/selector_cli.py new file mode 100644 index 0000000..68a6637 --- /dev/null +++ b/packages/device-connect-server/device_connect_server/devctl/selector_cli.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""``devctl`` selector-driven discovery verbs. + +Thin wrappers around ``device_connect_agent_tools.discover`` and +``discover_labels`` so operators can drive the same selector grammar +from a shell. +""" +from __future__ import annotations + +import json +import os +from typing import Any + + +def _connect(broker: str | None) -> None: + """Best-effort connect to the messaging backend. + + Reuses ``DEVICE_CONNECT_*`` and ``NATS_URL`` env vars when ``broker`` is + not given. Kept as a thin wrapper so all CLI verbs share the same + connect-or-fail semantics. + """ + from device_connect_agent_tools import connect + + if broker: + connect(nats_url=broker) + else: + nats_url = os.getenv("NATS_URL") or os.getenv("DEVICE_CONNECT_NATS_URL") + if nats_url: + connect(nats_url=nats_url) + else: + connect() + + +def _pretty(data: Any) -> str: + """Render a JSON payload for terminal output.""" + return json.dumps(data, indent=2, sort_keys=True, default=str) + + +def run_discover(args: Any) -> int: + """Execute ``devctl discover ""``.""" + from device_connect_agent_tools import disconnect, discover + + _connect(getattr(args, "broker", None)) + try: + result = discover( + args.selector, + offset=int(args.offset or 0), + limit=int(args.limit or 200), + ) + print(_pretty(result)) + return 0 if "error" not in result else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_discover_labels(args: Any) -> int: + """Execute ``devctl discover-labels [--key K]``.""" + from device_connect_agent_tools import disconnect, discover_labels + + _connect(getattr(args, "broker", None)) + try: + result = discover_labels( + key=args.key, + offset=int(args.offset or 0), + limit=int(args.limit or 50), + ) + print(_pretty(result)) + return 0 if "error" not in result else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def register_subparsers(sub: Any) -> None: + """Attach the discover / discover-labels subparsers to a devctl parser.""" + p = sub.add_parser( + "discover", + help="Resolve a selector to devices, functions, or events", + ) + p.add_argument("selector", help="Selector expression (e.g. 'device(category:camera)')") + p.add_argument("--broker", default=None, help="Messaging broker URL") + p.add_argument("--offset", type=int, default=0, help="Pagination offset") + p.add_argument("--limit", type=int, default=200, help="Page size") + + p = sub.add_parser( + "discover-labels", + help="Browse fleet label vocabulary", + ) + p.add_argument( + "--key", default=None, + help="Axis-qualified label key (e.g. 'device.location') for per-key pagination", + ) + p.add_argument("--broker", default=None, help="Messaging broker URL") + p.add_argument("--offset", type=int, default=0, help="Pagination offset") + p.add_argument("--limit", type=int, default=50, help="Page size") diff --git a/packages/device-connect-server/device_connect_server/statectl/cli.py b/packages/device-connect-server/device_connect_server/statectl/cli.py index e1a03ef..161afdd 100644 --- a/packages/device-connect-server/device_connect_server/statectl/cli.py +++ b/packages/device-connect-server/device_connect_server/statectl/cli.py @@ -408,6 +408,13 @@ def create_parser() -> argparse.ArgumentParser: # stats sub.add_parser("stats", help="Key counts by namespace") + # Selector-driven operations (invoke / invoke-many / broadcast / + # subscribe / await). These verbs do not touch etcd; they run over + # the messaging fabric. They live under statectl because they all + # change the live state of devices. + from device_connect_server.statectl import operations_cli + operations_cli.register_subparsers(sub) + return parser @@ -430,9 +437,25 @@ async def _run(args) -> None: await handler(client, args) +_OPERATIONS_DISPATCH = { + "invoke": "run_invoke", + "invoke-many": "run_invoke_many", + "broadcast": "run_broadcast", + "subscribe": "run_subscribe", + "await": "run_await", +} + + def main(): parser = create_parser() args = parser.parse_args() + if args.cmd in _OPERATIONS_DISPATCH: + # Operations verbs run over messaging, not etcd. Bypass the etcd + # client setup that the COMMANDS dispatch table assumes. + from device_connect_server.statectl import operations_cli + handler = getattr(operations_cli, _OPERATIONS_DISPATCH[args.cmd]) + sys.exit(handler(args)) + try: asyncio.run(_run(args)) except KeyboardInterrupt: diff --git a/packages/device-connect-server/device_connect_server/statectl/operations_cli.py b/packages/device-connect-server/device_connect_server/statectl/operations_cli.py new file mode 100644 index 0000000..630709a --- /dev/null +++ b/packages/device-connect-server/device_connect_server/statectl/operations_cli.py @@ -0,0 +1,282 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""``statectl`` selector-driven operations verbs. + +Thin wrappers around the agent-tools ``invoke`` / ``invoke_many`` / +``broadcast`` / ``subscribe`` / ``await_replies`` functions so operators +can fire selector-driven calls from a shell. +""" +from __future__ import annotations + +import json +import os +from typing import Any + + +def _connect(broker: str | None) -> None: + """Connect to the messaging backend using the same env-or-broker rules + as devctl's selector verbs.""" + from device_connect_agent_tools import connect + + if broker: + connect(nats_url=broker) + else: + nats_url = os.getenv("NATS_URL") or os.getenv("DEVICE_CONNECT_NATS_URL") + if nats_url: + connect(nats_url=nats_url) + else: + connect() + + +def _parse_param_kv(values: list[str] | None) -> dict[str, Any]: + """Parse ``--param k=v`` repeated args into a function-params dict. + + Values that look like JSON (``[...]``, ``{...}``, numbers, ``true`` / + ``false`` / ``null``) are decoded; everything else stays a string. This + matches what an operator would expect when typing + ``--param resolution=1080p --param tags='["a","b"]'``. + """ + out: dict[str, Any] = {} + for entry in values or []: + if "=" not in entry: + raise ValueError(f"--param must be 'k=v', got {entry!r}") + k, _, v = entry.partition("=") + k = k.strip() + if not k: + raise ValueError(f"--param has empty key in {entry!r}") + v_stripped = v.strip() + # JSON-decode obvious JSON-shaped values; fall back to raw string. + if ( + v_stripped.startswith(("[", "{", '"')) + or v_stripped in ("true", "false", "null") + or _looks_numeric(v_stripped) + ): + try: + out[k] = json.loads(v_stripped) + continue + except json.JSONDecodeError: + pass + out[k] = v + return out + + +def _looks_numeric(s: str) -> bool: + try: + float(s) + return True + except ValueError: + return False + + +def _pretty(data: Any) -> str: + return json.dumps(data, indent=2, sort_keys=True, default=str) + + +# -- verbs ---------------------------------------------------------- + + +def run_invoke(args: Any) -> int: + from device_connect_agent_tools import disconnect, invoke + + _connect(getattr(args, "broker", None)) + try: + result = invoke( + args.selector, + params=_parse_param_kv(args.param), + llm_reasoning=args.reason, + ) + print(_pretty(result)) + return 0 if result.get("success") else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_invoke_many(args: Any) -> int: + from device_connect_agent_tools import disconnect, invoke_many + + _connect(getattr(args, "broker", None)) + try: + result = invoke_many( + args.selector, + params=_parse_param_kv(args.param), + timeout=float(args.timeout), + max_concurrency=int(args.max_concurrency), + llm_reasoning=args.reason, + ) + print(_pretty(result)) + return 0 if "error" not in result else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_broadcast(args: Any) -> int: + from device_connect_agent_tools import broadcast, disconnect + + bindings = None + if args.bindings: + try: + bindings = json.loads(args.bindings) + except json.JSONDecodeError as e: + print(f"--bindings must be valid JSON: {e}") + return 2 + + _connect(getattr(args, "broker", None)) + try: + result = broadcast( + args.selector, + params=_parse_param_kv(args.param), + where=args.where, + bindings=bindings, + fire_at=float(args.fire_at) if args.fire_at is not None else None, + on_late=args.on_late, + llm_reasoning=args.reason, + ) + print(_pretty(result)) + return 0 if "error" not in result else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_subscribe(args: Any) -> int: + """Stream events / replies for ``args.selector`` to stdout. + + Each message is printed as one JSON line so the output can be piped + into ``jq`` or grep. Runs until ``--timeout`` of idle silence elapses + or ``--until`` messages have been printed (whichever comes first). + """ + from device_connect_agent_tools import disconnect, subscribe + + _connect(getattr(args, "broker", None)) + try: + count = 0 + with subscribe(args.selector) as sub: + for msg in sub.iter( + timeout=float(args.timeout), poll_interval=0.05, + ): + print(json.dumps(msg, default=str)) + count += 1 + if args.until is not None and count >= int(args.until): + break + return 0 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_await(args: Any) -> int: + from device_connect_agent_tools import await_replies, disconnect + + _connect(getattr(args, "broker", None)) + try: + replies = await_replies( + args.correlation_id, + timeout=float(args.timeout), + until=int(args.until) if args.until is not None else None, + ) + print(_pretty(replies)) + return 0 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +# -- parser wiring -------------------------------------------------- + + +def register_subparsers(sub: Any) -> None: + """Attach the operation subparsers to a statectl parser.""" + p = sub.add_parser("invoke", help="Call exactly one function on one device") + p.add_argument("selector", help="Function-scoped selector") + p.add_argument( + "--param", action="append", default=[], + help="Function param as k=v (repeatable; JSON values decoded)", + ) + p.add_argument("--reason", default=None, help="LLM reasoning") + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "invoke-many", help="Fan out a call over a selector-resolved set", + ) + p.add_argument("selector", help="Function-scoped selector") + p.add_argument( + "--param", action="append", default=[], + help="Function param as k=v (repeatable; JSON values decoded)", + ) + p.add_argument("--timeout", default=30.0, help="Per-target timeout (s)") + p.add_argument( + "--max-concurrency", default=32, dest="max_concurrency", + help="Parallel worker cap", + ) + p.add_argument("--reason", default=None, help="LLM reasoning") + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "broadcast", + help="Async fan-out; returns correlation_id", + ) + p.add_argument("selector", help="Function-scoped selector") + p.add_argument( + "--param", action="append", default=[], + help="Function param as k=v (repeatable; JSON values decoded)", + ) + p.add_argument( + "--where", default=None, + help="CEL predicate evaluated at the edge per candidate", + ) + p.add_argument( + "--bindings", default=None, + help="JSON-encoded bindings dict (shared payload for the predicate)", + ) + p.add_argument( + "--fire-at", default=None, dest="fire_at", + help="Wall-clock epoch seconds for synchronized fan-out", + ) + p.add_argument( + "--on-late", choices=["skip", "fire"], default="skip", dest="on_late", + help="Policy when fire_at deadline has passed (default: skip)", + ) + p.add_argument("--reason", default=None, help="LLM reasoning") + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "subscribe", help="Stream events or broadcast replies to stdout", + ) + p.add_argument( + "selector", + help="Event selector or 'correlation:' for broadcast replies", + ) + p.add_argument( + "--timeout", default=10.0, + help="Idle-silence timeout per message (s; resets on each arrival)", + ) + p.add_argument( + "--until", default=None, + help="Stop after this many messages are printed", + ) + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "await", help="Collect replies for a broadcast correlation_id", + ) + p.add_argument("correlation_id", help="Correlation id returned by broadcast") + p.add_argument("--timeout", default=10.0, help="Overall timeout (s)") + p.add_argument( + "--until", default=None, + help="Stop after this many replies have been collected", + ) + p.add_argument("--broker", default=None, help="Messaging broker URL") diff --git a/packages/device-connect-server/tests/device_connect_server/test_selector_cli.py b/packages/device-connect-server/tests/device_connect_server/test_selector_cli.py new file mode 100644 index 0000000..2ed2cf8 --- /dev/null +++ b/packages/device-connect-server/tests/device_connect_server/test_selector_cli.py @@ -0,0 +1,199 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Smoke tests for the selector-driven CLI verbs. + +Argument-parser shape only; the underlying tools (``discover``, +``invoke``, ``broadcast``, etc.) have their own unit and integration +tests. These guards catch parser-config regressions (missing positional, +typoed dest, alias drift). +""" +from __future__ import annotations + +import json + +import pytest + +from device_connect_server.devctl import cli as devctl_cli +from device_connect_server.devctl import selector_cli +from device_connect_server.statectl import cli as statectl_cli +from device_connect_server.statectl import operations_cli + + +# -- devctl --------------------------------------------------------- + + +class TestDevctlSelectorParser: + def test_discover_requires_selector(self): + parser = devctl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args(["discover"]) + + def test_discover_parses_selector(self): + parser = devctl_cli.create_parser() + args = parser.parse_args(["discover", "device(category:camera)"]) + assert args.cmd == "discover" + assert args.selector == "device(category:camera)" + assert args.offset == 0 + assert args.limit == 200 + + def test_discover_offset_limit_override(self): + parser = devctl_cli.create_parser() + args = parser.parse_args( + ["discover", "device(*)", "--offset", "100", "--limit", "50"] + ) + assert args.offset == 100 + assert args.limit == 50 + + def test_discover_labels_no_key(self): + parser = devctl_cli.create_parser() + args = parser.parse_args(["discover-labels"]) + assert args.cmd == "discover-labels" + assert args.key is None + assert args.limit == 50 + + def test_discover_labels_key_pagination(self): + parser = devctl_cli.create_parser() + args = parser.parse_args( + ["discover-labels", "--key", "device.location", "--limit", "20"] + ) + assert args.key == "device.location" + assert args.limit == 20 + + def test_legacy_discover_renamed_to_mdns_scan(self): + # The historical "discover" verb (mDNS scan) now lives under + # mdns-scan; the alias "scan" keeps it discoverable. + parser = devctl_cli.create_parser() + for verb in ("mdns-scan", "scan"): + args = parser.parse_args([verb]) + # Both aliases share the same args.cmd + assert args.cmd in ("mdns-scan", "scan") + + +# -- statectl ------------------------------------------------------- + + +class TestStatectlOperationsParser: + def test_invoke_requires_selector(self): + parser = statectl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args(["invoke"]) + + def test_invoke_parses(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + [ + "invoke", "device(robot-001).function(grip_close)", + "--param", "force_n=10", + "--reason", "test", + ] + ) + assert args.cmd == "invoke" + assert args.selector == "device(robot-001).function(grip_close)" + assert args.param == ["force_n=10"] + assert args.reason == "test" + + def test_invoke_many_with_timeout(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + [ + "invoke-many", + "function(safety:critical)", + "--timeout", "5", + "--max-concurrency", "8", + ] + ) + assert args.cmd == "invoke-many" + assert float(args.timeout) == 5.0 + assert int(args.max_concurrency) == 8 + + def test_broadcast_full_signature(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + [ + "broadcast", + "device(category:phone).function(set_flashlight)", + "--param", "on=true", + "--param", "color=white", + "--where", "labels.location == 'lab-A'", + "--bindings", '{"mask": [[0,1],[1,0]]}', + "--fire-at", "1700000000.0", + "--on-late", "fire", + ] + ) + assert args.cmd == "broadcast" + assert args.selector.startswith("device(category:phone)") + assert args.where == "labels.location == 'lab-A'" + assert args.on_late == "fire" + + def test_broadcast_rejects_unknown_on_late(self): + parser = statectl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args( + [ + "broadcast", "device(*).function(do)", + "--on-late", "bogus", + ] + ) + + def test_subscribe_parses_correlation_form(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + ["subscribe", "correlation:br-abc123", "--until", "5"] + ) + assert args.cmd == "subscribe" + assert args.selector == "correlation:br-abc123" + assert int(args.until) == 5 + + def test_await_requires_correlation_id(self): + parser = statectl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args(["await"]) + + def test_await_parses(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + ["await", "br-abc123", "--timeout", "2.5", "--until", "10"] + ) + assert args.correlation_id == "br-abc123" + assert float(args.timeout) == 2.5 + assert int(args.until) == 10 + + +# -- parameter parsing ---------------------------------------------- + + +class TestParseParamKV: + def test_string_values_default(self): + result = operations_cli._parse_param_kv(["a=hello", "b=world"]) + assert result == {"a": "hello", "b": "world"} + + def test_numbers_decoded(self): + result = operations_cli._parse_param_kv(["count=5", "ratio=0.75"]) + assert result == {"count": 5, "ratio": 0.75} + + def test_booleans_decoded(self): + result = operations_cli._parse_param_kv(["on=true", "off=false"]) + assert result == {"on": True, "off": False} + + def test_json_array_decoded(self): + result = operations_cli._parse_param_kv(["zones=[1,2,3]"]) + assert result == {"zones": [1, 2, 3]} + + def test_json_object_decoded(self): + result = operations_cli._parse_param_kv(['nested={"a":1}']) + assert result == {"nested": {"a": 1}} + + def test_string_with_equals(self): + # The split is on the first '=', so values may contain further '='. + result = operations_cli._parse_param_kv(["query=a=b"]) + assert result == {"query": "a=b"} + + def test_invalid_form_rejected(self): + with pytest.raises(ValueError): + operations_cli._parse_param_kv(["no_equals_sign"]) + + def test_empty_key_rejected(self): + with pytest.raises(ValueError): + operations_cli._parse_param_kv(["=value"]) From 02a94201bc97fe30eb4e22edd45aff0b95bbf0ad Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 12:14:12 -0700 Subject: [PATCH 5/8] docs: extend discovery guide for operations, where, and CLI Add the operations layer (invoke / invoke_many / broadcast / subscribe / await_replies) to docs/discovery.md, with the edge-side ``where`` predicate, synchronized fan-out via ``fire_at`` / ``on_late``, worked examples that exercise each tool, and the corresponding devctl / statectl CLI verbs. The guide now covers everything the discovery API ships: labels schema, selector grammar, the five scope shapes, response envelope, error codes, all seven tools, and the CLI surface. --- docs/discovery.md | 184 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 182 insertions(+), 2 deletions(-) diff --git a/docs/discovery.md b/docs/discovery.md index 6d2b8b4..ee33d60 100644 --- a/docs/discovery.md +++ b/docs/discovery.md @@ -96,7 +96,9 @@ function(estop) fleet emergency-st ## Tools -### `discover(selector, offset=0, limit=200)` +### Discovery + +#### `discover(selector, offset=0, limit=200)` Resolves a selector to matched entities. Returns devices, function tuples, or event tuples depending on the selector scope. The response includes a @@ -108,7 +110,7 @@ and switches to a name-and-labels summary above `DEVICE_CONNECT_FUNCTION_THRESHOLD` (default 20). The threshold is configurable via environment variable. -### `discover_labels(key=None, offset=0, limit=50)` +#### `discover_labels(key=None, offset=0, limit=50)` Returns the fleet label vocabulary. Use this first when you do not know which dimensions are available. @@ -118,6 +120,84 @@ which dimensions are available. - With a `key` like `"device.location"` or `"function.direction"`: paginates the full value list for that one key. +### Operations + +Calling a function on devices is one logical operation; the only choice +is whether the caller waits for replies and how they arrive. + +| Tool | Selector resolves to | Reply mode | +| --- | --- | --- | +| `invoke(selector, params)` | exactly one (device, function) tuple | sync, single result | +| `invoke_many(selector, params, timeout=)` | any number of (device, function) tuples | sync, aggregated | +| `broadcast(selector, params, where=, bindings=, fire_at=, on_late=)` | any number of (device, function) tuples | async; correlation-tagged replies stream as events | +| `subscribe(selector)` | events, or `"correlation:"` for broadcast replies | live stream (`Subscription` handle) | +| `await_replies(correlation_id, timeout=, until=)` | replies for one broadcast | sync helper that subscribes, collects, returns | + +`invoke_many` runs every target's call in parallel and returns when each +target has finished or hit its per-target timeout (30 s default). Partial +failures do not abort siblings; the response carries both `results` and +`errors` lists. + +`broadcast` does the same fan-out asynchronously: the caller gets a +`correlation_id` immediately and replies stream back on a per-device +subject keyed by that id. Subscribe with `subscribe("correlation:")` +or block with `await_replies(correlation_id, timeout=...)`. + +### Edge-side `where` predicate + +`broadcast` accepts an optional `where` expression that runs at each +candidate device. The predicate is a CEL (Common Expression Language) +string and sees four variables: + +- `identity` — device-local identity dict (`device_id`, `device_type`, ...) +- `labels` — device labels (the same labels selectors filter on) +- `status` — device status (heartbeat-updated: `location`, `availability`, + `battery`, `online`, ...) +- `bindings` — the shared payload passed to `broadcast` (selection masks, + thresholds, lookup tables) + +```python +broadcast( + "device(category:camera).function(capture_image)", + params={"resolution": "4k"}, + where="status.battery > 50 && labels.location == 'lab-A'", +) +``` + +The `where` predicate is sandboxed by CEL (no I/O, no filesystem). The +predicate evaluator is an optional install: + +``` +pip install device-connect-agent-tools[predicate] +``` + +Without the extra, calling `broadcast(..., where=...)` returns an +`invalid_predicate` error immediately at the dispatcher; calls without a +`where` work unchanged. + +### Synchronized fan-out (`fire_at` + `on_late`) + +`broadcast` accepts an optional `fire_at` (wall-clock epoch seconds). +Each device holds the message and fires from its own clock at the +deadline. `on_late` controls behaviour when a device receives the +message past the deadline: + +- `"skip"` (default) — drop the call to preserve coherence. +- `"fire"` — execute immediately. + +```python +broadcast( + "device(category:phone).function(set_flashlight)", + params={"on": True, "color": "white"}, + fire_at=time.time() + 0.500, # 500 ms in the future + on_late="skip", +) +``` + +With NTP-synced devices the achieved spread is typically 5-10 ms +(clock-sync residual) rather than the 50-150 ms a naive fire-on-receipt +broadcast would produce. + ## Response envelope `discover` returns a stable envelope: @@ -229,3 +309,103 @@ while True: break offset = page["next_offset"] ``` + +### Invoke a single function + +```python +from device_connect_agent_tools import invoke + +result = invoke( + "device(robot-001).function(grip_close)", + {"force_n": 10}, +) +# {"success": True, "device_id": "robot-001", "function": "grip_close", +# "result": {...}} +``` + +### Fan out across every camera in lab-A + +```python +from device_connect_agent_tools import invoke_many + +result = invoke_many( + "device(category:camera, location:lab-A).function(capture_image)", + {"resolution": "4k"}, +) +# {"candidates": 12, "matched": 12, "succeeded": 12, "failed": 0, +# "results": [...], "errors": []} +``` + +### Async fleet emergency stop + +```python +from device_connect_agent_tools import broadcast, await_replies + +result = broadcast("function(estop)") +# {"correlation_id": "br-7f3a91", "candidates": 240, ...} + +replies = await_replies(result["correlation_id"], timeout=5.0) +# list of {device_id, success, result|error, actually_fired_at} +``` + +### Synchronized actuation across a phone fleet + +```python +import time +from device_connect_agent_tools import broadcast + +mask = build_mask_from_scores(threshold=0.8) # caller-side selection +broadcast( + "device(category:phone, location:auditorium-A).function(set_flashlight)", + params={"on": True, "color": "white"}, + where="mask[seat_row][seat_col] == 1 && status.battery > 30", + bindings={"mask": mask}, + fire_at=time.time() + 0.5, + on_late="skip", +) +``` + +### Subscribe to motion events in lab-A + +```python +from device_connect_agent_tools import subscribe + +with subscribe("device(location:lab-A/*).event(modality:motion)") as sub: + for event in sub.iter(timeout=60.0): + handle(event) +``` + +## CLI + +The same selector syntax drives the operator CLIs. Every CLI command +maps to the matching Python tool call. + +``` +# Discovery (devctl) +devctl discover "" [--offset N] [--limit M] +devctl discover-labels [--key K] [--offset N] [--limit M] + +# Operations (statectl) +statectl invoke "" [--param k=v ...] +statectl invoke-many "" [--param k=v ...] [--timeout T] +statectl broadcast "" [--param k=v ...] [--where E] + [--bindings JSON] [--fire-at T] + [--on-late skip|fire] +statectl subscribe "" [--timeout T] [--until N] +statectl await [--timeout T] [--until N] +``` + +`--param k=v` accepts JSON-shaped values (numbers, booleans, arrays, +objects); everything else passes through as a string. So +`--param resolution=4k` and `--param zones='[1,2,3]'` both work +without quoting heroics. + +Each verb exits non-zero on tool-side errors so the verbs compose into +shell pipelines: + +``` +statectl broadcast "device(category:camera).function(capture_image)" \ + --param resolution=4k \ + | jq -r .correlation_id \ + | xargs statectl await --timeout 5 +``` From 7c760ab655ab510e4509c695fdaa2e96d083dcd6 Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 12:39:14 -0700 Subject: [PATCH 6/8] fix(broadcast): robustness pass on edge handler, subscribe, and CLI Applies findings from the pre-merge review of the operations stack: Edge runtime (device.py): - Hand the broadcast envelope off to a tracked task so the subscription callback returns immediately. A long fire_at hold or slow driver function no longer blocks subsequent broadcasts from being received. - Extract _handle_broadcast_envelope and _evaluate_where so the where self-election step is isolated, unit-testable, and the callback body stays flat. - Splice device_id into the predicate's identity context so the natural ``identity.device_id == "..."`` form works (DeviceIdentity itself does not carry device_id; that lives on the runtime). Wire format (tools.py + device.py): - Rename the broadcast envelope's ``target_device_ids`` field to ``targets`` before any edge ships. Shorter, less prescriptive, and matches the dispatcher-side ``candidates`` naming. Subscription handle (tools.py): - Fix a race in Subscription.read(): truncate by the snapshot length captured BEFORE iteration, not by clearing post-iteration. A message appended by the messaging callback during draining now survives to the next read instead of being silently dropped. - Add __iter__ so ``for msg in sub:`` works with a sensible 30s idle timeout, matching the standard Python iteration protocol. CLI (statectl/operations_cli.py): - statectl subscribe now catches KeyboardInterrupt cleanly (exit 130), distinguishes "got messages" (exit 0) from "idle timeout with no messages" (exit 4), so shell pipelines can branch on either outcome. - statectl invoke-many exits 3 when any target failed (alongside the existing 1 for top-level errors), so partial failure is visible to callers without parsing JSON. ASCII compliance (predicate.py, tools.py): - Drop a banned-vocabulary token from a docstring. - Replace an em-dash in invoke_device's docstring with ASCII text. New tests: - Unit: __iter__ protocol + race-safety guard for Subscription.read. - Integration: broadcast where=identity.device_id in bindings.allow (exercises the new identity context + bindings path), await_replies(until=) early-return timing, ``for msg in sub:`` iteration end-to-end, and subscribe(event(...)) live-event capture. --- .../device_connect_agent_tools/connection.py | 2 +- .../device_connect_agent_tools/tools.py | 48 ++-- .../tests/test_broadcast.py | 2 +- .../tests/test_subscribe.py | 39 +++ .../device_connect_edge/device.py | 251 ++++++++++-------- .../device_connect_edge/predicate.py | 2 +- .../statectl/operations_cli.py | 35 ++- tests/tests/test_tools_broadcast.py | 145 ++++++++++ 8 files changed, 387 insertions(+), 137 deletions(-) diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py b/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py index b399f70..4dce5fb 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py @@ -415,7 +415,7 @@ def publish_broadcast(self, envelope: Dict[str, Any]) -> None: The envelope shape is documented in ``device_connect_edge.device.DeviceRuntime._broadcast_subscription``; every device subscribed to ``device-connect..broadcast`` - receives the message and self-elects via ``target_device_ids`` and + receives the message and self-elects via ``targets`` and the optional ``where`` predicate. """ return self._run(self._async_publish_broadcast(envelope)) diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py index c81faf2..c99c5bc 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py @@ -829,7 +829,7 @@ def broadcast( } correlation_id = f"br-{uuid.uuid4().hex[:12]}" - target_device_ids = sorted({ + targets = sorted({ row.get("device_id") for row in rows if row.get("device_id") }) clean_params = { @@ -840,7 +840,7 @@ def broadcast( "correlation_id": correlation_id, "function": function_name, "params": clean_params, - "target_device_ids": target_device_ids, + "targets": targets, } if where: envelope["where"] = where @@ -857,7 +857,7 @@ def broadcast( ) logger.info( "[broadcast::%s::%d targets] Reason: %s", - correlation_id, len(target_device_ids), truncated, + correlation_id, len(targets), truncated, ) try: @@ -866,13 +866,13 @@ def broadcast( except Exception as e: logger.error("broadcast publish failed: %s", e) return { - "candidates": len(target_device_ids), + "candidates": len(targets), "error": _error("connection_error", str(e)), } return { "correlation_id": correlation_id, - "candidates": len(target_device_ids), + "candidates": len(targets), "selector": selector, "function": function_name, } @@ -919,26 +919,27 @@ def read(self, max_messages: int | None = None) -> list[dict[str, Any]]: Returns parsed payload dicts (already JSON-decoded by the connection's buffered subscription path). Subsequent calls return only messages that arrived after the previous call. + + Race-safe against the messaging callback that appends to the same + inbox: each inbox is read by snapshotting its current length and + truncating only that prefix, so a message that arrives during + iteration stays buffered for the next ``read``. """ if self._closed: return [] out: list[dict[str, Any]] = [] for name in self._inbox_names: - inboxes = self._conn.get_inbox(name) - buffered = inboxes.get(name, []) or [] - # Each buffered entry is (subject, payload). We expose the - # parsed payload but stamp the subject onto it so callers can - # distinguish per-source messages without parsing it themselves. - for subject, payload in buffered: + buf = self._conn._inbox.get(name) or [] + # Snapshot the consumed prefix length BEFORE iterating, then + # truncate by exactly that many items. Any message appended by + # the messaging callback between the snapshot and the truncation + # remains buffered for a subsequent ``read``. + n = len(buf) + for subject, payload in buf[:n]: if not isinstance(payload, dict): payload = {"raw": payload} - payload = {**payload, "_subject": subject} - out.append(payload) - # Fast cursor: trim per-inbox buffers we have already returned by - # truncating from the front. The connection layer already caps each - # inbox at 1000 entries, so bounded growth is its concern. - for name in self._inbox_names: - self._conn._inbox[name] = [] + out.append({**payload, "_subject": subject}) + self._conn._inbox[name] = buf[n:] if max_messages is not None: out = out[:max_messages] return out @@ -962,6 +963,15 @@ def iter(self, timeout: float = 5.0, poll_interval: float = 0.05): return time.sleep(poll_interval) + def __iter__(self): + """Allow ``for msg in sub:`` with a default 30-second idle timeout. + + Delegates to :meth:`iter` with sensible defaults so the idiomatic + Python iteration form works. Use ``sub.iter(timeout=...)`` directly + when the default does not fit. + """ + return self.iter(timeout=30.0, poll_interval=0.05) + def close(self) -> None: """Tear down the underlying messaging subscriptions.""" if self._closed: @@ -1324,7 +1334,7 @@ def invoke_device( device_id: Target device ID (e.g., "robot-001", "camera-001"). function: Function name to call. params: Function parameters as a dictionary. - llm_reasoning: Why you're calling this function -- for observability. + llm_reasoning: Why you are calling this function (for observability). """ warnings.warn( "invoke_device(device_id, function, ...) is deprecated; use " diff --git a/packages/device-connect-agent-tools/tests/test_broadcast.py b/packages/device-connect-agent-tools/tests/test_broadcast.py index e25a35e..e8d8831 100644 --- a/packages/device-connect-agent-tools/tests/test_broadcast.py +++ b/packages/device-connect-agent-tools/tests/test_broadcast.py @@ -100,7 +100,7 @@ def test_envelope_carries_function_and_targets(self, mock_conn): env = mock_conn._published[0] assert env["function"] == "capture_image" assert env["params"] == {"resolution": "4k"} - assert sorted(env["target_device_ids"]) == ["cam-001", "cam-002"] + assert sorted(env["targets"]) == ["cam-001", "cam-002"] # No optional fields when caller did not set them. assert "where" not in env assert "bindings" not in env diff --git a/packages/device-connect-agent-tools/tests/test_subscribe.py b/packages/device-connect-agent-tools/tests/test_subscribe.py index a6f032e..a8b4be4 100644 --- a/packages/device-connect-agent-tools/tests/test_subscribe.py +++ b/packages/device-connect-agent-tools/tests/test_subscribe.py @@ -159,6 +159,45 @@ def test_iter_yields_until_idle_timeout(self, fake_conn): assert len(msgs) == 1 sub.close() + def test_for_loop_protocol_via_dunder_iter(self, fake_conn): + # ``for msg in sub:`` should drive __iter__ which delegates to iter() + # with a sensible default timeout. Break early so the test does not + # block on the 30s default. + sub = tools_mod.subscribe("correlation:r_iter") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r_iter", + {"correlation_id": "r_iter", "device_id": "cam-001"}, + ) + gathered: list[dict] = [] + for msg in sub: + gathered.append(msg) + break # one message is enough to confirm __iter__ wiring + sub.close() + assert len(gathered) == 1 + assert gathered[0]["device_id"] == "cam-001" + + def test_read_does_not_drop_messages_appended_during_iteration(self, fake_conn): + # Race-safety guard: simulate a callback that appends a fresh + # message between the read's snapshot and truncation. The message + # must still be visible on the next read(). + sub = tools_mod.subscribe("correlation:r_race") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r_race", + {"correlation_id": "r_race", "device_id": "cam-001", "ordinal": 1}, + ) + first = sub.read() + assert len(first) == 1 + # Now simulate a late-arriving append into the same inbox AFTER + # the previous read drained the prefix. + fake_conn.deliver( + "device-connect.default.cam-002.event.async_reply.r_race", + {"correlation_id": "r_race", "device_id": "cam-002", "ordinal": 2}, + ) + second = sub.read() + assert len(second) == 1 + assert second[0]["device_id"] == "cam-002" + sub.close() + # -- await_replies -------------------------------------------------- diff --git a/packages/device-connect-edge/device_connect_edge/device.py b/packages/device-connect-edge/device_connect_edge/device.py index b64d443..96e31d4 100644 --- a/packages/device-connect-edge/device_connect_edge/device.py +++ b/packages/device-connect-edge/device_connect_edge/device.py @@ -1144,17 +1144,22 @@ async def _broadcast_subscription(self) -> None: "correlation_id": "br-abc123", "function": "capture_image", "params": {"resolution": "4k"}, - "target_device_ids": ["cam-001", "cam-002"], // pre-resolved - "where": "status.battery > 50", // optional CEL - "bindings": {"mask": [[0,1],[1,0]]}, // optional - "fire_at": 1234567890.5, // optional, epoch s - "on_late": "skip" // skip|fire + "targets": ["cam-001", "cam-002"], // pre-resolved + "where": "status.battery > 50", // optional CEL + "bindings": {"mask": [[0,1],[1,0]]}, // optional + "fire_at": 1234567890.5, // optional, epoch s + "on_late": "skip" // skip|fire } On match, the device executes the function and emits a reply on ``device-connect...event.async_reply.`` with ``{correlation_id, device_id, success, result|error, actually_fired_at}``. + + The envelope is processed in a tracked task so the subscription + loop does not block on ``fire_at`` sleeps or long-running driver + functions; subsequent broadcasts can continue to land while an + earlier one is in flight. """ subj = f"device-connect.{self.tenant}.broadcast" @@ -1169,115 +1174,151 @@ async def on_msg(data: bytes, reply_subject: Optional[str]): if not correlation_id: return - # Self-election step 1: target_device_ids gate (pre-resolved by - # the dispatcher from the selector). When absent or empty, treat - # the broadcast as fleet-wide. - targets = envelope.get("target_device_ids") or [] + # Cheap self-election: target gate (pre-resolved by the dispatcher + # from the selector). When absent or empty, treat as fleet-wide. + targets = envelope.get("targets") or [] if targets and self.device_id not in targets: return - function_name = envelope.get("function") - if not function_name: + if not envelope.get("function"): return - params_dict = envelope.get("params", {}) or {} - # Self-election step 2: where predicate against {identity, labels, - # status, bindings}. A failed compile or eval is treated as - # fail-closed (do not execute). - where_expr = envelope.get("where") - if where_expr: - try: - from device_connect_edge.predicate import compile_where - predicate = compile_where(where_expr) - caps = self._driver.capabilities if self._driver else self.capabilities - status = self._driver.status if self._driver else None - labels = (caps.labels if caps and caps.labels else {}) or {} - status_dict = ( - status.model_dump() if status and hasattr(status, "model_dump") else {} - ) - # Mirror the legacy DeviceStatus.location into labels so - # ``labels.location`` works in predicates without the driver - # having to declare it explicitly. Matches the dispatcher-side - # flatten_device contract. - if "location" not in labels and status_dict.get("location"): - labels = {**labels, "location": status_dict["location"]} - context = { - "identity": ( - caps.identity.model_dump() - if caps and getattr(caps, "identity", None) else {} - ), - "labels": labels, - "status": status_dict, - "bindings": envelope.get("bindings", {}) or {}, - } - if not predicate.evaluate(context): - return - except Exception as e: - self._logger.warning( - "Broadcast %s: where predicate failed (skipping): %s", - correlation_id, e, - ) - return + # Hand off to a tracked task. The task owns the where evaluation, + # the fire_at sleep, and the driver call, so this callback returns + # immediately and the messaging subscription stays drained. + self._track_task(asyncio.create_task( + self._handle_broadcast_envelope(envelope, correlation_id) + )) - # fire_at: hold the message until the wall-clock deadline. The - # on_late policy decides what to do if the message arrives past - # the deadline (skip preserves coherence; fire runs anyway). - fire_at = envelope.get("fire_at") - on_late = envelope.get("on_late", "skip") - if fire_at is not None: - delay = float(fire_at) - time.time() - if delay < 0 and on_late == "skip": - self._logger.info( - "Broadcast %s arrived %.3fs late, on_late=skip", - correlation_id, -delay, - ) - return - if delay > 0: - await asyncio.sleep(delay) + await self.messaging.subscribe(subj, callback=on_msg) + self._logger.info("Subscribed to broadcasts on %s", subj) - # Execute the driver function and emit the reply. - actually_fired_at = time.time() - reply_subj = ( - f"device-connect.{self.tenant}.{self.device_id}" - f".event.async_reply.{correlation_id}" - ) - try: - if self._driver is None: - raise RuntimeError("no driver configured") - driver_functions = self._driver._get_functions() - if function_name not in driver_functions: - raise RuntimeError(f"unknown function: {function_name}") - result = await self._driver.invoke(function_name, **params_dict) - reply_payload = { - "correlation_id": correlation_id, - "device_id": self.device_id, - "success": True, - "result": result, - "actually_fired_at": actually_fired_at, - } - except Exception as e: - self._logger.warning( - "Broadcast %s: function %s failed: %s", - correlation_id, function_name, e, - ) - reply_payload = { - "correlation_id": correlation_id, - "device_id": self.device_id, - "success": False, - "error": {"code": "invoke_failed", "message": str(e)}, - "actually_fired_at": actually_fired_at, - } - try: - await self.messaging.publish( - reply_subj, json.dumps(reply_payload).encode(), - ) - except Exception as e: # pragma: no cover - self._logger.warning( - "Broadcast %s: reply publish failed: %s", correlation_id, e, + + async def _handle_broadcast_envelope( + self, envelope: Dict[str, Any], correlation_id: str, + ) -> None: + """Process one broadcast envelope: evaluate where, honour fire_at, invoke, reply. + + Runs in its own task so a long-held ``fire_at`` or slow driver + function does not block the subscription callback from accepting + subsequent broadcasts. + """ + function_name = envelope.get("function") + params_dict = envelope.get("params", {}) or {} + + # Step 1: where predicate against {identity, labels, status, bindings}. + # A failed compile or eval is treated as fail-closed (do not execute); + # the message is logged at WARNING with the correlation_id so an + # operator can correlate a silent skip with a misspelled label key. + where_expr = envelope.get("where") + if where_expr and not self._evaluate_where( + where_expr, envelope.get("bindings"), correlation_id, + ): + return + + # Step 2: fire_at hold. The on_late policy decides what to do when + # the message arrives past the deadline (skip preserves coherence; + # fire runs anyway). + fire_at = envelope.get("fire_at") + on_late = envelope.get("on_late", "skip") + if fire_at is not None: + delay = float(fire_at) - time.time() + if delay < 0 and on_late == "skip": + self._logger.info( + "Broadcast %s arrived %.3fs late, on_late=skip", + correlation_id, -delay, ) + return + if delay > 0: + await asyncio.sleep(delay) - await self.messaging.subscribe(subj, callback=on_msg) - self._logger.info("Subscribed to broadcasts on %s", subj) + # Step 3: execute and reply. + actually_fired_at = time.time() + reply_subj = ( + f"device-connect.{self.tenant}.{self.device_id}" + f".event.async_reply.{correlation_id}" + ) + try: + if self._driver is None: + raise RuntimeError("no driver configured") + driver_functions = self._driver._get_functions() + if function_name not in driver_functions: + raise RuntimeError(f"unknown function: {function_name}") + result = await self._driver.invoke(function_name, **params_dict) + reply_payload: Dict[str, Any] = { + "correlation_id": correlation_id, + "device_id": self.device_id, + "success": True, + "result": result, + "actually_fired_at": actually_fired_at, + } + except Exception as e: + self._logger.warning( + "Broadcast %s: function %s failed: %s", + correlation_id, function_name, e, + ) + reply_payload = { + "correlation_id": correlation_id, + "device_id": self.device_id, + "success": False, + "error": {"code": "invoke_failed", "message": str(e)}, + "actually_fired_at": actually_fired_at, + } + try: + await self.messaging.publish( + reply_subj, json.dumps(reply_payload).encode(), + ) + except Exception as e: # pragma: no cover + self._logger.warning( + "Broadcast %s: reply publish failed: %s", correlation_id, e, + ) + + + def _evaluate_where( + self, + where_expr: str, + bindings: Optional[Dict[str, Any]], + correlation_id: str, + ) -> bool: + """Compile and evaluate a where predicate; return True iff it passes. + + Returns False (do not execute) on compile or eval errors, logging + a warning so silent self-deselection is operator-visible. + """ + try: + from device_connect_edge.predicate import compile_where + predicate = compile_where(where_expr) + caps = self._driver.capabilities if self._driver else self.capabilities + status = self._driver.status if self._driver else None + labels = (caps.labels if caps and caps.labels else {}) or {} + status_dict = ( + status.model_dump() if status and hasattr(status, "model_dump") else {} + ) + # Mirror DeviceStatus.location into labels so ``labels.location`` + # works in predicates without the driver having to declare it + # explicitly. Matches the dispatcher-side flatten_device contract. + if "location" not in labels and status_dict.get("location"): + labels = {**labels, "location": status_dict["location"]} + # The DeviceIdentity model carries device_type / manufacturer / + # model / firmware_version but NOT device_id (which lives on the + # runtime). Splice it in so predicates can write the natural + # ``identity.device_id == "..."``. + identity_dict: Dict[str, Any] = {"device_id": self.device_id} + if caps and getattr(caps, "identity", None): + identity_dict.update(caps.identity.model_dump()) + context = { + "identity": identity_dict, + "labels": labels, + "status": status_dict, + "bindings": bindings or {}, + } + return bool(predicate.evaluate(context)) + except Exception as e: + self._logger.warning( + "Broadcast %s: where predicate failed (skipping): %s", + correlation_id, e, + ) + return False async def _event_dispatch_loop(self) -> None: diff --git a/packages/device-connect-edge/device_connect_edge/predicate.py b/packages/device-connect-edge/device_connect_edge/predicate.py index 6ddc7c0..5bf5ff6 100644 --- a/packages/device-connect-edge/device_connect_edge/predicate.py +++ b/packages/device-connect-edge/device_connect_edge/predicate.py @@ -15,7 +15,7 @@ bindings shared payload supplied by the caller (selection masks, thresholds, lookup tables) -Examples (every example here ships with v4 spec):: +Examples:: battery > 50 labels.category == "camera" && status.battery > 50 diff --git a/packages/device-connect-server/device_connect_server/statectl/operations_cli.py b/packages/device-connect-server/device_connect_server/statectl/operations_cli.py index 630709a..7ddc9ef 100644 --- a/packages/device-connect-server/device_connect_server/statectl/operations_cli.py +++ b/packages/device-connect-server/device_connect_server/statectl/operations_cli.py @@ -109,7 +109,13 @@ def run_invoke_many(args: Any) -> int: llm_reasoning=args.reason, ) print(_pretty(result)) - return 0 if "error" not in result else 1 + # Exit non-zero on a top-level error OR when any target failed, so + # shell pipelines can detect partial failure without parsing JSON. + if "error" in result: + return 1 + if result.get("failed", 0) > 0: + return 3 + return 0 finally: try: disconnect() @@ -154,21 +160,30 @@ def run_subscribe(args: Any) -> int: Each message is printed as one JSON line so the output can be piped into ``jq`` or grep. Runs until ``--timeout`` of idle silence elapses or ``--until`` messages have been printed (whichever comes first). + Exit codes: + 0 one or more messages were printed + 4 idle-timeout reached with zero messages + 130 interrupted with Ctrl-C """ from device_connect_agent_tools import disconnect, subscribe _connect(getattr(args, "broker", None)) + count = 0 try: - count = 0 with subscribe(args.selector) as sub: - for msg in sub.iter( - timeout=float(args.timeout), poll_interval=0.05, - ): - print(json.dumps(msg, default=str)) - count += 1 - if args.until is not None and count >= int(args.until): - break - return 0 + try: + for msg in sub.iter( + timeout=float(args.timeout), poll_interval=0.05, + ): + print(json.dumps(msg, default=str)) + count += 1 + if args.until is not None and count >= int(args.until): + break + except KeyboardInterrupt: + # Clean exit on Ctrl-C: the ``with`` block tears the + # subscription down before this returns. + return 130 + return 0 if count > 0 else 4 finally: try: disconnect() diff --git a/tests/tests/test_tools_broadcast.py b/tests/tests/test_tools_broadcast.py index 0e7413f..975016e 100644 --- a/tests/tests/test_tools_broadcast.py +++ b/tests/tests/test_tools_broadcast.py @@ -183,6 +183,151 @@ async def test_broadcast_fire_at_late_with_skip_drops( await asyncio.to_thread(disconnect) +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_where_with_bindings(device_spawner, messaging_url): + """A where predicate that reads bindings. self-elects per-target.""" + pytest.importorskip("celpy") + await device_spawner.spawn_camera("itest-bcbnd-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bcbnd-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices( + messaging_url, {"itest-bcbnd-cam-1", "itest-bcbnd-cam-2"} + ) + try: + # Allowlist sent in bindings; the predicate uses bindings.allow to + # select. Devices not in the allowlist self-deselect silently. + result = await asyncio.to_thread( + broadcast, + "device(itest-bcbnd-cam-*).function(capture_image)", + None, + "identity.device_id in bindings.allow", + {"allow": ["itest-bcbnd-cam-1"]}, + ) + assert result["candidates"] == 2 + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=3.0, + ) + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bcbnd-cam-1"} + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_await_replies_until_stops_early(device_spawner, messaging_url): + """``await_replies`` returns once ``until`` replies have arrived.""" + await device_spawner.spawn_camera("itest-awu-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-awu-cam-2", location="lab-A") + await device_spawner.spawn_camera("itest-awu-cam-3", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices( + messaging_url, {"itest-awu-cam-1", "itest-awu-cam-2", "itest-awu-cam-3"} + ) + try: + result = await asyncio.to_thread( + broadcast, "device(itest-awu-cam-*).function(capture_image)", + ) + assert result["candidates"] == 3 + # until=1 should let us return after the first reply arrives even + # though more are coming. + t0 = time.monotonic() + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], + timeout=5.0, until=1, poll_interval=0.02, + ) + elapsed = time.monotonic() - t0 + assert len(replies) >= 1 + # Sanity: returning early should be well under the timeout. + assert elapsed < 2.0, f"await_replies(until=1) took {elapsed:.2f}s" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_subscribe_iter_protocol(device_spawner, messaging_url): + """``for msg in sub:`` works via Subscription.__iter__.""" + await device_spawner.spawn_camera("itest-subiter-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-subiter-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import broadcast, disconnect, subscribe + + await _wait_for_devices( + messaging_url, {"itest-subiter-cam-1", "itest-subiter-cam-2"} + ) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-subiter-cam-*).function(capture_image)", + ) + cid = result["correlation_id"] + + def collect(): + # Exercise the bare ``for msg in sub:`` form (uses __iter__). + # Break after both expected replies arrive so the test stays + # bounded regardless of the default idle timeout. + with subscribe(f"correlation:{cid}") as sub: + gathered: list[dict] = [] + for msg in sub: + gathered.append(msg) + if len(gathered) >= 2: + break + return gathered + + replies = await asyncio.to_thread(collect) + ids = {r["device_id"] for r in replies} + assert ids == {"itest-subiter-cam-1", "itest-subiter-cam-2"} + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_subscribe_event_selector_live_stream(device_spawner, messaging_url): + """subscribe(event()) receives live events from matching devices.""" + device, driver = await device_spawner.spawn_camera( + "itest-evsub-cam", location="lab-A", + ) + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, subscribe + + await _wait_for_devices(messaging_url, {"itest-evsub-cam"}) + try: + with subscribe("device(itest-evsub-cam).event(object_detected)") as sub: + await asyncio.sleep(SETTLE_TIME) # let subscription warm up + await driver.trigger_event( + "object_detected", + {"label": "person", "confidence": 0.95}, + ) + msgs = await asyncio.to_thread( + list, sub.iter(timeout=2.0, poll_interval=0.05), + ) + # The event arrives via the JSON-RPC event subject; payload is + # under either ``params`` or top-level depending on transport. + matching = [ + m for m in msgs + if (m.get("params") or {}).get("label") == "person" + or m.get("label") == "person" + ] + assert matching, f"no object_detected events received: {msgs}" + finally: + await asyncio.to_thread(disconnect) + + @pytest.mark.asyncio @pytest.mark.integration async def test_subscribe_correlation_form(device_spawner, messaging_url): From 8660801eeb511b8bb9a35e959bac8b7ad01b7caa Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 16:44:50 -0700 Subject: [PATCH 7/8] feat(adapters): expose broadcast and await_replies via all three adapters Phases 4-5 added broadcast() and await_replies() to the agent-tools surface but the adapter migration in feat(invoke) only carried invoke / invoke_many across. The flashlight-auditorium demo needs the LLM to issue selector-driven broadcasts with where + bindings + fire_at, so broadcast and await_replies both need to be Strands/LangChain/Claude tools as well. Tool descriptions for the Claude adapter spell out the broadcast + await_replies pairing (caller fires broadcast, then awaits replies by correlation_id) so agents discover the workflow from the tool docs. subscribe() is intentionally NOT exposed via the adapters: it returns a Subscription handle that does not serialise cleanly as a tool result and is more natural to call from operator code or the CLI than from an LLM. Agents needing the same shape use broadcast + await_replies. --- .../adapters/claude.py | 55 +++++++++++++++++++ .../adapters/langchain.py | 6 ++ .../adapters/strands.py | 6 ++ .../tests/test_claude_adapter.py | 2 + .../tests/test_langchain_adapter.py | 2 + .../tests/test_strands_adapter.py | 2 + 6 files changed, 73 insertions(+) diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py index 9dd08d8..f4a2883 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py @@ -47,6 +47,8 @@ async def main(): discover_devices as _discover_devices, invoke as _invoke, invoke_many as _invoke_many, + broadcast as _broadcast, + await_replies as _await_replies, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) @@ -147,6 +149,55 @@ async def invoke_many(args: dict[str, Any]) -> dict[str, Any]: ) +@tool( + "broadcast", + "Async selector-driven fan-out. Returns immediately with a " + "correlation_id; replies stream on a per-device subject keyed by id. " + "Each candidate self-elects via the optional CEL `where` predicate " + "(evaluated at the edge against identity/labels/status/bindings) and " + "executes the function. Use fire_at (wall-clock epoch seconds) + " + "on_late (skip|fire) for synchronized fan-out. Pair with " + "await_replies(correlation_id) to collect outcomes.", + { + "selector": str, "params": dict, "where": str, "bindings": dict, + "fire_at": float, "on_late": str, "llm_reasoning": str, + }, +) +async def broadcast(args: dict[str, Any]) -> dict[str, Any]: + return _text( + _broadcast( + selector=args["selector"], + params=args.get("params"), + where=args.get("where"), + bindings=args.get("bindings"), + fire_at=args.get("fire_at"), + on_late=args.get("on_late", "skip"), + llm_reasoning=args.get("llm_reasoning"), + ) + ) + + +@tool( + "await_replies", + "Collect replies for a broadcast() call. Subscribes to the " + "correlation reply subject, drains for up to `timeout` seconds (or " + "until `until` replies have arrived), then returns the list.", + { + "correlation_id": str, "timeout": float, "until": int, + "poll_interval": float, + }, +) +async def await_replies(args: dict[str, Any]) -> dict[str, Any]: + return _text( + _await_replies( + correlation_id=args["correlation_id"], + timeout=float(args.get("timeout", 10.0)), + until=int(args["until"]) if args.get("until") is not None else None, + poll_interval=float(args.get("poll_interval", 0.05)), + ) + ) + + # Other invocation helpers @@ -206,6 +257,8 @@ def create_device_connect_server(name: str = "device-connect"): discover, invoke, invoke_many, + broadcast, + await_replies, invoke_device_with_fallback, get_device_status, discover_devices, @@ -218,6 +271,8 @@ def create_device_connect_server(name: str = "device-connect"): "discover", "invoke", "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py index 35d5e51..c18ed7e 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py @@ -26,6 +26,8 @@ discover_devices as _discover_devices, invoke as _invoke, invoke_many as _invoke_many, + broadcast as _broadcast, + await_replies as _await_replies, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) @@ -37,6 +39,8 @@ # Selector-driven invocation (recommended) invoke = StructuredTool.from_function(_invoke) invoke_many = StructuredTool.from_function(_invoke_many) +broadcast = StructuredTool.from_function(_broadcast) +await_replies = StructuredTool.from_function(_await_replies) # Other invocation helpers invoke_device_with_fallback = StructuredTool.from_function(_invoke_device_with_fallback) @@ -50,6 +54,8 @@ "discover", "invoke", "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py index d22fcf7..b68c16b 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py @@ -27,6 +27,8 @@ discover_devices as _discover_devices, invoke as _invoke, invoke_many as _invoke_many, + broadcast as _broadcast, + await_replies as _await_replies, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) @@ -38,6 +40,8 @@ # Selector-driven invocation (recommended) invoke = strands_tool(_invoke) invoke_many = strands_tool(_invoke_many) +broadcast = strands_tool(_broadcast) +await_replies = strands_tool(_await_replies) # Other invocation helpers invoke_device_with_fallback = strands_tool(_invoke_device_with_fallback) @@ -51,6 +55,8 @@ "discover", "invoke", "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/tests/test_claude_adapter.py b/packages/device-connect-agent-tools/tests/test_claude_adapter.py index 311aab5..4960a49 100644 --- a/packages/device-connect-agent-tools/tests/test_claude_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_claude_adapter.py @@ -70,6 +70,8 @@ def _mock_sdk_and_connection(): "discover_devices", "invoke", "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", ) diff --git a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py index c4a487e..9aae070 100644 --- a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py @@ -74,6 +74,8 @@ def _mock_langchain_and_connection(): "discover", "invoke", "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/tests/test_strands_adapter.py b/packages/device-connect-agent-tools/tests/test_strands_adapter.py index 30d1ae0..4e46ceb 100644 --- a/packages/device-connect-agent-tools/tests/test_strands_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_strands_adapter.py @@ -57,6 +57,8 @@ def _mock_strands_and_connection(): "discover", "invoke", "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", "discover_devices", From 08c9aa7d69029c88963c251c07b146f5b4f4cc6b Mon Sep 17 00:00:00 2001 From: Sourav Pati Date: Sun, 10 May 2026 18:45:51 -0700 Subject: [PATCH 8/8] fix(broadcast): read identity from driver, not from DeviceCapabilities The broadcast handler built the where-predicate context from ``caps.identity`` -- but DeviceCapabilities does not carry an ``identity`` field; that lives on the driver as a separate DeviceIdentity model. The ``getattr(caps, "identity", None)`` fallback masked the bug: identity_dict was always just ``{"device_id": ...}`` with none of the driver's extra fields (seat_row, seat_col, x-mhp slot metadata, ...) reaching the predicate. Symptom: a where predicate like ``bindings.mask[identity.seat_row][identity.seat_col] == 1`` failed at every candidate (CEL surfaces undefined field access as CELEvalError, fail-closed fires, nobody self-elects). Fix: read identity from ``self._driver.identity`` and splice in ``device_id`` from the runtime. Backwards-compatible with drivers that don't expose an identity property (driver_identity is None -> only device_id is present, same as before for those drivers). Surfaced while building the flashlight-auditorium demo, where each phone exposes its seat coordinates as extra fields on DeviceIdentity and the spell-CMU broadcast indexes a 2D mask by those coordinates. --- .../device_connect_edge/device.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/packages/device-connect-edge/device_connect_edge/device.py b/packages/device-connect-edge/device_connect_edge/device.py index 96e31d4..c776d1a 100644 --- a/packages/device-connect-edge/device_connect_edge/device.py +++ b/packages/device-connect-edge/device_connect_edge/device.py @@ -1299,13 +1299,18 @@ def _evaluate_where( # explicitly. Matches the dispatcher-side flatten_device contract. if "location" not in labels and status_dict.get("location"): labels = {**labels, "location": status_dict["location"]} - # The DeviceIdentity model carries device_type / manufacturer / - # model / firmware_version but NOT device_id (which lives on the - # runtime). Splice it in so predicates can write the natural - # ``identity.device_id == "..."``. + # DeviceIdentity is exposed by the driver, not by DeviceCapabilities; + # they are independent pydantic models. Read identity from the + # driver so extra fields (seat_row, seat_col, x-mhp metadata, ...) + # reach the predicate context. Splice in device_id which lives on + # the runtime so predicates can write + # ``identity.device_id == "..."`` naturally. identity_dict: Dict[str, Any] = {"device_id": self.device_id} - if caps and getattr(caps, "identity", None): - identity_dict.update(caps.identity.model_dump()) + driver_identity = ( + getattr(self._driver, "identity", None) if self._driver else None + ) + if driver_identity is not None and hasattr(driver_identity, "model_dump"): + identity_dict.update(driver_identity.model_dump()) context = { "identity": identity_dict, "labels": labels,