Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 127 additions & 15 deletions src/adcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import time
from collections.abc import Callable, Iterator
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypedDict

from a2a.types import Task, TaskStatusUpdateEvent
from pydantic import BaseModel
Expand Down Expand Up @@ -296,6 +296,27 @@
logger = logging.getLogger(__name__)


class Checkpoint(TypedDict):
"""Persistable session-resume state for an A2A ``ADCPClient``.

The minimal set of fields needed to reconnect to an in-flight A2A
conversation after a process restart. Produced by
``ADCPClient.checkpoint()``; consumed by
``ADCPClient.from_checkpoint()``.

- ``agent_id`` — binds the checkpoint to the agent that minted it,
so a restore against the wrong ``AgentConfig`` fails loudly
instead of sending Agent A's ids to Agent B.
- ``context_id`` — the A2A conversation id.
- ``active_task_id`` — the in-flight task the next message must
echo; ``None`` if no task is pending.
"""

agent_id: str
context_id: str | None
active_task_id: str | None


class ADCPClient:
"""Client for interacting with a single AdCP agent."""

Expand Down Expand Up @@ -373,7 +394,12 @@ def __init__(
multiple concurrent briefs with the same agent, construct
one client per brief rather than sharing.

Raises ``ValueError`` if passed with a non-A2A protocol.
For HITL flows that can span a process restart mid-task,
use ``checkpoint()`` / ``from_checkpoint()`` instead of
persisting ``context_id`` alone — full resume state is
both ``context_id`` AND ``active_task_id``.

Raises ``TypeError`` if passed with a non-A2A protocol.
"""
self.agent_config = agent_config
self.webhook_url_template = webhook_url_template
Expand Down Expand Up @@ -414,10 +440,14 @@ def __init__(
# in dev/test, warn in production — see ``ValidationHookConfig`` docs).
self.adapter.configure_validation(validation)

if context_id is not None:
if context_id:
# Empty string is treated as "not provided" — callers using
# ``context_id=os.getenv("...") or ""`` patterns shouldn't
# silently seed an empty id on the wire.
if not isinstance(self.adapter, A2AAdapter):
raise ValueError(
"context_id is only supported for A2A protocol; " f"got {agent_config.protocol}"
raise TypeError(
f"context_id is only supported for A2A protocol; "
f"got {agent_config.protocol}"
)
self.adapter.set_context_id(context_id)

Expand All @@ -442,28 +472,36 @@ def context_id(self) -> str | None:

Not safe for concurrent calls on the same client — the adapter
mutates this on every response. Rule of thumb: one ADCPClient
per A2A conversation. Persist this value (e.g., Redis keyed by
your brief id) to resume across process restarts by passing it
to ``ADCPClient(context_id=...)``.
per A2A conversation.

For simple completed-task resume, persist this value and pass
it to ``ADCPClient(context_id=...)``. For HITL flows that may
restart mid-``input-required``, use ``checkpoint()`` /
``from_checkpoint()`` — full resume state is both this id AND
``active_task_id``.
"""
if isinstance(self.adapter, A2AAdapter):
return self.adapter.context_id
return None

@property
def pending_task_id(self) -> str | None:
"""A2A task_id pending resume, or None if no task is in-flight.
def active_task_id(self) -> str | None:
"""A2A task_id the next send must echo to resume the same task.

Set when the last A2A response was non-terminal
(``input-required``, ``working``, ``submitted``,
``auth-required``). The adapter echoes this id on the next
outbound message so the server resumes the same task. Clears
automatically when the task reaches a terminal state.

Full resume state is *both* ``context_id`` and
``active_task_id`` — persist both (or use ``checkpoint()``) to
survive a process restart mid-HITL without orphaning the task.

Returns ``None`` for non-A2A clients.
"""
if isinstance(self.adapter, A2AAdapter):
return self.adapter.pending_task_id
return self.adapter.active_task_id
return None

def reset_context(self, context_id: str | None = None) -> None:
Expand All @@ -477,18 +515,92 @@ def reset_context(self, context_id: str | None = None) -> None:
client-supplied ids into their own session format; the client
auto-adopts the rewritten id on the next response.

Also clears any pending_task_id — starting a new conversation
Also clears any active_task_id — starting a new conversation
discards any in-flight task on the old one.

Raises ``ValueError`` when called on a non-A2A client.
Raises ``TypeError`` when called on a non-A2A client.
"""
if not isinstance(self.adapter, A2AAdapter):
raise ValueError(
"reset_context is only supported for A2A protocol; "
raise TypeError(
f"reset_context is only supported for A2A protocol; "
f"got {self.agent_config.protocol}"
)
self.adapter.set_context_id(context_id)

def checkpoint(self) -> Checkpoint:
"""Return the minimal state needed to resume this A2A session.

Full resume for HITL / multi-turn flows requires *both*
``context_id`` (which conversation) AND ``active_task_id``
(which in-flight task to echo). Persisting only ``context_id``
reconnects to the right conversation but orphans the pending
task server-side — the next send starts a new task under the
same context, and the original ``input-required`` task is
abandoned.

The returned dict also carries ``agent_id`` so a later
``from_checkpoint`` call against a different ``AgentConfig``
fails loudly instead of sending one agent's session ids to
another.

Pair with ``ADCPClient.from_checkpoint(agent_config, state)``.

Returns a fully-populated ``Checkpoint`` on non-A2A clients
with ``context_id``/``active_task_id`` set to ``None``, so
generic persist-and-restore code can call this without
branching on protocol.
"""
return Checkpoint(
agent_id=self.agent_config.id,
context_id=self.context_id,
active_task_id=self.active_task_id,
)

@classmethod
def from_checkpoint(
cls,
agent_config: AgentConfig,
state: Checkpoint,
**kwargs: Any,
) -> ADCPClient:
"""Rehydrate an ADCPClient from a prior ``checkpoint()``.

Restores both ``context_id`` and ``active_task_id`` so a process
restart mid-``input-required`` can resume the same task, not
orphan it. Accepts the same keyword arguments as ``__init__``
(signing, strict_idempotency, etc.) — the checkpoint only
carries session-resume state; operational config is re-supplied
by the caller.

Raises ``ValueError`` if the checkpoint's ``agent_id`` doesn't
match ``agent_config.id`` — a checkpoint minted for Agent A
must not be restored onto Agent B, or the client will leak
Agent A's opaque session ids to Agent B on the next message.

Raises ``TypeError`` on a non-A2A ``agent_config`` if the
checkpoint carries a non-empty ``context_id`` or
``active_task_id`` — session-resume state on a protocol that
doesn't support it would be silently dropped, masking bugs.
An empty/absent checkpoint round-trips cleanly on any protocol.
"""
saved_agent_id = state.get("agent_id") if state else None
if saved_agent_id and saved_agent_id != agent_config.id:
raise ValueError(
f"checkpoint was minted for agent {saved_agent_id!r}, "
f"cannot restore against {agent_config.id!r}"
)
context_id = state.get("context_id") if state else None
active_task_id = state.get("active_task_id") if state else None
if active_task_id and agent_config.protocol != Protocol.A2A:
raise TypeError(
f"active_task_id in checkpoint is only supported for A2A "
f"protocol; got {agent_config.protocol}"
)
client = cls(agent_config, context_id=context_id, **kwargs)
if active_task_id and isinstance(client.adapter, A2AAdapter):
client.adapter._restore_active_task_id(active_task_id)
return client

async def _ensure_idempotency_capability(self) -> None:
"""Verify the seller positively declares idempotency support in capabilities.

Expand Down
96 changes: 69 additions & 27 deletions src/adcp/protocols/a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Role,
SendMessageRequest,
Task,
TaskState,
TextPart,
)

Expand Down Expand Up @@ -48,11 +49,18 @@ class A2AAdapter(ProtocolAdapter):
# in-flight states). While the adapter holds a task_id in one of
# these states, the next outbound Message must echo it back so the
# server resumes the same task rather than orphaning it and starting
# a new one. Terminal states (completed/failed/canceled/rejected)
# clear the retained task_id — subsequent calls in the conversation
# are new tasks.
_NONTERMINAL_TASK_STATES = frozenset(
{"submitted", "working", "input-required", "auth-required"}
# a new one. Everything else — completed/failed/canceled/rejected
# (terminal) and the defensive unknown state — clears the retained
# task_id so subsequent calls start a fresh task. Coupled directly
# to the TaskState enum so a rename upstream is a type error, not a
# silent behavior change.
_NONTERMINAL_TASK_STATES: frozenset[TaskState] = frozenset(
{
TaskState.submitted,
TaskState.working,
TaskState.input_required,
TaskState.auth_required,
}
)

def __init__(self, agent_config: AgentConfig):
Expand All @@ -72,8 +80,8 @@ def __init__(self, agent_config: AgentConfig):
# non-terminal (input-required, working, etc). On terminal states
# this clears to None so the next call starts a new task under
# the same context_id. Without this, resume of an input-required
# task orphans the server-side pending task.
self._pending_task_id: str | None = None
# task orphans the server-side in-flight task.
self._active_task_id: str | None = None

@property
def context_id(self) -> str | None:
Expand All @@ -91,15 +99,17 @@ def context_id(self) -> str | None:
return self._context_id

@property
def pending_task_id(self) -> str | None:
"""A2A task_id retained for resume, or None if no task is pending.
def active_task_id(self) -> str | None:
"""A2A task_id the next send must echo to resume the same task.

Populated when the last response was non-terminal (e.g.
``input-required``). Echoed on the next outbound message so the
server continues the same task. Clears to None on terminal
states (``completed``/``failed``/``canceled``).
``input-required``, ``working``). Echoed on the next outbound
message so the server continues the same task. Clears to None
on terminal states (``completed``/``failed``/``canceled``/
``rejected``) — and defensively on ``unknown`` — so subsequent
calls start a fresh task under the same context.
"""
return self._pending_task_id
return self._active_task_id

def set_context_id(self, context_id: str | None) -> None:
"""Set the A2A context_id for subsequent message sends.
Expand All @@ -113,11 +123,21 @@ def set_context_id(self, context_id: str | None) -> None:
format and return the rewritten value on the next response — at
which point this adapter auto-adopts it.

Also clears any retained ``pending_task_id``: switching context
Also clears any retained ``active_task_id``: switching context
always starts a fresh task under the new context.
"""
self._context_id = context_id
self._pending_task_id = None
self._active_task_id = None

def _restore_active_task_id(self, task_id: str) -> None:
"""Internal: rehydrate ``active_task_id`` from a persisted checkpoint.

Separate from normal in-flight state updates so the checkpoint
restore path is an explicit contract — a rename of the storage
field fails loudly here instead of silently breaking resume.
Intended for ``ADCPClient.from_checkpoint`` only.
"""
self._active_task_id = task_id

async def _get_httpx_client(self) -> httpx.AsyncClient:
"""Get or create the HTTP client with connection pooling."""
Expand Down Expand Up @@ -279,7 +299,7 @@ async def _call_a2a_tool(
role=Role.user,
parts=[Part(root=data_part)],
context_id=self._context_id,
task_id=self._pending_task_id,
task_id=self._active_task_id,
)
else:
# Natural language invocation (flexible)
Expand All @@ -290,7 +310,7 @@ async def _call_a2a_tool(
role=Role.user,
parts=[Part(root=text_part)],
context_id=self._context_id,
task_id=self._pending_task_id,
task_id=self._active_task_id,
)

# Build request params
Expand Down Expand Up @@ -358,26 +378,48 @@ async def _call_a2a_tool(

# Result can be either Task or Message
if isinstance(result, Task):
# Retain the server-assigned context_id so subsequent
# turns continue the same A2A conversation. Task.context_id
# is required by a2a-sdk, so no None-guard needed.
self._context_id = result.context_id
# Retain task_id only while the task is non-terminal.
# On terminal states (completed/failed/canceled/rejected)
# the next send must NOT echo this task_id — it starts a
# fresh task under the same context.
# Compute next-turn state from the response but do NOT
# commit yet — _process_task_response and the idempotency
# check below can raise, and leaving the adapter advanced
# after an exception would orphan the legitimate in-flight
# task on the next retry. Commit only after both succeed.
# Task.context_id is required by a2a-sdk, so no None-guard.
next_context_id = result.context_id
if result.status.state in self._NONTERMINAL_TASK_STATES:
self._pending_task_id = result.id
next_active_task_id: str | None = result.id
else:
self._pending_task_id = None
# Terminal states (completed/failed/canceled/rejected)
# clear the retained task_id — subsequent calls start
# a new task under the same context. The defensive
# unknown state falls here too (don't cling to an
# undefined task); warn so operators notice if a
# server starts emitting it.
next_active_task_id = None
if result.status.state == TaskState.unknown:
logger.warning(
"A2A agent %s returned TaskState.unknown for "
"task_id=%s; clearing active_task_id and "
"starting a fresh task on next call",
self.agent_config.id,
result.id,
)
task_result = self._process_task_response(result, debug_info)
_idempotency.raise_for_idempotency_error(
tool_name, task_result.data, self.agent_config.id
)
# All raise-sites have passed; commit next-turn state so
# the adapter reflects the response the caller is about
# to receive.
self._context_id = next_context_id
self._active_task_id = next_active_task_id
# Post-receive schema validation. Only runs when the task
# carries data (terminal completion); async interim states
# with ``data=None`` skip naturally. Strict mode flips the
# TaskResult to FAILED; warn mode logs and passes through.
# Runs after the state commit — a payload-schema failure
# doesn't invalidate the A2A envelope ids, and the next
# call in the same conversation should still target the
# right session.
if task_result.success and task_result.data is not None:
response_outcome = validate_incoming_response(
tool_name, task_result.data, self.response_validation_mode
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/test_a2a_context_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,13 +370,13 @@ async def test_task_id_echoed_on_resume_after_input_required():
r1 = await client.adapter.create_media_buy({"budget": 1000})
# After an input-required response the adapter stashed both ids.
assert client.context_id is not None
assert client.pending_task_id is not None
retained_task_id = client.pending_task_id
assert client.active_task_id is not None
retained_task_id = client.active_task_id
retained_context_id = client.context_id

r2 = await client.adapter.create_media_buy({"approval": "yes"})
# Terminal state on turn 2 cleared pending_task_id; context stays.
assert client.pending_task_id is None
# Terminal state on turn 2 cleared active_task_id; context stays.
assert client.active_task_id is None
assert client.context_id == retained_context_id

assert len(executor.observations) == 2
Expand Down
Loading
Loading