diff --git a/src/a2a/server/apps/jsonrpc/fastapi_app.py b/src/a2a/server/apps/jsonrpc/fastapi_app.py index c18cde3d..4918eafc 100644 --- a/src/a2a/server/apps/jsonrpc/fastapi_app.py +++ b/src/a2a/server/apps/jsonrpc/fastapi_app.py @@ -1,7 +1,5 @@ import logging -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager from typing import Any from fastapi import FastAPI @@ -21,6 +19,28 @@ logger = logging.getLogger(__name__) +class A2AFastAPI(FastAPI): + """A FastAPI application that adds A2A-specific OpenAPI components.""" + + _a2a_components_added: bool = False + + def openapi(self) -> dict[str, Any]: + """Generates the OpenAPI schema for the application.""" + openapi_schema = super().openapi() + if not self._a2a_components_added: + a2a_request_schema = A2ARequest.model_json_schema( + ref_template='#/components/schemas/{model}' + ) + defs = a2a_request_schema.pop('$defs', {}) + component_schemas = openapi_schema.setdefault( + 'components', {} + ).setdefault('schemas', {}) + component_schemas.update(defs) + component_schemas['A2ARequest'] = a2a_request_schema + self._a2a_components_added = True + return openapi_schema + + class A2AFastAPIApplication(JSONRPCApplication): """A FastAPI application implementing the A2A protocol server endpoints. @@ -92,23 +112,7 @@ def build( Returns: A configured FastAPI application instance. """ - - @asynccontextmanager - async def lifespan(app: FastAPI) -> AsyncIterator[None]: - a2a_request_schema = A2ARequest.model_json_schema( - ref_template='#/components/schemas/{model}' - ) - defs = a2a_request_schema.pop('$defs', {}) - openapi_schema = app.openapi() - component_schemas = openapi_schema.setdefault( - 'components', {} - ).setdefault('schemas', {}) - component_schemas.update(defs) - component_schemas['A2ARequest'] = a2a_request_schema - - yield - - app = FastAPI(lifespan=lifespan, **kwargs) + app = A2AFastAPI(**kwargs) self.add_routes_to_app( app, agent_card_url, rpc_url, extended_agent_card_url diff --git a/tests/server/apps/jsonrpc/test_serialization.py b/tests/server/apps/jsonrpc/test_serialization.py index 663de5e0..1670bb96 100644 --- a/tests/server/apps/jsonrpc/test_serialization.py +++ b/tests/server/apps/jsonrpc/test_serialization.py @@ -1,6 +1,7 @@ from unittest import mock import pytest +from fastapi import FastAPI from pydantic import ValidationError from starlette.testclient import TestClient @@ -183,3 +184,21 @@ def test_handle_unicode_characters(agent_card_with_api_key: AgentCard): data = response.json() assert 'error' not in data or data['error'] is None assert data['result']['parts'][0]['text'] == f'Received: {unicode_text}' + + +def test_fastapi_sub_application(agent_card_with_api_key: AgentCard): + """ + Tests that the A2AFastAPIApplication endpoint correctly passes the url in sub-application. + """ + handler = mock.AsyncMock() + sub_app_instance = A2AFastAPIApplication(agent_card_with_api_key, handler) + app_instance = FastAPI() + app_instance.mount('/a2a', sub_app_instance.build()) + client = TestClient(app_instance) + + response = client.get('/a2a/openapi.json') + assert response.status_code == 200 + response_data = response.json() + + assert 'servers' in response_data + assert response_data['servers'] == [{'url': '/a2a'}]