Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/a2a/client/card_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

import httpx

from google.protobuf.json_format import ParseDict, ParseError
from google.protobuf.json_format import ParseError

from a2a.client.errors import AgentCardResolutionError
from a2a.client.helpers import parse_agent_card
from a2a.types.a2a_pb2 import (
AgentCard,
)
Expand Down Expand Up @@ -85,7 +86,7 @@ async def get_agent_card(
target_url,
agent_card_data,
)
agent_card = ParseDict(agent_card_data, AgentCard())
agent_card = parse_agent_card(agent_card_data)
if signature_verifier:
signature_verifier(agent_card)
except httpx.HTTPStatusError as e:
Expand Down
110 changes: 109 additions & 1 deletion src/a2a/client/helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,116 @@
"""Helper functions for the A2A client."""

from typing import Any
from uuid import uuid4

from a2a.types.a2a_pb2 import Message, Part, Role
from google.protobuf.json_format import ParseDict

from a2a.types.a2a_pb2 import AgentCard, Message, Part, Role


def parse_agent_card(agent_card_data: dict[str, Any]) -> AgentCard:
"""Parse AgentCard JSON dictionary and handle backward compatibility."""
_handle_extended_card_compatibility(agent_card_data)
_handle_connection_fields_compatibility(agent_card_data)
_handle_security_compatibility(agent_card_data)

return ParseDict(agent_card_data, AgentCard(), ignore_unknown_fields=True)


def _handle_extended_card_compatibility(
agent_card_data: dict[str, Any],
) -> None:
"""Map legacy supportsAuthenticatedExtendedCard to capabilities."""
if agent_card_data.pop('supportsAuthenticatedExtendedCard', None):
capabilities = agent_card_data.setdefault('capabilities', {})
if 'extendedAgentCard' not in capabilities:
capabilities['extendedAgentCard'] = True


def _handle_connection_fields_compatibility(
agent_card_data: dict[str, Any],
) -> None:
"""Map legacy connection and transport fields to supportedInterfaces."""
main_url = agent_card_data.pop('url', None)
main_transport = agent_card_data.pop('preferredTransport', 'JSONRPC')
version = agent_card_data.pop('protocolVersion', '0.3.0')
additional_interfaces = (
agent_card_data.pop('additionalInterfaces', None) or []
)

if 'supportedInterfaces' not in agent_card_data and main_url:
supported_interfaces = []
supported_interfaces.append(
{
'url': main_url,
'protocolBinding': main_transport,
'protocolVersion': version,
}
)
supported_interfaces.extend(
{
'url': iface.get('url'),
'protocolBinding': iface.get('transport'),
'protocolVersion': version,
}
for iface in additional_interfaces
)
agent_card_data['supportedInterfaces'] = supported_interfaces


def _map_legacy_security(
sec_list: list[dict[str, list[str]]],
) -> list[dict[str, Any]]:
"""Convert a legacy security requirement list into the 1.0.0 Protobuf format."""
return [
{
'schemes': {
scheme_name: {'list': scopes}
for scheme_name, scopes in sec_dict.items()
}
}
for sec_dict in sec_list
]


def _handle_security_compatibility(agent_card_data: dict[str, Any]) -> None:
"""Map legacy security requirements and schemas to their 1.0.0 Protobuf equivalents."""
legacy_security = agent_card_data.pop('security', None)
if (
'securityRequirements' not in agent_card_data
and legacy_security is not None
):
agent_card_data['securityRequirements'] = _map_legacy_security(
legacy_security
)

for skill in agent_card_data.get('skills', []):
legacy_skill_sec = skill.pop('security', None)
if 'securityRequirements' not in skill and legacy_skill_sec is not None:
skill['securityRequirements'] = _map_legacy_security(
legacy_skill_sec
)

security_schemes = agent_card_data.get('securitySchemes', {})
if security_schemes:
type_mapping = {
'apiKey': 'apiKeySecurityScheme',
'http': 'httpAuthSecurityScheme',
'oauth2': 'oauth2SecurityScheme',
'openIdConnect': 'openIdConnectSecurityScheme',
'mutualTLS': 'mtlsSecurityScheme',
}
for scheme in security_schemes.values():
scheme_type = scheme.pop('type', None)
if scheme_type in type_mapping:
# Map legacy 'in' to modern 'location'
if scheme_type == 'apiKey' and 'in' in scheme:
scheme['location'] = scheme.pop('in')

mapped_name = type_mapping[scheme_type]
new_scheme_wrapper = {mapped_name: scheme.copy()}
scheme.clear()
scheme.update(new_scheme_wrapper)


