Skip to content

Commit

Permalink
Make calls to endpoint discovery non-blocking for asynchornous clients
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudh9391 committed May 9, 2024
1 parent 019159d commit e6099b8
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -429,14 +429,14 @@ protected MethodSpec.Builder operationBody(MethodSpec.Builder builder, Operation
AwsClientOption.class)
.addCode(" .resolveIdentity();");

builder.addCode("endpointFuture = identityFuture.thenApply(credentials -> {")
builder.addCode("endpointFuture = identityFuture.thenCompose(credentials -> {")
.addCode(" $1T endpointDiscoveryRequest = $1T.builder()", EndpointDiscoveryRequest.class)
.addCode(" .required($L)", opModel.getInputShape().getEndpointDiscovery().isRequired())
.addCode(" .defaultEndpoint(clientConfiguration.option($T.ENDPOINT))", SdkClientOption.class)
.addCode(" .overrideConfiguration($N.overrideConfiguration().orElse(null))",
opModel.getInput().getVariableName())
.addCode(" .build();")
.addCode(" return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest);")
.addCode(" return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest);")
.addCode("});");

builder.endControlFlow();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,12 @@ public CompletableFuture<TestDiscoveryIdentifiersRequiredResponse> testDiscovery
.overrideConfiguration().flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider)
.orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER))
.resolveIdentity();
endpointFuture = identityFuture.thenApply(credentials -> {
endpointFuture = identityFuture.thenCompose(credentials -> {
EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true)
.defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT))
.overrideConfiguration(testDiscoveryIdentifiersRequiredRequest.overrideConfiguration().orElse(null))
.build();
return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest);
return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest);
});
}

Expand Down Expand Up @@ -267,11 +267,11 @@ public CompletableFuture<TestDiscoveryOptionalResponse> testDiscoveryOptional(
.overrideConfiguration().flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider)
.orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER))
.resolveIdentity();
endpointFuture = identityFuture.thenApply(credentials -> {
endpointFuture = identityFuture.thenCompose(credentials -> {
EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(false)
.defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT))
.overrideConfiguration(testDiscoveryOptionalRequest.overrideConfiguration().orElse(null)).build();
return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest);
return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest);
});
}

Expand Down Expand Up @@ -348,11 +348,11 @@ public CompletableFuture<TestDiscoveryRequiredResponse> testDiscoveryRequired(
.overrideConfiguration().flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider)
.orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER))
.resolveIdentity();
endpointFuture = identityFuture.thenApply(credentials -> {
endpointFuture = identityFuture.thenCompose(credentials -> {
EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true)
.defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT))
.overrideConfiguration(testDiscoveryRequiredRequest.overrideConfiguration().orElse(null)).build();
return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest);
return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.time.Instant;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import software.amazon.awssdk.annotations.SdkProtectedApi;
Expand All @@ -29,6 +30,8 @@ public final class EndpointDiscoveryRefreshCache {

private final EndpointDiscoveryCacheLoader client;

private String key;

private EndpointDiscoveryRefreshCache(EndpointDiscoveryCacheLoader client) {
this.client = client;
}
Expand Down Expand Up @@ -61,22 +64,21 @@ public URI get(String accessKey, EndpointDiscoveryRequest request) {
if (endpoint == null) {
if (request.required()) {
return cache.computeIfAbsent(key, k -> getAndJoin(request)).endpoint();
}
EndpointDiscoveryEndpoint tempEndpoint = EndpointDiscoveryEndpoint.builder()
.endpoint(request.defaultEndpoint())
.expirationTime(Instant.now().plusSeconds(60))
.build();

EndpointDiscoveryEndpoint previousValue = cache.putIfAbsent(key, tempEndpoint);
if (previousValue != null) {
// Someone else primed the cache. Use that endpoint (which may be temporary).
return previousValue.endpoint();
} else {
EndpointDiscoveryEndpoint tempEndpoint = EndpointDiscoveryEndpoint.builder()
.endpoint(request.defaultEndpoint())
.expirationTime(Instant.now().plusSeconds(60))
.build();

EndpointDiscoveryEndpoint previousValue = cache.putIfAbsent(key, tempEndpoint);
if (previousValue != null) {
// Someone else primed the cache. Use that endpoint (which may be temporary).
return previousValue.endpoint();
} else {
// We primed the cache with the temporary endpoint. Kick off discovery in the background.
refreshCacheAsync(request, key);
}
return tempEndpoint.endpoint();
// We primed the cache with the temporary endpoint. Kick off discovery in the background.
refreshCacheAsync(request, key);
}
return tempEndpoint.endpoint();
}

if (endpoint.expirationTime().isBefore(Instant.now())) {
Expand All @@ -87,6 +89,64 @@ public URI get(String accessKey, EndpointDiscoveryRequest request) {
return endpoint.endpoint();
}

public CompletableFuture<URI> getAsync(String accessKey, EndpointDiscoveryRequest request) {
key = accessKey;

// Support null (anonymous credentials) by mapping to empty-string. The backing cache does not support null.
if (key == null) {
key = "";
}

if (request.cacheKey().isPresent()) {
key = key + ":" + request.cacheKey().get();
}

EndpointDiscoveryEndpoint endpoint = cache.get(key);

if (endpoint == null) {
if (request.required()) {
return discoverEndpoint(request).handle(
(endpointDiscoveryEndpoint, throwable) -> {
if (throwable != null) {
if (throwable instanceof InterruptedException) {
Thread.currentThread().interrupt();
throw EndpointDiscoveryFailedException.create(throwable);
}
if (throwable instanceof ExecutionException
|| throwable instanceof CompletionException) {
throw EndpointDiscoveryFailedException.create(throwable.getCause());
}
throw new RuntimeException("new exception");
}
return cache.computeIfAbsent(
key, k -> endpointDiscoveryEndpoint
).endpoint();
});
}
EndpointDiscoveryEndpoint tempEndpoint = EndpointDiscoveryEndpoint.builder()
.endpoint(request.defaultEndpoint())
.expirationTime(Instant.now().plusSeconds(60))
.build();

EndpointDiscoveryEndpoint previousValue = cache.putIfAbsent(key, tempEndpoint);
if (previousValue != null) {
// Someone else primed the cache. Use that endpoint (which may be temporary).
return CompletableFuture.completedFuture(previousValue.endpoint());
} else {
// We primed the cache with the temporary endpoint. Kick off discovery in the background.
refreshCacheAsync(request, key);
}
return CompletableFuture.completedFuture(tempEndpoint.endpoint());
}

if (endpoint.expirationTime().isBefore(Instant.now())) {
cache.put(key, endpoint.toBuilder().expirationTime(Instant.now().plusSeconds(60)).build());
refreshCacheAsync(request, key);
}

return CompletableFuture.completedFuture(endpoint.endpoint());
}

private EndpointDiscoveryEndpoint getAndJoin(EndpointDiscoveryRequest request) {
try {
return discoverEndpoint(request).get();
Expand Down

0 comments on commit e6099b8

Please sign in to comment.