diff --git a/pom.xml b/pom.xml index 2f1b9f85..1c20a97b 100644 --- a/pom.xml +++ b/pom.xml @@ -155,6 +155,11 @@ software.amazon.awssdk apache-client + + software.amazon.awssdk + s3 + test + diff --git a/python/rpdk/java/templates/init/guided_aws/AbstractTestBase.java b/python/rpdk/java/templates/init/guided_aws/AbstractTestBase.java index 97d0bf76..539637ef 100644 --- a/python/rpdk/java/templates/init/guided_aws/AbstractTestBase.java +++ b/python/rpdk/java/templates/init/guided_aws/AbstractTestBase.java @@ -5,6 +5,8 @@ import software.amazon.awssdk.awscore.AwsRequest; import software.amazon.awssdk.awscore.AwsResponse; import software.amazon.awssdk.core.SdkClient; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.pagination.sync.SdkIterable; import software.amazon.cloudformation.proxy.AmazonWebServicesClientProxy; import software.amazon.cloudformation.proxy.Credentials; @@ -43,6 +45,18 @@ static ProxyClient MOCK_PROXY( return proxy.injectCredentialsAndInvokeIterableV2(request, requestFunction); } + @Override + public ResponseInputStream + injectCredentialsAndInvokeV2InputStream(RequestT requestT, Function> function) { + throw new UnsupportedOperationException(); + } + + @Override + public ResponseBytes + injectCredentialsAndInvokeV2Bytes(RequestT requestT, Function> function) { + throw new UnsupportedOperationException(); + } + @Override public SdkClient client() { return sdkClient; diff --git a/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java b/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java index 265e5185..76fdef3a 100644 --- a/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java +++ b/src/main/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxy.java @@ -40,6 +40,8 @@ import software.amazon.awssdk.awscore.AwsResponse; import software.amazon.awssdk.awscore.exception.AwsErrorDetails; import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.SdkClient; import software.amazon.awssdk.core.exception.NonRetryableException; import software.amazon.awssdk.core.exception.RetryableException; @@ -124,6 +126,22 @@ public ProxyClient newProxy(@Nonnull Supplier client return AmazonWebServicesClientProxy.this.injectCredentialsAndInvokeIterableV2(request, requestFunction); } + @Override + public + ResponseInputStream + injectCredentialsAndInvokeV2InputStream(RequestT request, + Function> requestFunction) { + return AmazonWebServicesClientProxy.this.injectCredentialsAndInvokeV2InputStream(request, requestFunction); + } + + @Override + public + ResponseBytes + injectCredentialsAndInvokeV2Bytes(RequestT request, + Function> requestFunction) { + return AmazonWebServicesClientProxy.this.injectCredentialsAndInvokeV2Bytes(request, requestFunction); + } + @Override public ClientT client() { return client.get(); @@ -512,6 +530,44 @@ public final long getRemainingTimeInMillis() { } } + public + ResponseInputStream + injectCredentialsAndInvokeV2InputStream(final RequestT request, + final Function> requestFunction) { + + AwsRequestOverrideConfiguration overrideConfiguration = AwsRequestOverrideConfiguration.builder() + .credentialsProvider(v2CredentialsProvider).build(); + + @SuppressWarnings("unchecked") + RequestT wrappedRequest = (RequestT) request.toBuilder().overrideConfiguration(overrideConfiguration).build(); + + try { + return requestFunction.apply(wrappedRequest); + } catch (final Throwable e) { + loggerProxy.log(String.format("Failed to execute remote function: {%s}", e.getMessage())); + throw e; + } + } + + public + ResponseBytes + injectCredentialsAndInvokeV2Bytes(final RequestT request, + final Function> requestFunction) { + + AwsRequestOverrideConfiguration overrideConfiguration = AwsRequestOverrideConfiguration.builder() + .credentialsProvider(v2CredentialsProvider).build(); + + @SuppressWarnings("unchecked") + RequestT wrappedRequest = (RequestT) request.toBuilder().overrideConfiguration(overrideConfiguration).build(); + + try { + return requestFunction.apply(wrappedRequest); + } catch (final Throwable e) { + loggerProxy.log(String.format("Failed to execute remote function: {%s}", e.getMessage())); + throw e; + } + } + public ProgressEvent defaultHandler(RequestT request, Exception e, ClientT client, ModelT model, CallbackT context) throws Exception { diff --git a/src/main/java/software/amazon/cloudformation/proxy/ProxyClient.java b/src/main/java/software/amazon/cloudformation/proxy/ProxyClient.java index d0c9df74..5985e20e 100644 --- a/src/main/java/software/amazon/cloudformation/proxy/ProxyClient.java +++ b/src/main/java/software/amazon/cloudformation/proxy/ProxyClient.java @@ -18,6 +18,8 @@ import java.util.function.Function; import software.amazon.awssdk.awscore.AwsRequest; import software.amazon.awssdk.awscore.AwsResponse; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.pagination.sync.SdkIterable; /** @@ -94,6 +96,45 @@ public interface ProxyClient { IterableT injectCredentialsAndInvokeIterableV2(RequestT request, Function requestFunction); + /** + * This is a synchronous version of making API calls which implement + * ResponseInputStream in the SDKv2 + * + * @param request, the AWS service request that we need to make + * @param requestFunction, this is a Lambda closure that provide the actual API + * that needs to be invoked. + * @param the request type + * @param the response from the request + * @return the response if successful. Else it will propagate all + * {@link software.amazon.awssdk.awscore.exception.AwsServiceException} + * that is thrown or + * {@link software.amazon.awssdk.core.exception.SdkClientException} if + * there is client side problem + */ + + ResponseInputStream + injectCredentialsAndInvokeV2InputStream(RequestT request, + Function> requestFunction); + + /** + * This is a synchronous version of making API calls which implement + * ResponseBytes in the SDKv2 + * + * @param request, the AWS service request that we need to make + * @param requestFunction, this is a Lambda closure that provide the actual API + * that needs to be invoked. + * @param the request type + * @param the response from the request + * @return the response if successful. Else it will propagate all + * {@link software.amazon.awssdk.awscore.exception.AwsServiceException} + * that is thrown or + * {@link software.amazon.awssdk.core.exception.SdkClientException} if + * there is client side problem + */ + + ResponseBytes + injectCredentialsAndInvokeV2Bytes(RequestT request, Function> requestFunction); + /** * @return the actual AWS service client that we need to use to provide the * actual method we are going to call. diff --git a/src/test/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxyTest.java b/src/test/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxyTest.java index 701afaea..d3dc5181 100644 --- a/src/test/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxyTest.java +++ b/src/test/java/software/amazon/cloudformation/proxy/AmazonWebServicesClientProxyTest.java @@ -20,6 +20,8 @@ import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -42,11 +44,15 @@ import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.awscore.exception.AwsErrorDetails; import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.ResponseInputStream; import software.amazon.awssdk.core.exception.NonRetryableException; import software.amazon.awssdk.http.SdkHttpResponse; import software.amazon.awssdk.services.cloudformation.CloudFormationAsyncClient; import software.amazon.awssdk.services.cloudformation.CloudFormationClient; import software.amazon.awssdk.services.cloudformation.model.DescribeStackEventsResponse; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.cloudformation.exceptions.ResourceAlreadyExistsException; import software.amazon.cloudformation.exceptions.TerminalException; import software.amazon.cloudformation.proxy.delay.Constant; @@ -229,6 +235,150 @@ public void testInjectCredentialsAndInvokeV2Async_WithException() { } + @Test + public void testInjectCredentialsAndInvokeV2InputStream() { + + final LoggerProxy loggerProxy = mock(LoggerProxy.class); + final Credentials credentials = new Credentials("accessKeyId", "secretAccessKey", "sessionToken"); + final ResponseInputStream responseInputStream = mock(ResponseInputStream.class); + + final AmazonWebServicesClientProxy proxy = new AmazonWebServicesClientProxy(loggerProxy, credentials, () -> 1000L); + + final software.amazon.awssdk.services.s3.model.GetObjectRequest wrappedRequest = mock( + software.amazon.awssdk.services.s3.model.GetObjectRequest.class); + + final software.amazon.awssdk.services.s3.model.GetObjectRequest.Builder builder = mock( + software.amazon.awssdk.services.s3.model.GetObjectRequest.Builder.class); + when(builder.overrideConfiguration(any(AwsRequestOverrideConfiguration.class))).thenReturn(builder); + when(builder.build()).thenReturn(wrappedRequest); + final software.amazon.awssdk.services.s3.model.GetObjectRequest request = mock( + software.amazon.awssdk.services.s3.model.GetObjectRequest.class); + when(request.toBuilder()).thenReturn(builder); + + final S3Client client = mock(S3Client.class); + + doReturn(responseInputStream).when(client) + .getObject(any(software.amazon.awssdk.services.s3.model.GetObjectRequest.class)); + + final ResponseInputStream< + GetObjectResponse> result = proxy.injectCredentialsAndInvokeV2InputStream(request, client::getObject); + + // verify request is rebuilt for injection + verify(request).toBuilder(); + + // verify the wrapped request is sent over the initiate + verify(client).getObject(wrappedRequest); + + // ensure the return type matches + assertThat(result).isEqualTo(responseInputStream); + } + + @Test + public void testInjectCredentialsAndInvokeV2InputStream_Exception() { + + final LoggerProxy loggerProxy = mock(LoggerProxy.class); + final Credentials credentials = new Credentials("accessKeyId", "secretAccessKey", "sessionToken"); + + final AmazonWebServicesClientProxy proxy = new AmazonWebServicesClientProxy(loggerProxy, credentials, () -> 1000L); + + final software.amazon.awssdk.services.s3.model.GetObjectRequest wrappedRequest = mock( + software.amazon.awssdk.services.s3.model.GetObjectRequest.class); + + final software.amazon.awssdk.services.s3.model.GetObjectRequest.Builder builder = mock( + software.amazon.awssdk.services.s3.model.GetObjectRequest.Builder.class); + when(builder.overrideConfiguration(any(AwsRequestOverrideConfiguration.class))).thenReturn(builder); + when(builder.build()).thenReturn(wrappedRequest); + final software.amazon.awssdk.services.s3.model.GetObjectRequest request = mock( + software.amazon.awssdk.services.s3.model.GetObjectRequest.class); + when(request.toBuilder()).thenReturn(builder); + + final S3Client client = mock(S3Client.class); + + doThrow(new TerminalException(new RuntimeException("Sorry"))).when(client) + .getObject(any(software.amazon.awssdk.services.s3.model.GetObjectRequest.class)); + + assertThrows(RuntimeException.class, () -> proxy.injectCredentialsAndInvokeV2InputStream(request, client::getObject), + "Expected Runtime Exception."); + + // verify request is rebuilt for injection + verify(request).toBuilder(); + + // verify the wrapped request is sent over the initiate + verify(client).getObject(wrappedRequest); + } + + @Test + public void testInjectCredentialsAndInvokeV2Bytes() { + + final LoggerProxy loggerProxy = mock(LoggerProxy.class); + final Credentials credentials = new Credentials("accessKeyId", "secretAccessKey", "sessionToken"); + final ResponseBytes responseBytes = mock(ResponseBytes.class); + + final AmazonWebServicesClientProxy proxy = new AmazonWebServicesClientProxy(loggerProxy, credentials, () -> 1000L); + + final software.amazon.awssdk.services.s3.model.GetObjectRequest wrappedRequest = mock( + software.amazon.awssdk.services.s3.model.GetObjectRequest.class); + + final software.amazon.awssdk.services.s3.model.GetObjectRequest.Builder builder = mock( + software.amazon.awssdk.services.s3.model.GetObjectRequest.Builder.class); + when(builder.overrideConfiguration(any(AwsRequestOverrideConfiguration.class))).thenReturn(builder); + when(builder.build()).thenReturn(wrappedRequest); + final software.amazon.awssdk.services.s3.model.GetObjectRequest request = mock( + software.amazon.awssdk.services.s3.model.GetObjectRequest.class); + when(request.toBuilder()).thenReturn(builder); + + final S3Client client = mock(S3Client.class); + + doReturn(responseBytes).when(client) + .getObjectAsBytes(any(software.amazon.awssdk.services.s3.model.GetObjectRequest.class)); + + final ResponseBytes< + GetObjectResponse> result = proxy.injectCredentialsAndInvokeV2Bytes(request, client::getObjectAsBytes); + + // verify request is rebuilt for injection + verify(request).toBuilder(); + + // verify the wrapped request is sent over the initiate + verify(client).getObjectAsBytes(wrappedRequest); + + // ensure the return type matches + assertThat(result).isEqualTo(responseBytes); + } + + @Test + public void testInjectCredentialsAndInvokeV2Bytes_Exception() { + + final LoggerProxy loggerProxy = mock(LoggerProxy.class); + final Credentials credentials = new Credentials("accessKeyId", "secretAccessKey", "sessionToken"); + + final AmazonWebServicesClientProxy proxy = new AmazonWebServicesClientProxy(loggerProxy, credentials, () -> 1000L); + + final software.amazon.awssdk.services.s3.model.GetObjectRequest wrappedRequest = mock( + software.amazon.awssdk.services.s3.model.GetObjectRequest.class); + + final software.amazon.awssdk.services.s3.model.GetObjectRequest.Builder builder = mock( + software.amazon.awssdk.services.s3.model.GetObjectRequest.Builder.class); + when(builder.overrideConfiguration(any(AwsRequestOverrideConfiguration.class))).thenReturn(builder); + when(builder.build()).thenReturn(wrappedRequest); + final software.amazon.awssdk.services.s3.model.GetObjectRequest request = mock( + software.amazon.awssdk.services.s3.model.GetObjectRequest.class); + when(request.toBuilder()).thenReturn(builder); + + final S3Client client = mock(S3Client.class); + + doThrow(new TerminalException(new RuntimeException("Sorry"))).when(client) + .getObjectAsBytes(any(software.amazon.awssdk.services.s3.model.GetObjectRequest.class)); + + assertThrows(RuntimeException.class, () -> proxy.injectCredentialsAndInvokeV2Bytes(request, client::getObjectAsBytes), + "Expected Runtime Exception."); + + // verify request is rebuilt for injection + verify(request).toBuilder(); + + // verify the wrapped request is sent over the initiate + verify(client).getObjectAsBytes(wrappedRequest); + } + private final Credentials MOCK = new Credentials("accessKeyId", "secretKey", "token"); @Test