diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java index 4490e00fb9f4..034ad680ffb8 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java @@ -429,14 +429,14 @@ protected MethodSpec.Builder operationBody(MethodSpec.Builder builder, Operation AwsClientOption.class) .addCode(" .resolveIdentity();"); - builder.addCode("endpointFuture = identityFuture.thenApply(credentials -> {") + builder.addCode("endpointFuture = identityFuture.thenCompose(credentials -> {") .addCode(" $1T endpointDiscoveryRequest = $1T.builder()", EndpointDiscoveryRequest.class) .addCode(" .required($L)", opModel.getInputShape().getEndpointDiscovery().isRequired()) .addCode(" .defaultEndpoint(clientConfiguration.option($T.ENDPOINT))", SdkClientOption.class) .addCode(" .overrideConfiguration($N.overrideConfiguration().orElse(null))", opModel.getInput().getVariableName()) .addCode(" .build();") - .addCode(" return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest);") + .addCode(" return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest);") .addCode("});"); builder.endControlFlow(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java index 8518ad4db2fc..5618a0538bee 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java @@ -193,12 +193,12 @@ public CompletableFuture testDiscovery .overrideConfiguration().flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider) .orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER)) .resolveIdentity(); - endpointFuture = identityFuture.thenApply(credentials -> { + endpointFuture = identityFuture.thenCompose(credentials -> { EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true) .defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT)) .overrideConfiguration(testDiscoveryIdentifiersRequiredRequest.overrideConfiguration().orElse(null)) .build(); - return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest); + return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest); }); } @@ -267,11 +267,11 @@ public CompletableFuture testDiscoveryOptional( .overrideConfiguration().flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider) .orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER)) .resolveIdentity(); - endpointFuture = identityFuture.thenApply(credentials -> { + endpointFuture = identityFuture.thenCompose(credentials -> { EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(false) .defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT)) .overrideConfiguration(testDiscoveryOptionalRequest.overrideConfiguration().orElse(null)).build(); - return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest); + return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest); }); } @@ -348,11 +348,11 @@ public CompletableFuture testDiscoveryRequired( .overrideConfiguration().flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider) .orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER)) .resolveIdentity(); - endpointFuture = identityFuture.thenApply(credentials -> { + endpointFuture = identityFuture.thenCompose(credentials -> { EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true) .defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT)) .overrideConfiguration(testDiscoveryRequiredRequest.overrideConfiguration().orElse(null)).build(); - return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest); + return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest); }); } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/endpointdiscovery/EndpointDiscoveryRefreshCache.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/endpointdiscovery/EndpointDiscoveryRefreshCache.java index 982b3427216d..b4ce8e5953dd 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/endpointdiscovery/EndpointDiscoveryRefreshCache.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/endpointdiscovery/EndpointDiscoveryRefreshCache.java @@ -19,6 +19,7 @@ import java.time.Instant; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import software.amazon.awssdk.annotations.SdkProtectedApi; @@ -29,6 +30,8 @@ public final class EndpointDiscoveryRefreshCache { private final EndpointDiscoveryCacheLoader client; + private String key; + private EndpointDiscoveryRefreshCache(EndpointDiscoveryCacheLoader client) { this.client = client; } @@ -61,22 +64,21 @@ public URI get(String accessKey, EndpointDiscoveryRequest request) { if (endpoint == null) { if (request.required()) { return cache.computeIfAbsent(key, k -> getAndJoin(request)).endpoint(); + } + EndpointDiscoveryEndpoint tempEndpoint = EndpointDiscoveryEndpoint.builder() + .endpoint(request.defaultEndpoint()) + .expirationTime(Instant.now().plusSeconds(60)) + .build(); + + EndpointDiscoveryEndpoint previousValue = cache.putIfAbsent(key, tempEndpoint); + if (previousValue != null) { + // Someone else primed the cache. Use that endpoint (which may be temporary). + return previousValue.endpoint(); } else { - EndpointDiscoveryEndpoint tempEndpoint = EndpointDiscoveryEndpoint.builder() - .endpoint(request.defaultEndpoint()) - .expirationTime(Instant.now().plusSeconds(60)) - .build(); - - EndpointDiscoveryEndpoint previousValue = cache.putIfAbsent(key, tempEndpoint); - if (previousValue != null) { - // Someone else primed the cache. Use that endpoint (which may be temporary). - return previousValue.endpoint(); - } else { - // We primed the cache with the temporary endpoint. Kick off discovery in the background. - refreshCacheAsync(request, key); - } - return tempEndpoint.endpoint(); + // We primed the cache with the temporary endpoint. Kick off discovery in the background. + refreshCacheAsync(request, key); } + return tempEndpoint.endpoint(); } if (endpoint.expirationTime().isBefore(Instant.now())) { @@ -87,6 +89,64 @@ public URI get(String accessKey, EndpointDiscoveryRequest request) { return endpoint.endpoint(); } + public CompletableFuture getAsync(String accessKey, EndpointDiscoveryRequest request) { + key = accessKey; + + // Support null (anonymous credentials) by mapping to empty-string. The backing cache does not support null. + if (key == null) { + key = ""; + } + + if (request.cacheKey().isPresent()) { + key = key + ":" + request.cacheKey().get(); + } + + EndpointDiscoveryEndpoint endpoint = cache.get(key); + + if (endpoint == null) { + if (request.required()) { + return discoverEndpoint(request).handle( + (endpointDiscoveryEndpoint, throwable) -> { + if (throwable != null) { + if (throwable instanceof InterruptedException) { + Thread.currentThread().interrupt(); + throw EndpointDiscoveryFailedException.create(throwable); + } + if (throwable instanceof ExecutionException + || throwable instanceof CompletionException) { + throw EndpointDiscoveryFailedException.create(throwable.getCause()); + } + throw new RuntimeException("new exception"); + } + return cache.computeIfAbsent( + key, k -> endpointDiscoveryEndpoint + ).endpoint(); + }); + } + EndpointDiscoveryEndpoint tempEndpoint = EndpointDiscoveryEndpoint.builder() + .endpoint(request.defaultEndpoint()) + .expirationTime(Instant.now().plusSeconds(60)) + .build(); + + EndpointDiscoveryEndpoint previousValue = cache.putIfAbsent(key, tempEndpoint); + if (previousValue != null) { + // Someone else primed the cache. Use that endpoint (which may be temporary). + return CompletableFuture.completedFuture(previousValue.endpoint()); + } else { + // We primed the cache with the temporary endpoint. Kick off discovery in the background. + refreshCacheAsync(request, key); + } + return CompletableFuture.completedFuture(tempEndpoint.endpoint()); + } + + if (endpoint.expirationTime().isBefore(Instant.now())) { + cache.put(key, endpoint.toBuilder().expirationTime(Instant.now().plusSeconds(60)).build()); + refreshCacheAsync(request, key); + } + + return CompletableFuture.completedFuture(endpoint.endpoint()); + } + private EndpointDiscoveryEndpoint getAndJoin(EndpointDiscoveryRequest request) { try { return discoverEndpoint(request).get();