def create_text_message_object(
Expand Down
10 changes: 8 additions & 2 deletions src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,38 @@
import httpx

from google.protobuf import json_format
from google.protobuf.json_format import ParseDict
from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response

from a2a.client.errors import A2AClientError
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
from a2a.client.transports.http_helpers import (
send_http_request,
send_http_stream_request,
)
from a2a.extensions.common import update_extension_header
from a2a.types.a2a_pb2 import (
AgentCard,
CancelTaskRequest,
CreateTaskPushNotificationConfigRequest,
DeleteTaskPushNotificationConfigRequest,
GetExtendedAgentCardRequest,
GetTaskPushNotificationConfigRequest,
GetTaskRequest,
ListTaskPushNotificationConfigsRequest,
ListTaskPushNotificationConfigsResponse,
ListTasksRequest,
ListTasksResponse,
SendMessageRequest,
SendMessageResponse,
StreamResponse,
SubscribeToTaskRequest,
Task,
TaskPushNotificationConfig,
)
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP
from a2a.utils.telemetry import SpanKind, trace_class

Check notice on line 41 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (12-39)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -413,8 +414,13 @@
json_rpc_response = JSONRPC20Response(**response_data)
if json_rpc_response.error:
raise self._create_jsonrpc_error(json_rpc_response.error)
response: AgentCard = json_format.ParseDict(
json_rpc_response.result, AgentCard()
# Validate type of the response
if not isinstance(json_rpc_response.result, dict):
raise A2AClientError(
f'Invalid response type: {type(json_rpc_response.result)}'
)
response: AgentCard = ParseDict(
cast('dict[str, Any]', json_rpc_response.result), AgentCard()
)
if signature_verifier:
signature_verifier(response)
Expand Down
10 changes: 6 additions & 4 deletions src/a2a/server/apps/jsonrpc/jsonrpc_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import TYPE_CHECKING, Any

from google.protobuf.json_format import MessageToDict, ParseDict
from google.protobuf.json_format import ParseDict
from jsonrpc.jsonrpc2 import JSONRPC20Request

from a2a.auth.user import UnauthenticatedUser
Expand All @@ -29,7 +29,10 @@
)
from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler
from a2a.server.request_handlers.request_handler import RequestHandler
from a2a.server.request_handlers.response_helpers import build_error_response
from a2a.server.request_handlers.response_helpers import (
agent_card_to_dict,
build_error_response,
)
from a2a.types import A2ARequest
from a2a.types.a2a_pb2 import (
AgentCard,
Expand Down Expand Up @@ -575,9 +578,8 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
card_to_serve = await maybe_await(self.card_modifier(card_to_serve))

return JSONResponse(
MessageToDict(
agent_card_to_dict(
card_to_serve,
preserving_proto_field_name=False,
)
)

Expand Down
28 changes: 28 additions & 0 deletions src/a2a/server/request_handlers/response_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
from google.protobuf.message import Message as ProtoMessage
from jsonrpc.jsonrpc2 import JSONRPC20Response

from a2a.compat.v0_3.conversions import to_compat_agent_card
from a2a.server.jsonrpc_models import (
InternalError as JSONRPCInternalError,
)
from a2a.server.jsonrpc_models import (
JSONRPCError,
)
from a2a.types.a2a_pb2 import (
AgentCard,
ListTasksResponse,
Message,
StreamResponse,
Expand Down Expand Up @@ -89,6 +91,32 @@
"""Type alias for possible event types produced by handlers."""


def agent_card_to_dict(card: AgentCard) -> dict[str, Any]:
"""Convert AgentCard to dict and inject backward compatibility fields."""
result = MessageToDict(card)

compat_card = to_compat_agent_card(card)
compat_dict = compat_card.model_dump(exclude_none=True)

# Do not include supportsAuthenticatedExtendedCard if false
if not compat_dict.get('supportsAuthenticatedExtendedCard'):
compat_dict.pop('supportsAuthenticatedExtendedCard', None)

def merge(dict1: dict[str, Any], dict2: dict[str, Any]) -> dict[str, Any]:
for k, v in dict2.items():
if k not in dict1:
dict1[k] = v
elif isinstance(v, dict) and isinstance(dict1[k], dict):
merge(dict1[k], v)
elif isinstance(v, list) and isinstance(dict1[k], list):
for i in range(min(len(dict1[k]), len(v))):
if isinstance(dict1[k][i], dict) and isinstance(v[i], dict):
merge(dict1[k][i], v[i])
return dict1

return merge(result, compat_dict)


def build_error_response(
request_id: str | int | None,
error: A2AError | JSONRPCError,
Expand Down
2 changes: 1 addition & 1 deletion tests/client/test_card_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ async def test_get_agent_card_validation_error(
valid_agent_card_data,
):
"""Test A2AClientJSONError is raised on agent card validation error."""
return_json = {'invalid': 'data'}
return_json = {'name': {'invalid': 'type'}}
mock_response.json.return_value = return_json
mock_httpx_client.get.return_value = mock_response
with pytest.raises(AgentCardResolutionError) as exc_info:
Expand Down
Loading
Loading