Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/adcp/decisioning/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, ClassVar, Generic
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal

from typing_extensions import TypeVar

Expand Down Expand Up @@ -533,6 +533,16 @@ class RequestContext(ToolContext, Generic[TMeta]):
* Idempotency scope? → don't touch; the framework owns this.
* Logging request provenance? → log all four; they're cheap.

:param transport: The wire protocol that dispatched this call —
``"mcp"`` or ``"a2a"``. ``None`` when ``RequestContext`` is
constructed in tests without a transport-aware ``ToolContext``,
or when a custom ``context_factory`` omits
``metadata["transport"]``. Production dispatch always populates
this field. Note: even when the server is started with
``transport="both"``, individual requests always resolve to
exactly one of ``"mcp"`` or ``"a2a"`` — this field never
carries ``"both"``. For code running outside a handler call
stack, read :data:`adcp.server.current_transport` instead.
:param state: Sync reads of framework-owned in-flight workflow
state. Default is :class:`adcp.decisioning.state._NotYetWiredStateReader`
— returns empty values + emits one-time UserWarning per
Expand Down Expand Up @@ -560,6 +570,7 @@ class RequestContext(ToolContext, Generic[TMeta]):
auth_info: AuthInfo | None = None
auth_principal: str | None = None
buyer_agent: BuyerAgent | None = None
transport: Literal["mcp", "a2a"] | None = None
now: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
state: StateReader = field(default_factory=_make_default_state_reader)
resolve: ResourceResolver = field(default_factory=_make_default_resolver)
Expand Down
31 changes: 29 additions & 2 deletions src/adcp/decisioning/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import typing
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

from adcp.decisioning.account_projection import (
strip_credentials_from_wire_result,
Expand Down Expand Up @@ -1059,6 +1059,7 @@ def _build_request_context(
# Local import keeps the layering local — read the bearer ContextVar
# without forcing a top-level dep on adcp.server.auth.
from adcp.server.auth import current_principal as _current_principal
from adcp.server.auth import current_transport as _current_transport

if auth_info is None:
bearer_principal = _current_principal.get()
Expand Down Expand Up @@ -1092,14 +1093,40 @@ def _build_request_context(
else:
caller_identity = tool_ctx.caller_identity

# Extract transport from metadata. In production paths RequestMetadata
# always populates metadata["transport"] before calling the context
# factory; None here means a test fixture supplied a bare ToolContext.
raw_transport = tool_ctx.metadata.get("transport")
if raw_transport not in ("mcp", "a2a", None):
raise ValueError(
f"metadata['transport'] must be 'mcp', 'a2a', or absent; got {raw_transport!r}"
)
transport: Literal["mcp", "a2a"] | None = raw_transport

# Set the ContextVar for code outside the handler call stack (webhook
# services, background helpers) that don't receive a RequestContext.
# No reset token is saved: asyncio tasks each get their own context
# copy, so set() is task-scoped and doesn't bleed across requests.
# Callers that need the previous value must save/restore it themselves
# (the test suite exercises this via asyncio.copy_context() isolation).
_current_transport.set(transport)

# SDK-owned keys set by auth_context_factory / build_context examples
# ("transport", "tool_name") are framework-internal — strip them from
# the handler-visible metadata so adopters can't accidentally rely on
# undocumented dict paths and ctx.transport is the sole typed surface.
_sdk_metadata_keys = frozenset({"transport", "tool_name"})
clean_metadata = {k: v for k, v in tool_ctx.metadata.items() if k not in _sdk_metadata_keys}

# Build the RequestContext with the explicit state/resolve kwargs
# if provided; otherwise let the dataclass default factories
# supply the v6.0 stubs.
ctx_kwargs: dict[str, Any] = {
"request_id": tool_ctx.request_id,
"caller_identity": caller_identity,
"tenant_id": tool_ctx.tenant_id,
"metadata": dict(tool_ctx.metadata),
"metadata": clean_metadata,
"transport": transport,
"account": account,
"auth_info": auth_info,
"auth_principal": auth_principal,
Expand Down
6 changes: 6 additions & 0 deletions src/adcp/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ async def get_products(params, context=None):
TokenValidator,
auth_context_factory,
constant_time_token_match,
current_principal,
current_principal_metadata,
current_transport,
validator_from_token_map,
)
from adcp.server.base import (
Expand Down Expand Up @@ -207,6 +210,9 @@ async def get_products(params, context=None):
"TokenValidator",
"auth_context_factory",
"constant_time_token_match",
"current_principal",
"current_principal_metadata",
"current_transport",
"validator_from_token_map",
# Idempotency middleware (AdCP #2315 seller side)
"IdempotencyStore",
Expand Down
5 changes: 4 additions & 1 deletion src/adcp/server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def validate_token(token: str) -> Principal | None:
from collections.abc import Awaitable, Mapping
from contextvars import ContextVar
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar

_V = TypeVar("_V")

Expand Down Expand Up @@ -192,6 +192,9 @@ def __call__(self, token: str) -> Awaitable[Principal | None]: ...
current_principal_metadata: ContextVar[dict[str, Any] | None] = ContextVar(
"adcp_auth_principal_metadata", default=None
)
current_transport: ContextVar[Literal["mcp", "a2a"] | None] = ContextVar(
"adcp_transport", default=None
)


class BearerTokenAuthMiddleware(BaseHTTPMiddleware):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_decisioning_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,23 @@ def test_build_request_context_threads_account_and_auth() -> None:
assert ctx.caller_identity == "caller_x"
assert ctx.tenant_id == "tenant_y"
assert ctx.metadata == {"foo": "bar"}
# Fixture ToolContext has no "transport" in metadata — transport is None.
assert ctx.transport is None


@pytest.mark.parametrize("transport_value", ["mcp", "a2a"])
def test_build_request_context_extracts_transport_from_metadata(transport_value: str) -> None:
"""Transport is lifted from ToolContext.metadata into the typed field and ContextVar."""
from adcp.server.auth import current_transport

tool_ctx = ToolContext(metadata={"transport": transport_value, "tool_name": "get_products"})
account: Account[Any] = Account(id="acct_b")
ctx = _build_request_context(tool_ctx, account, None)
assert ctx.transport == transport_value
assert current_transport.get() == transport_value
# SDK-owned keys are stripped from handler-visible metadata.
assert "transport" not in ctx.metadata
assert "tool_name" not in ctx.metadata


def test_build_request_context_uses_composite_key_when_store_supplied() -> None:
Expand Down
Loading