diff --git a/docs/reference/sdk-py.mdx b/docs/reference/sdk-py.mdx index 171593350..134664744 100644 --- a/docs/reference/sdk-py.mdx +++ b/docs/reference/sdk-py.mdx @@ -47,14 +47,17 @@ agent = await relay.gemini.spawn(name="Researcher", model=Models.Gemini.GEMINI_2 **Spawn keyword arguments:** -| Parameter | Type | Description | -| ---------- | --------------- | --------------------------------- | -| `name` | `str` | Agent name (defaults to CLI name) | -| `model` | `str` | Model to use (see Models below) | -| `task` | `str` | Initial task / prompt | -| `channels` | `list[str]` | Channels to join | -| `args` | `list[str]` | Extra CLI arguments | -| `cwd` | `str` | Working directory override | +| Parameter | Type | Description | +| ------------ | ----------- | ---------------------------------------- | +| `name` | `str` | Agent name (defaults to CLI name) | +| `model` | `str` | Model to use (see Models below) | +| `task` | `str` | Initial task / prompt | +| `channels` | `list[str]` | Channels to join | +| `args` | `list[str]` | Extra CLI arguments | +| `cwd` | `str` | Working directory override | +| `on_start` | `Callable` | Sync/async callback before spawn request | +| `on_success` | `Callable` | Sync/async callback after spawn succeeds | +| `on_error` | `Callable` | Sync/async callback when spawn fails | ### `relay.spawn(name, cli, task?, options?)` @@ -94,7 +97,7 @@ class Agent: exit_reason: str | None async def send_message(*, to: str, text: str, thread_id=None, priority=None, data=None) -> Message - async def release(reason=None) -> None + async def release(reason=None, *, on_start=None, on_success=None, on_error=None) -> None async def wait_for_ready(timeout_ms=60_000) -> None async def wait_for_exit(timeout_ms=None) -> str # "exited" | "timeout" | "released" async def wait_for_idle(timeout_ms=None) -> str # "idle" | "timeout" | "exited" diff --git a/docs/reference/sdk.mdx b/docs/reference/sdk.mdx index ffdd3ece9..e1bdee9d7 100644 --- a/docs/reference/sdk.mdx +++ b/docs/reference/sdk.mdx @@ -23,18 +23,18 @@ const relay = new AgentRelay(options?: AgentRelayOptions); ### AgentRelayOptions -| Property | Type | Description | Default | -| -------------------- | --------------------- | ------------------------------------------------------------- | -------------------- | -| `binaryPath` | `string` | Path to the broker binary | Auto-resolved | -| `binaryArgs` | `string[]` | Extra arguments for the broker process | None | -| `brokerName` | `string` | Name for the broker instance | Auto-generated | -| `channels` | `string[]` | Default channels agents are joined to on spawn | `['general']` | -| `cwd` | `string` | Working directory for the broker and spawned agents | `process.cwd()` | -| `env` | `NodeJS.ProcessEnv` | Environment variables passed to the broker | Inherited | -| `requestTimeoutMs` | `number` | Timeout for broker requests | `30000` | -| `shutdownTimeoutMs` | `number` | Timeout when shutting down | `10000` | -| `workspaceName` | `string` | Name for the auto-created Relaycast workspace | Random | -| `relaycastBaseUrl` | `string` | Base URL for the Relaycast API | `https://api.relaycast.dev` | +| Property | Type | Description | Default | +| ------------------- | ------------------- | --------------------------------------------------- | --------------------------- | +| `binaryPath` | `string` | Path to the broker binary | Auto-resolved | +| `binaryArgs` | `string[]` | Extra arguments for the broker process | None | +| `brokerName` | `string` | Name for the broker instance | Auto-generated | +| `channels` | `string[]` | Default channels agents are joined to on spawn | `['general']` | +| `cwd` | `string` | Working directory for the broker and spawned agents | `process.cwd()` | +| `env` | `NodeJS.ProcessEnv` | Environment variables passed to the broker | Inherited | +| `requestTimeoutMs` | `number` | Timeout for broker requests | `30000` | +| `shutdownTimeoutMs` | `number` | Timeout when shutting down | `10000` | +| `workspaceName` | `string` | Name for the auto-created Relaycast workspace | Random | +| `relaycastBaseUrl` | `string` | Base URL for the Relaycast API | `https://api.relaycast.dev` | --- @@ -50,14 +50,17 @@ const agent = await relay.codex.spawn(options?) **Spawn options:** -| Property | Type | Description | -| ---------- | ---------- | ------------------------------------- | -| `name` | `string` | Agent name (defaults to CLI name) | -| `model` | `string` | Model to use (see Models below) | -| `task` | `string` | Initial task / prompt | -| `channels` | `string[]` | Channels to join | -| `args` | `string[]` | Extra CLI arguments | -| `cwd` | `string` | Working directory override | +| Property | Type | Description | +| ----------- | ---------- | ------------------------------------------------ | +| `name` | `string` | Agent name (defaults to CLI name) | +| `model` | `string` | Model to use (see Models below) | +| `task` | `string` | Initial task / prompt | +| `channels` | `string[]` | Channels to join | +| `args` | `string[]` | Extra CLI arguments | +| `cwd` | `string` | Working directory override | +| `onStart` | `function` | Sync/async callback before spawn request is sent | +| `onSuccess` | `function` | Sync/async callback after spawn succeeds | +| `onError` | `function` | Sync/async callback when spawn fails | ### `relay.spawn(name, cli, task?, options?)` @@ -105,7 +108,7 @@ interface Agent { data?: Record; }): Promise; - release(reason?: string): Promise; + release(reasonOrOptions?: string | ReleaseOptions): Promise; waitForReady(timeoutMs?: number): Promise; waitForExit(timeoutMs?: number): Promise<'exited' | 'timeout' | 'released'>; waitForIdle(timeoutMs?: number): Promise<'idle' | 'timeout' | 'exited'>; @@ -113,6 +116,17 @@ interface Agent { } ``` +### `ReleaseOptions` + +`agent.release(...)` accepts either a reason string or a `ReleaseOptions` object: + +| Property | Type | Description | +| ----------- | ---------- | -------------------------------------------------- | +| `reason` | `string` | Optional release reason sent to the broker | +| `onStart` | `function` | Sync/async callback before release request is sent | +| `onSuccess` | `function` | Sync/async callback after release succeeds | +| `onError` | `function` | Sync/async callback when release fails | + --- ## Human Handles @@ -249,16 +263,16 @@ await relay.shutdown(); import { Models } from '@agent-relay/sdk'; // Claude -Models.Claude.OPUS // 'opus' -Models.Claude.SONNET // 'sonnet' -Models.Claude.HAIKU // 'haiku' +Models.Claude.OPUS; // 'opus' +Models.Claude.SONNET; // 'sonnet' +Models.Claude.HAIKU; // 'haiku' // Codex -Models.Codex.GPT_5_3_CODEX // 'gpt-5.3-codex' -Models.Codex.GPT_5_2_CODEX // 'gpt-5.2-codex' (default) -Models.Codex.GPT_5_3_CODEX_SPARK // 'gpt-5.3-codex-spark' -Models.Codex.GPT_5_1_CODEX_MAX // 'gpt-5.1-codex-max' -Models.Codex.GPT_5_1_CODEX_MINI // 'gpt-5.1-codex-mini' +Models.Codex.GPT_5_3_CODEX; // 'gpt-5.3-codex' +Models.Codex.GPT_5_2_CODEX; // 'gpt-5.2-codex' (default) +Models.Codex.GPT_5_3_CODEX_SPARK; // 'gpt-5.3-codex-spark' +Models.Codex.GPT_5_1_CODEX_MAX; // 'gpt-5.1-codex-max' +Models.Codex.GPT_5_1_CODEX_MINI; // 'gpt-5.1-codex-mini' ``` --- diff --git a/packages/sdk-py/README.md b/packages/sdk-py/README.md index 0fde08dab..0db2d8ad7 100644 --- a/packages/sdk-py/README.md +++ b/packages/sdk-py/README.md @@ -74,6 +74,21 @@ Use runtime-specific spawners: await relay.claude.spawn(name="Agent1", model=Models.Claude.SONNET, channels=["dev"], task="...") await relay.codex.spawn(name="Agent2", model=Models.Codex.GPT_5_3_CODEX, channels=["dev"], task="...") await relay.gemini.spawn(name="Agent3", model=Models.Gemini.GEMINI_2_5_PRO, channels=["dev"], task="...") + +worker = await relay.claude.spawn( + name="HookedWorker", + channels=["dev"], + # Lifecycle hooks can be sync or async callables. + on_start=lambda ctx: print(f"spawning {ctx['name']}"), + on_success=lambda ctx: print(f"spawned {ctx['name']} ({ctx['runtime']})"), + on_error=lambda ctx: print(f"failed to spawn {ctx['name']}: {ctx['error']}"), +) + +await worker.release( + "done", + on_start=lambda ctx: print(f"releasing {ctx['name']}"), + on_success=lambda ctx: print(f"released {ctx['name']}"), +) ``` ### Sending Messages @@ -108,7 +123,6 @@ Models.Gemini.GEMINI_2_5_PRO Models.Gemini.GEMINI_2_5_FLASH ``` - ## License Apache-2.0 diff --git a/packages/sdk-py/src/agent_relay/relay.py b/packages/sdk-py/src/agent_relay/relay.py index 0a27fe8df..46beb77c8 100644 --- a/packages/sdk-py/src/agent_relay/relay.py +++ b/packages/sdk-py/src/agent_relay/relay.py @@ -9,10 +9,11 @@ from __future__ import annotations import asyncio +import inspect import os import secrets from dataclasses import dataclass, field -from typing import Any, Callable, Optional +from typing import Any, Awaitable, Callable, Optional from .client import AgentRelayClient from .protocol import AgentRuntime, BrokerEvent @@ -22,6 +23,7 @@ AgentStatus = str # "spawning" | "ready" | "idle" | "exited" EventHook = Optional[Callable[..., None]] +LifecycleHook = Optional[Callable[[dict[str, Any]], None | Awaitable[None]]] @dataclass @@ -49,6 +51,9 @@ class SpawnOptions: shadow_mode: Optional[str] = None idle_threshold_secs: Optional[int] = None restart_policy: Optional[dict[str, Any]] = None + on_start: LifecycleHook = None + on_success: LifecycleHook = None + on_error: LifecycleHook = None # ── Agent handle ────────────────────────────────────────────────────────────── @@ -94,9 +99,41 @@ def status(self) -> AgentStatus: return "ready" return "spawning" - async def release(self, reason: Optional[str] = None) -> None: + async def release( + self, + reason: Optional[str] = None, + *, + on_start: LifecycleHook = None, + on_success: LifecycleHook = None, + on_error: LifecycleHook = None, + ) -> None: + context = { + "name": self._name, + "reason": reason, + } client = await self._relay._ensure_started() - await client.release(self._name, reason) + await self._relay._invoke_lifecycle_hook( + on_start, + context, + f'release("{self._name}") on_start', + ) + try: + await client.release(self._name, reason) + await self._relay._invoke_lifecycle_hook( + on_success, + context, + f'release("{self._name}") on_success', + ) + except Exception as error: + await self._relay._invoke_lifecycle_hook( + on_error, + { + **context, + "error": error, + }, + f'release("{self._name}") on_error', + ) + raise async def wait_for_ready(self, timeout_ms: int = 60_000) -> None: await self._relay.wait_for_agent_ready(self._name, timeout_ms) @@ -267,21 +304,46 @@ async def spawn( task: Optional[str] = None, model: Optional[str] = None, cwd: Optional[str] = None, + on_start: LifecycleHook = None, + on_success: LifecycleHook = None, + on_error: LifecycleHook = None, ) -> Agent: agent_name = name or self._default_name agent_channels = channels or ["general"] + context = { + "name": agent_name, + "cli": self._cli, + "channels": agent_channels, + "task": task, + } client = await self._relay._ensure_started() - - result = await client.spawn_pty( - name=agent_name, - cli=self._cli, - args=args or [], - channels=agent_channels, - task=task, - model=model, - cwd=cwd, + await self._relay._invoke_lifecycle_hook( + on_start, + context, + f'spawn("{agent_name}") on_start', ) + try: + result = await client.spawn_pty( + name=agent_name, + cli=self._cli, + args=args or [], + channels=agent_channels, + task=task, + model=model, + cwd=cwd, + ) + except Exception as error: + await self._relay._invoke_lifecycle_hook( + on_error, + { + **context, + "error": error, + }, + f'spawn("{agent_name}") on_error', + ) + raise + agent = Agent( name=result.get("name", agent_name), runtime=result.get("runtime", "pty"), @@ -289,10 +351,16 @@ async def spawn( relay=self._relay, ) self._relay._known_agents[agent.name] = agent - self._relay._ready_agents.discard(agent.name) - self._relay._message_ready_agents.discard(agent.name) - self._relay._exited_agents.discard(agent.name) - self._relay._idle_agents.discard(agent.name) + self._relay._reset_agent_lifecycle_state(agent.name) + await self._relay._invoke_lifecycle_hook( + on_success, + { + **context, + "name": agent.name, + "runtime": agent.runtime, + }, + f'spawn("{agent_name}") on_success', + ) return agent @@ -418,22 +486,44 @@ async def spawn( client = await self._ensure_started() opts = options or SpawnOptions() channels = opts.channels or ["general"] - - result = await client.spawn_pty( - name=name, - cli=cli, - task=task, - args=opts.args, - channels=channels, - model=opts.model, - cwd=opts.cwd, - team=opts.team, - shadow_of=opts.shadow_of, - shadow_mode=opts.shadow_mode, - idle_threshold_secs=opts.idle_threshold_secs, - restart_policy=opts.restart_policy, + context = { + "name": name, + "cli": cli, + "channels": channels, + "task": task, + } + await self._invoke_lifecycle_hook( + opts.on_start, + context, + f'spawn("{name}") on_start', ) + try: + result = await client.spawn_pty( + name=name, + cli=cli, + task=task, + args=opts.args, + channels=channels, + model=opts.model, + cwd=opts.cwd, + team=opts.team, + shadow_of=opts.shadow_of, + shadow_mode=opts.shadow_mode, + idle_threshold_secs=opts.idle_threshold_secs, + restart_policy=opts.restart_policy, + ) + except Exception as error: + await self._invoke_lifecycle_hook( + opts.on_error, + { + **context, + "error": error, + }, + f'spawn("{name}") on_error', + ) + raise + agent = Agent( name=result.get("name", name), runtime=result.get("runtime", "pty"), @@ -441,10 +531,16 @@ async def spawn( relay=self, ) self._known_agents[agent.name] = agent - self._ready_agents.discard(agent.name) - self._message_ready_agents.discard(agent.name) - self._exited_agents.discard(agent.name) - self._idle_agents.discard(agent.name) + self._reset_agent_lifecycle_state(agent.name) + await self._invoke_lifecycle_hook( + opts.on_success, + { + **context, + "name": agent.name, + "runtime": agent.runtime, + }, + f'spawn("{name}") on_success', + ) return agent async def spawn_and_wait( @@ -622,6 +718,27 @@ async def shutdown(self) -> None: # ── Private helpers ─────────────────────────────────────────────────── + async def _invoke_lifecycle_hook( + self, + hook: LifecycleHook, + context: dict[str, Any], + label: str, + ) -> None: + if hook is None: + return + try: + result = hook(context) + if inspect.isawaitable(result): + await result + except Exception as error: + print(f"[AgentRelay] {label} hook threw: {error}") + + def _reset_agent_lifecycle_state(self, name: str) -> None: + self._ready_agents.discard(name) + self._message_ready_agents.discard(name) + self._exited_agents.discard(name) + self._idle_agents.discard(name) + def _ensure_agent_handle( self, name: str, runtime: AgentRuntime = "pty", channels: Optional[list[str]] = None, ) -> Agent: diff --git a/packages/sdk-py/tests/test_relay_lifecycle_hooks.py b/packages/sdk-py/tests/test_relay_lifecycle_hooks.py new file mode 100644 index 000000000..44d0df2a0 --- /dev/null +++ b/packages/sdk-py/tests/test_relay_lifecycle_hooks.py @@ -0,0 +1,250 @@ +"""Tests for spawn/release lifecycle hooks in the high-level relay facade.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from agent_relay import AgentRelay, SpawnOptions + + +class _FakeRelayClient: + def __init__(self) -> None: + self.spawn_error: Exception | None = None + self.release_error: Exception | None = None + self.spawn_calls: list[dict] = [] + self.release_calls: list[tuple[str, str | None]] = [] + + async def spawn_pty(self, **kwargs): + self.spawn_calls.append(kwargs) + if self.spawn_error: + raise self.spawn_error + return {"name": kwargs["name"], "runtime": "pty"} + + async def release(self, name: str, reason: str | None = None): + self.release_calls.append((name, reason)) + if self.release_error: + raise self.release_error + return {"name": name} + + +@pytest.mark.asyncio +async def test_spawn_lifecycle_hooks_success(): + relay = AgentRelay() + client = _FakeRelayClient() + relay._ensure_started = AsyncMock(return_value=client) + + events: list[tuple[str, dict]] = [] + options = SpawnOptions( + channels=["general"], + on_start=lambda ctx: events.append(("start", dict(ctx))), + on_success=lambda ctx: events.append(("success", dict(ctx))), + on_error=lambda ctx: events.append(("error", dict(ctx))), + ) + + agent = await relay.spawn("HookWorker", "claude", "Do the work", options) + + assert agent.name == "HookWorker" + assert events[0] == ( + "start", + { + "name": "HookWorker", + "cli": "claude", + "channels": ["general"], + "task": "Do the work", + }, + ) + assert events[1] == ( + "success", + { + "name": "HookWorker", + "cli": "claude", + "channels": ["general"], + "task": "Do the work", + "runtime": "pty", + }, + ) + assert len(events) == 2 + + +@pytest.mark.asyncio +async def test_spawn_lifecycle_hooks_support_async_callbacks(): + relay = AgentRelay() + client = _FakeRelayClient() + relay._ensure_started = AsyncMock(return_value=client) + + start_done = False + success_done = False + + async def on_start(_ctx): + nonlocal start_done + await asyncio.sleep(0) + start_done = True + + async def on_success(_ctx): + nonlocal success_done + await asyncio.sleep(0) + success_done = True + + options = SpawnOptions( + channels=["general"], + on_start=on_start, + on_success=on_success, + ) + + await relay.spawn("AsyncHookWorker", "claude", "Do the work", options) + + assert start_done is True + assert success_done is True + + +@pytest.mark.asyncio +async def test_spawn_lifecycle_hooks_error(): + relay = AgentRelay() + client = _FakeRelayClient() + client.spawn_error = RuntimeError("spawn failed") + relay._ensure_started = AsyncMock(return_value=client) + + on_error_calls: list[dict] = [] + options = SpawnOptions( + channels=["general"], + on_start=lambda _: None, + on_error=lambda ctx: on_error_calls.append(dict(ctx)), + ) + + with pytest.raises(RuntimeError, match="spawn failed"): + await relay.spawn("HookWorkerFail", "claude", "Do the work", options) + + assert len(on_error_calls) == 1 + error_ctx = on_error_calls[0] + assert error_ctx["name"] == "HookWorkerFail" + assert error_ctx["cli"] == "claude" + assert isinstance(error_ctx["error"], RuntimeError) + + +@pytest.mark.asyncio +async def test_shorthand_spawn_lifecycle_hooks_success(): + relay = AgentRelay() + client = _FakeRelayClient() + relay._ensure_started = AsyncMock(return_value=client) + + events: list[str] = [] + agent = await relay.claude.spawn( + name="ShorthandWorker", + channels=["general"], + task="Run analysis", + on_start=lambda _: events.append("start"), + on_success=lambda _: events.append("success"), + on_error=lambda _: events.append("error"), + ) + + assert agent.name == "ShorthandWorker" + assert events == ["start", "success"] + + +@pytest.mark.asyncio +async def test_shorthand_spawn_does_not_fire_start_hook_if_broker_startup_fails(): + relay = AgentRelay() + relay._ensure_started = AsyncMock(side_effect=RuntimeError("broker startup failed")) + + start_called = False + error_called = False + + def _mark_called(kind: str) -> None: + nonlocal start_called, error_called + if kind == "start": + start_called = True + else: + error_called = True + + with pytest.raises(RuntimeError, match="broker startup failed"): + await relay.claude.spawn( + name="ShorthandWorkerStartupFail", + channels=["general"], + on_start=lambda _ctx: _mark_called("start"), + on_error=lambda _ctx: _mark_called("error"), + ) + + assert start_called is False + assert error_called is False + + +@pytest.mark.asyncio +async def test_release_lifecycle_hooks_success_and_error(): + relay = AgentRelay() + client = _FakeRelayClient() + relay._ensure_started = AsyncMock(return_value=client) + + agent = await relay.spawn("ReleaseWorker", "claude") + + success_events: list[str] = [] + await agent.release( + "cleanup", + on_start=lambda _: success_events.append("start"), + on_success=lambda _: success_events.append("success"), + on_error=lambda _: success_events.append("error"), + ) + + assert client.release_calls[-1] == ("ReleaseWorker", "cleanup") + assert success_events == ["start", "success"] + + client.release_error = RuntimeError("release failed") + error_calls: list[dict] = [] + with pytest.raises(RuntimeError, match="release failed"): + await agent.release( + "cleanup-again", + on_error=lambda ctx: error_calls.append(dict(ctx)), + ) + + assert len(error_calls) == 1 + assert error_calls[0]["name"] == "ReleaseWorker" + assert isinstance(error_calls[0]["error"], RuntimeError) + + +@pytest.mark.asyncio +async def test_release_lifecycle_hooks_support_async_callbacks(): + relay = AgentRelay() + client = _FakeRelayClient() + relay._ensure_started = AsyncMock(return_value=client) + + agent = await relay.spawn("ReleaseAsyncWorker", "claude") + + success_done = False + + async def on_success(_ctx): + nonlocal success_done + await asyncio.sleep(0) + success_done = True + + await agent.release("cleanup", on_success=on_success) + + assert success_done is True + + +@pytest.mark.asyncio +async def test_release_does_not_fire_hooks_if_broker_startup_fails(): + relay = AgentRelay() + client = _FakeRelayClient() + relay._ensure_started = AsyncMock(return_value=client) + agent = await relay.spawn("ReleaseStartupFailWorker", "claude") + + relay._ensure_started = AsyncMock(side_effect=RuntimeError("broker startup failed")) + + start_called = False + error_called = False + + def mark_start(_ctx): + nonlocal start_called + start_called = True + + def mark_error(_ctx): + nonlocal error_called + error_called = True + + with pytest.raises(RuntimeError, match="broker startup failed"): + await agent.release("cleanup", on_start=mark_start, on_error=mark_error) + + assert start_called is False + assert error_called is False diff --git a/packages/sdk/README.md b/packages/sdk/README.md index 295c6f0f7..4f263f7ec 100644 --- a/packages/sdk/README.md +++ b/packages/sdk/README.md @@ -40,7 +40,14 @@ relay.onMessageReceived = (msg) => console.log(`${msg.from}: ${msg.text}`); relay.onAgentIdle = ({ name, idleSecs }) => console.log(`${name} idle for ${idleSecs}s`); // Spawn agents using shorthand spawners -const worker = await relay.claude.spawn({ name: 'Worker1', channels: ['general'] }); +const worker = await relay.claude.spawn({ + name: 'Worker1', + channels: ['general'], + // Lifecycle hooks can be sync or async functions. + onStart: ({ name }) => console.log(`spawning ${name}`), + onSuccess: ({ name, runtime }) => console.log(`spawned ${name} (${runtime})`), + onError: ({ name, error }) => console.error(`failed to spawn ${name}`, error), +}); // Or use the generic spawn method const agent = await relay.spawn('Worker2', 'codex', 'Build the API', { @@ -51,6 +58,13 @@ const agent = await relay.spawn('Worker2', 'codex', 'Build the API', { // Wait for agent to finish (go idle or exit) const result = await agent.waitForIdle(120_000); +// Release with lifecycle hooks +await worker.release({ + reason: 'done', + onStart: ({ name }) => console.log(`releasing ${name}`), + onSuccess: ({ name }) => console.log(`released ${name}`), +}); + // Send messages const human = relay.human({ name: 'Orchestrator' }); await human.sendMessage({ to: 'Worker1', text: 'Start the task' }); diff --git a/packages/sdk/src/__tests__/orchestration-upgrades.test.ts b/packages/sdk/src/__tests__/orchestration-upgrades.test.ts index 4f2dab713..819fd44a8 100644 --- a/packages/sdk/src/__tests__/orchestration-upgrades.test.ts +++ b/packages/sdk/src/__tests__/orchestration-upgrades.test.ts @@ -399,6 +399,112 @@ describe('AgentRelay orchestration handles', () => { } }); + it('spawn lifecycle hooks fire for success', async () => { + const { client } = createMockFacadeClient(); + vi.spyOn(AgentRelayClient, 'start').mockResolvedValue(client); + + const relay = new AgentRelay(); + const callOrder: string[] = []; + const onStart = vi.fn(() => callOrder.push('start')); + const onSuccess = vi.fn(() => callOrder.push('success')); + const onError = vi.fn(() => callOrder.push('error')); + + try { + const agent = await relay.spawn('hook-agent', 'claude', 'do work', { + channels: ['general'], + onStart, + onSuccess, + onError, + }); + + expect(agent.name).toBe('hook-agent'); + expect(onStart).toHaveBeenCalledWith({ + name: 'hook-agent', + cli: 'claude', + channels: ['general'], + task: 'do work', + }); + expect(onSuccess).toHaveBeenCalledWith({ + name: 'hook-agent', + cli: 'claude', + channels: ['general'], + task: 'do work', + runtime: 'pty', + }); + expect(onError).not.toHaveBeenCalled(); + expect(callOrder).toEqual(['start', 'success']); + } finally { + await relay.shutdown(); + } + }); + + it('spawn lifecycle hooks await async callbacks', async () => { + const { client } = createMockFacadeClient(); + vi.spyOn(AgentRelayClient, 'start').mockResolvedValue(client); + + const relay = new AgentRelay(); + let startDone = false; + let successDone = false; + + try { + await relay.spawn('async-hook-agent', 'claude', 'do work', { + channels: ['general'], + onStart: async () => { + await new Promise((resolve) => setTimeout(resolve, 5)); + startDone = true; + }, + onSuccess: async () => { + await new Promise((resolve) => setTimeout(resolve, 5)); + successDone = true; + }, + }); + + expect(startDone).toBe(true); + expect(successDone).toBe(true); + } finally { + await relay.shutdown(); + } + }); + + it('spawn lifecycle hooks fire on error', async () => { + const { client, mock } = createMockFacadeClient(); + vi.spyOn(AgentRelayClient, 'start').mockResolvedValue(client); + mock.spawnPty.mockRejectedValueOnce(new Error('spawn failed')); + + const relay = new AgentRelay(); + const onStart = vi.fn(); + const onError = vi.fn(); + + try { + await expect( + relay.spawnPty({ + name: 'hook-agent-fail', + cli: 'claude', + channels: ['general'], + onStart, + onError, + }) + ).rejects.toThrow('spawn failed'); + + expect(onStart).toHaveBeenCalledWith({ + name: 'hook-agent-fail', + cli: 'claude', + channels: ['general'], + task: undefined, + }); + expect(onError).toHaveBeenCalledTimes(1); + expect(onError.mock.calls[0][0]).toMatchObject({ + name: 'hook-agent-fail', + cli: 'claude', + channels: ['general'], + }); + expect(onError.mock.calls[0][0].error).toBeInstanceOf(Error); + expect((onError.mock.calls[0][0].error as Error).message).toBe('spawn failed'); + } finally { + await relay.shutdown(); + } + }); + it('agent.release passes reason to the broker client', async () => { const { client, mock } = createMockFacadeClient(); vi.spyOn(AgentRelayClient, 'start').mockResolvedValue(client); @@ -420,6 +526,146 @@ describe('AgentRelay orchestration handles', () => { } }); + it('agent.release lifecycle hooks fire for success', async () => { + const { client, mock } = createMockFacadeClient(); + vi.spyOn(AgentRelayClient, 'start').mockResolvedValue(client); + + const relay = new AgentRelay(); + const callOrder: string[] = []; + const onStart = vi.fn(() => callOrder.push('start')); + const onSuccess = vi.fn(() => callOrder.push('success')); + const onError = vi.fn(() => callOrder.push('error')); + + try { + const agent = await relay.spawnPty({ + name: 'release-hook-agent', + cli: 'claude', + channels: ['general'], + }); + + await agent.release({ + reason: 'cleanup', + onStart, + onSuccess, + onError, + }); + + expect(mock.release).toHaveBeenCalledWith('release-hook-agent', 'cleanup'); + expect(onStart).toHaveBeenCalledWith({ + name: 'release-hook-agent', + reason: 'cleanup', + }); + expect(onSuccess).toHaveBeenCalledWith({ + name: 'release-hook-agent', + reason: 'cleanup', + }); + expect(onError).not.toHaveBeenCalled(); + expect(callOrder).toEqual(['start', 'success']); + } finally { + await relay.shutdown(); + } + }); + + it('agent.release lifecycle hooks fire on error', async () => { + const { client, mock } = createMockFacadeClient(); + vi.spyOn(AgentRelayClient, 'start').mockResolvedValue(client); + mock.release.mockRejectedValueOnce(new Error('release failed')); + + const relay = new AgentRelay(); + const onStart = vi.fn(); + const onError = vi.fn(); + + try { + const agent = await relay.spawnPty({ + name: 'release-hook-fail', + cli: 'claude', + channels: ['general'], + }); + + await expect( + agent.release({ + reason: 'cleanup', + onStart, + onError, + }) + ).rejects.toThrow('release failed'); + + expect(onStart).toHaveBeenCalledWith({ + name: 'release-hook-fail', + reason: 'cleanup', + }); + expect(onError).toHaveBeenCalledTimes(1); + expect(onError.mock.calls[0][0]).toMatchObject({ + name: 'release-hook-fail', + reason: 'cleanup', + }); + expect(onError.mock.calls[0][0].error).toBeInstanceOf(Error); + expect((onError.mock.calls[0][0].error as Error).message).toBe('release failed'); + } finally { + await relay.shutdown(); + } + }); + + it('agent.release lifecycle hooks await async callbacks', async () => { + const { client } = createMockFacadeClient(); + vi.spyOn(AgentRelayClient, 'start').mockResolvedValue(client); + + const relay = new AgentRelay(); + let successDone = false; + + try { + const agent = await relay.spawnPty({ + name: 'release-async-hook-agent', + cli: 'claude', + channels: ['general'], + }); + + await agent.release({ + reason: 'cleanup', + onSuccess: async () => { + await new Promise((resolve) => setTimeout(resolve, 5)); + successDone = true; + }, + }); + + expect(successDone).toBe(true); + } finally { + await relay.shutdown(); + } + }); + + it('agent.release does not fire lifecycle hooks if broker startup fails before release begins', async () => { + const { client } = createMockFacadeClient(); + vi.spyOn(AgentRelayClient, 'start').mockResolvedValue(client); + + const relay = new AgentRelay(); + const onStart = vi.fn(); + const onError = vi.fn(); + + try { + const agent = await relay.spawnPty({ + name: 'release-startup-fail-agent', + cli: 'claude', + channels: ['general'], + }); + + vi.spyOn(relay as any, 'ensureStarted').mockRejectedValueOnce(new Error('startup failed')); + + await expect( + agent.release({ + reason: 'cleanup', + onStart, + onError, + }) + ).rejects.toThrow('startup failed'); + + expect(onStart).not.toHaveBeenCalled(); + expect(onError).not.toHaveBeenCalled(); + } finally { + await relay.shutdown(); + } + }); + it('system() sends messages from the system identity', async () => { const { client, mock } = createMockFacadeClient(); vi.spyOn(AgentRelayClient, 'start').mockResolvedValue(client); diff --git a/packages/sdk/src/relay.ts b/packages/sdk/src/relay.ts index e1461ec03..865d5cdb5 100644 --- a/packages/sdk/src/relay.ts +++ b/packages/sdk/src/relay.ts @@ -87,7 +87,47 @@ export interface DeliveryState { updatedAt: number; } -export interface SpawnOptions { +export interface SpawnLifecycleContext { + name: string; + cli: string; + channels: string[]; + task?: string; +} + +export interface SpawnLifecycleSuccessContext extends SpawnLifecycleContext { + runtime: AgentRuntime; +} + +export interface SpawnLifecycleErrorContext extends SpawnLifecycleContext { + error: unknown; +} + +export interface SpawnLifecycleHooks { + onStart?: (context: SpawnLifecycleContext) => void | Promise; + onSuccess?: (context: SpawnLifecycleSuccessContext) => void | Promise; + onError?: (context: SpawnLifecycleErrorContext) => void | Promise; +} + +export interface ReleaseLifecycleContext { + name: string; + reason?: string; +} + +export interface ReleaseLifecycleErrorContext extends ReleaseLifecycleContext { + error: unknown; +} + +export interface ReleaseLifecycleHooks { + onStart?: (context: ReleaseLifecycleContext) => void | Promise; + onSuccess?: (context: ReleaseLifecycleContext) => void | Promise; + onError?: (context: ReleaseLifecycleErrorContext) => void | Promise; +} + +export interface ReleaseOptions extends ReleaseLifecycleHooks { + reason?: string; +} + +export interface SpawnOptions extends SpawnLifecycleHooks { args?: string[]; channels?: string[]; model?: string; @@ -119,7 +159,7 @@ export interface Agent { exitSignal?: string; /** Set when the agent requests exit via /exit. Available after `onAgentExitRequested` fires. */ exitReason?: string; - release(reason?: string): Promise; + release(reasonOrOptions?: string | ReleaseOptions): Promise; waitForReady(timeoutMs?: number): Promise; /** Wait for the agent process to exit on its own. * @param timeoutMs — optional timeout in ms. Resolves with `"timeout"` if exceeded, @@ -152,14 +192,16 @@ export interface HumanHandle { } export interface AgentSpawner { - spawn(options?: { - name?: string; - args?: string[]; - channels?: string[]; - task?: string; - model?: string; - cwd?: string; - }): Promise; + spawn(options?: SpawnerSpawnOptions): Promise; +} + +export interface SpawnerSpawnOptions extends SpawnLifecycleHooks { + name?: string; + args?: string[]; + channels?: string[]; + task?: string; + model?: string; + cwd?: string; } export type EventHook = ((value: T) => void) | null; @@ -296,7 +338,7 @@ export class AgentRelay { // ── Spawning ──────────────────────────────────────────────────────────── - async spawnPty(input: SpawnPtyInput): Promise { + async spawnPty(input: SpawnPtyInput & SpawnLifecycleHooks): Promise { const client = await this.ensureStarted(); if (!input.channels || input.channels.length === 0) { console.warn( @@ -305,26 +347,52 @@ export class AgentRelay { ); } const channels = input.channels ?? ['general']; - const result = await client.spawnPty({ + const lifecycleContext: SpawnLifecycleContext = { name: input.name, cli: input.cli, - args: input.args, channels, task: input.task, - model: input.model, - cwd: input.cwd, - team: input.team, - shadowOf: input.shadowOf, - shadowMode: input.shadowMode, - idleThresholdSecs: input.idleThresholdSecs, - restartPolicy: input.restartPolicy, - }); - this.readyAgents.delete(result.name); - this.messageReadyAgents.delete(result.name); - this.exitedAgents.delete(result.name); - this.idleAgents.delete(result.name); + }; + await this.invokeLifecycleHook(input.onStart, lifecycleContext, `spawnPty("${input.name}") onStart`); + let result: { name: string; runtime: AgentRuntime }; + try { + result = await client.spawnPty({ + name: input.name, + cli: input.cli, + args: input.args, + channels, + task: input.task, + model: input.model, + cwd: input.cwd, + team: input.team, + shadowOf: input.shadowOf, + shadowMode: input.shadowMode, + idleThresholdSecs: input.idleThresholdSecs, + restartPolicy: input.restartPolicy, + }); + } catch (error) { + await this.invokeLifecycleHook( + input.onError, + { + ...lifecycleContext, + error, + }, + `spawnPty("${input.name}") onError` + ); + throw error; + } + this.resetAgentLifecycleState(result.name); const agent = this.makeAgent(result.name, result.runtime, channels); this.knownAgents.set(agent.name, agent); + await this.invokeLifecycleHook( + input.onSuccess, + { + ...lifecycleContext, + name: result.name, + runtime: result.runtime, + }, + `spawnPty("${input.name}") onSuccess` + ); return agent; } @@ -342,6 +410,9 @@ export class AgentRelay { shadowMode: options?.shadowMode, idleThresholdSecs: options?.idleThresholdSecs, restartPolicy: options?.restartPolicy, + onStart: options?.onStart, + onSuccess: options?.onSuccess, + onError: options?.onError, }); } @@ -990,9 +1061,32 @@ export class AgentRelay { }, exitCode: undefined, exitSignal: undefined, - async release(reason?: string) { + async release(reasonOrOptions?: string | ReleaseOptions) { + const releaseOptions = relay.normalizeReleaseOptions(reasonOrOptions); + const releaseContext: ReleaseLifecycleContext = { + name, + reason: releaseOptions.reason, + }; const client = await relay.ensureStarted(); - await client.release(name, reason); + await relay.invokeLifecycleHook(releaseOptions.onStart, releaseContext, `release("${name}") onStart`); + try { + await client.release(name, releaseOptions.reason); + await relay.invokeLifecycleHook( + releaseOptions.onSuccess, + releaseContext, + `release("${name}") onSuccess` + ); + } catch (error) { + await relay.invokeLifecycleHook( + releaseOptions.onError, + { + ...releaseContext, + error, + }, + `release("${name}") onError` + ); + throw error; + } }, async waitForReady(timeoutMs = 60_000) { await relay.waitForAgentReady(name, timeoutMs); @@ -1117,38 +1211,99 @@ export class AgentRelay { private createSpawner(cli: string, defaultName: string, runtime: AgentRuntime): AgentSpawner { return { spawn: async (options?) => { - const client = await this.ensureStarted(); const name = options?.name ?? defaultName; const channels = options?.channels ?? ['general']; const args = options?.args ?? []; const task = options?.task; - let result: { name: string; runtime: AgentRuntime }; - if (runtime === 'headless') { - result = await client.spawnProvider({ + if (runtime === 'pty') { + return this.spawnPty({ name, - provider: cli as HeadlessProvider, - transport: 'headless', + cli, args, channels, task, + model: options?.model, + cwd: options?.cwd, + onStart: options?.onStart, + onSuccess: options?.onSuccess, + onError: options?.onError, }); - } else { - result = await client.spawnPty({ + } + + const client = await this.ensureStarted(); + const lifecycleContext: SpawnLifecycleContext = { + name, + cli, + channels, + task, + }; + await this.invokeLifecycleHook(options?.onStart, lifecycleContext, `spawn("${name}") onStart`); + let result: { name: string; runtime: AgentRuntime }; + try { + result = await client.spawnProvider({ name, - cli, + provider: cli as HeadlessProvider, + transport: 'headless', args, channels, task, - model: options?.model, - cwd: options?.cwd, }); + } catch (error) { + await this.invokeLifecycleHook( + options?.onError, + { + ...lifecycleContext, + error, + }, + `spawn("${name}") onError` + ); + throw error; } + this.resetAgentLifecycleState(result.name); const agent = this.makeAgent(result.name, result.runtime, channels); this.knownAgents.set(agent.name, agent); + await this.invokeLifecycleHook( + options?.onSuccess, + { + ...lifecycleContext, + name: result.name, + runtime: result.runtime, + }, + `spawn("${name}") onSuccess` + ); return agent; }, }; } + + private async invokeLifecycleHook( + hook: ((context: T) => void | Promise) | undefined, + context: T, + label: string + ): Promise { + if (!hook) { + return; + } + try { + await hook(context); + } catch (error) { + console.warn(`[AgentRelay] ${label} hook threw`, error); + } + } + + private resetAgentLifecycleState(name: string): void { + this.readyAgents.delete(name); + this.messageReadyAgents.delete(name); + this.exitedAgents.delete(name); + this.idleAgents.delete(name); + } + + private normalizeReleaseOptions(reasonOrOptions?: string | ReleaseOptions): ReleaseOptions { + if (typeof reasonOrOptions === 'string' || reasonOrOptions === undefined) { + return { reason: reasonOrOptions }; + } + return reasonOrOptions; + } }