diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java index 4490e00fb9f4..034ad680ffb8 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/AsyncClientClass.java @@ -429,14 +429,14 @@ protected MethodSpec.Builder operationBody(MethodSpec.Builder builder, Operation AwsClientOption.class) .addCode(" .resolveIdentity();"); - builder.addCode("endpointFuture = identityFuture.thenApply(credentials -> {") + builder.addCode("endpointFuture = identityFuture.thenCompose(credentials -> {") .addCode(" $1T endpointDiscoveryRequest = $1T.builder()", EndpointDiscoveryRequest.class) .addCode(" .required($L)", opModel.getInputShape().getEndpointDiscovery().isRequired()) .addCode(" .defaultEndpoint(clientConfiguration.option($T.ENDPOINT))", SdkClientOption.class) .addCode(" .overrideConfiguration($N.overrideConfiguration().orElse(null))", opModel.getInput().getVariableName()) .addCode(" .build();") - .addCode(" return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest);") + .addCode(" return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest);") .addCode("});"); builder.endControlFlow(); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java index 8518ad4db2fc..5618a0538bee 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/test-endpoint-discovery-async.java @@ -193,12 +193,12 @@ public CompletableFuture testDiscovery .overrideConfiguration().flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider) .orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER)) .resolveIdentity(); - endpointFuture = identityFuture.thenApply(credentials -> { + endpointFuture = identityFuture.thenCompose(credentials -> { EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true) .defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT)) .overrideConfiguration(testDiscoveryIdentifiersRequiredRequest.overrideConfiguration().orElse(null)) .build(); - return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest); + return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest); }); } @@ -267,11 +267,11 @@ public CompletableFuture testDiscoveryOptional( .overrideConfiguration().flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider) .orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER)) .resolveIdentity(); - endpointFuture = identityFuture.thenApply(credentials -> { + endpointFuture = identityFuture.thenCompose(credentials -> { EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(false) .defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT)) .overrideConfiguration(testDiscoveryOptionalRequest.overrideConfiguration().orElse(null)).build(); - return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest); + return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest); }); } @@ -348,11 +348,11 @@ public CompletableFuture testDiscoveryRequired( .overrideConfiguration().flatMap(AwsRequestOverrideConfiguration::credentialsIdentityProvider) .orElseGet(() -> clientConfiguration.option(AwsClientOption.CREDENTIALS_IDENTITY_PROVIDER)) .resolveIdentity(); - endpointFuture = identityFuture.thenApply(credentials -> { + endpointFuture = identityFuture.thenCompose(credentials -> { EndpointDiscoveryRequest endpointDiscoveryRequest = EndpointDiscoveryRequest.builder().required(true) .defaultEndpoint(clientConfiguration.option(SdkClientOption.ENDPOINT)) .overrideConfiguration(testDiscoveryRequiredRequest.overrideConfiguration().orElse(null)).build(); - return endpointDiscoveryCache.get(credentials.accessKeyId(), endpointDiscoveryRequest); + return endpointDiscoveryCache.getAsync(credentials.accessKeyId(), endpointDiscoveryRequest); }); } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/endpointdiscovery/EndpointDiscoveryRefreshCache.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/endpointdiscovery/EndpointDiscoveryRefreshCache.java index 982b3427216d..988e49ac3ea8 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/endpointdiscovery/EndpointDiscoveryRefreshCache.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/endpointdiscovery/EndpointDiscoveryRefreshCache.java @@ -45,46 +45,28 @@ public static EndpointDiscoveryRefreshCache create(EndpointDiscoveryCacheLoader * @return The endpoint to use for this request */ public URI get(String accessKey, EndpointDiscoveryRequest request) { - String key = accessKey; - // Support null (anonymous credentials) by mapping to empty-string. The backing cache does not support null. - if (key == null) { - key = ""; - } + String key = getKey(accessKey, request); + EndpointDiscoveryEndpoint endpoint = cache.get(key); - if (request.cacheKey().isPresent()) { - key = key + ":" + request.cacheKey().get(); + if (endpoint == null && request.required()) { + return cache.computeIfAbsent(key, k -> getAndJoin(request)).endpoint(); } + return returnCachedOrDefaultEndpoint(key, endpoint, request); + } + public CompletableFuture getAsync(String accessKey, EndpointDiscoveryRequest request) { + String key = getKey(accessKey, request); EndpointDiscoveryEndpoint endpoint = cache.get(key); - if (endpoint == null) { - if (request.required()) { - return cache.computeIfAbsent(key, k -> getAndJoin(request)).endpoint(); - } else { - EndpointDiscoveryEndpoint tempEndpoint = EndpointDiscoveryEndpoint.builder() - .endpoint(request.defaultEndpoint()) - .expirationTime(Instant.now().plusSeconds(60)) - .build(); - - EndpointDiscoveryEndpoint previousValue = cache.putIfAbsent(key, tempEndpoint); - if (previousValue != null) { - // Someone else primed the cache. Use that endpoint (which may be temporary). - return previousValue.endpoint(); - } else { - // We primed the cache with the temporary endpoint. Kick off discovery in the background. - refreshCacheAsync(request, key); - } - return tempEndpoint.endpoint(); - } - } - - if (endpoint.expirationTime().isBefore(Instant.now())) { - cache.put(key, endpoint.toBuilder().expirationTime(Instant.now().plusSeconds(60)).build()); - refreshCacheAsync(request, key); + // If a service call needs to be made to discover endpoint + // a completable future for the service call is returned, unblocking I/O + // and then completed asynchronously + if (endpoint == null && request.required()) { + return discoverEndpointHandler(key, request); } - - return endpoint.endpoint(); + // In the event of a cache hit, i.e. service call not required, defer to the synchronous code path method. + return CompletableFuture.completedFuture(returnCachedOrDefaultEndpoint(key, endpoint, request)); } private EndpointDiscoveryEndpoint getAndJoin(EndpointDiscoveryRequest request) { @@ -109,4 +91,55 @@ public CompletableFuture discoverEndpoint(EndpointDis public void evict(String key) { cache.remove(key); } + + private String getKey(String accessKey, EndpointDiscoveryRequest request) { + String key = accessKey; + + // Support null (anonymous credentials) by mapping to empty-string. The backing cache does not support null. + if (key == null) { + key = ""; + } + + if (request.cacheKey().isPresent()) { + key = key + ":" + request.cacheKey().get(); + } + return key; + } + + private CompletableFuture discoverEndpointHandler(String key, EndpointDiscoveryRequest request) { + return discoverEndpoint(request).handle( + (endpointDiscoveryEndpoint, throwable) -> { + if (throwable != null) { + throw EndpointDiscoveryFailedException.create(throwable.getCause()); + } + return cache.computeIfAbsent( + key, k -> endpointDiscoveryEndpoint + ).endpoint(); + }); + } + + private URI returnCachedOrDefaultEndpoint(String key, EndpointDiscoveryEndpoint endpoint, EndpointDiscoveryRequest request) { + EndpointDiscoveryEndpoint tempEndpoint = EndpointDiscoveryEndpoint.builder() + .endpoint(request.defaultEndpoint()) + .expirationTime(Instant.now().plusSeconds(60)) + .build(); + + if (endpoint == null) { + EndpointDiscoveryEndpoint previousValue = cache.putIfAbsent(key, tempEndpoint); + if (previousValue != null) { + // Someone else primed the cache. Use that endpoint (which may be temporary). + return previousValue.endpoint(); + } + // We primed the cache with the temporary endpoint. Kick off discovery in the background. + refreshCacheAsync(request, key); + return tempEndpoint.endpoint(); + } + + if (endpoint.expirationTime().isBefore(Instant.now())) { + cache.put(key, endpoint.toBuilder().expirationTime(Instant.now().plusSeconds(60)).build()); + refreshCacheAsync(request, key); + } + + return endpoint.endpoint(); + } } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/endpointdiscovery/EndpointDiscoveryRefreshCacheTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/endpointdiscovery/EndpointDiscoveryRefreshCacheTest.java new file mode 100644 index 000000000000..8daa4f0909b4 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/endpointdiscovery/EndpointDiscoveryRefreshCacheTest.java @@ -0,0 +1,92 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.endpointdiscovery; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.net.URI; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class EndpointDiscoveryRefreshCacheTest { + + private EndpointDiscoveryRefreshCache endpointDiscoveryRefreshCache; + private EndpointDiscoveryCacheLoader mockClient; + private static final URI testURI = URI.create("test_endpoint"); + private static final String requestCacheKey = "request_cache_key"; + private static final String accessKey = "access_cache_key"; + + @BeforeEach + public void setup() { + this.mockClient= mock(EndpointDiscoveryCacheLoader.class); + this.endpointDiscoveryRefreshCache = EndpointDiscoveryRefreshCache.create(mockClient); + } + + @Test + public void getAsync_notRequired_returns_CompletedFuture() throws ExecutionException, InterruptedException { + when(mockClient.discoverEndpoint(any())).thenReturn(new CompletableFuture<>()); + EndpointDiscoveryRequest request = EndpointDiscoveryRequest.builder() + .required(false) + .defaultEndpoint(testURI) + .build(); + assertThat(endpointDiscoveryRefreshCache.getAsync("key", request).isDone()).isEqualTo(true); + assertThat(endpointDiscoveryRefreshCache.getAsync("key", request).get()).isEqualTo(testURI); + + } + + @Test + public void getAsync_returns_CompletedFuture() throws ExecutionException, InterruptedException { + + when(mockClient.discoverEndpoint(any())).thenReturn(new CompletableFuture<>()); + EndpointDiscoveryRequest request = EndpointDiscoveryRequest.builder() + .required(true) + .defaultEndpoint(testURI) + .build(); + CompletableFuture future = endpointDiscoveryRefreshCache.getAsync("key", request); + assertThat(future.isDone()).isEqualTo(false); + + future.complete(testURI); + + assertThat(future.isDone()).isEqualTo(true); + assertThat(future.get()).isEqualTo(testURI); + } + + @Test + public void getAsync_future_cancelled() { + + when(mockClient.discoverEndpoint(any())).thenReturn(new CompletableFuture<>()); + EndpointDiscoveryRequest request = EndpointDiscoveryRequest.builder() + .required(true) + .defaultEndpoint(testURI) + .build(); + CompletableFuture future = endpointDiscoveryRefreshCache.getAsync("key", request); + assertThat(future.isDone()).isEqualTo(false); + + future.cancel(true); + assertThat(future.isCancelled()).isEqualTo(true); + assertThatThrownBy(future::get).isInstanceOf(CancellationException.class); + + } + +}