diff --git a/samples/cli.py b/samples/cli.py index 54b68388f..935834dd3 100644 --- a/samples/cli.py +++ b/samples/cli.py @@ -10,9 +10,9 @@ import httpx from a2a.client import A2ACardResolver, ClientConfig, create_client +from a2a.helpers import get_artifact_text, get_message_text +from a2a.helpers.agent_card import display_agent_card from a2a.types import Message, Part, Role, SendMessageRequest, TaskState -from a2a.utils import get_artifact_text, get_message_text -from a2a.utils.agent_card import display_agent_card async def _handle_stream( diff --git a/scripts/test_minimal_install.py b/scripts/test_minimal_install.py index 076df4c0f..0b29a48b6 100755 --- a/scripts/test_minimal_install.py +++ b/scripts/test_minimal_install.py @@ -50,14 +50,13 @@ 'a2a.server.tasks', 'a2a.types', 'a2a.utils', - 'a2a.utils.artifact', 'a2a.utils.constants', 'a2a.utils.error_handlers', 'a2a.utils.helpers', - 'a2a.utils.message', - 'a2a.utils.parts', 'a2a.utils.proto_utils', 'a2a.utils.task', + 'a2a.helpers.agent_card', + 'a2a.helpers.proto_helpers', ] diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index c23041f32..d33c09481 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -22,7 +22,6 @@ A2AClientTimeoutError, AgentCardResolutionError, ) -from a2a.client.helpers import create_text_message_object from a2a.client.interceptors import ClientCallInterceptor @@ -41,6 +40,5 @@ 'CredentialService', 'InMemoryContextCredentialStore', 'create_client', - 'create_text_message_object', 'minimal_agent_card', ] diff --git a/src/a2a/client/helpers.py b/src/a2a/client/helpers.py index fc7bfdbdf..f8207f03b 100644 --- a/src/a2a/client/helpers.py +++ b/src/a2a/client/helpers.py @@ -1,11 +1,10 @@ """Helper functions for the A2A client.""" from typing import Any -from uuid import uuid4 from google.protobuf.json_format import ParseDict -from a2a.types.a2a_pb2 import AgentCard, Message, Part, Role +from a2a.types.a2a_pb2 import AgentCard def parse_agent_card(agent_card_data: dict[str, Any]) -> AgentCard: @@ -111,20 +110,3 @@ def _handle_security_compatibility(agent_card_data: dict[str, Any]) -> None: new_scheme_wrapper = {mapped_name: scheme.copy()} scheme.clear() scheme.update(new_scheme_wrapper) - - -def create_text_message_object( - role: Role = Role.ROLE_USER, content: str = '' -) -> Message: - """Create a Message object containing a single text Part. - - Args: - role: The role of the message sender (user or agent). Defaults to Role.ROLE_USER. - content: The text content of the message. Defaults to an empty string. - - Returns: - A `Message` object with a new UUID message_id. - """ - return Message( - role=role, parts=[Part(text=content)], message_id=str(uuid4()) - ) diff --git a/src/a2a/helpers/__init__.py b/src/a2a/helpers/__init__.py new file mode 100644 index 000000000..c42429d43 --- /dev/null +++ b/src/a2a/helpers/__init__.py @@ -0,0 +1,34 @@ +"""Helper functions for the A2A Python SDK.""" + +from a2a.helpers.agent_card import display_agent_card +from a2a.helpers.proto_helpers import ( + get_artifact_text, + get_message_text, + get_stream_response_text, + get_text_parts, + new_artifact, + new_message, + new_task, + new_task_from_user_message, + new_text_artifact, + new_text_artifact_update_event, + new_text_message, + new_text_status_update_event, +) + + +__all__ = [ + 'display_agent_card', + 'get_artifact_text', + 'get_message_text', + 'get_stream_response_text', + 'get_text_parts', + 'new_artifact', + 'new_message', + 'new_task', + 'new_task_from_user_message', + 'new_text_artifact', + 'new_text_artifact_update_event', + 'new_text_message', + 'new_text_status_update_event', +] diff --git a/src/a2a/utils/agent_card.py b/src/a2a/helpers/agent_card.py similarity index 100% rename from src/a2a/utils/agent_card.py rename to src/a2a/helpers/agent_card.py diff --git a/src/a2a/helpers/proto_helpers.py b/src/a2a/helpers/proto_helpers.py new file mode 100644 index 000000000..79e1f739d --- /dev/null +++ b/src/a2a/helpers/proto_helpers.py @@ -0,0 +1,214 @@ +"""Unified helper functions for creating and handling A2A types.""" + +import uuid + +from collections.abc import Sequence + +from a2a.types.a2a_pb2 import ( + Artifact, + Message, + Part, + Role, + StreamResponse, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) + + +# --- Message Helpers --- + + +def new_message( + parts: list[Part], + role: Role = Role.ROLE_AGENT, + context_id: str | None = None, + task_id: str | None = None, +) -> Message: + """Creates a new message containing a list of Parts.""" + return Message( + role=role, + parts=parts, + message_id=str(uuid.uuid4()), + task_id=task_id, + context_id=context_id, + ) + + +def new_text_message( + text: str, + context_id: str | None = None, + task_id: str | None = None, + role: Role = Role.ROLE_AGENT, +) -> Message: + """Creates a new message containing a single text Part.""" + return new_message( + parts=[Part(text=text)], + role=role, + task_id=task_id, + context_id=context_id, + ) + + +def get_message_text(message: Message, delimiter: str = '\n') -> str: + """Extracts and joins all text content from a Message's parts.""" + return delimiter.join(get_text_parts(message.parts)) + + +# --- Artifact Helpers --- + + +def new_artifact( + parts: list[Part], + name: str, + description: str | None = None, + artifact_id: str | None = None, +) -> Artifact: + """Creates a new Artifact object.""" + return Artifact( + artifact_id=artifact_id or str(uuid.uuid4()), + parts=parts, + name=name, + description=description, + ) + + +def new_text_artifact( + name: str, + text: str, + description: str | None = None, + artifact_id: str | None = None, +) -> Artifact: + """Creates a new Artifact object containing only a single text Part.""" + return new_artifact( + [Part(text=text)], + name, + description, + artifact_id=artifact_id, + ) + + +def get_artifact_text(artifact: Artifact, delimiter: str = '\n') -> str: + """Extracts and joins all text content from an Artifact's parts.""" + return delimiter.join(get_text_parts(artifact.parts)) + + +# --- Task Helpers --- + + +def new_task_from_user_message(user_message: Message) -> Task: + """Creates a new Task object from an initial user message.""" + if user_message.role != Role.ROLE_USER: + raise ValueError('Message must be from a user') + if not user_message.parts: + raise ValueError('Message parts cannot be empty') + for part in user_message.parts: + if part.HasField('text') and not part.text: + raise ValueError('Message.text cannot be empty') + + return Task( + status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), + id=user_message.task_id or str(uuid.uuid4()), + context_id=user_message.context_id or str(uuid.uuid4()), + history=[user_message], + ) + + +def new_task( + task_id: str, + context_id: str, + state: TaskState, + artifacts: list[Artifact] | None = None, + history: list[Message] | None = None, +) -> Task: + """Creates a Task object with a specified status.""" + if history is None: + history = [] + if artifacts is None: + artifacts = [] + + return Task( + status=TaskStatus(state=state), + id=task_id, + context_id=context_id, + artifacts=artifacts, + history=history, + ) + + +# --- Part Helpers --- + + +def get_text_parts(parts: Sequence[Part]) -> list[str]: + """Extracts text content from all text Parts.""" + return [part.text for part in parts if part.HasField('text')] + + +# --- Event & Stream Helpers --- + + +def new_text_status_update_event( + task_id: str, + context_id: str, + state: TaskState, + text: str, +) -> TaskStatusUpdateEvent: + """Creates a TaskStatusUpdateEvent with a single text message.""" + return TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + status=TaskStatus( + state=state, + message=new_text_message( + text=text, + role=Role.ROLE_AGENT, + context_id=context_id, + task_id=task_id, + ), + ), + ) + + +def new_text_artifact_update_event( # noqa: PLR0913 + task_id: str, + context_id: str, + name: str, + text: str, + append: bool = False, + last_chunk: bool = False, + artifact_id: str | None = None, +) -> TaskArtifactUpdateEvent: + """Creates a TaskArtifactUpdateEvent with a single text artifact.""" + return TaskArtifactUpdateEvent( + task_id=task_id, + context_id=context_id, + artifact=new_text_artifact( + name=name, text=text, artifact_id=artifact_id + ), + append=append, + last_chunk=last_chunk, + ) + + +def get_stream_response_text( + response: StreamResponse, delimiter: str = '\n' +) -> str: + """Extracts text content from a StreamResponse.""" + if response.HasField('message'): + return get_message_text(response.message, delimiter) + if response.HasField('task'): + texts = [ + get_artifact_text(a, delimiter) for a in response.task.artifacts + ] + return delimiter.join(t for t in texts if t) + if response.HasField('status_update'): + if response.status_update.status.HasField('message'): + return get_message_text( + response.status_update.status.message, delimiter + ) + return '' + if response.HasField('artifact_update'): + return get_artifact_text(response.artifact_update.artifact, delimiter) + return '' diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index 1feefb1df..8b78c1045 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -1,5 +1,6 @@ from typing import Any +from a2a.helpers.proto_helpers import get_message_text from a2a.server.context import ServerCallContext from a2a.server.id_generator import ( IDGenerator, @@ -12,7 +13,6 @@ SendMessageRequest, Task, ) -from a2a.utils import get_message_text from a2a.utils.errors import InvalidParamsError diff --git a/src/a2a/server/tasks/task_manager.py b/src/a2a/server/tasks/task_manager.py index 143413d5b..e5d899c1e 100644 --- a/src/a2a/server/tasks/task_manager.py +++ b/src/a2a/server/tasks/task_manager.py @@ -4,6 +4,7 @@ from a2a.server.events.event_queue import Event from a2a.server.tasks.task_store import TaskStore from a2a.types.a2a_pb2 import ( + Artifact, Message, Task, TaskArtifactUpdateEvent, @@ -11,13 +12,77 @@ TaskStatus, TaskStatusUpdateEvent, ) -from a2a.utils import append_artifact_to_task from a2a.utils.errors import InvalidParamsError +from a2a.utils.telemetry import trace_function logger = logging.getLogger(__name__) +@trace_function() +def append_artifact_to_task(task: Task, event: TaskArtifactUpdateEvent) -> None: + """Helper method for updating a Task object with new artifact data from an event. + + Handles creating the artifacts list if it doesn't exist, adding new artifacts, + and appending parts to existing artifacts based on the `append` flag in the event. + + Args: + task: The `Task` object to modify. + event: The `TaskArtifactUpdateEvent` containing the artifact data. + """ + new_artifact_data: Artifact = event.artifact + artifact_id: str = new_artifact_data.artifact_id + append_parts: bool = event.append + + existing_artifact: Artifact | None = None + existing_artifact_list_index: int | None = None + + # Find existing artifact by its id + for i, art in enumerate(task.artifacts): + if art.artifact_id == artifact_id: + existing_artifact = art + existing_artifact_list_index = i + break + + if not append_parts: + # This represents the first chunk for this artifact index. + if existing_artifact_list_index is not None: + # Replace the existing artifact entirely with the new data + logger.debug( + 'Replacing artifact at id %s for task %s', artifact_id, task.id + ) + task.artifacts[existing_artifact_list_index].CopyFrom( + new_artifact_data + ) + else: + # Append the new artifact since no artifact with this index exists yet + logger.debug( + 'Adding new artifact with id %s for task %s', + artifact_id, + task.id, + ) + task.artifacts.append(new_artifact_data) + elif existing_artifact: + # Append new parts to the existing artifact's part list + logger.debug( + 'Appending parts to artifact id %s for task %s', + artifact_id, + task.id, + ) + existing_artifact.parts.extend(new_artifact_data.parts) + existing_artifact.metadata.update( + dict(new_artifact_data.metadata.items()) + ) + else: + # We received a chunk to append, but we don't have an existing artifact. + # we will ignore this chunk + logger.warning( + 'Received append=True for nonexistent artifact index %s in task %s. Ignoring chunk.', + artifact_id, + task.id, + ) + + class TaskManager: """Helps manage a task's lifecycle during execution of a request. diff --git a/src/a2a/utils/__init__.py b/src/a2a/utils/__init__.py index 1efed5794..04693dd0b 100644 --- a/src/a2a/utils/__init__.py +++ b/src/a2a/utils/__init__.py @@ -1,62 +1,18 @@ """Utility functions for the A2A Python SDK.""" from a2a.utils import proto_utils -from a2a.utils.agent_card import display_agent_card -from a2a.utils.artifact import ( - get_artifact_text, - new_artifact, - new_data_artifact, - new_text_artifact, -) from a2a.utils.constants import ( AGENT_CARD_WELL_KNOWN_PATH, DEFAULT_RPC_URL, TransportProtocol, ) -from a2a.utils.helpers import ( - append_artifact_to_task, - are_modalities_compatible, - build_text_artifact, - create_task_obj, -) -from a2a.utils.message import ( - get_message_text, - new_agent_parts_message, - new_agent_text_message, -) -from a2a.utils.parts import ( - get_data_parts, - get_file_parts, - get_text_parts, -) from a2a.utils.proto_utils import to_stream_response -from a2a.utils.task import ( - completed_task, - new_task, -) __all__ = [ 'AGENT_CARD_WELL_KNOWN_PATH', 'DEFAULT_RPC_URL', 'TransportProtocol', - 'append_artifact_to_task', - 'are_modalities_compatible', - 'build_text_artifact', - 'completed_task', - 'create_task_obj', - 'display_agent_card', - 'get_artifact_text', - 'get_data_parts', - 'get_file_parts', - 'get_message_text', - 'get_text_parts', - 'new_agent_parts_message', - 'new_agent_text_message', - 'new_artifact', - 'new_data_artifact', - 'new_task', - 'new_text_artifact', 'proto_utils', 'to_stream_response', ] diff --git a/src/a2a/utils/artifact.py b/src/a2a/utils/artifact.py deleted file mode 100644 index ac14087dc..000000000 --- a/src/a2a/utils/artifact.py +++ /dev/null @@ -1,92 +0,0 @@ -"""Utility functions for creating A2A Artifact objects.""" - -import uuid - -from typing import Any - -from google.protobuf.struct_pb2 import Struct, Value - -from a2a.types.a2a_pb2 import Artifact, Part -from a2a.utils.parts import get_text_parts - - -def new_artifact( - parts: list[Part], - name: str, - description: str | None = None, -) -> Artifact: - """Creates a new Artifact object. - - Args: - parts: The list of `Part` objects forming the artifact's content. - name: The human-readable name of the artifact. - description: An optional description of the artifact. - - Returns: - A new `Artifact` object with a generated artifact_id. - """ - return Artifact( - artifact_id=str(uuid.uuid4()), - parts=parts, - name=name, - description=description, - ) - - -def new_text_artifact( - name: str, - text: str, - description: str | None = None, -) -> Artifact: - """Creates a new Artifact object containing only a single text Part. - - Args: - name: The human-readable name of the artifact. - text: The text content of the artifact. - description: An optional description of the artifact. - - Returns: - A new `Artifact` object with a generated artifact_id. - """ - return new_artifact( - [Part(text=text)], - name, - description, - ) - - -def new_data_artifact( - name: str, - data: dict[str, Any], - description: str | None = None, -) -> Artifact: - """Creates a new Artifact object containing only a single data Part. - - Args: - name: The human-readable name of the artifact. - data: The structured data content of the artifact. - description: An optional description of the artifact. - - Returns: - A new `Artifact` object with a generated artifact_id. - """ - struct_data = Struct() - struct_data.update(data) - return new_artifact( - [Part(data=Value(struct_value=struct_data))], - name, - description, - ) - - -def get_artifact_text(artifact: Artifact, delimiter: str = '\n') -> str: - """Extracts and joins all text content from an Artifact's parts. - - Args: - artifact: The `Artifact` object. - delimiter: The string to use when joining text from multiple TextParts. - - Returns: - A single string containing all text content, or an empty string if no text parts are found. - """ - return delimiter.join(get_text_parts(artifact.parts)) diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index fe69bf26d..9a974a4c2 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -2,30 +2,16 @@ import functools import inspect -import json import logging from collections.abc import AsyncIterator, Awaitable, Callable from typing import Any, TypeVar, cast -from uuid import uuid4 -from google.protobuf.json_format import MessageToDict from packaging.version import InvalidVersion, Version from a2a.server.context import ServerCallContext -from a2a.types.a2a_pb2 import ( - AgentCard, - Artifact, - Part, - SendMessageRequest, - Task, - TaskArtifactUpdateEvent, - TaskState, - TaskStatus, -) from a2a.utils import constants from a2a.utils.errors import VersionNotSupportedError -from a2a.utils.telemetry import trace_function T = TypeVar('T') @@ -35,168 +21,6 @@ logger = logging.getLogger(__name__) -@trace_function() -def create_task_obj(message_send_params: SendMessageRequest) -> Task: - """Create a new task object from message send params. - - Generates UUIDs for task and context IDs if they are not already present in the message. - - Args: - message_send_params: The `SendMessageRequest` object containing the initial message. - - Returns: - A new `Task` object initialized with 'submitted' status and the input message in history. - """ - if not message_send_params.message.context_id: - message_send_params.message.context_id = str(uuid4()) - - task = Task( - id=str(uuid4()), - context_id=message_send_params.message.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - ) - task.history.append(message_send_params.message) - return task - - -@trace_function() -def append_artifact_to_task(task: Task, event: TaskArtifactUpdateEvent) -> None: - """Helper method for updating a Task object with new artifact data from an event. - - Handles creating the artifacts list if it doesn't exist, adding new artifacts, - and appending parts to existing artifacts based on the `append` flag in the event. - - Args: - task: The `Task` object to modify. - event: The `TaskArtifactUpdateEvent` containing the artifact data. - """ - new_artifact_data: Artifact = event.artifact - artifact_id: str = new_artifact_data.artifact_id - append_parts: bool = event.append - - existing_artifact: Artifact | None = None - existing_artifact_list_index: int | None = None - - # Find existing artifact by its id - for i, art in enumerate(task.artifacts): - if art.artifact_id == artifact_id: - existing_artifact = art - existing_artifact_list_index = i - break - - if not append_parts: - # This represents the first chunk for this artifact index. - if existing_artifact_list_index is not None: - # Replace the existing artifact entirely with the new data - logger.debug( - 'Replacing artifact at id %s for task %s', artifact_id, task.id - ) - task.artifacts[existing_artifact_list_index].CopyFrom( - new_artifact_data - ) - else: - # Append the new artifact since no artifact with this index exists yet - logger.debug( - 'Adding new artifact with id %s for task %s', - artifact_id, - task.id, - ) - task.artifacts.append(new_artifact_data) - elif existing_artifact: - # Append new parts to the existing artifact's part list - logger.debug( - 'Appending parts to artifact id %s for task %s', - artifact_id, - task.id, - ) - existing_artifact.parts.extend(new_artifact_data.parts) - existing_artifact.metadata.update( - dict(new_artifact_data.metadata.items()) - ) - else: - # We received a chunk to append, but we don't have an existing artifact. - # we will ignore this chunk - logger.warning( - 'Received append=True for nonexistent artifact index %s in task %s. Ignoring chunk.', - artifact_id, - task.id, - ) - - -def build_text_artifact(text: str, artifact_id: str) -> Artifact: - """Helper to create a text artifact. - - Args: - text: The text content for the artifact. - artifact_id: The ID for the artifact. - - Returns: - An `Artifact` object containing a single text Part. - """ - part = Part(text=text) - return Artifact(parts=[part], artifact_id=artifact_id) - - -def are_modalities_compatible( - server_output_modes: list[str] | None, client_output_modes: list[str] | None -) -> bool: - """Checks if server and client output modalities (MIME types) are compatible. - - Modalities are compatible if: - 1. The client specifies no preferred output modes (client_output_modes is None or empty). - 2. The server specifies no supported output modes (server_output_modes is None or empty). - 3. There is at least one common modality between the server's supported list and the client's preferred list. - - Args: - server_output_modes: A list of MIME types supported by the server/agent for output. - Can be None or empty if the server doesn't specify. - client_output_modes: A list of MIME types preferred by the client for output. - Can be None or empty if the client accepts any. - - Returns: - True if the modalities are compatible, False otherwise. - """ - if client_output_modes is None or len(client_output_modes) == 0: - return True - - if server_output_modes is None or len(server_output_modes) == 0: - return True - - return any(x in server_output_modes for x in client_output_modes) - - -def _clean_empty(d: Any) -> Any: - """Recursively remove empty strings, lists and dicts from a dictionary.""" - if isinstance(d, dict): - cleaned_dict = { - k: cleaned_v - for k, v in d.items() - if (cleaned_v := _clean_empty(v)) is not None - } - return cleaned_dict or None - if isinstance(d, list): - cleaned_list = [ - cleaned_v for v in d if (cleaned_v := _clean_empty(v)) is not None - ] - return cleaned_list or None - if isinstance(d, str) and not d: - return None - return d - - -def canonicalize_agent_card(agent_card: AgentCard) -> str: - """Canonicalizes the Agent Card JSON according to RFC 8785 (JCS).""" - card_dict = MessageToDict( - agent_card, - ) - # Remove signatures field if present - card_dict.pop('signatures', None) - - # Recursively remove empty values - cleaned_dict = _clean_empty(card_dict) - return json.dumps(cleaned_dict, separators=(',', ':'), sort_keys=True) - - async def maybe_await(value: T | Awaitable[T]) -> T: """Awaits a value if it's awaitable, otherwise simply provides it back.""" if inspect.isawaitable(value): diff --git a/src/a2a/utils/message.py b/src/a2a/utils/message.py deleted file mode 100644 index 528d952f4..000000000 --- a/src/a2a/utils/message.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Utility functions for creating and handling A2A Message objects.""" - -import uuid - -from a2a.types.a2a_pb2 import ( - Message, - Part, - Role, -) -from a2a.utils.parts import get_text_parts - - -def new_agent_text_message( - text: str, - context_id: str | None = None, - task_id: str | None = None, -) -> Message: - """Creates a new agent message containing a single text Part. - - Args: - text: The text content of the message. - context_id: The context ID for the message. - task_id: The task ID for the message. - - Returns: - A new `Message` object with role 'agent'. - """ - return Message( - role=Role.ROLE_AGENT, - parts=[Part(text=text)], - message_id=str(uuid.uuid4()), - task_id=task_id, - context_id=context_id, - ) - - -def new_agent_parts_message( - parts: list[Part], - context_id: str | None = None, - task_id: str | None = None, -) -> Message: - """Creates a new agent message containing a list of Parts. - - Args: - parts: The list of `Part` objects for the message content. - context_id: The context ID for the message. - task_id: The task ID for the message. - - Returns: - A new `Message` object with role 'agent'. - """ - return Message( - role=Role.ROLE_AGENT, - parts=parts, - message_id=str(uuid.uuid4()), - task_id=task_id, - context_id=context_id, - ) - - -def get_message_text(message: Message, delimiter: str = '\n') -> str: - """Extracts and joins all text content from a Message's parts. - - Args: - message: The `Message` object. - delimiter: The string to use when joining text from multiple text Parts. - - Returns: - A single string containing all text content, or an empty string if no text parts are found. - """ - return delimiter.join(get_text_parts(message.parts)) diff --git a/src/a2a/utils/parts.py b/src/a2a/utils/parts.py deleted file mode 100644 index c9b964540..000000000 --- a/src/a2a/utils/parts.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Utility functions for creating and handling A2A Parts objects.""" - -from collections.abc import Sequence -from typing import Any - -from google.protobuf.json_format import MessageToDict - -from a2a.types.a2a_pb2 import ( - Part, -) - - -def get_text_parts(parts: Sequence[Part]) -> list[str]: - """Extracts text content from all text Parts. - - Args: - parts: A sequence of `Part` objects. - - Returns: - A list of strings containing the text content from any text Parts found. - """ - return [part.text for part in parts if part.HasField('text')] - - -def get_data_parts(parts: Sequence[Part]) -> list[Any]: - """Extracts data from all data Parts in a list of Parts. - - Args: - parts: A sequence of `Part` objects. - - Returns: - A list of values containing the data from any data Parts found. - """ - return [MessageToDict(part.data) for part in parts if part.HasField('data')] - - -def get_file_parts(parts: Sequence[Part]) -> list[Part]: - """Extracts file parts from a list of Parts. - - Args: - parts: A sequence of `Part` objects. - - Returns: - A list of `Part` objects containing file data (raw or url). - """ - return [part for part in parts if part.raw or part.url] diff --git a/src/a2a/utils/signing.py b/src/a2a/utils/signing.py index 68924c8a0..aa720d159 100644 --- a/src/a2a/utils/signing.py +++ b/src/a2a/utils/signing.py @@ -3,7 +3,7 @@ from collections.abc import Callable from typing import Any, TypedDict -from a2a.utils.helpers import canonicalize_agent_card +from google.protobuf.json_format import MessageToDict try: @@ -68,7 +68,7 @@ def create_agent_card_signer( def agent_card_signer(agent_card: AgentCard) -> AgentCard: """Signs agent card.""" - canonical_payload = canonicalize_agent_card(agent_card) + canonical_payload = _canonicalize_agent_card(agent_card) payload_dict = json.loads(canonical_payload) jws_string = jwt.encode( @@ -128,7 +128,7 @@ def signature_verifier( jku = protected_header.get('jku') verification_key = key_provider(kid, jku) - canonical_payload = canonicalize_agent_card(agent_card) + canonical_payload = _canonicalize_agent_card(agent_card) encoded_payload = base64url_encode( canonical_payload.encode('utf-8') ).decode('utf-8') @@ -148,3 +148,35 @@ def signature_verifier( raise InvalidSignaturesError('No valid signature found') return signature_verifier + + +def _clean_empty(d: Any) -> Any: + """Recursively remove empty strings, lists and dicts from a dictionary.""" + if isinstance(d, dict): + cleaned_dict = { + k: cleaned_v + for k, v in d.items() + if (cleaned_v := _clean_empty(v)) is not None + } + return cleaned_dict or None + if isinstance(d, list): + cleaned_list = [ + cleaned_v for v in d if (cleaned_v := _clean_empty(v)) is not None + ] + return cleaned_list or None + if isinstance(d, str) and not d: + return None + return d + + +def _canonicalize_agent_card(agent_card: AgentCard) -> str: + """Canonicalizes the Agent Card JSON according to RFC 8785 (JCS).""" + card_dict = MessageToDict( + agent_card, + ) + # Remove signatures field if present + card_dict.pop('signatures', None) + + # Recursively remove empty values + cleaned_dict = _clean_empty(card_dict) + return json.dumps(cleaned_dict, separators=(',', ':'), sort_keys=True) diff --git a/src/a2a/utils/task.py b/src/a2a/utils/task.py index 6ff716a30..4acf54e46 100644 --- a/src/a2a/utils/task.py +++ b/src/a2a/utils/task.py @@ -1,89 +1,15 @@ """Utility functions for creating A2A Task objects.""" import binascii -import uuid from base64 import b64decode, b64encode from typing import Literal, Protocol, runtime_checkable -from a2a.types.a2a_pb2 import ( - Artifact, - Message, - Task, - TaskState, - TaskStatus, -) +from a2a.types.a2a_pb2 import Task from a2a.utils.constants import MAX_LIST_TASKS_PAGE_SIZE from a2a.utils.errors import InvalidParamsError -def new_task(request: Message) -> Task: - """Creates a new Task object from an initial user message. - - Generates task and context IDs if not provided in the message. - - Args: - request: The initial `Message` object from the user. - - Returns: - A new `Task` object initialized with 'submitted' status and the input message in history. - - Raises: - TypeError: If the message role is None. - ValueError: If the message parts are empty, if any part has empty content, or if the provided context_id is invalid. - """ - if not request.role: - raise TypeError('Message role cannot be None') - if not request.parts: - raise ValueError('Message parts cannot be empty') - for part in request.parts: - if part.HasField('text') and not part.text: - raise ValueError('Message.text cannot be empty') - - return Task( - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - id=request.task_id or str(uuid.uuid4()), - context_id=request.context_id or str(uuid.uuid4()), - history=[request], - ) - - -def completed_task( - task_id: str, - context_id: str, - artifacts: list[Artifact], - history: list[Message] | None = None, -) -> Task: - """Creates a Task object in the 'completed' state. - - Useful for constructing a final Task representation when the agent - finishes and produces artifacts. - - Args: - task_id: The ID of the task. - context_id: The context ID of the task. - artifacts: A list of `Artifact` objects produced by the task. - history: An optional list of `Message` objects representing the task history. - - Returns: - A `Task` object with status set to 'completed'. - """ - if not artifacts or not all(isinstance(a, Artifact) for a in artifacts): - raise ValueError( - 'artifacts must be a non-empty list of Artifact objects' - ) - - if history is None: - history = [] - return Task( - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - id=task_id, - context_id=context_id, - artifacts=artifacts, - history=history, - ) - - @runtime_checkable class HistoryLengthConfig(Protocol): """Protocol for configuration arguments containing history_length field.""" diff --git a/tests/client/test_client_helpers.py b/tests/client/test_client_helpers.py index 8963eefce..0eb394f43 100644 --- a/tests/client/test_client_helpers.py +++ b/tests/client/test_client_helpers.py @@ -3,7 +3,8 @@ import json from google.protobuf.json_format import MessageToDict -from a2a.client.helpers import create_text_message_object, parse_agent_card +from a2a.client.helpers import parse_agent_card +from a2a.helpers.proto_helpers import new_text_message from a2a.server.request_handlers.response_helpers import agent_card_to_dict from a2a.types.a2a_pb2 import ( APIKeySecurityScheme, @@ -263,7 +264,7 @@ def test_parse_agent_card_security_scheme_unknown_type() -> None: def test_create_text_message_object() -> None: - msg = create_text_message_object(role=Role.ROLE_AGENT, content='Hello') + msg = new_text_message(text='Hello', role=Role.ROLE_AGENT) assert msg.role == Role.ROLE_AGENT assert len(msg.parts) == 1 assert msg.parts[0].text == 'Hello' diff --git a/tests/client/transports/test_grpc_client.py b/tests/client/transports/test_grpc_client.py index 9e81bd71e..95cca9189 100644 --- a/tests/client/transports/test_grpc_client.py +++ b/tests/client/transports/test_grpc_client.py @@ -35,7 +35,7 @@ TaskStatus, TaskStatusUpdateEvent, ) -from a2a.utils import get_text_parts +from a2a.helpers.proto_helpers import get_text_parts @pytest.fixture diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index e7912566e..0c9f7c30a 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -8,7 +8,7 @@ from google.protobuf.timestamp_pb2 import Timestamp from httpx_sse import EventSource, ServerSentEvent -from a2a.client import create_text_message_object +from a2a.helpers.proto_helpers import new_text_message from a2a.client.client import ClientCallContext from a2a.client.errors import A2AClientError from a2a.client.transports.rest import RestTransport @@ -83,7 +83,7 @@ async def test_send_message_streaming_timeout( url='http://agent.example.com/api', ) params = SendMessageRequest( - message=create_text_message_object(content='Hello stream') + message=new_text_message(text='Hello stream') ) mock_event_source = AsyncMock(spec=EventSource) mock_event_source.response = MagicMock(spec=httpx.Response) @@ -120,9 +120,7 @@ async def test_rest_mapped_errors( agent_card=mock_agent_card, url='http://agent.example.com/api', ) - params = SendMessageRequest( - message=create_text_message_object(content='Hello') - ) + params = SendMessageRequest(message=new_text_message(text='Hello')) mock_build_request = MagicMock( return_value=AsyncMock(spec=httpx.Request) @@ -172,9 +170,7 @@ async def test_send_message_with_timeout_context( agent_card=mock_agent_card, url='http://agent.example.com/api', ) - params = SendMessageRequest( - message=create_text_message_object(content='Hello') - ) + params = SendMessageRequest(message=new_text_message(text='Hello')) context = ClientCallContext(timeout=10.0) mock_build_request = MagicMock( @@ -246,9 +242,7 @@ async def test_send_message_with_default_extensions( agent_card=mock_agent_card, url='http://agent.example.com/api', ) - params = SendMessageRequest( - message=create_text_message_object(content='Hello') - ) + params = SendMessageRequest(message=new_text_message(text='Hello')) # Mock the build_request method to capture its inputs mock_build_request = MagicMock( @@ -294,7 +288,7 @@ async def test_send_message_streaming_with_new_extensions( url='http://agent.example.com/api', ) params = SendMessageRequest( - message=create_text_message_object(content='Hello stream') + message=new_text_message(text='Hello stream') ) mock_event_source = AsyncMock(spec=EventSource) @@ -343,7 +337,7 @@ async def test_send_message_streaming_server_error_propagates( url='http://agent.example.com/api', ) request = SendMessageRequest( - message=create_text_message_object(content='Error stream') + message=new_text_message(text='Error stream') ) mock_event_source = AsyncMock(spec=EventSource) @@ -524,7 +518,7 @@ class TestRestTransportTenant: 'send_message', SendMessageRequest( tenant='my-tenant', - message=create_text_message_object(content='hi'), + message=new_text_message(text='hi'), ), '/my-tenant/message:send', ), @@ -686,7 +680,7 @@ async def test_rest_get_task_prepend_empty_tenant( 'send_message_streaming', SendMessageRequest( tenant='my-tenant', - message=create_text_message_object(content='hi'), + message=new_text_message(text='hi'), ), '/my-tenant/message:stream', ), diff --git a/tests/e2e/push_notifications/agent_app.py b/tests/e2e/push_notifications/agent_app.py index 106a97cea..bc95f6c37 100644 --- a/tests/e2e/push_notifications/agent_app.py +++ b/tests/e2e/push_notifications/agent_app.py @@ -24,9 +24,9 @@ Message, Task, ) -from a2a.utils import ( - new_agent_text_message, - new_task, +from a2a.helpers.proto_helpers import ( + new_text_message, + new_task_from_user_message, ) @@ -74,7 +74,7 @@ async def invoke( or not msg.parts[0].HasField('text') ): await updater.failed( - new_agent_text_message( + new_text_message( 'Unsupported message.', task.context_id, task.id ) ) @@ -84,25 +84,23 @@ async def invoke( # Simple request-response flow. if text_message == 'Hello Agent!': await updater.complete( - new_agent_text_message('Hello User!', task.context_id, task.id) + new_text_message('Hello User!', task.context_id, task.id) ) # Flow with user input required: "How are you?" -> "Good! How are you?" -> "Good" -> "Amazing". elif text_message == 'How are you?': await updater.requires_input( - new_agent_text_message( - 'Good! How are you?', task.context_id, task.id - ) + new_text_message('Good! How are you?', task.context_id, task.id) ) elif text_message == 'Good': await updater.complete( - new_agent_text_message('Amazing', task.context_id, task.id) + new_text_message('Amazing', task.context_id, task.id) ) # Fail for unsupported messages. else: await updater.failed( - new_agent_text_message( + new_text_message( 'Unsupported message.', task.context_id, task.id ) ) @@ -124,7 +122,7 @@ async def execute( task = context.current_task if not task: - task = new_task(context.message) + task = new_task_from_user_message(context.message) await event_queue.enqueue_event(task) updater = TaskUpdater(event_queue, task.id, task.context_id) diff --git a/tests/utils/test_agent_card_display.py b/tests/helpers/test_agent_card_display.py similarity index 99% rename from tests/utils/test_agent_card_display.py rename to tests/helpers/test_agent_card_display.py index 93dc1aad4..e252a52fe 100644 --- a/tests/utils/test_agent_card_display.py +++ b/tests/helpers/test_agent_card_display.py @@ -9,7 +9,7 @@ AgentProvider, AgentSkill, ) -from a2a.utils.agent_card import display_agent_card +from a2a.helpers.agent_card import display_agent_card @pytest.fixture diff --git a/tests/helpers/test_proto_helpers.py b/tests/helpers/test_proto_helpers.py new file mode 100644 index 000000000..a4f6498ab --- /dev/null +++ b/tests/helpers/test_proto_helpers.py @@ -0,0 +1,230 @@ +"""Tests for proto helpers.""" + +import pytest +from a2a.helpers.proto_helpers import ( + new_message, + new_text_message, + get_message_text, + new_artifact, + new_text_artifact, + get_artifact_text, + new_task_from_user_message, + new_task, + get_text_parts, + new_text_status_update_event, + new_text_artifact_update_event, + get_stream_response_text, +) +from a2a.types.a2a_pb2 import ( + Part, + Role, + Message, + Artifact, + Task, + TaskState, + StreamResponse, +) + +# --- Message Helpers Tests --- + + +def test_new_message() -> None: + parts = [Part(text='hello')] + msg = new_message( + parts=parts, role=Role.ROLE_USER, context_id='ctx1', task_id='task1' + ) + assert msg.role == Role.ROLE_USER + assert msg.parts == parts + assert msg.context_id == 'ctx1' + assert msg.task_id == 'task1' + assert msg.message_id != '' + + +def test_new_text_message() -> None: + msg = new_text_message( + text='hello', context_id='ctx1', task_id='task1', role=Role.ROLE_USER + ) + assert msg.role == Role.ROLE_USER + assert len(msg.parts) == 1 + assert msg.parts[0].text == 'hello' + assert msg.context_id == 'ctx1' + assert msg.task_id == 'task1' + assert msg.message_id != '' + + +def test_get_message_text() -> None: + msg = Message(parts=[Part(text='hello'), Part(text='world')]) + assert get_message_text(msg) == 'hello\nworld' + assert get_message_text(msg, delimiter=' ') == 'hello world' + + +# --- Artifact Helpers Tests --- + + +def test_new_artifact() -> None: + parts = [Part(text='content')] + art = new_artifact(parts=parts, name='test', description='desc') + assert art.name == 'test' + assert art.description == 'desc' + assert art.parts == parts + assert art.artifact_id != '' + + +def test_new_text_artifact() -> None: + art = new_text_artifact(name='test', text='content', description='desc') + assert art.name == 'test' + assert art.description == 'desc' + assert len(art.parts) == 1 + assert art.parts[0].text == 'content' + assert art.artifact_id != '' + + +def test_new_text_artifact_with_id() -> None: + art = new_text_artifact( + name='test', text='content', description='desc', artifact_id='art1' + ) + assert art.name == 'test' + assert art.description == 'desc' + assert len(art.parts) == 1 + assert art.parts[0].text == 'content' + assert art.artifact_id == 'art1' + + +def test_get_artifact_text() -> None: + art = Artifact(parts=[Part(text='hello'), Part(text='world')]) + assert get_artifact_text(art) == 'hello\nworld' + assert get_artifact_text(art, delimiter=' ') == 'hello world' + + +# --- Task Helpers Tests --- + + +def test_new_task_from_user_message() -> None: + msg = Message( + role=Role.ROLE_USER, + parts=[Part(text='hello')], + task_id='task1', + context_id='ctx1', + ) + task = new_task_from_user_message(msg) + assert task.id == 'task1' + assert task.context_id == 'ctx1' + assert task.status.state == TaskState.TASK_STATE_SUBMITTED + assert len(task.history) == 1 + assert task.history[0] == msg + + +def test_new_task_from_user_message_empty_parts() -> None: + msg = Message(role=Role.ROLE_USER, parts=[]) + with pytest.raises(ValueError, match='Message parts cannot be empty'): + new_task_from_user_message(msg) + + +def test_new_task_from_user_message_empty_text() -> None: + msg = Message(role=Role.ROLE_USER, parts=[Part(text='')]) + with pytest.raises(ValueError, match='Message.text cannot be empty'): + new_task_from_user_message(msg) + + +def test_new_task() -> None: + task = new_task( + task_id='task1', context_id='ctx1', state=TaskState.TASK_STATE_WORKING + ) + assert task.id == 'task1' + assert task.context_id == 'ctx1' + assert task.status.state == TaskState.TASK_STATE_WORKING + assert len(task.history) == 0 + assert len(task.artifacts) == 0 + + +# --- Part Helpers Tests --- + + +def test_get_text_parts() -> None: + parts = [ + Part(text='hello'), + Part(url='http://example.com'), + Part(text='world'), + ] + assert get_text_parts(parts) == ['hello', 'world'] + + +# --- Event & Stream Helpers Tests --- + + +def test_new_text_status_update_event() -> None: + event = new_text_status_update_event( + task_id='task1', + context_id='ctx1', + state=TaskState.TASK_STATE_WORKING, + text='progress', + ) + assert event.task_id == 'task1' + assert event.context_id == 'ctx1' + assert event.status.state == TaskState.TASK_STATE_WORKING + assert event.status.message.parts[0].text == 'progress' + + +def test_new_text_artifact_update_event() -> None: + event = new_text_artifact_update_event( + task_id='task1', + context_id='ctx1', + name='test', + text='content', + append=True, + last_chunk=True, + ) + assert event.task_id == 'task1' + assert event.context_id == 'ctx1' + assert event.artifact.name == 'test' + assert event.artifact.parts[0].text == 'content' + assert event.append is True + assert event.last_chunk is True + + +def test_new_text_artifact_update_event_with_id() -> None: + event = new_text_artifact_update_event( + task_id='task1', + context_id='ctx1', + name='test', + text='content', + artifact_id='art1', + ) + assert event.task_id == 'task1' + assert event.context_id == 'ctx1' + assert event.artifact.name == 'test' + assert event.artifact.parts[0].text == 'content' + assert event.artifact.artifact_id == 'art1' + + +def test_get_stream_response_text_message() -> None: + resp = StreamResponse(message=Message(parts=[Part(text='hello')])) + assert get_stream_response_text(resp) == 'hello' + + +def test_get_stream_response_text_task() -> None: + resp = StreamResponse( + task=Task(artifacts=[Artifact(parts=[Part(text='hello')])]) + ) + assert get_stream_response_text(resp) == 'hello' + + +def test_get_stream_response_text_status_update() -> None: + resp = StreamResponse( + status_update=new_text_status_update_event( + 't', 'c', TaskState.TASK_STATE_WORKING, 'hello' + ) + ) + assert get_stream_response_text(resp) == 'hello' + + +def test_get_stream_response_text_artifact_update() -> None: + resp = StreamResponse( + artifact_update=new_text_artifact_update_event('t', 'c', 'n', 'hello') + ) + assert get_stream_response_text(resp) == 'hello' + + +def test_get_stream_response_text_empty() -> None: + resp = StreamResponse() + assert get_stream_response_text(resp) == '' diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index aea9784ad..b6cddbe4d 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -43,7 +43,8 @@ TaskState, a2a_pb2_grpc, ) -from a2a.utils import TransportProtocol, new_task +from a2a.utils import TransportProtocol +from a2a.helpers.proto_helpers import new_task_from_user_message from a2a.utils.errors import InvalidParamsError @@ -130,7 +131,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): # Task-based response. task = context.current_task if not task: - task = new_task(context.message) + task = new_task_from_user_message(context.message) await event_queue.enqueue_event(task) task_updater = TaskUpdater( diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 294f5aefe..5a2bf0446 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -73,7 +73,10 @@ TaskStatus, TaskStatusUpdateEvent, ) -from a2a.utils import new_agent_text_message, new_task +from a2a.helpers.proto_helpers import ( + new_text_message, + new_task_from_user_message, +) class MockAgentExecutor(AgentExecutor): @@ -254,8 +257,8 @@ async def test_on_list_tasks_applies_history_length(agent_card): """Test on_list_tasks applies history length filter.""" mock_task_store = AsyncMock(spec=TaskStore) history = [ - new_agent_text_message('Hello 1!'), - new_agent_text_message('Hello 2!'), + new_text_message('Hello 1!'), + new_text_message('Hello 2!'), ] task2 = create_sample_task(task_id='task2') task2.history.extend(history) @@ -957,7 +960,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): assert context.message is not None, ( 'A message is required to create a new task' ) - task = new_task(context.message) # type: ignore + task = new_task_from_user_message(context.message) # type: ignore await event_queue.enqueue_event(task) updater = TaskUpdater(event_queue, task.id, task.context_id) diff --git a/tests/server/request_handlers/test_default_request_handler_v2.py b/tests/server/request_handlers/test_default_request_handler_v2.py index d48b82461..3e1568b2e 100644 --- a/tests/server/request_handlers/test_default_request_handler_v2.py +++ b/tests/server/request_handlers/test_default_request_handler_v2.py @@ -54,7 +54,10 @@ TaskState, TaskStatus, ) -from a2a.utils import new_agent_text_message, new_task +from a2a.helpers.proto_helpers import ( + new_text_message, + new_task_from_user_message, +) def create_default_agent_card(): @@ -211,8 +214,8 @@ async def test_on_list_tasks_applies_history_length(): """Test on_list_tasks applies history length filter.""" mock_task_store = AsyncMock(spec=TaskStore) history = [ - new_agent_text_message('Hello 1!'), - new_agent_text_message('Hello 2!'), + new_text_message('Hello 1!'), + new_text_message('Hello 2!'), ] task2 = create_sample_task(task_id='task2') task2.history.extend(history) @@ -274,7 +277,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): assert context.message is not None, ( 'A message is required to create a new task' ) - task = new_task(context.message) + task = new_task_from_user_message(context.message) await event_queue.enqueue_event(task) updater = TaskUpdater(event_queue, task.id, task.context_id) try: diff --git a/tests/utils/test_artifact.py b/tests/utils/test_artifact.py deleted file mode 100644 index cbe8e9c91..000000000 --- a/tests/utils/test_artifact.py +++ /dev/null @@ -1,161 +0,0 @@ -import unittest -import uuid - -from unittest.mock import patch - -from google.protobuf.struct_pb2 import Struct - -from a2a.types.a2a_pb2 import ( - Artifact, - Part, -) -from a2a.utils.artifact import ( - get_artifact_text, - new_artifact, - new_data_artifact, - new_text_artifact, -) - - -class TestArtifact(unittest.TestCase): - @patch('uuid.uuid4') - def test_new_artifact_generates_id(self, mock_uuid4): - mock_uuid = uuid.UUID('abcdef12-1234-5678-1234-567812345678') - mock_uuid4.return_value = mock_uuid - artifact = new_artifact(parts=[], name='test_artifact') - self.assertEqual(artifact.artifact_id, str(mock_uuid)) - - def test_new_artifact_assigns_parts_name_description(self): - parts = [Part(text='Sample text')] - name = 'My Artifact' - description = 'This is a test artifact.' - artifact = new_artifact(parts=parts, name=name, description=description) - assert len(artifact.parts) == len(parts) - self.assertEqual(artifact.name, name) - self.assertEqual(artifact.description, description) - - def test_new_artifact_empty_description_if_not_provided(self): - parts = [Part(text='Another sample')] - name = 'Artifact_No_Desc' - artifact = new_artifact(parts=parts, name=name) - self.assertEqual(artifact.description, '') - - def test_new_text_artifact_creates_single_text_part(self): - text = 'This is a text artifact.' - name = 'Text_Artifact' - artifact = new_text_artifact(text=text, name=name) - self.assertEqual(len(artifact.parts), 1) - self.assertTrue(artifact.parts[0].HasField('text')) - - def test_new_text_artifact_part_contains_provided_text(self): - text = 'Hello, world!' - name = 'Greeting_Artifact' - artifact = new_text_artifact(text=text, name=name) - self.assertEqual(artifact.parts[0].text, text) - - def test_new_text_artifact_assigns_name_description(self): - text = 'Some content.' - name = 'Named_Text_Artifact' - description = 'Description for text artifact.' - artifact = new_text_artifact( - text=text, name=name, description=description - ) - self.assertEqual(artifact.name, name) - self.assertEqual(artifact.description, description) - - def test_new_data_artifact_creates_single_data_part(self): - sample_data = {'key': 'value', 'number': 123} - name = 'Data_Artifact' - artifact = new_data_artifact(data=sample_data, name=name) - self.assertEqual(len(artifact.parts), 1) - self.assertTrue(artifact.parts[0].HasField('data')) - - def test_new_data_artifact_part_contains_provided_data(self): - sample_data = {'content': 'test_data', 'is_valid': True} - name = 'Structured_Data_Artifact' - artifact = new_data_artifact(data=sample_data, name=name) - self.assertTrue(artifact.parts[0].HasField('data')) - # Compare via MessageToDict for proto Struct - from google.protobuf.json_format import MessageToDict - - self.assertEqual(MessageToDict(artifact.parts[0].data), sample_data) - - def test_new_data_artifact_assigns_name_description(self): - sample_data = {'info': 'some details'} - name = 'Named_Data_Artifact' - description = 'Description for data artifact.' - artifact = new_data_artifact( - data=sample_data, name=name, description=description - ) - self.assertEqual(artifact.name, name) - self.assertEqual(artifact.description, description) - - -class TestGetArtifactText(unittest.TestCase): - def test_get_artifact_text_single_part(self): - # Setup - artifact = Artifact( - name='test-artifact', - parts=[Part(text='Hello world')], - artifact_id='test-artifact-id', - ) - - # Exercise - result = get_artifact_text(artifact) - - # Verify - assert result == 'Hello world' - - def test_get_artifact_text_multiple_parts(self): - # Setup - artifact = Artifact( - name='test-artifact', - parts=[ - Part(text='First line'), - Part(text='Second line'), - Part(text='Third line'), - ], - artifact_id='test-artifact-id', - ) - - # Exercise - result = get_artifact_text(artifact) - - # Verify - default delimiter is newline - assert result == 'First line\nSecond line\nThird line' - - def test_get_artifact_text_custom_delimiter(self): - # Setup - artifact = Artifact( - name='test-artifact', - parts=[ - Part(text='First part'), - Part(text='Second part'), - Part(text='Third part'), - ], - artifact_id='test-artifact-id', - ) - - # Exercise - result = get_artifact_text(artifact, delimiter=' | ') - - # Verify - assert result == 'First part | Second part | Third part' - - def test_get_artifact_text_empty_parts(self): - # Setup - artifact = Artifact( - name='test-artifact', - parts=[], - artifact_id='test-artifact-id', - ) - - # Exercise - result = get_artifact_text(artifact) - - # Verify - assert result == '' - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index d8a85fcd9..c2c990c0d 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -22,14 +22,9 @@ TaskStatus, ) from a2a.utils.errors import UnsupportedOperationError -from a2a.utils.helpers import ( - _clean_empty, - append_artifact_to_task, - are_modalities_compatible, - build_text_artifact, - canonicalize_agent_card, - create_task_obj, -) + +from a2a.utils.signing import _clean_empty, _canonicalize_agent_card +from a2a.server.tasks.task_manager import append_artifact_to_task # --- Helper Functions --- @@ -90,62 +85,6 @@ def create_test_task( } -# Test create_task_obj -def test_create_task_obj(): - message = create_test_message() - message.context_id = 'test-context' # Set context_id to test it's preserved - send_params = SendMessageRequest(message=message) - - task = create_task_obj(send_params) - assert task.id is not None - assert task.context_id == message.context_id - assert task.status.state == TaskState.TASK_STATE_SUBMITTED - assert len(task.history) == 1 - assert task.history[0] == message - - -def test_create_task_obj_generates_context_id(): - """Test that create_task_obj generates context_id if not present and uses it for the task.""" - # Message without context_id - message_no_context_id = Message( - role=Role.ROLE_USER, - parts=[Part(text='test')], - message_id='msg-no-ctx', - task_id='task-from-msg', # Provide a task_id to differentiate from generated task.id - ) - send_params = SendMessageRequest(message=message_no_context_id) - - # Ensure message.context_id is empty initially (proto default is empty string) - assert send_params.message.context_id == '' - - known_task_uuid = uuid.UUID('11111111-1111-1111-1111-111111111111') - known_context_uuid = uuid.UUID('22222222-2222-2222-2222-222222222222') - - # Patch uuid.uuid4 to return specific UUIDs in sequence - # The first call will be for message.context_id (if empty), the second for task.id. - with patch( - 'a2a.utils.helpers.uuid4', - side_effect=[known_context_uuid, known_task_uuid], - ) as mock_uuid4: - task = create_task_obj(send_params) - - # Assert that uuid4 was called twice (once for context_id, once for task.id) - assert mock_uuid4.call_count == 2 - - # Assert that message.context_id was set to the first generated UUID - assert send_params.message.context_id == str(known_context_uuid) - - # Assert that task.context_id is the same generated UUID - assert task.context_id == str(known_context_uuid) - - # Assert that task.id is the second generated UUID - assert task.id == str(known_task_uuid) - - # Ensure the original message in history also has the updated context_id - assert len(task.history) == 1 - assert task.history[0].context_id == str(known_context_uuid) - - # Test append_artifact_to_task def test_append_artifact_to_task(): # Prepare base task @@ -243,6 +182,10 @@ def test_append_artifact_to_task(): assert len(task.artifacts[1].parts) == 1 +def build_text_artifact(text: str, artifact_id: str) -> Artifact: + return Artifact(artifact_id=artifact_id, parts=[Part(text=text)]) + + # Test build_text_artifact def test_build_text_artifact(): artifact_id = 'text_artifact' @@ -254,111 +197,6 @@ def test_build_text_artifact(): assert artifact.parts[0].text == text -# Tests for are_modalities_compatible -def test_are_modalities_compatible_client_none(): - assert ( - are_modalities_compatible( - client_output_modes=None, server_output_modes=['text/plain'] - ) - is True - ) - - -def test_are_modalities_compatible_client_empty(): - assert ( - are_modalities_compatible( - client_output_modes=[], server_output_modes=['text/plain'] - ) - is True - ) - - -def test_are_modalities_compatible_server_none(): - assert ( - are_modalities_compatible( - server_output_modes=None, client_output_modes=['text/plain'] - ) - is True - ) - - -def test_are_modalities_compatible_server_empty(): - assert ( - are_modalities_compatible( - server_output_modes=[], client_output_modes=['text/plain'] - ) - is True - ) - - -def test_are_modalities_compatible_common_mode(): - assert ( - are_modalities_compatible( - server_output_modes=['text/plain', 'application/json'], - client_output_modes=['application/json', 'image/png'], - ) - is True - ) - - -def test_are_modalities_compatible_no_common_modes(): - assert ( - are_modalities_compatible( - server_output_modes=['text/plain'], - client_output_modes=['application/json'], - ) - is False - ) - - -def test_are_modalities_compatible_exact_match(): - assert ( - are_modalities_compatible( - server_output_modes=['text/plain'], - client_output_modes=['text/plain'], - ) - is True - ) - - -def test_are_modalities_compatible_server_more_but_common(): - assert ( - are_modalities_compatible( - server_output_modes=['text/plain', 'image/jpeg'], - client_output_modes=['text/plain'], - ) - is True - ) - - -def test_are_modalities_compatible_client_more_but_common(): - assert ( - are_modalities_compatible( - server_output_modes=['text/plain'], - client_output_modes=['text/plain', 'image/jpeg'], - ) - is True - ) - - -def test_are_modalities_compatible_both_none(): - assert ( - are_modalities_compatible( - server_output_modes=None, client_output_modes=None - ) - is True - ) - - -def test_are_modalities_compatible_both_empty(): - assert ( - are_modalities_compatible( - server_output_modes=[], client_output_modes=[] - ) - is True - ) - - def test_canonicalize_agent_card(): """Test canonicalize_agent_card with defaults, optionals, and exceptions. @@ -375,7 +213,7 @@ def test_canonicalize_agent_card(): '"supportedInterfaces":[{"protocolBinding":"HTTP+JSON","url":"http://localhost"}],' '"version":"1.0.0"}' ) - result = canonicalize_agent_card(agent_card) + result = _canonicalize_agent_card(agent_card) assert result == expected_jcs @@ -390,7 +228,7 @@ def test_canonicalize_agent_card_preserves_false_capability(): ), } ) - result = canonicalize_agent_card(card) + result = _canonicalize_agent_card(card) assert '"streaming":false' in result diff --git a/tests/utils/test_message.py b/tests/utils/test_message.py deleted file mode 100644 index c90d422aa..000000000 --- a/tests/utils/test_message.py +++ /dev/null @@ -1,209 +0,0 @@ -import uuid - -from unittest.mock import patch - -from google.protobuf.struct_pb2 import Struct, Value - -from a2a.types.a2a_pb2 import ( - Message, - Part, - Role, -) -from a2a.utils.message import ( - get_message_text, - new_agent_parts_message, - new_agent_text_message, -) - - -class TestNewAgentTextMessage: - def test_new_agent_text_message_basic(self): - # Setup - text = "Hello, I'm an agent" - - # Exercise - with a fixed uuid for testing - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = new_agent_text_message(text) - - # Verify - assert message.role == Role.ROLE_AGENT - assert len(message.parts) == 1 - assert message.parts[0].text == text - assert message.message_id == '12345678-1234-5678-1234-567812345678' - assert message.task_id == '' - assert message.context_id == '' - - def test_new_agent_text_message_with_context_id(self): - # Setup - text = 'Message with context' - context_id = 'test-context-id' - - # Exercise - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = new_agent_text_message(text, context_id=context_id) - - # Verify - assert message.role == Role.ROLE_AGENT - assert message.parts[0].text == text - assert message.message_id == '12345678-1234-5678-1234-567812345678' - assert message.context_id == context_id - assert message.task_id == '' - - def test_new_agent_text_message_with_task_id(self): - # Setup - text = 'Message with task id' - task_id = 'test-task-id' - - # Exercise - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = new_agent_text_message(text, task_id=task_id) - - # Verify - assert message.role == Role.ROLE_AGENT - assert message.parts[0].text == text - assert message.message_id == '12345678-1234-5678-1234-567812345678' - assert message.task_id == task_id - assert message.context_id == '' - - def test_new_agent_text_message_with_both_ids(self): - # Setup - text = 'Message with both ids' - context_id = 'test-context-id' - task_id = 'test-task-id' - - # Exercise - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = new_agent_text_message( - text, context_id=context_id, task_id=task_id - ) - - # Verify - assert message.role == Role.ROLE_AGENT - assert message.parts[0].text == text - assert message.message_id == '12345678-1234-5678-1234-567812345678' - assert message.context_id == context_id - assert message.task_id == task_id - - def test_new_agent_text_message_empty_text(self): - # Setup - text = '' - - # Exercise - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = new_agent_text_message(text) - - # Verify - assert message.role == Role.ROLE_AGENT - assert message.parts[0].text == '' - assert message.message_id == '12345678-1234-5678-1234-567812345678' - - -class TestNewAgentPartsMessage: - def test_new_agent_parts_message(self): - """Test creating an agent message with multiple, mixed parts.""" - # Setup - data = Struct() - data.update({'product_id': 123, 'quantity': 2}) - parts = [ - Part(text='Here is some text.'), - Part(data=Value(struct_value=data)), - ] - context_id = 'ctx-multi-part' - task_id = 'task-multi-part' - - # Exercise - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('abcdefab-cdef-abcd-efab-cdefabcdefab'), - ): - message = new_agent_parts_message( - parts, context_id=context_id, task_id=task_id - ) - - # Verify - assert message.role == Role.ROLE_AGENT - assert len(message.parts) == len(parts) - assert message.context_id == context_id - assert message.task_id == task_id - assert message.message_id == 'abcdefab-cdef-abcd-efab-cdefabcdefab' - - -class TestGetMessageText: - def test_get_message_text_single_part(self): - # Setup - message = Message( - role=Role.ROLE_AGENT, - parts=[Part(text='Hello world')], - message_id='test-message-id', - ) - - # Exercise - result = get_message_text(message) - - # Verify - assert result == 'Hello world' - - def test_get_message_text_multiple_parts(self): - # Setup - message = Message( - role=Role.ROLE_AGENT, - parts=[ - Part(text='First line'), - Part(text='Second line'), - Part(text='Third line'), - ], - message_id='test-message-id', - ) - - # Exercise - result = get_message_text(message) - - # Verify - default delimiter is newline - assert result == 'First line\nSecond line\nThird line' - - def test_get_message_text_custom_delimiter(self): - # Setup - message = Message( - role=Role.ROLE_AGENT, - parts=[ - Part(text='First part'), - Part(text='Second part'), - Part(text='Third part'), - ], - message_id='test-message-id', - ) - - # Exercise - result = get_message_text(message, delimiter=' | ') - - # Verify - assert result == 'First part | Second part | Third part' - - def test_get_message_text_empty_parts(self): - # Setup - message = Message( - role=Role.ROLE_AGENT, - parts=[], - message_id='test-message-id', - ) - - # Exercise - result = get_message_text(message) - - # Verify - assert result == '' diff --git a/tests/utils/test_parts.py b/tests/utils/test_parts.py deleted file mode 100644 index a7a24e225..000000000 --- a/tests/utils/test_parts.py +++ /dev/null @@ -1,184 +0,0 @@ -from google.protobuf.struct_pb2 import Struct, Value -from a2a.types.a2a_pb2 import ( - Part, -) -from a2a.utils.parts import ( - get_data_parts, - get_file_parts, - get_text_parts, -) - - -class TestGetTextParts: - def test_get_text_parts_single_text_part(self): - # Setup - parts = [Part(text='Hello world')] - - # Exercise - result = get_text_parts(parts) - - # Verify - assert result == ['Hello world'] - - def test_get_text_parts_multiple_text_parts(self): - # Setup - parts = [ - Part(text='First part'), - Part(text='Second part'), - Part(text='Third part'), - ] - - # Exercise - result = get_text_parts(parts) - - # Verify - assert result == ['First part', 'Second part', 'Third part'] - - def test_get_text_parts_empty_list(self): - # Setup - parts = [] - - # Exercise - result = get_text_parts(parts) - - # Verify - assert result == [] - - -class TestGetDataParts: - def test_get_data_parts_single_data_part(self): - # Setup - data = Struct() - data.update({'key': 'value'}) - parts = [Part(data=Value(struct_value=data))] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [{'key': 'value'}] - - def test_get_data_parts_multiple_data_parts(self): - # Setup - data1 = Struct() - data1.update({'key1': 'value1'}) - data2 = Struct() - data2.update({'key2': 'value2'}) - parts = [ - Part(data=Value(struct_value=data1)), - Part(data=Value(struct_value=data2)), - ] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [{'key1': 'value1'}, {'key2': 'value2'}] - - def test_get_data_parts_mixed_parts(self): - # Setup - data1 = Struct() - data1.update({'key1': 'value1'}) - data2 = Struct() - data2.update({'key2': 'value2'}) - parts = [ - Part(text='some text'), - Part(data=Value(struct_value=data1)), - Part(data=Value(struct_value=data2)), - ] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [{'key1': 'value1'}, {'key2': 'value2'}] - - def test_get_data_parts_no_data_parts(self): - # Setup - parts = [ - Part(text='some text'), - ] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [] - - def test_get_data_parts_empty_list(self): - # Setup - parts = [] - - # Exercise - result = get_data_parts(parts) - - # Verify - assert result == [] - - -class TestGetFileParts: - def test_get_file_parts_single_file_part(self): - # Setup - parts = [Part(url='file://path/to/file', media_type='text/plain')] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert len(result) == 1 - assert result[0].url == 'file://path/to/file' - assert result[0].media_type == 'text/plain' - - def test_get_file_parts_multiple_file_parts(self): - # Setup - parts = [ - Part(url='file://path/to/file1', media_type='text/plain'), - Part(raw=b'file content', media_type='application/octet-stream'), - ] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert len(result) == 2 - assert result[0].url == 'file://path/to/file1' - assert result[1].raw == b'file content' - - def test_get_file_parts_mixed_parts(self): - # Setup - parts = [ - Part(text='some text'), - Part(url='file://path/to/file', media_type='text/plain'), - ] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert len(result) == 1 - assert result[0].url == 'file://path/to/file' - - def test_get_file_parts_no_file_parts(self): - # Setup - data = Struct() - data.update({'key': 'value'}) - parts = [ - Part(text='some text'), - Part(data=Value(struct_value=data)), - ] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert result == [] - - def test_get_file_parts_empty_list(self): - # Setup - parts = [] - - # Exercise - result = get_file_parts(parts) - - # Verify - assert result == [] diff --git a/tests/utils/test_task.py b/tests/utils/test_task.py index 3e1f3c058..55dc8ed4f 100644 --- a/tests/utils/test_task.py +++ b/tests/utils/test_task.py @@ -14,197 +14,16 @@ GetTaskRequest, SendMessageConfiguration, ) +from a2a.helpers.proto_helpers import new_task from a2a.utils.task import ( apply_history_length, - completed_task, decode_page_token, encode_page_token, - new_task, ) from a2a.utils.errors import InvalidParamsError class TestTask(unittest.TestCase): - def test_new_task_status(self): - message = Message( - role=Role.ROLE_USER, - parts=[Part(text='test message')], - message_id=str(uuid.uuid4()), - ) - task = new_task(message) - self.assertEqual(task.status.state, TaskState.TASK_STATE_SUBMITTED) - - @patch('uuid.uuid4') - def test_new_task_generates_ids(self, mock_uuid4): - mock_uuid = uuid.UUID('12345678-1234-5678-1234-567812345678') - mock_uuid4.return_value = mock_uuid - message = Message( - role=Role.ROLE_USER, - parts=[Part(text='test message')], - message_id=str(uuid.uuid4()), - ) - task = new_task(message) - self.assertEqual(task.id, str(mock_uuid)) - self.assertEqual(task.context_id, str(mock_uuid)) - - def test_new_task_uses_provided_ids(self): - task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) - message = Message( - role=Role.ROLE_USER, - parts=[Part(text='test message')], - message_id=str(uuid.uuid4()), - task_id=task_id, - context_id=context_id, - ) - task = new_task(message) - self.assertEqual(task.id, task_id) - self.assertEqual(task.context_id, context_id) - - def test_new_task_initial_message_in_history(self): - message = Message( - role=Role.ROLE_USER, - parts=[Part(text='test message')], - message_id=str(uuid.uuid4()), - ) - task = new_task(message) - self.assertEqual(len(task.history), 1) - self.assertEqual(task.history[0], message) - - def test_completed_task_status(self): - task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) - artifacts = [ - Artifact( - artifact_id='artifact_1', - parts=[Part(text='some content')], - ) - ] - task = completed_task( - task_id=task_id, - context_id=context_id, - artifacts=artifacts, - history=[], - ) - self.assertEqual(task.status.state, TaskState.TASK_STATE_COMPLETED) - - def test_completed_task_assigns_ids_and_artifacts(self): - task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) - artifacts = [ - Artifact( - artifact_id='artifact_1', - parts=[Part(text='some content')], - ) - ] - task = completed_task( - task_id=task_id, - context_id=context_id, - artifacts=artifacts, - history=[], - ) - self.assertEqual(task.id, task_id) - self.assertEqual(task.context_id, context_id) - self.assertEqual(len(task.artifacts), len(artifacts)) - - def test_completed_task_empty_history_if_not_provided(self): - task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) - artifacts = [ - Artifact( - artifact_id='artifact_1', - parts=[Part(text='some content')], - ) - ] - task = completed_task( - task_id=task_id, context_id=context_id, artifacts=artifacts - ) - self.assertEqual(len(task.history), 0) - - def test_completed_task_uses_provided_history(self): - task_id = str(uuid.uuid4()) - context_id = str(uuid.uuid4()) - artifacts = [ - Artifact( - artifact_id='artifact_1', - parts=[Part(text='some content')], - ) - ] - history = [ - Message( - role=Role.ROLE_USER, - parts=[Part(text='Hello')], - message_id=str(uuid.uuid4()), - ), - Message( - role=Role.ROLE_AGENT, - parts=[Part(text='Hi there')], - message_id=str(uuid.uuid4()), - ), - ] - task = completed_task( - task_id=task_id, - context_id=context_id, - artifacts=artifacts, - history=history, - ) - self.assertEqual(len(task.history), len(history)) - - def test_new_task_invalid_message_empty_parts(self): - with self.assertRaises(ValueError): - new_task( - Message( - role=Role.ROLE_USER, - parts=[], - message_id=str(uuid.uuid4()), - ) - ) - - def test_new_task_invalid_message_empty_content(self): - with self.assertRaises(ValueError): - new_task( - Message( - role=Role.ROLE_USER, - parts=[Part(text='')], - message_id=str(uuid.uuid4()), - ) - ) - - def test_new_task_invalid_message_none_role(self): - # Proto messages always have a default role (ROLE_UNSPECIFIED = 0) - # Testing with unspecified role - msg = Message( - role=Role.ROLE_UNSPECIFIED, - parts=[Part(text='test message')], - message_id=str(uuid.uuid4()), - ) - with self.assertRaises((TypeError, ValueError)): - new_task(msg) - - def test_completed_task_empty_artifacts(self): - with pytest.raises( - ValueError, - match='artifacts must be a non-empty list of Artifact objects', - ): - completed_task( - task_id='task-123', - context_id='ctx-456', - artifacts=[], - history=[], - ) - - def test_completed_task_invalid_artifact_type(self): - with pytest.raises( - ValueError, - match='artifacts must be a non-empty list of Artifact objects', - ): - completed_task( - task_id='task-123', - context_id='ctx-456', - artifacts=['not an artifact'], # type: ignore[arg-type] - history=[], - ) - page_token = 'd47a95ba-0f39-4459-965b-3923cdd2ff58' encoded_page_token = 'ZDQ3YTk1YmEtMGYzOS00NDU5LTk2NWItMzkyM2NkZDJmZjU4' # base64 for 'd47a95ba-0f39-4459-965b-3923cdd2ff58' @@ -234,9 +53,10 @@ def setUp(self): for i in range(5) ] artifacts = [Artifact(artifact_id='a1', parts=[Part(text='a')])] - self.task = completed_task( + self.task = new_task( task_id='t1', context_id='c1', + state=TaskState.TASK_STATE_COMPLETED, artifacts=artifacts, history=self.history, )