From 1634895fd66c1afad34b7bd55b5b5c6cc6a51dfa Mon Sep 17 00:00:00 2001 From: Nitesh Kant Date: Fri, 3 Jan 2020 11:53:12 -0800 Subject: [PATCH] gRPC services to support closing (#903) __Motivation__ gRPC implementation does not wire close implementations (by implementing methods in `AsyncCloseable`, `GracefulAutoCloseable` or `AutoCloseable`) defines by services such that server closure will invoke the appropriate close methods on the service. Also, there is no way to specify close semantics while implementing interfaces for each RPC method. __Modification__ - Defined two new interfaces; `Rpc` and `BlockingRpc` that are appropriately implemented by individual RPC methods. - Introduced a `wrap()` method for each existing `Route` interface variant such that a lambda implementing the route can also define a detached close implementation and this method will attach the close implementation to the route implementation by wrapping. - Modified gRPC code generation to utilize the above changes. - Modified gRPC router to use these new functionality and wire close together through various route layers. - Added tests to verify each close implementation variant. - Fixed a bug in blocking -> async service conversions that were not implementing graceful close (required for this change). __Result__ gRPC services can now implement close methods when required. --- .../grpc/api/BlockingGrpcService.java | 4 + .../io/servicetalk/grpc/api/GrpcRouter.java | 313 +++++++++++++----- .../io/servicetalk/grpc/api/GrpcRoutes.java | 266 ++++++++++++++- .../io/servicetalk/grpc/api/GrpcService.java | 7 + .../servicetalk/grpc/netty/ClosureTest.java | 267 +++++++++++++++ .../grpc/netty/ErrorHandlingTest.java | 18 +- .../io/servicetalk/grpc/protoc/Generator.java | 46 ++- .../io/servicetalk/grpc/protoc/Types.java | 5 + .../BlockingStreamingToStreamingService.java | 5 + .../http/api/BlockingToStreamingService.java | 5 + 10 files changed, 818 insertions(+), 118 deletions(-) create mode 100644 servicetalk-grpc-netty/src/test/java/io/servicetalk/grpc/netty/ClosureTest.java diff --git a/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/BlockingGrpcService.java b/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/BlockingGrpcService.java index 96b29df84a..96d208f838 100644 --- a/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/BlockingGrpcService.java +++ b/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/BlockingGrpcService.java @@ -21,4 +21,8 @@ * A blocking gRPC service. */ public interface BlockingGrpcService extends GracefulAutoCloseable { + @Override + default void close() throws Exception { + // noop + } } diff --git a/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/GrpcRouter.java b/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/GrpcRouter.java index 917bb20489..fc6aa85a0a 100644 --- a/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/GrpcRouter.java +++ b/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/GrpcRouter.java @@ -15,6 +15,7 @@ */ package io.servicetalk.grpc.api; +import io.servicetalk.concurrent.BlockingIterable; import io.servicetalk.concurrent.GracefulAutoCloseable; import io.servicetalk.concurrent.api.AsyncCloseable; import io.servicetalk.concurrent.api.AsyncCloseables; @@ -33,14 +34,22 @@ import io.servicetalk.grpc.api.GrpcServiceFactory.ServerBinder; import io.servicetalk.grpc.api.GrpcUtils.GrpcStatusUpdater; import io.servicetalk.http.api.BlockingHttpService; +import io.servicetalk.http.api.BlockingStreamingHttpRequest; +import io.servicetalk.http.api.BlockingStreamingHttpServerResponse; +import io.servicetalk.http.api.BlockingStreamingHttpService; import io.servicetalk.http.api.HttpApiConversions.ServiceAdapterHolder; import io.servicetalk.http.api.HttpDeserializer; import io.servicetalk.http.api.HttpExecutionStrategy; import io.servicetalk.http.api.HttpPayloadWriter; -import io.servicetalk.http.api.HttpRequestMethod; +import io.servicetalk.http.api.HttpRequest; +import io.servicetalk.http.api.HttpResponse; +import io.servicetalk.http.api.HttpResponseFactory; import io.servicetalk.http.api.HttpSerializer; import io.servicetalk.http.api.HttpService; +import io.servicetalk.http.api.HttpServiceContext; +import io.servicetalk.http.api.StreamingHttpRequest; import io.servicetalk.http.api.StreamingHttpResponse; +import io.servicetalk.http.api.StreamingHttpResponseFactory; import io.servicetalk.http.api.StreamingHttpService; import io.servicetalk.transport.api.ExecutionContext; import io.servicetalk.transport.api.ServerContext; @@ -62,9 +71,9 @@ import static io.servicetalk.grpc.api.GrpcUtils.newResponse; import static io.servicetalk.grpc.api.GrpcUtils.readGrpcMessageEncoding; import static io.servicetalk.grpc.api.GrpcUtils.setStatus; -import static io.servicetalk.grpc.api.GrpcUtils.uncheckedCast; import static io.servicetalk.http.api.HttpApiConversions.toStreamingHttpService; import static io.servicetalk.http.api.HttpExecutionStrategies.noOffloadsStrategy; +import static io.servicetalk.http.api.HttpRequestMethod.POST; import static java.util.Collections.unmodifiableMap; import static java.util.Objects.requireNonNull; @@ -98,30 +107,48 @@ private GrpcRouter(final Map routes, } Single bind(final ServerBinder binder, final ExecutionContext executionContext) { + CompositeCloseable closeable = AsyncCloseables.newCompositeCloseable(); final Map allRoutes = new HashMap<>(); - populateRoutes(executionContext, allRoutes, routes); - populateRoutes(executionContext, allRoutes, streamingRoutes); - populateRoutes(executionContext, allRoutes, blockingRoutes); - populateRoutes(executionContext, allRoutes, blockingStreamingRoutes); + populateRoutes(executionContext, allRoutes, routes, closeable); + populateRoutes(executionContext, allRoutes, streamingRoutes, closeable); + populateRoutes(executionContext, allRoutes, blockingRoutes, closeable); + populateRoutes(executionContext, allRoutes, blockingStreamingRoutes, closeable); // TODO: Optimize to bind a specific programming model service based on routes - return binder.bindStreaming((ctx, request, responseFactory) -> { - StreamingHttpService service; - if (request.method() != HttpRequestMethod.POST || (service = allRoutes.get(request.path())) == null) { - return notFound.handle(ctx, request, responseFactory); - } else { - return service.handle(ctx, request, responseFactory); + return binder.bindStreaming(new StreamingHttpService() { + @Override + public Single handle(final HttpServiceContext ctx, + final StreamingHttpRequest request, + final StreamingHttpResponseFactory responseFactory) { + final StreamingHttpService service; + if (!POST.equals(request.method()) || (service = allRoutes.get(request.path())) == null) { + return notFound.handle(ctx, request, responseFactory); + } else { + return service.handle(ctx, request, responseFactory); + } + } + + @Override + public Completable closeAsync() { + return closeable.closeAsync(); + } + + @Override + public Completable closeAsyncGracefully() { + return closeable.closeAsyncGracefully(); } }); } - private void populateRoutes(final ExecutionContext executionContext, - final Map allRoutes, - final Map routes) { + private static void populateRoutes(final ExecutionContext executionContext, + final Map allRoutes, + final Map routes, + final CompositeCloseable closeable) { for (Map.Entry entry : routes.entrySet()) { final ServiceAdapterHolder adapterHolder = entry.getValue().buildRoute(executionContext); + StreamingHttpService route = closeable.append(adapterHolder.adaptor()); allRoutes.put(entry.getKey(), adapterHolder.serviceInvocationStrategy() - .offloadService(executionContext.executor(), adapterHolder.adaptor())); + .offloadService(executionContext.executor(), route)); } } @@ -184,24 +211,38 @@ Builder addRoute( final Route route, final Class requestClass, final Class responseClass, final GrpcSerializationProvider serializationProvider) { routes.put(path, new RouteProvider(executionContext -> toStreamingHttpService( - (HttpService) (ctx, request, responseFactory) -> { - try { - final GrpcServiceContext serviceContext = - new DefaultGrpcServiceContext(request.path(), ctx); - final HttpDeserializer deserializer = - serializationProvider.deserializerFor(readGrpcMessageEncoding(request), - requestClass); - return route.handle(serviceContext, request.payloadBody(deserializer)) - .map(rawResp -> newResponse(responseFactory, - ctx.executionContext().bufferAllocator()) - .payloadBody(uncheckedCast(rawResp), - serializationProvider.serializerFor(serviceContext, - responseClass))) - .recoverWith(cause -> succeeded(newErrorResponse(responseFactory, cause, - ctx.executionContext().bufferAllocator()))); - } catch (Throwable t) { - return succeeded(newErrorResponse(responseFactory, t, - ctx.executionContext().bufferAllocator())); + new HttpService() { + @Override + public Single handle(final HttpServiceContext ctx, final HttpRequest request, + final HttpResponseFactory responseFactory) { + try { + final GrpcServiceContext serviceContext = + new DefaultGrpcServiceContext(request.path(), ctx); + final HttpDeserializer deserializer = + serializationProvider.deserializerFor(readGrpcMessageEncoding(request), + requestClass); + return route.handle(serviceContext, request.payloadBody(deserializer)) + .map(rawResp -> newResponse(responseFactory, + ctx.executionContext().bufferAllocator()) + .payloadBody(rawResp, + serializationProvider.serializerFor(serviceContext, + responseClass))) + .recoverWith(cause -> succeeded(newErrorResponse(responseFactory, cause, + ctx.executionContext().bufferAllocator()))); + } catch (Throwable t) { + return succeeded(newErrorResponse(responseFactory, t, + ctx.executionContext().bufferAllocator())); + } + } + + @Override + public Completable closeAsync() { + return route.closeAsync(); + } + + @Override + public Completable closeAsyncGracefully() { + return route.closeAsyncGracefully(); } }, strategy -> executionStrategy == null ? strategy : executionStrategy), () -> toStreaming(route), () -> toRequestStreamingRoute(route), @@ -214,19 +255,36 @@ Builder addStreamingRoute( final StreamingRoute route, final Class requestClass, final Class responseClass, final GrpcSerializationProvider serializationProvider) { streamingRoutes.put(path, new RouteProvider(executionContext -> { - StreamingHttpService service = (ctx, request, responseFactory) -> { - try { - final GrpcServiceContext serviceContext = new DefaultGrpcServiceContext(request.path(), ctx); - final HttpDeserializer deserializer = - serializationProvider.deserializerFor(readGrpcMessageEncoding(request), requestClass); - final Publisher response = route.handle(serviceContext, request.payloadBody(deserializer)) - .map(GrpcUtils::uncheckedCast); - return succeeded(newResponse(responseFactory, response, - serializationProvider.serializerFor(serviceContext, responseClass), - ctx.executionContext().bufferAllocator())); - } catch (Throwable t) { - return succeeded(newErrorResponse(responseFactory, t, - ctx.executionContext().bufferAllocator())); + StreamingHttpService service = new StreamingHttpService() { + @Override + public Single handle(final HttpServiceContext ctx, + final StreamingHttpRequest request, + final StreamingHttpResponseFactory responseFactory) { + try { + final GrpcServiceContext serviceContext = + new DefaultGrpcServiceContext(request.path(), ctx); + final HttpDeserializer deserializer = + serializationProvider.deserializerFor(readGrpcMessageEncoding(request), + requestClass); + final Publisher response = route.handle(serviceContext, + request.payloadBody(deserializer)); + return succeeded(newResponse(responseFactory, response, + serializationProvider.serializerFor(serviceContext, responseClass), + ctx.executionContext().bufferAllocator())); + } catch (Throwable t) { + return succeeded(newErrorResponse(responseFactory, t, + ctx.executionContext().bufferAllocator())); + } + } + + @Override + public Completable closeAsync() { + return route.closeAsync(); + } + + @Override + public Completable closeAsyncGracefully() { + return route.closeAsyncGracefully(); } }; return new ServiceAdapterHolder() { @@ -250,7 +308,22 @@ Builder addRequestStreamingRoute( final RequestStreamingRoute route, final Class requestClass, final Class responseClass, final GrpcSerializationProvider serializationProvider) { return addStreamingRoute(path, executionStrategy, - (ctx, request) -> route.handle(ctx, request).toPublisher(), requestClass, responseClass, + new StreamingRoute() { + @Override + public Publisher handle(final GrpcServiceContext ctx, final Publisher request) { + return route.handle(ctx, request).toPublisher(); + } + + @Override + public Completable closeAsync() { + return route.closeAsync(); + } + + @Override + public Completable closeAsyncGracefully() { + return route.closeAsyncGracefully(); + } + }, requestClass, responseClass, serializationProvider); } @@ -270,8 +343,23 @@ Builder addResponseStreamingRoute( final String path, @Nullable final GrpcExecutionStrategy executionStrategy, final ResponseStreamingRoute route, final Class requestClass, final Class responseClass, final GrpcSerializationProvider serializationProvider) { - return addStreamingRoute(path, executionStrategy, (ctx, request) -> request.firstOrError() - .flatMapPublisher(rawReq -> route.handle(ctx, uncheckedCast(rawReq))), + return addStreamingRoute(path, executionStrategy, new StreamingRoute() { + @Override + public Publisher handle(final GrpcServiceContext ctx, final Publisher request) { + return request.firstOrError() + .flatMapPublisher(rawReq -> route.handle(ctx, rawReq)); + } + + @Override + public Completable closeAsync() { + return route.closeAsync(); + } + + @Override + public Completable closeAsyncGracefully() { + return route.closeAsyncGracefully(); + } + }, requestClass, responseClass, serializationProvider); } @@ -292,19 +380,33 @@ Builder addBlockingRoute( final BlockingRoute route, final Class requestClass, final Class responseClass, final GrpcSerializationProvider serializationProvider) { blockingRoutes.put(path, new RouteProvider(executionContext -> - toStreamingHttpService((BlockingHttpService) (ctx, request, responseFactory) -> { - try { - final GrpcServiceContext serviceContext = - new DefaultGrpcServiceContext(request.path(), ctx); - final HttpDeserializer deserializer = - serializationProvider.deserializerFor(readGrpcMessageEncoding(request), - requestClass); - final Resp response = route.handle(serviceContext, request.payloadBody(deserializer)); - return newResponse(responseFactory, ctx.executionContext().bufferAllocator()) - .payloadBody(response, - serializationProvider.serializerFor(serviceContext, responseClass)); - } catch (Throwable t) { - return newErrorResponse(responseFactory, t, ctx.executionContext().bufferAllocator()); + toStreamingHttpService(new BlockingHttpService() { + @Override + public HttpResponse handle(final HttpServiceContext ctx, final HttpRequest request, + final HttpResponseFactory responseFactory) { + try { + final GrpcServiceContext serviceContext = + new DefaultGrpcServiceContext(request.path(), ctx); + final HttpDeserializer deserializer = + serializationProvider.deserializerFor(readGrpcMessageEncoding(request), + requestClass); + final Resp response = route.handle(serviceContext, request.payloadBody(deserializer)); + return newResponse(responseFactory, ctx.executionContext().bufferAllocator()) + .payloadBody(response, + serializationProvider.serializerFor(serviceContext, responseClass)); + } catch (Throwable t) { + return newErrorResponse(responseFactory, t, ctx.executionContext().bufferAllocator()); + } + } + + @Override + public void close() throws Exception { + route.close(); + } + + @Override + public void closeGracefully() throws Exception { + route.closeGracefully(); } }, strategy -> executionStrategy == null ? strategy : executionStrategy), () -> toStreaming(route), () -> toRequestStreamingRoute(route), @@ -329,21 +431,37 @@ Builder addBlockingStreamingRoute( final BlockingStreamingRoute route, final Class requestClass, final Class responseClass, final GrpcSerializationProvider serializationProvider) { blockingRoutes.put(path, new RouteProvider(executionContext -> - toStreamingHttpService((ctx, request, response) -> { - final GrpcServiceContext serviceContext = new DefaultGrpcServiceContext(request.path(), ctx); - final HttpDeserializer deserializer = - serializationProvider.deserializerFor(readGrpcMessageEncoding(request), requestClass); - final HttpSerializer serializer = - serializationProvider.serializerFor(serviceContext, responseClass); - final DefaultGrpcPayloadWriter grpcPayloadWriter = - new DefaultGrpcPayloadWriter<>(response.sendMetaData(serializer)); - try { - route.handle(serviceContext, request.payloadBody(deserializer), grpcPayloadWriter); - } catch (Throwable t) { - final HttpPayloadWriter payloadWriter = grpcPayloadWriter.payloadWriter(); - setStatus(payloadWriter.trailers(), t, ctx.executionContext().bufferAllocator()); - } finally { - grpcPayloadWriter.close(); + toStreamingHttpService(new BlockingStreamingHttpService() { + @Override + public void handle(final HttpServiceContext ctx, final BlockingStreamingHttpRequest request, + final BlockingStreamingHttpServerResponse response) throws Exception { + final GrpcServiceContext serviceContext = + new DefaultGrpcServiceContext(request.path(), ctx); + final HttpDeserializer deserializer = + serializationProvider.deserializerFor(readGrpcMessageEncoding(request), + requestClass); + final HttpSerializer serializer = + serializationProvider.serializerFor(serviceContext, responseClass); + final DefaultGrpcPayloadWriter grpcPayloadWriter = + new DefaultGrpcPayloadWriter<>(response.sendMetaData(serializer)); + try { + route.handle(serviceContext, request.payloadBody(deserializer), grpcPayloadWriter); + } catch (Throwable t) { + final HttpPayloadWriter payloadWriter = grpcPayloadWriter.payloadWriter(); + setStatus(payloadWriter.trailers(), t, ctx.executionContext().bufferAllocator()); + } finally { + grpcPayloadWriter.close(); + } + } + + @Override + public void close() throws Exception { + route.close(); + } + + @Override + public void closeGracefully() throws Exception { + route.closeGracefully(); } }, strategy -> executionStrategy == null ? strategy : executionStrategy), () -> toStreaming(route), () -> toRequestStreamingRoute(route), () -> toResponseStreamingRoute(route), @@ -368,9 +486,23 @@ Builder addBlockingRequestStreamingRoute( final BlockingRequestStreamingRoute route, final Class requestClass, final Class responseClass, final GrpcSerializationProvider serializationProvider) { - return addBlockingStreamingRoute(path, executionStrategy, (ctx, request, responseWriter) -> { - final Resp resp = route.handle(ctx, request); - responseWriter.write(resp); + return addBlockingStreamingRoute(path, executionStrategy, new BlockingStreamingRoute() { + @Override + public void handle(final GrpcServiceContext ctx, final BlockingIterable request, + final GrpcPayloadWriter responseWriter) throws Exception { + final Resp resp = route.handle(ctx, request); + responseWriter.write(resp); + } + + @Override + public void close() throws Exception { + route.close(); + } + + @Override + public void closeGracefully() throws Exception { + route.closeGracefully(); + } }, requestClass, responseClass, serializationProvider); } @@ -391,8 +523,23 @@ Builder addBlockingResponseStreamingRoute( final String path, @Nullable final GrpcExecutionStrategy executionStrategy, final BlockingResponseStreamingRoute route, final Class requestClass, final Class responseClass, final GrpcSerializationProvider serializationProvider) { - return addBlockingStreamingRoute(path, executionStrategy, (ctx, request, responseWriter) -> - route.handle(ctx, requireNonNull(request.iterator().next()), responseWriter), + return addBlockingStreamingRoute(path, executionStrategy, new BlockingStreamingRoute() { + @Override + public void handle(final GrpcServiceContext ctx, final BlockingIterable request, + final GrpcPayloadWriter responseWriter) throws Exception { + route.handle(ctx, requireNonNull(request.iterator().next()), responseWriter); + } + + @Override + public void close() throws Exception { + route.close(); + } + + @Override + public void closeGracefully() throws Exception { + route.closeGracefully(); + } + }, requestClass, responseClass, serializationProvider); } diff --git a/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/GrpcRoutes.java b/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/GrpcRoutes.java index 5d255080db..ab19d7db75 100644 --- a/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/GrpcRoutes.java +++ b/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/GrpcRoutes.java @@ -208,7 +208,7 @@ protected final void addStreamingRoute( final String path, final GrpcExecutionStrategy executionStrategy, final StreamingRoute route, final Class requestClass, final Class responseClass, final GrpcSerializationProvider serializationProvider) { - routeBuilder.addStreamingRoute(path, null, route, requestClass, responseClass, + routeBuilder.addStreamingRoute(path, executionStrategy, route, requestClass, responseClass, serializationProvider); } @@ -446,6 +446,7 @@ protected final void addBlockingResponseStreamingRoute( * @param Type of request. * @param Type of response. */ + @FunctionalInterface protected interface Route extends AsyncCloseable { /** @@ -461,6 +462,35 @@ protected interface Route extends AsyncCloseable { default Completable closeAsync() { return completed(); } + + /** + * Convenience method to wrap a raw {@link Route} instance with a passed detached close implementation + * of {@link AsyncCloseable}. + * + * @param rawRoute {@link Route} instance that has a detached close implementation. + * @param closeable {@link AsyncCloseable} implementation for the passed {@code rawRoute}. + * @param Type of request. + * @param Type of response. + * @return A new {@link Route} that attaches the passed {@code closeable} to the passed {@code rawRoute}. + */ + static Route wrap(final Route rawRoute, final AsyncCloseable closeable) { + return new Route() { + @Override + public Single handle(final GrpcServiceContext ctx, final Req request) { + return rawRoute.handle(ctx, request); + } + + @Override + public Completable closeAsync() { + return closeable.closeAsync(); + } + + @Override + public Completable closeAsyncGracefully() { + return closeable.closeAsyncGracefully(); + } + }; + } } /** @@ -469,6 +499,7 @@ default Completable closeAsync() { * @param Type of request. * @param Type of response. */ + @FunctionalInterface protected interface StreamingRoute extends AsyncCloseable { /** @@ -484,6 +515,37 @@ protected interface StreamingRoute extends AsyncCloseable { default Completable closeAsync() { return completed(); } + + /** + * Convenience method to wrap a raw {@link StreamingRoute} instance with a passed detached close implementation + * of {@link AsyncCloseable}. + * + * @param rawRoute {@link StreamingRoute} instance that has a detached close implementation. + * @param closeable {@link AsyncCloseable} implementation for the passed {@code rawRoute}. + * @param Type of request. + * @param Type of response. + * @return A new {@link StreamingRoute} that attaches the passed {@code closeable} to the passed + * {@code rawRoute}. + */ + static StreamingRoute wrap(final StreamingRoute rawRoute, + final AsyncCloseable closeable) { + return new StreamingRoute() { + @Override + public Publisher handle(final GrpcServiceContext ctx, final Publisher request) { + return rawRoute.handle(ctx, request); + } + + @Override + public Completable closeAsync() { + return closeable.closeAsync(); + } + + @Override + public Completable closeAsyncGracefully() { + return closeable.closeAsyncGracefully(); + } + }; + } } /** @@ -492,6 +554,7 @@ default Completable closeAsync() { * @param Type of request. * @param Type of response. */ + @FunctionalInterface protected interface RequestStreamingRoute extends AsyncCloseable { @@ -508,6 +571,37 @@ protected interface RequestStreamingRoute default Completable closeAsync() { return completed(); } + + /** + * Convenience method to wrap a raw {@link RequestStreamingRoute} instance with a passed detached close + * implementation of {@link AsyncCloseable}. + * + * @param rawRoute {@link RequestStreamingRoute} instance that has a detached close implementation. + * @param closeable {@link AsyncCloseable} implementation for the passed {@code rawRoute}. + * @param Type of request. + * @param Type of response. + * @return A new {@link RequestStreamingRoute} that attaches the passed {@code closeable} to the passed + * {@code rawRoute}. + */ + static RequestStreamingRoute wrap(final RequestStreamingRoute rawRoute, + final AsyncCloseable closeable) { + return new RequestStreamingRoute() { + @Override + public Single handle(final GrpcServiceContext ctx, final Publisher request) { + return rawRoute.handle(ctx, request); + } + + @Override + public Completable closeAsync() { + return closeable.closeAsync(); + } + + @Override + public Completable closeAsyncGracefully() { + return closeable.closeAsyncGracefully(); + } + }; + } } /** @@ -516,6 +610,7 @@ default Completable closeAsync() { * @param Type of request. * @param Type of response. */ + @FunctionalInterface protected interface ResponseStreamingRoute extends AsyncCloseable { @@ -532,6 +627,37 @@ protected interface ResponseStreamingRoute default Completable closeAsync() { return completed(); } + + /** + * Convenience method to wrap a raw {@link ResponseStreamingRoute} instance with a passed detached close + * implementation of {@link AsyncCloseable}. + * + * @param rawRoute {@link ResponseStreamingRoute} instance that has a detached close implementation. + * @param closeable {@link AsyncCloseable} implementation for the passed {@code rawRoute}. + * @param Type of request. + * @param Type of response. + * @return A new {@link ResponseStreamingRoute} that attaches the passed {@code closeable} to the passed + * {@code rawRoute}. + */ + static ResponseStreamingRoute wrap(final ResponseStreamingRoute rawRoute, + final AsyncCloseable closeable) { + return new ResponseStreamingRoute() { + @Override + public Publisher handle(final GrpcServiceContext ctx, final Req request) { + return rawRoute.handle(ctx, request); + } + + @Override + public Completable closeAsync() { + return closeable.closeAsync(); + } + + @Override + public Completable closeAsyncGracefully() { + return closeable.closeAsyncGracefully(); + } + }; + } } /** @@ -540,6 +666,7 @@ default Completable closeAsync() { * @param Type of request. * @param Type of response. */ + @FunctionalInterface protected interface BlockingRoute extends GracefulAutoCloseable { /** @@ -553,9 +680,40 @@ protected interface BlockingRoute Resp handle(GrpcServiceContext ctx, Req request) throws Exception; @Override - default void close() { + default void close() throws Exception { // No op } + + /** + * Convenience method to wrap a raw {@link BlockingRoute} instance with a passed detached close + * implementation of {@link GracefulAutoCloseable}. + * + * @param rawRoute {@link BlockingRoute} instance that has a detached close implementation. + * @param closeable {@link GracefulAutoCloseable} implementation for the passed {@code rawRoute}. + * @param Type of request. + * @param Type of response. + * @return A new {@link BlockingRoute} that attaches the passed {@code closeable} to the passed + * {@code rawRoute}. + */ + static BlockingRoute wrap(final BlockingRoute rawRoute, + final GracefulAutoCloseable closeable) { + return new BlockingRoute() { + @Override + public Resp handle(final GrpcServiceContext ctx, final Req request) throws Exception { + return rawRoute.handle(ctx, request); + } + + @Override + public void close() throws Exception { + closeable.close(); + } + + @Override + public void closeGracefully() throws Exception { + closeable.closeGracefully(); + } + }; + } } /** @@ -564,6 +722,7 @@ default void close() { * @param Type of request. * @param Type of response. */ + @FunctionalInterface protected interface BlockingStreamingRoute extends GracefulAutoCloseable { @@ -579,9 +738,41 @@ void handle(GrpcServiceContext ctx, BlockingIterable request, GrpcPayloadWriter responseWriter) throws Exception; @Override - default void close() { + default void close() throws Exception { // No op } + + /** + * Convenience method to wrap a raw {@link BlockingStreamingRoute} instance with a passed detached close + * implementation of {@link GracefulAutoCloseable}. + * + * @param rawRoute {@link BlockingStreamingRoute} instance that has a detached close implementation. + * @param closeable {@link GracefulAutoCloseable} implementation for the passed {@code rawRoute}. + * @param Type of request. + * @param Type of response. + * @return A new {@link BlockingStreamingRoute} that attaches the passed {@code closeable} to the passed + * {@code rawRoute}. + */ + static BlockingStreamingRoute wrap(final BlockingStreamingRoute rawRoute, + final GracefulAutoCloseable closeable) { + return new BlockingStreamingRoute() { + @Override + public void handle(final GrpcServiceContext ctx, final BlockingIterable request, + final GrpcPayloadWriter responseWriter) throws Exception { + rawRoute.handle(ctx, request, responseWriter); + } + + @Override + public void close() throws Exception { + closeable.close(); + } + + @Override + public void closeGracefully() throws Exception { + closeable.closeGracefully(); + } + }; + } } /** @@ -590,6 +781,7 @@ default void close() { * @param Type of request. * @param Type of response. */ + @FunctionalInterface protected interface BlockingRequestStreamingRoute extends GracefulAutoCloseable { @@ -604,9 +796,40 @@ protected interface BlockingRequestStreamingRoute Resp handle(GrpcServiceContext ctx, BlockingIterable request) throws Exception; @Override - default void close() { + default void close() throws Exception { // No op } + + /** + * Convenience method to wrap a raw {@link BlockingRequestStreamingRoute} instance with a passed detached close + * implementation of {@link GracefulAutoCloseable}. + * + * @param rawRoute {@link BlockingRequestStreamingRoute} instance that has a detached close implementation. + * @param closeable {@link GracefulAutoCloseable} implementation for the passed {@code rawRoute}. + * @param Type of request. + * @param Type of response. + * @return A new {@link BlockingRequestStreamingRoute} that attaches the passed {@code closeable} to the passed + * {@code rawRoute}. + */ + static BlockingRequestStreamingRoute wrap( + final BlockingRequestStreamingRoute rawRoute, final GracefulAutoCloseable closeable) { + return new BlockingRequestStreamingRoute() { + @Override + public Resp handle(final GrpcServiceContext ctx, final BlockingIterable request) throws Exception { + return rawRoute.handle(ctx, request); + } + + @Override + public void close() throws Exception { + closeable.close(); + } + + @Override + public void closeGracefully() throws Exception { + closeable.closeGracefully(); + } + }; + } } /** @@ -615,6 +838,7 @@ default void close() { * @param Type of request. * @param Type of response. */ + @FunctionalInterface protected interface BlockingResponseStreamingRoute extends GracefulAutoCloseable { @@ -629,9 +853,41 @@ protected interface BlockingResponseStreamingRoute void handle(GrpcServiceContext ctx, Req request, GrpcPayloadWriter responseWriter) throws Exception; @Override - default void close() { + default void close() throws Exception { // No op } + + /** + * Convenience method to wrap a raw {@link BlockingResponseStreamingRoute} instance with a passed detached close + * implementation of {@link GracefulAutoCloseable}. + * + * @param rawRoute {@link BlockingResponseStreamingRoute} instance that has a detached close implementation. + * @param closeable {@link GracefulAutoCloseable} implementation for the passed {@code rawRoute}. + * @param Type of request. + * @param Type of response. + * @return A new {@link BlockingResponseStreamingRoute} that attaches the passed {@code closeable} to the passed + * {@code rawRoute}. + */ + static BlockingResponseStreamingRoute wrap( + final BlockingResponseStreamingRoute rawRoute, final GracefulAutoCloseable closeable) { + return new BlockingResponseStreamingRoute() { + @Override + public void handle(final GrpcServiceContext ctx, final Req request, + final GrpcPayloadWriter responseWriter) throws Exception { + rawRoute.handle(ctx, request, responseWriter); + } + + @Override + public void close() throws Exception { + closeable.close(); + } + + @Override + public void closeGracefully() throws Exception { + closeable.closeGracefully(); + } + }; + } } /** diff --git a/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/GrpcService.java b/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/GrpcService.java index dddbdccd35..26a6f0d422 100644 --- a/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/GrpcService.java +++ b/servicetalk-grpc-api/src/main/java/io/servicetalk/grpc/api/GrpcService.java @@ -16,9 +16,16 @@ package io.servicetalk.grpc.api; import io.servicetalk.concurrent.api.AsyncCloseable; +import io.servicetalk.concurrent.api.Completable; + +import static io.servicetalk.concurrent.api.Completable.completed; /** * A gRPC service. */ public interface GrpcService extends AsyncCloseable { + @Override + default Completable closeAsync() { + return completed(); + } } diff --git a/servicetalk-grpc-netty/src/test/java/io/servicetalk/grpc/netty/ClosureTest.java b/servicetalk-grpc-netty/src/test/java/io/servicetalk/grpc/netty/ClosureTest.java new file mode 100644 index 0000000000..b0cc4685c8 --- /dev/null +++ b/servicetalk-grpc-netty/src/test/java/io/servicetalk/grpc/netty/ClosureTest.java @@ -0,0 +1,267 @@ +/* + * Copyright © 2019 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.grpc.netty; + +import io.servicetalk.concurrent.GracefulAutoCloseable; +import io.servicetalk.concurrent.api.AsyncCloseable; +import io.servicetalk.concurrent.api.Completable; +import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout; +import io.servicetalk.grpc.netty.TesterProto.Tester.BlockingTestBiDiStreamRpc; +import io.servicetalk.grpc.netty.TesterProto.Tester.BlockingTestRequestStreamRpc; +import io.servicetalk.grpc.netty.TesterProto.Tester.BlockingTestResponseStreamRpc; +import io.servicetalk.grpc.netty.TesterProto.Tester.BlockingTestRpc; +import io.servicetalk.grpc.netty.TesterProto.Tester.BlockingTesterService; +import io.servicetalk.grpc.netty.TesterProto.Tester.TestBiDiStreamRpc; +import io.servicetalk.grpc.netty.TesterProto.Tester.TestRequestStreamRpc; +import io.servicetalk.grpc.netty.TesterProto.Tester.TestResponseStreamRpc; +import io.servicetalk.grpc.netty.TesterProto.Tester.TestRpc; +import io.servicetalk.grpc.netty.TesterProto.Tester.TesterService; +import io.servicetalk.transport.api.ServerContext; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Collection; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicInteger; + +import static io.servicetalk.concurrent.api.Completable.completed; +import static io.servicetalk.concurrent.api.Completable.defer; +import static io.servicetalk.grpc.netty.GrpcServers.forAddress; +import static io.servicetalk.grpc.netty.TesterProto.Tester.ServiceFactory; +import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress; +import static java.util.Arrays.asList; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +@RunWith(Parameterized.class) +public class ClosureTest { + @Rule + public final Timeout timeout = new ServiceTalkTestTimeout(); + + private final boolean closeGracefully; + + public ClosureTest(final boolean closeGracefully) { + this.closeGracefully = closeGracefully; + } + + @Parameterized.Parameters(name = "graceful? => {0}") + public static Collection data() { + return asList(true, false); + } + + @Test + public void serviceImplIsClosed() throws Exception { + CloseSignal signal = new CloseSignal(1); + TesterService svc = setupCloseMock(mock(TesterService.class), signal); + startServerAndClose(new ServiceFactory(svc), signal); + verifyClosure(svc, 4 /* 4 rpc methods */); + signal.verifyCloseAtLeastCount(closeGracefully); + } + + @Test + public void blockingServiceImplIsClosed() throws Exception { + CloseSignal signal = new CloseSignal(1); + BlockingTesterService svc = setupBlockingCloseMock(mock(BlockingTesterService.class), signal); + startServerAndClose(new ServiceFactory(svc), signal); + verifyClosure(svc, 4 /* 4 rpc methods */); + signal.verifyCloseAtLeastCount(closeGracefully); + } + + @Test + public void rpcMethodsAreClosed() throws Exception { + CloseSignal signal = new CloseSignal(4); + TestRpc testRpc = setupCloseMock(mock(TestRpc.class), signal); + TestRequestStreamRpc testRequestStreamRpc = setupCloseMock(mock(TestRequestStreamRpc.class), signal); + TestResponseStreamRpc testResponseStreamRpc = setupCloseMock(mock(TestResponseStreamRpc.class), signal); + TestBiDiStreamRpc testBiDiStreamRpc = setupCloseMock(mock(TestBiDiStreamRpc.class), signal); + startServerAndClose(new ServiceFactory.Builder() + .test(testRpc) + .testRequestStream(testRequestStreamRpc) + .testResponseStream(testResponseStreamRpc) + .testBiDiStream(testBiDiStreamRpc) + .build(), signal); + verifyClosure(testRpc); + verifyClosure(testRequestStreamRpc); + verifyClosure(testResponseStreamRpc); + verifyClosure(testBiDiStreamRpc); + signal.verifyClose(closeGracefully); + } + + @Test + public void blockingRpcMethodsAreClosed() throws Exception { + CloseSignal signal = new CloseSignal(4); + BlockingTestRpc testRpc = setupBlockingCloseMock(mock(BlockingTestRpc.class), signal); + BlockingTestRequestStreamRpc testRequestStreamRpc = + setupBlockingCloseMock(mock(BlockingTestRequestStreamRpc.class), signal); + BlockingTestResponseStreamRpc testResponseStreamRpc = + setupBlockingCloseMock(mock(BlockingTestResponseStreamRpc.class), signal); + BlockingTestBiDiStreamRpc testBiDiStreamRpc = + setupBlockingCloseMock(mock(BlockingTestBiDiStreamRpc.class), signal); + startServerAndClose(new ServiceFactory.Builder() + .testBlocking(testRpc) + .testRequestStreamBlocking(testRequestStreamRpc) + .testResponseStreamBlocking(testResponseStreamRpc) + .testBiDiStreamBlocking(testBiDiStreamRpc) + .build(), signal); + verifyClosure(testRpc); + verifyClosure(testRequestStreamRpc); + verifyClosure(testResponseStreamRpc); + verifyClosure(testBiDiStreamRpc); + signal.verifyClose(closeGracefully); + } + + @Test + public void mixedModeRpcMethodsAreClosed() throws Exception { + CloseSignal signal = new CloseSignal(4); + TestRpc testRpc = setupCloseMock(mock(TestRpc.class), signal); + TestRequestStreamRpc testRequestStreamRpc = setupCloseMock(mock(TestRequestStreamRpc.class), signal); + BlockingTestResponseStreamRpc testResponseStreamRpc = + setupBlockingCloseMock(mock(BlockingTestResponseStreamRpc.class), signal); + BlockingTestBiDiStreamRpc testBiDiStreamRpc = + setupBlockingCloseMock(mock(BlockingTestBiDiStreamRpc.class), signal); + startServerAndClose(new ServiceFactory.Builder() + .test(testRpc) + .testRequestStream(testRequestStreamRpc) + .testResponseStreamBlocking(testResponseStreamRpc) + .testBiDiStreamBlocking(testBiDiStreamRpc) + .build(), signal); + verifyClosure(testRpc); + verifyClosure(testRequestStreamRpc); + verifyClosure(testResponseStreamRpc); + verifyClosure(testBiDiStreamRpc); + signal.verifyClose(closeGracefully); + } + + private T setupCloseMock(final T autoCloseable, final AsyncCloseable closeSignal) { + when(autoCloseable.closeAsync()).thenReturn(closeSignal.closeAsync()); + when(autoCloseable.closeAsyncGracefully()).thenReturn(closeSignal.closeAsyncGracefully()); + return autoCloseable; + } + + private T setupBlockingCloseMock(final T autoCloseable, + final AsyncCloseable closeSignal) + throws Exception { + doAnswer(__ -> closeSignal.closeAsync().toFuture().get()).when(autoCloseable).close(); + doAnswer(__ -> closeSignal.closeAsyncGracefully().toFuture().get()).when(autoCloseable).closeGracefully(); + return autoCloseable; + } + + private void verifyClosure(AsyncCloseable closeable) { + verifyClosure(closeable, 1); + } + + private void verifyClosure(AsyncCloseable closeable, int times) { + // Async mode both methods are called but one is subscribed. + verify(closeable, times(times)).closeAsyncGracefully(); + verify(closeable, times(times)).closeAsync(); + verifyNoMoreInteractions(closeable); + } + + private void verifyClosure(GracefulAutoCloseable closeable) throws Exception { + verifyClosure(closeable, 1); + } + + private void verifyClosure(GracefulAutoCloseable closeable, int times) throws Exception { + if (closeGracefully) { + verify(closeable, times(times)).closeGracefully(); + } else { + verify(closeable, times(times)).close(); + } + verifyNoMoreInteractions(closeable); + } + + private void startServerAndClose(final ServiceFactory serviceFactory, final CloseSignal signal) throws Exception { + ServerContext serverContext = forAddress(localAddress(0)) + .listenAndAwait(serviceFactory); + + if (closeGracefully) { + serverContext.closeGracefully(); + } else { + serverContext.close(); + } + + signal.await(); + } + + private static final class CloseSignal implements AsyncCloseable { + private final CountDownLatch latch; + private final int count; + private final AtomicInteger closeCount; + private final AtomicInteger gracefulCloseCount; + private final Completable close; + private final Completable closeGraceful; + + CloseSignal(final int count) { + latch = new CountDownLatch(count); + this.count = count; + closeCount = new AtomicInteger(); + gracefulCloseCount = new AtomicInteger(); + close = defer(() -> { + closeCount.incrementAndGet(); + latch.countDown(); + return completed(); + }); + closeGraceful = defer(() -> { + gracefulCloseCount.incrementAndGet(); + latch.countDown(); + return completed(); + }); + } + + void await() throws Exception { + latch.await(); + } + + void verifyClose(boolean graceful) { + assertThat("Unexpected graceful closures.", gracefulCloseCount.get(), + equalTo(graceful ? count : 0)); + assertThat("Unexpected closures.", closeCount.get(), equalTo(graceful ? 0 : count)); + } + + void verifyCloseAtLeastCount(boolean graceful) { + if (graceful) { + assertThat("Unexpected graceful closures.", gracefulCloseCount.get(), + greaterThanOrEqualTo(count)); + assertThat("Unexpected closures.", closeCount.get(), equalTo(0)); + } else { + assertThat("Unexpected graceful closures.", gracefulCloseCount.get(), + equalTo(0)); + assertThat("Unexpected closures.", closeCount.get(), greaterThanOrEqualTo(count)); + } + } + + @Override + public Completable closeAsync() { + return close; + } + + @Override + public Completable closeAsyncGracefully() { + return closeGraceful; + } + } +} diff --git a/servicetalk-grpc-netty/src/test/java/io/servicetalk/grpc/netty/ErrorHandlingTest.java b/servicetalk-grpc-netty/src/test/java/io/servicetalk/grpc/netty/ErrorHandlingTest.java index 810ffed68b..ec73418427 100644 --- a/servicetalk-grpc-netty/src/test/java/io/servicetalk/grpc/netty/ErrorHandlingTest.java +++ b/servicetalk-grpc-netty/src/test/java/io/servicetalk/grpc/netty/ErrorHandlingTest.java @@ -65,6 +65,7 @@ import java.util.concurrent.Future; import static io.servicetalk.concurrent.api.AsyncCloseables.newCompositeCloseable; +import static io.servicetalk.concurrent.api.Completable.completed; import static io.servicetalk.concurrent.api.SourceAdapters.toSource; import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION; import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress; @@ -135,7 +136,7 @@ public ErrorHandlingTest(TestMode testMode) throws Exception { this.testMode = testMode; cannedResponse = TestResponse.newBuilder().setMessage("foo").build(); ServiceFactory serviceFactory; - TesterService filter = mock(TesterService.class); + TesterService filter = mockTesterService(); StreamingHttpServiceFilterFactory serviceFilterFactory = IDENTITY_FILTER; StreamingHttpClientFilterFactory clientFilterFactory = IDENTITY_CLIENT_FILTER; switch (testMode) { @@ -230,7 +231,7 @@ public ErrorHandlingTest(TestMode testMode) throws Exception { private ServiceFactory configureFilter(final TesterService filter) { final ServiceFactory serviceFactory; - final TesterService service = mock(TesterService.class); + final TesterService service = mockTesterService(); serviceFactory = new ServiceFactory(service); serviceFactory.appendServiceFilter(original -> new ErrorSimulatingTesterServiceFilter(original, filter)); @@ -265,7 +266,7 @@ public Single test(final GrpcServiceContext ctx, final TestRequest } private ServiceFactory setupForServiceThrows(final Throwable toThrow) { - final TesterService service = mock(TesterService.class); + final TesterService service = mockTesterService(); setupForServiceThrows(service, toThrow); return new ServiceFactory(service); } @@ -278,7 +279,7 @@ private void setupForServiceThrows(final TesterService service, final Throwable } private ServiceFactory setupForServiceEmitsError(final Throwable toThrow) { - final TesterService service = mock(TesterService.class); + final TesterService service = mockTesterService(); setupForServiceEmitsError(service, toThrow); return new ServiceFactory(service); } @@ -291,7 +292,7 @@ private void setupForServiceEmitsError(final TesterService service, final Throwa } private ServiceFactory setupForServiceEmitsDataThenError(final Throwable toThrow) { - final TesterService service = mock(TesterService.class); + final TesterService service = mockTesterService(); setupForServiceEmitsDataThenError(service, toThrow); return new ServiceFactory(service); } @@ -420,6 +421,13 @@ public void responseStreamingFromBlockingClient() throws Exception { } } + private TesterService mockTesterService() { + TesterService filter = mock(TesterService.class); + when(filter.closeAsync()).thenReturn(completed()); + when(filter.closeAsyncGracefully()).thenReturn(completed()); + return filter; + } + private void verifyStreamingResponse(final BlockingIterator resp) { switch (testMode) { case ServiceEmitsDataThenError: diff --git a/servicetalk-grpc-protoc/src/main/java/io/servicetalk/grpc/protoc/Generator.java b/servicetalk-grpc-protoc/src/main/java/io/servicetalk/grpc/protoc/Generator.java index 8e7de4018b..7b6aaa8340 100644 --- a/servicetalk-grpc-protoc/src/main/java/io/servicetalk/grpc/protoc/Generator.java +++ b/servicetalk-grpc-protoc/src/main/java/io/servicetalk/grpc/protoc/Generator.java @@ -50,8 +50,12 @@ import static io.servicetalk.grpc.protoc.Types.BlockingGrpcService; import static io.servicetalk.grpc.protoc.Types.BlockingIterable; import static io.servicetalk.grpc.protoc.Types.BlockingRequestStreamingClientCall; +import static io.servicetalk.grpc.protoc.Types.BlockingRequestStreamingRoute; import static io.servicetalk.grpc.protoc.Types.BlockingResponseStreamingClientCall; +import static io.servicetalk.grpc.protoc.Types.BlockingResponseStreamingRoute; +import static io.servicetalk.grpc.protoc.Types.BlockingRoute; import static io.servicetalk.grpc.protoc.Types.BlockingStreamingClientCall; +import static io.servicetalk.grpc.protoc.Types.BlockingStreamingRoute; import static io.servicetalk.grpc.protoc.Types.ClientCall; import static io.servicetalk.grpc.protoc.Types.Completable; import static io.servicetalk.grpc.protoc.Types.DefaultGrpcClientMetadata; @@ -116,7 +120,6 @@ import static java.util.stream.Collectors.joining; import static java.util.stream.Stream.concat; import static javax.lang.model.element.Modifier.ABSTRACT; -import static javax.lang.model.element.Modifier.DEFAULT; import static javax.lang.model.element.Modifier.FINAL; import static javax.lang.model.element.Modifier.PRIVATE; import static javax.lang.model.element.Modifier.PROTECTED; @@ -256,7 +259,8 @@ private void addServiceRpcInterfaces(final State state, final TypeSpec.Builder s .addModifiers(PUBLIC) .addMethod(newRpcMethodSpec(methodProto, blocking ? EnumSet.of(BLOCKING, INTERFACE) : EnumSet.of(INTERFACE), - (__, b) -> b.addModifiers(ABSTRACT).addParameter(GrpcServiceContext, ctx))); + (__, b) -> b.addModifiers(ABSTRACT).addParameter(GrpcServiceContext, ctx))) + .addSuperinterface(blocking ? BlockingGrpcService : GrpcService); if (methodProto.hasOptions() && methodProto.getOptions().getDeprecated()) { interfaceSpecBuilder.addAnnotation(Deprecated.class); @@ -335,15 +339,16 @@ private void addServiceFactory(final State state, final TypeSpec.Builder service final String routeName = routeName(rpcInterface.methodProto); final String methodName = routeName + (rpcInterface.blocking ? Blocking : ""); final String addRouteMethodName = addRouteMethodName(rpcInterface.methodProto, rpcInterface.blocking); + final ClassName routeInterfaceClass = routeInterfaceClass(rpcInterface.methodProto, rpcInterface.blocking); serviceBuilderSpecBuilder .addMethod(methodBuilder(methodName) .addModifiers(PUBLIC) .addParameter(rpcInterface.className, rpc, FINAL) .returns(builderClass) - .addStatement("$L($T.$L.path(), $L::$L, $T.class, $T.class, $L)", addRouteMethodName, - state.rpcPathsEnumClass, routeName, rpc, routeName, inClass, outClass, - serializationProvider) + .addStatement("$L($T.$L.path(), $L.wrap($L::$L, $L), $T.class, $T.class, $L)", + addRouteMethodName, state.rpcPathsEnumClass, routeName, routeInterfaceClass, + rpc, routeName, rpc, inClass, outClass, serializationProvider) .addStatement("return this") .build()) .addMethod(methodBuilder(methodName) @@ -351,9 +356,9 @@ private void addServiceFactory(final State state, final TypeSpec.Builder service .addParameter(GrpcExecutionStrategy, strategy, FINAL) .addParameter(rpcInterface.className, rpc, FINAL) .returns(builderClass) - .addStatement("$L($T.$L.path(), $L, $L::$L, $T.class, $T.class, $L)", addRouteMethodName, - state.rpcPathsEnumClass, routeName, strategy, rpc, routeName, inClass, - outClass, serializationProvider) + .addStatement("$L($T.$L.path(), $L, $L.wrap($L::$L, $L), $T.class, $T.class, $L)", + addRouteMethodName, state.rpcPathsEnumClass, routeName, strategy, + routeInterfaceClass, rpc, routeName, rpc, inClass, outClass, serializationProvider) .addStatement("return this") .build()); }); @@ -857,23 +862,6 @@ private TypeSpec newServiceInterfaceSpec(final State state, final boolean blocki .map(e -> e.className) .forEach(interfaceSpecBuilder::addSuperinterface); - if (blocking) { - interfaceSpecBuilder - .addMethod(methodBuilder(close) - .addModifiers(DEFAULT, PUBLIC) - .addAnnotation(Override.class) - .addComment("noop") - .build()); - } else { - interfaceSpecBuilder - .addMethod(methodBuilder(closeAsync) - .addModifiers(DEFAULT, PUBLIC) - .addAnnotation(Override.class) - .returns(Completable) - .addStatement("return $T.completed()", Completable) - .build()); - } - return interfaceSpecBuilder.build(); } @@ -907,6 +895,14 @@ private static ClassName routeInterfaceClass(final MethodDescriptorProto methodP (methodProto.getServerStreaming() ? ResponseStreamingRoute : Route); } + private static ClassName routeInterfaceClass(final MethodDescriptorProto methodProto, final boolean blocking) { + return methodProto.getClientStreaming() ? + (methodProto.getServerStreaming() ? blocking ? BlockingStreamingRoute : StreamingRoute : + blocking ? BlockingRequestStreamingRoute : RequestStreamingRoute) : + (methodProto.getServerStreaming() ? blocking ? BlockingResponseStreamingRoute : ResponseStreamingRoute + : blocking ? BlockingRoute : Route); + } + private static String routeFactoryMethodName(final MethodDescriptorProto methodProto) { return (methodProto.getClientStreaming() ? (methodProto.getServerStreaming() ? "streamingR" : "requestStreamingR") : diff --git a/servicetalk-grpc-protoc/src/main/java/io/servicetalk/grpc/protoc/Types.java b/servicetalk-grpc-protoc/src/main/java/io/servicetalk/grpc/protoc/Types.java index 2fc1085cab..79d6f44cc7 100644 --- a/servicetalk-grpc-protoc/src/main/java/io/servicetalk/grpc/protoc/Types.java +++ b/servicetalk-grpc-protoc/src/main/java/io/servicetalk/grpc/protoc/Types.java @@ -73,6 +73,11 @@ final class Types { static final ClassName ResponseStreamingRoute = bestGuess(grpcRoutesFqcn + ".ResponseStreamingRoute"); static final ClassName Route = bestGuess(grpcRoutesFqcn + ".Route"); static final ClassName StreamingRoute = bestGuess(grpcRoutesFqcn + ".StreamingRoute"); + static final ClassName BlockingRequestStreamingRoute = bestGuess(grpcRoutesFqcn + ".BlockingRequestStreamingRoute"); + static final ClassName BlockingResponseStreamingRoute = bestGuess(grpcRoutesFqcn + + ".BlockingResponseStreamingRoute"); + static final ClassName BlockingRoute = bestGuess(grpcRoutesFqcn + ".BlockingRoute"); + static final ClassName BlockingStreamingRoute = bestGuess(grpcRoutesFqcn + ".BlockingStreamingRoute"); static final ClassName ProtoBufSerializationProviderBuilder = bestGuess(grpcProtobufPkg + ".ProtoBufSerializationProviderBuilder"); diff --git a/servicetalk-http-api/src/main/java/io/servicetalk/http/api/BlockingStreamingToStreamingService.java b/servicetalk-http-api/src/main/java/io/servicetalk/http/api/BlockingStreamingToStreamingService.java index 837a35dc38..80dbb0cf7f 100644 --- a/servicetalk-http-api/src/main/java/io/servicetalk/http/api/BlockingStreamingToStreamingService.java +++ b/servicetalk-http-api/src/main/java/io/servicetalk/http/api/BlockingStreamingToStreamingService.java @@ -164,6 +164,11 @@ public Completable closeAsync() { return blockingToCompletable(original::close); } + @Override + public Completable closeAsyncGracefully() { + return blockingToCompletable(original::closeGracefully); + } + private static final class BufferHttpPayloadWriter implements HttpPayloadWriter { private static final AtomicIntegerFieldUpdater subscriberCompleteUpdater = diff --git a/servicetalk-http-api/src/main/java/io/servicetalk/http/api/BlockingToStreamingService.java b/servicetalk-http-api/src/main/java/io/servicetalk/http/api/BlockingToStreamingService.java index 1676bc598d..ef42563101 100644 --- a/servicetalk-http-api/src/main/java/io/servicetalk/http/api/BlockingToStreamingService.java +++ b/servicetalk-http-api/src/main/java/io/servicetalk/http/api/BlockingToStreamingService.java @@ -44,4 +44,9 @@ public Single handle(final HttpServiceContext ctx, public Completable closeAsync() { return blockingToCompletable(original::close); } + + @Override + public Completable closeAsyncGracefully() { + return blockingToCompletable(original::closeGracefully); + } }