From abb85dd2cc3bb9631c9911cc9a23c4fcd5aa1551 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Fri, 17 Apr 2026 10:18:56 +0000 Subject: [PATCH 1/3] fix: propagate activated extensions for REST --- src/a2a/server/routes/rest_dispatcher.py | 67 +++++++++++++++--------- tests/integration/test_end_to_end.py | 54 +++++++++++++++++++ 2 files changed, 95 insertions(+), 26 deletions(-) diff --git a/src/a2a/server/routes/rest_dispatcher.py b/src/a2a/server/routes/rest_dispatcher.py index 8af384893..76ddb3819 100644 --- a/src/a2a/server/routes/rest_dispatcher.py +++ b/src/a2a/server/routes/rest_dispatcher.py @@ -6,6 +6,7 @@ from google.protobuf.json_format import MessageToDict, Parse +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.routes.common import ( @@ -99,14 +100,29 @@ def _build_call_context(self, request: Request) -> ServerCallContext: call_context.tenant = request.path_params['tenant'] return call_context + def _extension_headers(self, context: ServerCallContext) -> dict[str, str]: + """Builds response headers carrying the activated extensions, if any.""" + if exts := context.activated_extensions: + return {HTTP_EXTENSION_HEADER: ', '.join(sorted(exts))} + return {} + async def _handle_non_streaming( self, request: Request, handler_func: Callable[[ServerCallContext], Awaitable[TResponse]], - ) -> TResponse: - """Centralized error handling and context management for unary calls.""" + serializer: Callable[[TResponse], Any] = MessageToDict, + ) -> JSONResponse: + """Centralized error handling and context management for unary calls. + + Builds the call context, invokes the handler, and wraps the result in + a `JSONResponse` carrying any activated-extension headers. + """ context = self._build_call_context(request) - return await handler_func(context) + response = await handler_func(context) + return JSONResponse( + content=serializer(response), + headers=self._extension_headers(context), + ) async def _handle_streaming( self, @@ -137,7 +153,9 @@ async def _handle_streaming( try: first_item = await anext(stream) except StopAsyncIteration: - return EventSourceResponse(iter([])) + return EventSourceResponse( + iter([]), headers=self._extension_headers(context) + ) async def event_generator() -> AsyncIterator[ServerSentEvent]: yield ServerSentEvent(data=json.dumps(first_item)) @@ -151,7 +169,9 @@ async def event_generator() -> AsyncIterator[ServerSentEvent]: event='error', ) - return EventSourceResponse(event_generator()) + return EventSourceResponse( + event_generator(), headers=self._extension_headers(context) + ) @rest_error_handler async def on_message_send(self, request: Request) -> Response: @@ -171,8 +191,7 @@ async def _handler( return a2a_pb2.SendMessageResponse(task=task_or_message) return a2a_pb2.SendMessageResponse(message=task_or_message) - response = await self._handle_non_streaming(request, _handler) - return JSONResponse(content=MessageToDict(response)) + return await self._handle_non_streaming(request, _handler) @rest_stream_error_handler async def on_message_send_stream( @@ -209,8 +228,7 @@ async def _handler(context: ServerCallContext) -> a2a_pb2.Task: return task raise TaskNotFoundError - response = await self._handle_non_streaming(request, _handler) - return JSONResponse(content=MessageToDict(response)) + return await self._handle_non_streaming(request, _handler) @rest_stream_error_handler async def on_subscribe_to_task( @@ -245,8 +263,7 @@ async def _handler(context: ServerCallContext) -> a2a_pb2.Task: return task raise TaskNotFoundError - response = await self._handle_non_streaming(request, _handler) - return JSONResponse(content=MessageToDict(response)) + return await self._handle_non_streaming(request, _handler) @rest_error_handler async def get_push_notification(self, request: Request) -> Response: @@ -267,8 +284,7 @@ async def _handler( ) ) - response = await self._handle_non_streaming(request, _handler) - return JSONResponse(content=MessageToDict(response)) + return await self._handle_non_streaming(request, _handler) @rest_error_handler async def delete_push_notification(self, request: Request) -> Response: @@ -285,8 +301,9 @@ async def _handler(context: ServerCallContext) -> None: params, context ) - await self._handle_non_streaming(request, _handler) - return JSONResponse(content={}) + return await self._handle_non_streaming( + request, _handler, serializer=lambda _: {} + ) @rest_error_handler async def set_push_notification(self, request: Request) -> Response: @@ -304,8 +321,7 @@ async def _handler( params, context ) - response = await self._handle_non_streaming(request, _handler) - return JSONResponse(content=MessageToDict(response)) + return await self._handle_non_streaming(request, _handler) @rest_error_handler async def list_push_notifications(self, request: Request) -> Response: @@ -322,8 +338,7 @@ async def _handler( params, context ) - response = await self._handle_non_streaming(request, _handler) - return JSONResponse(content=MessageToDict(response)) + return await self._handle_non_streaming(request, _handler) @rest_error_handler async def list_tasks(self, request: Request) -> Response: @@ -337,11 +352,12 @@ async def _handler( proto_utils.parse_params(request.query_params, params) return await self.request_handler.on_list_tasks(params, context) - response = await self._handle_non_streaming(request, _handler) - return JSONResponse( - content=MessageToDict( - response, always_print_fields_with_no_presence=True - ) + return await self._handle_non_streaming( + request, + _handler, + serializer=lambda r: MessageToDict( + r, always_print_fields_with_no_presence=True + ), ) @rest_error_handler @@ -359,5 +375,4 @@ async def _handler( params, context ) - response = await self._handle_non_streaming(request, _handler) - return JSONResponse(content=MessageToDict(response)) + return await self._handle_non_streaming(request, _handler) diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index aea9784ad..f578a9d94 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -15,6 +15,7 @@ ServiceParametersFactory, with_a2a_extensions, ) +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager @@ -835,3 +836,56 @@ async def test_end_to_end_extensions_propagation(transport_setups, streaming): response.message, Role.ROLE_AGENT, 'extensions echoed' ) assert set(response.message.extensions) == set(SUPPORTED_EXTENSION_URIS) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'transport_fixture', + [ + pytest.param('rest_setup', id='REST'), + pytest.param('jsonrpc_setup', id='JSON-RPC'), + ], +) +async def test_end_to_end_extensions_response_header( + request, transport_fixture +): + """Test that activated extensions are returned in the X-A2A-Extensions + response header for HTTP-based transports.""" + setup = request.getfixturevalue(transport_fixture) + client = setup.client + client._config.streaming = False + + captured_headers: list[httpx.Headers] = [] + + async def capture_response(response: httpx.Response) -> None: + captured_headers.append(response.headers) + + client._transport.httpx_client.event_hooks['response'].append( + capture_response + ) + + service_params = ServiceParametersFactory.create( + [with_a2a_extensions(SUPPORTED_EXTENSION_URIS)] + ) + context = ClientCallContext(service_parameters=service_params) + + message_to_send = Message( + role=Role.ROLE_USER, + message_id='msg-ext-header', + parts=[Part(text='Extensions: echo')], + ) + + async for _ in client.send_message( + request=SendMessageRequest(message=message_to_send), + context=context, + ): + pass + + assert captured_headers, 'No HTTP response was captured' + response_headers = captured_headers[-1] + assert HTTP_EXTENSION_HEADER in response_headers + returned = { + ext.strip() + for ext in response_headers[HTTP_EXTENSION_HEADER].split(',') + } + assert returned == set(SUPPORTED_EXTENSION_URIS) From 76829dd512cd0d2b8d6a9b43055746d46d7ce6af Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Fri, 17 Apr 2026 12:49:05 +0000 Subject: [PATCH 2/3] fix: cleanup extensions --- src/a2a/compat/v0_3/context_builders.py | 80 +++++++++ src/a2a/compat/v0_3/extension_headers.py | 27 +++ src/a2a/compat/v0_3/grpc_handler.py | 19 +-- src/a2a/compat/v0_3/grpc_transport.py | 5 +- src/a2a/compat/v0_3/jsonrpc_adapter.py | 3 +- src/a2a/compat/v0_3/jsonrpc_transport.py | 3 + src/a2a/compat/v0_3/rest_adapter.py | 3 +- src/a2a/compat/v0_3/rest_transport.py | 3 + src/a2a/extensions/common.py | 2 +- src/a2a/server/agent_execution/context.py | 10 +- src/a2a/server/context.py | 1 - .../server/request_handlers/grpc_handler.py | 15 -- src/a2a/server/routes/jsonrpc_dispatcher.py | 12 +- src/a2a/server/routes/rest_dispatcher.py | 24 +-- .../client/transports/test_jsonrpc_client.py | 4 +- tests/client/transports/test_rest_client.py | 6 +- tests/compat/v0_3/test_context_builders.py | 159 ++++++++++++++++++ tests/compat/v0_3/test_extension_headers.py | 39 +++++ tests/compat/v0_3/test_grpc_handler.py | 20 --- tests/compat/v0_3/test_grpc_transport.py | 28 +++ tests/compat/v0_3/test_jsonrpc_transport.py | 26 +++ .../test_client_server_integration.py | 4 +- tests/integration/test_end_to_end.py | 64 +------ tests/server/agent_execution/test_context.py | 8 +- .../request_handlers/test_grpc_handler.py | 38 +---- .../server/routes/test_jsonrpc_dispatcher.py | 25 --- 26 files changed, 400 insertions(+), 228 deletions(-) create mode 100644 src/a2a/compat/v0_3/context_builders.py create mode 100644 src/a2a/compat/v0_3/extension_headers.py create mode 100644 tests/compat/v0_3/test_context_builders.py create mode 100644 tests/compat/v0_3/test_extension_headers.py diff --git a/src/a2a/compat/v0_3/context_builders.py b/src/a2a/compat/v0_3/context_builders.py new file mode 100644 index 000000000..2f2eec362 --- /dev/null +++ b/src/a2a/compat/v0_3/context_builders.py @@ -0,0 +1,80 @@ +"""Context builders that add v0.3 backwards-compatibility for extensions. + +The current spec uses ``A2A-Extensions`` (RFC 6648, no ``X-`` prefix). v0.3 +clients still send the old ``X-A2A-Extensions`` name, so the v0.3 compat +adapters wrap the default builders with these classes to recognize both names. +""" + +from typing import TYPE_CHECKING, Any + +import grpc + +from a2a.compat.v0_3.extension_headers import LEGACY_HTTP_EXTENSION_HEADER +from a2a.extensions.common import get_requested_extensions +from a2a.server.context import ServerCallContext + + +if TYPE_CHECKING: + from starlette.requests import Request + + from a2a.server.request_handlers.grpc_handler import ( + GrpcServerCallContextBuilder, + ) + from a2a.server.routes.common import ServerCallContextBuilder +else: + try: + from starlette.requests import Request + except ImportError: + Request = Any + + +def _get_legacy_grpc_extensions( + context: grpc.aio.ServicerContext, +) -> list[str]: + md = context.invocation_metadata() + if md is None: + return [] + lower_key = LEGACY_HTTP_EXTENSION_HEADER.lower() + return [ + e if isinstance(e, str) else e.decode('utf-8') + for k, e in md + if k.lower() == lower_key + ] + + +class V03ServerCallContextBuilder: + """Wraps a ServerCallContextBuilder to also accept the legacy header. + + Recognizes the v0.3 ``X-A2A-Extensions`` HTTP header in addition to the + spec ``A2A-Extensions``. + """ + + def __init__(self, inner: 'ServerCallContextBuilder') -> None: + self._inner = inner + + def build(self, request: 'Request') -> ServerCallContext: + """Builds a ServerCallContext, merging legacy extension headers.""" + context = self._inner.build(request) + context.requested_extensions |= get_requested_extensions( + request.headers.getlist(LEGACY_HTTP_EXTENSION_HEADER) + ) + return context + + +class V03GrpcServerCallContextBuilder: + """Wraps a GrpcServerCallContextBuilder to also accept the legacy metadata. + + Recognizes the v0.3 ``X-A2A-Extensions`` gRPC metadata key in addition to + the spec ``A2A-Extensions``. + """ + + def __init__(self, inner: 'GrpcServerCallContextBuilder') -> None: + self._inner = inner + + def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: + """Builds a ServerCallContext, merging legacy extension metadata.""" + server_context = self._inner.build(context) + server_context.requested_extensions |= get_requested_extensions( + _get_legacy_grpc_extensions(context) + ) + return server_context diff --git a/src/a2a/compat/v0_3/extension_headers.py b/src/a2a/compat/v0_3/extension_headers.py new file mode 100644 index 000000000..e1421a0b0 --- /dev/null +++ b/src/a2a/compat/v0_3/extension_headers.py @@ -0,0 +1,27 @@ +"""Shared header name constants for v0.3 extension compatibility. + +The current spec uses ``A2A-Extensions``. v0.3 used the ``X-`` prefixed +``X-A2A-Extensions`` form. v0.3 compat servers and clients accept/emit both +names so they can interoperate with peers that only know the legacy one. +""" + +from a2a.client.service_parameters import ServiceParameters +from a2a.extensions.common import HTTP_EXTENSION_HEADER + + +LEGACY_HTTP_EXTENSION_HEADER = f'X-{HTTP_EXTENSION_HEADER}' + + +def add_legacy_extension_header(parameters: ServiceParameters) -> None: + """Mirrors the ``A2A-Extensions`` parameter under its legacy name in-place. + + Used by v0.3 compat client transports so that requests can be understood + by older v0.3 servers that only recognize ``X-A2A-Extensions``. + """ + if ( + HTTP_EXTENSION_HEADER in parameters + and LEGACY_HTTP_EXTENSION_HEADER not in parameters + ): + parameters[LEGACY_HTTP_EXTENSION_HEADER] = parameters[ + HTTP_EXTENSION_HEADER + ] diff --git a/src/a2a/compat/v0_3/grpc_handler.py b/src/a2a/compat/v0_3/grpc_handler.py index 23d1f831d..b7bec26ea 100644 --- a/src/a2a/compat/v0_3/grpc_handler.py +++ b/src/a2a/compat/v0_3/grpc_handler.py @@ -17,8 +17,8 @@ from a2a.compat.v0_3 import ( types as types_v03, ) +from a2a.compat.v0_3.context_builders import V03GrpcServerCallContextBuilder from a2a.compat.v0_3.request_handler import RequestHandler03 -from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.server.context import ServerCallContext from a2a.server.request_handlers.grpc_handler import ( _ERROR_CODE_MAP, @@ -51,7 +51,7 @@ def __init__( DefaultCallContextBuilder is used. """ self.handler03 = RequestHandler03(request_handler=request_handler) - self._context_builder = ( + self._context_builder = V03GrpcServerCallContextBuilder( context_builder or DefaultGrpcServerCallContextBuilder() ) @@ -65,7 +65,6 @@ async def _handle_unary( try: server_context = self._context_builder.build(context) result = await handler_func(server_context) - self._set_extension_metadata(context, server_context) except A2AError as e: await self.abort_context(e, context) else: @@ -82,7 +81,6 @@ async def _handle_stream( server_context = self._context_builder.build(context) async for item in handler_func(server_context): yield item - self._set_extension_metadata(context, server_context) except A2AError as e: await self.abort_context(e, context) @@ -120,19 +118,6 @@ async def abort_context( f'Unknown error type: {error}', ) - def _set_extension_metadata( - self, - context: grpc.aio.ServicerContext, - server_context: ServerCallContext, - ) -> None: - if server_context.activated_extensions: - context.set_trailing_metadata( - [ - (HTTP_EXTENSION_HEADER.lower(), e) - for e in sorted(server_context.activated_extensions) - ] - ) - async def SendMessage( self, request: a2a_v0_3_pb2.SendMessageRequest, diff --git a/src/a2a/compat/v0_3/grpc_transport.py b/src/a2a/compat/v0_3/grpc_transport.py index 32ce7f27b..95314e3f1 100644 --- a/src/a2a/compat/v0_3/grpc_transport.py +++ b/src/a2a/compat/v0_3/grpc_transport.py @@ -30,6 +30,7 @@ from a2a.compat.v0_3 import ( types as types_v03, ) +from a2a.compat.v0_3.extension_headers import add_legacy_extension_header from a2a.types import a2a_pb2 from a2a.utils.constants import PROTOCOL_VERSION_0_3, VERSION_HEADER from a2a.utils.telemetry import SpanKind, trace_class @@ -361,7 +362,9 @@ def _get_grpc_metadata( metadata = [(VERSION_HEADER.lower(), PROTOCOL_VERSION_0_3)] if context and context.service_parameters: - for key, value in context.service_parameters.items(): + params = dict(context.service_parameters) + add_legacy_extension_header(params) + for key, value in params.items(): metadata.append((key.lower(), value)) return metadata diff --git a/src/a2a/compat/v0_3/jsonrpc_adapter.py b/src/a2a/compat/v0_3/jsonrpc_adapter.py index baa2bcda8..8b4b26a79 100644 --- a/src/a2a/compat/v0_3/jsonrpc_adapter.py +++ b/src/a2a/compat/v0_3/jsonrpc_adapter.py @@ -24,6 +24,7 @@ _package_starlette_installed = False from a2a.compat.v0_3 import types as types_v03 +from a2a.compat.v0_3.context_builders import V03ServerCallContextBuilder from a2a.compat.v0_3.request_handler import RequestHandler03 from a2a.server.context import ServerCallContext from a2a.server.jsonrpc_models import ( @@ -70,7 +71,7 @@ def __init__( self.handler = RequestHandler03( request_handler=http_handler, ) - self._context_builder = ( + self._context_builder = V03ServerCallContextBuilder( context_builder or DefaultServerCallContextBuilder() ) diff --git a/src/a2a/compat/v0_3/jsonrpc_transport.py b/src/a2a/compat/v0_3/jsonrpc_transport.py index 557a63a16..caccd2811 100644 --- a/src/a2a/compat/v0_3/jsonrpc_transport.py +++ b/src/a2a/compat/v0_3/jsonrpc_transport.py @@ -19,6 +19,7 @@ ) from a2a.compat.v0_3 import conversions from a2a.compat.v0_3 import types as types_v03 +from a2a.compat.v0_3.extension_headers import add_legacy_extension_header from a2a.types.a2a_pb2 import ( AgentCard, CancelTaskRequest, @@ -424,6 +425,7 @@ async def _send_stream_request( http_kwargs = get_http_args(context) http_kwargs.setdefault('headers', {}) http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3 + add_legacy_extension_header(http_kwargs['headers']) async for sse_data in send_http_stream_request( self.httpx_client, @@ -485,6 +487,7 @@ async def _send_request( http_kwargs = get_http_args(context) http_kwargs.setdefault('headers', {}) http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3 + add_legacy_extension_header(http_kwargs['headers']) request = self.httpx_client.build_request( 'POST', diff --git a/src/a2a/compat/v0_3/rest_adapter.py b/src/a2a/compat/v0_3/rest_adapter.py index a2a9b56ee..38687054f 100644 --- a/src/a2a/compat/v0_3/rest_adapter.py +++ b/src/a2a/compat/v0_3/rest_adapter.py @@ -31,6 +31,7 @@ _package_starlette_installed = False +from a2a.compat.v0_3.context_builders import V03ServerCallContextBuilder from a2a.compat.v0_3.rest_handler import REST03Handler from a2a.server.routes.common import ( DefaultServerCallContextBuilder, @@ -60,7 +61,7 @@ def __init__( context_builder: 'ServerCallContextBuilder | None' = None, ): self.handler = REST03Handler(request_handler=http_handler) - self._context_builder = ( + self._context_builder = V03ServerCallContextBuilder( context_builder or DefaultServerCallContextBuilder() ) diff --git a/src/a2a/compat/v0_3/rest_transport.py b/src/a2a/compat/v0_3/rest_transport.py index 0ba38538d..bcaed2949 100644 --- a/src/a2a/compat/v0_3/rest_transport.py +++ b/src/a2a/compat/v0_3/rest_transport.py @@ -25,6 +25,7 @@ from a2a.compat.v0_3 import ( types as types_v03, ) +from a2a.compat.v0_3.extension_headers import add_legacy_extension_header from a2a.types.a2a_pb2 import ( AgentCard, CancelTaskRequest, @@ -380,6 +381,7 @@ async def _send_stream_request( http_kwargs = get_http_args(context) http_kwargs.setdefault('headers', {}) http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3 + add_legacy_extension_header(http_kwargs['headers']) async for sse_data in send_http_stream_request( self.httpx_client, @@ -414,6 +416,7 @@ async def _execute_request( http_kwargs = get_http_args(context) http_kwargs.setdefault('headers', {}) http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3 + add_legacy_extension_header(http_kwargs['headers']) request = self.httpx_client.build_request( method, diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py index 0595216ed..06ccf8f40 100644 --- a/src/a2a/extensions/common.py +++ b/src/a2a/extensions/common.py @@ -1,7 +1,7 @@ from a2a.types.a2a_pb2 import AgentCard, AgentExtension -HTTP_EXTENSION_HEADER = 'X-A2A-Extensions' +HTTP_EXTENSION_HEADER = 'A2A-Extensions' def get_requested_extensions(values: list[str]) -> set[str]: diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index 1feefb1df..910475e90 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -151,14 +151,6 @@ def metadata(self) -> dict[str, Any]: return dict(self._params.metadata) return {} - def add_activated_extension(self, uri: str) -> None: - """Add an extension to the set of activated extensions for this request. - - This causes the extension to be indicated back to the client in the - response. - """ - self._call_context.activated_extensions.add(uri) - @property def tenant(self) -> str: """The tenant associated with this request.""" @@ -166,7 +158,7 @@ def tenant(self) -> str: @property def requested_extensions(self) -> set[str]: - """Extensions that the client requested to activate.""" + """Extensions that the client requested for this interaction.""" return self._call_context.requested_extensions def _check_or_generate_task_id(self) -> None: diff --git a/src/a2a/server/context.py b/src/a2a/server/context.py index 6196a69d6..833ca44c4 100644 --- a/src/a2a/server/context.py +++ b/src/a2a/server/context.py @@ -23,4 +23,3 @@ class ServerCallContext(BaseModel): user: User = Field(default_factory=UnauthenticatedUser) tenant: str = Field(default='') requested_extensions: set[str] = Field(default_factory=set) - activated_extensions: set[str] = Field(default_factory=set) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 2ccfa9bdd..8cd421e93 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -135,7 +135,6 @@ async def _handle_unary( try: server_context = self._build_call_context(context, request) result = await handler_func(server_context) - self._set_extension_metadata(context, server_context) except A2AError as e: await self.abort_context(e, context) else: @@ -153,7 +152,6 @@ async def _handle_stream( server_context = self._build_call_context(context, request) async for item in handler_func(server_context): yield item - self._set_extension_metadata(context, server_context) except A2AError as e: await self.abort_context(e, context) @@ -422,19 +420,6 @@ async def abort_context( f'Unknown error type: {error}', ) - def _set_extension_metadata( - self, - context: grpc.aio.ServicerContext, - server_context: ServerCallContext, - ) -> None: - if server_context.activated_extensions: - context.set_trailing_metadata( - [ - (HTTP_EXTENSION_HEADER.lower(), e) - for e in sorted(server_context.activated_extensions) - ] - ) - def _build_call_context( self, context: grpc.aio.ServicerContext, diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index 60620081a..3dc94488a 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -11,9 +11,6 @@ from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response from a2a.compat.v0_3.jsonrpc_adapter import JSONRPC03Adapter -from a2a.extensions.common import ( - HTTP_EXTENSION_HEADER, -) from a2a.server.context import ServerCallContext from a2a.server.events import Event from a2a.server.jsonrpc_models import ( @@ -570,9 +567,6 @@ def _create_response( Returns: A Starlette JSONResponse or EventSourceResponse. """ - headers = {} - if exts := context.activated_extensions: - headers[HTTP_EXTENSION_HEADER] = ', '.join(sorted(exts)) if isinstance(handler_result, AsyncGenerator): # Result is a stream of dict objects async def event_generator( @@ -603,9 +597,7 @@ async def event_generator( 'data': json.dumps(error_response), } - return EventSourceResponse( - event_generator(handler_result), headers=headers - ) + return EventSourceResponse(event_generator(handler_result)) # handler_result is a dict (JSON-RPC response) - return JSONResponse(handler_result, headers=headers) + return JSONResponse(handler_result) diff --git a/src/a2a/server/routes/rest_dispatcher.py b/src/a2a/server/routes/rest_dispatcher.py index 76ddb3819..a699bf897 100644 --- a/src/a2a/server/routes/rest_dispatcher.py +++ b/src/a2a/server/routes/rest_dispatcher.py @@ -6,7 +6,6 @@ from google.protobuf.json_format import MessageToDict, Parse -from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.routes.common import ( @@ -100,12 +99,6 @@ def _build_call_context(self, request: Request) -> ServerCallContext: call_context.tenant = request.path_params['tenant'] return call_context - def _extension_headers(self, context: ServerCallContext) -> dict[str, str]: - """Builds response headers carrying the activated extensions, if any.""" - if exts := context.activated_extensions: - return {HTTP_EXTENSION_HEADER: ', '.join(sorted(exts))} - return {} - async def _handle_non_streaming( self, request: Request, @@ -114,15 +107,12 @@ async def _handle_non_streaming( ) -> JSONResponse: """Centralized error handling and context management for unary calls. - Builds the call context, invokes the handler, and wraps the result in - a `JSONResponse` carrying any activated-extension headers. + Builds the call context, invokes the handler, and wraps the serialized + result in a `JSONResponse`. """ context = self._build_call_context(request) response = await handler_func(context) - return JSONResponse( - content=serializer(response), - headers=self._extension_headers(context), - ) + return JSONResponse(content=serializer(response)) async def _handle_streaming( self, @@ -153,9 +143,7 @@ async def _handle_streaming( try: first_item = await anext(stream) except StopAsyncIteration: - return EventSourceResponse( - iter([]), headers=self._extension_headers(context) - ) + return EventSourceResponse(iter([])) async def event_generator() -> AsyncIterator[ServerSentEvent]: yield ServerSentEvent(data=json.dumps(first_item)) @@ -169,9 +157,7 @@ async def event_generator() -> AsyncIterator[ServerSentEvent]: event='error', ) - return EventSourceResponse( - event_generator(), headers=self._extension_headers(context) - ) + return EventSourceResponse(event_generator()) @rest_error_handler async def on_message_send(self, request: Request) -> Response: diff --git a/tests/client/transports/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py index 1339bb8af..b005c2e05 100644 --- a/tests/client/transports/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -545,7 +545,7 @@ async def test_extensions_added_to_request( from a2a.client.client import ClientCallContext context = ClientCallContext( - service_parameters={'X-A2A-Extensions': 'https://example.com/ext1'} + service_parameters={'A2A-Extensions': 'https://example.com/ext1'} ) await transport.send_message(request, context=context) @@ -555,7 +555,7 @@ async def test_extensions_added_to_request( call_args = mock_httpx_client.build_request.call_args # Extensions should be in the kwargs assert ( - call_args[1].get('headers', {}).get('X-A2A-Extensions') + call_args[1].get('headers', {}).get('A2A-Extensions') == 'https://example.com/ext1' ) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index e7912566e..5798bf850 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -263,7 +263,7 @@ async def test_send_message_with_default_extensions( context = ClientCallContext( service_parameters={ - 'X-A2A-Extensions': 'https://example.com/test-ext/v1,https://example.com/test-ext/v2' + 'A2A-Extensions': 'https://example.com/test-ext/v1,https://example.com/test-ext/v2' } ) await client.send_message(request=params, context=context) @@ -287,7 +287,7 @@ async def test_send_message_streaming_with_new_extensions( mock_httpx_client: AsyncMock, mock_agent_card: MagicMock, ): - """Test X-A2A-Extensions header in send_message_streaming.""" + """Test A2A-Extensions header in send_message_streaming.""" client = RestTransport( httpx_client=mock_httpx_client, agent_card=mock_agent_card, @@ -309,7 +309,7 @@ async def test_send_message_streaming_with_new_extensions( context = ClientCallContext( service_parameters={ - 'X-A2A-Extensions': 'https://example.com/test-ext/v2' + 'A2A-Extensions': 'https://example.com/test-ext/v2' } ) diff --git a/tests/compat/v0_3/test_context_builders.py b/tests/compat/v0_3/test_context_builders.py new file mode 100644 index 000000000..1b711f52f --- /dev/null +++ b/tests/compat/v0_3/test_context_builders.py @@ -0,0 +1,159 @@ +from unittest.mock import AsyncMock, MagicMock + +import grpc + +from starlette.datastructures import Headers + +from a2a.compat.v0_3.context_builders import ( + V03GrpcServerCallContextBuilder, + V03ServerCallContextBuilder, +) +from a2a.compat.v0_3.extension_headers import LEGACY_HTTP_EXTENSION_HEADER +from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.grpc_handler import ( + DefaultGrpcServerCallContextBuilder, +) +from a2a.server.routes.common import DefaultServerCallContextBuilder + + +def _make_mock_request(headers=None): + request = MagicMock() + request.scope = {} + request.headers = Headers(headers or {}) + return request + + +def _make_mock_grpc_context(metadata: list[tuple[str, str]]) -> AsyncMock: + context = AsyncMock(spec=grpc.aio.ServicerContext) + context.invocation_metadata.return_value = grpc.aio.Metadata(*metadata) + return context + + +class TestV03ServerCallContextBuilder: + def test_legacy_header_only(self): + request = _make_mock_request( + headers={LEGACY_HTTP_EXTENSION_HEADER: 'legacy-ext'} + ) + builder = V03ServerCallContextBuilder(DefaultServerCallContextBuilder()) + + ctx = builder.build(request) + + assert isinstance(ctx, ServerCallContext) + assert ctx.requested_extensions == {'legacy-ext'} + + def test_spec_header_only(self): + request = _make_mock_request( + headers={HTTP_EXTENSION_HEADER: 'spec-ext'} + ) + builder = V03ServerCallContextBuilder(DefaultServerCallContextBuilder()) + + ctx = builder.build(request) + + assert ctx.requested_extensions == {'spec-ext'} + + def test_both_headers_merged(self): + request = _make_mock_request( + headers={ + HTTP_EXTENSION_HEADER: 'spec-ext', + LEGACY_HTTP_EXTENSION_HEADER: 'legacy-ext', + } + ) + builder = V03ServerCallContextBuilder(DefaultServerCallContextBuilder()) + + ctx = builder.build(request) + + assert ctx.requested_extensions == {'spec-ext', 'legacy-ext'} + + def test_legacy_header_comma_separated(self): + request = _make_mock_request( + headers={LEGACY_HTTP_EXTENSION_HEADER: 'foo, bar'} + ) + builder = V03ServerCallContextBuilder(DefaultServerCallContextBuilder()) + + ctx = builder.build(request) + + assert ctx.requested_extensions == {'foo', 'bar'} + + def test_no_extensions(self): + request = _make_mock_request() + builder = V03ServerCallContextBuilder(DefaultServerCallContextBuilder()) + + ctx = builder.build(request) + + assert ctx.requested_extensions == set() + + +class TestV03GrpcServerCallContextBuilder: + def test_legacy_metadata_only(self): + context = _make_mock_grpc_context( + [(LEGACY_HTTP_EXTENSION_HEADER.lower(), 'legacy-ext')] + ) + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert isinstance(ctx, ServerCallContext) + assert ctx.requested_extensions == {'legacy-ext'} + + def test_spec_metadata_only(self): + context = _make_mock_grpc_context( + [(HTTP_EXTENSION_HEADER.lower(), 'spec-ext')] + ) + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert ctx.requested_extensions == {'spec-ext'} + + def test_both_metadata_merged(self): + context = _make_mock_grpc_context( + [ + (HTTP_EXTENSION_HEADER.lower(), 'spec-ext'), + (LEGACY_HTTP_EXTENSION_HEADER.lower(), 'legacy-ext'), + ] + ) + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert ctx.requested_extensions == {'spec-ext', 'legacy-ext'} + + def test_legacy_metadata_comma_separated(self): + context = _make_mock_grpc_context( + [(LEGACY_HTTP_EXTENSION_HEADER.lower(), 'foo, bar')] + ) + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert ctx.requested_extensions == {'foo', 'bar'} + + def test_no_extensions(self): + context = _make_mock_grpc_context([]) + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert ctx.requested_extensions == set() + + def test_no_metadata(self): + context = AsyncMock(spec=grpc.aio.ServicerContext) + context.invocation_metadata.return_value = None + builder = V03GrpcServerCallContextBuilder( + DefaultGrpcServerCallContextBuilder() + ) + + ctx = builder.build(context) + + assert ctx.requested_extensions == set() diff --git a/tests/compat/v0_3/test_extension_headers.py b/tests/compat/v0_3/test_extension_headers.py new file mode 100644 index 000000000..d5abbdfcc --- /dev/null +++ b/tests/compat/v0_3/test_extension_headers.py @@ -0,0 +1,39 @@ +from a2a.compat.v0_3.extension_headers import ( + LEGACY_HTTP_EXTENSION_HEADER, + add_legacy_extension_header, +) +from a2a.extensions.common import HTTP_EXTENSION_HEADER + + +def test_legacy_header_constant_value(): + assert LEGACY_HTTP_EXTENSION_HEADER == 'X-A2A-Extensions' + + +def test_mirrors_spec_header_under_legacy_name(): + params = {HTTP_EXTENSION_HEADER: 'foo,bar'} + + add_legacy_extension_header(params) + + assert params == { + HTTP_EXTENSION_HEADER: 'foo,bar', + LEGACY_HTTP_EXTENSION_HEADER: 'foo,bar', + } + + +def test_no_op_when_spec_header_absent(): + params = {'Other': 'value'} + + add_legacy_extension_header(params) + + assert params == {'Other': 'value'} + + +def test_does_not_overwrite_existing_legacy_header(): + params = { + HTTP_EXTENSION_HEADER: 'spec', + LEGACY_HTTP_EXTENSION_HEADER: 'legacy-original', + } + + add_legacy_extension_header(params) + + assert params[LEGACY_HTTP_EXTENSION_HEADER] == 'legacy-original' diff --git a/tests/compat/v0_3/test_grpc_handler.py b/tests/compat/v0_3/test_grpc_handler.py index 75c6421e8..fbd74f29f 100644 --- a/tests/compat/v0_3/test_grpc_handler.py +++ b/tests/compat/v0_3/test_grpc_handler.py @@ -7,8 +7,6 @@ a2a_v0_3_pb2, grpc_handler as compat_grpc_handler, ) -from a2a.extensions.common import HTTP_EXTENSION_HEADER -from a2a.server.context import ServerCallContext from a2a.server.request_handlers import RequestHandler from a2a.types import a2a_pb2 from a2a.utils.errors import TaskNotFoundError, InvalidParamsError @@ -506,21 +504,3 @@ async def test_extract_task_and_config_id_invalid( ): with pytest.raises(InvalidParamsError): handler._extract_task_and_config_id('invalid-name') - - -@pytest.mark.asyncio -async def test_handle_unary_extension_metadata( - handler: compat_grpc_handler.CompatGrpcHandler, - mock_request_handler: AsyncMock, - mock_grpc_context: AsyncMock, -) -> None: - async def mock_func(server_context: ServerCallContext): - server_context.activated_extensions.add('ext-1') - return a2a_pb2.Task() - - await handler._handle_unary(mock_grpc_context, mock_func, a2a_pb2.Task()) - - expected_metadata = [(HTTP_EXTENSION_HEADER.lower(), 'ext-1')] - mock_grpc_context.set_trailing_metadata.assert_called_once_with( - expected_metadata - ) diff --git a/tests/compat/v0_3/test_grpc_transport.py b/tests/compat/v0_3/test_grpc_transport.py index ba1e6af3d..402a57000 100644 --- a/tests/compat/v0_3/test_grpc_transport.py +++ b/tests/compat/v0_3/test_grpc_transport.py @@ -2,6 +2,7 @@ import pytest +from a2a.client.client import ClientCallContext from a2a.client.optionals import Channel from a2a.compat.v0_3 import a2a_v0_3_pb2 from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport @@ -38,3 +39,30 @@ async def test_compat_grpc_transport_send_message_response_msg_parsing(): assert isinstance(response, SendMessageResponse) assert response.HasField('message') assert response.message.message_id == 'msg-123' + + +def test_compat_grpc_transport_mirrors_extension_metadata(): + """Compat gRPC client must also emit the legacy x-a2a-extensions metadata + so that v0.3 servers (which only know that name) understand the request.""" + transport = CompatGrpcTransport( + channel=AsyncMock(spec=Channel), agent_card=None + ) + context = ClientCallContext( + service_parameters={'A2A-Extensions': 'foo,bar'} + ) + + metadata = dict(transport._get_grpc_metadata(context)) + + assert metadata['a2a-extensions'] == 'foo,bar' + assert metadata['x-a2a-extensions'] == 'foo,bar' + + +def test_compat_grpc_transport_no_extension_metadata(): + transport = CompatGrpcTransport( + channel=AsyncMock(spec=Channel), agent_card=None + ) + + metadata = dict(transport._get_grpc_metadata(None)) + + assert 'a2a-extensions' not in metadata + assert 'x-a2a-extensions' not in metadata diff --git a/tests/compat/v0_3/test_jsonrpc_transport.py b/tests/compat/v0_3/test_jsonrpc_transport.py index 50b33e162..70291f005 100644 --- a/tests/compat/v0_3/test_jsonrpc_transport.py +++ b/tests/compat/v0_3/test_jsonrpc_transport.py @@ -539,3 +539,29 @@ async def test_compat_jsonrpc_transport_send_request( mock_send_http_request.assert_called_once_with( transport.httpx_client, mock_request, transport._handle_http_error ) + + +@pytest.mark.asyncio +@patch('a2a.compat.v0_3.jsonrpc_transport.send_http_request') +async def test_compat_jsonrpc_transport_mirrors_extension_header( + mock_send_http_request, transport +): + """Compat client must also emit the legacy X-A2A-Extensions header so + that v0.3 servers (which only know that name) understand the request.""" + from a2a.client.client import ClientCallContext + + mock_send_http_request.return_value = {'result': {'ok': True}} + transport.httpx_client.build_request.return_value = httpx.Request( + 'POST', 'http://example.com' + ) + + context = ClientCallContext( + service_parameters={'A2A-Extensions': 'foo,bar'} + ) + + await transport._send_request({'some': 'data'}, context=context) + + _, kwargs = transport.httpx_client.build_request.call_args + headers = kwargs['headers'] + assert headers['A2A-Extensions'] == 'foo,bar' + assert headers['X-A2A-Extensions'] == 'foo,bar' diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 1ac8a7162..76da2e20f 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -675,9 +675,9 @@ async def test_json_transport_base_client_send_message_with_extensions( call_args[1] if len(call_args) > 1 else call_kwargs.get('context') ) service_params = getattr(called_context, 'service_parameters', {}) - assert 'X-A2A-Extensions' in service_params + assert 'A2A-Extensions' in service_params assert ( - service_params['X-A2A-Extensions'] + service_params['A2A-Extensions'] == 'https://example.com/test-ext/v1,https://example.com/test-ext/v2' ) diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index f578a9d94..ccabacbc1 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -15,7 +15,6 @@ ServiceParametersFactory, with_a2a_extensions, ) -from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager @@ -100,19 +99,15 @@ class MockAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): user_input = context.get_user_input() - # Extensions echo: activate all requested extensions and report them - # back via the Message.extensions field. + # Extensions echo: report the requested extensions back to the client + # via the Message.extensions field. if user_input.startswith('Extensions:'): - for ext_uri in context.requested_extensions: - context.add_activated_extension(ext_uri) await event_queue.enqueue_event( Message( role=Role.ROLE_AGENT, message_id='ext-reply-1', parts=[Part(text='extensions echoed')], - extensions=sorted( - context.call_context.activated_extensions - ), + extensions=sorted(context.requested_extensions), ) ) return @@ -836,56 +831,3 @@ async def test_end_to_end_extensions_propagation(transport_setups, streaming): response.message, Role.ROLE_AGENT, 'extensions echoed' ) assert set(response.message.extensions) == set(SUPPORTED_EXTENSION_URIS) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - 'transport_fixture', - [ - pytest.param('rest_setup', id='REST'), - pytest.param('jsonrpc_setup', id='JSON-RPC'), - ], -) -async def test_end_to_end_extensions_response_header( - request, transport_fixture -): - """Test that activated extensions are returned in the X-A2A-Extensions - response header for HTTP-based transports.""" - setup = request.getfixturevalue(transport_fixture) - client = setup.client - client._config.streaming = False - - captured_headers: list[httpx.Headers] = [] - - async def capture_response(response: httpx.Response) -> None: - captured_headers.append(response.headers) - - client._transport.httpx_client.event_hooks['response'].append( - capture_response - ) - - service_params = ServiceParametersFactory.create( - [with_a2a_extensions(SUPPORTED_EXTENSION_URIS)] - ) - context = ClientCallContext(service_parameters=service_params) - - message_to_send = Message( - role=Role.ROLE_USER, - message_id='msg-ext-header', - parts=[Part(text='Extensions: echo')], - ) - - async for _ in client.send_message( - request=SendMessageRequest(message=message_to_send), - context=context, - ): - pass - - assert captured_headers, 'No HTTP response was captured' - response_headers = captured_headers[-1] - assert HTTP_EXTENSION_HEADER in response_headers - returned = { - ext.strip() - for ext in response_headers[HTTP_EXTENSION_HEADER].split(',') - } - assert returned == set(SUPPORTED_EXTENSION_URIS) diff --git a/tests/server/agent_execution/test_context.py b/tests/server/agent_execution/test_context.py index 7ec612986..dce780f58 100644 --- a/tests/server/agent_execution/test_context.py +++ b/tests/server/agent_execution/test_context.py @@ -322,14 +322,8 @@ def test_init_with_context_id_and_existing_context_id_match( assert context.current_task == mock_task def test_extension_handling(self) -> None: - """Test extension handling in RequestContext.""" + """Test that requested_extensions is exposed via RequestContext.""" call_context = ServerCallContext(requested_extensions={'foo', 'bar'}) context = RequestContext(call_context=call_context) assert context.requested_extensions == {'foo', 'bar'} - - context.add_activated_extension('foo') - assert call_context.activated_extensions == {'foo'} - - context.add_activated_extension('baz') - assert call_context.activated_extensions == {'foo', 'baz'} diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 2b1a37385..d140d3d7b 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -421,19 +421,11 @@ async def test_send_message_with_extensions( (HTTP_EXTENSION_HEADER.lower(), 'foo'), (HTTP_EXTENSION_HEADER.lower(), 'bar'), ) - - def side_effect(request, context: ServerCallContext): - context.activated_extensions.add('foo') - context.activated_extensions.add('baz') - return types.Task( - id='task-1', - context_id='ctx-1', - status=types.TaskStatus( - state=types.TaskState.TASK_STATE_COMPLETED - ), - ) - - mock_request_handler.on_message_send.side_effect = side_effect + mock_request_handler.on_message_send.return_value = types.Task( + id='task-1', + context_id='ctx-1', + status=types.TaskStatus(state=types.TaskState.TASK_STATE_COMPLETED), + ) await grpc_handler.SendMessage( a2a_pb2.SendMessageRequest(), mock_grpc_context @@ -444,15 +436,6 @@ def side_effect(request, context: ServerCallContext): assert isinstance(call_context, ServerCallContext) assert call_context.requested_extensions == {'foo', 'bar'} - mock_grpc_context.set_trailing_metadata.assert_called_once() - called_metadata = ( - mock_grpc_context.set_trailing_metadata.call_args.args[0] - ) - assert set(called_metadata) == { - (HTTP_EXTENSION_HEADER.lower(), 'foo'), - (HTTP_EXTENSION_HEADER.lower(), 'baz'), - } - async def test_send_message_with_comma_separated_extensions( self, grpc_handler: GrpcHandler, @@ -490,8 +473,6 @@ async def test_send_streaming_message_with_extensions( ) async def side_effect(request, context: ServerCallContext): - context.activated_extensions.add('foo') - context.activated_extensions.add('baz') yield types.Task( id='task-1', context_id='ctx-1', @@ -517,15 +498,6 @@ async def side_effect(request, context: ServerCallContext): assert isinstance(call_context, ServerCallContext) assert call_context.requested_extensions == {'foo', 'bar'} - mock_grpc_context.set_trailing_metadata.assert_called_once() - called_metadata = ( - mock_grpc_context.set_trailing_metadata.call_args.args[0] - ) - assert set(called_metadata) == { - (HTTP_EXTENSION_HEADER.lower(), 'foo'), - (HTTP_EXTENSION_HEADER.lower(), 'baz'), - } - @pytest.mark.asyncio class TestTenantExtraction: diff --git a/tests/server/routes/test_jsonrpc_dispatcher.py b/tests/server/routes/test_jsonrpc_dispatcher.py index 15d3349cd..7ce73eb2e 100644 --- a/tests/server/routes/test_jsonrpc_dispatcher.py +++ b/tests/server/routes/test_jsonrpc_dispatcher.py @@ -169,31 +169,6 @@ def test_method_added_to_call_context_state(self, client, mock_handler): call_context = mock_handler.on_message_send.call_args[0][1] assert call_context.state['method'] == 'SendMessage' - def test_response_with_activated_extensions(self, client, mock_handler): - def side_effect(request, context: ServerCallContext): - context.activated_extensions.add('foo') - context.activated_extensions.add('baz') - return Message( - message_id='test', - role=Role.ROLE_AGENT, - parts=[Part(text='response message')], - ) - - mock_handler.on_message_send.side_effect = side_effect - - response = client.post( - '/', - json=_make_send_message_request(), - ) - response.raise_for_status() - - assert response.status_code == 200 - assert HTTP_EXTENSION_HEADER in response.headers - assert set(response.headers[HTTP_EXTENSION_HEADER].split(', ')) == { - 'foo', - 'baz', - } - class TestJsonRpcDispatcherTenant: def test_tenant_extraction_from_params(self, client, mock_handler): From 44358a1493cb9b5508ce8f5a876cbdff618d1d12 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Fri, 17 Apr 2026 12:53:53 +0000 Subject: [PATCH 3/3] Update --- src/a2a/server/routes/rest_dispatcher.py | 49 ++++++++++++------------ 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/src/a2a/server/routes/rest_dispatcher.py b/src/a2a/server/routes/rest_dispatcher.py index a699bf897..8af384893 100644 --- a/src/a2a/server/routes/rest_dispatcher.py +++ b/src/a2a/server/routes/rest_dispatcher.py @@ -103,16 +103,10 @@ async def _handle_non_streaming( self, request: Request, handler_func: Callable[[ServerCallContext], Awaitable[TResponse]], - serializer: Callable[[TResponse], Any] = MessageToDict, - ) -> JSONResponse: - """Centralized error handling and context management for unary calls. - - Builds the call context, invokes the handler, and wraps the serialized - result in a `JSONResponse`. - """ + ) -> TResponse: + """Centralized error handling and context management for unary calls.""" context = self._build_call_context(request) - response = await handler_func(context) - return JSONResponse(content=serializer(response)) + return await handler_func(context) async def _handle_streaming( self, @@ -177,7 +171,8 @@ async def _handler( return a2a_pb2.SendMessageResponse(task=task_or_message) return a2a_pb2.SendMessageResponse(message=task_or_message) - return await self._handle_non_streaming(request, _handler) + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) @rest_stream_error_handler async def on_message_send_stream( @@ -214,7 +209,8 @@ async def _handler(context: ServerCallContext) -> a2a_pb2.Task: return task raise TaskNotFoundError - return await self._handle_non_streaming(request, _handler) + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) @rest_stream_error_handler async def on_subscribe_to_task( @@ -249,7 +245,8 @@ async def _handler(context: ServerCallContext) -> a2a_pb2.Task: return task raise TaskNotFoundError - return await self._handle_non_streaming(request, _handler) + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) @rest_error_handler async def get_push_notification(self, request: Request) -> Response: @@ -270,7 +267,8 @@ async def _handler( ) ) - return await self._handle_non_streaming(request, _handler) + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) @rest_error_handler async def delete_push_notification(self, request: Request) -> Response: @@ -287,9 +285,8 @@ async def _handler(context: ServerCallContext) -> None: params, context ) - return await self._handle_non_streaming( - request, _handler, serializer=lambda _: {} - ) + await self._handle_non_streaming(request, _handler) + return JSONResponse(content={}) @rest_error_handler async def set_push_notification(self, request: Request) -> Response: @@ -307,7 +304,8 @@ async def _handler( params, context ) - return await self._handle_non_streaming(request, _handler) + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) @rest_error_handler async def list_push_notifications(self, request: Request) -> Response: @@ -324,7 +322,8 @@ async def _handler( params, context ) - return await self._handle_non_streaming(request, _handler) + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) @rest_error_handler async def list_tasks(self, request: Request) -> Response: @@ -338,12 +337,11 @@ async def _handler( proto_utils.parse_params(request.query_params, params) return await self.request_handler.on_list_tasks(params, context) - return await self._handle_non_streaming( - request, - _handler, - serializer=lambda r: MessageToDict( - r, always_print_fields_with_no_presence=True - ), + response = await self._handle_non_streaming(request, _handler) + return JSONResponse( + content=MessageToDict( + response, always_print_fields_with_no_presence=True + ) ) @rest_error_handler @@ -361,4 +359,5 @@ async def _handler( params, context ) - return await self._handle_non_streaming(request, _handler) + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response))