diff --git a/.changes/2d52d36d-564b-4e31-a7fa-14d8839d4a96.json b/.changes/2d52d36d-564b-4e31-a7fa-14d8839d4a96.json new file mode 100644 index 00000000000..cec59aa1f0a --- /dev/null +++ b/.changes/2d52d36d-564b-4e31-a7fa-14d8839d4a96.json @@ -0,0 +1,5 @@ +{ + "id": "2d52d36d-564b-4e31-a7fa-14d8839d4a96", + "type": "feature", + "description": "Implement recursion detection middleware." +} \ No newline at end of file diff --git a/aws-runtime/aws-http/common/src/aws/sdk/kotlin/runtime/http/middleware/RecursionDetection.kt b/aws-runtime/aws-http/common/src/aws/sdk/kotlin/runtime/http/middleware/RecursionDetection.kt new file mode 100644 index 00000000000..f05b15f3feb --- /dev/null +++ b/aws-runtime/aws-http/common/src/aws/sdk/kotlin/runtime/http/middleware/RecursionDetection.kt @@ -0,0 +1,56 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package aws.sdk.kotlin.runtime.http.middleware + +import aws.sdk.kotlin.runtime.InternalSdkApi +import aws.smithy.kotlin.runtime.http.operation.ModifyRequestMiddleware +import aws.smithy.kotlin.runtime.http.operation.SdkHttpRequest +import aws.smithy.kotlin.runtime.util.EnvironmentProvider +import aws.smithy.kotlin.runtime.util.Platform +import aws.smithy.kotlin.runtime.util.text.percentEncodeTo + +internal const val ENV_FUNCTION_NAME = "AWS_LAMBDA_FUNCTION_NAME" +internal const val ENV_TRACE_ID = "_X_AMZN_TRACE_ID" +internal const val HEADER_TRACE_ID = "X-Amzn-Trace-Id" + +/** + * HTTP middleware to add the recursion detection header where required. + */ +@InternalSdkApi +public class RecursionDetection( + private val env: EnvironmentProvider = Platform +) : ModifyRequestMiddleware { + override suspend fun modifyRequest(req: SdkHttpRequest): SdkHttpRequest { + if (req.subject.headers.contains(HEADER_TRACE_ID)) return req + + val traceId = env.getenv(ENV_TRACE_ID) + if (env.getenv(ENV_FUNCTION_NAME) == null || traceId == null) return req + + req.subject.headers[HEADER_TRACE_ID] = traceId.percentEncode() + return req + } +} + +/** + * Percent-encode ISO control characters for the purposes of this specific header. + * + * The existing `Char::isISOControl` check cannot be used here, because that matches against characters in + * `[0x00, 0x1f] U [0x7f, 0x9f]`. The SEP for recursion detection dictates we should only encode across + * `[0x00, 0x1f]`. + */ +private fun String.percentEncode(): String { + val sb = StringBuilder(this.length) + val data = this.encodeToByteArray() + for (cbyte in data) { + val chr = cbyte.toInt().toChar() + if (chr.code in 0x00..0x1f) { + cbyte.percentEncodeTo(sb) + } else { + sb.append(chr) + } + } + return sb.toString() +} diff --git a/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/RecursionDetectionTest.kt b/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/RecursionDetectionTest.kt new file mode 100644 index 00000000000..f5a497ad5a3 --- /dev/null +++ b/aws-runtime/aws-http/common/test/aws/sdk/kotlin/runtime/http/middleware/RecursionDetectionTest.kt @@ -0,0 +1,143 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package aws.sdk.kotlin.runtime.http.middleware + +import aws.sdk.kotlin.runtime.testing.TestPlatformProvider +import aws.smithy.kotlin.runtime.client.ExecutionContext +import aws.smithy.kotlin.runtime.http.Headers +import aws.smithy.kotlin.runtime.http.HttpBody +import aws.smithy.kotlin.runtime.http.HttpStatusCode +import aws.smithy.kotlin.runtime.http.engine.HttpClientEngineBase +import aws.smithy.kotlin.runtime.http.operation.* +import aws.smithy.kotlin.runtime.http.request.HttpRequest +import aws.smithy.kotlin.runtime.http.request.HttpRequestBuilder +import aws.smithy.kotlin.runtime.http.response.HttpCall +import aws.smithy.kotlin.runtime.http.response.HttpResponse +import aws.smithy.kotlin.runtime.http.sdkHttpClient +import aws.smithy.kotlin.runtime.time.Instant +import aws.smithy.kotlin.runtime.util.get +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse + +@OptIn(ExperimentalCoroutinesApi::class) +class RecursionDetectionTest { + private class TraceHeaderSerializer( + private val traceHeader: String + ) : HttpSerialize { + override suspend fun serialize(context: ExecutionContext, input: Unit): HttpRequestBuilder { + val builder = HttpRequestBuilder() + builder.headers[HEADER_TRACE_ID] = traceHeader + return builder + } + } + + private val mockEngine = object : HttpClientEngineBase("test") { + override suspend fun roundTrip(request: HttpRequest): HttpCall { + val resp = HttpResponse(HttpStatusCode.fromValue(200), Headers.Empty, HttpBody.Empty) + val now = Instant.now() + return HttpCall(request, resp, now, now) + } + } + + private val client = sdkHttpClient(mockEngine) + + private suspend fun test( + env: Map, + existingTraceHeader: String?, + expectedTraceHeader: String? + ) { + val op = SdkHttpOperation.build { + serializer = if (existingTraceHeader != null) TraceHeaderSerializer(existingTraceHeader) else UnitSerializer + deserializer = IdentityDeserializer + context { + service = "Test Service" + operationName = "testOperation" + } + } + + val provider = TestPlatformProvider(env) + op.install(RecursionDetection(provider)) + op.roundTrip(client, Unit) + + val request = op.context[HttpOperationContext.HttpCallList].last().request + if (expectedTraceHeader != null) { + assertEquals(expectedTraceHeader, request.headers[HEADER_TRACE_ID]) + } else { + assertFalse(request.headers.contains(HEADER_TRACE_ID)) + } + } + + @Test + fun `it noops if env unset`() = runTest { + test( + emptyMap(), + null, + null + ) + } + + @Test + fun `it sets header when both envs are present`() = runTest { + test( + mapOf( + ENV_FUNCTION_NAME to "some-function", + ENV_TRACE_ID to "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1;lineage=a87bd80c:0,68fd508a:5,c512fbe3:2" + ), + null, + "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1;lineage=a87bd80c:0,68fd508a:5,c512fbe3:2" + ) + } + + @Test + fun `it noops if trace env set but no lambda env`() = runTest { + test( + mapOf( + ENV_TRACE_ID to "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1;lineage=a87bd80c:0,68fd508a:5,c512fbe3:2" + ), + null, + null + ) + } + + @Test + fun `it respects existing trace header`() = runTest { + test( + mapOf( + ENV_FUNCTION_NAME to "some-function", + ENV_TRACE_ID to "EnvValue" + ), + "OriginalValue", + "OriginalValue" + ) + } + + @Test + fun `it url encodes new trace header`() = runTest { + test( + mapOf( + ENV_FUNCTION_NAME to "some-function", + ENV_TRACE_ID to "first\nsecond" + ), + null, + "first%0Asecond" + ) + } + + @Test + fun `ignores other chars that are usually percent encoded`() = runTest { + test( + mapOf( + ENV_FUNCTION_NAME to "some-function", + ENV_TRACE_ID to "test123-=;:+&[]{}\"'" + ), + null, + "test123-=;:+&[]{}\"'" + ) + } +} diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/core/AwsHttpBindingProtocolGenerator.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/core/AwsHttpBindingProtocolGenerator.kt index 53e4be8b4c1..dc76c16da42 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/core/AwsHttpBindingProtocolGenerator.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/core/AwsHttpBindingProtocolGenerator.kt @@ -8,6 +8,7 @@ import aws.sdk.kotlin.codegen.AwsKotlinDependency import aws.sdk.kotlin.codegen.AwsRuntimeTypes import aws.sdk.kotlin.codegen.protocols.eventstream.EventStreamParserGenerator import aws.sdk.kotlin.codegen.protocols.eventstream.EventStreamSerializerGenerator +import aws.sdk.kotlin.codegen.protocols.middleware.RecursionDetectionMiddleware import aws.sdk.kotlin.codegen.protocols.middleware.ResolveAwsEndpointMiddleware import aws.sdk.kotlin.codegen.protocols.middleware.UserAgentMiddleware import aws.sdk.kotlin.codegen.protocols.protocoltest.AwsHttpProtocolUnitTestErrorGenerator @@ -48,6 +49,7 @@ abstract class AwsHttpBindingProtocolGenerator : HttpBindingProtocolGenerator() }.toMutableList() middleware.add(UserAgentMiddleware()) + middleware.add(RecursionDetectionMiddleware()) return middleware } diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/middleware/RecursionDetectionMiddleware.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/middleware/RecursionDetectionMiddleware.kt new file mode 100644 index 00000000000..da203ff67ee --- /dev/null +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/middleware/RecursionDetectionMiddleware.kt @@ -0,0 +1,31 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package aws.sdk.kotlin.codegen.protocols.middleware + +import aws.sdk.kotlin.codegen.AwsKotlinDependency +import software.amazon.smithy.kotlin.codegen.core.KotlinWriter +import software.amazon.smithy.kotlin.codegen.model.buildSymbol +import software.amazon.smithy.kotlin.codegen.model.namespace +import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator +import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolMiddleware +import software.amazon.smithy.model.shapes.OperationShape + +/** + * HTTP middleware to add the recursion detection header where required. + */ +class RecursionDetectionMiddleware : ProtocolMiddleware { + override val name: String = "RecursionDetection" + override val order: Byte = 30 + + private val middlewareSymbol = buildSymbol { + name = "RecursionDetection" + namespace(AwsKotlinDependency.AWS_HTTP, subpackage = "middleware") + } + + override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) { + writer.write("op.install(#T())", middlewareSymbol) + } +}