Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions src/a2a/compat/v0_3/context_builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Context builders that add v0.3 backwards-compatibility for extensions.

The current spec uses ``A2A-Extensions`` (RFC 6648, no ``X-`` prefix). v0.3
clients still send the old ``X-A2A-Extensions`` name, so the v0.3 compat
adapters wrap the default builders with these classes to recognize both names.
"""

from typing import TYPE_CHECKING, Any

import grpc

from a2a.compat.v0_3.extension_headers import LEGACY_HTTP_EXTENSION_HEADER
from a2a.extensions.common import get_requested_extensions
from a2a.server.context import ServerCallContext


if TYPE_CHECKING:
from starlette.requests import Request

from a2a.server.request_handlers.grpc_handler import (
GrpcServerCallContextBuilder,
)
from a2a.server.routes.common import ServerCallContextBuilder
else:
try:
from starlette.requests import Request
except ImportError:
Request = Any


def _get_legacy_grpc_extensions(
context: grpc.aio.ServicerContext,
) -> list[str]:
md = context.invocation_metadata()
if md is None:
return []
lower_key = LEGACY_HTTP_EXTENSION_HEADER.lower()
return [
e if isinstance(e, str) else e.decode('utf-8')
for k, e in md
if k.lower() == lower_key
]


class V03ServerCallContextBuilder:
"""Wraps a ServerCallContextBuilder to also accept the legacy header.

Recognizes the v0.3 ``X-A2A-Extensions`` HTTP header in addition to the
spec ``A2A-Extensions``.
"""

def __init__(self, inner: 'ServerCallContextBuilder') -> None:
self._inner = inner

def build(self, request: 'Request') -> ServerCallContext:
"""Builds a ServerCallContext, merging legacy extension headers."""
context = self._inner.build(request)
context.requested_extensions |= get_requested_extensions(
request.headers.getlist(LEGACY_HTTP_EXTENSION_HEADER)
)
return context


class V03GrpcServerCallContextBuilder:
"""Wraps a GrpcServerCallContextBuilder to also accept the legacy metadata.

Recognizes the v0.3 ``X-A2A-Extensions`` gRPC metadata key in addition to
the spec ``A2A-Extensions``.
"""

def __init__(self, inner: 'GrpcServerCallContextBuilder') -> None:
self._inner = inner

def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
"""Builds a ServerCallContext, merging legacy extension metadata."""
server_context = self._inner.build(context)
server_context.requested_extensions |= get_requested_extensions(
_get_legacy_grpc_extensions(context)
)
return server_context
27 changes: 27 additions & 0 deletions src/a2a/compat/v0_3/extension_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Shared header name constants for v0.3 extension compatibility.

The current spec uses ``A2A-Extensions``. v0.3 used the ``X-`` prefixed
``X-A2A-Extensions`` form. v0.3 compat servers and clients accept/emit both
names so they can interoperate with peers that only know the legacy one.
"""

from a2a.client.service_parameters import ServiceParameters
from a2a.extensions.common import HTTP_EXTENSION_HEADER


LEGACY_HTTP_EXTENSION_HEADER = f'X-{HTTP_EXTENSION_HEADER}'


