diff --git a/src/a2a/server/apps/jsonrpc/fastapi_app.py b/src/a2a/server/apps/jsonrpc/fastapi_app.py index ace2c6ae..2293e560 100644 --- a/src/a2a/server/apps/jsonrpc/fastapi_app.py +++ b/src/a2a/server/apps/jsonrpc/fastapi_app.py @@ -78,6 +78,7 @@ def __init__( # noqa: PLR0913 ] | None = None, max_content_length: int | None = 10 * 1024 * 1024, # 10MB + stream_send_timeout: float | None = None, ) -> None: """Initializes the A2AFastAPIApplication. @@ -97,6 +98,10 @@ def __init__( # noqa: PLR0913 call context. max_content_length: The maximum allowed content length for incoming requests. Defaults to 10MB. Set to None for unbounded maximum. + stream_send_timeout: The timeout in seconds for sending events in + streaming responses. Defaults to `None`, which disables the timeout. + This changes the default behavior from using Starlette's 5-second + default. Set a float value to specify a timeout. """ if not _package_fastapi_installed: raise ImportError( @@ -112,6 +117,7 @@ def __init__( # noqa: PLR0913 card_modifier=card_modifier, extended_card_modifier=extended_card_modifier, max_content_length=max_content_length, + stream_send_timeout=stream_send_timeout, ) def add_routes_to_app( diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index 3e7c2854..fdc71b23 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -184,6 +184,7 @@ def __init__( # noqa: PLR0913 ] | None = None, max_content_length: int | None = 10 * 1024 * 1024, # 10MB + stream_send_timeout: float | None = None, ) -> None: """Initializes the JSONRPCApplication. @@ -203,6 +204,10 @@ def __init__( # noqa: PLR0913 call context. max_content_length: The maximum allowed content length for incoming requests. Defaults to 10MB. Set to None for unbounded maximum. + stream_send_timeout: The timeout in seconds for sending events in + streaming responses. Defaults to `None`, which disables the timeout. + This changes the default behavior from using Starlette's 5-second + default. Set a float value to specify a timeout. """ if not _package_starlette_installed: raise ImportError( @@ -222,6 +227,7 @@ def __init__( # noqa: PLR0913 ) self._context_builder = context_builder or DefaultCallContextBuilder() self._max_content_length = max_content_length + self.stream_send_timeout = stream_send_timeout def _generate_error_response( self, request_id: str | int | None, error: JSONRPCError | A2AError @@ -540,8 +546,14 @@ async def event_generator( async for item in stream: yield {'data': item.root.model_dump_json(exclude_none=True)} + send_timeout = context.state.get( + 'stream_send_timeout', self.stream_send_timeout + ) + return EventSourceResponse( - event_generator(handler_result), headers=headers + event_generator(handler_result), + headers=headers, + send_timeout=send_timeout, ) if isinstance(handler_result, JSONRPCErrorResponse): return JSONResponse( diff --git a/src/a2a/server/apps/jsonrpc/starlette_app.py b/src/a2a/server/apps/jsonrpc/starlette_app.py index 1effa9d5..ad767ef3 100644 --- a/src/a2a/server/apps/jsonrpc/starlette_app.py +++ b/src/a2a/server/apps/jsonrpc/starlette_app.py @@ -60,6 +60,7 @@ def __init__( # noqa: PLR0913 ] | None = None, max_content_length: int | None = 10 * 1024 * 1024, # 10MB + stream_send_timeout: float | None = None, ) -> None: """Initializes the A2AStarletteApplication. @@ -79,6 +80,10 @@ def __init__( # noqa: PLR0913 call context. max_content_length: The maximum allowed content length for incoming requests. Defaults to 10MB. Set to None for unbounded maximum. + stream_send_timeout: The timeout in seconds for sending events in + streaming responses. Defaults to `None`, which disables the timeout. + This changes the default behavior from using Starlette's 5-second + default. Set a float value to specify a timeout. """ if not _package_starlette_installed: raise ImportError( @@ -94,6 +99,7 @@ def __init__( # noqa: PLR0913 card_modifier=card_modifier, extended_card_modifier=extended_card_modifier, max_content_length=max_content_length, + stream_send_timeout=stream_send_timeout, ) def routes( diff --git a/tests/server/apps/jsonrpc/test_fastapi_app.py b/tests/server/apps/jsonrpc/test_fastapi_app.py index ddb68691..ee1bfc9d 100644 --- a/tests/server/apps/jsonrpc/test_fastapi_app.py +++ b/tests/server/apps/jsonrpc/test_fastapi_app.py @@ -75,6 +75,23 @@ def test_create_a2a_fastapi_app_with_missing_deps_raises_importerror( ): _app = A2AFastAPIApplication(**mock_app_params) + def test_stream_send_timeout_parameter(self, mock_app_params: dict): + try: + app_default = A2AFastAPIApplication(**mock_app_params) + assert app_default.stream_send_timeout is None + + app_custom = A2AFastAPIApplication( + **mock_app_params, stream_send_timeout=30.0 + ) + assert app_custom.stream_send_timeout == 30.0 + + app_none = A2AFastAPIApplication( + **mock_app_params, stream_send_timeout=None + ) + assert app_none.stream_send_timeout is None + except ImportError: + pytest.skip('FastAPI dependencies not available') + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/server/apps/jsonrpc/test_jsonrpc_app.py b/tests/server/apps/jsonrpc/test_jsonrpc_app.py index 36309872..3ebf2166 100644 --- a/tests/server/apps/jsonrpc/test_jsonrpc_app.py +++ b/tests/server/apps/jsonrpc/test_jsonrpc_app.py @@ -184,6 +184,192 @@ def build( ): _app = DummyJSONRPCApp(**mock_app_params) + @pytest.mark.asyncio + async def test_stream_send_timeout_applied_to_event_source_response( + self, mock_app_params: dict + ): + """Test that stream_send_timeout is correctly applied to EventSourceResponse.""" + from unittest.mock import patch + + class DummyJSONRPCApp(JSONRPCApplication): + def __init__(self, **kwargs): + # Skip parent __init__ to avoid package checks + self.stream_send_timeout = kwargs.get( + 'stream_send_timeout', None + ) + self.agent_card = kwargs.get('agent_card') + self.extended_agent_card = kwargs.get('extended_agent_card') + self.card_modifier = kwargs.get('card_modifier') + self.extended_card_modifier = kwargs.get( + 'extended_card_modifier' + ) + self.max_content_length = kwargs.get( + 'max_content_length', 10 * 1024 * 1024 + ) + self.handler = kwargs.get('http_handler') + + def build(self, **kwargs): + return object() + + # Test with app-level timeout + app = DummyJSONRPCApp(stream_send_timeout=30.0, **mock_app_params) + + # Mock context + context = MagicMock(spec=ServerCallContext) + context.state = {} + context.activated_extensions = None + + # Mock streaming handler result + async def mock_generator(): + yield SendMessageResponse( + root=SendMessageSuccessResponse( + message=Message( + message_id='1', + role=Role.assistant, + parts=[Part(TextPart(text='test'))], + ) + ) + ) + + handler_result = mock_generator() + + with patch( + 'a2a.server.apps.jsonrpc.jsonrpc_app.EventSourceResponse' + ) as mock_esr: + mock_esr.return_value = MagicMock() + + # Call the method + response = app._create_response(context, handler_result) + + # Assert EventSourceResponse was called with correct timeout + mock_esr.assert_called_once() + call_args = mock_esr.call_args + assert call_args[1]['send_timeout'] == 30.0 + + @pytest.mark.asyncio + async def test_stream_send_timeout_none_disables_timeout( + self, mock_app_params: dict + ): + """Test that stream_send_timeout=None disables the timeout.""" + from unittest.mock import patch + + class DummyJSONRPCApp(JSONRPCApplication): + def __init__(self, **kwargs): + # Skip parent __init__ to avoid package checks + self.stream_send_timeout = kwargs.get( + 'stream_send_timeout', None + ) + self.agent_card = kwargs.get('agent_card') + self.extended_agent_card = kwargs.get('extended_agent_card') + self.card_modifier = kwargs.get('card_modifier') + self.extended_card_modifier = kwargs.get( + 'extended_card_modifier' + ) + self.max_content_length = kwargs.get( + 'max_content_length', 10 * 1024 * 1024 + ) + self.handler = kwargs.get('http_handler') + + def build(self, **kwargs): + return object() + + # Test with None timeout (default) + app = DummyJSONRPCApp(stream_send_timeout=None, **mock_app_params) + + # Mock context + context = MagicMock(spec=ServerCallContext) + context.state = {} + context.activated_extensions = None + + # Mock streaming handler result + async def mock_generator(): + yield SendMessageResponse( + root=SendMessageSuccessResponse( + message=Message( + message_id='1', + role=Role.assistant, + parts=[Part(TextPart(text='test'))], + ) + ) + ) + + handler_result = mock_generator() + + with patch( + 'a2a.server.apps.jsonrpc.jsonrpc_app.EventSourceResponse' + ) as mock_esr: + mock_esr.return_value = MagicMock() + + # Call the method + response = app._create_response(context, handler_result) + + # Assert EventSourceResponse was called with None (disabled timeout) + mock_esr.assert_called_once() + call_args = mock_esr.call_args + assert call_args[1]['send_timeout'] is None + + @pytest.mark.asyncio + async def test_stream_send_timeout_context_override( + self, mock_app_params: dict + ): + """Test that context.state can override the app-level stream_send_timeout.""" + from unittest.mock import patch + + class DummyJSONRPCApp(JSONRPCApplication): + def __init__(self, **kwargs): + # Skip parent __init__ to avoid package checks + self.stream_send_timeout = kwargs.get( + 'stream_send_timeout', None + ) + self.agent_card = kwargs.get('agent_card') + self.extended_agent_card = kwargs.get('extended_agent_card') + self.card_modifier = kwargs.get('card_modifier') + self.extended_card_modifier = kwargs.get( + 'extended_card_modifier' + ) + self.max_content_length = kwargs.get( + 'max_content_length', 10 * 1024 * 1024 + ) + self.handler = kwargs.get('http_handler') + + def build(self, **kwargs): + return object() + + # Test with app-level timeout + app = DummyJSONRPCApp(stream_send_timeout=30.0, **mock_app_params) + + # Mock context with override + context = MagicMock(spec=ServerCallContext) + context.state = {'stream_send_timeout': 60.0} + context.activated_extensions = None + + # Mock streaming handler result + async def mock_generator(): + yield SendMessageResponse( + root=SendMessageSuccessResponse( + message=Message( + message_id='1', + role=Role.assistant, + parts=[Part(TextPart(text='test'))], + ) + ) + ) + + handler_result = mock_generator() + + with patch( + 'a2a.server.apps.jsonrpc.jsonrpc_app.EventSourceResponse' + ) as mock_esr: + mock_esr.return_value = MagicMock() + + # Call the method + response = app._create_response(context, handler_result) + + # Assert EventSourceResponse was called with context override + mock_esr.assert_called_once() + call_args = mock_esr.call_args + assert call_args[1]['send_timeout'] == 60.0 + class TestJSONRPCExtensions: @pytest.fixture diff --git a/tests/server/apps/jsonrpc/test_starlette_app.py b/tests/server/apps/jsonrpc/test_starlette_app.py index 6a1472c8..bbdf5293 100644 --- a/tests/server/apps/jsonrpc/test_starlette_app.py +++ b/tests/server/apps/jsonrpc/test_starlette_app.py @@ -77,6 +77,23 @@ def test_create_a2a_starlette_app_with_missing_deps_raises_importerror( ): _app = A2AStarletteApplication(**mock_app_params) + def test_stream_send_timeout_parameter(self, mock_app_params: dict): + try: + app_default = A2AStarletteApplication(**mock_app_params) + assert app_default.stream_send_timeout is None + + app_custom = A2AStarletteApplication( + **mock_app_params, stream_send_timeout=30.0 + ) + assert app_custom.stream_send_timeout == 30.0 + + app_none = A2AStarletteApplication( + **mock_app_params, stream_send_timeout=None + ) + assert app_none.stream_send_timeout is None + except ImportError: + pytest.skip('Starlette dependencies not available') + if __name__ == '__main__': pytest.main([__file__])