Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions src/a2a/server/apps/jsonrpc/fastapi_app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import logging

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any

from fastapi import FastAPI
Expand All @@ -21,6 +19,28 @@
logger = logging.getLogger(__name__)


class A2AFastAPI(FastAPI):
"""A FastAPI application that adds A2A-specific OpenAPI components."""

_a2a_components_added: bool = False

def openapi(self) -> dict[str, Any]:
"""Generates the OpenAPI schema for the application."""
openapi_schema = super().openapi()
if not self._a2a_components_added:
a2a_request_schema = A2ARequest.model_json_schema(
ref_template='#/components/schemas/{model}'
)
defs = a2a_request_schema.pop('$defs', {})
component_schemas = openapi_schema.setdefault(
'components', {}
).setdefault('schemas', {})
component_schemas.update(defs)
component_schemas['A2ARequest'] = a2a_request_schema
self._a2a_components_added = True
return openapi_schema


class A2AFastAPIApplication(JSONRPCApplication):
"""A FastAPI application implementing the A2A protocol server endpoints.

Expand Down Expand Up @@ -92,23 +112,7 @@ def build(
Returns:
A configured FastAPI application instance.
"""

@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
a2a_request_schema = A2ARequest.model_json_schema(
ref_template='#/components/schemas/{model}'
)
defs = a2a_request_schema.pop('$defs', {})
openapi_schema = app.openapi()
component_schemas = openapi_schema.setdefault(
'components', {}
).setdefault('schemas', {})
component_schemas.update(defs)
component_schemas['A2ARequest'] = a2a_request_schema

yield

app = FastAPI(lifespan=lifespan, **kwargs)
app = A2AFastAPI(**kwargs)

self.add_routes_to_app(
app, agent_card_url, rpc_url, extended_agent_card_url
Expand Down
19 changes: 19 additions & 0 deletions tests/server/apps/jsonrpc/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import mock

import pytest
from fastapi import FastAPI

from pydantic import ValidationError
from starlette.testclient import TestClient
Expand Down Expand Up @@ -183,3 +184,21 @@ def test_handle_unicode_characters(agent_card_with_api_key: AgentCard):
data = response.json()
assert 'error' not in data or data['error'] is None
assert data['result']['parts'][0]['text'] == f'Received: {unicode_text}'


def test_fastapi_sub_application(agent_card_with_api_key: AgentCard):
"""
Tests that the A2AFastAPIApplication endpoint correctly passes the url in sub-application.
"""
handler = mock.AsyncMock()
sub_app_instance = A2AFastAPIApplication(agent_card_with_api_key, handler)
app_instance = FastAPI()
app_instance.mount('/a2a', sub_app_instance.build())
client = TestClient(app_instance)

response = client.get('/a2a/openapi.json')
assert response.status_code == 200
response_data = response.json()

assert 'servers' in response_data
assert response_data['servers'] == [{'url': '/a2a'}]
Loading