From a17ffb382aa1d777e5e98fdc26c6cc9efd4898c0 Mon Sep 17 00:00:00 2001 From: Brian O'Kelley Date: Thu, 23 Apr 2026 07:03:37 -0400 Subject: [PATCH] feat(a2a)!: checkpoint/from_checkpoint API, harden context-id retention MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follow-up to PR 251 addressing a post-merge audit. Highlights: - **Ordering fix**: `_context_id` and `_active_task_id` now commit only after `_process_task_response` and the idempotency-error check both succeed. Previously, a raise from either would leave the adapter advanced to reflect a response the caller never received, potentially orphaning the in-flight task on retry. - **Rename**: `pending_task_id` → `active_task_id`. The name now matches the semantic ("server-side task the next call must echo to resume"). - **Checkpoint API**: `ADCPClient.checkpoint()` returns a typed `Checkpoint` (TypedDict with `agent_id`, `context_id`, `active_task_id`). `ADCPClient.from_checkpoint(config, state)` rehydrates both ids. Fixes the advertised Redis-resume story for HITL flows — persisting only `context_id` would orphan the pending task. Restore validates `agent_id` against the target config so a checkpoint minted for Agent A can't leak session tokens to Agent B. - **Enum coupling**: `_NONTERMINAL_TASK_STATES` is now `frozenset[TaskState]`, so an upstream rename in a2a-sdk becomes a type error instead of a silent retention regression. - **`unknown` TaskState**: explicit — clears `active_task_id` and logs a warning so operators notice if a server starts emitting it. - **Empty-string guard**: `context_id=""` from `os.getenv(...) or ""` patterns is now treated as "not provided" instead of being echoed on the wire. - **Coverage gaps**: added unit tests for `submitted`/`auth-required` retention, `canceled`/`rejected`/`unknown` clearing, ordering invariant under both raised and converted-to-FAILED exception paths, and checkpoint mis-restore rejection. BREAKING CHANGE: `ADCPClient.pending_task_id` is now `ADCPClient.active_task_id` (same for `A2AAdapter`). The constructor's `context_id=` kwarg and `reset_context()` now raise `TypeError` (was `ValueError`) on non-A2A protocols — the string value is fine, the operation doesn't apply to MCP. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/adcp/client.py | 142 ++++++++-- src/adcp/protocols/a2a.py | 96 +++++-- tests/integration/test_a2a_context_id.py | 8 +- tests/test_protocols.py | 322 +++++++++++++++++++++-- 4 files changed, 507 insertions(+), 61 deletions(-) diff --git a/src/adcp/client.py b/src/adcp/client.py index 2dea9f78..ee8c93fe 100644 --- a/src/adcp/client.py +++ b/src/adcp/client.py @@ -11,7 +11,7 @@ import time from collections.abc import Callable, Iterator from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypedDict from a2a.types import Task, TaskStatusUpdateEvent from pydantic import BaseModel @@ -296,6 +296,27 @@ logger = logging.getLogger(__name__) +class Checkpoint(TypedDict): + """Persistable session-resume state for an A2A ``ADCPClient``. + + The minimal set of fields needed to reconnect to an in-flight A2A + conversation after a process restart. Produced by + ``ADCPClient.checkpoint()``; consumed by + ``ADCPClient.from_checkpoint()``. + + - ``agent_id`` — binds the checkpoint to the agent that minted it, + so a restore against the wrong ``AgentConfig`` fails loudly + instead of sending Agent A's ids to Agent B. + - ``context_id`` — the A2A conversation id. + - ``active_task_id`` — the in-flight task the next message must + echo; ``None`` if no task is pending. + """ + + agent_id: str + context_id: str | None + active_task_id: str | None + + class ADCPClient: """Client for interacting with a single AdCP agent.""" @@ -373,7 +394,12 @@ def __init__( multiple concurrent briefs with the same agent, construct one client per brief rather than sharing. - Raises ``ValueError`` if passed with a non-A2A protocol. + For HITL flows that can span a process restart mid-task, + use ``checkpoint()`` / ``from_checkpoint()`` instead of + persisting ``context_id`` alone — full resume state is + both ``context_id`` AND ``active_task_id``. + + Raises ``TypeError`` if passed with a non-A2A protocol. """ self.agent_config = agent_config self.webhook_url_template = webhook_url_template @@ -414,10 +440,14 @@ def __init__( # in dev/test, warn in production — see ``ValidationHookConfig`` docs). self.adapter.configure_validation(validation) - if context_id is not None: + if context_id: + # Empty string is treated as "not provided" — callers using + # ``context_id=os.getenv("...") or ""`` patterns shouldn't + # silently seed an empty id on the wire. if not isinstance(self.adapter, A2AAdapter): - raise ValueError( - "context_id is only supported for A2A protocol; " f"got {agent_config.protocol}" + raise TypeError( + f"context_id is only supported for A2A protocol; " + f"got {agent_config.protocol}" ) self.adapter.set_context_id(context_id) @@ -442,17 +472,21 @@ def context_id(self) -> str | None: Not safe for concurrent calls on the same client — the adapter mutates this on every response. Rule of thumb: one ADCPClient - per A2A conversation. Persist this value (e.g., Redis keyed by - your brief id) to resume across process restarts by passing it - to ``ADCPClient(context_id=...)``. + per A2A conversation. + + For simple completed-task resume, persist this value and pass + it to ``ADCPClient(context_id=...)``. For HITL flows that may + restart mid-``input-required``, use ``checkpoint()`` / + ``from_checkpoint()`` — full resume state is both this id AND + ``active_task_id``. """ if isinstance(self.adapter, A2AAdapter): return self.adapter.context_id return None @property - def pending_task_id(self) -> str | None: - """A2A task_id pending resume, or None if no task is in-flight. + def active_task_id(self) -> str | None: + """A2A task_id the next send must echo to resume the same task. Set when the last A2A response was non-terminal (``input-required``, ``working``, ``submitted``, @@ -460,10 +494,14 @@ def pending_task_id(self) -> str | None: outbound message so the server resumes the same task. Clears automatically when the task reaches a terminal state. + Full resume state is *both* ``context_id`` and + ``active_task_id`` — persist both (or use ``checkpoint()``) to + survive a process restart mid-HITL without orphaning the task. + Returns ``None`` for non-A2A clients. """ if isinstance(self.adapter, A2AAdapter): - return self.adapter.pending_task_id + return self.adapter.active_task_id return None def reset_context(self, context_id: str | None = None) -> None: @@ -477,18 +515,92 @@ def reset_context(self, context_id: str | None = None) -> None: client-supplied ids into their own session format; the client auto-adopts the rewritten id on the next response. - Also clears any pending_task_id — starting a new conversation + Also clears any active_task_id — starting a new conversation discards any in-flight task on the old one. - Raises ``ValueError`` when called on a non-A2A client. + Raises ``TypeError`` when called on a non-A2A client. """ if not isinstance(self.adapter, A2AAdapter): - raise ValueError( - "reset_context is only supported for A2A protocol; " + raise TypeError( + f"reset_context is only supported for A2A protocol; " f"got {self.agent_config.protocol}" ) self.adapter.set_context_id(context_id) + def checkpoint(self) -> Checkpoint: + """Return the minimal state needed to resume this A2A session. + + Full resume for HITL / multi-turn flows requires *both* + ``context_id`` (which conversation) AND ``active_task_id`` + (which in-flight task to echo). Persisting only ``context_id`` + reconnects to the right conversation but orphans the pending + task server-side — the next send starts a new task under the + same context, and the original ``input-required`` task is + abandoned. + + The returned dict also carries ``agent_id`` so a later + ``from_checkpoint`` call against a different ``AgentConfig`` + fails loudly instead of sending one agent's session ids to + another. + + Pair with ``ADCPClient.from_checkpoint(agent_config, state)``. + + Returns a fully-populated ``Checkpoint`` on non-A2A clients + with ``context_id``/``active_task_id`` set to ``None``, so + generic persist-and-restore code can call this without + branching on protocol. + """ + return Checkpoint( + agent_id=self.agent_config.id, + context_id=self.context_id, + active_task_id=self.active_task_id, + ) + + @classmethod + def from_checkpoint( + cls, + agent_config: AgentConfig, + state: Checkpoint, + **kwargs: Any, + ) -> ADCPClient: + """Rehydrate an ADCPClient from a prior ``checkpoint()``. + + Restores both ``context_id`` and ``active_task_id`` so a process + restart mid-``input-required`` can resume the same task, not + orphan it. Accepts the same keyword arguments as ``__init__`` + (signing, strict_idempotency, etc.) — the checkpoint only + carries session-resume state; operational config is re-supplied + by the caller. + + Raises ``ValueError`` if the checkpoint's ``agent_id`` doesn't + match ``agent_config.id`` — a checkpoint minted for Agent A + must not be restored onto Agent B, or the client will leak + Agent A's opaque session ids to Agent B on the next message. + + Raises ``TypeError`` on a non-A2A ``agent_config`` if the + checkpoint carries a non-empty ``context_id`` or + ``active_task_id`` — session-resume state on a protocol that + doesn't support it would be silently dropped, masking bugs. + An empty/absent checkpoint round-trips cleanly on any protocol. + """ + saved_agent_id = state.get("agent_id") if state else None + if saved_agent_id and saved_agent_id != agent_config.id: + raise ValueError( + f"checkpoint was minted for agent {saved_agent_id!r}, " + f"cannot restore against {agent_config.id!r}" + ) + context_id = state.get("context_id") if state else None + active_task_id = state.get("active_task_id") if state else None + if active_task_id and agent_config.protocol != Protocol.A2A: + raise TypeError( + f"active_task_id in checkpoint is only supported for A2A " + f"protocol; got {agent_config.protocol}" + ) + client = cls(agent_config, context_id=context_id, **kwargs) + if active_task_id and isinstance(client.adapter, A2AAdapter): + client.adapter._restore_active_task_id(active_task_id) + return client + async def _ensure_idempotency_capability(self) -> None: """Verify the seller positively declares idempotency support in capabilities. diff --git a/src/adcp/protocols/a2a.py b/src/adcp/protocols/a2a.py index 72aea1c3..3fdcbbb5 100644 --- a/src/adcp/protocols/a2a.py +++ b/src/adcp/protocols/a2a.py @@ -17,6 +17,7 @@ Role, SendMessageRequest, Task, + TaskState, TextPart, ) @@ -48,11 +49,18 @@ class A2AAdapter(ProtocolAdapter): # in-flight states). While the adapter holds a task_id in one of # these states, the next outbound Message must echo it back so the # server resumes the same task rather than orphaning it and starting - # a new one. Terminal states (completed/failed/canceled/rejected) - # clear the retained task_id — subsequent calls in the conversation - # are new tasks. - _NONTERMINAL_TASK_STATES = frozenset( - {"submitted", "working", "input-required", "auth-required"} + # a new one. Everything else — completed/failed/canceled/rejected + # (terminal) and the defensive unknown state — clears the retained + # task_id so subsequent calls start a fresh task. Coupled directly + # to the TaskState enum so a rename upstream is a type error, not a + # silent behavior change. + _NONTERMINAL_TASK_STATES: frozenset[TaskState] = frozenset( + { + TaskState.submitted, + TaskState.working, + TaskState.input_required, + TaskState.auth_required, + } ) def __init__(self, agent_config: AgentConfig): @@ -72,8 +80,8 @@ def __init__(self, agent_config: AgentConfig): # non-terminal (input-required, working, etc). On terminal states # this clears to None so the next call starts a new task under # the same context_id. Without this, resume of an input-required - # task orphans the server-side pending task. - self._pending_task_id: str | None = None + # task orphans the server-side in-flight task. + self._active_task_id: str | None = None @property def context_id(self) -> str | None: @@ -91,15 +99,17 @@ def context_id(self) -> str | None: return self._context_id @property - def pending_task_id(self) -> str | None: - """A2A task_id retained for resume, or None if no task is pending. + def active_task_id(self) -> str | None: + """A2A task_id the next send must echo to resume the same task. Populated when the last response was non-terminal (e.g. - ``input-required``). Echoed on the next outbound message so the - server continues the same task. Clears to None on terminal - states (``completed``/``failed``/``canceled``). + ``input-required``, ``working``). Echoed on the next outbound + message so the server continues the same task. Clears to None + on terminal states (``completed``/``failed``/``canceled``/ + ``rejected``) — and defensively on ``unknown`` — so subsequent + calls start a fresh task under the same context. """ - return self._pending_task_id + return self._active_task_id def set_context_id(self, context_id: str | None) -> None: """Set the A2A context_id for subsequent message sends. @@ -113,11 +123,21 @@ def set_context_id(self, context_id: str | None) -> None: format and return the rewritten value on the next response — at which point this adapter auto-adopts it. - Also clears any retained ``pending_task_id``: switching context + Also clears any retained ``active_task_id``: switching context always starts a fresh task under the new context. """ self._context_id = context_id - self._pending_task_id = None + self._active_task_id = None + + def _restore_active_task_id(self, task_id: str) -> None: + """Internal: rehydrate ``active_task_id`` from a persisted checkpoint. + + Separate from normal in-flight state updates so the checkpoint + restore path is an explicit contract — a rename of the storage + field fails loudly here instead of silently breaking resume. + Intended for ``ADCPClient.from_checkpoint`` only. + """ + self._active_task_id = task_id async def _get_httpx_client(self) -> httpx.AsyncClient: """Get or create the HTTP client with connection pooling.""" @@ -279,7 +299,7 @@ async def _call_a2a_tool( role=Role.user, parts=[Part(root=data_part)], context_id=self._context_id, - task_id=self._pending_task_id, + task_id=self._active_task_id, ) else: # Natural language invocation (flexible) @@ -290,7 +310,7 @@ async def _call_a2a_tool( role=Role.user, parts=[Part(root=text_part)], context_id=self._context_id, - task_id=self._pending_task_id, + task_id=self._active_task_id, ) # Build request params @@ -358,26 +378,48 @@ async def _call_a2a_tool( # Result can be either Task or Message if isinstance(result, Task): - # Retain the server-assigned context_id so subsequent - # turns continue the same A2A conversation. Task.context_id - # is required by a2a-sdk, so no None-guard needed. - self._context_id = result.context_id - # Retain task_id only while the task is non-terminal. - # On terminal states (completed/failed/canceled/rejected) - # the next send must NOT echo this task_id — it starts a - # fresh task under the same context. + # Compute next-turn state from the response but do NOT + # commit yet — _process_task_response and the idempotency + # check below can raise, and leaving the adapter advanced + # after an exception would orphan the legitimate in-flight + # task on the next retry. Commit only after both succeed. + # Task.context_id is required by a2a-sdk, so no None-guard. + next_context_id = result.context_id if result.status.state in self._NONTERMINAL_TASK_STATES: - self._pending_task_id = result.id + next_active_task_id: str | None = result.id else: - self._pending_task_id = None + # Terminal states (completed/failed/canceled/rejected) + # clear the retained task_id — subsequent calls start + # a new task under the same context. The defensive + # unknown state falls here too (don't cling to an + # undefined task); warn so operators notice if a + # server starts emitting it. + next_active_task_id = None + if result.status.state == TaskState.unknown: + logger.warning( + "A2A agent %s returned TaskState.unknown for " + "task_id=%s; clearing active_task_id and " + "starting a fresh task on next call", + self.agent_config.id, + result.id, + ) task_result = self._process_task_response(result, debug_info) _idempotency.raise_for_idempotency_error( tool_name, task_result.data, self.agent_config.id ) + # All raise-sites have passed; commit next-turn state so + # the adapter reflects the response the caller is about + # to receive. + self._context_id = next_context_id + self._active_task_id = next_active_task_id # Post-receive schema validation. Only runs when the task # carries data (terminal completion); async interim states # with ``data=None`` skip naturally. Strict mode flips the # TaskResult to FAILED; warn mode logs and passes through. + # Runs after the state commit — a payload-schema failure + # doesn't invalidate the A2A envelope ids, and the next + # call in the same conversation should still target the + # right session. if task_result.success and task_result.data is not None: response_outcome = validate_incoming_response( tool_name, task_result.data, self.response_validation_mode diff --git a/tests/integration/test_a2a_context_id.py b/tests/integration/test_a2a_context_id.py index 957bafbd..d0db46de 100644 --- a/tests/integration/test_a2a_context_id.py +++ b/tests/integration/test_a2a_context_id.py @@ -370,13 +370,13 @@ async def test_task_id_echoed_on_resume_after_input_required(): r1 = await client.adapter.create_media_buy({"budget": 1000}) # After an input-required response the adapter stashed both ids. assert client.context_id is not None - assert client.pending_task_id is not None - retained_task_id = client.pending_task_id + assert client.active_task_id is not None + retained_task_id = client.active_task_id retained_context_id = client.context_id r2 = await client.adapter.create_media_buy({"approval": "yes"}) - # Terminal state on turn 2 cleared pending_task_id; context stays. - assert client.pending_task_id is None + # Terminal state on turn 2 cleared active_task_id; context stays. + assert client.active_task_id is None assert client.context_id == retained_context_id assert len(executor.observations) == 2 diff --git a/tests/test_protocols.py b/tests/test_protocols.py index 4e5d7a4a..aae502d9 100644 --- a/tests/test_protocols.py +++ b/tests/test_protocols.py @@ -583,14 +583,14 @@ async def test_task_id_retained_when_state_is_input_required(self, a2a_config): with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): await adapter._call_a2a_tool("create_media_buy", {}) - assert adapter.pending_task_id == "task-hitl-1" + assert adapter.active_task_id == "task-hitl-1" await adapter._call_a2a_tool("create_media_buy", {"approval": "yes"}) assert self._captured_task_id(mock_a2a_client.send_message, 0) is None assert self._captured_task_id(mock_a2a_client.send_message, 1) == "task-hitl-1" # Terminal state clears the pending task. - assert adapter.pending_task_id is None + assert adapter.active_task_id is None @pytest.mark.asyncio async def test_task_id_cleared_on_completed_state(self, a2a_config): @@ -618,7 +618,7 @@ async def test_task_id_cleared_on_completed_state(self, a2a_config): with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): await adapter._call_a2a_tool("get_products", {}) - assert adapter.pending_task_id is None + assert adapter.active_task_id is None await adapter._call_a2a_tool("create_media_buy", {}) @@ -645,7 +645,7 @@ async def test_task_id_cleared_on_failed_state(self, a2a_config): with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): await adapter._call_a2a_tool("get_products", {}) - assert adapter.pending_task_id is None + assert adapter.active_task_id is None @pytest.mark.asyncio async def test_task_id_retained_on_working_state(self, a2a_config): @@ -667,7 +667,7 @@ async def test_task_id_retained_on_working_state(self, a2a_config): with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): await adapter._call_a2a_tool("create_media_buy", {}) - assert adapter.pending_task_id == "task-in-progress" + assert adapter.active_task_id == "task-in-progress" @pytest.mark.asyncio async def test_set_context_id_clears_pending_task(self, a2a_config): @@ -675,12 +675,12 @@ async def test_set_context_id_clears_pending_task(self, a2a_config): conversation shouldn't try to resume a task from the old one.""" adapter = A2AAdapter(a2a_config) adapter._context_id = "old-ctx" - adapter._pending_task_id = "old-task" + adapter._active_task_id = "old-task" adapter.set_context_id("new-ctx") assert adapter.context_id == "new-ctx" - assert adapter.pending_task_id is None + assert adapter.active_task_id is None @pytest.mark.asyncio async def test_server_rebinding_context_id_is_honored(self, a2a_config): @@ -702,6 +702,185 @@ async def test_server_rebinding_context_id_is_honored(self, a2a_config): assert self._captured_context_id(mock_a2a_client.send_message) == "buyer-proposed" assert adapter.context_id == "server-overrode" + @pytest.mark.asyncio + async def test_task_id_retained_on_submitted_state(self, a2a_config): + """'submitted' is non-terminal — server has accepted the task but + not started processing. Adapter must retain task_id so the next + call lands on the same queued task instead of stacking a duplicate. + """ + adapter = A2AAdapter(a2a_config) + + submitted = create_mock_a2a_task( + task_id="task-queued", + context_id="ctx", + state="submitted", + parts=[TextPart(text="accepted")], + ) + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock( + return_value=SendMessageSuccessResponse(result=submitted) + ) + + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + await adapter._call_a2a_tool("create_media_buy", {}) + + assert adapter.active_task_id == "task-queued" + + @pytest.mark.asyncio + async def test_task_id_retained_on_auth_required_state(self, a2a_config): + """'auth-required' is non-terminal — server is blocked pending + buyer-side auth. Adapter must retain task_id so the resubmit with + credentials lands on the same task.""" + adapter = A2AAdapter(a2a_config) + + auth_required = create_mock_a2a_task( + task_id="task-needs-auth", + context_id="ctx", + state="auth-required", + parts=[TextPart(text="authenticate and retry")], + ) + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock( + return_value=SendMessageSuccessResponse(result=auth_required) + ) + + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + await adapter._call_a2a_tool("create_media_buy", {}) + + assert adapter.active_task_id == "task-needs-auth" + + @pytest.mark.asyncio + async def test_task_id_cleared_on_canceled_state(self, a2a_config): + """'canceled' is terminal — adapter must clear task_id so the + next call starts fresh instead of echoing a dead task.""" + adapter = A2AAdapter(a2a_config) + + canceled = create_mock_a2a_task( + task_id="task-canceled", + context_id="ctx", + state="canceled", + parts=[TextPart(text="canceled by buyer")], + ) + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock( + return_value=SendMessageSuccessResponse(result=canceled) + ) + + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + await adapter._call_a2a_tool("create_media_buy", {}) + + assert adapter.active_task_id is None + + @pytest.mark.asyncio + async def test_task_id_cleared_on_rejected_state(self, a2a_config): + """'rejected' is terminal — adapter must clear task_id.""" + adapter = A2AAdapter(a2a_config) + + rejected = create_mock_a2a_task( + task_id="task-rejected", + context_id="ctx", + state="rejected", + parts=[TextPart(text="rejected by agent")], + ) + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock( + return_value=SendMessageSuccessResponse(result=rejected) + ) + + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + await adapter._call_a2a_tool("create_media_buy", {}) + + assert adapter.active_task_id is None + + @pytest.mark.asyncio + async def test_task_id_cleared_on_unknown_state(self, a2a_config): + """'unknown' is treated as terminal — don't cling to a task in + an undefined state. Adapter should clear and warn.""" + adapter = A2AAdapter(a2a_config) + + unknown = create_mock_a2a_task( + task_id="task-mystery", + context_id="ctx", + state="unknown", + parts=[TextPart(text="???")], + ) + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock( + return_value=SendMessageSuccessResponse(result=unknown) + ) + + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + await adapter._call_a2a_tool("create_media_buy", {}) + + assert adapter.active_task_id is None + + @pytest.mark.asyncio + async def test_state_not_committed_when_post_processing_raises(self, a2a_config): + """If _process_task_response raises, the adapter must NOT advance + its state — otherwise a retry echoes a task_id the caller never + saw a response for. Uses IdempotencyConflictError because it's in + the adapter's allow-list to propagate (most exceptions get caught + and converted to TaskResult(FAILED), but typed idempotency errors + bubble out). Either way the invariant is the same: pre-call state + must survive a raise from post-processing. + """ + from adcp.exceptions import IdempotencyConflictError + + adapter = A2AAdapter(a2a_config) + adapter._context_id = "prior-ctx" + adapter._active_task_id = "prior-task" + + response_task = create_mock_a2a_task( + task_id="server-new-task", + context_id="server-new-ctx", + state="input-required", + ) + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock( + return_value=SendMessageSuccessResponse(result=response_task) + ) + + boom = IdempotencyConflictError("create_media_buy", errors=[]) + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + with patch.object(adapter, "_process_task_response", side_effect=boom): + with pytest.raises(IdempotencyConflictError): + await adapter._call_a2a_tool("create_media_buy", {}) + + assert adapter.context_id == "prior-ctx" + assert adapter.active_task_id == "prior-task" + + @pytest.mark.asyncio + async def test_state_not_committed_when_exception_converts_to_failed(self, a2a_config): + """Mirror of the IdempotencyConflictError test for the generic + exception path. Most exceptions in post-processing get caught by + the broad ``except Exception`` at the end of _call_a2a_tool and + converted to ``TaskResult(FAILED)`` — the caller never sees the + exception, but the adapter still must not have advanced state, + or the next call echoes a task_id the caller never saw succeed. + """ + adapter = A2AAdapter(a2a_config) + adapter._context_id = "prior-ctx" + adapter._active_task_id = "prior-task" + + response_task = create_mock_a2a_task( + task_id="server-new-task", + context_id="server-new-ctx", + state="input-required", + ) + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock( + return_value=SendMessageSuccessResponse(result=response_task) + ) + + boom = RuntimeError("post-processing blew up") + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + with patch.object(adapter, "_process_task_response", side_effect=boom): + result = await adapter._call_a2a_tool("create_media_buy", {}) + + assert result.status == TaskStatus.FAILED + assert adapter.context_id == "prior-ctx" + assert adapter.active_task_id == "prior-task" + class TestADCPClientContextId: """Tests for the ADCPClient-level contextId surface.""" @@ -737,14 +916,14 @@ def test_reset_context_with_new_id(self, a2a_config): def test_constructor_rejects_context_id_on_non_a2a(self, mcp_config): from adcp.client import ADCPClient - with pytest.raises(ValueError, match="only supported for A2A"): + with pytest.raises(TypeError, match="only supported for A2A"): ADCPClient(mcp_config, context_id="nope") def test_reset_context_rejects_on_non_a2a(self, mcp_config): from adcp.client import ADCPClient client = ADCPClient(mcp_config) - with pytest.raises(ValueError, match="only supported for A2A"): + with pytest.raises(TypeError, match="only supported for A2A"): client.reset_context("anything") def test_context_id_property_returns_none_on_non_a2a(self, mcp_config): @@ -753,20 +932,133 @@ def test_context_id_property_returns_none_on_non_a2a(self, mcp_config): client = ADCPClient(mcp_config) assert client.context_id is None - def test_pending_task_id_property_exposes_adapter_state(self, a2a_config): + def test_active_task_id_property_exposes_adapter_state(self, a2a_config): from adcp.client import ADCPClient client = ADCPClient(a2a_config) - assert client.pending_task_id is None + assert client.active_task_id is None assert isinstance(client.adapter, A2AAdapter) - client.adapter._pending_task_id = "task-mid-flight" - assert client.pending_task_id == "task-mid-flight" + client.adapter._active_task_id = "task-mid-flight" + assert client.active_task_id == "task-mid-flight" - def test_pending_task_id_returns_none_on_non_a2a(self, mcp_config): + def test_active_task_id_returns_none_on_non_a2a(self, mcp_config): from adcp.client import ADCPClient client = ADCPClient(mcp_config) - assert client.pending_task_id is None + assert client.active_task_id is None + + def test_empty_string_context_id_is_not_seeded(self, a2a_config): + """``context_id=""`` from ``os.getenv(...) or ""`` patterns must + not silently seed an empty id on the wire.""" + from adcp.client import ADCPClient + + client = ADCPClient(a2a_config, context_id="") + assert client.context_id is None + + def test_empty_string_context_id_ok_on_non_a2a(self, mcp_config): + """Empty context_id should be treated as 'not provided' on any + protocol — no TypeError on MCP, same as passing None.""" + from adcp.client import ADCPClient + + client = ADCPClient(mcp_config, context_id="") + assert client.context_id is None + + def test_checkpoint_returns_all_fields(self, a2a_config): + from adcp.client import ADCPClient + + client = ADCPClient(a2a_config, context_id="ctx-123") + assert isinstance(client.adapter, A2AAdapter) + client.adapter._active_task_id = "task-in-flight" + + state = client.checkpoint() + assert state == { + "agent_id": a2a_config.id, + "context_id": "ctx-123", + "active_task_id": "task-in-flight", + } + + def test_checkpoint_on_non_a2a_carries_agent_id_and_nones(self, mcp_config): + from adcp.client import ADCPClient + + client = ADCPClient(mcp_config) + assert client.checkpoint() == { + "agent_id": mcp_config.id, + "context_id": None, + "active_task_id": None, + } + + def test_from_checkpoint_restores_both_ids(self, a2a_config): + """Full resume requires both ids — persisting only context_id + orphans the pending task server-side.""" + from adcp.client import ADCPClient + + state = { + "agent_id": a2a_config.id, + "context_id": "ctx-resume", + "active_task_id": "task-hitl", + } + client = ADCPClient.from_checkpoint(a2a_config, state) + + assert client.context_id == "ctx-resume" + assert client.active_task_id == "task-hitl" + + def test_from_checkpoint_with_empty_state_is_fresh_client(self, a2a_config): + from adcp.client import ADCPClient + + client = ADCPClient.from_checkpoint(a2a_config, {}) + assert client.context_id is None + assert client.active_task_id is None + + def test_from_checkpoint_roundtrips(self, a2a_config): + from adcp.client import ADCPClient + + original = ADCPClient(a2a_config, context_id="ctx-orig") + assert isinstance(original.adapter, A2AAdapter) + original.adapter._active_task_id = "task-orig" + + restored = ADCPClient.from_checkpoint(a2a_config, original.checkpoint()) + assert restored.context_id == original.context_id + assert restored.active_task_id == original.active_task_id + + def test_from_checkpoint_rejects_mismatched_agent_id(self, a2a_config): + """A checkpoint minted for Agent A must not be restored onto + Agent B — that would leak Agent A's opaque session ids to a + different vendor on the next message.""" + from adcp.client import ADCPClient + + state = { + "agent_id": "other-agent", + "context_id": "ctx-from-other", + "active_task_id": "task-from-other", + } + with pytest.raises(ValueError, match="minted for agent"): + ADCPClient.from_checkpoint(a2a_config, state) + + def test_from_checkpoint_raises_on_non_a2a_with_active_task(self, mcp_config): + """Silently dropping active_task_id on a non-A2A restore would + mask bugs — raise instead.""" + from adcp.client import ADCPClient + + state = { + "agent_id": mcp_config.id, + "context_id": None, + "active_task_id": "task-x", + } + with pytest.raises(TypeError, match="active_task_id"): + ADCPClient.from_checkpoint(mcp_config, state) + + def test_from_checkpoint_empty_on_mcp_is_fine(self, mcp_config): + """Empty/None checkpoint must round-trip on any protocol.""" + from adcp.client import ADCPClient + + state = { + "agent_id": mcp_config.id, + "context_id": None, + "active_task_id": None, + } + client = ADCPClient.from_checkpoint(mcp_config, state) + assert client.context_id is None + assert client.active_task_id is None class TestMCPAdapter: