diff --git a/aws-runtime/aws-http/common/src/aws/sdk/kotlin/runtime/http/middleware/AwsRetryMiddleware.kt b/aws-runtime/aws-http/common/src/aws/sdk/kotlin/runtime/http/middleware/AwsRetryMiddleware.kt new file mode 100644 index 00000000000..2b3cb725564 --- /dev/null +++ b/aws-runtime/aws-http/common/src/aws/sdk/kotlin/runtime/http/middleware/AwsRetryMiddleware.kt @@ -0,0 +1,46 @@ +package aws.sdk.kotlin.runtime.http.middleware + +import aws.smithy.kotlin.runtime.http.middleware.Retry +import aws.smithy.kotlin.runtime.http.operation.* +import aws.smithy.kotlin.runtime.http.request.header +import aws.smithy.kotlin.runtime.io.Handler +import aws.smithy.kotlin.runtime.retries.RetryStrategy +import aws.smithy.kotlin.runtime.retries.policy.RetryPolicy +import aws.smithy.kotlin.runtime.util.InternalApi +import aws.smithy.kotlin.runtime.util.get + +/** + * The per/operation unique client side ID header name. This will match + * the [HttpOperationContext.SdkRequestId] + */ +internal const val AMZ_SDK_INVOCATION_ID_HEADER = "amz-sdk-invocation-id" + +/** + * Details about the current request such as the attempt number, maximum possible attempts, ttl, etc + */ +internal const val AMZ_SDK_REQUEST_HEADER = "amz-sdk-request" + +/** + * Retry requests with the given strategy and policy. This middleware customizes the default [Retry] implementation + * to add AWS specific retry headers + * + * @param strategy the [RetryStrategy] to retry failed requests with + * @param policy the [RetryPolicy] used to determine when to retry + */ +@InternalApi +public class AwsRetryMiddleware( + strategy: RetryStrategy, + policy: RetryPolicy +) : Retry(strategy, policy) { + + override suspend fun > handle(request: SdkHttpRequest, next: H): O { + request.subject.header(AMZ_SDK_INVOCATION_ID_HEADER, request.context[HttpOperationContext.SdkRequestId]) + return super.handle(request, next) + } + + override fun onAttempt(request: SdkHttpRequest, attempt: Int) { + // setting ttl would never be accurate, just set what we know which is attempt and maybe max attempt + val maxAttempts = strategy.options.maxAttempts?.let { "; max=$it" } ?: "" + request.subject.header(AMZ_SDK_REQUEST_HEADER, "attempt=${attempt}$maxAttempts") + } +} diff --git a/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/AwsRetryMiddlewareTest.kt b/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/AwsRetryMiddlewareTest.kt new file mode 100644 index 00000000000..2101df0e468 --- /dev/null +++ b/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/AwsRetryMiddlewareTest.kt @@ -0,0 +1,70 @@ +package aws.sdk.kotlin.runtime.http.middleware + +import aws.sdk.kotlin.runtime.http.retries.AwsDefaultRetryPolicy +import aws.smithy.kotlin.runtime.client.ExecutionContext +import aws.smithy.kotlin.runtime.http.Headers +import aws.smithy.kotlin.runtime.http.HttpBody +import aws.smithy.kotlin.runtime.http.HttpStatusCode +import aws.smithy.kotlin.runtime.http.engine.HttpClientEngineBase +import aws.smithy.kotlin.runtime.http.operation.* +import aws.smithy.kotlin.runtime.http.request.HttpRequest +import aws.smithy.kotlin.runtime.http.response.HttpCall +import aws.smithy.kotlin.runtime.http.response.HttpResponse +import aws.smithy.kotlin.runtime.http.sdkHttpClient +import aws.smithy.kotlin.runtime.retries.StandardRetryStrategy +import aws.smithy.kotlin.runtime.retries.StandardRetryStrategyOptions +import aws.smithy.kotlin.runtime.retries.delay.DelayProvider +import aws.smithy.kotlin.runtime.retries.delay.StandardRetryTokenBucket +import aws.smithy.kotlin.runtime.retries.delay.StandardRetryTokenBucketOptions +import aws.smithy.kotlin.runtime.time.Instant +import aws.smithy.kotlin.runtime.util.get +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +@OptIn(ExperimentalCoroutinesApi::class) +class AwsRetryMiddlewareTest { + + private val mockEngine = object : HttpClientEngineBase("test") { + override suspend fun roundTrip(context: ExecutionContext, request: HttpRequest): HttpCall { + val resp = HttpResponse(HttpStatusCode.OK, Headers.Empty, HttpBody.Empty) + return HttpCall(request, resp, Instant.now(), Instant.now()) + } + } + private val client = sdkHttpClient(mockEngine) + + @Test + fun testItSetsRetryHeaders() = runTest { + // see retry-header SEP + val op = SdkHttpOperation.build { + serializer = UnitSerializer + deserializer = UnitDeserializer + context { + // required operation context + operationName = "TestOperation" + service = "TestService" + } + } + + val delayProvider = DelayProvider { } + val strategy = StandardRetryStrategy( + StandardRetryStrategyOptions.Default, + StandardRetryTokenBucket(StandardRetryTokenBucketOptions.Default), + delayProvider + ) + val maxAttempts = strategy.options.maxAttempts + + op.install(AwsRetryMiddleware(strategy, AwsDefaultRetryPolicy)) + + op.roundTrip(client, Unit) + val calls = op.context.attributes[HttpOperationContext.HttpCallList] + val sdkRequestId = op.context[HttpOperationContext.SdkRequestId] + + assertTrue(calls.all { it.request.headers[AMZ_SDK_INVOCATION_ID_HEADER] == sdkRequestId }) + calls.forEachIndexed { idx, call -> + assertEquals("attempt=${idx + 1}; max=$maxAttempts", call.request.headers[AMZ_SDK_REQUEST_HEADER]) + } + } +} diff --git a/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/RecursionDetectionTest.kt b/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/RecursionDetectionTest.kt index f5a497ad5a3..c883e84b77e 100644 --- a/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/RecursionDetectionTest.kt +++ b/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/RecursionDetectionTest.kt @@ -38,7 +38,7 @@ class RecursionDetectionTest { } private val mockEngine = object : HttpClientEngineBase("test") { - override suspend fun roundTrip(request: HttpRequest): HttpCall { + override suspend fun roundTrip(context: ExecutionContext, request: HttpRequest): HttpCall { val resp = HttpResponse(HttpStatusCode.fromValue(200), Headers.Empty, HttpBody.Empty) val now = Instant.now() return HttpCall(request, resp, now, now) diff --git a/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/ResolveAwsEndpointTest.kt b/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/ResolveAwsEndpointTest.kt index 40f6cfe9aa4..9a6364e97b8 100644 --- a/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/ResolveAwsEndpointTest.kt +++ b/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/ResolveAwsEndpointTest.kt @@ -10,6 +10,7 @@ import aws.sdk.kotlin.runtime.endpoint.AwsEndpoint import aws.sdk.kotlin.runtime.endpoint.AwsEndpointResolver import aws.sdk.kotlin.runtime.endpoint.CredentialScope import aws.smithy.kotlin.runtime.auth.awssigning.AwsSigningAttributes +import aws.smithy.kotlin.runtime.client.ExecutionContext import aws.smithy.kotlin.runtime.http.* import aws.smithy.kotlin.runtime.http.engine.HttpClientEngineBase import aws.smithy.kotlin.runtime.http.operation.* @@ -27,7 +28,7 @@ import kotlin.test.assertEquals class ResolveAwsEndpointTest { private val mockEngine = object : HttpClientEngineBase("test") { - override suspend fun roundTrip(request: HttpRequest): HttpCall { + override suspend fun roundTrip(context: ExecutionContext, request: HttpRequest): HttpCall { val resp = HttpResponse(HttpStatusCode.OK, Headers.Empty, HttpBody.Empty) return HttpCall(request, resp, Instant.now(), Instant.now()) } diff --git a/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/UserAgentTest.kt b/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/UserAgentTest.kt index 0189951e3ee..917f923a791 100644 --- a/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/UserAgentTest.kt +++ b/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/UserAgentTest.kt @@ -9,6 +9,7 @@ import aws.sdk.kotlin.runtime.http.ApiMetadata import aws.sdk.kotlin.runtime.http.loadAwsUserAgentMetadataFromEnvironment import aws.sdk.kotlin.runtime.http.operation.customUserAgentMetadata import aws.sdk.kotlin.runtime.testing.TestPlatformProvider +import aws.smithy.kotlin.runtime.client.ExecutionContext import aws.smithy.kotlin.runtime.http.Headers import aws.smithy.kotlin.runtime.http.HttpBody import aws.smithy.kotlin.runtime.http.HttpStatusCode @@ -31,7 +32,7 @@ import kotlin.test.assertTrue @OptIn(ExperimentalCoroutinesApi::class) class UserAgentTest { private val mockEngine = object : HttpClientEngineBase("test") { - override suspend fun roundTrip(request: HttpRequest): HttpCall { + override suspend fun roundTrip(context: ExecutionContext, request: HttpRequest): HttpCall { val resp = HttpResponse(HttpStatusCode.fromValue(200), Headers.Empty, HttpBody.Empty) val now = Instant.now() return HttpCall(request, resp, now, now) diff --git a/aws-runtime/protocols/aws-json-protocols/common/test/aws/sdk/kotlin/runtime/protocol/json/AwsJsonProtocolTest.kt b/aws-runtime/protocols/aws-json-protocols/common/test/aws/sdk/kotlin/runtime/protocol/json/AwsJsonProtocolTest.kt index bac9e385d69..6ecaec574bb 100644 --- a/aws-runtime/protocols/aws-json-protocols/common/test/aws/sdk/kotlin/runtime/protocol/json/AwsJsonProtocolTest.kt +++ b/aws-runtime/protocols/aws-json-protocols/common/test/aws/sdk/kotlin/runtime/protocol/json/AwsJsonProtocolTest.kt @@ -27,7 +27,7 @@ class AwsJsonProtocolTest { @Test fun testSetJsonProtocolHeaders() = runTest { val mockEngine = object : HttpClientEngineBase("test") { - override suspend fun roundTrip(request: HttpRequest): HttpCall { + override suspend fun roundTrip(context: ExecutionContext, request: HttpRequest): HttpCall { val resp = HttpResponse(HttpStatusCode.OK, Headers.Empty, HttpBody.Empty) val now = Instant.now() return HttpCall(request, resp, now, now) @@ -58,7 +58,7 @@ class AwsJsonProtocolTest { @Test fun testEmptyBody() = runTest { val mockEngine = object : HttpClientEngineBase("test") { - override suspend fun roundTrip(request: HttpRequest): HttpCall { + override suspend fun roundTrip(context: ExecutionContext, request: HttpRequest): HttpCall { val resp = HttpResponse(HttpStatusCode.OK, Headers.Empty, HttpBody.Empty) val now = Instant.now() return HttpCall(request, resp, now, now) @@ -86,7 +86,7 @@ class AwsJsonProtocolTest { @Test fun testDoesNotOverride() = runTest { val mockEngine = object : HttpClientEngineBase("test") { - override suspend fun roundTrip(request: HttpRequest): HttpCall { + override suspend fun roundTrip(context: ExecutionContext, request: HttpRequest): HttpCall { val resp = HttpResponse(HttpStatusCode.OK, Headers.Empty, HttpBody.Empty) val now = Instant.now() return HttpCall(request, resp, now, now) diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsDefaultRetryIntegration.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsDefaultRetryIntegration.kt index 93b02a2ca37..b5c7d1a3e77 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsDefaultRetryIntegration.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsDefaultRetryIntegration.kt @@ -6,7 +6,6 @@ package aws.sdk.kotlin.codegen import software.amazon.smithy.kotlin.codegen.core.KotlinWriter -import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolMiddleware @@ -15,8 +14,8 @@ import software.amazon.smithy.kotlin.codegen.retries.StandardRetryMiddleware import software.amazon.smithy.model.shapes.OperationShape /** - * Adds AWS-specific retry wrappers around operation invocations. This replaces - * [StandardRetryPolicy][aws.smithy.kotlin.runtime.retries.impl] with + * Replace the [StandardRetryMiddleware] with AWS specific retry middleware (AwsRetryMiddleware) + * as well as replace the [StandardRetryPolicy][aws.smithy.kotlin.runtime.retries.impl] with * [AwsDefaultRetryPolicy][aws.sdk.kotlin.runtime.http.retries]. */ class AwsDefaultRetryIntegration : KotlinIntegration { @@ -26,12 +25,14 @@ class AwsDefaultRetryIntegration : KotlinIntegration { ): List = resolved.replace(middleware) { it is StandardRetryMiddleware } private val middleware = object : ProtocolMiddleware { - override val name: String = RuntimeTypes.Http.Middlware.Retry.name + override val name: String = AwsRuntimeTypes.Http.Middleware.AwsRetryMiddleware.name override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) { - writer.addImport(RuntimeTypes.Http.Middlware.Retry) - writer.addImport(AwsRuntimeTypes.Http.Retries.AwsDefaultRetryPolicy) - writer.write("op.install(#T(config.retryStrategy, AwsDefaultRetryPolicy))", RuntimeTypes.Http.Middlware.Retry) + writer.write( + "op.install(#T(config.retryStrategy, #T))", + AwsRuntimeTypes.Http.Middleware.AwsRetryMiddleware, + AwsRuntimeTypes.Http.Retries.AwsDefaultRetryPolicy + ) } } } diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsRuntimeTypes.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsRuntimeTypes.kt index 08c05cb54f2..f3f491d8dbf 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsRuntimeTypes.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsRuntimeTypes.kt @@ -55,6 +55,9 @@ object AwsRuntimeTypes { object Retries { val AwsDefaultRetryPolicy = runtimeSymbol("AwsDefaultRetryPolicy", AwsKotlinDependency.AWS_HTTP, "retries") } + object Middleware { + val AwsRetryMiddleware = runtimeSymbol("AwsRetryMiddleware", AwsKotlinDependency.AWS_HTTP, "middleware") + } } object JsonProtocols { diff --git a/services/sts/common/test/aws/sdk/kotlin/services/sts/StsAuthTests.kt b/services/sts/common/test/aws/sdk/kotlin/services/sts/StsAuthTests.kt index 74839fc6fc7..a65ce2553ec 100644 --- a/services/sts/common/test/aws/sdk/kotlin/services/sts/StsAuthTests.kt +++ b/services/sts/common/test/aws/sdk/kotlin/services/sts/StsAuthTests.kt @@ -7,6 +7,7 @@ package aws.sdk.kotlin.services.sts import aws.sdk.kotlin.runtime.auth.credentials.StaticCredentialsProvider import aws.smithy.kotlin.runtime.auth.awscredentials.Credentials +import aws.smithy.kotlin.runtime.client.ExecutionContext import aws.smithy.kotlin.runtime.http.Headers import aws.smithy.kotlin.runtime.http.HttpBody import aws.smithy.kotlin.runtime.http.HttpStatusCode @@ -32,7 +33,7 @@ class StsAuthTests { private val mockEngine = object : HttpClientEngineBase("mock-engine") { var capturedRequest: HttpRequest? = null - override suspend fun roundTrip(request: HttpRequest): HttpCall { + override suspend fun roundTrip(context: ExecutionContext, request: HttpRequest): HttpCall { capturedRequest = request val callContext = callContext() val now = Instant.now()