def add_legacy_extension_header(parameters: ServiceParameters) -> None:
"""Mirrors the ``A2A-Extensions`` parameter under its legacy name in-place.

Used by v0.3 compat client transports so that requests can be understood
by older v0.3 servers that only recognize ``X-A2A-Extensions``.
"""
if (
HTTP_EXTENSION_HEADER in parameters
and LEGACY_HTTP_EXTENSION_HEADER not in parameters
):
parameters[LEGACY_HTTP_EXTENSION_HEADER] = parameters[
HTTP_EXTENSION_HEADER
]
19 changes: 2 additions & 17 deletions src/a2a/compat/v0_3/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from a2a.compat.v0_3 import (
types as types_v03,
)
from a2a.compat.v0_3.context_builders import V03GrpcServerCallContextBuilder
from a2a.compat.v0_3.request_handler import RequestHandler03
from a2a.extensions.common import HTTP_EXTENSION_HEADER
from a2a.server.context import ServerCallContext
from a2a.server.request_handlers.grpc_handler import (
_ERROR_CODE_MAP,
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(
DefaultCallContextBuilder is used.
"""
self.handler03 = RequestHandler03(request_handler=request_handler)
self._context_builder = (
self._context_builder = V03GrpcServerCallContextBuilder(
context_builder or DefaultGrpcServerCallContextBuilder()
)

Expand All @@ -65,7 +65,6 @@ async def _handle_unary(
try:
server_context = self._context_builder.build(context)
result = await handler_func(server_context)
self._set_extension_metadata(context, server_context)
except A2AError as e:
await self.abort_context(e, context)
else:
Expand All @@ -82,7 +81,6 @@ async def _handle_stream(
server_context = self._context_builder.build(context)
async for item in handler_func(server_context):
yield item
self._set_extension_metadata(context, server_context)
except A2AError as e:
await self.abort_context(e, context)

Expand Down Expand Up @@ -120,19 +118,6 @@ async def abort_context(
f'Unknown error type: {error}',
)

def _set_extension_metadata(
self,
context: grpc.aio.ServicerContext,
server_context: ServerCallContext,
) -> None:
if server_context.activated_extensions:
context.set_trailing_metadata(
[
(HTTP_EXTENSION_HEADER.lower(), e)
for e in sorted(server_context.activated_extensions)
]
)

async def SendMessage(
self,
request: a2a_v0_3_pb2.SendMessageRequest,
Expand Down
5 changes: 4 additions & 1 deletion src/a2a/compat/v0_3/grpc_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from a2a.compat.v0_3 import (
types as types_v03,
)
from a2a.compat.v0_3.extension_headers import add_legacy_extension_header
from a2a.types import a2a_pb2
from a2a.utils.constants import PROTOCOL_VERSION_0_3, VERSION_HEADER
from a2a.utils.telemetry import SpanKind, trace_class
Expand Down Expand Up @@ -361,7 +362,9 @@ def _get_grpc_metadata(
metadata = [(VERSION_HEADER.lower(), PROTOCOL_VERSION_0_3)]

if context and context.service_parameters:
for key, value in context.service_parameters.items():
params = dict(context.service_parameters)
add_legacy_extension_header(params)
for key, value in params.items():
metadata.append((key.lower(), value))

return metadata
3 changes: 2 additions & 1 deletion src/a2a/compat/v0_3/jsonrpc_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_package_starlette_installed = False

from a2a.compat.v0_3 import types as types_v03
from a2a.compat.v0_3.context_builders import V03ServerCallContextBuilder
from a2a.compat.v0_3.request_handler import RequestHandler03
from a2a.server.context import ServerCallContext
from a2a.server.jsonrpc_models import (
Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(
self.handler = RequestHandler03(
request_handler=http_handler,
)
self._context_builder = (
self._context_builder = V03ServerCallContextBuilder(
context_builder or DefaultServerCallContextBuilder()
)

Expand Down
3 changes: 3 additions & 0 deletions src/a2a/compat/v0_3/jsonrpc_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from a2a.compat.v0_3 import conversions
from a2a.compat.v0_3 import types as types_v03
from a2a.compat.v0_3.extension_headers import add_legacy_extension_header
from a2a.types.a2a_pb2 import (
AgentCard,
CancelTaskRequest,
Expand Down Expand Up @@ -424,6 +425,7 @@ async def _send_stream_request(
http_kwargs = get_http_args(context)
http_kwargs.setdefault('headers', {})
http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3
add_legacy_extension_header(http_kwargs['headers'])

async for sse_data in send_http_stream_request(
self.httpx_client,
Expand Down Expand Up @@ -485,6 +487,7 @@ async def _send_request(
http_kwargs = get_http_args(context)
http_kwargs.setdefault('headers', {})
http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3
add_legacy_extension_header(http_kwargs['headers'])

request = self.httpx_client.build_request(
'POST',
Expand Down
3 changes: 2 additions & 1 deletion src/a2a/compat/v0_3/rest_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
_package_starlette_installed = False


from a2a.compat.v0_3.context_builders import V03ServerCallContextBuilder
from a2a.compat.v0_3.rest_handler import REST03Handler
from a2a.server.routes.common import (
DefaultServerCallContextBuilder,
Expand Down Expand Up @@ -60,7 +61,7 @@ def __init__(
context_builder: 'ServerCallContextBuilder | None' = None,
):
self.handler = REST03Handler(request_handler=http_handler)
self._context_builder = (
self._context_builder = V03ServerCallContextBuilder(
context_builder or DefaultServerCallContextBuilder()
)

Expand Down
3 changes: 3 additions & 0 deletions src/a2a/compat/v0_3/rest_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from a2a.compat.v0_3 import (
types as types_v03,
)
from a2a.compat.v0_3.extension_headers import add_legacy_extension_header
from a2a.types.a2a_pb2 import (
AgentCard,
CancelTaskRequest,
Expand Down Expand Up @@ -380,6 +381,7 @@ async def _send_stream_request(
http_kwargs = get_http_args(context)
http_kwargs.setdefault('headers', {})
http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3
add_legacy_extension_header(http_kwargs['headers'])

async for sse_data in send_http_stream_request(
self.httpx_client,
Expand Down Expand Up @@ -414,6 +416,7 @@ async def _execute_request(
http_kwargs = get_http_args(context)
http_kwargs.setdefault('headers', {})
http_kwargs['headers'][VERSION_HEADER.lower()] = PROTOCOL_VERSION_0_3
add_legacy_extension_header(http_kwargs['headers'])

request = self.httpx_client.build_request(
method,
Expand Down
2 changes: 1 addition & 1 deletion src/a2a/extensions/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from a2a.types.a2a_pb2 import AgentCard, AgentExtension


HTTP_EXTENSION_HEADER = 'X-A2A-Extensions'
HTTP_EXTENSION_HEADER = 'A2A-Extensions'


def get_requested_extensions(values: list[str]) -> set[str]:
Expand Down
10 changes: 1 addition & 9 deletions src/a2a/server/agent_execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,22 +151,14 @@ def metadata(self) -> dict[str, Any]:
return dict(self._params.metadata)
return {}

def add_activated_extension(self, uri: str) -> None:
"""Add an extension to the set of activated extensions for this request.

This causes the extension to be indicated back to the client in the
response.
"""
self._call_context.activated_extensions.add(uri)

@property
def tenant(self) -> str:
"""The tenant associated with this request."""
return self._call_context.tenant

@property
def requested_extensions(self) -> set[str]:
"""Extensions that the client requested to activate."""
"""Extensions that the client requested for this interaction."""
return self._call_context.requested_extensions

def _check_or_generate_task_id(self) -> None:
Expand Down
1 change: 0 additions & 1 deletion src/a2a/server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,3 @@ class ServerCallContext(BaseModel):
user: User = Field(default_factory=UnauthenticatedUser)
tenant: str = Field(default='')
requested_extensions: set[str] = Field(default_factory=set)
activated_extensions: set[str] = Field(default_factory=set)
15 changes: 0 additions & 15 deletions src/a2a/server/request_handlers/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ async def _handle_unary(
try:
server_context = self._build_call_context(context, request)
result = await handler_func(server_context)
self._set_extension_metadata(context, server_context)
except A2AError as e:
await self.abort_context(e, context)
else:
Expand All @@ -153,7 +152,6 @@ async def _handle_stream(
server_context = self._build_call_context(context, request)
async for item in handler_func(server_context):
yield item
self._set_extension_metadata(context, server_context)
except A2AError as e:
await self.abort_context(e, context)

Expand Down Expand Up @@ -422,19 +420,6 @@ async def abort_context(
f'Unknown error type: {error}',
)

def _set_extension_metadata(
self,
context: grpc.aio.ServicerContext,
server_context: ServerCallContext,
) -> None:
if server_context.activated_extensions:
context.set_trailing_metadata(
[
(HTTP_EXTENSION_HEADER.lower(), e)
for e in sorted(server_context.activated_extensions)
]
)

def _build_call_context(
self,
context: grpc.aio.ServicerContext,
Expand Down
12 changes: 2 additions & 10 deletions src/a2a/server/routes/jsonrpc_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response

from a2a.compat.v0_3.jsonrpc_adapter import JSONRPC03Adapter
from a2a.extensions.common import (
HTTP_EXTENSION_HEADER,
)
from a2a.server.context import ServerCallContext
from a2a.server.events import Event
from a2a.server.jsonrpc_models import (
Expand Down Expand Up @@ -570,9 +567,6 @@ def _create_response(
Returns:
A Starlette JSONResponse or EventSourceResponse.
"""
headers = {}
if exts := context.activated_extensions:
headers[HTTP_EXTENSION_HEADER] = ', '.join(sorted(exts))
if isinstance(handler_result, AsyncGenerator):
# Result is a stream of dict objects
async def event_generator(
Expand Down Expand Up @@ -603,9 +597,7 @@ async def event_generator(
'data': json.dumps(error_response),
}

return EventSourceResponse(
event_generator(handler_result), headers=headers
)
return EventSourceResponse(event_generator(handler_result))

# handler_result is a dict (JSON-RPC response)
return JSONResponse(handler_result, headers=headers)
return JSONResponse(handler_result)
4 changes: 2 additions & 2 deletions tests/client/transports/test_jsonrpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ async def test_extensions_added_to_request(
from a2a.client.client import ClientCallContext

context = ClientCallContext(
service_parameters={'X-A2A-Extensions': 'https://example.com/ext1'}
service_parameters={'A2A-Extensions': 'https://example.com/ext1'}
)

await transport.send_message(request, context=context)
Expand All @@ -555,7 +555,7 @@ async def test_extensions_added_to_request(
call_args = mock_httpx_client.build_request.call_args
# Extensions should be in the kwargs
assert (
call_args[1].get('headers', {}).get('X-A2A-Extensions')
call_args[1].get('headers', {}).get('A2A-Extensions')
== 'https://example.com/ext1'
)

Expand Down
Loading
Loading