Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make calls to endpoint discovery non-blocking for asynchornous clients #5205

Merged
merged 19 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -429,14 +429,14 @@ protected MethodSpec.Builder operationBody(MethodSpec.Builder builder, Operation
AwsClientOption.class)
.addCode(" .resolveIdentity();");

builder.addCode("endpointFuture = identityFuture.thenApply(credentials -> {")
builder.addCode("endpointFuture = identityFuture.thenCompose(credentials -> {")
.addCode(" $1T endpointDiscoveryRequest = $1T.builder()", EndpointDiscoveryRequest.class)
.addCode(" .required($L)", opModel.getInputShape().getEndpointDiscovery().isRequired())
.addCode(" .defaultEndpoint(clientConfiguration.option($T.ENDPOINT))", SdkClientOption.class)
.addCode(" .overrideConfiguration($N.overrideConfiguration().orElse(null))",
opModel.getInput().getVariableName())
.addCode(" .build();")
.addCode(" return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest);")
.addCode(" return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest);")
.addCode("});");

builder.endControlFlow();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,12 @@ public CompletableFuture<TestDiscoveryIdentifiersRequiredResponse> testDiscovery
.overrideConfiguration().flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider)
.orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER))
.resolveIdentity();
endpointFuture = identityFuture.thenApply(credentials -> {
endpointFuture = identityFuture.thenCompose(credentials -> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we propagate exception cancellation? What happens if endpoint discovery takes longer than the configured client execution timeout? Would the request fail properly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm.. So the cancellation Exception is going to be thrown here

return discoverEndpoint(request).handle(
                (endpointDiscoveryEndpoint, throwable) -> {
                    if (throwable != null) {
                        throw new RuntimeException(throwable);

This will be propogated upwards, unless you mean something else

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wrapping it around an EndpointDiscoveryFailedException

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant if users have client execution timeout enabled and if endpoint discovery takes longer than the configured timeout, the returned future will be cancelled, but will endpoint discovery future also be cancelled?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we just need CompletableFutureUtils.forwardExceptionTo(executeFuture, endpointFuture ); on line 277.

Can we write a test to verify that?

Example: https://github.com/aws/aws-sdk-java-v2/blob/master/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/CopyObjectHelperTest.java#L314

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline. Doing this as a follow up item

EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true)
.defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT))
.overrideConfiguration(testDiscoveryIdentifiersRequiredRequest.overrideConfiguration().orElse(null))
.build();
return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest);
return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest);
});
}

Expand Down Expand Up @@ -267,11 +267,11 @@ public CompletableFuture<TestDiscoveryOptionalResponse> testDiscoveryOptional(
.overrideConfiguration().flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider)
.orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER))
.resolveIdentity();
endpointFuture = identityFuture.thenApply(credentials -> {
endpointFuture = identityFuture.thenCompose(credentials -> {
EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(false)
.defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT))
.overrideConfiguration(testDiscoveryOptionalRequest.overrideConfiguration().orElse(null)).build();
return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest);
return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest);
});
}

Expand Down Expand Up @@ -348,11 +348,11 @@ public CompletableFuture<TestDiscoveryRequiredResponse> testDiscoveryRequired(
.overrideConfiguration().flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider)
.orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER))
.resolveIdentity();
endpointFuture = identityFuture.thenApply(credentials -> {
endpointFuture = identityFuture.thenCompose(credentials -> {
EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true)
.defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT))
.overrideConfiguration(testDiscoveryRequiredRequest.overrideConfiguration().orElse(null)).build();
return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest);
return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest);
});
}

