diff --git a/README.md b/README.md index 43497bc2..4ef7a49b 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,12 @@ To install with gRPC support: uv add "a2a-sdk[grpc]" ``` +To install with Kafka transport support: + +```bash +uv add "a2a-sdk[kafka]" +``` + To install with OpenTelemetry tracing support: ```bash @@ -87,6 +93,12 @@ To install with gRPC support: pip install "a2a-sdk[grpc]" ``` +To install with Kafka transport support: + +```bash +pip install "a2a-sdk[kafka]" +``` + To install with OpenTelemetry tracing support: ```bash diff --git a/pyproject.toml b/pyproject.toml index c1da2323..ccdad4c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"] sql = ["sqlalchemy[asyncio,postgresql-asyncpg,aiomysql,aiosqlite]>=2.0.0"] encryption = ["cryptography>=43.0.0"] grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio_reflection>=1.7.0"] +kafka = ["aiokafka>=0.11.0"] telemetry = ["opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0"] [project.urls] diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index c568331f..b98312d3 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -25,6 +25,11 @@ except ImportError: GrpcTransport = None # type: ignore # pyright: ignore +try: + from a2a.client.transports.kafka import KafkaClientTransport +except ImportError: + KafkaClientTransport = None # type: ignore # pyright: ignore + logger = logging.getLogger(__name__) @@ -97,6 +102,32 @@ def _register_defaults( TransportProtocol.grpc, GrpcTransport.create, ) + if TransportProtocol.kafka in supported: + if KafkaClientTransport is None: + raise ImportError( + 'To use KafkaClient, its dependencies must be installed. ' + 'You can install them with \'pip install "a2a-sdk[kafka]"\'' + ) + self.register( + TransportProtocol.kafka, + self._create_kafka_transport, + ) + + def _create_kafka_transport( + self, + card: AgentCard, + url: str, + config: ClientConfig, + interceptors: list[ClientCallInterceptor], + ) -> ClientTransport: + """Create a Kafka transport that will auto-start when first used.""" + # Create the transport using the existing create method + transport = KafkaClientTransport.create(card, url, config, interceptors) + + # Mark the transport for auto-start when first used + transport._auto_start = True + + return transport def register(self, label: str, generator: TransportProducer) -> None: """Register a new transport producer for a given transport label.""" diff --git a/src/a2a/client/transports/__init__.py b/src/a2a/client/transports/__init__.py index af7c60f6..55d0aead 100644 --- a/src/a2a/client/transports/__init__.py +++ b/src/a2a/client/transports/__init__.py @@ -10,10 +10,16 @@ except ImportError: GrpcTransport = None # type: ignore +try: + from a2a.client.transports.kafka import KafkaClientTransport +except ImportError: + KafkaClientTransport = None # type: ignore + __all__ = [ 'ClientTransport', 'GrpcTransport', 'JsonRpcTransport', + 'KafkaClientTransport', 'RestTransport', ] diff --git a/src/a2a/client/transports/kafka.py b/src/a2a/client/transports/kafka.py new file mode 100644 index 00000000..34acbbdb --- /dev/null +++ b/src/a2a/client/transports/kafka.py @@ -0,0 +1,602 @@ +"""Kafka transport implementation for A2A client.""" + +import asyncio +import json +import logging +import re +from collections.abc import AsyncGenerator +from typing import Any, Dict, List, Optional +from uuid import uuid4 + +from aiokafka import AIOKafkaConsumer, AIOKafkaProducer +from aiokafka.errors import KafkaError + +from a2a.client.middleware import ClientCallContext +from a2a.client.transports.base import ClientTransport +from a2a.client.transports.kafka_correlation import CorrelationManager +from a2a.client.errors import A2AClientError +from a2a.types import ( + AgentCard, + DeleteTaskPushNotificationConfigParams, + GetTaskPushNotificationConfigParams, + ListTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, +) + +logger = logging.getLogger(__name__) + + +class KafkaClientTransport(ClientTransport): + """Kafka-based client transport for A2A protocol.""" + + def __init__( + self, + agent_card: AgentCard, + bootstrap_servers: str | List[str] = "localhost:9092", + request_topic: str = "a2a-requests", + reply_topic_prefix: str = "a2a-reply", + reply_topic: Optional[str] = None, + consumer_group_id: Optional[str] = None, + **kafka_config: Any, + ) -> None: + """Initialize Kafka client transport. + + Args: + agent_card: The agent card for this client. + bootstrap_servers: Kafka bootstrap servers. + request_topic: Topic where requests are sent. + reply_topic_prefix: Prefix for reply topics. + reply_topic: Explicit reply topic to use. If not provided, it will be generated on start(). + consumer_group_id: Consumer group ID for the reply consumer. + **kafka_config: Additional Kafka configuration. + """ + self.agent_card = agent_card + self.bootstrap_servers = bootstrap_servers + self.request_topic = request_topic + self.reply_topic_prefix = reply_topic_prefix + # Defer reply_topic generation until start() unless explicitly provided + self.reply_topic: Optional[str] = reply_topic + # Defer consumer_group_id defaulting until start() + self.consumer_group_id = consumer_group_id + # Per-instance unique ID to ensure unique reply topics even with same agent name + self._instance_id = uuid4().hex[:8] + self.kafka_config = kafka_config + + self.producer: Optional[AIOKafkaProducer] = None + self.consumer: Optional[AIOKafkaConsumer] = None + self.correlation_manager = CorrelationManager() + self._consumer_task: Optional[asyncio.Task[None]] = None + self._running = False + self._auto_start = False + + def _sanitize_topic_name(self, name: str) -> str: + """Sanitize a name to be valid for Kafka topic names. + + Kafka topic names must: + - Contain only alphanumeric characters, periods, underscores, and hyphens + - Not be empty + - Not exceed 249 characters + + Args: + name: The original name to sanitize. + + Returns: + A sanitized name suitable for use in Kafka topic names. + """ + # Replace invalid characters with underscores + sanitized = re.sub(r'[^a-zA-Z0-9._-]', '_', name) + + # Ensure it's not empty + if not sanitized: + sanitized = "unknown_agent" + + # Truncate if too long (leave room for prefixes) + if len(sanitized) > 200: + sanitized = sanitized[:200] + + return sanitized + + async def start(self) -> None: + """Start the Kafka client transport. + + This method is called internally by the client factory and should not be + exposed to end users. It initializes the Kafka producer and consumer. + """ + if self._running: + return + + try: + # Ensure reply_topic and consumer_group_id are prepared + if not self.reply_topic: + sanitized_agent_name = self._sanitize_topic_name(self.agent_card.name) + self.reply_topic = f"{self.reply_topic_prefix}-{sanitized_agent_name}-{self._instance_id}" + if not self.consumer_group_id: + sanitized_agent_name = self._sanitize_topic_name(self.agent_card.name) + self.consumer_group_id = f"a2a-client-{sanitized_agent_name}-{self._instance_id}" + + # Initialize producer + self.producer = AIOKafkaProducer( + bootstrap_servers=self.bootstrap_servers, + value_serializer=lambda v: json.dumps(v).encode('utf-8'), + **self.kafka_config + ) + await self.producer.start() + + # Initialize consumer + self.consumer = AIOKafkaConsumer( + self.reply_topic, + bootstrap_servers=self.bootstrap_servers, + group_id=self.consumer_group_id, + value_deserializer=lambda m: json.loads(m.decode('utf-8')), + auto_offset_reset='latest', + **self.kafka_config + ) + await self.consumer.start() + + # Start consumer task + self._consumer_task = asyncio.create_task(self._consume_responses()) + self._running = True + + logger.info(f"Kafka client transport started for agent {self.agent_card.name}") + + except Exception as e: + await self.stop() + raise A2AClientError(f"Failed to start Kafka client transport: {e}") from e + + async def stop(self) -> None: + """Stop the Kafka client transport. + + This method is called internally by the close() method and should not be + exposed to end users. It cleans up the Kafka producer and consumer. + """ + if not self._running: + return + + self._running = False + + # Cancel consumer task + if self._consumer_task: + self._consumer_task.cancel() + try: + await self._consumer_task + except asyncio.CancelledError: + pass + + # Cancel all pending requests + await self.correlation_manager.cancel_all() + + # Stop producer and consumer + if self.producer: + await self.producer.stop() + if self.consumer: + await self.consumer.stop() + + logger.info(f"Kafka client transport stopped for agent {self.agent_card.name}") + + async def _ensure_started(self) -> None: + """Ensure the transport is started, auto-starting if needed.""" + if not self._running and self._auto_start: + await self.start() + + async def _consume_responses(self) -> None: + """Consume responses from the reply topic.""" + if not self.consumer: + return + + try: + async for message in self.consumer: + try: + # Extract correlation ID from headers + correlation_id = None + if message.headers: + for key, value in message.headers: + if key == 'correlation_id': + correlation_id = value.decode('utf-8') + break + + if not correlation_id: + logger.warning("Received message without correlation_id") + continue + + # Parse response + response_data = message.value + response_type = response_data.get('type', 'message') + + # Handle stream completion signal + if response_type == 'stream_complete': + await self.correlation_manager.complete_streaming(correlation_id) + continue + + # Handle error responses + if response_type == 'error': + error_message = response_data.get('data', {}).get('error', 'Unknown error') + await self.correlation_manager.complete_with_exception( + correlation_id, + A2AClientError(error_message) + ) + continue + + # Parse and complete normal responses + response = self._parse_response(response_data) + await self.correlation_manager.complete(correlation_id, response) + + except Exception as e: + logger.error(f"Error processing response message: {e}") + + except asyncio.CancelledError: + logger.debug("Response consumer cancelled") + except Exception as e: + logger.error(f"Error in response consumer: {e}") + + def _parse_response(self, data: Dict[str, Any]) -> Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent: + """Parse response data into appropriate type.""" + response_type = data.get('type', 'message') + + if response_type == 'task': + return Task.model_validate(data['data']) + elif response_type == 'task_status_update': + return TaskStatusUpdateEvent.model_validate(data['data']) + elif response_type == 'task_artifact_update': + return TaskArtifactUpdateEvent.model_validate(data['data']) + else: + return Message.model_validate(data['data']) + + async def _send_request( + self, + method: str, + params: Any, + context: ClientCallContext | None = None, + streaming: bool = False, + ) -> str: + """Send a request and return the correlation ID.""" + await self._ensure_started() + + if not self.producer or not self._running: + raise A2AClientError("Kafka client transport not started") + + correlation_id = self.correlation_manager.generate_correlation_id() + + # Prepare request message + request_data = { + 'method': method, + 'params': params.model_dump() if hasattr(params, 'model_dump') else params, + 'streaming': streaming, + 'agent_card': self.agent_card.model_dump(), + } + + # Prepare headers + headers = [ + ('correlation_id', correlation_id.encode('utf-8')), + ('reply_topic', (self.reply_topic or '').encode('utf-8')), + ('agent_id', self.agent_card.name.encode('utf-8')), + ] + + if context: + # Add context headers if needed + if context.trace_id: + headers.append(('trace_id', context.trace_id.encode('utf-8'))) + + try: + await self.producer.send_and_wait( + self.request_topic, + value=request_data, + headers=headers + ) + return correlation_id + except KafkaError as e: + raise A2AClientError(f"Failed to send Kafka message: {e}") from e + + async def send_message( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> Task | Message: + """Send a non-streaming message request to the agent.""" + await self._ensure_started() + correlation_id = await self._send_request('message_send', request, context, streaming=False) + + # Register and wait for response + future = await self.correlation_manager.register(correlation_id) + + try: + # Wait for response with timeout + timeout = 30.0 # Default timeout + if context and context.timeout: + timeout = context.timeout + + result = await asyncio.wait_for(future, timeout=timeout) + return result + except asyncio.TimeoutError: + await self.correlation_manager.complete_with_exception( + correlation_id, + A2AClientError(f"Request timed out after {timeout} seconds") + ) + raise A2AClientError(f"Request timed out after {timeout} seconds") + + async def send_message_streaming( + self, + request: MessageSendParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent]: + """Send a streaming message request to the agent and yield responses as they arrive.""" + await self._ensure_started() + correlation_id = await self._send_request('message_send', request, context, streaming=True) + + # Register streaming request + streaming_future = await self.correlation_manager.register_streaming(correlation_id) + + try: + timeout = 30.0 + if context and context.timeout: + timeout = context.timeout + + # Yield responses as they arrive + while not streaming_future.is_done(): + try: + # Wait for next response with timeout + result = await asyncio.wait_for(streaming_future.get(), timeout=5.0) + yield result + except asyncio.TimeoutError: + # Check if stream is done or if we've exceeded total timeout + if streaming_future.is_done(): + break + # Continue waiting for more responses + continue + + except Exception as e: + await self.correlation_manager.complete_with_exception( + correlation_id, + A2AClientError(f"Streaming request failed: {e}") + ) + raise A2AClientError(f"Streaming request failed: {e}") from e + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Get a task by ID.""" + correlation_id = await self._send_request('task_get', request, context) + future = await self.correlation_manager.register(correlation_id) + + try: + timeout = 30.0 + if context and context.timeout: + timeout = context.timeout + + result = await asyncio.wait_for(future, timeout=timeout) + if not isinstance(result, Task): + raise A2AClientError(f"Expected Task, got {type(result)}") + return result + except asyncio.TimeoutError: + await self.correlation_manager.complete_with_exception( + correlation_id, + A2AClientError(f"Get task request timed out after {timeout} seconds") + ) + raise A2AClientError(f"Get task request timed out after {timeout} seconds") + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + """Cancel a task.""" + correlation_id = await self._send_request('task_cancel', request, context) + future = await self.correlation_manager.register(correlation_id) + + try: + timeout = 30.0 + if context and context.timeout: + timeout = context.timeout + + result = await asyncio.wait_for(future, timeout=timeout) + if not isinstance(result, Task): + raise A2AClientError(f"Expected Task, got {type(result)}") + return result + except asyncio.TimeoutError: + await self.correlation_manager.complete_with_exception( + correlation_id, + A2AClientError(f"Cancel task request timed out after {timeout} seconds") + ) + raise A2AClientError(f"Cancel task request timed out after {timeout} seconds") + + async def get_task_push_notification_config( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig | None: + """Get task push notification configuration.""" + correlation_id = await self._send_request('task_push_notification_config_get', request, context) + future = await self.correlation_manager.register(correlation_id) + + try: + timeout = 30.0 + if context and context.timeout: + timeout = context.timeout + + result = await asyncio.wait_for(future, timeout=timeout) + if result is None or isinstance(result, TaskPushNotificationConfig): + return result + raise A2AClientError(f"Expected TaskPushNotificationConfig or None, got {type(result)}") + except asyncio.TimeoutError: + await self.correlation_manager.complete_with_exception( + correlation_id, + A2AClientError(f"Get push notification config request timed out after {timeout} seconds") + ) + raise A2AClientError(f"Get push notification config request timed out after {timeout} seconds") + + async def list_task_push_notification_configs( + self, + request: ListTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> List[TaskPushNotificationConfig]: + """List task push notification configurations.""" + correlation_id = await self._send_request('task_push_notification_config_list', request, context) + future = await self.correlation_manager.register(correlation_id) + + try: + timeout = 30.0 + if context and context.timeout: + timeout = context.timeout + + result = await asyncio.wait_for(future, timeout=timeout) + if isinstance(result, list): + return result + raise A2AClientError(f"Expected list, got {type(result)}") + except asyncio.TimeoutError: + await self.correlation_manager.complete_with_exception( + correlation_id, + A2AClientError(f"List push notification configs request timed out after {timeout} seconds") + ) + raise A2AClientError(f"List push notification configs request timed out after {timeout} seconds") + + async def delete_task_push_notification_config( + self, + request: DeleteTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> None: + """Delete task push notification configuration.""" + correlation_id = await self._send_request('task_push_notification_config_delete', request, context) + future = await self.correlation_manager.register(correlation_id) + + try: + timeout = 30.0 + if context and context.timeout: + timeout = context.timeout + + await asyncio.wait_for(future, timeout=timeout) + except asyncio.TimeoutError: + await self.correlation_manager.complete_with_exception( + correlation_id, + A2AClientError(f"Delete push notification config request timed out after {timeout} seconds") + ) + raise A2AClientError(f"Delete push notification config request timed out after {timeout} seconds") + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Set task push notification configuration.""" + # For Kafka, we can store the callback configuration locally + # and use it when we receive push notifications + # This is a simplified implementation + return request + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Get task push notification configuration.""" + return await self.get_task_push_notification_config(request, context=context) + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent]: + """Reconnect to get task updates.""" + # For Kafka, resubscription is handled automatically by the consumer + # This method can be used to request task updates + task_request = TaskQueryParams(task_id=request.task_id) + task = await self.get_task(task_request, context=context) + if task: + yield task + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieve the agent card.""" + # For Kafka transport, we return the local agent card + # In a real implementation, this might query the server + return self.agent_card + + async def close(self) -> None: + """Close the transport. + + This method stops the Kafka client transport and cleans up all resources. + It's the public interface for shutting down the transport. + """ + await self.stop() + + async def __aenter__(self): + """Async context manager entry.""" + await self.start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.stop() + + def set_reply_topic(self, topic: str) -> None: + """Set an explicit reply topic before starting the transport. + + Must be called before start(). If called after the transport has + started, it will have no effect on the already running consumer. + """ + if self._running: + logger.warning("set_reply_topic called after start(); ignoring.") + return + self.reply_topic = topic + + @classmethod + def create( + cls, + agent_card: AgentCard, + url: str, + config: Any, + interceptors: List[Any], + ) -> "KafkaClientTransport": + """Create a Kafka client transport instance. + + This method matches the signature expected by ClientFactory. + For Kafka, the URL should be in the format: kafka://bootstrap_servers/request_topic + + Args: + agent_card: The agent card for this client. + url: Kafka URL (e.g., kafka://localhost:9092/a2a-requests) + config: Client configuration (unused for Kafka) + interceptors: Client interceptors (unused for Kafka) + + Returns: + Configured KafkaClientTransport instance. + """ + # Parse Kafka URL + if not url.startswith('kafka://'): + raise ValueError("Kafka URL must start with 'kafka://'") + + # Remove kafka:// prefix + kafka_part = url[8:] + + # Split into bootstrap_servers and topic + if '/' in kafka_part: + bootstrap_servers, request_topic = kafka_part.split('/', 1) + else: + bootstrap_servers = kafka_part + request_topic = "a2a-requests" # default topic + + return cls( + agent_card=agent_card, + bootstrap_servers=bootstrap_servers, + request_topic=request_topic, + ) diff --git a/src/a2a/client/transports/kafka_correlation.py b/src/a2a/client/transports/kafka_correlation.py new file mode 100644 index 00000000..6b70d272 --- /dev/null +++ b/src/a2a/client/transports/kafka_correlation.py @@ -0,0 +1,136 @@ +"""Correlation manager for Kafka request-response pattern.""" + +import asyncio +import uuid +from typing import Any, Dict, Optional, Set + +from a2a.types import Message, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent + + +class StreamingFuture: + """A future-like object for handling streaming responses.""" + + def __init__(self): + self.queue: asyncio.Queue[Any] = asyncio.Queue() + self._done = False + self._exception: Optional[Exception] = None + + async def put(self, item: Any) -> None: + """Add an item to the stream.""" + if not self._done: + await self.queue.put(item) + + async def get(self) -> Any: + """Get the next item from the stream.""" + if self._exception: + raise self._exception + return await self.queue.get() + + def set_exception(self, exception: Exception) -> None: + """Set an exception for the stream.""" + self._exception = exception + self._done = True + + def set_done(self) -> None: + """Mark the stream as complete.""" + self._done = True + + def is_done(self) -> bool: + """Check if the stream is complete.""" + return self._done + + def empty(self) -> bool: + """Check if the queue is empty.""" + return self.queue.empty() + + +class CorrelationManager: + """Manages correlation IDs and futures for Kafka request-response pattern.""" + + def __init__(self) -> None: + self._pending_requests: Dict[str, asyncio.Future[Any]] = {} + self._streaming_requests: Dict[str, StreamingFuture] = {} + self._lock = asyncio.Lock() + + def generate_correlation_id(self) -> str: + """Generate a unique correlation ID.""" + return str(uuid.uuid4()) + + async def register( + self, correlation_id: str + ) -> asyncio.Future[Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent]: + """Register a new request with correlation ID and return a future for the response.""" + async with self._lock: + future: asyncio.Future[Any] = asyncio.Future() + self._pending_requests[correlation_id] = future + return future + + async def register_streaming(self, correlation_id: str) -> StreamingFuture: + """Register a new streaming request and return a streaming future.""" + async with self._lock: + streaming_future = StreamingFuture() + self._streaming_requests[correlation_id] = streaming_future + return streaming_future + + async def complete( + self, + correlation_id: str, + result: Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent, + ) -> bool: + """Complete a pending request with the given result.""" + async with self._lock: + # Check regular requests first + future = self._pending_requests.pop(correlation_id, None) + if future and not future.done(): + future.set_result(result) + return True + + # Check streaming requests + streaming_future = self._streaming_requests.get(correlation_id) + if streaming_future and not streaming_future.is_done(): + await streaming_future.put(result) + return True + + return False + + async def complete_streaming(self, correlation_id: str) -> bool: + """Mark a streaming request as complete.""" + async with self._lock: + streaming_future = self._streaming_requests.pop(correlation_id, None) + if streaming_future: + streaming_future.set_done() + return True + return False + + async def complete_with_exception(self, correlation_id: str, exception: Exception) -> bool: + """Complete a pending request with an exception.""" + async with self._lock: + # Check regular requests first + future = self._pending_requests.pop(correlation_id, None) + if future and not future.done(): + future.set_exception(exception) + return True + + # Check streaming requests + streaming_future = self._streaming_requests.pop(correlation_id, None) + if streaming_future: + streaming_future.set_exception(exception) + return True + + return False + + async def cancel_all(self) -> None: + """Cancel all pending requests.""" + async with self._lock: + for future in self._pending_requests.values(): + if not future.done(): + future.cancel() + self._pending_requests.clear() + + for streaming_future in self._streaming_requests.values(): + streaming_future.set_exception(asyncio.CancelledError("Request cancelled")) + self._streaming_requests.clear() + + def get_pending_count(self) -> int: + """Get the number of pending requests.""" + return len(self._pending_requests) + len(self._streaming_requests) diff --git a/src/a2a/server/apps/__init__.py b/src/a2a/server/apps/__init__.py index 579deaa5..646c9c35 100644 --- a/src/a2a/server/apps/__init__.py +++ b/src/a2a/server/apps/__init__.py @@ -8,6 +8,11 @@ ) from a2a.server.apps.rest import A2ARESTFastAPIApplication +try: + from a2a.server.apps.kafka import KafkaServerApp +except ImportError: + KafkaServerApp = None # type: ignore + __all__ = [ 'A2AFastAPIApplication', @@ -15,4 +20,5 @@ 'A2AStarletteApplication', 'CallContextBuilder', 'JSONRPCApplication', + 'KafkaServerApp', ] diff --git a/src/a2a/server/apps/kafka/__init__.py b/src/a2a/server/apps/kafka/__init__.py new file mode 100644 index 00000000..5a0a5e42 --- /dev/null +++ b/src/a2a/server/apps/kafka/__init__.py @@ -0,0 +1,7 @@ +"""Kafka server application components for A2A.""" + +from a2a.server.apps.kafka.kafka_app import KafkaServerApp + +__all__ = [ + 'KafkaServerApp', +] diff --git a/src/a2a/server/apps/kafka/kafka_app.py b/src/a2a/server/apps/kafka/kafka_app.py new file mode 100644 index 00000000..32d32d9f --- /dev/null +++ b/src/a2a/server/apps/kafka/kafka_app.py @@ -0,0 +1,343 @@ +"""Kafka server application for A2A protocol.""" + +import asyncio +import json +import logging +import signal +from typing import Any, Dict, List, Optional + +from aiokafka import AIOKafkaConsumer, AIOKafkaProducer +from aiokafka.errors import KafkaError + +from a2a.server.request_handlers.kafka_handler import KafkaHandler, KafkaMessage +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.utils.errors import ServerError +from a2a.types import ( + Message, + Task, + TaskStatusUpdateEvent, + TaskArtifactUpdateEvent, +) + +logger = logging.getLogger(__name__) + + +class KafkaServerApp: + """Kafka server application that manages the service lifecycle.""" + + def __init__( + self, + request_handler: RequestHandler, + bootstrap_servers: str | List[str] = "localhost:9092", + request_topic: str = "a2a-requests", + consumer_group_id: str = "a2a-server", + **kafka_config: Any, + ) -> None: + """Initialize Kafka server application. + + Args: + request_handler: Business logic handler. + bootstrap_servers: Kafka bootstrap servers. + request_topic: Topic to consume requests from. + consumer_group_id: Consumer group ID for the server. + **kafka_config: Additional Kafka configuration. + """ + self.request_handler = request_handler + self.bootstrap_servers = bootstrap_servers + self.request_topic = request_topic + self.consumer_group_id = consumer_group_id + self.kafka_config = kafka_config + + self.consumer: Optional[AIOKafkaConsumer] = None + self.producer: Optional[AIOKafkaProducer] = None + self.handler: Optional[KafkaHandler] = None + self._running = False + self._consumer_task: Optional[asyncio.Task[None]] = None + + async def start(self) -> None: + """Start the Kafka server application.""" + if self._running: + return + + try: + # Initialize protocol handler (Kafka-agnostic) and pass self as response sender + self.handler = KafkaHandler( + self.request_handler, + response_sender=self, + ) + + # Initialize producer + self.producer = AIOKafkaProducer( + bootstrap_servers=self.bootstrap_servers, + value_serializer=lambda v: json.dumps(v).encode('utf-8'), + **self.kafka_config, + ) + await self.producer.start() + + # Initialize consumer + self.consumer = AIOKafkaConsumer( + self.request_topic, + bootstrap_servers=self.bootstrap_servers, + group_id=self.consumer_group_id, + value_deserializer=lambda m: json.loads(m.decode('utf-8')), + auto_offset_reset='latest', + enable_auto_commit=True, + **self.kafka_config + ) + await self.consumer.start() + + self._running = True + logger.info(f"Kafka server started, consuming from topic: {self.request_topic}") + + except Exception as e: + await self.stop() + raise ServerError(f"Failed to start Kafka server: {e}") from e + + async def stop(self) -> None: + """Stop the Kafka server application.""" + if not self._running: + return + + self._running = False + + # Cancel consumer task + if self._consumer_task: + self._consumer_task.cancel() + try: + await self._consumer_task + except asyncio.CancelledError: + pass + + # Stop consumer and producer + if self.consumer: + await self.consumer.stop() + if self.producer: + await self.producer.stop() + + logger.info("Kafka server stopped") + + async def run(self) -> None: + """Run the server and start consuming messages. + + This method will block until the server is stopped. + """ + await self.start() + + try: + self._consumer_task = asyncio.create_task(self._consume_requests()) + + # Set up signal handlers for graceful shutdown (Unix only) + import platform + if platform.system() != 'Windows': + loop = asyncio.get_event_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, lambda: asyncio.create_task(self.stop())) + + # Wait for consumer task to complete + await self._consumer_task + + except asyncio.CancelledError: + logger.info("Server run cancelled") + except Exception as e: + logger.error(f"Error in server run: {e}") + raise + finally: + await self.stop() + + async def _consume_requests(self) -> None: + """Consume requests from the request topic.""" + if not self.consumer or not self.handler: + return + + try: + logger.info("Starting to consume requests...") + async for message in self.consumer: + try: + # Convert Kafka message to our KafkaMessage format + kafka_message = KafkaMessage( + headers=message.headers or [], + value=message.value + ) + + # Handle the request + await self.handler.handle_request(kafka_message) + + except Exception as e: + logger.error(f"Error processing message: {e}") + # Continue processing other messages even if one fails + + except asyncio.CancelledError: + logger.debug("Request consumer cancelled") + except KafkaError as e: + logger.error(f"Kafka error in consumer: {e}") + if self._running: + # Try to restart consumer after a delay + await asyncio.sleep(5) + if self._running: + logger.info("Attempting to restart consumer...") + try: + await self.consumer.stop() + await self.consumer.start() + # Recursively call to continue consuming + await self._consume_requests() + except Exception as restart_error: + logger.error(f"Failed to restart consumer: {restart_error}") + except Exception as e: + logger.error(f"Unexpected error in request consumer: {e}") + + # ResponseSender implementation + async def send_response( + self, + reply_topic: str, + correlation_id: str, + result: Any, + response_type: str, + ) -> None: + if not self.producer: + logger.error("Producer not available") + return + try: + response_data = { + "type": response_type, + "data": result.model_dump() if hasattr(result, 'model_dump') else result, + } + headers = [ + ("correlation_id", correlation_id.encode("utf-8")), + ] + await self.producer.send_and_wait( + reply_topic, + value=response_data, + headers=headers, + ) + except Exception as e: + logger.error(f"Failed to send response: {e}") + + async def send_stream_complete( + self, + reply_topic: str, + correlation_id: str, + ) -> None: + if not self.producer: + logger.error("Producer not available") + return + try: + response_data = { + "type": "stream_complete", + "data": {}, + } + headers = [ + ("correlation_id", correlation_id.encode("utf-8")), + ] + await self.producer.send_and_wait( + reply_topic, + value=response_data, + headers=headers, + ) + except Exception as e: + logger.error(f"Failed to send stream completion signal: {e}") + + async def send_error_response( + self, + reply_topic: str, + correlation_id: str, + error_message: str, + ) -> None: + if not self.producer: + logger.error("Producer not available") + return + try: + response_data = { + "type": "error", + "data": {"error": error_message}, + } + headers = [ + ("correlation_id", correlation_id.encode("utf-8")), + ] + await self.producer.send_and_wait( + reply_topic, + value=response_data, + headers=headers, + ) + except Exception as e: + logger.error(f"Failed to send error response: {e}") + + async def get_handler(self) -> KafkaHandler: + """Get the Kafka handler instance. + + This can be used to send push notifications. + """ + if not self.handler: + raise ServerError("Kafka handler not initialized") + return self.handler + + async def send_push_notification( + self, + reply_topic: str, + notification: Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent, + ) -> None: + """Send a push notification to a specific client topic.""" + if not self.producer: + logger.error("Producer not available for push notification") + return + try: + if isinstance(notification, Task): + response_type = "task" + elif isinstance(notification, TaskStatusUpdateEvent): + response_type = "task_status_update" + elif isinstance(notification, TaskArtifactUpdateEvent): + response_type = "task_artifact_update" + else: + response_type = "message" + + response_data = { + "type": f"push_{response_type}", + "data": notification.model_dump() if hasattr(notification, 'model_dump') else notification, + } + headers = [ + ("notification_type", b"push"), + ] + await self.producer.send_and_wait( + reply_topic, + value=response_data, + headers=headers, + ) + logger.debug(f"Sent push notification to {reply_topic}") + except Exception as e: + logger.error(f"Failed to send push notification: {e}") + + async def __aenter__(self): + """Async context manager entry.""" + await self.start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.stop() + + +async def create_kafka_server( + request_handler: RequestHandler, + bootstrap_servers: str | List[str] = "localhost:9092", + request_topic: str = "a2a-requests", + consumer_group_id: str = "a2a-server", + **kafka_config: Any, +) -> KafkaServerApp: + """Create and return a Kafka server application. + + Args: + request_handler: Business logic handler. + bootstrap_servers: Kafka bootstrap servers. + request_topic: Topic to consume requests from. + consumer_group_id: Consumer group ID for the server. + **kafka_config: Additional Kafka configuration. + + Returns: + Configured KafkaServerApp instance. + """ + return KafkaServerApp( + request_handler=request_handler, + bootstrap_servers=bootstrap_servers, + request_topic=request_topic, + consumer_group_id=consumer_group_id, + **kafka_config + ) diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 43ebc8e2..0462654a 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -28,19 +28,32 @@ ) class GrpcHandler: # type: ignore - """Placeholder for GrpcHandler when dependencies are not installed.""" - def __init__(self, *args, **kwargs): raise ImportError( - 'To use GrpcHandler, its dependencies must be installed. ' - 'You can install them with \'pip install "a2a-sdk[grpc]"\'' + 'GrpcHandler requires gRPC dependencies. Install with: pip install a2a-sdk[grpc]' ) from _original_error +try: + from a2a.server.request_handlers.kafka_handler import KafkaHandler +except ImportError as e: + _kafka_error = e + logger.debug( + 'KafkaHandler not loaded. This is expected if Kafka dependencies are not installed. Error: %s', + _kafka_error, + ) + + class KafkaHandler: # type: ignore + def __init__(self, *args, **kwargs): + raise ImportError( + 'KafkaHandler requires Kafka dependencies. Install with: pip install a2a-sdk[kafka]' + ) from _kafka_error + __all__ = [ 'DefaultRequestHandler', 'GrpcHandler', 'JSONRPCHandler', + 'KafkaHandler', 'RESTHandler', 'RequestHandler', 'build_error_response', diff --git a/src/a2a/server/request_handlers/kafka_handler.py b/src/a2a/server/request_handlers/kafka_handler.py new file mode 100644 index 00000000..579543f5 --- /dev/null +++ b/src/a2a/server/request_handlers/kafka_handler.py @@ -0,0 +1,242 @@ +"""Kafka request handler for A2A server (Kafka-agnostic).""" + +import logging +from typing import Any, Dict, List, Optional, Protocol + +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import ( + AgentCard, + DeleteTaskPushNotificationConfigParams, + GetTaskPushNotificationConfigParams, + ListTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, +) +from a2a.utils.errors import ServerError + +logger = logging.getLogger(__name__) + + +class KafkaMessage: + """Represents a Kafka message with headers and value.""" + + def __init__(self, headers: List[tuple[str, bytes]], value: Dict[str, Any]): + self.headers = headers + self.value = value + + def get_header(self, key: str) -> Optional[str]: + """Get header value by key.""" + for header_key, header_value in self.headers: + if header_key == key: + return header_value.decode('utf-8') + return None + + +class ResponseSender(Protocol): + """Protocol for sending responses back to clients.""" + + async def send_response( + self, + reply_topic: str, + correlation_id: str, + result: Any, + response_type: str, + ) -> None: ... + + async def send_error_response( + self, + reply_topic: str, + correlation_id: str, + error_message: str, + ) -> None: ... + + async def send_stream_complete( + self, + reply_topic: str, + correlation_id: str, + ) -> None: ... + + +class KafkaHandler: + """Protocol adapter that parses requests and delegates to business logic. + + Note: This class is intentionally Kafka-agnostic. It does not manage producers + or perform network I/O. All message sending is delegated to `response_sender`. + """ + + def __init__( + self, + request_handler: RequestHandler, + response_sender: ResponseSender, + ) -> None: + """Initialize handler. + + Args: + request_handler: Business logic handler. + response_sender: Callback provider to send responses. + """ + self.request_handler = request_handler + self.response_sender = response_sender + + async def handle_request(self, message: KafkaMessage) -> None: + """Handle incoming Kafka request message. + + This is the core callback function called by the consumer loop. + It extracts metadata, processes the request, and uses `response_sender` + to send the response. + """ + try: + # Extract metadata from headers + reply_topic = message.get_header('reply_topic') + correlation_id = message.get_header('correlation_id') + agent_id = message.get_header('agent_id') + trace_id = message.get_header('trace_id') + + if not reply_topic or not correlation_id: + logger.error("Missing required headers: reply_topic or correlation_id") + return + + # Parse request data + request_data = message.value + method = request_data.get('method') + params = request_data.get('params', {}) + streaming = request_data.get('streaming', False) + agent_card_data = request_data.get('agent_card') + + if not method: + logger.error("Missing method in request") + await self.response_sender.send_error_response( + reply_topic, correlation_id, "Missing method in request" + ) + return + + # Create server call context + context = ServerCallContext( + agent_id=agent_id, + trace_id=trace_id, + ) + + # Parse agent card if provided + agent_card = None + if agent_card_data: + try: + agent_card = AgentCard.model_validate(agent_card_data) + except Exception as e: + logger.error(f"Invalid agent card: {e}") + + # Route request to appropriate handler method + try: + if streaming: + await self._handle_streaming_request( + method, params, reply_topic, correlation_id, context + ) + else: + await self._handle_single_request( + method, params, reply_topic, correlation_id, context + ) + except Exception as e: + logger.error(f"Error handling request {method}: {e}") + await self.response_sender.send_error_response( + reply_topic, correlation_id, f"Request processing error: {e}" + ) + + except Exception as e: + logger.error(f"Error in handle_request: {e}") + + async def _handle_single_request( + self, + method: str, + params: Dict[str, Any], + reply_topic: str, + correlation_id: str, + context: ServerCallContext, + ) -> None: + """Handle a single (non-streaming) request.""" + result = None + response_type = "message" + + try: + if method == "message_send": + request = MessageSendParams.model_validate(params) + result = await self.request_handler.on_message_send(request, context) + response_type = "task" if isinstance(result, Task) else "message" + + elif method == "task_get": + request = TaskQueryParams.model_validate(params) + result = await self.request_handler.on_get_task(request, context) + response_type = "task" + + elif method == "task_cancel": + request = TaskIdParams.model_validate(params) + result = await self.request_handler.on_cancel_task(request, context) + response_type = "task" + + elif method == "task_push_notification_config_get": + request = GetTaskPushNotificationConfigParams.model_validate(params) + result = await self.request_handler.on_get_task_push_notification_config(request, context) + response_type = "task_push_notification_config" + + elif method == "task_push_notification_config_list": + request = ListTaskPushNotificationConfigParams.model_validate(params) + result = await self.request_handler.on_list_task_push_notification_config(request, context) + response_type = "task_push_notification_config_list" + + elif method == "task_push_notification_config_delete": + request = DeleteTaskPushNotificationConfigParams.model_validate(params) + await self.request_handler.on_delete_task_push_notification_config(request, context) + result = {"success": True} + response_type = "success" + + else: + raise ServerError(f"Unknown method: {method}") + + # Send response + await self.response_sender.send_response(reply_topic, correlation_id, result, response_type) + + except Exception as e: + logger.error(f"Error in _handle_single_request for {method}: {e}") + await self.response_sender.send_error_response(reply_topic, correlation_id, str(e)) + + async def _handle_streaming_request( + self, + method: str, + params: Dict[str, Any], + reply_topic: str, + correlation_id: str, + context: ServerCallContext, + ) -> None: + """Handle a streaming request.""" + try: + if method == "message_send": + request = MessageSendParams.model_validate(params) + + # Handle streaming response + async for event in self.request_handler.on_message_send_stream(request, context): + if isinstance(event, TaskStatusUpdateEvent): + response_type = "task_status_update" + elif isinstance(event, TaskArtifactUpdateEvent): + response_type = "task_artifact_update" + elif isinstance(event, Task): + response_type = "task" + else: + response_type = "message" + + await self.response_sender.send_response(reply_topic, correlation_id, event, response_type) + + # Send stream completion signal + await self.response_sender.send_stream_complete(reply_topic, correlation_id) + + else: + raise ServerError(f"Streaming not supported for method: {method}") + + except Exception as e: + logger.error(f"Error in _handle_streaming_request for {method}: {e}") + await self.response_sender.send_error_response(reply_topic, correlation_id, str(e)) + diff --git a/src/a2a/types.py b/src/a2a/types.py index 63db5e66..9a63b540 100644 --- a/src/a2a/types.py +++ b/src/a2a/types.py @@ -1029,6 +1029,7 @@ class TransportProtocol(str, Enum): jsonrpc = 'JSONRPC' grpc = 'GRPC' http_json = 'HTTP+JSON' + kafka = 'KAFKA' class UnsupportedOperationError(A2ABaseModel): @@ -1775,7 +1776,7 @@ class AgentCard(A2ABaseModel): A human-readable name for the agent. """ preferred_transport: str | None = Field( - default='JSONRPC', examples=['JSONRPC', 'GRPC', 'HTTP+JSON'] + default='JSONRPC', examples=['JSONRPC', 'GRPC', 'HTTP+JSON','KAFKA'] ) """ The transport protocol for the preferred endpoint (the main 'url' field). diff --git a/test_handler.py b/test_handler.py new file mode 100644 index 00000000..13637abe --- /dev/null +++ b/test_handler.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +"""测试 ExampleRequestHandler 是否正确实现了所有抽象方法。""" + +import asyncio +import sys +import os + +# 添加 src 目录到路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) + +from kafka_example import ExampleRequestHandler +from a2a.types import MessageSendParams, Message + + +async def test_handler(): + """测试请求处理器是否可以正常实例化和调用。""" + print("测试 ExampleRequestHandler...") + + # 尝试实例化处理器 + try: + handler = ExampleRequestHandler() + print("✓ 成功实例化 ExampleRequestHandler") + except Exception as e: + print(f"✗ 实例化失败: {e}") + return False + + # 测试消息发送 + try: + params = MessageSendParams( + content="测试消息", + role="user" + ) + response = await handler.on_message_send(params) + print(f"✓ on_message_send 正常工作: {response.content}") + except Exception as e: + print(f"✗ on_message_send 失败: {e}") + return False + + # 测试流式消息发送 + try: + params = MessageSendParams( + content="测试流式消息", + role="user" + ) + events = [] + async for event in handler.on_message_send_stream(params): + events.append(event) + print(f"✓ 收到流式事件: {event.content}") + print(f"✓ on_message_send_stream 正常工作,收到 {len(events)} 个事件") + except Exception as e: + print(f"✗ on_message_send_stream 失败: {e}") + return False + + print("✓ 所有测试通过!") + return True + + +if __name__ == "__main__": + success = asyncio.run(test_handler()) + sys.exit(0 if success else 1) diff --git a/test_simple_kafka.py b/test_simple_kafka.py new file mode 100644 index 00000000..6ed84d9e --- /dev/null +++ b/test_simple_kafka.py @@ -0,0 +1,56 @@ +"""简单的 Kafka 传输测试。""" + +import sys +import asyncio +sys.path.append('src') + +from a2a.client.transports.kafka import KafkaClientTransport +from a2a.types import AgentCard, AgentCapabilities, AgentSkill, MessageSendParams + +async def test_kafka_client(): + """测试 Kafka 客户端创建。""" + print("测试 Kafka 客户端创建...") + + # 创建智能体卡片 + agent_card = AgentCard( + name="测试智能体", + description="测试智能体", + url="https://example.com/test-agent", + version="1.0.0", + capabilities=AgentCapabilities(), + default_input_modes=["text/plain"], + default_output_modes=["text/plain"], + skills=[ + AgentSkill( + id="test_skill", + name="test_skill", + description="测试技能", + tags=["test"], + input_modes=["text/plain"], + output_modes=["text/plain"] + ) + ] + ) + + # 创建 Kafka 客户端传输 + transport = KafkaClientTransport( + agent_card=agent_card, + bootstrap_servers="localhost:9092", + request_topic="a2a-requests" + ) + + print(f"Kafka 客户端创建成功") + print(f" 回复主题: {transport.reply_topic}") + print(f" 消费者组: {transport.consumer_group_id}") + + # 测试消息参数创建 + message_params = MessageSendParams( + content="测试消息", + role="user" + ) + print(f"消息参数创建成功: {message_params.content}") + + print("所有测试通过!") + +if __name__ == "__main__": + asyncio.run(test_kafka_client()) diff --git a/tests/client/test_kafka_client.py b/tests/client/test_kafka_client.py new file mode 100644 index 00000000..8b5a7832 --- /dev/null +++ b/tests/client/test_kafka_client.py @@ -0,0 +1,448 @@ +"""Tests for Kafka client transport.""" + +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from a2a.client.transports.kafka import KafkaClientTransport +from a2a.client.transports.kafka_correlation import CorrelationManager +from a2a.client.errors import A2AClientError +from a2a.types import ( + AgentCard, + AgentCapabilities, + Message, + MessageSendParams, + Part, + Role, + Task, + TaskState, + TaskStatus, + TextPart, + TransportProtocol, +) + + +@pytest.fixture +def agent_card(): + """Create test agent card.""" + return AgentCard( + name="Test Agent", + description="Test agent for Kafka transport", + url="kafka://localhost:9092/test-requests", + version="1.0.0", + capabilities=AgentCapabilities(streaming=True), + skills=[], + default_input_modes=["text/plain"], + default_output_modes=["text/plain"], + preferred_transport=TransportProtocol.kafka, + ) + + +@pytest.fixture +def kafka_transport(agent_card): + """Create Kafka transport instance.""" + return KafkaClientTransport( + agent_card=agent_card, + bootstrap_servers="localhost:9092", + request_topic="test-requests", + reply_topic_prefix="test-reply" + ) + + +class TestCorrelationManager: + """Test correlation manager functionality.""" + + @pytest.mark.asyncio + async def test_generate_correlation_id(self): + """Test correlation ID generation.""" + manager = CorrelationManager() + + # Generate multiple IDs + id1 = manager.generate_correlation_id() + id2 = manager.generate_correlation_id() + + # Should be different + assert id1 != id2 + assert len(id1) > 0 + assert len(id2) > 0 + + @pytest.mark.asyncio + async def test_register_and_complete(self): + """Test request registration and completion.""" + manager = CorrelationManager() + correlation_id = manager.generate_correlation_id() + + # Register request + future = await manager.register(correlation_id) + assert not future.done() + assert manager.get_pending_count() == 1 + + # Complete request + result = Message( + message_id="msg-1", + role=Role.assistant, + parts=[Part(root=TextPart(text="test response"))], + ) + completed = await manager.complete(correlation_id, result) + + assert completed is True + assert future.done() + assert await future == result + assert manager.get_pending_count() == 0 + + @pytest.mark.asyncio + async def test_complete_with_exception(self): + """Test completing request with exception.""" + manager = CorrelationManager() + correlation_id = manager.generate_correlation_id() + + # Register request + future = await manager.register(correlation_id) + + # Complete with exception + exception = Exception("test error") + completed = await manager.complete_with_exception(correlation_id, exception) + + assert completed is True + assert future.done() + + with pytest.raises(Exception) as exc_info: + await future + assert str(exc_info.value) == "test error" + + @pytest.mark.asyncio + async def test_cancel_all(self): + """Test cancelling all pending requests.""" + manager = CorrelationManager() + + # Register multiple requests + futures = [] + for i in range(3): + correlation_id = manager.generate_correlation_id() + future = await manager.register(correlation_id) + futures.append(future) + + assert manager.get_pending_count() == 3 + + # Cancel all + await manager.cancel_all() + + assert manager.get_pending_count() == 0 + for future in futures: + assert future.cancelled() + + +class TestKafkaClientTransport: + """Test Kafka client transport functionality.""" + + def test_initialization(self, kafka_transport, agent_card): + """Test transport initialization.""" + assert kafka_transport.agent_card == agent_card + assert kafka_transport.bootstrap_servers == "localhost:9092" + assert kafka_transport.request_topic == "test-requests" + assert kafka_transport.reply_topic is None # Not set until _start() + assert not kafka_transport._running + assert not kafka_transport._auto_start + + @pytest.mark.asyncio + @patch('a2a.client.transports.kafka.AIOKafkaProducer') + @patch('a2a.client.transports.kafka.AIOKafkaConsumer') + async def test_internal_start_stop(self, mock_consumer_class, mock_producer_class, kafka_transport): + """Test internal starting and stopping of the transport.""" + # Mock producer and consumer + mock_producer = AsyncMock() + mock_consumer = AsyncMock() + mock_producer_class.return_value = mock_producer + mock_consumer_class.return_value = mock_consumer + + # Start transport using internal method + await kafka_transport._start() + + assert kafka_transport._running is True + assert kafka_transport.producer == mock_producer + assert kafka_transport.consumer == mock_consumer + # After _start, reply_topic should be generated + assert kafka_transport.reply_topic is not None + assert kafka_transport.reply_topic.startswith("test-reply-Test_Agent-") + mock_producer.start.assert_called_once() + mock_consumer.start.assert_called_once() + + # Stop transport using internal method + await kafka_transport._stop() + + assert kafka_transport._running is False + mock_producer.stop.assert_called_once() + mock_consumer.stop.assert_called_once() + + @pytest.mark.asyncio + @patch('a2a.client.transports.kafka.AIOKafkaProducer') + @patch('a2a.client.transports.kafka.AIOKafkaConsumer') + async def test_send_message(self, mock_consumer_class, mock_producer_class, kafka_transport): + """Test sending a message.""" + # Mock producer and consumer + mock_producer = AsyncMock() + mock_consumer = AsyncMock() + mock_producer_class.return_value = mock_producer + mock_consumer_class.return_value = mock_consumer + + # Start transport + await kafka_transport.start() + + # Mock correlation manager + with patch.object(kafka_transport.correlation_manager, 'generate_correlation_id') as mock_gen_id, \ + patch.object(kafka_transport.correlation_manager, 'register') as mock_register: + + mock_gen_id.return_value = "test-correlation-id" + + # Create a future that resolves to a response + response = Message( + message_id="msg-1", + role=Role.assistant, + parts=[Part(root=TextPart(text="test response"))], + ) + future = asyncio.Future() + future.set_result(response) + mock_register.return_value = future + + # Send message + request = MessageSendParams( + message=Message( + message_id="msg-1", + role=Role.user, + parts=[Part(root=TextPart(text="test message"))], + ) + ) + result = await kafka_transport.send_message(request) + + # Verify result + assert result == response + + # Verify producer was called + mock_producer.send_and_wait.assert_called_once() + call_args = mock_producer.send_and_wait.call_args + + assert call_args[0][0] == "test-requests" # topic + assert call_args[1]['value']['method'] == 'message_send' + assert 'params' in call_args[1]['value'] + # Verify the message structure is properly serialized + params = call_args[1]['value']['params'] + assert 'message' in params + + # Check headers + headers = call_args[1]['headers'] + header_dict = {k: v.decode('utf-8') for k, v in headers} + assert header_dict['correlation_id'] == 'test-correlation-id' + assert 'reply_topic' in header_dict + assert header_dict['reply_topic'] is not None + + def test_parse_response(self, kafka_transport): + """Test response parsing.""" + # Test message response + message_data = { + 'type': 'message', + 'data': { + 'message_id': 'msg-1', + 'role': 'assistant', + 'parts': [{'root': {'text': 'test response', 'type': 'text'}}] + } + } + result = kafka_transport._parse_response(message_data) + assert isinstance(result, Message) + assert result.message_id == 'msg-1' + assert result.role == Role.assistant + + # Test task response + task_data = { + 'type': 'task', + 'data': { + 'id': 'task-123', + 'context_id': 'ctx-456', + 'status': {'state': 'completed'} + } + } + result = kafka_transport._parse_response(task_data) + assert isinstance(result, Task) + assert result.id == 'task-123' + + # Test default case (should default to message) + default_data = { + 'data': { + 'message_id': 'msg-2', + 'role': 'assistant', + 'parts': [{'root': {'text': 'default response', 'type': 'text'}}] + } + } + result = kafka_transport._parse_response(default_data) + assert isinstance(result, Message) + assert result.message_id == 'msg-2' + + @pytest.mark.asyncio + async def test_context_manager(self, kafka_transport): + """Test async context manager.""" + with patch.object(kafka_transport, '_start') as mock_start, \ + patch.object(kafka_transport, '_stop') as mock_stop: + + async with kafka_transport: + mock_start.assert_called_once() + + mock_stop.assert_called_once() + + @pytest.mark.asyncio + async def test_send_message_timeout(self, kafka_transport): + """Test send message with timeout.""" + with patch.object(kafka_transport, '_send_request') as mock_send, \ + patch.object(kafka_transport.correlation_manager, 'register') as mock_register: + + # Create a future that never resolves + future = asyncio.Future() + mock_register.return_value = future + mock_send.return_value = "test-correlation-id" + + request = MessageSendParams( + message=Message( + message_id="msg-1", + role=Role.user, + parts=[Part(root=TextPart(text="test message"))], + ) + ) + + # Should timeout + with pytest.raises(A2AClientError, match="Request timed out"): + await asyncio.wait_for( + kafka_transport.send_message(request), + timeout=0.1 + ) + + @pytest.mark.asyncio + @patch('a2a.client.transports.kafka.AIOKafkaProducer') + @patch('a2a.client.transports.kafka.AIOKafkaConsumer') + async def test_send_message_streaming(self, mock_consumer_class, mock_producer_class, kafka_transport): + """Test streaming message sending.""" + # Mock producer and consumer + mock_producer = AsyncMock() + mock_consumer = AsyncMock() + mock_producer_class.return_value = mock_producer + mock_consumer_class.return_value = mock_consumer + + # Start transport + await kafka_transport.start() + + # Mock correlation manager for streaming + with patch.object(kafka_transport.correlation_manager, 'generate_correlation_id') as mock_gen_id, \ + patch.object(kafka_transport.correlation_manager, 'register_streaming') as mock_register: + + mock_gen_id.return_value = "test-correlation-id" + + # Create a streaming future that yields responses + from a2a.client.transports.kafka_correlation import StreamingFuture + streaming_future = StreamingFuture() + mock_register.return_value = streaming_future + + # Send streaming message + request = MessageSendParams( + message=Message( + message_id="msg-1", + role=Role.user, + parts=[Part(root=TextPart(text="test message"))], + ) + ) + + # Start the streaming request + stream = kafka_transport.send_message_streaming(request) + + # Simulate receiving responses + response1 = Message( + message_id="msg-2", + role=Role.assistant, + parts=[Part(root=TextPart(text="response 1"))], + ) + response2 = Message( + message_id="msg-3", + role=Role.assistant, + parts=[Part(root=TextPart(text="response 2"))], + ) + + # Put responses in the streaming future + await streaming_future.put(response1) + await streaming_future.put(response2) + streaming_future.set_done() + + # Collect responses + responses = [] + async for response in stream: + responses.append(response) + if len(responses) >= 2: # Prevent infinite loop + break + + assert len(responses) == 2 + assert responses[0] == response1 + assert responses[1] == response2 + + def test_sanitize_topic_name(self, kafka_transport): + """Test topic name sanitization.""" + # Test normal name + assert kafka_transport._sanitize_topic_name("test-agent") == "test-agent" + + # Test name with invalid characters + assert kafka_transport._sanitize_topic_name("test@agent#123") == "test_agent_123" + + # Test empty name + assert kafka_transport._sanitize_topic_name("") == "unknown_agent" + + # Test very long name + long_name = "a" * 300 + sanitized = kafka_transport._sanitize_topic_name(long_name) + assert len(sanitized) <= 200 + + def test_create_classmethod(self, agent_card): + """Test the create class method.""" + # Test with full URL + transport = KafkaClientTransport.create( + agent_card=agent_card, + url="kafka://localhost:9092/custom-topic", + config=None, + interceptors=[] + ) + assert transport.bootstrap_servers == "localhost:9092" + assert transport.request_topic == "custom-topic" + assert not transport._auto_start # Should be False by default + + # Test with URL without topic (should use default) + transport = KafkaClientTransport.create( + agent_card=agent_card, + url="kafka://localhost:9092", + config=None, + interceptors=[] + ) + assert transport.bootstrap_servers == "localhost:9092" + assert transport.request_topic == "a2a-requests" + + # Test invalid URL + with pytest.raises(ValueError, match="Kafka URL must start with 'kafka://'"): + KafkaClientTransport.create( + agent_card=agent_card, + url="http://localhost:9092", + config=None, + interceptors=[] + ) + + +@pytest.mark.integration +class TestKafkaIntegration: + """Integration tests for Kafka transport (requires running Kafka).""" + + @pytest.mark.skip(reason="Requires running Kafka instance") + @pytest.mark.asyncio + async def test_real_kafka_connection(self, agent_card): + """Test connection to real Kafka instance.""" + transport = KafkaClientTransport( + agent_card=agent_card, + bootstrap_servers="localhost:9092" + ) + + try: + await transport._start() + assert transport._running is True + finally: + await transport._stop() + assert transport._running is False