diff --git a/docs/discovery.md b/docs/discovery.md index 6d2b8b4..ee33d60 100644 --- a/docs/discovery.md +++ b/docs/discovery.md @@ -96,7 +96,9 @@ function(estop) fleet emergency-st ## Tools -### `discover(selector, offset=0, limit=200)` +### Discovery + +#### `discover(selector, offset=0, limit=200)` Resolves a selector to matched entities. Returns devices, function tuples, or event tuples depending on the selector scope. The response includes a @@ -108,7 +110,7 @@ and switches to a name-and-labels summary above `DEVICE_CONNECT_FUNCTION_THRESHOLD` (default 20). The threshold is configurable via environment variable. -### `discover_labels(key=None, offset=0, limit=50)` +#### `discover_labels(key=None, offset=0, limit=50)` Returns the fleet label vocabulary. Use this first when you do not know which dimensions are available. @@ -118,6 +120,84 @@ which dimensions are available. - With a `key` like `"device.location"` or `"function.direction"`: paginates the full value list for that one key. +### Operations + +Calling a function on devices is one logical operation; the only choice +is whether the caller waits for replies and how they arrive. + +| Tool | Selector resolves to | Reply mode | +| --- | --- | --- | +| `invoke(selector, params)` | exactly one (device, function) tuple | sync, single result | +| `invoke_many(selector, params, timeout=)` | any number of (device, function) tuples | sync, aggregated | +| `broadcast(selector, params, where=, bindings=, fire_at=, on_late=)` | any number of (device, function) tuples | async; correlation-tagged replies stream as events | +| `subscribe(selector)` | events, or `"correlation:"` for broadcast replies | live stream (`Subscription` handle) | +| `await_replies(correlation_id, timeout=, until=)` | replies for one broadcast | sync helper that subscribes, collects, returns | + +`invoke_many` runs every target's call in parallel and returns when each +target has finished or hit its per-target timeout (30 s default). Partial +failures do not abort siblings; the response carries both `results` and +`errors` lists. + +`broadcast` does the same fan-out asynchronously: the caller gets a +`correlation_id` immediately and replies stream back on a per-device +subject keyed by that id. Subscribe with `subscribe("correlation:")` +or block with `await_replies(correlation_id, timeout=...)`. + +### Edge-side `where` predicate + +`broadcast` accepts an optional `where` expression that runs at each +candidate device. The predicate is a CEL (Common Expression Language) +string and sees four variables: + +- `identity` — device-local identity dict (`device_id`, `device_type`, ...) +- `labels` — device labels (the same labels selectors filter on) +- `status` — device status (heartbeat-updated: `location`, `availability`, + `battery`, `online`, ...) +- `bindings` — the shared payload passed to `broadcast` (selection masks, + thresholds, lookup tables) + +```python +broadcast( + "device(category:camera).function(capture_image)", + params={"resolution": "4k"}, + where="status.battery > 50 && labels.location == 'lab-A'", +) +``` + +The `where` predicate is sandboxed by CEL (no I/O, no filesystem). The +predicate evaluator is an optional install: + +``` +pip install device-connect-agent-tools[predicate] +``` + +Without the extra, calling `broadcast(..., where=...)` returns an +`invalid_predicate` error immediately at the dispatcher; calls without a +`where` work unchanged. + +### Synchronized fan-out (`fire_at` + `on_late`) + +`broadcast` accepts an optional `fire_at` (wall-clock epoch seconds). +Each device holds the message and fires from its own clock at the +deadline. `on_late` controls behaviour when a device receives the +message past the deadline: + +- `"skip"` (default) — drop the call to preserve coherence. +- `"fire"` — execute immediately. + +```python +broadcast( + "device(category:phone).function(set_flashlight)", + params={"on": True, "color": "white"}, + fire_at=time.time() + 0.500, # 500 ms in the future + on_late="skip", +) +``` + +With NTP-synced devices the achieved spread is typically 5-10 ms +(clock-sync residual) rather than the 50-150 ms a naive fire-on-receipt +broadcast would produce. + ## Response envelope `discover` returns a stable envelope: @@ -229,3 +309,103 @@ while True: break offset = page["next_offset"] ``` + +### Invoke a single function + +```python +from device_connect_agent_tools import invoke + +result = invoke( + "device(robot-001).function(grip_close)", + {"force_n": 10}, +) +# {"success": True, "device_id": "robot-001", "function": "grip_close", +# "result": {...}} +``` + +### Fan out across every camera in lab-A + +```python +from device_connect_agent_tools import invoke_many + +result = invoke_many( + "device(category:camera, location:lab-A).function(capture_image)", + {"resolution": "4k"}, +) +# {"candidates": 12, "matched": 12, "succeeded": 12, "failed": 0, +# "results": [...], "errors": []} +``` + +### Async fleet emergency stop + +```python +from device_connect_agent_tools import broadcast, await_replies + +result = broadcast("function(estop)") +# {"correlation_id": "br-7f3a91", "candidates": 240, ...} + +replies = await_replies(result["correlation_id"], timeout=5.0) +# list of {device_id, success, result|error, actually_fired_at} +``` + +### Synchronized actuation across a phone fleet + +```python +import time +from device_connect_agent_tools import broadcast + +mask = build_mask_from_scores(threshold=0.8) # caller-side selection +broadcast( + "device(category:phone, location:auditorium-A).function(set_flashlight)", + params={"on": True, "color": "white"}, + where="mask[seat_row][seat_col] == 1 && status.battery > 30", + bindings={"mask": mask}, + fire_at=time.time() + 0.5, + on_late="skip", +) +``` + +### Subscribe to motion events in lab-A + +```python +from device_connect_agent_tools import subscribe + +with subscribe("device(location:lab-A/*).event(modality:motion)") as sub: + for event in sub.iter(timeout=60.0): + handle(event) +``` + +## CLI + +The same selector syntax drives the operator CLIs. Every CLI command +maps to the matching Python tool call. + +``` +# Discovery (devctl) +devctl discover "" [--offset N] [--limit M] +devctl discover-labels [--key K] [--offset N] [--limit M] + +# Operations (statectl) +statectl invoke "" [--param k=v ...] +statectl invoke-many "" [--param k=v ...] [--timeout T] +statectl broadcast "" [--param k=v ...] [--where E] + [--bindings JSON] [--fire-at T] + [--on-late skip|fire] +statectl subscribe "" [--timeout T] [--until N] +statectl await [--timeout T] [--until N] +``` + +`--param k=v` accepts JSON-shaped values (numbers, booleans, arrays, +objects); everything else passes through as a string. So +`--param resolution=4k` and `--param zones='[1,2,3]'` both work +without quoting heroics. + +Each verb exits non-zero on tool-side errors so the verbs compose into +shell pipelines: + +``` +statectl broadcast "device(category:camera).function(capture_image)" \ + --param resolution=4k \ + | jq -r .correlation_id \ + | xargs statectl await --timeout 5 +``` diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py b/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py index de79913..1a7c1e0 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/__init__.py @@ -4,19 +4,21 @@ """Device Connect Tools — framework-agnostic SDK for Device Connect IoT. -Selector-driven discovery keeps LLM context small: +Selector-driven discovery and invocation keep LLM context small: - from device_connect_agent_tools import connect, discover, discover_labels + from device_connect_agent_tools import connect, discover, discover_labels, invoke connect() vocab = discover_labels() # fleet vocabulary cams = discover("device(category:camera, location:zone-A/*)") # device roster writes = discover("device(*).function(direction:write)") # function tuples - result = invoke_device("camera-001", "capture_image", {"resolution": "1080p"}) + result = invoke("device(camera-001).function(capture_image)", + {"resolution": "1080p"}) -The older ``describe_fleet`` / ``list_devices`` / ``get_device_functions`` -trio remains available for one release as advisory-deprecated wrappers -- -prefer ``discover`` / ``discover_labels`` for new code. +The older ``describe_fleet`` / ``list_devices`` / ``get_device_functions`` / +``invoke_device`` family remains available for one release as +advisory-deprecated wrappers -- prefer ``discover`` / ``discover_labels`` / +``invoke`` / ``invoke_many`` for new code. """ from device_connect_agent_tools.agent import DeviceConnectAgent @@ -25,14 +27,22 @@ # Selector-driven discovery (preferred) discover, discover_labels, - # Invocation - invoke_device, + # Selector-driven invocation (preferred) + invoke, + invoke_many, + broadcast, + # Selector-driven subscription + Subscription, + subscribe, + await_replies, + # Other invocation helpers invoke_device_with_fallback, get_device_status, - # Advisory-deprecated discovery wrappers (one-release transition) + # Advisory-deprecated wrappers (one-release transition) describe_fleet, list_devices, get_device_functions, + invoke_device, discover_devices, ) @@ -46,13 +56,21 @@ # Selector-driven discovery (preferred) "discover", "discover_labels", - # Invocation - "invoke_device", + # Selector-driven invocation (preferred) + "invoke", + "invoke_many", + "broadcast", + # Selector-driven subscription + "Subscription", + "subscribe", + "await_replies", + # Other invocation helpers "invoke_device_with_fallback", "get_device_status", - # Advisory-deprecated -- use discover() / discover_labels() instead + # Advisory-deprecated -- use discover / discover_labels / invoke instead "describe_fleet", "list_devices", "get_device_functions", + "invoke_device", "discover_devices", ] diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py index 807abcb..f4a2883 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/claude.py @@ -4,7 +4,7 @@ """Claude Agent SDK adapter — exposes Device Connect tools to claude-agent-sdk. -Selector-driven discovery keeps LLM context small:: +Selector-driven discovery and invocation keep LLM context small:: import anyio from claude_agent_sdk import ClaudeSDKClient, ClaudeAgentOptions, AssistantMessage, TextBlock @@ -45,7 +45,10 @@ async def main(): discover as _discover, discover_labels as _discover_labels, discover_devices as _discover_devices, - invoke_device as _invoke_device, + invoke as _invoke, + invoke_many as _invoke_many, + broadcast as _broadcast, + await_replies as _await_replies, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) @@ -101,27 +104,103 @@ async def discover(args: dict[str, Any]) -> dict[str, Any]: ) -# Invocation tools +# Selector-driven invocation tools (recommended) @tool( - "invoke_device", - "Call a function on a Device Connect device. Use discover() with a " - "function-scoped selector first to learn available functions and " - "parameters.", - {"device_id": str, "function": str, "params": dict, "llm_reasoning": str}, + "invoke", + "Call exactly one function on one device. The selector must resolve " + "to a single (device, function) tuple -- use device().function() " + "or function() scope. Returns {success, device_id, function, " + "result|error}. Use invoke_many for fan-out across multiple targets.", + {"selector": str, "params": dict, "llm_reasoning": str}, ) -async def invoke_device(args: dict[str, Any]) -> dict[str, Any]: +async def invoke(args: dict[str, Any]) -> dict[str, Any]: return _text( - _invoke_device( - device_id=args["device_id"], - function=args["function"], + _invoke( + selector=args["selector"], + params=args.get("params"), + llm_reasoning=args.get("llm_reasoning"), + ) + ) + + +@tool( + "invoke_many", + "Fan out a function call over a selector-resolved set of (device, " + "function) tuples in parallel. Partial-failure semantics: per-target " + "results and errors are returned even if some targets fail. Returns " + "{candidates, matched, succeeded, failed, results, errors}. Each " + "target gets a per-call timeout (default 30s).", + { + "selector": str, "params": dict, "timeout": float, + "max_concurrency": int, "llm_reasoning": str, + }, +) +async def invoke_many(args: dict[str, Any]) -> dict[str, Any]: + return _text( + _invoke_many( + selector=args["selector"], params=args.get("params"), + timeout=float(args.get("timeout", 30.0)), + max_concurrency=int(args.get("max_concurrency", 32)), llm_reasoning=args.get("llm_reasoning"), ) ) +@tool( + "broadcast", + "Async selector-driven fan-out. Returns immediately with a " + "correlation_id; replies stream on a per-device subject keyed by id. " + "Each candidate self-elects via the optional CEL `where` predicate " + "(evaluated at the edge against identity/labels/status/bindings) and " + "executes the function. Use fire_at (wall-clock epoch seconds) + " + "on_late (skip|fire) for synchronized fan-out. Pair with " + "await_replies(correlation_id) to collect outcomes.", + { + "selector": str, "params": dict, "where": str, "bindings": dict, + "fire_at": float, "on_late": str, "llm_reasoning": str, + }, +) +async def broadcast(args: dict[str, Any]) -> dict[str, Any]: + return _text( + _broadcast( + selector=args["selector"], + params=args.get("params"), + where=args.get("where"), + bindings=args.get("bindings"), + fire_at=args.get("fire_at"), + on_late=args.get("on_late", "skip"), + llm_reasoning=args.get("llm_reasoning"), + ) + ) + + +@tool( + "await_replies", + "Collect replies for a broadcast() call. Subscribes to the " + "correlation reply subject, drains for up to `timeout` seconds (or " + "until `until` replies have arrived), then returns the list.", + { + "correlation_id": str, "timeout": float, "until": int, + "poll_interval": float, + }, +) +async def await_replies(args: dict[str, Any]) -> dict[str, Any]: + return _text( + _await_replies( + correlation_id=args["correlation_id"], + timeout=float(args.get("timeout", 10.0)), + until=int(args["until"]) if args.get("until") is not None else None, + poll_interval=float(args.get("poll_interval", 0.05)), + ) + ) + + +# Other invocation helpers + + @tool( "invoke_device_with_fallback", "Call a function with automatic fallback across a list of device IDs. " @@ -148,12 +227,12 @@ async def get_device_status(args: dict[str, Any]) -> dict[str, Any]: return _text(_get_device_status(device_id=args["device_id"])) -# Backward-compatible (long-deprecated — prefer discover() / discover_labels()) +# Backward-compatible (long-deprecated -- prefer discover() / invoke()) @tool( "discover_devices", - "Deprecated — use discover() and discover_labels() instead. Discovers " + "Deprecated -- use discover() and discover_labels() instead. Discovers " "all devices with full function schemas.", {"device_type": str, "refresh": bool}, ) @@ -176,7 +255,10 @@ def create_device_connect_server(name: str = "device-connect"): tools=[ discover_labels, discover, - invoke_device, + invoke, + invoke_many, + broadcast, + await_replies, invoke_device_with_fallback, get_device_status, discover_devices, @@ -187,7 +269,10 @@ def create_device_connect_server(name: str = "device-connect"): __all__ = [ "discover_labels", "discover", - "invoke_device", + "invoke", + "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py index f934024..c18ed7e 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/langchain.py @@ -4,16 +4,16 @@ """LangChain adapter — wraps Device Connect tools as LangChain StructuredTools. -Selector-driven discovery keeps LLM context small: +Selector-driven discovery and invocation keep LLM context small: from device_connect_agent_tools import connect from device_connect_agent_tools.adapters.langchain import ( - discover_labels, discover, invoke_device, + discover_labels, discover, invoke, invoke_many, ) from langgraph.prebuilt import create_react_agent connect() - agent = create_react_agent(model, [discover_labels, discover, invoke_device]) + agent = create_react_agent(model, [discover_labels, discover, invoke, invoke_many]) Requires: pip install device-connect-agent-tools[langchain] """ @@ -24,27 +24,38 @@ discover as _discover, discover_labels as _discover_labels, discover_devices as _discover_devices, - invoke_device as _invoke_device, + invoke as _invoke, + invoke_many as _invoke_many, + broadcast as _broadcast, + await_replies as _await_replies, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) -# Selector-driven discovery tools (recommended) +# Selector-driven discovery (recommended) discover_labels = StructuredTool.from_function(_discover_labels) discover = StructuredTool.from_function(_discover) -# Invocation tools -invoke_device = StructuredTool.from_function(_invoke_device) +# Selector-driven invocation (recommended) +invoke = StructuredTool.from_function(_invoke) +invoke_many = StructuredTool.from_function(_invoke_many) +broadcast = StructuredTool.from_function(_broadcast) +await_replies = StructuredTool.from_function(_await_replies) + +# Other invocation helpers invoke_device_with_fallback = StructuredTool.from_function(_invoke_device_with_fallback) get_device_status = StructuredTool.from_function(_get_device_status) -# Backward-compatible (long-deprecated — prefer discover() / discover_labels()) +# Backward-compatible (long-deprecated -- prefer discover() / invoke()) discover_devices = StructuredTool.from_function(_discover_devices) __all__ = [ "discover_labels", "discover", - "invoke_device", + "invoke", + "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py index 848f362..b68c16b 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands.py @@ -4,16 +4,16 @@ """Strands adapter — wraps Device Connect tools with @strands.tool. -Selector-driven discovery keeps LLM context small: +Selector-driven discovery and invocation keep LLM context small: from device_connect_agent_tools import connect from device_connect_agent_tools.adapters.strands import ( - discover_labels, discover, invoke_device, + discover_labels, discover, invoke, invoke_many, ) from strands import Agent connect() - agent = Agent(tools=[discover_labels, discover, invoke_device]) + agent = Agent(tools=[discover_labels, discover, invoke, invoke_many]) agent("What devices are online?") Requires: pip install device-connect-agent-tools[strands] @@ -25,27 +25,38 @@ discover as _discover, discover_labels as _discover_labels, discover_devices as _discover_devices, - invoke_device as _invoke_device, + invoke as _invoke, + invoke_many as _invoke_many, + broadcast as _broadcast, + await_replies as _await_replies, invoke_device_with_fallback as _invoke_device_with_fallback, get_device_status as _get_device_status, ) -# Selector-driven discovery tools (recommended) +# Selector-driven discovery (recommended) discover_labels = strands_tool(_discover_labels) discover = strands_tool(_discover) -# Invocation tools -invoke_device = strands_tool(_invoke_device) +# Selector-driven invocation (recommended) +invoke = strands_tool(_invoke) +invoke_many = strands_tool(_invoke_many) +broadcast = strands_tool(_broadcast) +await_replies = strands_tool(_await_replies) + +# Other invocation helpers invoke_device_with_fallback = strands_tool(_invoke_device_with_fallback) get_device_status = strands_tool(_get_device_status) -# Backward-compatible (long-deprecated — prefer discover() / discover_labels()) +# Backward-compatible (long-deprecated -- prefer discover() / invoke()) discover_devices = strands_tool(_discover_devices) __all__ = [ "discover_labels", "discover", - "invoke_device", + "invoke", + "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py index a3f0cf5..c5f5e67 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/adapters/strands_agent.py @@ -62,7 +62,8 @@ async def prepare(self) -> Dict[str, Any]: from device_connect_agent_tools.adapters.strands import ( discover_labels, discover, - invoke_device, + invoke, + invoke_many, invoke_device_with_fallback, get_device_status, ) @@ -74,7 +75,7 @@ async def prepare(self) -> Dict[str, Any]: model=AnthropicModel(model_id=self._model_id, max_tokens=self._max_tokens), tools=[ discover_labels, discover, - invoke_device, invoke_device_with_fallback, get_device_status, + invoke, invoke_many, invoke_device_with_fallback, get_device_status, ], system_prompt=system_prompt, ) @@ -120,14 +121,18 @@ def _build_system_prompt(self) -> str: f"functions, or events. Examples:\n" f" device(category:camera, location:zone-A/*)\n" f" device(robot-001).function(direction:write)\n" - f" function(safety:critical)\n" - f" - invoke_device(device_id, function, params) -- call a device function\n\n" + f" function(safety:critical)\n\n" + f"INVOCATION TOOLS:\n" + f" - invoke(selector, params) -- call exactly one function. " + f"Selector must resolve to one (device, function) tuple.\n" + f" - invoke_many(selector, params) -- fan out a function call " + f"over a selector-resolved set in parallel.\n\n" f"INSTRUCTIONS:\n" f"When you receive device events, you MUST:\n" f"1. Analyze the events\n" f"2. Use discover() with a function-scoped selector to check " f"available functions if needed\n" - f"3. Use invoke_device() to interact with devices\n" + f"3. Use invoke() or invoke_many() to interact with devices\n" f"4. Report what you found and what actions you took\n\n" f"Always provide llm_reasoning when invoking devices to explain your decision.\n" f"Always call at least one tool per batch of events." diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py b/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py index dae997c..4dce5fb 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/connection.py @@ -409,13 +409,33 @@ async def _async_invoke( # ── Broadcast ──────────────────────────────────────────────────── + def publish_broadcast(self, envelope: Dict[str, Any]) -> None: + """Publish a selector-driven broadcast envelope to the fanout subject. + + The envelope shape is documented in + ``device_connect_edge.device.DeviceRuntime._broadcast_subscription``; + every device subscribed to ``device-connect..broadcast`` + receives the message and self-elects via ``targets`` and + the optional ``where`` predicate. + """ + return self._run(self._async_publish_broadcast(envelope)) + + async def _async_publish_broadcast(self, envelope: Dict[str, Any]) -> None: + subject = f"device-connect.{self.zone}.broadcast" + await self._client.publish(subject, json.dumps(envelope).encode()) + def broadcast( self, function: str, params: Optional[Dict[str, Any]] = None, timeout: float = 5.0, ) -> List[Dict[str, Any]]: - """Invoke a function on all discovered devices and collect results.""" + """Invoke a function on all discovered devices and collect results. + + Sequential sync fan-out (one invoke per device). Predates the + selector-driven broadcast tool; left in place for callers that want + a simple "call this on everyone" without setting up subscriptions. + """ devices = self.list_devices() results = [] for d in devices: diff --git a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py index db71bc2..c99c5bc 100644 --- a/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py +++ b/packages/device-connect-agent-tools/device_connect_agent_tools/tools.py @@ -20,6 +20,7 @@ import logging import os +import time import uuid import warnings from typing import Any @@ -463,6 +464,683 @@ def discover_labels( return out +# ── Selector-driven operations ─────────────────────────────────── + + +# Default per-target timeout for invoke_many fan-out. Configurable per call. +DEFAULT_INVOKE_TIMEOUT = 30.0 + +# Cap on parallel worker threads for invoke_many fan-out. Larger fleets can +# raise this via the ``max_concurrency`` argument; the default keeps thread +# overhead bounded while still parallelising typical 10-100 device fan-outs. +DEFAULT_INVOKE_CONCURRENCY = 32 + + +def _resolve_function_tuples( + selector: str, +) -> tuple[list[dict] | None, dict[str, Any] | None]: + """Resolve a selector to (device_id, function_name) tuples for invocation. + + Walks pagination so callers do not have to. Returns ``(rows, None)`` on + success or ``(None, error_envelope)`` if the selector failed to parse, + used a non-function scope, or the registry was unreachable. + """ + rows: list[dict] = [] + offset = 0 + while True: + page = discover(selector, offset=offset, limit=DISCOVER_HARD_LIMIT) + if "error" in page: + return None, page + if page["scope"] not in ( + Scope.DEVICE_FUNCTION.value, Scope.FUNCTION_ONLY.value, + ): + return None, _empty_envelope( + scope=page["scope"], + error=_error( + "invalid_invoke_scope", + "invoke/invoke_many require a function-scoped selector " + "(device(...).function(...) or function(...)); got " + f"scope={page['scope']!r}", + ), + ) + rows.extend(page["results"]) + if page["next_offset"] is None: + break + offset = page["next_offset"] + return rows, None + + +def _shape_invoke_response( + response: dict[str, Any], + device_id: str, + function_name: str, +) -> dict[str, Any]: + """Normalize a JSON-RPC response into a {success, result|error} envelope. + + JSON-RPC error objects arrive as ``{"code": int, "message": str}`` from + the wire; this maps them to the structured ``{code: str, message: str}`` + error shape that the rest of the agent surface uses. + """ + if "error" in response: + err = response["error"] + if isinstance(err, dict): + code = str(err.get("code", "invoke_failed")) + message = str(err.get("message", err)) + else: + code, message = "invoke_failed", str(err) + return { + "success": False, + "device_id": device_id, + "function": function_name, + "error": {"code": code, "message": message}, + } + return { + "success": True, + "device_id": device_id, + "function": function_name, + "result": response.get("result", {}), + } + + +def invoke( + selector: str, + params: dict[str, Any] | None = None, + llm_reasoning: str | None = None, +) -> dict[str, Any]: + """Resolve a selector to one (device, function) tuple and invoke it. + + Use this when the call is unambiguous -- one device, one function. + The selector must use ``device().function()`` or + ``function()`` scope. + + Args: + selector: Selector expression resolving to exactly one function. + params: Function parameters dict. Do NOT put ``llm_reasoning`` + inside ``params``. + llm_reasoning: Decision rationale for observability. + + Returns: + On success: ``{"success": True, "device_id": ..., "function": ..., + "result": ...}``. + On failure: ``{"success": False, "error": {"code": ..., + "message": ...}}``. Codes include the discover() codes plus + ``no_match`` (zero matches), ``ambiguous_match`` (multiple + matches), ``invalid_invoke_scope`` (selector did not target + functions), and ``invoke_failed`` (the device returned an error). + """ + rows, error_envelope = _resolve_function_tuples(selector) + if error_envelope is not None: + return {"success": False, "error": error_envelope["error"]} + + if not rows: + return { + "success": False, + "error": _error( + "no_match", + f"selector matched 0 functions: {selector!r}", + ), + } + if len(rows) > 1: + return { + "success": False, + "error": _error( + "ambiguous_match", + f"selector matched {len(rows)} functions, expected exactly 1: " + f"{selector!r}", + ), + "candidates": [ + {"device_id": r.get("device_id"), "function": r.get("name")} + for r in rows[:10] + ], + } + + row = rows[0] + device_id = row.get("device_id") or "" + function_name = row.get("name") or "" + + trace_id = f"trace-{uuid.uuid4().hex[:12]}" + if llm_reasoning: + truncated = ( + llm_reasoning[:200] + "..." + if len(llm_reasoning) > 200 else llm_reasoning + ) + logger.info( + "[%s] [%s::%s] Reason: %s", + trace_id, device_id, function_name, truncated, + ) + + try: + conn = get_connection() + clean = {k: v for k, v in (params or {}).items() if k != "llm_reasoning"} + response = conn.invoke(device_id, function_name, params=clean) + except Exception as e: + logger.error( + "[%s] %s::%s -> ERROR: %s", + trace_id, device_id, function_name, e, + ) + return { + "success": False, + "device_id": device_id, + "function": function_name, + "error": _error("invoke_failed", str(e)), + } + return _shape_invoke_response(response, device_id, function_name) + + +def invoke_many( + selector: str, + params: dict[str, Any] | None = None, + timeout: float = DEFAULT_INVOKE_TIMEOUT, + max_concurrency: int = DEFAULT_INVOKE_CONCURRENCY, + llm_reasoning: str | None = None, +) -> dict[str, Any]: + """Resolve a selector to (device, function) tuples and invoke each in parallel. + + Returns aggregated results with partial-failure semantics: a single + target's failure does not abort the rest. Each target gets ``timeout`` + seconds; the overall call returns once every target has finished or + timed out. + + Args: + selector: Function-scoped selector + (``device(...).function(...)`` or ``function(...)``). + params: Function parameters dict applied to every target. + timeout: Per-target timeout in seconds. + max_concurrency: Cap on parallel worker threads. + llm_reasoning: Decision rationale for observability. + + Returns: + ``{"candidates": N, "matched": N, "succeeded": S, "failed": F, + "results": [{device_id, function, result}, ...], + "errors": [{device_id, function, error}, ...]}``. + + ``candidates`` is the count returned by the selector resolver. + ``matched`` is the same value in this release; once edge-side + ``where`` predicates land, ``matched`` will narrow below + ``candidates`` to reflect post-predicate self-election. + + On selector parse / connection failure the envelope is returned + with all counts at zero plus a top-level ``error`` field. + """ + import concurrent.futures + + rows, error_envelope = _resolve_function_tuples(selector) + if error_envelope is not None: + return { + "candidates": 0, "matched": 0, "succeeded": 0, "failed": 0, + "results": [], "errors": [], "error": error_envelope["error"], + } + + out: dict[str, Any] = { + "candidates": len(rows), + "matched": len(rows), + "succeeded": 0, + "failed": 0, + "results": [], + "errors": [], + } + if not rows: + return out + + workers = max(1, min(max_concurrency, len(rows))) + clean = {k: v for k, v in (params or {}).items() if k != "llm_reasoning"} + + def call_one(row: dict) -> dict[str, Any]: + device_id = row.get("device_id") or "" + function_name = row.get("name") or "" + try: + conn = get_connection() + response = conn.invoke( + device_id, function_name, params=clean, timeout=timeout, + ) + except Exception as e: + response = {"error": {"code": "invoke_failed", "message": str(e)}} + return _shape_invoke_response(response, device_id, function_name) + + if llm_reasoning: + truncated = ( + llm_reasoning[:200] + "..." + if len(llm_reasoning) > 200 else llm_reasoning + ) + logger.info( + "[invoke_many::%d targets] Reason: %s", len(rows), truncated, + ) + + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as exe: + futures = [exe.submit(call_one, row) for row in rows] + for future in concurrent.futures.as_completed(futures): + shaped = future.result() + if shaped["success"]: + out["results"].append({ + "device_id": shaped["device_id"], + "function": shaped["function"], + "result": shaped["result"], + }) + out["succeeded"] += 1 + else: + out["errors"].append({ + "device_id": shaped["device_id"], + "function": shaped["function"], + "error": shaped["error"], + }) + out["failed"] += 1 + return out + + +def broadcast( + selector: str, + params: dict[str, Any] | None = None, + where: str | None = None, + bindings: dict[str, Any] | None = None, + fire_at: float | None = None, + on_late: str = "skip", + llm_reasoning: str | None = None, +) -> dict[str, Any]: + """Async selector-driven fan-out. Returns immediately with a correlation id. + + Use ``broadcast`` when the caller does not want to block on the slowest + device. Each candidate self-elects via the optional ``where`` predicate + (CEL, evaluated at the edge against the device's identity, labels, live + status, and the shared ``bindings``) and emits its reply as an event on + a per-device subject keyed by ``correlation_id``:: + + device-connect...event.async_reply. + + Subscribe to those replies via ``subscribe('correlation:')`` or wait + for them with ``await_replies(correlation_id, timeout=...)``. + + Args: + selector: Function-scoped selector. The selector must resolve to a + single function name across the matched devices; if multiple + functions match, an ``ambiguous_function`` error is returned. + params: Function parameters dict applied to every target. + where: Optional CEL predicate evaluated at the edge per candidate + (e.g. ``"status.battery > 50"``, ``"mask[row][col] == 1"``). + Validated at the dispatcher before publication so syntax + errors return immediately rather than reaching the wire. + bindings: Shared payload merged into the predicate context as + ``bindings.``. Keep small (selection masks, thresholds, + top-K rankings); the same bytes ship to every device. + fire_at: Optional wall-clock epoch seconds. Each device holds the + message and fires its function from its own clock at + ``fire_at`` for synchronized fan-out. + on_late: Policy when a device receives a ``fire_at`` message after + the deadline. ``"skip"`` (default) drops the call; ``"fire"`` + executes immediately. + llm_reasoning: Decision rationale for observability. + + Returns: + On success: ``{"correlation_id": "br-...", "candidates": N, + "selector": ..., "function": ...}``. + On failure: ``{"candidates": 0, "error": {"code", "message"}}`` + with codes including the discover() codes, + ``invalid_invoke_scope``, ``ambiguous_function``, + ``invalid_predicate``, and ``invalid_on_late``. + """ + if on_late not in ("skip", "fire"): + return { + "candidates": 0, + "error": _error( + "invalid_on_late", + f"on_late must be 'skip' or 'fire', got {on_late!r}", + ), + } + + rows, error_envelope = _resolve_function_tuples(selector) + if error_envelope is not None: + return {"candidates": 0, "error": error_envelope["error"]} + + if not rows: + # Empty fan-out: still mint a correlation id so callers waiting on + # replies see a clean "no candidates" rather than a hang. + return { + "correlation_id": f"br-{uuid.uuid4().hex[:12]}", + "candidates": 0, + "selector": selector, + } + + # Broadcast assumes one function per call. If the selector resolves to + # multiple distinct functions, surface that as a structured error so + # the caller can either narrow the selector or split into multiple + # broadcasts. + function_names = {row.get("name") for row in rows if row.get("name")} + if len(function_names) != 1: + return { + "candidates": len(rows), + "error": _error( + "ambiguous_function", + f"selector resolved to {len(function_names)} distinct " + "functions; broadcast requires exactly one function per call: " + f"{sorted(function_names)!r}", + ), + } + function_name = next(iter(function_names)) + + # Compile-validate the where predicate before going to the wire so a + # syntax error short-circuits without bothering devices. + if where is not None: + try: + from device_connect_edge.predicate import compile_where + compile_where(where) + except Exception as e: + return { + "candidates": len(rows), + "error": _error("invalid_predicate", str(e)), + } + + correlation_id = f"br-{uuid.uuid4().hex[:12]}" + targets = sorted({ + row.get("device_id") for row in rows if row.get("device_id") + }) + clean_params = { + k: v for k, v in (params or {}).items() if k != "llm_reasoning" + } + + envelope: dict[str, Any] = { + "correlation_id": correlation_id, + "function": function_name, + "params": clean_params, + "targets": targets, + } + if where: + envelope["where"] = where + if bindings: + envelope["bindings"] = bindings + if fire_at is not None: + envelope["fire_at"] = float(fire_at) + envelope["on_late"] = on_late + + if llm_reasoning: + truncated = ( + llm_reasoning[:200] + "..." + if len(llm_reasoning) > 200 else llm_reasoning + ) + logger.info( + "[broadcast::%s::%d targets] Reason: %s", + correlation_id, len(targets), truncated, + ) + + try: + conn = get_connection() + conn.publish_broadcast(envelope) + except Exception as e: + logger.error("broadcast publish failed: %s", e) + return { + "candidates": len(targets), + "error": _error("connection_error", str(e)), + } + + return { + "correlation_id": correlation_id, + "candidates": len(targets), + "selector": selector, + "function": function_name, + } + + +# ── Selector-driven subscription ───────────────────────────────── + + +# Sentinel used to recognise the broadcast-reply form of a subscribe +# selector (``correlation:``). Kept short so the selector reads +# naturally; the parser matches an exact prefix. +_CORRELATION_PREFIX = "correlation:" + + +class Subscription: + """A live subscription handle returned by :func:`subscribe`. + + Two selector forms produce a subscription: + + * ``"correlation:"`` -- replies from a prior :func:`broadcast` call, + keyed by ``correlation_id`` and routed across all devices that fired. + * Event-scoped selectors (``event()`` or + ``device(...).event()``) -- a multiplex of matching events + across the resolved candidate set. + + The handle exposes a sync ``read`` API that drains buffered messages. + Use as a context manager (or call :meth:`close`) to tear the + underlying messaging subscription down deterministically:: + + with subscribe("correlation:" + cid) as sub: + for reply in sub.iter(timeout=5.0): + process(reply) + """ + + def __init__(self, conn: Any, inbox_names: list[str]): + self._conn = conn + self._inbox_names = list(inbox_names) + self._closed = False + self._cursor = 0 # index into the concatenated message stream + + def read(self, max_messages: int | None = None) -> list[dict[str, Any]]: + """Drain currently buffered messages without blocking. + + Returns parsed payload dicts (already JSON-decoded by the + connection's buffered subscription path). Subsequent calls return + only messages that arrived after the previous call. + + Race-safe against the messaging callback that appends to the same + inbox: each inbox is read by snapshotting its current length and + truncating only that prefix, so a message that arrives during + iteration stays buffered for the next ``read``. + """ + if self._closed: + return [] + out: list[dict[str, Any]] = [] + for name in self._inbox_names: + buf = self._conn._inbox.get(name) or [] + # Snapshot the consumed prefix length BEFORE iterating, then + # truncate by exactly that many items. Any message appended by + # the messaging callback between the snapshot and the truncation + # remains buffered for a subsequent ``read``. + n = len(buf) + for subject, payload in buf[:n]: + if not isinstance(payload, dict): + payload = {"raw": payload} + out.append({**payload, "_subject": subject}) + self._conn._inbox[name] = buf[n:] + if max_messages is not None: + out = out[:max_messages] + return out + + def iter(self, timeout: float = 5.0, poll_interval: float = 0.05): + """Yield messages until ``timeout`` elapses with no new arrivals. + + ``timeout`` resets each time at least one message is yielded, so + callers can drain a steady stream without re-parameterising the + wait. Use ``read`` instead for one-shot draining. + """ + deadline = time.monotonic() + timeout + while not self._closed: + new = self.read() + if new: + for msg in new: + yield msg + deadline = time.monotonic() + timeout + continue + if time.monotonic() >= deadline: + return + time.sleep(poll_interval) + + def __iter__(self): + """Allow ``for msg in sub:`` with a default 30-second idle timeout. + + Delegates to :meth:`iter` with sensible defaults so the idiomatic + Python iteration form works. Use ``sub.iter(timeout=...)`` directly + when the default does not fit. + """ + return self.iter(timeout=30.0, poll_interval=0.05) + + def close(self) -> None: + """Tear down the underlying messaging subscriptions.""" + if self._closed: + return + self._closed = True + for name in self._inbox_names: + try: + self._conn.unsubscribe_buffered(name) + except Exception: # pragma: no cover - cleanup best effort + logger.debug("close: unsubscribe %s failed", name, exc_info=True) + + def __enter__(self) -> "Subscription": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self.close() + + +def _correlation_subjects(conn: Any, correlation_id: str) -> list[str]: + """Build the per-device wildcard reply subjects for a correlation id. + + The reply template is ``device-connect...event + .async_reply.``; ```` is single-token wildcarded + so a subscription receives replies from any device that fires the + broadcast without having to enumerate them up-front. + """ + return [ + f"device-connect.{conn.zone}.*.event.async_reply.{correlation_id}", + ] + + +def _event_subjects_for_selector(selector: str) -> tuple[list[str] | None, dict[str, Any] | None]: + """Resolve an event-scoped selector to per-device subjects. + + Returns ``(subjects, None)`` on success or ``(None, error_envelope)`` + if the selector failed to parse or used a non-event scope. + """ + rows: list[dict] = [] + offset = 0 + while True: + page = discover(selector, offset=offset, limit=DISCOVER_HARD_LIMIT) + if "error" in page: + return None, page + if page["scope"] not in (Scope.DEVICE_EVENT.value, Scope.EVENT_ONLY.value): + return None, _empty_envelope( + scope=page["scope"], + error=_error( + "invalid_subscribe_scope", + "subscribe requires an event-scoped selector " + "(device(...).event(...) or event(...)) or " + "'correlation:'; got " + f"scope={page['scope']!r}", + ), + ) + rows.extend(page["results"]) + if page["next_offset"] is None: + break + offset = page["next_offset"] + + conn = get_connection() + subjects: list[str] = [] + seen: set[str] = set() + for row in rows: + device_id = row.get("device_id") or "" + event_name = row.get("name") or "" + if not device_id or not event_name: + continue + subj = f"device-connect.{conn.zone}.{device_id}.event.{event_name}" + if subj not in seen: + seen.add(subj) + subjects.append(subj) + return subjects, None + + +def subscribe(selector: str) -> Subscription: + """Subscribe to events or broadcast replies matching a selector. + + Args: + selector: One of: + - ``"correlation:"`` for broadcast replies of a prior call. + - An event-scoped selector (``event()`` or + ``device(...).event()``) for live event streams. + + Returns: + A :class:`Subscription` handle. Iterate with ``sub.iter(timeout)`` + or drain currently-buffered messages with ``sub.read()``. Always + close (or use ``with``) to tear the underlying subscription down. + + Raises: + ValueError on selector errors. The selector string is checked at + the boundary; downstream subscribe calls are not retried, so a + parse error fails fast. + """ + if not isinstance(selector, str) or not selector.strip(): + raise ValueError("subscribe selector must be a non-empty string") + + conn = get_connection() + if selector.startswith(_CORRELATION_PREFIX): + correlation_id = selector[len(_CORRELATION_PREFIX):].strip() + if not correlation_id: + raise ValueError( + "correlation form must be 'correlation:' with non-empty id" + ) + subjects = _correlation_subjects(conn, correlation_id) + inbox_prefix = f"sub-corr-{correlation_id}-{uuid.uuid4().hex[:8]}" + else: + subjects, error_envelope = _event_subjects_for_selector(selector) + if error_envelope is not None: + err = error_envelope.get("error") + msg = err.get("message", str(err)) if isinstance(err, dict) else str(err) + raise ValueError(msg) + if not subjects: + # Nothing to subscribe to. Return an idle Subscription so the + # caller's ``with subscribe(...) as sub: ...`` pattern still + # works without raising; ``read``/``iter`` will yield nothing. + return Subscription(conn, inbox_names=[]) + inbox_prefix = f"sub-evt-{uuid.uuid4().hex[:8]}" + + inbox_names: list[str] = [] + for i, subj in enumerate(subjects): + name = f"{inbox_prefix}-{i}" + conn.subscribe_buffered(subj, name=name) + inbox_names.append(name) + return Subscription(conn, inbox_names=inbox_names) + + +def await_replies( + correlation_id: str, + timeout: float = 10.0, + until: int | None = None, + poll_interval: float = 0.05, +) -> list[dict[str, Any]]: + """Block until ``timeout`` elapses or ``until`` replies have arrived. + + A sync helper for the common broadcast pattern: caller fires a + :func:`broadcast`, then waits for some replies. Builds a one-shot + subscription on the correlation reply subject, drains it, and tears + down before returning. + + Args: + correlation_id: The id returned by :func:`broadcast`. + timeout: Overall wall-clock limit in seconds. + until: Stop early once this many replies have been collected. + poll_interval: How often the helper polls the subscription buffer. + + Returns: + A list of reply payload dicts, each with at least + ``{correlation_id, device_id, success, result|error, + actually_fired_at}``. + """ + if not correlation_id: + return [] + sub = subscribe(f"{_CORRELATION_PREFIX}{correlation_id}") + try: + replies: list[dict[str, Any]] = [] + deadline = time.monotonic() + timeout + while True: + new = sub.read() + replies.extend(new) + if until is not None and len(replies) >= until: + break + if time.monotonic() >= deadline: + break + time.sleep(poll_interval) + return replies + finally: + sub.close() + + # ── Hierarchical discovery tools ───────────────────────────────── @@ -650,22 +1328,20 @@ def invoke_device( params: dict[str, Any] | None = None, llm_reasoning: str | None = None, ) -> dict[str, Any]: - """Call a function on a Device Connect device. + """Call a function on a Device Connect device (deprecated; use invoke()). Args: device_id: Target device ID (e.g., "robot-001", "camera-001"). - function: Function name to call (e.g., "start_cleaning", "capture_image"). - params: Function parameters as a dictionary. Check get_device_functions() for schemas. - Do NOT put llm_reasoning inside params. - llm_reasoning: Why you're calling this function — for observability. - - Example: - result = invoke_device( - device_id="robot-001", function="start_cleaning", - params={"zone": "zone-A"}, - llm_reasoning="Camera detected spill in zone-A" - ) + function: Function name to call. + params: Function parameters as a dictionary. + llm_reasoning: Why you are calling this function (for observability). """ + warnings.warn( + "invoke_device(device_id, function, ...) is deprecated; use " + "invoke('device().function()', params) instead.", + DeprecationWarning, + stacklevel=2, + ) trace_id = f"trace-{uuid.uuid4().hex[:12]}" if llm_reasoning: truncated = llm_reasoning[:200] + "..." if len(llm_reasoning) > 200 else llm_reasoning diff --git a/packages/device-connect-agent-tools/pyproject.toml b/packages/device-connect-agent-tools/pyproject.toml index ec0f198..606073c 100644 --- a/packages/device-connect-agent-tools/pyproject.toml +++ b/packages/device-connect-agent-tools/pyproject.toml @@ -37,6 +37,7 @@ strands = ["strands-agents>=1.0"] langchain = ["langchain-core>=0.2"] claude = ["claude-agent-sdk>=0.1"] mcp = ["fastmcp>=1.0"] +predicate = ["device-connect-edge[predicate]"] dev = [ "pytest>=8.0", "pytest-asyncio>=0.23", diff --git a/packages/device-connect-agent-tools/tests/test_broadcast.py b/packages/device-connect-agent-tools/tests/test_broadcast.py new file mode 100644 index 0000000..e8d8831 --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_broadcast.py @@ -0,0 +1,201 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector-driven ``broadcast`` tool. + +Uses the same labeled mock fleet (cam-001, cam-002, robot-001, sensor-001) +as the discover/invoke tests so selectors exercise real device, function, +and event names. +""" +import json +from unittest.mock import MagicMock, patch + +import pytest + +from device_connect_agent_tools import tools as tools_mod + + +SAMPLE_DEVICES = [ + { + "device_id": "cam-001", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "cam-002", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "sensor-001", + "device_type": "temperature_sensor", + "location": "lab-B", + "status": {"state": "online"}, + "identity": {"device_type": "temperature_sensor"}, + "labels": {"category": "sensor"}, + "functions": [ + { + "name": "get_reading", + "parameters": {}, + "labels": {"direction": "read"}, + }, + ], + "events": [], + }, +] + + +@pytest.fixture +def mock_conn(): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + conn.zone = "default" + # Capture the published envelope for assertions. + published: list[dict] = [] + conn.publish_broadcast.side_effect = lambda env: published.append(env) + conn._published = published + with patch.object(tools_mod, "get_connection", return_value=conn): + yield conn + + +# -- broadcast ------------------------------------------------------ + + +class TestBroadcast: + def test_returns_correlation_id_and_candidates(self, mock_conn): + r = tools_mod.broadcast("device(*).function(capture_image)") + assert r["correlation_id"].startswith("br-") + assert r["candidates"] == 2 + assert r["function"] == "capture_image" + assert "error" not in r + + def test_envelope_carries_function_and_targets(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + params={"resolution": "4k"}, + ) + env = mock_conn._published[0] + assert env["function"] == "capture_image" + assert env["params"] == {"resolution": "4k"} + assert sorted(env["targets"]) == ["cam-001", "cam-002"] + # No optional fields when caller did not set them. + assert "where" not in env + assert "bindings" not in env + assert "fire_at" not in env + assert "on_late" not in env + + def test_where_and_bindings_propagate_to_envelope(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + where="status.battery > 50", + bindings={"threshold": 80}, + ) + env = mock_conn._published[0] + assert env["where"] == "status.battery > 50" + assert env["bindings"] == {"threshold": 80} + + def test_fire_at_propagates_with_default_on_late(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + fire_at=123456789.0, + ) + env = mock_conn._published[0] + assert env["fire_at"] == 123456789.0 + assert env["on_late"] == "skip" + + def test_fire_at_with_explicit_on_late_fire(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + fire_at=123.0, on_late="fire", + ) + env = mock_conn._published[0] + assert env["on_late"] == "fire" + + def test_invalid_on_late_rejected(self, mock_conn): + r = tools_mod.broadcast( + "device(*).function(capture_image)", on_late="bogus", + ) + assert r["candidates"] == 0 + assert r["error"]["code"] == "invalid_on_late" + assert mock_conn.publish_broadcast.call_count == 0 + + def test_ambiguous_function_rejected(self, mock_conn): + # function(direction:read) resolves to multiple distinct functions + # (get_reading + dispatch_robot's get_status if it had read; here + # it just hits sensor's get_reading and possibly more). With our + # SAMPLE_DEVICES this matches just get_reading, so artificially + # broaden by picking a selector that crosses functions: + r = tools_mod.broadcast("device(*).function(*)") + assert r["candidates"] == 3 + assert r["error"]["code"] == "ambiguous_function" + + def test_zero_matches_returns_correlation_with_zero(self, mock_conn): + r = tools_mod.broadcast("device(*).function(does_not_exist)") + assert r["candidates"] == 0 + assert r["correlation_id"].startswith("br-") + # No envelope was published (no targets). + assert mock_conn.publish_broadcast.call_count == 0 + + def test_invalid_scope_rejected(self, mock_conn): + r = tools_mod.broadcast("device(cam-001)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_selector_parse_error_propagated(self, mock_conn): + r = tools_mod.broadcast("widgets(*)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "selector_parse_error" + + def test_invalid_predicate_rejected_before_publish(self, mock_conn): + # The predicate is compile-validated at the dispatcher; a syntax + # error short-circuits without publishing. + try: + import celpy # noqa: F401 + except ImportError: + pytest.skip("cel-python not installed") + r = tools_mod.broadcast( + "device(*).function(capture_image)", where="a > > b", + ) + assert r["error"]["code"] == "invalid_predicate" + assert mock_conn.publish_broadcast.call_count == 0 + + def test_publish_failure_returns_connection_error(self): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + conn.zone = "default" + conn.publish_broadcast.side_effect = RuntimeError("messaging down") + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.broadcast("device(*).function(capture_image)") + assert r["error"]["code"] == "connection_error" + assert "messaging down" in r["error"]["message"] + + def test_llm_reasoning_stripped_from_params(self, mock_conn): + tools_mod.broadcast( + "device(*).function(capture_image)", + params={"resolution": "4k", "llm_reasoning": "should not appear"}, + ) + env = mock_conn._published[0] + assert "llm_reasoning" not in env["params"] diff --git a/packages/device-connect-agent-tools/tests/test_claude_adapter.py b/packages/device-connect-agent-tools/tests/test_claude_adapter.py index b0e2ac6..4960a49 100644 --- a/packages/device-connect-agent-tools/tests/test_claude_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_claude_adapter.py @@ -68,7 +68,10 @@ def _mock_sdk_and_connection(): "discover_labels", "discover", "discover_devices", - "invoke_device", + "invoke", + "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", ) diff --git a/packages/device-connect-agent-tools/tests/test_invoke.py b/packages/device-connect-agent-tools/tests/test_invoke.py new file mode 100644 index 0000000..aae1a83 --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_invoke.py @@ -0,0 +1,336 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector-driven ``invoke`` and ``invoke_many`` tools. + +Uses a small labeled fleet (cam-001, cam-002, robot-001, sensor-001) drawn +from the existing DC test driver vocabulary so every selector exercises +real device, function, and event names. +""" +import time +from unittest.mock import MagicMock, patch + +import pytest + +from device_connect_agent_tools import tools as tools_mod + + +# -- Fixtures ------------------------------------------------------- + + +SAMPLE_DEVICES = [ + { + "device_id": "cam-001", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "cam-002", + "device_type": "camera", + "location": "lab-A", + "status": {"state": "online"}, + "identity": {"device_type": "camera"}, + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [ + { + "name": "capture_image", + "parameters": {}, + "labels": {"direction": "write", "modality": "rgb"}, + }, + ], + "events": [], + }, + { + "device_id": "robot-001", + "device_type": "cleaner_robot", + "location": "lab-A", + "status": {"state": "idle"}, + "identity": {"device_type": "cleaner_robot"}, + "labels": {"category": "robot", "location": "lab-A"}, + "functions": [ + { + "name": "dispatch_robot", + "parameters": {}, + "labels": {"direction": "write", "safety": "critical"}, + }, + ], + "events": [], + }, + { + "device_id": "sensor-001", + "device_type": "temperature_sensor", + "location": "lab-B", + "status": {"state": "online"}, + "identity": {"device_type": "temperature_sensor"}, + "labels": {"category": "sensor", "location": "lab-B"}, + "functions": [ + { + "name": "get_reading", + "parameters": {}, + "labels": {"direction": "read"}, + }, + ], + "events": [], + }, +] + + +def _conn_with_invoke(invoke_side_effect): + """Return a mock Connection whose .invoke() applies ``invoke_side_effect``. + + ``invoke_side_effect`` is called with ``(device_id, function_name, + params, timeout)`` and must return a JSON-RPC response dict. + """ + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + + def _invoke(device_id, function_name, params=None, timeout=None): + return invoke_side_effect(device_id, function_name, params, timeout) + + conn.invoke.side_effect = _invoke + return conn + + +@pytest.fixture +def all_succeed_conn(): + def _ok(device_id, function_name, params, timeout): + return {"jsonrpc": "2.0", "id": "1", "result": { + "device_id": device_id, "function": function_name, "params": params, + }} + conn = _conn_with_invoke(_ok) + with patch.object(tools_mod, "get_connection", return_value=conn): + yield conn + + +# -- invoke --------------------------------------------------------- + + +class TestInvoke: + def test_single_match_returns_success(self, all_succeed_conn): + r = tools_mod.invoke( + "device(cam-001).function(capture_image)", + params={"resolution": "1080p"}, + ) + assert r["success"] is True + assert r["device_id"] == "cam-001" + assert r["function"] == "capture_image" + assert r["result"]["params"] == {"resolution": "1080p"} + + def test_function_only_selector_with_unique_name(self, all_succeed_conn): + r = tools_mod.invoke("function(get_reading)") + assert r["success"] is True + assert r["device_id"] == "sensor-001" + assert r["function"] == "get_reading" + + def test_no_match_returns_no_match_error(self, all_succeed_conn): + r = tools_mod.invoke("device(*).function(does_not_exist)") + assert r["success"] is False + assert r["error"]["code"] == "no_match" + assert "does_not_exist" in r["error"]["message"] + + def test_ambiguous_match_returns_error_with_candidates(self, all_succeed_conn): + # capture_image exists on both cam-001 and cam-002. + r = tools_mod.invoke("function(capture_image)") + assert r["success"] is False + assert r["error"]["code"] == "ambiguous_match" + assert "expected exactly 1" in r["error"]["message"] + ids = {c["device_id"] for c in r["candidates"]} + assert ids == {"cam-001", "cam-002"} + + def test_device_only_scope_rejected(self, all_succeed_conn): + # Device-only scope cannot resolve to a function. + r = tools_mod.invoke("device(robot-001)") + assert r["success"] is False + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_event_scope_rejected(self, all_succeed_conn): + r = tools_mod.invoke("event(reading)") + assert r["success"] is False + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_selector_parse_error_propagated(self, all_succeed_conn): + r = tools_mod.invoke("not a selector") + assert r["success"] is False + assert r["error"]["code"] == "selector_parse_error" + + def test_non_string_selector_rejected(self, all_succeed_conn): + r = tools_mod.invoke(None) # type: ignore[arg-type] + assert r["success"] is False + assert r["error"]["code"] == "invalid_selector" + + def test_jsonrpc_error_maps_to_invoke_failed(self): + def _err(device_id, function_name, params, timeout): + return { + "jsonrpc": "2.0", "id": "1", + "error": {"code": -32000, "message": "device busy"}, + } + conn = _conn_with_invoke(_err) + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke("device(robot-001).function(dispatch_robot)") + assert r["success"] is False + assert r["error"]["code"] == "-32000" + assert r["error"]["message"] == "device busy" + assert r["device_id"] == "robot-001" + assert r["function"] == "dispatch_robot" + + def test_connection_exception_returns_invoke_failed(self): + conn = MagicMock() + conn.list_devices.return_value = SAMPLE_DEVICES + conn.invoke.side_effect = RuntimeError("messaging down") + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke("device(cam-001).function(capture_image)") + assert r["success"] is False + assert r["error"]["code"] == "invoke_failed" + assert "messaging down" in r["error"]["message"] + + def test_llm_reasoning_stripped_from_params(self, all_succeed_conn): + tools_mod.invoke( + "device(cam-001).function(capture_image)", + params={"resolution": "1080p", "llm_reasoning": "should not appear"}, + llm_reasoning="caller reasoning", + ) + # Inspect the params actually delivered to the wire: + sent = all_succeed_conn.invoke.call_args.kwargs["params"] + assert "llm_reasoning" not in sent + assert sent["resolution"] == "1080p" + + +# -- invoke_many ---------------------------------------------------- + + +class TestInvokeMany: + def test_zero_matches_returns_empty_envelope(self, all_succeed_conn): + r = tools_mod.invoke_many("device(*).function(does_not_exist)") + assert r["candidates"] == 0 + assert r["matched"] == 0 + assert r["succeeded"] == 0 + assert r["failed"] == 0 + assert r["results"] == [] + assert r["errors"] == [] + assert "error" not in r + + def test_all_succeed(self, all_succeed_conn): + r = tools_mod.invoke_many("device(*).function(capture_image)") + assert r["candidates"] == 2 + assert r["matched"] == 2 + assert r["succeeded"] == 2 + assert r["failed"] == 0 + ids = {row["device_id"] for row in r["results"]} + assert ids == {"cam-001", "cam-002"} + # Each result row is shaped {device_id, function, result}. + for row in r["results"]: + assert row["function"] == "capture_image" + assert "result" in row + + def test_partial_failure_shape(self): + def _half_fail(device_id, function_name, params, timeout): + if device_id == "cam-001": + return {"jsonrpc": "2.0", "id": "1", "result": {"ok": True}} + return { + "jsonrpc": "2.0", "id": "1", + "error": {"code": -32000, "message": "down"}, + } + conn = _conn_with_invoke(_half_fail) + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke_many("device(*).function(capture_image)") + assert r["candidates"] == 2 + assert r["matched"] == 2 + assert r["succeeded"] == 1 + assert r["failed"] == 1 + assert {row["device_id"] for row in r["results"]} == {"cam-001"} + assert {row["device_id"] for row in r["errors"]} == {"cam-002"} + for row in r["errors"]: + assert row["error"]["code"] == "-32000" + assert row["error"]["message"] == "down" + + def test_invalid_scope_returns_error_envelope(self, all_succeed_conn): + r = tools_mod.invoke_many("device(robot-001)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "invalid_invoke_scope" + + def test_selector_parse_error_propagated(self, all_succeed_conn): + r = tools_mod.invoke_many("widgets(*)") + assert r["candidates"] == 0 + assert r["error"]["code"] == "selector_parse_error" + + def test_per_target_timeout_passed_to_connection(self, all_succeed_conn): + tools_mod.invoke_many( + "device(*).function(capture_image)", timeout=7.5, + ) + # Every conn.invoke call should carry the same timeout. + for call in all_succeed_conn.invoke.call_args_list: + assert call.kwargs["timeout"] == 7.5 + + def test_max_concurrency_caps_thread_pool(self, all_succeed_conn): + # The fan-out group has 3 targets (capture_image x2 + dispatch_robot + # don't share name; pick a selector that resolves to multiple). Use + # function(direction:write) which selects 4 distinct rows. + r = tools_mod.invoke_many( + "function(direction:write)", max_concurrency=1, + ) + assert r["candidates"] >= 2 + assert r["succeeded"] == r["candidates"] + + def test_connection_exception_recorded_per_target(self): + # Mix: cam-001 succeeds, cam-002's call raises locally. + def _mixed(device_id, function_name, params, timeout): + if device_id == "cam-002": + raise RuntimeError("messaging blip") + return {"jsonrpc": "2.0", "id": "1", "result": {"ok": True}} + conn = _conn_with_invoke(_mixed) + with patch.object(tools_mod, "get_connection", return_value=conn): + r = tools_mod.invoke_many("device(*).function(capture_image)") + assert r["succeeded"] == 1 + assert r["failed"] == 1 + cam002_err = next(e for e in r["errors"] if e["device_id"] == "cam-002") + assert cam002_err["error"]["code"] == "invoke_failed" + assert "messaging blip" in cam002_err["error"]["message"] + + def test_llm_reasoning_stripped_from_params(self, all_succeed_conn): + tools_mod.invoke_many( + "device(*).function(capture_image)", + params={"resolution": "4k", "llm_reasoning": "should not appear"}, + ) + for call in all_succeed_conn.invoke.call_args_list: + sent = call.kwargs["params"] + assert "llm_reasoning" not in sent + assert sent["resolution"] == "4k" + + +# -- _resolve_function_tuples --------------------------------------- + + +class TestResolveFunctionTuples: + def test_walks_all_pages(self, all_succeed_conn): + # Use a small DISCOVER_HARD_LIMIT temporarily. + with patch.object(tools_mod, "DISCOVER_HARD_LIMIT", 1): + rows, err = tools_mod._resolve_function_tuples( + "device(*).function(direction:write)" + ) + assert err is None + # 4 distinct (device, function) tuples for direction:write across the + # mock fleet (cam-001, cam-002, robot-001, sensor-001 set_threshold + # and set_location). With limit=1 per page, the resolver had to + # paginate through all of them. + assert len(rows) >= 2 + for row in rows: + assert "device_id" in row + assert "name" in row + + def test_propagates_discover_error(self, all_succeed_conn): + rows, err = tools_mod._resolve_function_tuples("not a selector") + assert rows is None + assert err is not None + assert err["error"]["code"] == "selector_parse_error" diff --git a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py index d647ee3..9aae070 100644 --- a/packages/device-connect-agent-tools/tests/test_langchain_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_langchain_adapter.py @@ -72,7 +72,10 @@ def _mock_langchain_and_connection(): EXPECTED_TOOLS = { "discover_labels", "discover", - "invoke_device", + "invoke", + "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/tests/test_strands_adapter.py b/packages/device-connect-agent-tools/tests/test_strands_adapter.py index a40b5ad..4e46ceb 100644 --- a/packages/device-connect-agent-tools/tests/test_strands_adapter.py +++ b/packages/device-connect-agent-tools/tests/test_strands_adapter.py @@ -55,7 +55,10 @@ def _mock_strands_and_connection(): EXPECTED_TOOLS = { "discover_labels", "discover", - "invoke_device", + "invoke", + "invoke_many", + "broadcast", + "await_replies", "invoke_device_with_fallback", "get_device_status", "discover_devices", diff --git a/packages/device-connect-agent-tools/tests/test_subscribe.py b/packages/device-connect-agent-tools/tests/test_subscribe.py new file mode 100644 index 0000000..a8b4be4 --- /dev/null +++ b/packages/device-connect-agent-tools/tests/test_subscribe.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the selector-driven subscribe + await_replies tools. + +The tests stand up a fake Connection that mirrors the buffered-inbox API +the production class exposes (``subscribe_buffered`` / +``unsubscribe_buffered`` / ``get_inbox`` / ``_inbox`` dict). Real +messaging is not exercised here; integration tests cover the wire. +""" +from unittest.mock import patch + +import pytest + +from device_connect_agent_tools import tools as tools_mod + + +SAMPLE_DEVICES = [ + { + "device_id": "cam-001", + "device_type": "camera", + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [], + "events": [ + {"name": "object_detected", "labels": {"modality": "rgb"}}, + ], + }, + { + "device_id": "cam-002", + "device_type": "camera", + "labels": {"category": "camera", "location": "lab-A"}, + "functions": [], + "events": [ + {"name": "object_detected", "labels": {"modality": "rgb"}}, + ], + }, +] + + +class FakeConnection: + """Minimal fake of the agent-tools Connection used by Subscription.""" + + def __init__(self, devices=None, zone="default"): + self.zone = zone + self.devices = devices or [] + self._inbox: dict[str, list[tuple]] = {} + self.subscribed_subjects: list[str] = [] + self.unsubscribed_names: list[str] = [] + + def list_devices(self): + return list(self.devices) + + def subscribe_buffered(self, subject: str, name: str | None = None) -> str: + name = name or subject + self._inbox[name] = [] + self.subscribed_subjects.append(subject) + return name + + def unsubscribe_buffered(self, name: str) -> None: + self.unsubscribed_names.append(name) + self._inbox.pop(name, None) + + def get_inbox(self, name: str | None = None): + if name is not None: + return {name: list(self._inbox.get(name, []))} + return {k: list(v) for k, v in self._inbox.items()} + + # Test helper: simulate a message landing on a given subject. + def deliver(self, subject: str, payload: dict): + for name, _ in list(self._inbox.items()): + self._inbox[name].append((subject, payload)) + + +@pytest.fixture +def fake_conn(): + conn = FakeConnection(devices=SAMPLE_DEVICES) + with patch.object(tools_mod, "get_connection", return_value=conn): + yield conn + + +# -- subscribe ------------------------------------------------------ + + +class TestSubscribe: + def test_correlation_form_subscribes_to_reply_subject(self, fake_conn): + sub = tools_mod.subscribe("correlation:abc-123") + assert len(fake_conn.subscribed_subjects) == 1 + subj = fake_conn.subscribed_subjects[0] + assert subj == "device-connect.default.*.event.async_reply.abc-123" + sub.close() + assert fake_conn.unsubscribed_names + + def test_correlation_form_with_empty_id_rejected(self, fake_conn): + with pytest.raises(ValueError): + tools_mod.subscribe("correlation:") + + def test_event_selector_subscribes_per_device(self, fake_conn): + sub = tools_mod.subscribe("device(*).event(object_detected)") + # Two cameras emit object_detected -> two subjects subscribed. + assert len(fake_conn.subscribed_subjects) == 2 + for subj in fake_conn.subscribed_subjects: + assert subj.startswith("device-connect.default.") + assert subj.endswith(".event.object_detected") + sub.close() + + def test_event_selector_zero_matches_returns_idle(self, fake_conn): + sub = tools_mod.subscribe("event(no_such_event)") + assert fake_conn.subscribed_subjects == [] + # Idle subscription: read returns empty, close is a no-op. + assert sub.read() == [] + sub.close() + + def test_non_event_scope_rejected(self, fake_conn): + with pytest.raises(ValueError) as exc: + tools_mod.subscribe("device(cam-001)") + assert "subscribe requires" in str(exc.value) + + def test_empty_or_non_string_rejected(self, fake_conn): + with pytest.raises(ValueError): + tools_mod.subscribe("") + with pytest.raises(ValueError): + tools_mod.subscribe(None) # type: ignore[arg-type] + + +# -- Subscription --------------------------------------------------- + + +class TestSubscriptionHandle: + def test_read_drains_buffered_messages(self, fake_conn): + sub = tools_mod.subscribe("correlation:r1") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r1", + {"correlation_id": "r1", "device_id": "cam-001", "success": True}, + ) + msgs = sub.read() + assert len(msgs) == 1 + assert msgs[0]["device_id"] == "cam-001" + # Subject is stamped onto the payload for source attribution. + assert "_subject" in msgs[0] + # A second read returns nothing -- the buffer is drained. + assert sub.read() == [] + sub.close() + + def test_context_manager_closes(self, fake_conn): + with tools_mod.subscribe("correlation:r2") as sub: + assert sub.read() == [] + assert fake_conn.unsubscribed_names # close() ran + + def test_iter_yields_until_idle_timeout(self, fake_conn): + sub = tools_mod.subscribe("correlation:r3") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r3", + {"correlation_id": "r3", "device_id": "cam-001"}, + ) + # Short timeout; iter() should yield the buffered reply then exit + # once no new messages arrive within the idle window. + msgs = list(sub.iter(timeout=0.1, poll_interval=0.01)) + assert len(msgs) == 1 + sub.close() + + def test_for_loop_protocol_via_dunder_iter(self, fake_conn): + # ``for msg in sub:`` should drive __iter__ which delegates to iter() + # with a sensible default timeout. Break early so the test does not + # block on the 30s default. + sub = tools_mod.subscribe("correlation:r_iter") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r_iter", + {"correlation_id": "r_iter", "device_id": "cam-001"}, + ) + gathered: list[dict] = [] + for msg in sub: + gathered.append(msg) + break # one message is enough to confirm __iter__ wiring + sub.close() + assert len(gathered) == 1 + assert gathered[0]["device_id"] == "cam-001" + + def test_read_does_not_drop_messages_appended_during_iteration(self, fake_conn): + # Race-safety guard: simulate a callback that appends a fresh + # message between the read's snapshot and truncation. The message + # must still be visible on the next read(). + sub = tools_mod.subscribe("correlation:r_race") + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r_race", + {"correlation_id": "r_race", "device_id": "cam-001", "ordinal": 1}, + ) + first = sub.read() + assert len(first) == 1 + # Now simulate a late-arriving append into the same inbox AFTER + # the previous read drained the prefix. + fake_conn.deliver( + "device-connect.default.cam-002.event.async_reply.r_race", + {"correlation_id": "r_race", "device_id": "cam-002", "ordinal": 2}, + ) + second = sub.read() + assert len(second) == 1 + assert second[0]["device_id"] == "cam-002" + sub.close() + + +# -- await_replies -------------------------------------------------- + + +class TestAwaitReplies: + def test_empty_correlation_id_returns_empty_list(self, fake_conn): + assert tools_mod.await_replies("") == [] + + def test_collects_replies_until_count(self, fake_conn): + # Pre-stage two replies on the to-be-subscribed subject. await_replies + # subscribes (drains nothing yet), then deliver more during the loop. + # We deliver up-front via the fake's deliver hook so the first poll + # picks them up. + def deliver_when_subscribed(subject, name=None): + n = FakeConnection.subscribe_buffered(fake_conn, subject, name) + # Pre-load a couple of replies so the first poll returns them. + fake_conn.deliver( + "device-connect.default.cam-001.event.async_reply.r4", + {"correlation_id": "r4", "device_id": "cam-001"}, + ) + fake_conn.deliver( + "device-connect.default.cam-002.event.async_reply.r4", + {"correlation_id": "r4", "device_id": "cam-002"}, + ) + return n + + with patch.object( + fake_conn, "subscribe_buffered", side_effect=deliver_when_subscribed, + ): + replies = tools_mod.await_replies( + "r4", timeout=2.0, until=2, poll_interval=0.01, + ) + assert len(replies) == 2 + ids = {r["device_id"] for r in replies} + assert ids == {"cam-001", "cam-002"} + + def test_returns_after_timeout_with_partial(self, fake_conn): + # No replies delivered -> after timeout, returns empty list. + replies = tools_mod.await_replies( + "r5", timeout=0.1, poll_interval=0.01, + ) + assert replies == [] diff --git a/packages/device-connect-edge/device_connect_edge/device.py b/packages/device-connect-edge/device_connect_edge/device.py index 40d5c63..c776d1a 100644 --- a/packages/device-connect-edge/device_connect_edge/device.py +++ b/packages/device-connect-edge/device_connect_edge/device.py @@ -1135,6 +1135,197 @@ async def on_msg(data: bytes, reply_subject: Optional[str]): self._logger.info("Subscribed to commands on %s", subj) + async def _broadcast_subscription(self) -> None: + """Subscribe to selector-driven broadcasts and self-elect to handle. + + Broadcast envelope shape (JSON over a fanout subject):: + + { + "correlation_id": "br-abc123", + "function": "capture_image", + "params": {"resolution": "4k"}, + "targets": ["cam-001", "cam-002"], // pre-resolved + "where": "status.battery > 50", // optional CEL + "bindings": {"mask": [[0,1],[1,0]]}, // optional + "fire_at": 1234567890.5, // optional, epoch s + "on_late": "skip" // skip|fire + } + + On match, the device executes the function and emits a reply on + ``device-connect...event.async_reply.`` + with ``{correlation_id, device_id, success, result|error, + actually_fired_at}``. + + The envelope is processed in a tracked task so the subscription + loop does not block on ``fire_at`` sleeps or long-running driver + functions; subsequent broadcasts can continue to land while an + earlier one is in flight. + """ + subj = f"device-connect.{self.tenant}.broadcast" + + async def on_msg(data: bytes, reply_subject: Optional[str]): + try: + envelope = json.loads(data) + except Exception as e: + self._logger.debug("Broadcast: malformed envelope: %s", e) + return + + correlation_id = envelope.get("correlation_id") + if not correlation_id: + return + + # Cheap self-election: target gate (pre-resolved by the dispatcher + # from the selector). When absent or empty, treat as fleet-wide. + targets = envelope.get("targets") or [] + if targets and self.device_id not in targets: + return + + if not envelope.get("function"): + return + + # Hand off to a tracked task. The task owns the where evaluation, + # the fire_at sleep, and the driver call, so this callback returns + # immediately and the messaging subscription stays drained. + self._track_task(asyncio.create_task( + self._handle_broadcast_envelope(envelope, correlation_id) + )) + + await self.messaging.subscribe(subj, callback=on_msg) + self._logger.info("Subscribed to broadcasts on %s", subj) + + + async def _handle_broadcast_envelope( + self, envelope: Dict[str, Any], correlation_id: str, + ) -> None: + """Process one broadcast envelope: evaluate where, honour fire_at, invoke, reply. + + Runs in its own task so a long-held ``fire_at`` or slow driver + function does not block the subscription callback from accepting + subsequent broadcasts. + """ + function_name = envelope.get("function") + params_dict = envelope.get("params", {}) or {} + + # Step 1: where predicate against {identity, labels, status, bindings}. + # A failed compile or eval is treated as fail-closed (do not execute); + # the message is logged at WARNING with the correlation_id so an + # operator can correlate a silent skip with a misspelled label key. + where_expr = envelope.get("where") + if where_expr and not self._evaluate_where( + where_expr, envelope.get("bindings"), correlation_id, + ): + return + + # Step 2: fire_at hold. The on_late policy decides what to do when + # the message arrives past the deadline (skip preserves coherence; + # fire runs anyway). + fire_at = envelope.get("fire_at") + on_late = envelope.get("on_late", "skip") + if fire_at is not None: + delay = float(fire_at) - time.time() + if delay < 0 and on_late == "skip": + self._logger.info( + "Broadcast %s arrived %.3fs late, on_late=skip", + correlation_id, -delay, + ) + return + if delay > 0: + await asyncio.sleep(delay) + + # Step 3: execute and reply. + actually_fired_at = time.time() + reply_subj = ( + f"device-connect.{self.tenant}.{self.device_id}" + f".event.async_reply.{correlation_id}" + ) + try: + if self._driver is None: + raise RuntimeError("no driver configured") + driver_functions = self._driver._get_functions() + if function_name not in driver_functions: + raise RuntimeError(f"unknown function: {function_name}") + result = await self._driver.invoke(function_name, **params_dict) + reply_payload: Dict[str, Any] = { + "correlation_id": correlation_id, + "device_id": self.device_id, + "success": True, + "result": result, + "actually_fired_at": actually_fired_at, + } + except Exception as e: + self._logger.warning( + "Broadcast %s: function %s failed: %s", + correlation_id, function_name, e, + ) + reply_payload = { + "correlation_id": correlation_id, + "device_id": self.device_id, + "success": False, + "error": {"code": "invoke_failed", "message": str(e)}, + "actually_fired_at": actually_fired_at, + } + try: + await self.messaging.publish( + reply_subj, json.dumps(reply_payload).encode(), + ) + except Exception as e: # pragma: no cover + self._logger.warning( + "Broadcast %s: reply publish failed: %s", correlation_id, e, + ) + + + def _evaluate_where( + self, + where_expr: str, + bindings: Optional[Dict[str, Any]], + correlation_id: str, + ) -> bool: + """Compile and evaluate a where predicate; return True iff it passes. + + Returns False (do not execute) on compile or eval errors, logging + a warning so silent self-deselection is operator-visible. + """ + try: + from device_connect_edge.predicate import compile_where + predicate = compile_where(where_expr) + caps = self._driver.capabilities if self._driver else self.capabilities + status = self._driver.status if self._driver else None + labels = (caps.labels if caps and caps.labels else {}) or {} + status_dict = ( + status.model_dump() if status and hasattr(status, "model_dump") else {} + ) + # Mirror DeviceStatus.location into labels so ``labels.location`` + # works in predicates without the driver having to declare it + # explicitly. Matches the dispatcher-side flatten_device contract. + if "location" not in labels and status_dict.get("location"): + labels = {**labels, "location": status_dict["location"]} + # DeviceIdentity is exposed by the driver, not by DeviceCapabilities; + # they are independent pydantic models. Read identity from the + # driver so extra fields (seat_row, seat_col, x-mhp metadata, ...) + # reach the predicate context. Splice in device_id which lives on + # the runtime so predicates can write + # ``identity.device_id == "..."`` naturally. + identity_dict: Dict[str, Any] = {"device_id": self.device_id} + driver_identity = ( + getattr(self._driver, "identity", None) if self._driver else None + ) + if driver_identity is not None and hasattr(driver_identity, "model_dump"): + identity_dict.update(driver_identity.model_dump()) + context = { + "identity": identity_dict, + "labels": labels, + "status": status_dict, + "bindings": bindings or {}, + } + return bool(predicate.evaluate(context)) + except Exception as e: + self._logger.warning( + "Broadcast %s: where predicate failed (skipping): %s", + correlation_id, e, + ) + return False + + async def _event_dispatch_loop(self) -> None: """Send queued events, retrying on failure.""" @@ -1372,6 +1563,13 @@ async def run(self) -> None: # Subscribe to commands BEFORE capability routines so log order makes sense await self._cmd_subscription() + # Subscribe to fleet broadcasts (best-effort; broadcast is opt-in for + # callers, so failure here should not block command handling). + try: + await self._broadcast_subscription() + except Exception as e: # pragma: no cover - best effort logging + self._logger.warning("Broadcast subscription failed: %s", e) + # Start capability routines if driver supports them (CapabilityDriverMixin) # This must happen after registration so events don't fire before device is registered if hasattr(self._driver, 'start_capability_routines'): diff --git a/packages/device-connect-edge/device_connect_edge/predicate.py b/packages/device-connect-edge/device_connect_edge/predicate.py new file mode 100644 index 0000000..5bf5ff6 --- /dev/null +++ b/packages/device-connect-edge/device_connect_edge/predicate.py @@ -0,0 +1,163 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""CEL ``where`` predicate evaluator for self-election at the edge. + +A ``where`` predicate is a CEL (Common Expression Language) expression that +each candidate device evaluates against its own context to decide whether +to execute a fan-out call. The predicate sees four top-level variables: + + identity device-local identity dict (device_id, device_type, ...) + labels device labels (the same labels selectors filter on) + status device status (heartbeat-updated: location, availability, + battery, online, ...) + bindings shared payload supplied by the caller (selection masks, + thresholds, lookup tables) + +Examples:: + + battery > 50 + labels.category == "camera" && status.battery > 50 + mask[seat_row][seat_col] == 1 + bindings.threshold < status.temperature + +CEL is sandboxed by construction: no I/O, no filesystem, no exec. This +module wraps `cel-python` with lazy import so device-connect-edge does +not require it as a hard dependency. Install with the optional +``[predicate]`` extra:: + + pip install device-connect-edge[predicate] + +The evaluator is shared by the dispatcher (validates the expression +before broadcast) and the device runtime (evaluates per-call to decide +whether to execute the fan-out). +""" + +from __future__ import annotations + +from typing import Any, Mapping + + +class PredicateCompileError(ValueError): + """Raised when a ``where`` expression fails to compile. + + Carries the original cel-python error chained so callers can drill in + if they need the exact parse position. + """ + + +class PredicateEvalError(RuntimeError): + """Raised when an otherwise-valid predicate fails at evaluation time. + + Typical causes: missing context key, type mismatch (e.g. comparing a + string to an int), or arithmetic overflow. + """ + + +# Lazy import: ``cel-python`` is an optional extra. Importers of this module +# pay no cost unless they actually compile a predicate. +def _require_celpy(): + try: + import celpy # type: ignore[import-not-found] + return celpy + except ImportError as e: + raise PredicateCompileError( + "where predicates require the 'cel-python' package; " + "install with the [predicate] extra: " + "pip install 'device-connect-edge[predicate]'" + ) from e + + +def _to_cel(value: Any) -> Any: + """Recursively wrap a Python value as the matching CEL type. + + Native Python ints, strings, dicts, and lists arrive at the boundary + untyped; cel-python's evaluator expects its own typed wrappers + (``IntType``, ``MapType``, ``ListType``, ...). We wrap once at the + top of evaluation rather than asking callers to import celtypes. + """ + celpy = _require_celpy() + ct = celpy.celtypes + if value is None: + return None + if isinstance(value, bool): + return ct.BoolType(value) + if isinstance(value, int): + return ct.IntType(value) + if isinstance(value, float): + return ct.DoubleType(value) + if isinstance(value, str): + return ct.StringType(value) + if isinstance(value, (bytes, bytearray)): + return ct.BytesType(bytes(value)) + if isinstance(value, Mapping): + return ct.MapType({ + ct.StringType(str(k)): _to_cel(v) for k, v in value.items() + }) + if isinstance(value, (list, tuple)): + return ct.ListType([_to_cel(v) for v in value]) + # Fallback: stringify. Rare; happens for custom objects in the context. + return ct.StringType(str(value)) + + +class WherePredicate: + """A compiled ``where`` predicate, ready to evaluate against device context. + + Compile once (typically at the dispatcher when the call comes in or at + the edge when the broadcast envelope is received), then evaluate once + per candidate. Predicates are stateless and safe to reuse across calls. + """ + + __slots__ = ("expression", "_program") + + def __init__(self, expression: str, _program: Any): + self.expression = expression + self._program = _program + + def evaluate(self, context: Mapping[str, Any]) -> bool: + """Return ``True`` if the predicate holds for ``context``. + + ``context`` should be a flat mapping of variable name to Python + value. Common keys: ``identity``, ``labels``, ``status``, + ``bindings``. Missing keys are not auto-defaulted; if the + predicate references one, the call raises PredicateEvalError so + the caller can decide between fail-open and fail-closed. + """ + celpy = _require_celpy() + cel_context = {k: _to_cel(v) for k, v in context.items()} + try: + result = self._program.evaluate(cel_context) + except celpy.CELEvalError as e: + raise PredicateEvalError( + f"failed to evaluate where {self.expression!r}: {e}" + ) from e + return bool(result) + + +def compile_where(expression: str) -> WherePredicate: + """Compile a ``where`` expression into a reusable :class:`WherePredicate`. + + Raises :class:`PredicateCompileError` if cel-python is not installed + or the expression is malformed. + """ + celpy = _require_celpy() + if not isinstance(expression, str): + raise PredicateCompileError( + f"where expression must be a string, got {type(expression).__name__}" + ) + if not expression.strip(): + raise PredicateCompileError("where expression must be non-empty") + env = celpy.Environment() + try: + ast = env.compile(expression) + except Exception as e: + # cel-python surfaces parse errors via several exception classes + # depending on the failure mode (lark.UnexpectedToken, ValueError, + # CELParseError). Catch broadly and rewrap so callers only see + # PredicateCompileError. + raise PredicateCompileError( + f"failed to compile where {expression!r}: {e}" + ) from e + program = env.program(ast) + return WherePredicate(expression=expression, _program=program) diff --git a/packages/device-connect-edge/pyproject.toml b/packages/device-connect-edge/pyproject.toml index 27b5e88..58de4d1 100644 --- a/packages/device-connect-edge/pyproject.toml +++ b/packages/device-connect-edge/pyproject.toml @@ -38,6 +38,9 @@ dependencies = [ [project.optional-dependencies] zenoh = [] # Zenoh is now a core dependency; kept for backward compat +predicate = [ + "cel-python>=0.5.0", +] telemetry = [ "opentelemetry-api>=1.30.0", "opentelemetry-sdk>=1.30.0", diff --git a/packages/device-connect-edge/tests/test_predicate.py b/packages/device-connect-edge/tests/test_predicate.py new file mode 100644 index 0000000..dfaff81 --- /dev/null +++ b/packages/device-connect-edge/tests/test_predicate.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the CEL ``where`` predicate evaluator. + +These tests require the ``[predicate]`` extra (cel-python). They are +skipped automatically when cel-python is not installed so the rest of +the edge test suite stays runnable on minimal installs. +""" +from __future__ import annotations + +import pytest + +celpy = pytest.importorskip("celpy") + +from device_connect_edge.predicate import ( + PredicateCompileError, + PredicateEvalError, + WherePredicate, + compile_where, +) + + +# -- compile_where -------------------------------------------------- + + +class TestCompile: + def test_simple_comparison_compiles(self): + p = compile_where("battery > 50") + assert isinstance(p, WherePredicate) + assert p.expression == "battery > 50" + + def test_boolean_combination_compiles(self): + p = compile_where("a > 1 && b < 10 || c == 'x'") + assert isinstance(p, WherePredicate) + + def test_array_indexing_compiles(self): + p = compile_where("mask[row][col] == 1") + assert isinstance(p, WherePredicate) + + def test_label_dot_access_compiles(self): + p = compile_where("labels.category == 'camera'") + assert isinstance(p, WherePredicate) + + def test_empty_expression_rejected(self): + with pytest.raises(PredicateCompileError): + compile_where("") + with pytest.raises(PredicateCompileError): + compile_where(" ") + + def test_non_string_rejected(self): + with pytest.raises(PredicateCompileError): + compile_where(123) # type: ignore[arg-type] + + def test_malformed_expression_rejected(self): + with pytest.raises(PredicateCompileError) as exc: + compile_where("a > > b") + assert "failed to compile" in str(exc.value) + + +# -- evaluate ------------------------------------------------------- + + +class TestEvaluate: + def test_truthy_comparison(self): + p = compile_where("battery > 50") + assert p.evaluate({"battery": 80}) is True + assert p.evaluate({"battery": 30}) is False + + def test_label_match(self): + p = compile_where("labels.category == 'camera'") + assert p.evaluate({"labels": {"category": "camera"}}) is True + assert p.evaluate({"labels": {"category": "robot"}}) is False + + def test_2d_mask_indexing(self): + # The mask-indexing case is the deciding example for picking CEL + # over JSONLogic; keep it as a regression guard. + p = compile_where("mask[row][col] == 1") + ctx = { + "mask": [[0, 1, 0], [1, 0, 0]], + "row": 0, + "col": 1, + } + assert p.evaluate(ctx) is True + ctx["col"] = 0 + assert p.evaluate(ctx) is False + + def test_combined_label_and_status(self): + p = compile_where("labels.category == 'camera' && status.battery > 50") + ctx = { + "labels": {"category": "camera"}, + "status": {"battery": 80}, + } + assert p.evaluate(ctx) is True + ctx["status"]["battery"] = 30 + assert p.evaluate(ctx) is False + ctx["labels"]["category"] = "robot" + ctx["status"]["battery"] = 80 + assert p.evaluate(ctx) is False + + def test_bindings_and_status_compose(self): + p = compile_where("status.temperature > bindings.threshold") + ctx = { + "status": {"temperature": 75.5}, + "bindings": {"threshold": 70.0}, + } + assert p.evaluate(ctx) is True + + def test_string_in_list(self): + p = compile_where("labels.category in ['camera', 'inference']") + assert p.evaluate({"labels": {"category": "camera"}}) is True + assert p.evaluate({"labels": {"category": "robot"}}) is False + + def test_missing_variable_raises_eval_error(self): + p = compile_where("status.battery > 50") + with pytest.raises(PredicateEvalError): + p.evaluate({}) + + def test_type_mismatch_raises_eval_error(self): + p = compile_where("battery > 50") + with pytest.raises(PredicateEvalError): + p.evaluate({"battery": "not a number"}) + + def test_evaluator_is_reusable(self): + # Compile once, evaluate against many contexts. Reusability is the + # property that lets callers compile broadcast envelopes once at + # the dispatcher and ship them to N targets. + p = compile_where("battery > 50") + results = [p.evaluate({"battery": v}) for v in (10, 50, 51, 100)] + assert results == [False, False, True, True] diff --git a/packages/device-connect-server/device_connect_server/devctl/cli.py b/packages/device-connect-server/device_connect_server/devctl/cli.py index 071b423..f73ec6a 100644 --- a/packages/device-connect-server/device_connect_server/devctl/cli.py +++ b/packages/device-connect-server/device_connect_server/devctl/cli.py @@ -574,9 +574,20 @@ def create_parser() -> argparse.ArgumentParser: p_reg.add_argument("--broker", default=None, help="Broker URL") p_reg.add_argument("--keepalive", action="store_true", help="Start heartbeat loop") - # discover command - p_discover = sub.add_parser("discover", help="Discover uncommissioned devices") - p_discover.add_argument("--timeout", type=int, default=5, help="Timeout in seconds") + # mdns-scan: discover uncommissioned devices on the local network. + # Renamed from the historical ``discover`` verb so the selector-driven + # ``discover`` below (which queries the fleet, not the local network) + # can take the natural name. + p_scan = sub.add_parser( + "mdns-scan", help="Discover uncommissioned devices via mDNS", + aliases=["scan"], + ) + p_scan.add_argument("--timeout", type=int, default=5, help="Timeout in seconds") + + # Selector-driven fleet discovery (new). Registers ``discover`` and + # ``discover-labels`` as parser entries. + from device_connect_server.devctl import selector_cli + selector_cli.register_subparsers(sub) # commission command p_commission = sub.add_parser("commission", help="Commission a device with PIN") @@ -617,9 +628,17 @@ def main(argv: Optional[List[str]] = None) -> None: loop.stop() print("\nbye!") - elif args.cmd == "discover": + elif args.cmd in ("mdns-scan", "scan"): asyncio.run(discover_devices(timeout=args.timeout)) + elif args.cmd == "discover": + from device_connect_server.devctl import selector_cli + sys.exit(selector_cli.run_discover(args)) + + elif args.cmd == "discover-labels": + from device_connect_server.devctl import selector_cli + sys.exit(selector_cli.run_discover_labels(args)) + elif args.cmd == "commission": asyncio.run( commission_device( diff --git a/packages/device-connect-server/device_connect_server/devctl/selector_cli.py b/packages/device-connect-server/device_connect_server/devctl/selector_cli.py new file mode 100644 index 0000000..68a6637 --- /dev/null +++ b/packages/device-connect-server/device_connect_server/devctl/selector_cli.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""``devctl`` selector-driven discovery verbs. + +Thin wrappers around ``device_connect_agent_tools.discover`` and +``discover_labels`` so operators can drive the same selector grammar +from a shell. +""" +from __future__ import annotations + +import json +import os +from typing import Any + + +def _connect(broker: str | None) -> None: + """Best-effort connect to the messaging backend. + + Reuses ``DEVICE_CONNECT_*`` and ``NATS_URL`` env vars when ``broker`` is + not given. Kept as a thin wrapper so all CLI verbs share the same + connect-or-fail semantics. + """ + from device_connect_agent_tools import connect + + if broker: + connect(nats_url=broker) + else: + nats_url = os.getenv("NATS_URL") or os.getenv("DEVICE_CONNECT_NATS_URL") + if nats_url: + connect(nats_url=nats_url) + else: + connect() + + +def _pretty(data: Any) -> str: + """Render a JSON payload for terminal output.""" + return json.dumps(data, indent=2, sort_keys=True, default=str) + + +def run_discover(args: Any) -> int: + """Execute ``devctl discover ""``.""" + from device_connect_agent_tools import disconnect, discover + + _connect(getattr(args, "broker", None)) + try: + result = discover( + args.selector, + offset=int(args.offset or 0), + limit=int(args.limit or 200), + ) + print(_pretty(result)) + return 0 if "error" not in result else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_discover_labels(args: Any) -> int: + """Execute ``devctl discover-labels [--key K]``.""" + from device_connect_agent_tools import disconnect, discover_labels + + _connect(getattr(args, "broker", None)) + try: + result = discover_labels( + key=args.key, + offset=int(args.offset or 0), + limit=int(args.limit or 50), + ) + print(_pretty(result)) + return 0 if "error" not in result else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def register_subparsers(sub: Any) -> None: + """Attach the discover / discover-labels subparsers to a devctl parser.""" + p = sub.add_parser( + "discover", + help="Resolve a selector to devices, functions, or events", + ) + p.add_argument("selector", help="Selector expression (e.g. 'device(category:camera)')") + p.add_argument("--broker", default=None, help="Messaging broker URL") + p.add_argument("--offset", type=int, default=0, help="Pagination offset") + p.add_argument("--limit", type=int, default=200, help="Page size") + + p = sub.add_parser( + "discover-labels", + help="Browse fleet label vocabulary", + ) + p.add_argument( + "--key", default=None, + help="Axis-qualified label key (e.g. 'device.location') for per-key pagination", + ) + p.add_argument("--broker", default=None, help="Messaging broker URL") + p.add_argument("--offset", type=int, default=0, help="Pagination offset") + p.add_argument("--limit", type=int, default=50, help="Page size") diff --git a/packages/device-connect-server/device_connect_server/portal/views/devices.py b/packages/device-connect-server/device_connect_server/portal/views/devices.py index 3f82309..7f5bf1e 100644 --- a/packages/device-connect-server/device_connect_server/portal/views/devices.py +++ b/packages/device-connect-server/device_connect_server/portal/views/devices.py @@ -320,7 +320,7 @@ async def download_starter_script(request: web.Request): """Device Connect — starter AI agent (Strands + OpenAI). Connects to Device Connect, discovers your fleet, and reacts to device -events by calling tools (discover_labels, discover, invoke_device). +events by calling tools (discover_labels, discover, invoke, invoke_many). LLM inference runs through the Arm internal OpenAI proxy. Usage: @@ -404,7 +404,7 @@ async def prepare(self) -> Dict[str, Any]: from strands.models.openai import OpenAIModel from device_connect_agent_tools.adapters.strands import ( discover_labels, discover, - invoke_device, invoke_device_with_fallback, get_device_status, + invoke, invoke_many, invoke_device_with_fallback, get_device_status, ) result = await super().prepare() @@ -417,7 +417,7 @@ async def prepare(self) -> Dict[str, Any]: ), tools=[ discover_labels, discover, - invoke_device, invoke_device_with_fallback, get_device_status, + invoke, invoke_many, invoke_device_with_fallback, get_device_status, ], system_prompt=self._build_system_prompt(), ) @@ -454,14 +454,18 @@ def _build_system_prompt(self) -> str: f"functions, or events. Examples:\\n" f" device(category:camera, location:zone-A/*)\\n" f" device(robot-001).function(direction:write)\\n" - f" function(safety:critical)\\n" - f" - invoke_device(device_id, function, params) -- call a device function\\n\\n" + f" function(safety:critical)\\n\\n" + f"INVOCATION TOOLS:\\n" + f" - invoke(selector, params) -- call exactly one function. " + f"Selector must resolve to one (device, function) tuple.\\n" + f" - invoke_many(selector, params) -- fan out a function call " + f"over a selector-resolved set in parallel.\\n\\n" f"INSTRUCTIONS:\\n" f"When you receive device events, you MUST:\\n" f"1. Analyze the events\\n" f"2. Use discover() with a function-scoped selector to check " f"available functions if needed\\n" - f"3. Use invoke_device() to interact with devices\\n" + f"3. Use invoke() or invoke_many() to interact with devices\\n" f"4. Report what you found and what actions you took\\n\\n" f"Always provide llm_reasoning when invoking devices.\\n" f"Always call at least one tool per batch of events." diff --git a/packages/device-connect-server/device_connect_server/statectl/cli.py b/packages/device-connect-server/device_connect_server/statectl/cli.py index e1a03ef..161afdd 100644 --- a/packages/device-connect-server/device_connect_server/statectl/cli.py +++ b/packages/device-connect-server/device_connect_server/statectl/cli.py @@ -408,6 +408,13 @@ def create_parser() -> argparse.ArgumentParser: # stats sub.add_parser("stats", help="Key counts by namespace") + # Selector-driven operations (invoke / invoke-many / broadcast / + # subscribe / await). These verbs do not touch etcd; they run over + # the messaging fabric. They live under statectl because they all + # change the live state of devices. + from device_connect_server.statectl import operations_cli + operations_cli.register_subparsers(sub) + return parser @@ -430,9 +437,25 @@ async def _run(args) -> None: await handler(client, args) +_OPERATIONS_DISPATCH = { + "invoke": "run_invoke", + "invoke-many": "run_invoke_many", + "broadcast": "run_broadcast", + "subscribe": "run_subscribe", + "await": "run_await", +} + + def main(): parser = create_parser() args = parser.parse_args() + if args.cmd in _OPERATIONS_DISPATCH: + # Operations verbs run over messaging, not etcd. Bypass the etcd + # client setup that the COMMANDS dispatch table assumes. + from device_connect_server.statectl import operations_cli + handler = getattr(operations_cli, _OPERATIONS_DISPATCH[args.cmd]) + sys.exit(handler(args)) + try: asyncio.run(_run(args)) except KeyboardInterrupt: diff --git a/packages/device-connect-server/device_connect_server/statectl/operations_cli.py b/packages/device-connect-server/device_connect_server/statectl/operations_cli.py new file mode 100644 index 0000000..7ddc9ef --- /dev/null +++ b/packages/device-connect-server/device_connect_server/statectl/operations_cli.py @@ -0,0 +1,297 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""``statectl`` selector-driven operations verbs. + +Thin wrappers around the agent-tools ``invoke`` / ``invoke_many`` / +``broadcast`` / ``subscribe`` / ``await_replies`` functions so operators +can fire selector-driven calls from a shell. +""" +from __future__ import annotations + +import json +import os +from typing import Any + + +def _connect(broker: str | None) -> None: + """Connect to the messaging backend using the same env-or-broker rules + as devctl's selector verbs.""" + from device_connect_agent_tools import connect + + if broker: + connect(nats_url=broker) + else: + nats_url = os.getenv("NATS_URL") or os.getenv("DEVICE_CONNECT_NATS_URL") + if nats_url: + connect(nats_url=nats_url) + else: + connect() + + +def _parse_param_kv(values: list[str] | None) -> dict[str, Any]: + """Parse ``--param k=v`` repeated args into a function-params dict. + + Values that look like JSON (``[...]``, ``{...}``, numbers, ``true`` / + ``false`` / ``null``) are decoded; everything else stays a string. This + matches what an operator would expect when typing + ``--param resolution=1080p --param tags='["a","b"]'``. + """ + out: dict[str, Any] = {} + for entry in values or []: + if "=" not in entry: + raise ValueError(f"--param must be 'k=v', got {entry!r}") + k, _, v = entry.partition("=") + k = k.strip() + if not k: + raise ValueError(f"--param has empty key in {entry!r}") + v_stripped = v.strip() + # JSON-decode obvious JSON-shaped values; fall back to raw string. + if ( + v_stripped.startswith(("[", "{", '"')) + or v_stripped in ("true", "false", "null") + or _looks_numeric(v_stripped) + ): + try: + out[k] = json.loads(v_stripped) + continue + except json.JSONDecodeError: + pass + out[k] = v + return out + + +def _looks_numeric(s: str) -> bool: + try: + float(s) + return True + except ValueError: + return False + + +def _pretty(data: Any) -> str: + return json.dumps(data, indent=2, sort_keys=True, default=str) + + +# -- verbs ---------------------------------------------------------- + + +def run_invoke(args: Any) -> int: + from device_connect_agent_tools import disconnect, invoke + + _connect(getattr(args, "broker", None)) + try: + result = invoke( + args.selector, + params=_parse_param_kv(args.param), + llm_reasoning=args.reason, + ) + print(_pretty(result)) + return 0 if result.get("success") else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_invoke_many(args: Any) -> int: + from device_connect_agent_tools import disconnect, invoke_many + + _connect(getattr(args, "broker", None)) + try: + result = invoke_many( + args.selector, + params=_parse_param_kv(args.param), + timeout=float(args.timeout), + max_concurrency=int(args.max_concurrency), + llm_reasoning=args.reason, + ) + print(_pretty(result)) + # Exit non-zero on a top-level error OR when any target failed, so + # shell pipelines can detect partial failure without parsing JSON. + if "error" in result: + return 1 + if result.get("failed", 0) > 0: + return 3 + return 0 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_broadcast(args: Any) -> int: + from device_connect_agent_tools import broadcast, disconnect + + bindings = None + if args.bindings: + try: + bindings = json.loads(args.bindings) + except json.JSONDecodeError as e: + print(f"--bindings must be valid JSON: {e}") + return 2 + + _connect(getattr(args, "broker", None)) + try: + result = broadcast( + args.selector, + params=_parse_param_kv(args.param), + where=args.where, + bindings=bindings, + fire_at=float(args.fire_at) if args.fire_at is not None else None, + on_late=args.on_late, + llm_reasoning=args.reason, + ) + print(_pretty(result)) + return 0 if "error" not in result else 1 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_subscribe(args: Any) -> int: + """Stream events / replies for ``args.selector`` to stdout. + + Each message is printed as one JSON line so the output can be piped + into ``jq`` or grep. Runs until ``--timeout`` of idle silence elapses + or ``--until`` messages have been printed (whichever comes first). + Exit codes: + 0 one or more messages were printed + 4 idle-timeout reached with zero messages + 130 interrupted with Ctrl-C + """ + from device_connect_agent_tools import disconnect, subscribe + + _connect(getattr(args, "broker", None)) + count = 0 + try: + with subscribe(args.selector) as sub: + try: + for msg in sub.iter( + timeout=float(args.timeout), poll_interval=0.05, + ): + print(json.dumps(msg, default=str)) + count += 1 + if args.until is not None and count >= int(args.until): + break + except KeyboardInterrupt: + # Clean exit on Ctrl-C: the ``with`` block tears the + # subscription down before this returns. + return 130 + return 0 if count > 0 else 4 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +def run_await(args: Any) -> int: + from device_connect_agent_tools import await_replies, disconnect + + _connect(getattr(args, "broker", None)) + try: + replies = await_replies( + args.correlation_id, + timeout=float(args.timeout), + until=int(args.until) if args.until is not None else None, + ) + print(_pretty(replies)) + return 0 + finally: + try: + disconnect() + except Exception: # pragma: no cover + pass + + +# -- parser wiring -------------------------------------------------- + + +def register_subparsers(sub: Any) -> None: + """Attach the operation subparsers to a statectl parser.""" + p = sub.add_parser("invoke", help="Call exactly one function on one device") + p.add_argument("selector", help="Function-scoped selector") + p.add_argument( + "--param", action="append", default=[], + help="Function param as k=v (repeatable; JSON values decoded)", + ) + p.add_argument("--reason", default=None, help="LLM reasoning") + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "invoke-many", help="Fan out a call over a selector-resolved set", + ) + p.add_argument("selector", help="Function-scoped selector") + p.add_argument( + "--param", action="append", default=[], + help="Function param as k=v (repeatable; JSON values decoded)", + ) + p.add_argument("--timeout", default=30.0, help="Per-target timeout (s)") + p.add_argument( + "--max-concurrency", default=32, dest="max_concurrency", + help="Parallel worker cap", + ) + p.add_argument("--reason", default=None, help="LLM reasoning") + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "broadcast", + help="Async fan-out; returns correlation_id", + ) + p.add_argument("selector", help="Function-scoped selector") + p.add_argument( + "--param", action="append", default=[], + help="Function param as k=v (repeatable; JSON values decoded)", + ) + p.add_argument( + "--where", default=None, + help="CEL predicate evaluated at the edge per candidate", + ) + p.add_argument( + "--bindings", default=None, + help="JSON-encoded bindings dict (shared payload for the predicate)", + ) + p.add_argument( + "--fire-at", default=None, dest="fire_at", + help="Wall-clock epoch seconds for synchronized fan-out", + ) + p.add_argument( + "--on-late", choices=["skip", "fire"], default="skip", dest="on_late", + help="Policy when fire_at deadline has passed (default: skip)", + ) + p.add_argument("--reason", default=None, help="LLM reasoning") + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "subscribe", help="Stream events or broadcast replies to stdout", + ) + p.add_argument( + "selector", + help="Event selector or 'correlation:' for broadcast replies", + ) + p.add_argument( + "--timeout", default=10.0, + help="Idle-silence timeout per message (s; resets on each arrival)", + ) + p.add_argument( + "--until", default=None, + help="Stop after this many messages are printed", + ) + p.add_argument("--broker", default=None, help="Messaging broker URL") + + p = sub.add_parser( + "await", help="Collect replies for a broadcast correlation_id", + ) + p.add_argument("correlation_id", help="Correlation id returned by broadcast") + p.add_argument("--timeout", default=10.0, help="Overall timeout (s)") + p.add_argument( + "--until", default=None, + help="Stop after this many replies have been collected", + ) + p.add_argument("--broker", default=None, help="Messaging broker URL") diff --git a/packages/device-connect-server/tests/device_connect_server/test_selector_cli.py b/packages/device-connect-server/tests/device_connect_server/test_selector_cli.py new file mode 100644 index 0000000..2ed2cf8 --- /dev/null +++ b/packages/device-connect-server/tests/device_connect_server/test_selector_cli.py @@ -0,0 +1,199 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Smoke tests for the selector-driven CLI verbs. + +Argument-parser shape only; the underlying tools (``discover``, +``invoke``, ``broadcast``, etc.) have their own unit and integration +tests. These guards catch parser-config regressions (missing positional, +typoed dest, alias drift). +""" +from __future__ import annotations + +import json + +import pytest + +from device_connect_server.devctl import cli as devctl_cli +from device_connect_server.devctl import selector_cli +from device_connect_server.statectl import cli as statectl_cli +from device_connect_server.statectl import operations_cli + + +# -- devctl --------------------------------------------------------- + + +class TestDevctlSelectorParser: + def test_discover_requires_selector(self): + parser = devctl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args(["discover"]) + + def test_discover_parses_selector(self): + parser = devctl_cli.create_parser() + args = parser.parse_args(["discover", "device(category:camera)"]) + assert args.cmd == "discover" + assert args.selector == "device(category:camera)" + assert args.offset == 0 + assert args.limit == 200 + + def test_discover_offset_limit_override(self): + parser = devctl_cli.create_parser() + args = parser.parse_args( + ["discover", "device(*)", "--offset", "100", "--limit", "50"] + ) + assert args.offset == 100 + assert args.limit == 50 + + def test_discover_labels_no_key(self): + parser = devctl_cli.create_parser() + args = parser.parse_args(["discover-labels"]) + assert args.cmd == "discover-labels" + assert args.key is None + assert args.limit == 50 + + def test_discover_labels_key_pagination(self): + parser = devctl_cli.create_parser() + args = parser.parse_args( + ["discover-labels", "--key", "device.location", "--limit", "20"] + ) + assert args.key == "device.location" + assert args.limit == 20 + + def test_legacy_discover_renamed_to_mdns_scan(self): + # The historical "discover" verb (mDNS scan) now lives under + # mdns-scan; the alias "scan" keeps it discoverable. + parser = devctl_cli.create_parser() + for verb in ("mdns-scan", "scan"): + args = parser.parse_args([verb]) + # Both aliases share the same args.cmd + assert args.cmd in ("mdns-scan", "scan") + + +# -- statectl ------------------------------------------------------- + + +class TestStatectlOperationsParser: + def test_invoke_requires_selector(self): + parser = statectl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args(["invoke"]) + + def test_invoke_parses(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + [ + "invoke", "device(robot-001).function(grip_close)", + "--param", "force_n=10", + "--reason", "test", + ] + ) + assert args.cmd == "invoke" + assert args.selector == "device(robot-001).function(grip_close)" + assert args.param == ["force_n=10"] + assert args.reason == "test" + + def test_invoke_many_with_timeout(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + [ + "invoke-many", + "function(safety:critical)", + "--timeout", "5", + "--max-concurrency", "8", + ] + ) + assert args.cmd == "invoke-many" + assert float(args.timeout) == 5.0 + assert int(args.max_concurrency) == 8 + + def test_broadcast_full_signature(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + [ + "broadcast", + "device(category:phone).function(set_flashlight)", + "--param", "on=true", + "--param", "color=white", + "--where", "labels.location == 'lab-A'", + "--bindings", '{"mask": [[0,1],[1,0]]}', + "--fire-at", "1700000000.0", + "--on-late", "fire", + ] + ) + assert args.cmd == "broadcast" + assert args.selector.startswith("device(category:phone)") + assert args.where == "labels.location == 'lab-A'" + assert args.on_late == "fire" + + def test_broadcast_rejects_unknown_on_late(self): + parser = statectl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args( + [ + "broadcast", "device(*).function(do)", + "--on-late", "bogus", + ] + ) + + def test_subscribe_parses_correlation_form(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + ["subscribe", "correlation:br-abc123", "--until", "5"] + ) + assert args.cmd == "subscribe" + assert args.selector == "correlation:br-abc123" + assert int(args.until) == 5 + + def test_await_requires_correlation_id(self): + parser = statectl_cli.create_parser() + with pytest.raises(SystemExit): + parser.parse_args(["await"]) + + def test_await_parses(self): + parser = statectl_cli.create_parser() + args = parser.parse_args( + ["await", "br-abc123", "--timeout", "2.5", "--until", "10"] + ) + assert args.correlation_id == "br-abc123" + assert float(args.timeout) == 2.5 + assert int(args.until) == 10 + + +# -- parameter parsing ---------------------------------------------- + + +class TestParseParamKV: + def test_string_values_default(self): + result = operations_cli._parse_param_kv(["a=hello", "b=world"]) + assert result == {"a": "hello", "b": "world"} + + def test_numbers_decoded(self): + result = operations_cli._parse_param_kv(["count=5", "ratio=0.75"]) + assert result == {"count": 5, "ratio": 0.75} + + def test_booleans_decoded(self): + result = operations_cli._parse_param_kv(["on=true", "off=false"]) + assert result == {"on": True, "off": False} + + def test_json_array_decoded(self): + result = operations_cli._parse_param_kv(["zones=[1,2,3]"]) + assert result == {"zones": [1, 2, 3]} + + def test_json_object_decoded(self): + result = operations_cli._parse_param_kv(['nested={"a":1}']) + assert result == {"nested": {"a": 1}} + + def test_string_with_equals(self): + # The split is on the first '=', so values may contain further '='. + result = operations_cli._parse_param_kv(["query=a=b"]) + assert result == {"query": "a=b"} + + def test_invalid_form_rejected(self): + with pytest.raises(ValueError): + operations_cli._parse_param_kv(["no_equals_sign"]) + + def test_empty_key_rejected(self): + with pytest.raises(ValueError): + operations_cli._parse_param_kv(["=value"]) diff --git a/tests/tests/test_tools_broadcast.py b/tests/tests/test_tools_broadcast.py new file mode 100644 index 0000000..975016e --- /dev/null +++ b/tests/tests/test_tools_broadcast.py @@ -0,0 +1,358 @@ +# Copyright (c) 2024-2026, Arm Limited and Contributors. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Integration tests for selector-driven broadcast + correlation replies. + +End-to-end coverage for the async fan-out path: +- Dispatcher publishes a broadcast envelope on the fanout subject. +- Each device runtime self-elects via target_device_ids and the optional + CEL ``where`` predicate. +- Devices execute the function and emit a reply on the per-device async + reply subject keyed by correlation_id. +- ``await_replies`` collects replies for a bounded window. +""" + +import asyncio +import time + +import pytest + +SETTLE_TIME = 0.4 +DISCOVERY_TIMEOUT = 5.0 + + +async def _wait_for_devices(messaging_url, expected_ids): + from device_connect_agent_tools import connect + from device_connect_agent_tools.connection import get_connection + + await asyncio.to_thread(connect, nats_url=messaging_url) + deadline = time.monotonic() + DISCOVERY_TIMEOUT + while True: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices) + ids = {d.get("device_id") for d in devices} + if expected_ids.issubset(ids) or time.monotonic() > deadline: + return devices + await asyncio.sleep(0.25) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_returns_correlation_and_replies_arrive( + device_spawner, messaging_url, +): + """broadcast() returns a correlation_id and matching devices reply on the + per-device async reply subject.""" + await device_spawner.spawn_camera("itest-bc-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bc-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bc-cam-1", "itest-bc-cam-2"}) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-bc-cam-*).function(capture_image)", + {"resolution": "720p"}, + ) + assert result["correlation_id"].startswith("br-") + assert result["candidates"] == 2 + assert result["function"] == "capture_image" + + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=5.0, until=2, + ) + assert len(replies) == 2 + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bc-cam-1", "itest-bc-cam-2"} + for r in replies: + assert r["success"] is True + assert r["correlation_id"] == result["correlation_id"] + assert "actually_fired_at" in r + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_where_filters_at_edge(device_spawner, messaging_url): + """A CEL where predicate runs at each candidate; only matches reply.""" + pytest.importorskip("celpy") + await device_spawner.spawn_camera("itest-bcw-cam-a", location="lab-A") + await device_spawner.spawn_camera("itest-bcw-cam-b", location="lab-B") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bcw-cam-a", "itest-bcw-cam-b"}) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-bcw-cam-*).function(capture_image)", + {"resolution": "1080p"}, + "labels.location == 'lab-A'", # where predicate + ) + assert result["candidates"] == 2 + + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=3.0, + ) + # Only cam-a is in lab-A; cam-b silently self-deselects. + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bcw-cam-a"} + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_fire_at_synchronizes_fan_out( + device_spawner, messaging_url, +): + """fire_at causes each device to fire from its own clock at the deadline.""" + await device_spawner.spawn_camera("itest-bcf-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bcf-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bcf-cam-1", "itest-bcf-cam-2"}) + try: + # Schedule 0.5s in the future; on_late=skip so any tardy device drops + # the call rather than firing late and breaking the coherence. + scheduled = time.time() + 0.5 + result = await asyncio.to_thread( + broadcast, + "device(itest-bcf-cam-*).function(capture_image)", + None, None, None, + scheduled, # fire_at + "skip", # on_late + ) + assert result["candidates"] == 2 + + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=3.0, until=2, + ) + assert len(replies) == 2 + # actually_fired_at should be at-or-after the scheduled time on each. + for r in replies: + assert r["actually_fired_at"] >= scheduled - 0.05 # small slack + # Achieved spread should be tight (well under network jitter). + spread = max(r["actually_fired_at"] for r in replies) - min( + r["actually_fired_at"] for r in replies + ) + assert spread < 0.5, f"fire_at spread too wide: {spread:.3f}s" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_fire_at_late_with_skip_drops( + device_spawner, messaging_url, +): + """A fire_at in the past with on_late=skip yields no replies.""" + await device_spawner.spawn_camera("itest-bcl-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices(messaging_url, {"itest-bcl-cam"}) + try: + past = time.time() - 5.0 # already 5s late + result = await asyncio.to_thread( + broadcast, + "device(itest-bcl-cam).function(capture_image)", + None, None, None, past, "skip", + ) + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=1.5, + ) + assert replies == [] + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_broadcast_where_with_bindings(device_spawner, messaging_url): + """A where predicate that reads bindings. self-elects per-target.""" + pytest.importorskip("celpy") + await device_spawner.spawn_camera("itest-bcbnd-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bcbnd-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices( + messaging_url, {"itest-bcbnd-cam-1", "itest-bcbnd-cam-2"} + ) + try: + # Allowlist sent in bindings; the predicate uses bindings.allow to + # select. Devices not in the allowlist self-deselect silently. + result = await asyncio.to_thread( + broadcast, + "device(itest-bcbnd-cam-*).function(capture_image)", + None, + "identity.device_id in bindings.allow", + {"allow": ["itest-bcbnd-cam-1"]}, + ) + assert result["candidates"] == 2 + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], timeout=3.0, + ) + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bcbnd-cam-1"} + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_await_replies_until_stops_early(device_spawner, messaging_url): + """``await_replies`` returns once ``until`` replies have arrived.""" + await device_spawner.spawn_camera("itest-awu-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-awu-cam-2", location="lab-A") + await device_spawner.spawn_camera("itest-awu-cam-3", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import ( + await_replies, broadcast, disconnect, + ) + + await _wait_for_devices( + messaging_url, {"itest-awu-cam-1", "itest-awu-cam-2", "itest-awu-cam-3"} + ) + try: + result = await asyncio.to_thread( + broadcast, "device(itest-awu-cam-*).function(capture_image)", + ) + assert result["candidates"] == 3 + # until=1 should let us return after the first reply arrives even + # though more are coming. + t0 = time.monotonic() + replies = await asyncio.to_thread( + await_replies, result["correlation_id"], + timeout=5.0, until=1, poll_interval=0.02, + ) + elapsed = time.monotonic() - t0 + assert len(replies) >= 1 + # Sanity: returning early should be well under the timeout. + assert elapsed < 2.0, f"await_replies(until=1) took {elapsed:.2f}s" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_subscribe_iter_protocol(device_spawner, messaging_url): + """``for msg in sub:`` works via Subscription.__iter__.""" + await device_spawner.spawn_camera("itest-subiter-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-subiter-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import broadcast, disconnect, subscribe + + await _wait_for_devices( + messaging_url, {"itest-subiter-cam-1", "itest-subiter-cam-2"} + ) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-subiter-cam-*).function(capture_image)", + ) + cid = result["correlation_id"] + + def collect(): + # Exercise the bare ``for msg in sub:`` form (uses __iter__). + # Break after both expected replies arrive so the test stays + # bounded regardless of the default idle timeout. + with subscribe(f"correlation:{cid}") as sub: + gathered: list[dict] = [] + for msg in sub: + gathered.append(msg) + if len(gathered) >= 2: + break + return gathered + + replies = await asyncio.to_thread(collect) + ids = {r["device_id"] for r in replies} + assert ids == {"itest-subiter-cam-1", "itest-subiter-cam-2"} + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_subscribe_event_selector_live_stream(device_spawner, messaging_url): + """subscribe(event()) receives live events from matching devices.""" + device, driver = await device_spawner.spawn_camera( + "itest-evsub-cam", location="lab-A", + ) + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, subscribe + + await _wait_for_devices(messaging_url, {"itest-evsub-cam"}) + try: + with subscribe("device(itest-evsub-cam).event(object_detected)") as sub: + await asyncio.sleep(SETTLE_TIME) # let subscription warm up + await driver.trigger_event( + "object_detected", + {"label": "person", "confidence": 0.95}, + ) + msgs = await asyncio.to_thread( + list, sub.iter(timeout=2.0, poll_interval=0.05), + ) + # The event arrives via the JSON-RPC event subject; payload is + # under either ``params`` or top-level depending on transport. + matching = [ + m for m in msgs + if (m.get("params") or {}).get("label") == "person" + or m.get("label") == "person" + ] + assert matching, f"no object_detected events received: {msgs}" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_subscribe_correlation_form(device_spawner, messaging_url): + """subscribe('correlation:') captures replies as they arrive.""" + await device_spawner.spawn_camera("itest-bcs-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-bcs-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import broadcast, disconnect, subscribe + + await _wait_for_devices(messaging_url, {"itest-bcs-cam-1", "itest-bcs-cam-2"}) + try: + result = await asyncio.to_thread( + broadcast, + "device(itest-bcs-cam-*).function(capture_image)", + ) + cid = result["correlation_id"] + + def collect(): + with subscribe(f"correlation:{cid}") as sub: + # Drain over a short window. + return list(sub.iter(timeout=2.0, poll_interval=0.05)) + + replies = await asyncio.to_thread(collect) + ids = {r["device_id"] for r in replies} + assert ids == {"itest-bcs-cam-1", "itest-bcs-cam-2"} + finally: + await asyncio.to_thread(disconnect) diff --git a/tests/tests/test_tools_invoke.py b/tests/tests/test_tools_invoke.py index 447f301..df9878b 100644 --- a/tests/tests/test_tools_invoke.py +++ b/tests/tests/test_tools_invoke.py @@ -2,40 +2,65 @@ # # SPDX-License-Identifier: Apache-2.0 -"""Integration tests for device-connect-agent-tools invoke_device(). +"""Integration tests for selector-driven invocation tools. -Tests that the agent SDK can invoke device RPCs via the messaging backend. +Covers ``invoke()`` and ``invoke_many()`` against real devices registered +via the messaging backend. Exercises single-match, ambiguous-match, +selector-scope rejection, parallel fan-out, and partial-failure semantics +end-to-end. """ import asyncio -import pytest +import time +import pytest SETTLE_TIME = 0.3 +DISCOVERY_TIMEOUT = 5.0 + + +async def _wait_for_devices(messaging_url, expected_ids): + """Connect and poll until all expected ``device_ids`` are visible.""" + from device_connect_agent_tools import connect + from device_connect_agent_tools.connection import get_connection + + await asyncio.to_thread(connect, nats_url=messaging_url) + deadline = time.monotonic() + DISCOVERY_TIMEOUT + while True: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices) + ids = {d.get("device_id") for d in devices} + if expected_ids.issubset(ids) or time.monotonic() > deadline: + return devices + await asyncio.sleep(0.25) + + +# -- invoke --------------------------------------------------------- @pytest.mark.asyncio @pytest.mark.integration async def test_invoke_sensor_reading(device_spawner, messaging_url): - """invoke_device() should call sensor's get_reading and return result.""" + """invoke() calls sensor.get_reading and returns the reading payload.""" await device_spawner.spawn_sensor( - "itest-tools-invoke-sensor", initial_temp=23.5, initial_humidity=50.0, + "itest-inv-read-sensor", initial_temp=23.5, initial_humidity=50.0, ) await asyncio.sleep(SETTLE_TIME) - from device_connect_agent_tools import connect, disconnect, invoke_device + from device_connect_agent_tools import disconnect, invoke - await asyncio.to_thread(connect, nats_url=messaging_url) + await _wait_for_devices(messaging_url, {"itest-inv-read-sensor"}) try: result = await asyncio.to_thread( - invoke_device, - device_id="itest-tools-invoke-sensor", - function="get_reading", - params={"unit": "celsius"}, - llm_reasoning="Testing sensor read", + invoke, + "device(itest-inv-read-sensor).function(get_reading)", + {"unit": "celsius"}, + "Testing sensor read", ) - assert isinstance(result, dict) - assert result.get("success") is True or "temperature" in result.get("result", {}) + assert result["success"] is True + assert result["device_id"] == "itest-inv-read-sensor" + assert result["function"] == "get_reading" + assert "temperature" in result["result"] finally: await asyncio.to_thread(disconnect) @@ -43,25 +68,26 @@ async def test_invoke_sensor_reading(device_spawner, messaging_url): @pytest.mark.asyncio @pytest.mark.integration async def test_invoke_robot_dispatch(device_spawner, event_capture, messaging_url): - """invoke_device() should dispatch robot and trigger cleaning.""" + """invoke() dispatches the robot and the cleaning_finished event arrives.""" await device_spawner.spawn_robot( - "itest-tools-invoke-robot", clean_duration=0.3, + "itest-inv-robot", clean_duration=0.3, ) await asyncio.sleep(SETTLE_TIME) - async with event_capture.subscribe("device-connect.*.itest-tools-invoke-robot.event.*") as events: - from device_connect_agent_tools import connect, disconnect, invoke_device + async with event_capture.subscribe( + "device-connect.*.itest-inv-robot.event.*" + ) as events: + from device_connect_agent_tools import disconnect, invoke - await asyncio.to_thread(connect, nats_url=messaging_url) + await _wait_for_devices(messaging_url, {"itest-inv-robot"}) try: result = await asyncio.to_thread( - invoke_device, - device_id="itest-tools-invoke-robot", - function="dispatch_robot", - params={"zone_id": "zone-tools"}, - llm_reasoning="Testing robot dispatch via tools", + invoke, + "device(itest-inv-robot).function(dispatch_robot)", + {"zone_id": "zone-tools"}, + "Testing robot dispatch", ) - assert isinstance(result, dict) + assert result["success"] is True finally: await asyncio.to_thread(disconnect) @@ -71,42 +97,184 @@ async def test_invoke_robot_dispatch(device_spawner, event_capture, messaging_ur @pytest.mark.asyncio @pytest.mark.integration -async def test_invoke_unknown_device(messaging_url): - """invoke_device() on non-existent device should return error.""" - from device_connect_agent_tools import connect, disconnect, invoke_device +async def test_invoke_no_match_returns_no_match(device_spawner, messaging_url): + """A selector that resolves to zero functions returns ``no_match``.""" + await device_spawner.spawn_camera("itest-inv-nomatch-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import connect, disconnect, invoke await asyncio.to_thread(connect, nats_url=messaging_url) try: result = await asyncio.to_thread( - invoke_device, - device_id="nonexistent-device-xyz", - function="ping", - llm_reasoning="Testing error handling", + invoke, + "device(itest-inv-nomatch-cam).function(does_not_exist)", + ) + assert result["success"] is False + assert result["error"]["code"] == "no_match" + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_ambiguous_match_returns_error(device_spawner, messaging_url): + """A selector matching multiple (device, function) tuples returns an error.""" + await device_spawner.spawn_camera("itest-inv-amb-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-inv-amb-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, invoke + + await _wait_for_devices( + messaging_url, {"itest-inv-amb-cam-1", "itest-inv-amb-cam-2"} + ) + try: + result = await asyncio.to_thread( + invoke, "device(itest-inv-amb-cam-*).function(capture_image)", + ) + assert result["success"] is False + assert result["error"]["code"] == "ambiguous_match" + cand_ids = {c["device_id"] for c in result["candidates"]} + assert {"itest-inv-amb-cam-1", "itest-inv-amb-cam-2"} <= cand_ids + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_device_only_scope_rejected(device_spawner, messaging_url): + """A device-only selector cannot resolve to a function.""" + await device_spawner.spawn_camera("itest-inv-scope-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import connect, disconnect, invoke + + await asyncio.to_thread(connect, nats_url=messaging_url) + try: + result = await asyncio.to_thread(invoke, "device(itest-inv-scope-cam)") + assert result["success"] is False + assert result["error"]["code"] == "invalid_invoke_scope" + finally: + await asyncio.to_thread(disconnect) + + +# -- invoke_many ---------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_many_succeeds_across_devices(device_spawner, messaging_url): + """invoke_many() fans out a single function across multiple matching devices.""" + await device_spawner.spawn_camera("itest-inv-many-cam-1", location="lab-A") + await device_spawner.spawn_camera("itest-inv-many-cam-2", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, invoke_many + + await _wait_for_devices( + messaging_url, {"itest-inv-many-cam-1", "itest-inv-many-cam-2"} + ) + try: + result = await asyncio.to_thread( + invoke_many, + "device(itest-inv-many-cam-*).function(capture_image)", + {"resolution": "720p"}, ) - assert isinstance(result, dict) - assert result.get("success") is False + assert result["candidates"] == 2 + assert result["matched"] == 2 + assert result["succeeded"] == 2 + assert result["failed"] == 0 + ids = {row["device_id"] for row in result["results"]} + assert ids == {"itest-inv-many-cam-1", "itest-inv-many-cam-2"} finally: await asyncio.to_thread(disconnect) @pytest.mark.asyncio @pytest.mark.integration -async def test_invoke_camera_capture(device_spawner, messaging_url): - """invoke_device() should capture image from camera.""" - await device_spawner.spawn_camera("itest-tools-invoke-cam") +async def test_invoke_many_partial_failure(device_spawner, messaging_url): + """A failing target is recorded in errors while siblings succeed.""" + await device_spawner.spawn_camera( + "itest-inv-many-pf-cam-1", location="lab-A", failure_rate=1.0, + ) + await device_spawner.spawn_camera( + "itest-inv-many-pf-cam-2", location="lab-A", + ) await asyncio.sleep(SETTLE_TIME) - from device_connect_agent_tools import connect, disconnect, invoke_device + from device_connect_agent_tools import disconnect, invoke_many + + await _wait_for_devices( + messaging_url, + {"itest-inv-many-pf-cam-1", "itest-inv-many-pf-cam-2"}, + ) + try: + result = await asyncio.to_thread( + invoke_many, + "device(itest-inv-many-pf-cam-*).function(capture_image)", + ) + assert result["candidates"] == 2 + assert result["matched"] == 2 + assert result["succeeded"] == 1 + assert result["failed"] == 1 + success_ids = {row["device_id"] for row in result["results"]} + error_ids = {row["device_id"] for row in result["errors"]} + assert success_ids == {"itest-inv-many-pf-cam-2"} + assert error_ids == {"itest-inv-many-pf-cam-1"} + for row in result["errors"]: + assert "code" in row["error"] + assert "message" in row["error"] + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_many_zero_candidates(device_spawner, messaging_url): + """No matches yields an empty envelope, not an error.""" + await device_spawner.spawn_camera("itest-inv-many-zero-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import connect, disconnect, invoke_many await asyncio.to_thread(connect, nats_url=messaging_url) try: result = await asyncio.to_thread( - invoke_device, - device_id="itest-tools-invoke-cam", - function="capture_image", - params={"resolution": "720p"}, - llm_reasoning="Testing camera capture via tools", + invoke_many, + "device(itest-no-such-device).function(capture_image)", ) - assert isinstance(result, dict) + assert result["candidates"] == 0 + assert result["matched"] == 0 + assert result["succeeded"] == 0 + assert result["failed"] == 0 + assert result["results"] == [] + assert result["errors"] == [] + assert "error" not in result + finally: + await asyncio.to_thread(disconnect) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_invoke_many_function_only_selector(device_spawner, messaging_url): + """function() selects the function across the whole fleet.""" + await device_spawner.spawn_sensor( + "itest-inv-many-fo-sensor", initial_temp=20.0, + ) + await device_spawner.spawn_camera("itest-inv-many-fo-cam", location="lab-A") + await asyncio.sleep(SETTLE_TIME) + + from device_connect_agent_tools import disconnect, invoke_many + + await _wait_for_devices( + messaging_url, {"itest-inv-many-fo-cam", "itest-inv-many-fo-sensor"} + ) + try: + result = await asyncio.to_thread(invoke_many, "function(get_reading)") + ids = {row["device_id"] for row in result["results"]} + assert "itest-inv-many-fo-sensor" in ids + # Camera does not have get_reading; should not be in results. + assert "itest-inv-many-fo-cam" not in ids finally: await asyncio.to_thread(disconnect)