Expand Down
anirudh9391 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -45,46 +45,28 @@ public static EndpointDiscoveryRefreshCache create(EndpointDiscoveryCacheLoader
* @return The endpoint to use for this request
*/
public URI get(String accessKey, EndpointDiscoveryRequest request) {
String key = accessKey;

// Support null (anonymous credentials) by mapping to empty-string. The backing cache does not support null.
if (key == null) {
key = "";
}
String key = getKey(accessKey, request);
EndpointDiscoveryEndpoint endpoint = cache.get(key);

if (request.cacheKey().isPresent()) {
key = key + ":" + request.cacheKey().get();
if (endpoint == null && request.required()) {
return cache.computeIfAbsent(key, k -> getAndJoin(request)).endpoint();
}
return returnCachedOrDefaultEndpoint(key, endpoint, request);
}

public CompletableFuture<URI> getAsync(String accessKey, EndpointDiscoveryRequest request) {
String key = getKey(accessKey, request);
EndpointDiscoveryEndpoint endpoint = cache.get(key);

if (endpoint == null) {
if (request.required()) {
return cache.computeIfAbsent(key, k -> getAndJoin(request)).endpoint();
} else {
EndpointDiscoveryEndpoint tempEndpoint = EndpointDiscoveryEndpoint.builder()
.endpoint(request.defaultEndpoint())
.expirationTime(Instant.now().plusSeconds(60))
.build();

EndpointDiscoveryEndpoint previousValue = cache.putIfAbsent(key, tempEndpoint);
if (previousValue != null) {
// Someone else primed the cache. Use that endpoint (which may be temporary).
return previousValue.endpoint();
} else {
// We primed the cache with the temporary endpoint. Kick off discovery in the background.
refreshCacheAsync(request, key);
}
return tempEndpoint.endpoint();
}
}

if (endpoint.expirationTime().isBefore(Instant.now())) {
cache.put(key, endpoint.toBuilder().expirationTime(Instant.now().plusSeconds(60)).build());
refreshCacheAsync(request, key);
// If a service call needs to be made to discover endpoint
// a completable future for the service call is returned, unblocking I/O
// and then completed asynchronously
if (endpoint == null && request.required()) {
return discoverEndpointHandler(key, request);
}

return endpoint.endpoint();
// In the event of a cache hit, i.e. service call not required, defer to the synchronous code path method.
return CompletableFuture.completedFuture(returnCachedOrDefaultEndpoint(key, endpoint, request));
}

private EndpointDiscoveryEndpoint getAndJoin(EndpointDiscoveryRequest request) {
Expand All @@ -109,4 +91,55 @@ public CompletableFuture<EndpointDiscoveryEndpoint> discoverEndpoint(EndpointDis
public void evict(String key) {
cache.remove(key);
}

private String getKey(String accessKey, EndpointDiscoveryRequest request) {
String key = accessKey;

// Support null (anonymous credentials) by mapping to empty-string. The backing cache does not support null.
if (key == null) {
key = "";
}

if (request.cacheKey().isPresent()) {
key = key + ":" + request.cacheKey().get();
}
return key;
}

private CompletableFuture<URI> discoverEndpointHandler(String key, EndpointDiscoveryRequest request) {
return discoverEndpoint(request).handle(
(endpointDiscoveryEndpoint, throwable) -> {
if (throwable != null) {
throw EndpointDiscoveryFailedException.create(throwable.getCause());
}
return cache.computeIfAbsent(
key, k -> endpointDiscoveryEndpoint
).endpoint();
});
}

private URI returnCachedOrDefaultEndpoint(String key, EndpointDiscoveryEndpoint endpoint, EndpointDiscoveryRequest request) {
EndpointDiscoveryEndpoint tempEndpoint = EndpointDiscoveryEndpoint.builder()
.endpoint(request.defaultEndpoint())
.expirationTime(Instant.now().plusSeconds(60))
.build();

if (endpoint == null) {
EndpointDiscoveryEndpoint previousValue = cache.putIfAbsent(key, tempEndpoint);
if (previousValue != null) {
// Someone else primed the cache. Use that endpoint (which may be temporary).
return previousValue.endpoint();
}
// We primed the cache with the temporary endpoint. Kick off discovery in the background.
refreshCacheAsync(request, key);
return tempEndpoint.endpoint();
}

if (endpoint.expirationTime().isBefore(Instant.now())) {
cache.put(key, endpoint.toBuilder().expirationTime(Instant.now().plusSeconds(60)).build());
refreshCacheAsync(request, key);
}

return endpoint.endpoint();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.core.endpointdiscovery;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.net.URI;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

public class EndpointDiscoveryRefreshCacheTest {

private EndpointDiscoveryRefreshCache endpointDiscoveryRefreshCache;
private EndpointDiscoveryCacheLoader mockClient;
private static final URI testURI = URI.create("test_endpoint");
private static final String requestCacheKey = "request_cache_key";
private static final String accessKey = "access_cache_key";

@BeforeEach
public void setup() {
this.mockClient= mock(EndpointDiscoveryCacheLoader.class);
this.endpointDiscoveryRefreshCache = EndpointDiscoveryRefreshCache.create(mockClient);
}

@Test
public void getAsync_notRequired_returns_CompletedFuture() throws ExecutionException, InterruptedException {
when(mockClient.discoverEndpoint(any())).thenReturn(new CompletableFuture<>());
EndpointDiscoveryRequest request = EndpointDiscoveryRequest.builder()
.required(false)
.defaultEndpoint(testURI)
.build();
assertThat(endpointDiscoveryRefreshCache.getAsync("key", request).isDone()).isEqualTo(true);
assertThat(endpointDiscoveryRefreshCache.getAsync("key", request).get()).isEqualTo(testURI);

}

@Test
public void getAsync_returns_CompletedFuture() throws ExecutionException, InterruptedException {

when(mockClient.discoverEndpoint(any())).thenReturn(new CompletableFuture<>());
EndpointDiscoveryRequest request = EndpointDiscoveryRequest.builder()
.required(true)
.defaultEndpoint(testURI)
.build();
CompletableFuture<URI> future = endpointDiscoveryRefreshCache.getAsync("key", request);
assertThat(future.isDone()).isEqualTo(false);

future.complete(testURI);

assertThat(future.isDone()).isEqualTo(true);
assertThat(future.get()).isEqualTo(testURI);
}

@Test
public void getAsync_future_cancelled() {

when(mockClient.discoverEndpoint(any())).thenReturn(new CompletableFuture<>());
EndpointDiscoveryRequest request = EndpointDiscoveryRequest.builder()
.required(true)
.defaultEndpoint(testURI)
.build();
CompletableFuture<URI> future = endpointDiscoveryRefreshCache.getAsync("key", request);
assertThat(future.isDone()).isEqualTo(false);

future.cancel(true);
assertThat(future.isCancelled()).isEqualTo(true);
assertThrows(CancellationException.class, () -> future.get());
anirudh9391 marked this conversation as resolved.
Show resolved Hide resolved

}

}