diff --git a/src/adcp/decisioning/context.py b/src/adcp/decisioning/context.py index 2a40980a..eaadcc82 100644 --- a/src/adcp/decisioning/context.py +++ b/src/adcp/decisioning/context.py @@ -18,7 +18,7 @@ from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, ClassVar, Generic +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal from typing_extensions import TypeVar @@ -533,6 +533,16 @@ class RequestContext(ToolContext, Generic[TMeta]): * Idempotency scope? → don't touch; the framework owns this. * Logging request provenance? → log all four; they're cheap. + :param transport: The wire protocol that dispatched this call — + ``"mcp"`` or ``"a2a"``. ``None`` when ``RequestContext`` is + constructed in tests without a transport-aware ``ToolContext``, + or when a custom ``context_factory`` omits + ``metadata["transport"]``. Production dispatch always populates + this field. Note: even when the server is started with + ``transport="both"``, individual requests always resolve to + exactly one of ``"mcp"`` or ``"a2a"`` — this field never + carries ``"both"``. For code running outside a handler call + stack, read :data:`adcp.server.current_transport` instead. :param state: Sync reads of framework-owned in-flight workflow state. Default is :class:`adcp.decisioning.state._NotYetWiredStateReader` — returns empty values + emits one-time UserWarning per @@ -560,6 +570,7 @@ class RequestContext(ToolContext, Generic[TMeta]): auth_info: AuthInfo | None = None auth_principal: str | None = None buyer_agent: BuyerAgent | None = None + transport: Literal["mcp", "a2a"] | None = None now: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) state: StateReader = field(default_factory=_make_default_state_reader) resolve: ResourceResolver = field(default_factory=_make_default_resolver) diff --git a/src/adcp/decisioning/dispatch.py b/src/adcp/decisioning/dispatch.py index 8e254dde..785201a9 100644 --- a/src/adcp/decisioning/dispatch.py +++ b/src/adcp/decisioning/dispatch.py @@ -44,7 +44,7 @@ import typing import warnings from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from adcp.decisioning.account_projection import ( strip_credentials_from_wire_result, @@ -1059,6 +1059,7 @@ def _build_request_context( # Local import keeps the layering local — read the bearer ContextVar # without forcing a top-level dep on adcp.server.auth. from adcp.server.auth import current_principal as _current_principal + from adcp.server.auth import current_transport as _current_transport if auth_info is None: bearer_principal = _current_principal.get() @@ -1092,6 +1093,31 @@ def _build_request_context( else: caller_identity = tool_ctx.caller_identity + # Extract transport from metadata. In production paths RequestMetadata + # always populates metadata["transport"] before calling the context + # factory; None here means a test fixture supplied a bare ToolContext. + raw_transport = tool_ctx.metadata.get("transport") + if raw_transport not in ("mcp", "a2a", None): + raise ValueError( + f"metadata['transport'] must be 'mcp', 'a2a', or absent; got {raw_transport!r}" + ) + transport: Literal["mcp", "a2a"] | None = raw_transport + + # Set the ContextVar for code outside the handler call stack (webhook + # services, background helpers) that don't receive a RequestContext. + # No reset token is saved: asyncio tasks each get their own context + # copy, so set() is task-scoped and doesn't bleed across requests. + # Callers that need the previous value must save/restore it themselves + # (the test suite exercises this via asyncio.copy_context() isolation). + _current_transport.set(transport) + + # SDK-owned keys set by auth_context_factory / build_context examples + # ("transport", "tool_name") are framework-internal — strip them from + # the handler-visible metadata so adopters can't accidentally rely on + # undocumented dict paths and ctx.transport is the sole typed surface. + _sdk_metadata_keys = frozenset({"transport", "tool_name"}) + clean_metadata = {k: v for k, v in tool_ctx.metadata.items() if k not in _sdk_metadata_keys} + # Build the RequestContext with the explicit state/resolve kwargs # if provided; otherwise let the dataclass default factories # supply the v6.0 stubs. @@ -1099,7 +1125,8 @@ def _build_request_context( "request_id": tool_ctx.request_id, "caller_identity": caller_identity, "tenant_id": tool_ctx.tenant_id, - "metadata": dict(tool_ctx.metadata), + "metadata": clean_metadata, + "transport": transport, "account": account, "auth_info": auth_info, "auth_principal": auth_principal, diff --git a/src/adcp/server/__init__.py b/src/adcp/server/__init__.py index 16055573..8521ab93 100644 --- a/src/adcp/server/__init__.py +++ b/src/adcp/server/__init__.py @@ -64,6 +64,9 @@ async def get_products(params, context=None): TokenValidator, auth_context_factory, constant_time_token_match, + current_principal, + current_principal_metadata, + current_transport, validator_from_token_map, ) from adcp.server.base import ( @@ -207,6 +210,9 @@ async def get_products(params, context=None): "TokenValidator", "auth_context_factory", "constant_time_token_match", + "current_principal", + "current_principal_metadata", + "current_transport", "validator_from_token_map", # Idempotency middleware (AdCP #2315 seller side) "IdempotencyStore", diff --git a/src/adcp/server/auth.py b/src/adcp/server/auth.py index 4b27c66c..8a8551e1 100644 --- a/src/adcp/server/auth.py +++ b/src/adcp/server/auth.py @@ -80,7 +80,7 @@ async def validate_token(token: str) -> Principal | None: from collections.abc import Awaitable, Mapping from contextvars import ContextVar from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar _V = TypeVar("_V") @@ -192,6 +192,9 @@ def __call__(self, token: str) -> Awaitable[Principal | None]: ... current_principal_metadata: ContextVar[dict[str, Any] | None] = ContextVar( "adcp_auth_principal_metadata", default=None ) +current_transport: ContextVar[Literal["mcp", "a2a"] | None] = ContextVar( + "adcp_transport", default=None +) class BearerTokenAuthMiddleware(BaseHTTPMiddleware): diff --git a/tests/test_decisioning_dispatch.py b/tests/test_decisioning_dispatch.py index 9ce8af80..8f80e108 100644 --- a/tests/test_decisioning_dispatch.py +++ b/tests/test_decisioning_dispatch.py @@ -410,6 +410,23 @@ def test_build_request_context_threads_account_and_auth() -> None: assert ctx.caller_identity == "caller_x" assert ctx.tenant_id == "tenant_y" assert ctx.metadata == {"foo": "bar"} + # Fixture ToolContext has no "transport" in metadata — transport is None. + assert ctx.transport is None + + +@pytest.mark.parametrize("transport_value", ["mcp", "a2a"]) +def test_build_request_context_extracts_transport_from_metadata(transport_value: str) -> None: + """Transport is lifted from ToolContext.metadata into the typed field and ContextVar.""" + from adcp.server.auth import current_transport + + tool_ctx = ToolContext(metadata={"transport": transport_value, "tool_name": "get_products"}) + account: Account[Any] = Account(id="acct_b") + ctx = _build_request_context(tool_ctx, account, None) + assert ctx.transport == transport_value + assert current_transport.get() == transport_value + # SDK-owned keys are stripped from handler-visible metadata. + assert "transport" not in ctx.metadata + assert "tool_name" not in ctx.metadata def test_build_request_context_uses_composite_key_when_store_supplied() -> None: