Skip to content
Permalink
Browse files

gRPC-client should fail responses without grpc-status code (#934)

Motivation:

gRPC protocol requires server to send grpc-status code as part
of the response. It may come in headers (when there is no
payload body) or in trailers. Current implementation considers
responses without `grpc-status` as legit responses.

Modifications:

- Ensure that the response object contains `grpc-status` code;
- Tests to verify that client throws an exception when server
does not send `grpc-status`;

Result:

gRPC-client throws an exception when server does not send
`grpc-status`.
  • Loading branch information
idelpivnitskiy committed Feb 12, 2020
1 parent b4413a1 commit 349f2ca17b2a315907fdbd2b5db3cae9ec079eb1
@@ -1,5 +1,5 @@
/*
* Copyright © 2019 Apple Inc. and the ServiceTalk project authors
* Copyright © 2019-2020 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -54,7 +54,7 @@
requireNonNull(serializationProvider);
requireNonNull(requestClass);
requireNonNull(responseClass);
HttpClient client = streamingHttpClient.asClient();
final HttpClient client = streamingHttpClient.asClient();
return (metadata, request) -> {
final HttpRequest httpRequest = newAggregatedRequest(metadata, request, client,
serializationProvider, requestClass);
@@ -91,7 +91,7 @@
public <Req, Resp> RequestStreamingClientCall<Req, Resp>
newRequestStreamingCall(final GrpcSerializationProvider serializationProvider,
final Class<Req> requestClass, final Class<Resp> responseClass) {
StreamingClientCall<Req, Resp> streamingClientCall =
final StreamingClientCall<Req, Resp> streamingClientCall =
newStreamingCall(serializationProvider, requestClass, responseClass);
return (metadata, request) -> streamingClientCall.request(metadata, request).firstOrError();
}
@@ -100,7 +100,7 @@
public <Req, Resp> ResponseStreamingClientCall<Req, Resp>
newResponseStreamingCall(final GrpcSerializationProvider serializationProvider,
final Class<Req> requestClass, final Class<Resp> responseClass) {
StreamingClientCall<Req, Resp> streamingClientCall =
final StreamingClientCall<Req, Resp> streamingClientCall =
newStreamingCall(serializationProvider, requestClass, responseClass);
return (metadata, request) -> streamingClientCall.request(metadata, Publisher.from(request));
}
@@ -112,7 +112,7 @@
requireNonNull(serializationProvider);
requireNonNull(requestClass);
requireNonNull(responseClass);
BlockingHttpClient client = streamingHttpClient.asBlockingClient();
final BlockingHttpClient client = streamingHttpClient.asBlockingClient();
return (metadata, request) -> {
final HttpRequest httpRequest = newAggregatedRequest(metadata, request, client,
serializationProvider, requestClass);
@@ -132,7 +132,7 @@
requireNonNull(serializationProvider);
requireNonNull(requestClass);
requireNonNull(responseClass);
BlockingStreamingHttpClient client = streamingHttpClient.asBlockingStreamingClient();
final BlockingStreamingHttpClient client = streamingHttpClient.asBlockingStreamingClient();
return (metadata, request) -> {
final BlockingStreamingHttpRequest httpRequest = client.post(metadata.path());
initRequest(httpRequest);
@@ -151,7 +151,7 @@
public <Req, Resp> BlockingRequestStreamingClientCall<Req, Resp>
newBlockingRequestStreamingCall(final GrpcSerializationProvider serializationProvider,
final Class<Req> requestClass, final Class<Resp> responseClass) {
BlockingStreamingClientCall<Req, Resp> streamingClientCall =
final BlockingStreamingClientCall<Req, Resp> streamingClientCall =
newBlockingStreamingCall(serializationProvider, requestClass, responseClass);
return (metadata, request) -> {
final BlockingIterator<Resp> iterator = streamingClientCall.request(metadata, request).iterator();
@@ -168,7 +168,7 @@
public <Req, Resp> BlockingResponseStreamingClientCall<Req, Resp>
newBlockingResponseStreamingCall(final GrpcSerializationProvider serializationProvider,
final Class<Req> requestClass, final Class<Resp> responseClass) {
BlockingStreamingClientCall<Req, Resp> streamingClientCall =
final BlockingStreamingClientCall<Req, Resp> streamingClientCall =
newBlockingStreamingCall(serializationProvider, requestClass, responseClass);
return (metadata, request) -> streamingClientCall.request(metadata, singletonBlockingIterable(request));
}
@@ -198,10 +198,10 @@ private GrpcMessageEncoding getMessageEncoding(final GrpcClientMetadata metadata
return None;
}

private <Req> HttpRequest newAggregatedRequest(final GrpcClientMetadata metadata, final Req rawReq,
final HttpRequestFactory requestFactory,
final GrpcSerializationProvider serializationProvider,
final Class<Req> requestClass) {
private static <Req> HttpRequest newAggregatedRequest(final GrpcClientMetadata metadata, final Req rawReq,
final HttpRequestFactory requestFactory,
final GrpcSerializationProvider serializationProvider,
final Class<Req> requestClass) {
final HttpRequest httpRequest = requestFactory.post(metadata.path());
initRequest(httpRequest);
return httpRequest.payloadBody(uncheckedCast(rawReq),
@@ -1,5 +1,5 @@
/*
* Copyright © 2019 Apple Inc. and the ServiceTalk project authors
* Copyright © 2019-2020 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@

import java.util.HashMap;
import java.util.Map;
import javax.annotation.Nullable;

import static java.lang.Integer.parseInt;
import static java.util.Collections.unmodifiableMap;
@@ -89,10 +88,7 @@
* @param codeValue code value.
* @return status code associated with the code value, or {@link #UNKNOWN}.
*/
public static GrpcStatusCode fromCodeValue(@Nullable CharSequence codeValue) {
if (codeValue == null) {
return UNKNOWN;
}
public static GrpcStatusCode fromCodeValue(CharSequence codeValue) {
try {
return fromCodeValue(parseInt(codeValue.toString()));
} catch (NumberFormatException e) {
@@ -1,5 +1,5 @@
/*
* Copyright © 2019 Apple Inc. and the ServiceTalk project authors
* Copyright © 2019-2020 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -29,6 +29,7 @@
import io.servicetalk.http.api.StatelessTrailersTransformer;
import io.servicetalk.http.api.StreamingHttpResponse;
import io.servicetalk.http.api.StreamingHttpResponseFactory;
import io.servicetalk.http.api.TrailersTransformer;
import io.servicetalk.serialization.api.SerializationException;

import com.google.protobuf.InvalidProtocolBufferException;
@@ -38,7 +39,10 @@
import java.util.function.Supplier;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.Publisher.empty;
import static io.servicetalk.concurrent.api.Publisher.failed;
import static io.servicetalk.grpc.api.GrpcMessageEncoding.None;
import static io.servicetalk.grpc.api.GrpcStatusCode.INTERNAL;
import static io.servicetalk.http.api.CharSequences.contentEqualsIgnoreCase;
import static io.servicetalk.http.api.CharSequences.newAsciiString;
import static io.servicetalk.http.api.HttpHeaderNames.CONTENT_TYPE;
@@ -59,6 +63,14 @@
private static final CharSequence IDENTITY = newAsciiString(None.encoding());
private static final CharSequence GRPC_MESSAGE_ENCODING_KEY = newAsciiString("grpc-encoding");
private static final GrpcStatus STATUS_OK = GrpcStatus.fromCodeValue(GrpcStatusCode.OK.value());
private static final TrailersTransformer<Object, Object> ENSURE_GRPC_STATUS_RECEIVED =
new StatelessTrailersTransformer<Object>() {
@Override
protected HttpHeaders payloadComplete(final HttpHeaders trailers) {
ensureGrpcStatusReceived(trailers);
return trailers;
}
};

private GrpcUtils() {
// No instances.
@@ -138,50 +150,56 @@ static void setStatus(final HttpHeaders trailers, final Throwable cause, final B

static <Resp> Publisher<Resp> validateResponseAndGetPayload(final StreamingHttpResponse response,
final HttpDeserializer<Resp> deserializer) {
// In case of server error, gRPC may return only one HEADER frame with endStream=true. Our
// HTTP1-based implementation translates them into response headers so we need to look for
// the status in both headers and trailers. Since this is streaming response and we have the headers now, we
// check for error here first. If we see trailers later in payloadBodyAndTrailers(), we will check for error
// there.
final HttpHeaders respHeaders = response.headers();
GrpcStatusException grpcStatusException = extractGrpcExceptionFromHeaders(respHeaders);
if (grpcStatusException != null) {
return Publisher.failed(grpcStatusException);
// In case of an empty response, gRPC-server may return only one HEADER frame with endStream=true. Our
// HTTP1-based implementation translates them into response headers so we need to look for a grpc-status in both
// headers and trailers. Since this is streaming response and we have the headers now, we check for the
// grpc-status here first. If there is no grpc-status in headers, we look for it in trailers later.
final HttpHeaders headers = response.headers();
final GrpcStatusCode grpcStatusCode = extractGrpcStatusCodeFromHeaders(headers);
if (grpcStatusCode != null) {
final GrpcStatusException grpcStatusException = convertToGrpcStatusException(grpcStatusCode, headers);
return response.payloadBodyAndTrailers().ignoreElements()
.concat(grpcStatusException != null ? failed(grpcStatusException) : empty());
}
return deserializer.deserialize(respHeaders, response.payloadBodyAndTrailers().map(o -> {
if (o instanceof HttpHeaders) {
// We have already checked for error in headers above, now we just check in trailers.
GrpcStatusException ex = extractGrpcExceptionFromHeaders((HttpHeaders) o);
if (ex != null) {
throw ex;
}
} else if (!(o instanceof Buffer)) {
throw new IllegalArgumentException("Unexpected payload type: " + o.getClass());
}
return o;
}).filter(o -> !(o instanceof HttpHeaders)).map(o -> (Buffer) o));

response.transformRaw(ENSURE_GRPC_STATUS_RECEIVED);
return deserializer.deserialize(headers, response.payloadBodyAndTrailers()
.filter(o -> !(o instanceof HttpHeaders)).map(o -> (Buffer) o));
}

static <Resp> Resp validateResponseAndGetPayload(final HttpResponse response,
final HttpDeserializer<Resp> deserializer) {
final HttpHeaders trailers = response.trailers();
final HttpHeaders headers = response.headers();
// In case of server error, gRPC may return only one HEADER frame with endStream=true. Our
// HTTP1-based implementation translates them into response headers so we need to look for
// the status in both headers and trailers.
// In case of an empty response, gRPC-server may return only one HEADER frame with endStream=true. Our
// HTTP1-based implementation translates them into response headers so we need to look for a grpc-status in both
// headers and trailers.

// We will try the trailers first as this is the most likely place to find the GRPC related headers.
GrpcStatusException grpcStatusException = extractGrpcExceptionFromHeaders(trailers);
if (grpcStatusException != null) {
throw grpcStatusException;
// We will try the trailers first as this is the most likely place to find the gRPC-related headers.
final HttpHeaders trailers = response.trailers();
GrpcStatusCode grpcStatusCode = extractGrpcStatusCodeFromHeaders(trailers);
if (grpcStatusCode != null) {
final GrpcStatusException grpcStatusException = convertToGrpcStatusException(grpcStatusCode, trailers);
if (grpcStatusException != null) {
throw grpcStatusException;
}
return response.payloadBody(deserializer);
}

// There was no grpc-status in the trailers, so error may be in the headers.
grpcStatusException = extractGrpcExceptionFromHeaders(headers);
// There was no grpc-status in the trailers, so it must be in headers.
ensureGrpcStatusReceived(response.headers());
return response.payloadBody(deserializer);
}

private static void ensureGrpcStatusReceived(final HttpHeaders headers) {
final GrpcStatusCode statusCode = extractGrpcStatusCodeFromHeaders(headers);
if (statusCode == null) {
// This is a protocol violation as we expect to receive grpc-status.
throw new GrpcStatus(INTERNAL, null, "Response does not contain " + GRPC_STATUS_CODE_TRAILER +
" header or trailer").asException();
}
final GrpcStatusException grpcStatusException = convertToGrpcStatusException(statusCode, headers);
if (grpcStatusException != null) {
throw grpcStatusException;
}
return response.payloadBody(deserializer);
}

static GrpcMessageEncoding readGrpcMessageEncoding(final HttpMetaData httpMetaData) {
@@ -203,17 +221,22 @@ private static void initResponse(final HttpResponseMetaData response) {
}

@Nullable
private static GrpcStatusException extractGrpcExceptionFromHeaders(final HttpHeaders headers) {
private static GrpcStatusCode extractGrpcStatusCodeFromHeaders(final HttpHeaders headers) {
final CharSequence statusCode = headers.get(GRPC_STATUS_CODE_TRAILER);
if (statusCode != null) {
final GrpcStatusCode grpcStatusCode = GrpcStatusCode.fromCodeValue(statusCode);
if (grpcStatusCode.value() != GrpcStatusCode.OK.value()) {
final GrpcStatus grpcStatus =
new GrpcStatus(grpcStatusCode, null, headers.get(GRPC_STATUS_MESSAGE_TRAILER));
return grpcStatus.asException(new StatusSupplier(headers, grpcStatus));
}
if (statusCode == null) {
return null;
}
return GrpcStatusCode.fromCodeValue(statusCode);
}

@Nullable
private static GrpcStatusException convertToGrpcStatusException(final GrpcStatusCode grpcStatusCode,
final HttpHeaders headers) {
if (grpcStatusCode.value() == GrpcStatusCode.OK.value()) {
return null;
}
return null;
final GrpcStatus grpcStatus = new GrpcStatus(grpcStatusCode, null, headers.get(GRPC_STATUS_MESSAGE_TRAILER));
return grpcStatus.asException(new StatusSupplier(headers, grpcStatus));
}

@Nullable

0 comments on commit 349f2ca

Please sign in to comment.
You can’t perform that action at this time.