From 2db740e075ceedf3b365f22b91de4d27e1807b80 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Mon, 22 Sep 2025 17:19:19 +0100 Subject: [PATCH 1/2] feat: ServerCallContext ports from Python implementation Port server-side extension support and implement complete gRPC context access equivalent to Python's ServicerContext. Adds client/server metadata support, rich context information, and interceptor infrastructure while maintaining backward compatibility. --- .../client/transport/grpc/GrpcTransport.java | 36 ++++- .../main/java/io/a2a/common/A2AHeaders.java | 17 +++ reference/grpc/pom.xml | 4 + .../quarkus/A2AExtensionsInterceptor.java | 70 +++++++++ .../grpc/quarkus/QuarkusGrpcHandler.java | 2 + .../server/apps/quarkus/A2AServerRoutes.java | 9 +- .../server/rest/quarkus/A2AServerRoutes.java | 10 +- .../java/io/a2a/server/ServerCallContext.java | 32 ++++- .../a2a/server/extensions/A2AExtensions.java | 52 +++++++ .../DefaultRequestHandler.java | 9 +- .../io/a2a/server/ServerCallContextTest.java | 133 ++++++++++++++++++ .../server/extensions/A2AExtensionsTest.java | 127 +++++++++++++++++ .../DefaultRequestHandlerBackgroundTest.java | 18 +-- .../grpc/context/GrpcContextKeys.java | 45 ++++++ .../transport/grpc/handler/GrpcHandler.java | 123 ++++++++++++++-- .../jsonrpc/handler/JSONRPCHandlerTest.java | 3 +- .../rest/handler/RestHandlerTest.java | 3 +- 17 files changed, 659 insertions(+), 34 deletions(-) create mode 100644 common/src/main/java/io/a2a/common/A2AHeaders.java create mode 100644 reference/grpc/src/main/java/io/a2a/server/grpc/quarkus/A2AExtensionsInterceptor.java create mode 100644 server-common/src/main/java/io/a2a/server/extensions/A2AExtensions.java create mode 100644 server-common/src/test/java/io/a2a/server/ServerCallContextTest.java create mode 100644 server-common/src/test/java/io/a2a/server/extensions/A2AExtensionsTest.java create mode 100644 transport/grpc/src/main/java/io/a2a/transport/grpc/context/GrpcContextKeys.java diff --git a/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java b/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java index b313fb43a..3d14ffe4a 100644 --- a/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java +++ b/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java @@ -12,6 +12,7 @@ import io.a2a.client.transport.spi.interceptors.ClientCallContext; import io.a2a.client.transport.spi.ClientTransport; +import io.a2a.common.A2AHeaders; import io.a2a.grpc.A2AServiceGrpc; import io.a2a.grpc.CancelTaskRequest; import io.a2a.grpc.CreateTaskPushNotificationConfigRequest; @@ -37,8 +38,9 @@ import io.a2a.spec.TaskPushNotificationConfig; import io.a2a.spec.TaskQueryParams; import io.grpc.Channel; - +import io.grpc.Metadata; import io.grpc.StatusRuntimeException; +import io.grpc.stub.MetadataUtils; import io.grpc.stub.StreamObserver; public class GrpcTransport implements ClientTransport { @@ -59,9 +61,12 @@ public EventKind sendMessage(MessageSendParams request, ClientCallContext contex checkNotNullParam("request", request); SendMessageRequest sendMessageRequest = createGrpcSendMessageRequest(request, context); + Metadata metadata = createGrpcMetadata(context); try { - SendMessageResponse response = blockingStub.sendMessage(sendMessageRequest); + // Create a stub with metadata attached + A2AServiceBlockingV2Stub stubWithMetadata = blockingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)); + SendMessageResponse response = stubWithMetadata.sendMessage(sendMessageRequest); if (response.hasMsg()) { return FromProto.message(response.getMsg()); } else if (response.hasTask()) { @@ -80,10 +85,13 @@ public void sendMessageStreaming(MessageSendParams request, Consumer streamObserver = new EventStreamObserver(eventConsumer, errorConsumer); try { - asyncStub.sendStreamingMessage(grpcRequest, streamObserver); + // Create a stub with metadata attached + A2AServiceStub stubWithMetadata = asyncStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)); + stubWithMetadata.sendStreamingMessage(grpcRequest, streamObserver); } catch (StatusRuntimeException e) { throw GrpcErrorMapper.mapGrpcError(e, "Failed to send streaming message request: "); } @@ -234,6 +242,28 @@ private SendMessageRequest createGrpcSendMessageRequest(MessageSendParams messag return builder.build(); } + /** + * Creates gRPC metadata from ClientCallContext headers. + * Extracts headers like X-A2A-Extensions and sets them as gRPC metadata. + */ + private Metadata createGrpcMetadata(ClientCallContext context) { + Metadata metadata = new Metadata(); + + if (context != null && context.getHeaders() != null) { + // Set X-A2A-Extensions header if present + String extensionsHeader = context.getHeaders().get(A2AHeaders.X_A2A_EXTENSIONS); + if (extensionsHeader != null) { + Metadata.Key extensionsKey = Metadata.Key.of(A2AHeaders.X_A2A_EXTENSIONS, Metadata.ASCII_STRING_MARSHALLER); + metadata.put(extensionsKey, extensionsHeader); + } + + // Add other headers as needed in the future + // For now, we only handle X-A2A-Extensions + } + + return metadata; + } + private String getTaskPushNotificationConfigName(GetTaskPushNotificationConfigParams params) { return getTaskPushNotificationConfigName(params.id(), params.pushNotificationConfigId()); } diff --git a/common/src/main/java/io/a2a/common/A2AHeaders.java b/common/src/main/java/io/a2a/common/A2AHeaders.java new file mode 100644 index 000000000..5118a4365 --- /dev/null +++ b/common/src/main/java/io/a2a/common/A2AHeaders.java @@ -0,0 +1,17 @@ +package io.a2a.common; + +/** + * Common A2A protocol headers and constants. + */ +public final class A2AHeaders { + + /** + * HTTP header name for A2A extensions. + * Used to communicate which extensions are requested by the client. + */ + public static final String X_A2A_EXTENSIONS = "X-A2A-Extensions"; + + private A2AHeaders() { + // Utility class + } +} diff --git a/reference/grpc/pom.xml b/reference/grpc/pom.xml index 8623b4154..4d8a08e9d 100644 --- a/reference/grpc/pom.xml +++ b/reference/grpc/pom.xml @@ -19,6 +19,10 @@ ${project.groupId} a2a-java-sdk-reference-common + + ${project.groupId} + a2a-java-sdk-common + ${project.groupId} a2a-java-sdk-transport-grpc diff --git a/reference/grpc/src/main/java/io/a2a/server/grpc/quarkus/A2AExtensionsInterceptor.java b/reference/grpc/src/main/java/io/a2a/server/grpc/quarkus/A2AExtensionsInterceptor.java new file mode 100644 index 000000000..9f0559cdb --- /dev/null +++ b/reference/grpc/src/main/java/io/a2a/server/grpc/quarkus/A2AExtensionsInterceptor.java @@ -0,0 +1,70 @@ +package io.a2a.server.grpc.quarkus; + +import jakarta.enterprise.context.ApplicationScoped; +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.a2a.common.A2AHeaders; +import io.a2a.transport.grpc.context.GrpcContextKeys; + +/** + * gRPC server interceptor that captures request metadata and context information, + * providing equivalent functionality to Python's grpc.aio.ServicerContext. + * + * This interceptor: + * - Extracts A2A extension headers from incoming requests + * - Captures ServerCall and Metadata for rich context access + * - Stores context information in gRPC Context for service method access + * - Provides proper equivalence to Python's ServicerContext + */ +@ApplicationScoped +public class A2AExtensionsInterceptor implements ServerInterceptor { + + + @Override + public ServerCall.Listener interceptCall( + ServerCall serverCall, + Metadata metadata, + ServerCallHandler serverCallHandler) { + + // Extract A2A extensions header + Metadata.Key extensionsKey = + Metadata.Key.of(A2AHeaders.X_A2A_EXTENSIONS, Metadata.ASCII_STRING_MARSHALLER); + String extensions = metadata.get(extensionsKey); + + // Create enhanced context with rich information (equivalent to Python's ServicerContext) + Context context = Context.current() + // Store complete metadata for full header access + .withValue(GrpcContextKeys.METADATA_KEY, metadata) + // Store method name (equivalent to Python's context.method()) + .withValue(GrpcContextKeys.METHOD_NAME_KEY, serverCall.getMethodDescriptor().getFullMethodName()) + // Store peer information for client connection details + .withValue(GrpcContextKeys.PEER_INFO_KEY, getPeerInfo(serverCall)); + + // Store A2A extensions if present + if (extensions != null) { + context = context.withValue(GrpcContextKeys.EXTENSIONS_HEADER_KEY, extensions); + } + + // Proceed with the call in the enhanced context + return Contexts.interceptCall(context, serverCall, metadata, serverCallHandler); + } + + /** + * Safely extracts peer information from the ServerCall. + * + * @param serverCall the gRPC ServerCall + * @return peer information string, or "unknown" if not available + */ + private String getPeerInfo(ServerCall serverCall) { + try { + Object remoteAddr = serverCall.getAttributes().get(io.grpc.Grpc.TRANSPORT_ATTR_REMOTE_ADDR); + return remoteAddr != null ? remoteAddr.toString() : "unknown"; + } catch (Exception e) { + return "unknown"; + } + } +} diff --git a/reference/grpc/src/main/java/io/a2a/server/grpc/quarkus/QuarkusGrpcHandler.java b/reference/grpc/src/main/java/io/a2a/server/grpc/quarkus/QuarkusGrpcHandler.java index 9cfba609a..355416441 100644 --- a/reference/grpc/src/main/java/io/a2a/server/grpc/quarkus/QuarkusGrpcHandler.java +++ b/reference/grpc/src/main/java/io/a2a/server/grpc/quarkus/QuarkusGrpcHandler.java @@ -9,8 +9,10 @@ import io.a2a.transport.grpc.handler.CallContextFactory; import io.a2a.transport.grpc.handler.GrpcHandler; import io.quarkus.grpc.GrpcService; +import io.quarkus.grpc.RegisterInterceptor; @GrpcService +@RegisterInterceptor(A2AExtensionsInterceptor.class) public class QuarkusGrpcHandler extends GrpcHandler { private final AgentCard agentCard; diff --git a/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java b/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java index 29791d653..05030c632 100644 --- a/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java +++ b/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java @@ -4,6 +4,8 @@ import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.Executor; @@ -21,6 +23,7 @@ import com.fasterxml.jackson.databind.JsonNode; import io.a2a.server.ServerCallContext; import io.a2a.server.auth.UnauthenticatedUser; +import io.a2a.server.extensions.A2AExtensions; import io.a2a.server.auth.User; import io.a2a.server.util.async.Internal; import io.a2a.spec.AgentCard; @@ -241,7 +244,11 @@ public String getUsername() { headerNames.forEach(name -> headers.put(name, rc.request().getHeader(name))); state.put("headers", headers); - return new ServerCallContext(user, state); + // Extract requested extensions from X-A2A-Extensions header + List extensionHeaderValues = rc.request().headers().getAll(A2AExtensions.HTTP_EXTENSION_HEADER); + Set requestedExtensions = A2AExtensions.getRequestedExtensions(extensionHeaderValues); + + return new ServerCallContext(user, state, requestedExtensions); } else { CallContextFactory builder = callContextFactory.get(); return builder.build(rc); diff --git a/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java index 1a90cda75..b1b1a88bf 100644 --- a/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java +++ b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java @@ -34,9 +34,13 @@ import io.vertx.core.http.HttpServerResponse; import io.vertx.ext.web.RoutingContext; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; +import io.a2a.server.extensions.A2AExtensions; + @Singleton public class A2AServerRoutes { @@ -308,7 +312,11 @@ public String getUsername() { headerNames.forEach(name -> headers.put(name, rc.request().getHeader(name))); state.put("headers", headers); - return new ServerCallContext(user, state); + // Extract requested extensions from X-A2A-Extensions header + List extensionHeaderValues = rc.request().headers().getAll(A2AExtensions.HTTP_EXTENSION_HEADER); + Set requestedExtensions = A2AExtensions.getRequestedExtensions(extensionHeaderValues); + + return new ServerCallContext(user, state, requestedExtensions); } else { CallContextFactory builder = callContextFactory.get(); return builder.build(rc); diff --git a/server-common/src/main/java/io/a2a/server/ServerCallContext.java b/server-common/src/main/java/io/a2a/server/ServerCallContext.java index 558f01eda..cef84700e 100644 --- a/server-common/src/main/java/io/a2a/server/ServerCallContext.java +++ b/server-common/src/main/java/io/a2a/server/ServerCallContext.java @@ -1,6 +1,8 @@ package io.a2a.server; +import java.util.HashSet; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import io.a2a.server.auth.User; @@ -10,10 +12,14 @@ public class ServerCallContext { private final Map modelConfig = new ConcurrentHashMap<>(); private final Map state; private final User user; + private final Set requestedExtensions; + private final Set activatedExtensions; - public ServerCallContext(User user, Map state) { + public ServerCallContext(User user, Map state, Set requestedExtensions) { this.user = user; this.state = state; + this.requestedExtensions = new HashSet<>(requestedExtensions); + this.activatedExtensions = new HashSet<>(); // Always starts empty, populated later by application code } public Map getState() { @@ -23,4 +29,28 @@ public Map getState() { public User getUser() { return user; } + + public Set getRequestedExtensions() { + return new HashSet<>(requestedExtensions); + } + + public Set getActivatedExtensions() { + return new HashSet<>(activatedExtensions); + } + + public void activateExtension(String extensionUri) { + activatedExtensions.add(extensionUri); + } + + public void deactivateExtension(String extensionUri) { + activatedExtensions.remove(extensionUri); + } + + public boolean isExtensionActivated(String extensionUri) { + return activatedExtensions.contains(extensionUri); + } + + public boolean isExtensionRequested(String extensionUri) { + return requestedExtensions.contains(extensionUri); + } } diff --git a/server-common/src/main/java/io/a2a/server/extensions/A2AExtensions.java b/server-common/src/main/java/io/a2a/server/extensions/A2AExtensions.java new file mode 100644 index 000000000..8f63b34d6 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/extensions/A2AExtensions.java @@ -0,0 +1,52 @@ +package io.a2a.server.extensions; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import io.a2a.common.A2AHeaders; +import io.a2a.spec.AgentCard; +import io.a2a.spec.AgentExtension; + +public class A2AExtensions { + /** + * HTTP header name for A2A extensions. + * @deprecated Use {@link A2AHeaders#X_A2A_EXTENSIONS} instead + */ + @Deprecated + public static final String HTTP_EXTENSION_HEADER = A2AHeaders.X_A2A_EXTENSIONS; + + public static Set getRequestedExtensions(List values) { + Set extensions = new HashSet<>(); + if (values == null) { + return extensions; + } + + for (String value : values) { + if (value != null) { + // Split by comma and trim whitespace + String[] parts = value.split(","); + for (String part : parts) { + String trimmed = part.trim(); + if (!trimmed.isEmpty()) { + extensions.add(trimmed); + } + } + } + } + + return extensions; + } + + public static AgentExtension findExtensionByUri(AgentCard card, String uri) { + if (card.capabilities() == null || card.capabilities().extensions() == null) { + return null; + } + for (AgentExtension extension : card.capabilities().extensions()) { + if (extension.uri().equals(uri)) { + return extension; + } + } + return null; + } +} diff --git a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java index 7a1e5b1f5..8bace7be2 100644 --- a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java +++ b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -15,8 +16,9 @@ import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; -import java.util.Set; -import java.util.concurrent.TimeUnit; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; import io.a2a.server.ServerCallContext; import io.a2a.server.agentexecution.AgentExecutor; @@ -50,9 +52,6 @@ import io.a2a.spec.TaskPushNotificationConfig; import io.a2a.spec.TaskQueryParams; import io.a2a.spec.UnsupportedOperationError; -import jakarta.enterprise.context.ApplicationScoped; -import jakarta.inject.Inject; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/server-common/src/test/java/io/a2a/server/ServerCallContextTest.java b/server-common/src/test/java/io/a2a/server/ServerCallContextTest.java new file mode 100644 index 000000000..c12a48e27 --- /dev/null +++ b/server-common/src/test/java/io/a2a/server/ServerCallContextTest.java @@ -0,0 +1,133 @@ +package io.a2a.server; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import org.junit.jupiter.api.Test; + +import io.a2a.server.auth.User; + +class ServerCallContextTest { + + @Test + void testDefaultConstructor() { + User user = new TestUser(); + Map state = new HashMap<>(); + + ServerCallContext context = new ServerCallContext(user, state, new HashSet<>()); + + assertEquals(user, context.getUser()); + assertEquals(state, context.getState()); + assertTrue(context.getRequestedExtensions().isEmpty()); + assertTrue(context.getActivatedExtensions().isEmpty()); + } + + @Test + void testConstructorWithRequestedExtensions() { + User user = new TestUser(); + Map state = new HashMap<>(); + Set requestedExtensions = Set.of("foo", "bar"); + + ServerCallContext context = new ServerCallContext(user, state, requestedExtensions); + + assertEquals(user, context.getUser()); + assertEquals(state, context.getState()); + assertEquals(requestedExtensions, context.getRequestedExtensions()); + assertTrue(context.getActivatedExtensions().isEmpty()); + } + + @Test + void testConstructorWithRequestedAndActivatedExtensions() { + User user = new TestUser(); + Map state = new HashMap<>(); + Set requestedExtensions = Set.of("foo", "bar"); + ServerCallContext context = new ServerCallContext(user, state, requestedExtensions); + + // Manually activate extensions since they start empty + context.activateExtension("foo"); + + assertEquals(user, context.getUser()); + assertEquals(state, context.getState()); + assertEquals(requestedExtensions, context.getRequestedExtensions()); + assertEquals(Set.of("foo"), context.getActivatedExtensions()); + } + + @Test + void testExtensionActivation() { + User user = new TestUser(); + Map state = new HashMap<>(); + Set requestedExtensions = Set.of("foo", "bar"); + + ServerCallContext context = new ServerCallContext(user, state, requestedExtensions); + + // Initially no extensions are activated + assertFalse(context.isExtensionActivated("foo")); + assertFalse(context.isExtensionActivated("bar")); + + // Activate an extension + context.activateExtension("foo"); + assertTrue(context.isExtensionActivated("foo")); + assertFalse(context.isExtensionActivated("bar")); + + // Activate another extension + context.activateExtension("bar"); + assertTrue(context.isExtensionActivated("foo")); + assertTrue(context.isExtensionActivated("bar")); + + // Deactivate an extension + context.deactivateExtension("foo"); + assertFalse(context.isExtensionActivated("foo")); + assertTrue(context.isExtensionActivated("bar")); + } + + @Test + void testExtensionRequested() { + User user = new TestUser(); + Map state = new HashMap<>(); + Set requestedExtensions = Set.of("foo", "bar"); + + ServerCallContext context = new ServerCallContext(user, state, requestedExtensions); + + assertTrue(context.isExtensionRequested("foo")); + assertTrue(context.isExtensionRequested("bar")); + assertFalse(context.isExtensionRequested("baz")); + } + + @Test + void testExtensionCollectionsAreDefensiveCopies() { + User user = new TestUser(); + Map state = new HashMap<>(); + Set requestedExtensions = Set.of("foo", "bar"); + + ServerCallContext context = new ServerCallContext(user, state, requestedExtensions); + + // Modifying returned sets should not affect the context + Set returnedRequested = context.getRequestedExtensions(); + returnedRequested.add("baz"); + assertFalse(context.isExtensionRequested("baz")); + + context.activateExtension("foo"); + Set returnedActivated = context.getActivatedExtensions(); + returnedActivated.add("bar"); + assertFalse(context.isExtensionActivated("bar")); + } + + // Simple test implementation of User interface + private static class TestUser implements User { + @Override + public boolean isAuthenticated() { + return true; + } + + @Override + public String getUsername() { + return "test-user"; + } + } +} diff --git a/server-common/src/test/java/io/a2a/server/extensions/A2AExtensionsTest.java b/server-common/src/test/java/io/a2a/server/extensions/A2AExtensionsTest.java new file mode 100644 index 000000000..1fc00ded3 --- /dev/null +++ b/server-common/src/test/java/io/a2a/server/extensions/A2AExtensionsTest.java @@ -0,0 +1,127 @@ +package io.a2a.server.extensions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import org.junit.jupiter.api.Test; + +import io.a2a.spec.AgentCapabilities; +import io.a2a.spec.AgentCard; +import io.a2a.spec.AgentExtension; + +class A2AExtensionsTest { + + @Test + void testGetRequestedExtensions() { + // Test empty list + Set result = A2AExtensions.getRequestedExtensions(Collections.emptyList()); + assertTrue(result.isEmpty()); + + // Test single extension + result = A2AExtensions.getRequestedExtensions(Arrays.asList("foo")); + assertEquals(Set.of("foo"), result); + + // Test multiple extensions in separate values + result = A2AExtensions.getRequestedExtensions(Arrays.asList("foo", "bar")); + assertEquals(Set.of("foo", "bar"), result); + + // Test comma-separated extensions with space + result = A2AExtensions.getRequestedExtensions(Arrays.asList("foo, bar")); + assertEquals(Set.of("foo", "bar"), result); + + // Test comma-separated extensions without space + result = A2AExtensions.getRequestedExtensions(Arrays.asList("foo,bar")); + assertEquals(Set.of("foo", "bar"), result); + + // Test mixed format + result = A2AExtensions.getRequestedExtensions(Arrays.asList("foo", "bar,baz")); + assertEquals(Set.of("foo", "bar", "baz"), result); + + // Test with empty values and extra spaces + result = A2AExtensions.getRequestedExtensions(Arrays.asList("foo,, bar", "baz")); + assertEquals(Set.of("foo", "bar", "baz"), result); + + // Test with leading/trailing spaces + result = A2AExtensions.getRequestedExtensions(Arrays.asList(" foo , bar ", "baz")); + assertEquals(Set.of("foo", "bar", "baz"), result); + + // Test null list + result = A2AExtensions.getRequestedExtensions(null); + assertTrue(result.isEmpty()); + + // Test list with null values + List listWithNulls = Arrays.asList("foo", null, "bar"); + result = A2AExtensions.getRequestedExtensions(listWithNulls); + assertEquals(Set.of("foo", "bar"), result); + } + + @Test + void testFindExtensionByUri() { + AgentExtension ext1 = new AgentExtension.Builder() + .uri("foo") + .description("The Foo extension") + .build(); + AgentExtension ext2 = new AgentExtension.Builder() + .uri("bar") + .description("The Bar extension") + .build(); + + AgentCard card = new AgentCard.Builder() + .name("Test Agent") + .description("Test Agent Description") + .version("1.0") + .url("http://test.com") + .skills(Collections.emptyList()) + .defaultInputModes(Arrays.asList("text/plain")) + .defaultOutputModes(Arrays.asList("text/plain")) + .capabilities(new AgentCapabilities.Builder() + .extensions(Arrays.asList(ext1, ext2)) + .build()) + .build(); + + assertEquals(ext1, A2AExtensions.findExtensionByUri(card, "foo")); + assertEquals(ext2, A2AExtensions.findExtensionByUri(card, "bar")); + assertNull(A2AExtensions.findExtensionByUri(card, "baz")); + } + + @Test + void testFindExtensionByUriNoExtensions() { + AgentCard card = new AgentCard.Builder() + .name("Test Agent") + .description("Test Agent Description") + .version("1.0") + .url("http://test.com") + .skills(Collections.emptyList()) + .defaultInputModes(Arrays.asList("text/plain")) + .defaultOutputModes(Arrays.asList("text/plain")) + .capabilities(new AgentCapabilities.Builder() + .extensions(null) + .build()) + .build(); + + assertNull(A2AExtensions.findExtensionByUri(card, "foo")); + } + + @Test + void testFindExtensionByUriNoCapabilities() { + // Test with empty capabilities (no extensions list) + AgentCard card = new AgentCard.Builder() + .name("Test Agent") + .description("Test Agent Description") + .version("1.0") + .url("http://test.com") + .skills(Collections.emptyList()) + .defaultInputModes(Arrays.asList("text/plain")) + .defaultOutputModes(Arrays.asList("text/plain")) + .capabilities(new AgentCapabilities.Builder().build()) + .build(); + + assertNull(A2AExtensions.findExtensionByUri(card, "foo")); + } +} diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerBackgroundTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerBackgroundTest.java index 412f1a8cf..284465882 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerBackgroundTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerBackgroundTest.java @@ -1,26 +1,17 @@ package io.a2a.server.requesthandlers; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import java.util.Map; +import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; import io.a2a.server.ServerCallContext; import io.a2a.server.agentexecution.AgentExecutor; import io.a2a.server.agentexecution.RequestContext; -import io.a2a.server.agentexecution.SimpleRequestContextBuilder; import io.a2a.server.auth.UnauthenticatedUser; import io.a2a.server.events.EventQueue; import io.a2a.server.events.InMemoryQueueManager; @@ -32,6 +23,9 @@ import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; /** * Tests for DefaultRequestHandler background cleanup and task tracking functionality, @@ -60,7 +54,7 @@ void setUp() { Executors.newCachedThreadPool() ); - serverCallContext = new ServerCallContext(UnauthenticatedUser.INSTANCE, java.util.Map.of()); + serverCallContext = new ServerCallContext(UnauthenticatedUser.INSTANCE, Map.of(), Set.of()); } /** diff --git a/transport/grpc/src/main/java/io/a2a/transport/grpc/context/GrpcContextKeys.java b/transport/grpc/src/main/java/io/a2a/transport/grpc/context/GrpcContextKeys.java new file mode 100644 index 000000000..483daf7e8 --- /dev/null +++ b/transport/grpc/src/main/java/io/a2a/transport/grpc/context/GrpcContextKeys.java @@ -0,0 +1,45 @@ +package io.a2a.transport.grpc.context; + +import io.grpc.Context; + +/** + * Shared gRPC context keys for A2A protocol data. + * + * These keys provide access to gRPC context information similar to + * Python's grpc.aio.ServicerContext, enabling rich context access + * in service method implementations. + */ +public final class GrpcContextKeys { + + /** + * Context key for storing the X-A2A-Extensions header value. + * Set by server interceptors and accessed by service handlers. + */ + public static final Context.Key EXTENSIONS_HEADER_KEY = + Context.key("x-a2a-extensions"); + + /** + * Context key for storing the complete gRPC Metadata object. + * Provides access to all request headers and metadata. + */ + public static final Context.Key METADATA_KEY = + Context.key("grpc-metadata"); + + /** + * Context key for storing the method name being called. + * Equivalent to Python's context.method() functionality. + */ + public static final Context.Key METHOD_NAME_KEY = + Context.key("grpc-method-name"); + + /** + * Context key for storing the peer information. + * Provides access to client connection details. + */ + public static final Context.Key PEER_INFO_KEY = + Context.key("grpc-peer-info"); + + private GrpcContextKeys() { + // Utility class + } +} diff --git a/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java b/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java index e19ba80ef..6c4a47414 100644 --- a/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java +++ b/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java @@ -6,9 +6,13 @@ import jakarta.enterprise.inject.Vetoed; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.CompletableFuture; + +import io.grpc.Context; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicBoolean; @@ -19,6 +23,8 @@ import io.a2a.server.ServerCallContext; import io.a2a.server.auth.UnauthenticatedUser; import io.a2a.server.auth.User; +import io.a2a.server.extensions.A2AExtensions; +import io.a2a.transport.grpc.context.GrpcContextKeys; import io.a2a.server.requesthandlers.RequestHandler; import io.a2a.spec.AgentCard; import io.a2a.spec.ContentTypeNotSupportedError; @@ -312,17 +318,53 @@ private ServerCallContext createCallContext(StreamObserver responseObserv User user = UnauthenticatedUser.INSTANCE; Map state = new HashMap<>(); - // TODO: ARCHITECTURAL LIMITATION - StreamObserver is NOT equivalent to Python's grpc.aio.ServicerContext - // In Python: context parameter provides metadata, peer info, cancellation, etc. - // In Java: the proper equivalent would be ServerCall + Metadata from ServerInterceptor - // Current compromise: Store StreamObserver for basic functionality, but it lacks rich context - // - // FUTURE ENHANCEMENT: Implement ServerInterceptor to capture ServerCall + Metadata, - // store in gRPC Context using Context.Key, then access via Context.current().get(key) - // This would provide proper equivalence to Python's ServicerContext + // Enhanced gRPC context access - equivalent to Python's grpc.aio.ServicerContext + // The A2AExtensionsInterceptor captures ServerCall + Metadata and stores them in gRPC Context + // This provides proper equivalence to Python's ServicerContext for metadata access + // Note: StreamObserver is still stored for response handling state.put("grpc_response_observer", responseObserver); - return new ServerCallContext(user, state); + // Add rich gRPC context information if available (set by interceptor) + // This provides equivalent functionality to Python's grpc.aio.ServicerContext + try { + Context currentContext = Context.current(); + if (currentContext != null) { + state.put("grpc_context", currentContext); + + // Add specific context information for easy access + io.grpc.Metadata grpcMetadata = GrpcContextKeys.METADATA_KEY.get(currentContext); + if (grpcMetadata != null) { + state.put("grpc_metadata", grpcMetadata); + } + + String methodName = GrpcContextKeys.METHOD_NAME_KEY.get(currentContext); + if (methodName != null) { + state.put("grpc_method_name", methodName); + } + + String peerInfo = GrpcContextKeys.PEER_INFO_KEY.get(currentContext); + if (peerInfo != null) { + state.put("grpc_peer_info", peerInfo); + } + } + } catch (Exception e) { + // Context not available - continue without it + } + + // Extract requested extensions from gRPC context (set by interceptor) + Set requestedExtensions = new HashSet<>(); + try { + // Try to get extensions from gRPC context (available when interceptor is used) + String extensionsHeader = getExtensionsFromContext(); + if (extensionsHeader != null) { + requestedExtensions = A2AExtensions.getRequestedExtensions(List.of(extensionsHeader)); + } + } catch (Exception e) { + // If context access fails (e.g., no interceptor), continue with empty set + // This maintains backward compatibility + } + + return new ServerCallContext(user, state, requestedExtensions); } else { // TODO: CallContextFactory interface expects ServerCall + Metadata, but we only have StreamObserver // This is another manifestation of the architectural limitation mentioned above @@ -418,4 +460,67 @@ public static void setStreamingSubscribedRunnable(Runnable runnable) { protected abstract AgentCard getAgentCard(); protected abstract CallContextFactory getCallContextFactory(); + + /** + * Attempts to extract the X-A2A-Extensions header from the current gRPC context. + * This will only work if a server interceptor has been configured to capture + * the metadata and store it in the context. + * + * @return the extensions header value, or null if not available + */ + private String getExtensionsFromContext() { + try { + return GrpcContextKeys.EXTENSIONS_HEADER_KEY.get(); + } catch (Exception e) { + // Context not available or key not set + return null; + } + } + + /** + * Utility methods for accessing gRPC context information. + * These provide equivalent functionality to Python's grpc.aio.ServicerContext methods. + */ + + /** + * Gets the complete gRPC metadata from the current context. + * Equivalent to Python's context.invocation_metadata. + * + * @return the gRPC Metadata object, or null if not available + */ + protected static io.grpc.Metadata getCurrentMetadata() { + try { + return GrpcContextKeys.METADATA_KEY.get(); + } catch (Exception e) { + return null; + } + } + + /** + * Gets the current gRPC method name. + * Equivalent to Python's context.method(). + * + * @return the method name, or null if not available + */ + protected static String getCurrentMethodName() { + try { + return GrpcContextKeys.METHOD_NAME_KEY.get(); + } catch (Exception e) { + return null; + } + } + + /** + * Gets the peer information for the current gRPC call. + * Equivalent to Python's context.peer(). + * + * @return the peer information, or null if not available + */ + protected static String getCurrentPeerInfo() { + try { + return GrpcContextKeys.PEER_INFO_KEY.get(); + } catch (Exception e) { + return null; + } + } } diff --git a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java index 19137ab01..b37724e56 100644 --- a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java +++ b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java @@ -7,6 +7,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.HashSet; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; @@ -76,7 +77,7 @@ public class JSONRPCHandlerTest extends AbstractA2ARequestHandlerTest { - private final ServerCallContext callContext = new ServerCallContext(UnauthenticatedUser.INSTANCE, Map.of("foo", "bar")); + private final ServerCallContext callContext = new ServerCallContext(UnauthenticatedUser.INSTANCE, Map.of("foo", "bar"), new HashSet<>()); @Test public void testOnGetTaskSuccess() throws Exception { diff --git a/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java b/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java index d8b355dc3..409998f6b 100644 --- a/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java +++ b/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java @@ -1,6 +1,7 @@ package io.a2a.transport.rest.handler; import com.google.protobuf.InvalidProtocolBufferException; +import java.util.HashSet; import java.util.Map; import io.a2a.server.ServerCallContext; @@ -18,7 +19,7 @@ public class RestHandlerTest extends AbstractA2ARequestHandlerTest { - private final ServerCallContext callContext = new ServerCallContext(UnauthenticatedUser.INSTANCE, Map.of("foo", "bar")); + private final ServerCallContext callContext = new ServerCallContext(UnauthenticatedUser.INSTANCE, Map.of("foo", "bar"), new HashSet<>()); @Test public void testGetTaskSuccess() { From 13399a18d812d6971f0769f51bde83e8a9223713 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Tue, 23 Sep 2025 10:51:00 +0100 Subject: [PATCH 2/2] Code review fixes --- .../client/transport/grpc/GrpcTransport.java | 51 ++++++++++++++----- .../server/apps/quarkus/A2AServerRoutes.java | 6 +-- .../server/rest/quarkus/A2AServerRoutes.java | 3 +- .../a2a/server/extensions/A2AExtensions.java | 6 --- .../transport/grpc/handler/GrpcHandler.java | 49 +++++++++--------- 5 files changed, 68 insertions(+), 47 deletions(-) diff --git a/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java b/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java index 3d14ffe4a..87de05ecd 100644 --- a/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java +++ b/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java @@ -61,11 +61,9 @@ public EventKind sendMessage(MessageSendParams request, ClientCallContext contex checkNotNullParam("request", request); SendMessageRequest sendMessageRequest = createGrpcSendMessageRequest(request, context); - Metadata metadata = createGrpcMetadata(context); try { - // Create a stub with metadata attached - A2AServiceBlockingV2Stub stubWithMetadata = blockingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)); + A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context); SendMessageResponse response = stubWithMetadata.sendMessage(sendMessageRequest); if (response.hasMsg()) { return FromProto.message(response.getMsg()); @@ -85,12 +83,10 @@ public void sendMessageStreaming(MessageSendParams request, Consumer streamObserver = new EventStreamObserver(eventConsumer, errorConsumer); try { - // Create a stub with metadata attached - A2AServiceStub stubWithMetadata = asyncStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)); + A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context); stubWithMetadata.sendStreamingMessage(grpcRequest, streamObserver); } catch (StatusRuntimeException e) { throw GrpcErrorMapper.mapGrpcError(e, "Failed to send streaming message request: "); @@ -109,7 +105,8 @@ public Task getTask(TaskQueryParams request, ClientCallContext context) throws A GetTaskRequest getTaskRequest = requestBuilder.build(); try { - return FromProto.task(blockingStub.getTask(getTaskRequest)); + A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context); + return FromProto.task(stubWithMetadata.getTask(getTaskRequest)); } catch (StatusRuntimeException e) { throw GrpcErrorMapper.mapGrpcError(e, "Failed to get task: "); } @@ -124,7 +121,8 @@ public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A .build(); try { - return FromProto.task(blockingStub.cancelTask(cancelTaskRequest)); + A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context); + return FromProto.task(stubWithMetadata.cancelTask(cancelTaskRequest)); } catch (StatusRuntimeException e) { throw GrpcErrorMapper.mapGrpcError(e, "Failed to cancel task: "); } @@ -143,7 +141,8 @@ public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushN .build(); try { - return FromProto.taskPushNotificationConfig(blockingStub.createTaskPushNotificationConfig(grpcRequest)); + A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context); + return FromProto.taskPushNotificationConfig(stubWithMetadata.createTaskPushNotificationConfig(grpcRequest)); } catch (StatusRuntimeException e) { throw GrpcErrorMapper.mapGrpcError(e, "Failed to create task push notification config: "); } @@ -160,7 +159,8 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration( .build(); try { - return FromProto.taskPushNotificationConfig(blockingStub.getTaskPushNotificationConfig(grpcRequest)); + A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context); + return FromProto.taskPushNotificationConfig(stubWithMetadata.getTaskPushNotificationConfig(grpcRequest)); } catch (StatusRuntimeException e) { throw GrpcErrorMapper.mapGrpcError(e, "Failed to get task push notification config: "); } @@ -177,7 +177,8 @@ public List listTaskPushNotificationConfigurations( .build(); try { - return blockingStub.listTaskPushNotificationConfig(grpcRequest).getConfigsList().stream() + A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context); + return stubWithMetadata.listTaskPushNotificationConfig(grpcRequest).getConfigsList().stream() .map(FromProto::taskPushNotificationConfig) .collect(Collectors.toList()); } catch (StatusRuntimeException e) { @@ -195,7 +196,8 @@ public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationC .build(); try { - blockingStub.deleteTaskPushNotificationConfig(grpcRequest); + A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context); + stubWithMetadata.deleteTaskPushNotificationConfig(grpcRequest); } catch (StatusRuntimeException e) { throw GrpcErrorMapper.mapGrpcError(e, "Failed to delete task push notification config: "); } @@ -214,7 +216,8 @@ public void resubscribe(TaskIdParams request, Consumer event StreamObserver streamObserver = new EventStreamObserver(eventConsumer, errorConsumer); try { - asyncStub.taskSubscription(grpcRequest, streamObserver); + A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context); + stubWithMetadata.taskSubscription(grpcRequest, streamObserver); } catch (StatusRuntimeException e) { throw GrpcErrorMapper.mapGrpcError(e, "Failed to resubscribe task push notification config: "); } @@ -264,6 +267,28 @@ private Metadata createGrpcMetadata(ClientCallContext context) { return metadata; } + /** + * Creates a blocking stub with metadata attached from the ClientCallContext. + * + * @param context the client call context + * @return blocking stub with metadata interceptor + */ + private A2AServiceBlockingV2Stub createBlockingStubWithMetadata(ClientCallContext context) { + Metadata metadata = createGrpcMetadata(context); + return blockingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)); + } + + /** + * Creates an async stub with metadata attached from the ClientCallContext. + * + * @param context the client call context + * @return async stub with metadata interceptor + */ + private A2AServiceStub createAsyncStubWithMetadata(ClientCallContext context) { + Metadata metadata = createGrpcMetadata(context); + return asyncStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)); + } + private String getTaskPushNotificationConfigName(GetTaskPushNotificationConfigParams params) { return getTaskPushNotificationConfigName(params.id(), params.pushNotificationConfigId()); } diff --git a/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java b/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java index 05030c632..3958d8eed 100644 --- a/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java +++ b/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java @@ -4,7 +4,6 @@ import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -21,10 +20,11 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.io.JsonEOFException; import com.fasterxml.jackson.databind.JsonNode; +import io.a2a.common.A2AHeaders; import io.a2a.server.ServerCallContext; import io.a2a.server.auth.UnauthenticatedUser; -import io.a2a.server.extensions.A2AExtensions; import io.a2a.server.auth.User; +import io.a2a.server.extensions.A2AExtensions; import io.a2a.server.util.async.Internal; import io.a2a.spec.AgentCard; import io.a2a.spec.CancelTaskRequest; @@ -245,7 +245,7 @@ public String getUsername() { state.put("headers", headers); // Extract requested extensions from X-A2A-Extensions header - List extensionHeaderValues = rc.request().headers().getAll(A2AExtensions.HTTP_EXTENSION_HEADER); + List extensionHeaderValues = rc.request().headers().getAll(A2AHeaders.X_A2A_EXTENSIONS); Set requestedExtensions = A2AExtensions.getRequestedExtensions(extensionHeaderValues); return new ServerCallContext(user, state, requestedExtensions); diff --git a/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java index b1b1a88bf..5e0f3e0f8 100644 --- a/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java +++ b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java @@ -12,6 +12,7 @@ import jakarta.inject.Inject; import jakarta.inject.Singleton; +import io.a2a.common.A2AHeaders; import io.a2a.server.ServerCallContext; import io.a2a.server.auth.UnauthenticatedUser; import io.a2a.server.auth.User; @@ -313,7 +314,7 @@ public String getUsername() { state.put("headers", headers); // Extract requested extensions from X-A2A-Extensions header - List extensionHeaderValues = rc.request().headers().getAll(A2AExtensions.HTTP_EXTENSION_HEADER); + List extensionHeaderValues = rc.request().headers().getAll(A2AHeaders.X_A2A_EXTENSIONS); Set requestedExtensions = A2AExtensions.getRequestedExtensions(extensionHeaderValues); return new ServerCallContext(user, state, requestedExtensions); diff --git a/server-common/src/main/java/io/a2a/server/extensions/A2AExtensions.java b/server-common/src/main/java/io/a2a/server/extensions/A2AExtensions.java index 8f63b34d6..fec151366 100644 --- a/server-common/src/main/java/io/a2a/server/extensions/A2AExtensions.java +++ b/server-common/src/main/java/io/a2a/server/extensions/A2AExtensions.java @@ -9,12 +9,6 @@ import io.a2a.spec.AgentExtension; public class A2AExtensions { - /** - * HTTP header name for A2A extensions. - * @deprecated Use {@link A2AHeaders#X_A2A_EXTENSIONS} instead - */ - @Deprecated - public static final String HTTP_EXTENSION_HEADER = A2AHeaders.X_A2A_EXTENSIONS; public static Set getRequestedExtensions(List values) { Set extensions = new HashSet<>(); diff --git a/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java b/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java index 6c4a47414..b259e91ab 100644 --- a/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java +++ b/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java @@ -15,6 +15,7 @@ import io.grpc.Context; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Logger; import com.google.protobuf.Empty; import io.a2a.grpc.A2AServiceGrpc; @@ -61,6 +62,8 @@ public abstract class GrpcHandler extends A2AServiceGrpc.A2AServiceImplBase { private AtomicBoolean initialised = new AtomicBoolean(false); + private static final Logger LOGGER = Logger.getLogger(GrpcHandler.class.getName()); + public GrpcHandler() { } @@ -349,19 +352,14 @@ private ServerCallContext createCallContext(StreamObserver responseObserv } } catch (Exception e) { // Context not available - continue without it + LOGGER.fine(() -> "Error getting data from current context" + e); } // Extract requested extensions from gRPC context (set by interceptor) Set requestedExtensions = new HashSet<>(); - try { - // Try to get extensions from gRPC context (available when interceptor is used) - String extensionsHeader = getExtensionsFromContext(); - if (extensionsHeader != null) { - requestedExtensions = A2AExtensions.getRequestedExtensions(List.of(extensionsHeader)); - } - } catch (Exception e) { - // If context access fails (e.g., no interceptor), continue with empty set - // This maintains backward compatibility + String extensionsHeader = getExtensionsFromContext(); + if (extensionsHeader != null) { + requestedExtensions = A2AExtensions.getRequestedExtensions(List.of(extensionsHeader)); } return new ServerCallContext(user, state, requestedExtensions); @@ -483,19 +481,30 @@ private String getExtensionsFromContext() { */ /** - * Gets the complete gRPC metadata from the current context. - * Equivalent to Python's context.invocation_metadata. + * Generic helper method to safely access gRPC context values. * - * @return the gRPC Metadata object, or null if not available + * @param key the context key to retrieve + * @return the context value, or null if not available */ - protected static io.grpc.Metadata getCurrentMetadata() { + private static T getFromContext(Context.Key key) { try { - return GrpcContextKeys.METADATA_KEY.get(); + return key.get(); } catch (Exception e) { + // Context not available or key not set return null; } } + /** + * Gets the complete gRPC metadata from the current context. + * Equivalent to Python's context.invocation_metadata. + * + * @return the gRPC Metadata object, or null if not available + */ + protected static io.grpc.Metadata getCurrentMetadata() { + return getFromContext(GrpcContextKeys.METADATA_KEY); + } + /** * Gets the current gRPC method name. * Equivalent to Python's context.method(). @@ -503,11 +512,7 @@ protected static io.grpc.Metadata getCurrentMetadata() { * @return the method name, or null if not available */ protected static String getCurrentMethodName() { - try { - return GrpcContextKeys.METHOD_NAME_KEY.get(); - } catch (Exception e) { - return null; - } + return getFromContext(GrpcContextKeys.METHOD_NAME_KEY); } /** @@ -517,10 +522,6 @@ protected static String getCurrentMethodName() { * @return the peer information, or null if not available */ protected static String getCurrentPeerInfo() { - try { - return GrpcContextKeys.PEER_INFO_KEY.get(); - } catch (Exception e) { - return null; - } + return getFromContext(GrpcContextKeys.PEER_INFO_KEY); } }