Skip to content

Commit

Permalink
Expose protocol name at ConnectionContext API (#955)
Browse files Browse the repository at this point in the history
Motivation:

Users should have a way to determine which application level
protocol is used by each connection.

Modifications:

- Introduce a new `Protocol` interface for `transport-api`;
- Add `ConnectionContext.protocol()` method;
- Add `HttpProtocol` interface that extends `Protocol` for `http-api`;
- `HttpProtocolVersion` implements `HttpProtocol` interface;
- Add `HttpConnectionContext` that returns `HttpProtocol` for
`#protocol()` method;
- Return `HttpConnectionContext` everywhere in `http-api`;
- Add `GrpcProtocol` interface that extends `Protocol` for `grpc-api`;
- `GrpcServiceContext overrides `HttpServiceContext` and returns
`GrpcProtocol` for `#protocol()` method;
- Add tests to verify new API;

Result:

Users can understand the protocol used by each connection.
  • Loading branch information
idelpivnitskiy committed Mar 7, 2020
1 parent 231a641 commit 001f579
Show file tree
Hide file tree
Showing 65 changed files with 829 additions and 235 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.servicetalk.grpc.api;

import io.servicetalk.concurrent.api.Completable;
import io.servicetalk.http.api.HttpConnectionContext.HttpProtocol;
import io.servicetalk.http.api.HttpServiceContext;
import io.servicetalk.transport.api.ConnectionContext;

Expand All @@ -30,11 +31,13 @@ final class DefaultGrpcServiceContext extends DefaultGrpcMetadata implements Grp

private final ConnectionContext connectionContext;
private final GrpcExecutionContext executionContext;
private final GrpcProtocol protocol;

DefaultGrpcServiceContext(final String path, final HttpServiceContext httpServiceContext) {
super(path);
connectionContext = requireNonNull(httpServiceContext);
executionContext = new DefaultGrpcExecutionContext(httpServiceContext.executionContext());
protocol = new DefaultGrpcProtocol(httpServiceContext.protocol());
}

@Override
Expand Down Expand Up @@ -64,6 +67,11 @@ public <T> T socketOption(final SocketOption<T> option) {
return connectionContext.socketOption(option);
}

@Override
public GrpcProtocol protocol() {
return protocol;
}

@Override
public Completable onClose() {
return connectionContext.onClose();
Expand All @@ -78,4 +86,22 @@ public Completable closeAsync() {
public Completable closeAsyncGracefully() {
return connectionContext.closeAsyncGracefully();
}

private static final class DefaultGrpcProtocol implements GrpcProtocol {
private final HttpProtocol httpProtocol;

private DefaultGrpcProtocol(final HttpProtocol httpProtocol) {
this.httpProtocol = requireNonNull(httpProtocol);
}

@Override
public String name() {
return "gRPC";
}

@Override
public HttpProtocol httpProtocol() {
return httpProtocol;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
*/
package io.servicetalk.grpc.api;

import io.servicetalk.http.api.HttpExecutionContext;
import io.servicetalk.transport.api.ExecutionContext;

/**
* An extension of {@link ExecutionContext} for <a href="https://www.grpc.io">gRPC</a>.
*/
public interface GrpcExecutionContext extends ExecutionContext {
public interface GrpcExecutionContext extends HttpExecutionContext {

/**
* Returns the {@link GrpcExecutionStrategy} associated with this context.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package io.servicetalk.grpc.api;

import io.servicetalk.http.api.HttpConnectionContext.HttpProtocol;
import io.servicetalk.transport.api.ConnectionContext;

/**
Expand All @@ -24,4 +25,12 @@ public interface GrpcServiceContext extends ConnectionContext, GrpcMetadata {

@Override
GrpcExecutionContext executionContext();

@Override
GrpcProtocol protocol();

interface GrpcProtocol extends Protocol {

HttpProtocol httpProtocol();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/*
* Copyright © 2020 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.grpc.netty;

import io.servicetalk.concurrent.BlockingIterable;
import io.servicetalk.concurrent.BlockingIterator;
import io.servicetalk.concurrent.api.Publisher;
import io.servicetalk.concurrent.api.Single;
import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout;
import io.servicetalk.grpc.api.GrpcPayloadWriter;
import io.servicetalk.grpc.api.GrpcServiceContext;
import io.servicetalk.grpc.netty.TesterProto.TestRequest;
import io.servicetalk.grpc.netty.TesterProto.TestResponse;
import io.servicetalk.grpc.netty.TesterProto.Tester.BlockingTesterClient;
import io.servicetalk.grpc.netty.TesterProto.Tester.BlockingTesterService;
import io.servicetalk.grpc.netty.TesterProto.Tester.ClientFactory;
import io.servicetalk.grpc.netty.TesterProto.Tester.ServiceFactory;
import io.servicetalk.grpc.netty.TesterProto.Tester.TesterService;
import io.servicetalk.http.api.HttpConnectionContext.HttpProtocol;
import io.servicetalk.http.api.HttpProtocolConfig;
import io.servicetalk.transport.api.ServerContext;

import org.junit.After;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;

import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.Publisher.from;
import static io.servicetalk.concurrent.api.Single.succeeded;
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1;
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_2_0;
import static io.servicetalk.http.netty.HttpProtocolConfigs.h1Default;
import static io.servicetalk.http.netty.HttpProtocolConfigs.h2Default;
import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress;
import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort;
import static java.util.Collections.singleton;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;

@RunWith(Parameterized.class)
public class GrpcServiceContextProtocolTest {

@Rule
public final Timeout timeout = new ServiceTalkTestTimeout();

private final String expectedValue;
private final ServerContext serverContext;
private final BlockingTesterClient client;

public GrpcServiceContextProtocolTest(HttpProtocol httpProtocol, boolean streamingService) throws Exception {
expectedValue = "gRPC over " + httpProtocol;

serverContext = GrpcServers.forAddress(localAddress(0))
.protocols(protocolConfig(httpProtocol))
.listenAndAwait(streamingService ?
new ServiceFactory(new TesterServiceImpl()) :
new ServiceFactory(new BlockingTesterServiceImpl()));

client = GrpcClients.forAddress(serverHostAndPort(serverContext))
.protocols(protocolConfig(httpProtocol))
.buildBlocking(new ClientFactory());
}

@Parameters(name = "httpVersion={0} streamingService={0}")
public static Object[] params() {
return new Object[][]{{HTTP_2_0, true}, {HTTP_2_0, false}, {HTTP_1_1, true}, {HTTP_1_1, false}};
}

private static HttpProtocolConfig protocolConfig(HttpProtocol httpProtocol) {
if (httpProtocol == HTTP_2_0) {
return h2Default();
}
if (httpProtocol == HTTP_1_1) {
return h1Default();
}
throw new IllegalArgumentException("Unknown httpProtocol: " + httpProtocol);
}

@After
public void tearDown() throws Exception {
try {
client.close();
} finally {
serverContext.close();
}
}

@Test
public void testAggregated() throws Exception {
assertResponse(client.test(newRequest()));
}

@Test
public void testRequestStream() throws Exception {
assertResponse(client.testRequestStream(singleton(newRequest())));
}

@Test
public void testBiDiStream() throws Exception {
try (BlockingIterator<TestResponse> iterator = client.testBiDiStream(singleton(newRequest())).iterator()) {
assertResponse(iterator.next());
assertThat(iterator.hasNext(), is(false));
}
}

@Test
public void testResponseStream() throws Exception {
try (BlockingIterator<TestResponse> iterator = client.testResponseStream(newRequest()).iterator()) {
assertResponse(iterator.next());
assertThat(iterator.hasNext(), is(false));
}
}

private void assertResponse(@Nullable TestResponse response) {
assertThat(response, is(notNullValue()));
assertThat(response.getMessage(), equalTo(expectedValue));
}

private static TestRequest newRequest() {
return TestRequest.newBuilder().setName("request").build();
}

private static TestResponse newResponse(GrpcServiceContext ctx) {
return TestResponse.newBuilder()
.setMessage(ctx.protocol().name() + " over " + ctx.protocol().httpProtocol())
.build();
}

private static class TesterServiceImpl implements TesterService {

@Override
public Single<TestResponse> test(GrpcServiceContext ctx, TestRequest request) {
return succeeded(newResponse(ctx));
}

@Override
public Single<TestResponse> testRequestStream(GrpcServiceContext ctx, Publisher<TestRequest> request) {
return succeeded(newResponse(ctx));
}

@Override
public Publisher<TestResponse> testBiDiStream(GrpcServiceContext ctx, Publisher<TestRequest> request) {
return from(newResponse(ctx));
}

@Override
public Publisher<TestResponse> testResponseStream(GrpcServiceContext ctx, TestRequest request) {
return from(newResponse(ctx));
}
}

private static class BlockingTesterServiceImpl implements BlockingTesterService {
@Override
public TestResponse test(GrpcServiceContext ctx, TestRequest request) {
return newResponse(ctx);
}

@Override
public TestResponse testRequestStream(GrpcServiceContext ctx,
BlockingIterable<TestRequest> request) {
return newResponse(ctx);
}

@Override
public void testBiDiStream(GrpcServiceContext ctx, BlockingIterable<TestRequest> request,
GrpcPayloadWriter<TestResponse> responseWriter) throws Exception {
responseWriter.write(newResponse(ctx));
}

@Override
public void testResponseStream(GrpcServiceContext ctx, TestRequest request,
GrpcPayloadWriter<TestResponse> responseWriter) throws Exception {
responseWriter.write(newResponse(ctx));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,16 @@ public class SingleRequestOrResponseApiTest {
@Rule
public final Timeout timeout = new ServiceTalkTestTimeout();

private final boolean streamingServer;
private final boolean streamingService;
private final boolean streamingClient;
private final ServerContext serverContext;
private final GrpcClientBuilder<HostAndPort, InetSocketAddress> clientBuilder;

public SingleRequestOrResponseApiTest(boolean streamingServer, boolean streamingClient) throws Exception {
this.streamingServer = streamingServer;
public SingleRequestOrResponseApiTest(boolean streamingService, boolean streamingClient) throws Exception {
this.streamingService = streamingService;
this.streamingClient = streamingClient;

serverContext = GrpcServers.forAddress(localAddress(0)).listenAndAwait(streamingServer ?
serverContext = GrpcServers.forAddress(localAddress(0)).listenAndAwait(streamingService ?
new ServiceFactory(new TesterServiceImpl()) :
new ServiceFactory(new BlockingTesterServiceImpl()));

Expand All @@ -103,7 +103,7 @@ protected Single<StreamingHttpResponse> request(StreamingHttpRequester delegate,
});
}

@Parameters(name = "streamingServer={0}, streamingServer={1}")
@Parameters(name = "streamingService={0}, streamingClient={1}")
public static Object[][] params() {
return new Object[][]{{false, false}, {false, true}, {true, false}};
}
Expand Down Expand Up @@ -148,7 +148,7 @@ public void clientRequestStreamingCallFailsOnSecondResponseItem() throws Excepti

private <T extends Throwable> void clientRequestStreamingCallFailsOnInvalidResponse(
int numberOfResponses, Class<T> exceptionClass) throws Exception {
assumeFalse(streamingServer); // No need to run the test with different server-side
assumeFalse(streamingService); // No need to run the test with different server-side
if (streamingClient) {
try (TesterClient client = newClient()) {
ExecutionException e = assertThrows(ExecutionException.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import io.servicetalk.concurrent.BlockingIterable;
import io.servicetalk.concurrent.PublisherSource.Subscriber;
import io.servicetalk.transport.api.ConnectionContext;

/**
* The equivalent of {@link HttpConnection} but with synchronous/blocking APIs instead of asynchronous APIs.
Expand All @@ -33,11 +32,11 @@ public interface BlockingHttpConnection extends BlockingHttpRequester {
HttpResponse request(HttpRequest request) throws Exception;

/**
* Get the {@link ConnectionContext}.
* Get the {@link HttpConnectionContext}.
*
* @return the {@link ConnectionContext}.
* @return the {@link HttpConnectionContext}.
*/
ConnectionContext connectionContext();
HttpConnectionContext connectionContext();

/**
* Returns a {@link BlockingIterable} that gives the current value of the setting as well as subsequent changes to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import io.servicetalk.concurrent.BlockingIterable;
import io.servicetalk.concurrent.PublisherSource;
import io.servicetalk.transport.api.ConnectionContext;

/**
* The equivalent of {@link StreamingHttpConnection} but with synchronous/blocking APIs instead of asynchronous APIs.
Expand All @@ -33,11 +32,11 @@ public interface BlockingStreamingHttpConnection extends BlockingStreamingHttpRe
BlockingStreamingHttpResponse request(BlockingStreamingHttpRequest request) throws Exception;

/**
* Get the {@link ConnectionContext}.
* Get the {@link HttpConnectionContext}.
*
* @return the {@link ConnectionContext}.
* @return the {@link HttpConnectionContext}.
*/
ConnectionContext connectionContext();
HttpConnectionContext connectionContext();

/**
* Returns a {@link BlockingIterable} that gives the current value of the setting as well as subsequent changes to
Expand Down
Loading

0 comments on commit 001f579

Please sign in to comment.