diff --git a/src/adcp/server/__init__.py b/src/adcp/server/__init__.py index 6cbd71c9..a5b6fe3a 100644 --- a/src/adcp/server/__init__.py +++ b/src/adcp/server/__init__.py @@ -55,7 +55,9 @@ async def get_products(params, context=None): from adcp.capabilities import validate_capabilities from adcp.server.a2a_server import ADCPAgentExecutor, MessageParser, create_a2a_server from adcp.server.auth import ( + A2ABearerAuthMiddleware, AsyncTokenValidator, + BearerTokenAuth, BearerTokenAuthMiddleware, Principal, SyncTokenValidator, @@ -194,7 +196,9 @@ async def get_products(params, context=None): "SkillMiddleware", "create_a2a_server", # Bearer-token auth middleware (seller-facing recipe) + "A2ABearerAuthMiddleware", "AsyncTokenValidator", + "BearerTokenAuth", "BearerTokenAuthMiddleware", "Principal", "SyncTokenValidator", diff --git a/src/adcp/server/a2a_server.py b/src/adcp/server/a2a_server.py index 9e4f1126..74b266fb 100644 --- a/src/adcp/server/a2a_server.py +++ b/src/adcp/server/a2a_server.py @@ -650,6 +650,7 @@ def create_a2a_server( message_parser: MessageParser | None = None, advertise_all: bool = False, validation: ValidationHookConfig | None = SERVER_DEFAULT_VALIDATION, + context_builder: Any | None = None, ) -> Any: """Create an A2A Starlette application from an ADCP handler. @@ -780,12 +781,25 @@ def create_a2a_server( # ``enable_v0_3_compat=True`` is load-bearing: it makes the server # dual-serve 0.3 and 1.0 wire formats on the same endpoint so existing # 0.3 buyer clients keep working unchanged. Do not disable. + # + # ``context_builder`` is the a2a-sdk seam for customising the + # :class:`ServerCallContext` each handler receives. We thread it + # through verbatim when supplied — bearer-token auth is wired + # separately via :class:`A2ABearerAuthMiddleware` at the ASGI + # layer (see ``serve.py:_wrap_a2a_with_auth``) because the v0.3 + # compat adapter swallows builder-raised ``HTTPException``s. The + # builder kwarg remains for adopters customising the + # ``ServerCallContext`` shape (e.g. surfacing additional + # ``state`` fields from the request). + jsonrpc_kwargs: dict[str, Any] = { + "request_handler": request_handler, + "rpc_url": "/", + "enable_v0_3_compat": True, + } + if context_builder is not None: + jsonrpc_kwargs["context_builder"] = context_builder routes = list(create_agent_card_routes(agent_card=agent_card)) + list( - create_jsonrpc_routes( - request_handler=request_handler, - rpc_url="/", - enable_v0_3_compat=True, - ) + create_jsonrpc_routes(**jsonrpc_kwargs) ) app = Starlette(routes=routes) diff --git a/src/adcp/server/auth.py b/src/adcp/server/auth.py index bb84d33d..9099b838 100644 --- a/src/adcp/server/auth.py +++ b/src/adcp/server/auth.py @@ -59,10 +59,15 @@ async def validate_token(token: str) -> Principal | None: * **Authorization.** The middleware answers "who is this?", not "can they do X?". Authorization checks run on the authenticated principal inside your handlers or as :data:`~adcp.server.SkillMiddleware`. -* **A2A auth.** A2A uses a different transport; wire a2a-sdk's - ``ServerCallContext.user`` via a2a-sdk auth middleware on that side. - The ``Principal`` / ``ToolContext`` shape is the same, so handlers - work unchanged across transports. +* **A2A auth.** A2A uses a different transport; the same + :class:`BearerTokenAuth` config object drives both legs when wired + via :func:`adcp.server.serve`'s ``auth=`` kwarg. The A2A side is + authenticated by a :class:`BearerTokenContextBuilder` plumbed into + ``a2a-sdk``'s ``create_jsonrpc_routes(context_builder=...)`` seam, + not by a Starlette middleware — that placement bypasses the + ``/.well-known/agent-card.json`` route automatically (which is + registered separately and never invokes the builder), satisfying + A2A spec §4.1's mandate that the agent card be publicly accessible. """ from __future__ import annotations @@ -480,3 +485,268 @@ def _validate(token: str) -> Principal | None: return constant_time_token_match(token, stored_hashes) return _validate + + +# --------------------------------------------------------------------------- +# Cross-transport auth config — drives both MCP middleware and A2A builder +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class BearerTokenAuth: + """Cross-transport bearer-token auth config for :func:`adcp.server.serve`. + + Single source of truth that wires the same ``validate_token`` + callback (and ``header_name`` / ``bearer_prefix_required`` knobs) + into both the MCP-side :class:`BearerTokenAuthMiddleware` and the + A2A-side :class:`BearerTokenContextBuilder`. Pass via + ``serve(auth=BearerTokenAuth(...))`` and both legs are + authenticated against the same token store with no per-leg + drift:: + + from adcp.server import serve + from adcp.server.auth import BearerTokenAuth, validator_from_token_map + + serve( + handler, + transport="both", + auth=BearerTokenAuth( + validate_token=validator_from_token_map({ + "secret-token": Principal(caller_identity="p", tenant_id="acme"), + }), + ), + ) + + On MCP, requests without a valid token receive a JSON ``401`` + body. On A2A, requests without a valid token receive an HTTP + ``401`` from Starlette's :class:`HTTPException`. Discovery + bypasses are transport-specific: + + * **MCP**: ``initialize`` / ``tools/list`` / ``notifications/initialized`` + / ``get_adcp_capabilities`` (JSON-RPC method-level bypass). + * **A2A**: ``/.well-known/agent-card.json`` (route-level — the + agent-card route is created separately and never invokes the + builder, so no path-based exemption is needed in the + :class:`BearerTokenContextBuilder` itself). + + Knobs mirror :class:`BearerTokenAuthMiddleware` exactly: pass + ``header_name="x-adcp-auth"`` and ``bearer_prefix_required=False`` + for non-OAuth custom-header schemes. + """ + + validate_token: TokenValidator + header_name: str = "authorization" + bearer_prefix_required: bool = True + unauthenticated_response: dict[str, Any] | None = None + + +# --------------------------------------------------------------------------- +# A2A: ASGI middleware that gates JSON-RPC requests, exempts agent-card +# --------------------------------------------------------------------------- +# +# Why an ASGI middleware (not a ServerCallContextBuilder)? +# The a2a-sdk v0.3 compat adapter wraps the entire dispatch in +# ``except Exception`` and converts any error — including a builder- +# raised :class:`HTTPException(401)` — into a 200 OK with a JSON-RPC +# error body. That breaks the spec-canonical HTTP 401 contract and +# leaks the auth path as a 200. Authenticating outside the dispatcher, +# at the ASGI layer, returns proper HTTP 401 every time. +# +# A2A discovery (``/.well-known/agent-card.json``) is exempted by URL +# path here because the agent-card route happens to live in the same +# Starlette app — the middleware can't rely on the route topology +# alone. Path-exemption keeps the spec §4.1 public-discovery mandate +# satisfied even if a future a2a-sdk refactor merges the routes. + + +# Canonical 1.0 path is sourced from a2a-sdk's own constant — if a +# future a2a-sdk release renames the well-known URI, the import-time +# reference here lifts to the new value automatically and +# ``test_discovery_paths_match_a2a_sdk_routes`` verifies that the +# frozenset still covers every route ``create_agent_card_routes`` +# actually registers. Hardcoding the string would silently leak auth +# on the renamed route until someone notices. +from a2a.utils.constants import ( # noqa: E402 (intentional placement after BearerTokenAuth definition) + AGENT_CARD_WELL_KNOWN_PATH as _A2A_AGENT_CARD_PATH, +) + +_A2A_DISCOVERY_PATHS: frozenset[str] = frozenset( + { + _A2A_AGENT_CARD_PATH, # 1.0 canonical: ``/.well-known/agent-card.json``. + "/.well-known/agent.json", # Legacy 0.3 alias retained by enable_v0_3_compat=True. + } +) + + +class A2ABearerAuthMiddleware: + """Pure-ASGI middleware that gates A2A JSON-RPC on a bearer token. + + Wrap the Starlette app produced by + :func:`adcp.server.a2a_server.create_a2a_server` with this + middleware to require a valid bearer header on every JSON-RPC + request, while leaving the spec-mandated public discovery + surface (``/.well-known/agent-card.json``) accessible. + + Designed to compose with a2a-sdk's + :class:`DefaultServerCallContextBuilder`: on auth success the + middleware writes a duck-typed user object into + ``scope['user']`` and the principal into ``scope['auth']``, + matching Starlette's :class:`AuthenticationMiddleware` contract. + The default builder reads ``scope['user']`` and adapts it via + :class:`a2a.server.routes.common.StarletteUser`, so downstream + handlers see ``ServerCallContext.user.user_name`` populated with + the principal's ``caller_identity`` without a custom builder. + + Composition order matters when ``transport="both"`` is in play: + wrap the per-leg apps before any outer dispatcher closes over + them. See ``serve.py:_build_mcp_and_a2a_app`` for the wiring. + """ + + def __init__(self, app: Any, config: BearerTokenAuth) -> None: + self._app = app + self._config = config + self._header_name = config.header_name.lower() + + async def __call__(self, scope: Any, receive: Any, send: Any) -> None: + # Lifespan + websocket pass through unchanged. Auth applies to + # HTTP requests only. + if scope.get("type") != "http": + await self._app(scope, receive, send) + return + + # CORS preflight is part of the public surface — browser-origin + # clients send ``OPTIONS`` before any auth'd POST. Returning 401 + # here breaks the preflight and the buyer never gets a chance to + # retry with a token. Pass through; let the inner app's CORS + # handler (or operator-supplied ``asgi_middleware``) respond. + if scope.get("method") == "OPTIONS": + await self._app(scope, receive, send) + return + + path = scope.get("path", "") + if path in _A2A_DISCOVERY_PATHS: + await self._app(scope, receive, send) + return + + principal = self._authenticate_scope(scope) + if principal is None: + await self._send_unauthenticated(send) + return + + # Stash both the duck-typed user (for DefaultServerCallContextBuilder) + # and the raw Principal (for downstream code reading scope['auth']). + # Mutating the scope dict before delegating propagates state to + # nested apps without copying. + scope["user"] = _A2AAuthenticatedUser( + display_name=principal.caller_identity, + tenant_id=principal.tenant_id, + principal_metadata=dict(principal.metadata) if principal.metadata else None, + ) + scope["auth"] = principal + await self._app(scope, receive, send) + + def _authenticate_scope(self, scope: Any) -> Principal | None: + """Read + validate the bearer header off raw ASGI scope. + + Validator exceptions are projected to :data:`None` (logged for + operators) so a buggy validator never leaks 500-level stack + traces or signals path existence to unauthenticated callers. + Auth-rejection branches log at INFO with a coarse reason code + so SOC dashboards can detect scanning without bloating logs. + """ + # ASGI ``headers`` is a list of ``(bytes_lower, bytes)`` tuples. + target = self._header_name.encode("latin-1") + raw_value: bytes | None = None + for name, value in scope.get("headers", ()): + if name == target: + raw_value = value + break + + if raw_value is None: + logger.info("a2a auth rejected", extra={"reason": "missing_header"}) + return None + + try: + raw_header = raw_value.decode("latin-1") + except UnicodeDecodeError: + logger.info("a2a auth rejected", extra={"reason": "header_decode"}) + return None + + if self._config.bearer_prefix_required: + bearer = _parse_bearer_header(raw_header) + else: + stripped = raw_header.strip() + bearer = stripped or None + if not bearer: + logger.info("a2a auth rejected", extra={"reason": "wrong_scheme"}) + return None + + try: + raw = self._config.validate_token(bearer) + except Exception: + logger.exception("token validator raised on A2A request") + return None + + if inspect.isawaitable(raw): + # Should be unreachable — :func:`_assert_sync_validator` at + # config time rejects async validators before any traffic + # lands. This branch is the in-depth catch in case an + # adopter swaps in an async validator at runtime via a + # closure that conditionally awaits. + logger.error( + "a2a auth rejected: validator returned awaitable at request " + "time. Async validators are not supported on the A2A leg; " + "wrap with a sync bridge." + ) + return None + + if raw is None: + logger.info("a2a auth rejected", extra={"reason": "invalid_token"}) + return None + return raw + + async def _send_unauthenticated(self, send: Any) -> None: + body_obj = self._config.unauthenticated_response or { + "error": "invalid_token", + "error_description": "Bearer token missing or invalid", + } + body = json.dumps(body_obj).encode("utf-8") + # RFC 6750 §3 + RFC 7235 §3.1 require ``WWW-Authenticate: Bearer`` + # on every 401. Without it, RFC-compliant clients (including + # browsers and many HTTP libraries) won't surface the auth + # challenge to the user — they treat the 401 as a generic + # error. Always emit; even when the operator overrides + # ``unauthenticated_response``, the header stays for protocol + # compliance. + await send( + { + "type": "http.response.start", + "status": 401, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(body)).encode("latin-1")), + (b"www-authenticate", b'Bearer realm="a2a", error="invalid_token"'), + ], + } + ) + await send({"type": "http.response.body", "body": body}) + + +@dataclass(frozen=True) +class _A2AAuthenticatedUser: + """Minimal Starlette-BaseUser-shaped object for :class:`StarletteUser`. + + a2a-sdk's :class:`StarletteUser` adapter wants ``is_authenticated`` + (bool) and ``display_name`` (str). It doesn't import Starlette's + :class:`BaseUser` directly — duck-typing works. We synthesize a + frozen dataclass so the principal's identity flows through with no + Starlette dependency on the auth side. + """ + + display_name: str + tenant_id: str | None = None + principal_metadata: dict[str, Any] | None = None + + @property + def is_authenticated(self) -> bool: + return True diff --git a/src/adcp/server/serve.py b/src/adcp/server/serve.py index 4216378f..abbbf035 100644 --- a/src/adcp/server/serve.py +++ b/src/adcp/server/serve.py @@ -56,6 +56,7 @@ async def get_adcp_capabilities(self, params, context=None): from a2a.server.tasks.task_store import TaskStore from adcp.server.a2a_server import MessageParser + from adcp.server.auth import BearerTokenAuth from adcp.server.test_controller import TestControllerStore @@ -448,6 +449,7 @@ def serve( allowed_hosts: Sequence[str] | None = None, allowed_origins: Sequence[str] | None = None, enable_dns_rebinding_protection: bool | None = None, + auth: BearerTokenAuth | None = None, ) -> None: """Start an MCP or A2A server from an ADCP handler or server builder. @@ -614,6 +616,17 @@ class of bug that shipped the ``pricing_options`` before forwarding to the endpoint. Without authentication, MCP exposes tools/list and A2A exposes /.well-known/agent.json, both of which reveal the agent's full capability surface. + auth: Optional :class:`~adcp.server.auth.BearerTokenAuth` config + applied to MCP, A2A, and ``transport="both"`` legs from the + same source of truth. Drives MCP's + :class:`~adcp.server.auth.BearerTokenAuthMiddleware` and + A2A's :class:`~adcp.server.auth.BearerTokenContextBuilder`. + On A2A, ``/.well-known/agent-card.json`` stays publicly + accessible per A2A spec §4.1 — the agent-card route is + registered separately and never invokes the builder. On + stdio, ``auth`` is ignored with a warning (no HTTP layer). + For non-bearer schemes (mTLS, signed-request derivation), + wire your own middleware via ``asgi_middleware=`` instead. Example (MCP): from adcp.server import ADCPHandler, serve @@ -676,6 +689,7 @@ async def force_account_status(self, account_id, status): base_url=base_url, specialisms=specialisms, description=description, + auth=auth, ) elif transport in ("streamable-http", "sse", "stdio"): _serve_mcp( @@ -700,6 +714,7 @@ async def force_account_status(self, account_id, status): allowed_hosts=allowed_hosts, allowed_origins=allowed_origins, enable_dns_rebinding_protection=enable_dns_rebinding_protection, + auth=auth, ) elif transport == "both": _serve_mcp_and_a2a( @@ -726,6 +741,7 @@ async def force_account_status(self, account_id, status): allowed_hosts=allowed_hosts, allowed_origins=allowed_origins, enable_dns_rebinding_protection=enable_dns_rebinding_protection, + auth=auth, ) else: valid = ", ".join(sorted(("a2a", "both", "streamable-http", "sse", "stdio"))) @@ -795,6 +811,91 @@ def _apply_asgi_middleware( return app +def _wrap_mcp_with_auth(app: Any, auth: BearerTokenAuth | None) -> Any: + """Wrap the FastMCP HTTP app with :class:`BearerTokenAuthMiddleware`. + + No-op when ``auth`` is ``None``. Expects a + :class:`~adcp.server.auth.BearerTokenAuth` config; raises + :class:`TypeError` for anything else so misconfiguration is loud at + boot, not silent at runtime. + + The middleware is applied *innermost* so its body-peek for the + JSON-RPC discovery bypass sees the payload before the path + normalizer / discovery wrapper / operator's ``asgi_middleware`` + layer reshape the request. + """ + if auth is None: + return app + from adcp.server.auth import BearerTokenAuth, BearerTokenAuthMiddleware + + if not isinstance(auth, BearerTokenAuth): + raise TypeError( + f"serve(auth=...) expects BearerTokenAuth, got {type(auth).__name__}. " + "Import from adcp.server.auth.BearerTokenAuth." + ) + + # FastMCP's ``streamable_http_app()`` returns a Starlette instance; + # ``add_middleware`` wraps the inner app in place and preserves + # FastMCP's lifespan + routing without a parallel Starlette. + app.add_middleware( + BearerTokenAuthMiddleware, + validate_token=auth.validate_token, + unauthenticated_response=auth.unauthenticated_response, + header_name=auth.header_name, + bearer_prefix_required=auth.bearer_prefix_required, + ) + return app + + +def _wrap_a2a_with_auth(app: Any, auth: BearerTokenAuth | None) -> Any: + """Wrap an A2A Starlette app with :class:`A2ABearerAuthMiddleware`. + + No-op when ``auth`` is ``None``. Returns the original app + untouched, so the A2A side falls back to a2a-sdk's default + (unauthenticated, agent-card publicly accessible) without any + middleware overhead. + + The middleware is wrapped at the ASGI layer (not via + ``Starlette.add_middleware``) so it sees the request before + a2a-sdk's JsonRpcDispatcher and v0.3 compat adapter — which + catch every exception including ``HTTPException`` and convert + them to JSON-RPC errors with HTTP 200. ASGI-layer wrapping + returns proper HTTP 401 every time. + + Same type guard as :func:`_wrap_mcp_with_auth` — a misconfig + that passes a dict / lambda / wrong type is loud at boot. + + Async validators are rejected at boot because the A2A leg's + middleware path is sync (the MCP middleware awaits async + validators transparently — A2A can't without restructuring + a2a-sdk's dispatcher). Catching the misuse at ``serve()`` time + instead of on the first request prevents production deployments + from shipping with silently-failing auth. + """ + if auth is None: + return app + import inspect as _inspect + + from adcp.server.auth import A2ABearerAuthMiddleware, BearerTokenAuth + + if not isinstance(auth, BearerTokenAuth): + raise TypeError( + f"serve(auth=...) expects BearerTokenAuth, got {type(auth).__name__}. " + "Import from adcp.server.auth.BearerTokenAuth." + ) + if _inspect.iscoroutinefunction(auth.validate_token): + raise TypeError( + "BearerTokenAuth.validate_token is async, which the A2A leg " + "cannot call directly — a2a-sdk's middleware path is sync. " + "Wrap your async validator with a sync bridge " + "(e.g. `lambda t: anyio.from_thread.run(my_async_validate, t)`) " + "before passing it to BearerTokenAuth, or use transport=" + "'streamable-http' (MCP middleware awaits async validators " + "transparently)." + ) + return A2ABearerAuthMiddleware(app, auth) + + def _wrap_with_discovery( app: Any, *, @@ -1013,6 +1114,7 @@ def _serve_mcp( allowed_hosts: Sequence[str] | None = None, allowed_origins: Sequence[str] | None = None, enable_dns_rebinding_protection: bool | None = None, + auth: BearerTokenAuth | None = None, ) -> None: """Start an MCP server.""" mcp = create_mcp_server( @@ -1052,9 +1154,18 @@ def _serve_mcp( discovery_base_url=base_url, discovery_specialisms=specialisms, discovery_description=description, + auth=auth, ) else: - # stdio — no listening socket, nothing to configure. + # stdio — no listening socket, no HTTP layer to authenticate. Auth + # over stdio doesn't apply (no Authorization header). Warn loudly + # rather than silently ignore so adopters notice the misconfig. + if auth is not None: + logger.warning( + "auth=BearerTokenAuth ignored on transport='stdio' — stdio " + "has no HTTP layer for bearer-token validation. Wire your " + "own out-of-band auth or use transport='streamable-http'." + ) if asgi_middleware: logger.warning( "asgi_middleware is ignored on transport='stdio'; " "ASGI middleware will not run" @@ -1072,6 +1183,7 @@ def _run_mcp_http( discovery_base_url: str | None = None, discovery_specialisms: list[str] | None = None, discovery_description: str | None = None, + auth: BearerTokenAuth | None = None, ) -> None: """Run FastMCP's HTTP transports with a pre-bound SO_REUSEADDR socket. @@ -1097,6 +1209,11 @@ def _run_mcp_http( resolved_base_url = resolve_base_url(host, port, discovery_base_url) + # Auth wraps innermost so the spec-mandated MCP discovery bypass + # (initialize / tools/list / get_adcp_capabilities) sees the + # JSON-RPC body before the path-normalizer / discovery wrapper / + # operator-supplied asgi_middleware get a turn. + app = _wrap_mcp_with_auth(app, auth) app = _wrap_with_path_normalize(app) app = _wrap_with_discovery( app, @@ -1154,6 +1271,7 @@ def _serve_a2a( base_url: str | None = None, specialisms: list[str] | None = None, description: str | None = None, + auth: BearerTokenAuth | None = None, ) -> None: """Start an A2A server using uvicorn.""" import uvicorn @@ -1178,6 +1296,11 @@ def _serve_a2a( advertise_all=advertise_all, validation=validation, ) + # Auth wraps the A2A app innermost (closer to the inner Starlette + # router than the discovery + size-limit + asgi_middleware + # wrappers) so bad tokens 401 before the request hits any + # operator-supplied layer. + app = _wrap_a2a_with_auth(app, auth) app = _wrap_with_discovery( app, name=name, @@ -1230,6 +1353,7 @@ def _build_mcp_and_a2a_app( allowed_hosts: Sequence[str] | None = None, allowed_origins: Sequence[str] | None = None, enable_dns_rebinding_protection: bool | None = None, + auth: BearerTokenAuth | None = None, ) -> Any: """Build the unified MCP+A2A ASGI app without starting a server. @@ -1278,6 +1402,18 @@ def _build_mcp_and_a2a_app( account_resolver=test_controller_account_resolver, ) mcp_inner = mcp.streamable_http_app() + # Auth wraps the FastMCP Starlette app *before* the path + # normalizer / dispatcher capture references. Wiring auth after + # ``mcp_app`` is captured by ``_dispatch`` would silently bypass + # the middleware on the MCP leg — the closure would already point + # at the unwrapped reference. + # + # Reassigning the return value (rather than relying on + # ``add_middleware``'s in-place mutation) future-proofs the call + # site: if a future refactor changes ``_wrap_mcp_with_auth`` to + # return a fresh ASGI callable, this line keeps wiring auth + # instead of silently dropping it. + mcp_inner = _wrap_mcp_with_auth(mcp_inner, auth) # Wrap with the standard trailing-slash normalizer so ``/mcp/`` # and ``/mcp`` resolve to the same FastMCP endpoint. Keep the # unwrapped ``mcp_inner`` reference so the lifespan composer @@ -1287,7 +1423,12 @@ def _build_mcp_and_a2a_app( # A2A app — built via the a2a-sdk wrapper. It mounts at the root # of its own app and handles ``/.well-known/agent.json``, ``/``, # and the message / push-notif endpoints. - a2a_app = create_a2a_server( + # + # Keep the unwrapped ``a2a_inner`` reference so the lifespan + # composer below can reach ``.router.lifespan_context``; wrap the + # dispatch reference separately so requests flow through auth on + # their way to the inner Starlette app. + a2a_inner = create_a2a_server( handler, name=name, port=port, @@ -1301,6 +1442,13 @@ def _build_mcp_and_a2a_app( advertise_all=advertise_all, validation=validation, ) + # Auth wraps both legs *before* ``_dispatch`` captures references — + # otherwise the closure points at unwrapped apps and auth is + # silently bypassed on whichever leg hadn't been wrapped yet. The + # MCP wrap above used ``add_middleware`` so it mutates in place; + # the A2A wrap returns a new ASGI callable layered on + # ``a2a_inner``. + a2a_app = _wrap_a2a_with_auth(a2a_inner, auth) # Lifespan composition: FastMCP's session manager initializes a # task group on startup; a2a-sdk's stores have their own init. @@ -1310,7 +1458,7 @@ def _build_mcp_and_a2a_app( @contextlib.asynccontextmanager async def _composed_lifespan(_app): # type: ignore[no-untyped-def] async with mcp_inner.router.lifespan_context(mcp_inner): - async with a2a_app.router.lifespan_context(a2a_app): + async with a2a_inner.router.lifespan_context(a2a_inner): yield parent = Starlette(lifespan=_composed_lifespan) @@ -1380,6 +1528,7 @@ def _serve_mcp_and_a2a( allowed_hosts: Sequence[str] | None = None, allowed_origins: Sequence[str] | None = None, enable_dns_rebinding_protection: bool | None = None, + auth: BearerTokenAuth | None = None, ) -> None: """Serve MCP and A2A on a single port via path dispatch. @@ -1425,6 +1574,7 @@ def _serve_mcp_and_a2a( allowed_hosts=allowed_hosts, allowed_origins=allowed_origins, enable_dns_rebinding_protection=enable_dns_rebinding_protection, + auth=auth, ) app = _apply_asgi_middleware(app, asgi_middleware) diff --git a/tests/test_serve_auth_both.py b/tests/test_serve_auth_both.py new file mode 100644 index 00000000..ceb91d54 --- /dev/null +++ b/tests/test_serve_auth_both.py @@ -0,0 +1,678 @@ +"""Cross-transport auth coverage for ``serve(auth=BearerTokenAuth(...))``. + +Closes the regression filed in #558: bearer-token auth applied via the +existing ``BearerTokenAuthMiddleware`` to ``serve(transport="both")`` +left the A2A leg unauthenticated. The fix wires ``auth=`` into both +the MCP middleware and an A2A +:class:`~adcp.server.auth.A2ABearerAuthMiddleware`, with the agent +card publicly accessible per A2A spec §4.1. + +Three layers of coverage: + +1. **Unit** — :class:`A2ABearerAuthMiddleware` accepts/rejects via + the same shapes as :class:`BearerTokenAuthMiddleware`. +2. **A2A through ASGI** — full route-level test with + ``httpx.AsyncClient`` against the a2a-sdk-built Starlette app + wrapped in our auth middleware. +3. **transport="both"** — the regression case: hit MCP and A2A on + the same binary; both legs require auth, agent-card and MCP + discovery are exempt. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import httpx +import pytest +from asgi_lifespan import LifespanManager + +from adcp.server import ADCPHandler +from adcp.server.auth import ( + A2ABearerAuthMiddleware, + BearerTokenAuth, + Principal, + validator_from_token_map, +) + +# --------------------------------------------------------------------------- +# Handlers +# --------------------------------------------------------------------------- + + +class _OkHandler(ADCPHandler): + """Minimal handler returning structured success on get_products.""" + + async def get_adcp_capabilities(self, params: Any, context: Any = None) -> dict[str, Any]: + return {"adcp": {"major_versions": [3]}, "supported_protocols": ["media_buy"]} + + async def get_products(self, params: Any, context: Any = None) -> dict[str, Any]: + return {"products": [{"id": "p1", "name": "Display"}], "sandbox": True} + + +def _auth() -> BearerTokenAuth: + return BearerTokenAuth( + validate_token=validator_from_token_map( + {"good-token": Principal(caller_identity="p-acme", tenant_id="acme")} + ) + ) + + +# =========================================================================== +# Unit: A2ABearerAuthMiddleware against raw ASGI scope +# =========================================================================== + + +class TestA2ABearerAuthMiddlewareUnit: + """Middleware logic verified against raw ASGI scope dicts.""" + + def _scope(self, path: str = "/", headers: list[tuple[bytes, bytes]] | None = None) -> dict: + return { + "type": "http", + "method": "POST", + "path": path, + "headers": list(headers or []), + } + + @pytest.mark.asyncio + async def test_valid_token_passes_through_and_populates_scope_user(self): + inner_calls: list[dict] = [] + + async def inner(scope: Any, _receive: Any, _send: Any) -> None: + inner_calls.append(scope) + + mw = A2ABearerAuthMiddleware(inner, _auth()) + scope = self._scope(headers=[(b"authorization", b"Bearer good-token")]) + await mw(scope, lambda: None, lambda _: None) + + assert len(inner_calls) == 1 + passed_scope = inner_calls[0] + assert "user" in passed_scope + assert passed_scope["user"].is_authenticated is True + assert passed_scope["user"].display_name == "p-acme" + assert "auth" in passed_scope + assert passed_scope["auth"].caller_identity == "p-acme" + + @pytest.mark.asyncio + async def test_missing_header_returns_401(self): + sent: list[dict] = [] + inner_called = False + + async def inner(_scope: Any, _receive: Any, _send: Any) -> None: + nonlocal inner_called + inner_called = True + + async def send(msg: dict) -> None: + sent.append(msg) + + mw = A2ABearerAuthMiddleware(inner, _auth()) + await mw(self._scope(), lambda: None, send) + + assert not inner_called # Auth failure short-circuits. + assert sent[0]["type"] == "http.response.start" + assert sent[0]["status"] == 401 + # RFC 6750 default body shape — error code is ``invalid_token``. + assert b"invalid_token" in sent[1]["body"] + + @pytest.mark.asyncio + async def test_invalid_token_returns_401(self): + sent: list[dict] = [] + + async def inner(_scope: Any, _receive: Any, _send: Any) -> None: ... + + async def send(msg: dict) -> None: + sent.append(msg) + + mw = A2ABearerAuthMiddleware(inner, _auth()) + await mw( + self._scope(headers=[(b"authorization", b"Bearer bad")]), + lambda: None, + send, + ) + assert sent[0]["status"] == 401 + + @pytest.mark.asyncio + async def test_validator_exception_projects_to_401_not_500(self): + """A buggy validator must not leak 500 stacks. We log + 401.""" + + def boom(_token: str) -> Principal | None: + raise RuntimeError("token store down") + + sent: list[dict] = [] + + async def inner(_scope: Any, _receive: Any, _send: Any) -> None: ... + + async def send(msg: dict) -> None: + sent.append(msg) + + mw = A2ABearerAuthMiddleware(inner, BearerTokenAuth(validate_token=boom)) + await mw( + self._scope(headers=[(b"authorization", b"Bearer x")]), + lambda: None, + send, + ) + assert sent[0]["status"] == 401 + + @pytest.mark.asyncio + async def test_agent_card_path_publicly_accessible(self): + """A2A spec §4.1 — ``/.well-known/agent-card.json`` MUST be + public regardless of auth config.""" + inner_calls: list[dict] = [] + + async def inner(scope: Any, _receive: Any, _send: Any) -> None: + inner_calls.append(scope) + + mw = A2ABearerAuthMiddleware(inner, _auth()) + # No Authorization header — MUST still pass through. + await mw(self._scope(path="/.well-known/agent-card.json"), lambda: None, lambda _: None) + assert len(inner_calls) == 1 + assert "user" not in inner_calls[0] # No principal injected on public route. + + @pytest.mark.asyncio + async def test_legacy_agent_json_path_also_exempt(self): + """``/.well-known/agent.json`` is the 0.3 alias retained by the + compat shim — exempt for the same spec reason.""" + inner_calls: list[dict] = [] + + async def inner(scope: Any, _receive: Any, _send: Any) -> None: + inner_calls.append(scope) + + mw = A2ABearerAuthMiddleware(inner, _auth()) + await mw(self._scope(path="/.well-known/agent.json"), lambda: None, lambda _: None) + assert len(inner_calls) == 1 + + @pytest.mark.asyncio + async def test_lifespan_scope_passes_through(self): + """Lifespan events bypass auth entirely — they're not HTTP.""" + inner_calls: list[Any] = [] + + async def inner(scope: Any, _receive: Any, _send: Any) -> None: + inner_calls.append(scope) + + mw = A2ABearerAuthMiddleware(inner, _auth()) + await mw({"type": "lifespan"}, lambda: None, lambda _: None) + assert len(inner_calls) == 1 + + @pytest.mark.asyncio + async def test_custom_header_name(self): + cfg = BearerTokenAuth( + validate_token=validator_from_token_map({"raw-key": Principal(caller_identity="p1")}), + header_name="x-adcp-auth", + bearer_prefix_required=False, + ) + inner_calls: list[dict] = [] + + async def inner(scope: Any, _receive: Any, _send: Any) -> None: + inner_calls.append(scope) + + mw = A2ABearerAuthMiddleware(inner, cfg) + await mw( + self._scope(headers=[(b"x-adcp-auth", b"raw-key")]), + lambda: None, + lambda _: None, + ) + assert inner_calls[0]["user"].display_name == "p1" + + +# =========================================================================== +# A2A through ASGI: full Starlette stack +# =========================================================================== + + +@pytest.mark.asyncio +async def test_a2a_agent_card_publicly_accessible_with_auth() -> None: + """End-to-end: ``/.well-known/agent-card.json`` MUST be public + even when auth is configured. Path-based exemption inside + :class:`A2ABearerAuthMiddleware` lets the request through to the + a2a-sdk's agent-card route.""" + from adcp.server.a2a_server import create_a2a_server + + inner = create_a2a_server(_OkHandler(), name="test-agent", validation=None) + app = A2ABearerAuthMiddleware(inner, _auth()) + async with LifespanManager(inner): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.get("/.well-known/agent-card.json") + assert response.status_code == 200 + body = response.json() + assert "name" in body # Full agent card came back, not a 401 body. + + +@pytest.mark.asyncio +async def test_a2a_jsonrpc_unauthenticated_returns_http_401() -> None: + """No Authorization header on a JSON-RPC request → middleware + short-circuits with HTTP 401. Critical for spec compliance: + earlier designs that raised HTTPException from inside the + a2a-sdk dispatcher were swallowed by the v0.3 compat adapter + and projected to HTTP 200 with a JSON-RPC error body.""" + from adcp.server.a2a_server import create_a2a_server + + inner = create_a2a_server(_OkHandler(), name="test-agent", validation=None) + app = A2ABearerAuthMiddleware(inner, _auth()) + body = { + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "messageId": "m1", + "role": "user", + "parts": [{"kind": "data", "data": {"skill": "get_products", "parameters": {}}}], + } + }, + } + async with LifespanManager(inner): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post("/", json=body) + assert response.status_code == 401 + body = response.json() + assert body["error"] == "invalid_token" + + +@pytest.mark.asyncio +async def test_a2a_jsonrpc_authenticated_passes_through() -> None: + """Valid bearer header → request reaches the handler.""" + from adcp.server.a2a_server import create_a2a_server + + inner = create_a2a_server(_OkHandler(), name="test-agent", validation=None) + app = A2ABearerAuthMiddleware(inner, _auth()) + body = { + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "messageId": "m1", + "role": "user", + "parts": [{"kind": "data", "data": {"skill": "get_products", "parameters": {}}}], + } + }, + } + async with LifespanManager(inner): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/", json=body, headers={"Authorization": "Bearer good-token"} + ) + assert response.status_code == 200 + + +# =========================================================================== +# transport="both": the regression case from #558 +# =========================================================================== + + +def _build_both_app(auth: Any | None = None) -> Any: + """Build the unified MCP+A2A ASGI app via the same path + ``serve(transport="both")`` uses, but without uvicorn so we can + drive it through ``httpx.AsyncClient``.""" + from adcp.server.serve import _build_mcp_and_a2a_app + + return _build_mcp_and_a2a_app( + _OkHandler(), + name="test-agent", + port=0, + host="127.0.0.1", + instructions=None, + test_controller=None, + validation=None, + auth=auth, + ) + + +@pytest.mark.asyncio +async def test_both_transport_a2a_leg_requires_auth_when_configured() -> None: + """The original bug: ``serve(transport="both", auth=...)`` was + expected to gate both legs but didn't. This test asserts the A2A + leg now rejects unauthenticated JSON-RPC under the unified + binary.""" + app = _build_both_app(_auth()) + body = { + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "messageId": "m1", + "role": "user", + "parts": [{"kind": "data", "data": {"skill": "get_products", "parameters": {}}}], + } + }, + } + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post("/", json=body) + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_both_transport_a2a_leg_accepts_valid_token() -> None: + """Auth is configured AND token is valid → A2A leg succeeds.""" + app = _build_both_app(_auth()) + body = { + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "messageId": "m1", + "role": "user", + "parts": [{"kind": "data", "data": {"skill": "get_products", "parameters": {}}}], + } + }, + } + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/", json=body, headers={"Authorization": "Bearer good-token"} + ) + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_both_transport_agent_card_publicly_accessible() -> None: + """A2A discovery (``/.well-known/agent-card.json``) MUST be public + even with ``auth=`` configured.""" + app = _build_both_app(_auth()) + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.get("/.well-known/agent-card.json") + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_both_transport_mcp_leg_requires_auth_when_configured() -> None: + """MCP leg under the unified binary still gates non-discovery + requests on a bearer token. Discovery methods (``initialize`` / + ``tools/list`` / ``get_adcp_capabilities``) bypass per + :class:`BearerTokenAuthMiddleware`'s body-peek logic.""" + app = _build_both_app(_auth()) + # tools/call without a token → 401 from the MCP middleware. + body = { + "jsonrpc": "2.0", + "id": "1", + "method": "tools/call", + "params": {"name": "get_products", "arguments": {}}, + } + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://test", + follow_redirects=True, + ) as client: + response = await client.post( + "/mcp", + json=body, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + assert response.status_code == 401 + assert "unauthenticated" in response.text + + +@pytest.mark.asyncio +async def test_both_transport_no_auth_runs_unauthenticated() -> None: + """Without ``auth=``, both legs accept everything (preserves the + pre-fix unauthenticated default — turning auth on is opt-in).""" + app = _build_both_app(auth=None) + body = { + "jsonrpc": "2.0", + "id": "1", + "method": "message/send", + "params": { + "message": { + "messageId": "m1", + "role": "user", + "parts": [{"kind": "data", "data": {"skill": "get_products", "parameters": {}}}], + } + }, + } + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post("/", json=body) + # No auth configured → A2A serves the request without checking. + assert response.status_code == 200 + + +# =========================================================================== +# Type-guard: serve(auth=...) rejects non-BearerTokenAuth +# =========================================================================== + + +def test_serve_auth_rejects_wrong_type_mcp() -> None: + from adcp.server.serve import _wrap_mcp_with_auth + + with pytest.raises(TypeError, match="BearerTokenAuth"): + _wrap_mcp_with_auth(MagicMock(), {"validate_token": lambda t: None}) + + +def test_serve_auth_rejects_wrong_type_a2a() -> None: + from adcp.server.serve import _wrap_a2a_with_auth + + with pytest.raises(TypeError, match="BearerTokenAuth"): + _wrap_a2a_with_auth(MagicMock(), "not-a-config") + + +def test_serve_auth_none_is_noop() -> None: + from adcp.server.serve import _wrap_a2a_with_auth, _wrap_mcp_with_auth + + sentinel = MagicMock() + assert _wrap_mcp_with_auth(sentinel, None) is sentinel + assert _wrap_a2a_with_auth(sentinel, None) is sentinel + + +def test_public_exports_include_new_symbols() -> None: + import adcp.server as srv + + assert hasattr(srv, "BearerTokenAuth") + assert hasattr(srv, "A2ABearerAuthMiddleware") + assert "BearerTokenAuth" in srv.__all__ + assert "A2ABearerAuthMiddleware" in srv.__all__ + + +# =========================================================================== +# Structural drift guard: a2a-sdk well-known route renames break loud +# =========================================================================== + + +def test_discovery_paths_match_a2a_sdk_routes() -> None: + """Catch silent drift between :data:`_A2A_DISCOVERY_PATHS` and + a2a-sdk's actual agent-card routes. If a future a2a-sdk release + renames ``/.well-known/agent-card.json`` (or removes the v0.3 + alias), the frozenset would leave the renamed route + unauthenticated until someone noticed. This test fails first. + + Walks ``create_agent_card_routes`` against a real ``AgentCard`` + and asserts every registered path is in the frozenset. + """ + from a2a.server.routes import create_agent_card_routes + + from adcp.server.a2a_server import _build_agent_card + from adcp.server.auth import _A2A_DISCOVERY_PATHS + + handler = _OkHandler() + agent_card = _build_agent_card( + handler, + name="drift-guard", + port=0, + description=None, + version="1.0.0", + extra_skills=None, + advertise_all=False, + push_notifications_supported=False, + ) + routes = create_agent_card_routes(agent_card=agent_card) + + registered_paths = [r.path for r in routes] + assert registered_paths, "a2a-sdk returned no agent-card routes" + + missing = [p for p in registered_paths if p not in _A2A_DISCOVERY_PATHS] + assert not missing, ( + f"a2a-sdk registers agent-card route(s) {missing!r} that are NOT in " + f"_A2A_DISCOVERY_PATHS={_A2A_DISCOVERY_PATHS!r}. Update the frozenset " + f"in adcp.server.auth to include the new path(s) — otherwise A2A " + f"discovery silently 401s on the renamed/added route." + ) + + +def test_a2a_agent_card_constant_referenced_directly() -> None: + """The 1.0 path uses ``a2a.utils.constants.AGENT_CARD_WELL_KNOWN_PATH`` + rather than a string literal. If a2a-sdk changes the constant, + our frozenset rebases without code changes. This test pins the + indirection so a future maintainer doesn't accidentally inline + the string.""" + from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH + + from adcp.server.auth import _A2A_DISCOVERY_PATHS + + assert AGENT_CARD_WELL_KNOWN_PATH in _A2A_DISCOVERY_PATHS + + +# =========================================================================== +# RFC 6750 / RFC 7235 compliance: 401 must carry WWW-Authenticate +# =========================================================================== + + +@pytest.mark.asyncio +async def test_401_includes_www_authenticate_header() -> None: + """RFC 7235 §3.1 + RFC 6750 §3 mandate ``WWW-Authenticate: Bearer`` + on 401 responses. Without it RFC-compliant clients (including + browsers) won't surface the auth challenge to the user.""" + from adcp.server.a2a_server import create_a2a_server + + inner = create_a2a_server(_OkHandler(), name="test-agent", validation=None) + app = A2ABearerAuthMiddleware(inner, _auth()) + body = {"jsonrpc": "2.0", "id": "1", "method": "message/send", "params": {}} + async with LifespanManager(inner): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post("/", json=body) + assert response.status_code == 401 + challenge = response.headers.get("www-authenticate", "") + assert challenge.lower().startswith("bearer") + assert "realm" in challenge.lower() + + +@pytest.mark.asyncio +async def test_401_body_uses_rfc6750_error_codes() -> None: + """RFC 6750 §3.1 defines ``invalid_token`` / ``invalid_request`` / + ``insufficient_scope``. Default body uses ``invalid_token`` so + OAuth-aware tooling parses it correctly.""" + from adcp.server.a2a_server import create_a2a_server + + inner = create_a2a_server(_OkHandler(), name="test-agent", validation=None) + app = A2ABearerAuthMiddleware(inner, _auth()) + async with LifespanManager(inner): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/", json={"jsonrpc": "2.0", "id": "1", "method": "message/send", "params": {}} + ) + assert response.status_code == 401 + body = response.json() + assert body.get("error") == "invalid_token" + + +# =========================================================================== +# CORS preflight: OPTIONS must bypass auth +# =========================================================================== + + +@pytest.mark.asyncio +async def test_options_preflight_bypasses_auth() -> None: + """Browser-origin clients send ``OPTIONS`` before any authenticated + POST. Returning 401 on the preflight breaks CORS — the buyer + never gets a chance to retry with a token. The middleware must + pass OPTIONS through unauthenticated.""" + inner_calls: list[dict] = [] + + async def inner(scope: Any, _receive: Any, _send: Any) -> None: + inner_calls.append(scope) + + mw = A2ABearerAuthMiddleware(inner, _auth()) + scope = { + "type": "http", + "method": "OPTIONS", + "path": "/", + "headers": [], + } + await mw(scope, lambda: None, lambda _: None) + assert len(inner_calls) == 1 # Inner reached. + assert "user" not in inner_calls[0] # No principal injected on preflight. + + +# =========================================================================== +# Async-validator rejection at boot, not at request time +# =========================================================================== + + +def test_async_validator_rejected_at_serve_boot_time() -> None: + """Async validators on A2A fail at config time so production + deployments don't ship with silently-failing auth that only + surfaces on first traffic. MCP middleware awaits async + validators transparently; A2A's middleware path is sync.""" + from adcp.server.serve import _wrap_a2a_with_auth + + async def async_validator(_token: str) -> Principal | None: + return Principal(caller_identity="p1") + + cfg = BearerTokenAuth(validate_token=async_validator) + with pytest.raises(TypeError, match="async"): + _wrap_a2a_with_auth(MagicMock(), cfg) + + +def test_sync_lambda_validator_passes_boot_check() -> None: + """Sync lambda / function validators are accepted unchanged.""" + from adcp.server.serve import _wrap_a2a_with_auth + + cfg = BearerTokenAuth(validate_token=lambda t: None) + # No exception — the wrap returns an A2ABearerAuthMiddleware instance. + wrapped = _wrap_a2a_with_auth(MagicMock(), cfg) + assert isinstance(wrapped, A2ABearerAuthMiddleware) + + +# =========================================================================== +# Validator-exception suppression survives the full ASGI stack +# =========================================================================== + + +@pytest.mark.asyncio +async def test_validator_exception_returns_401_through_full_stack() -> None: + """The unit-level test asserts the middleware short-circuits with + 401 when the validator raises. This test asserts the same shape + survives the full Starlette / a2a-sdk stack — i.e., the 500 + suppression isn't an artifact of the unit harness.""" + from adcp.server.a2a_server import create_a2a_server + + def boom(_token: str) -> Principal | None: + raise RuntimeError("token store down") + + inner = create_a2a_server(_OkHandler(), name="test-agent", validation=None) + app = A2ABearerAuthMiddleware(inner, BearerTokenAuth(validate_token=boom)) + async with LifespanManager(inner): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), base_url="http://test" + ) as client: + response = await client.post( + "/", + json={"jsonrpc": "2.0", "id": "1", "method": "message/send", "params": {}}, + headers={"Authorization": "Bearer x"}, + ) + assert response.status_code == 401 # Not 500.