Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/2d52d36d-564b-4e31-a7fa-14d8839d4a96.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"id": "2d52d36d-564b-4e31-a7fa-14d8839d4a96",
"type": "feature",
"description": "Implement recursion detection middleware."
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package aws.sdk.kotlin.runtime.http.middleware

import aws.sdk.kotlin.runtime.InternalSdkApi
import aws.smithy.kotlin.runtime.http.operation.ModifyRequestMiddleware
import aws.smithy.kotlin.runtime.http.operation.SdkHttpRequest
import aws.smithy.kotlin.runtime.util.EnvironmentProvider
import aws.smithy.kotlin.runtime.util.Platform
import aws.smithy.kotlin.runtime.util.text.percentEncodeTo

internal const val ENV_FUNCTION_NAME = "AWS_LAMBDA_FUNCTION_NAME"
internal const val ENV_TRACE_ID = "_X_AMZN_TRACE_ID"
internal const val HEADER_TRACE_ID = "X-Amzn-Trace-Id"

/**
* HTTP middleware to add the recursion detection header where required.
*/
@InternalSdkApi
public class RecursionDetection(
private val env: EnvironmentProvider = Platform
) : ModifyRequestMiddleware {
override suspend fun modifyRequest(req: SdkHttpRequest): SdkHttpRequest {
if (req.subject.headers.contains(HEADER_TRACE_ID)) return req

val traceId = env.getenv(ENV_TRACE_ID)
if (env.getenv(ENV_FUNCTION_NAME) == null || traceId == null) return req

req.subject.headers[HEADER_TRACE_ID] = traceId.percentEncode()
return req
}
}

/**
* Percent-encode ISO control characters for the purposes of this specific header.
*
* The existing `Char::isISOControl` check cannot be used here, because that matches against characters in
* `[0x00, 0x1f] U [0x7f, 0x9f]`. The SEP for recursion detection dictates we should only encode across
* `[0x00, 0x1f]`.
*/
private fun String.percentEncode(): String {
val sb = StringBuilder(this.length)
val data = this.encodeToByteArray()
for (cbyte in data) {
val chr = cbyte.toInt().toChar()
if (chr.code in 0x00..0x1f) {
cbyte.percentEncodeTo(sb)
} else {
sb.append(chr)
}
}
return sb.toString()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package aws.sdk.kotlin.runtime.http.middleware

import aws.sdk.kotlin.runtime.testing.TestPlatformProvider
import aws.smithy.kotlin.runtime.client.ExecutionContext
import aws.smithy.kotlin.runtime.http.Headers
import aws.smithy.kotlin.runtime.http.HttpBody
import aws.smithy.kotlin.runtime.http.HttpStatusCode
import aws.smithy.kotlin.runtime.http.engine.HttpClientEngineBase
import aws.smithy.kotlin.runtime.http.operation.*
import aws.smithy.kotlin.runtime.http.request.HttpRequest
import aws.smithy.kotlin.runtime.http.request.HttpRequestBuilder
import aws.smithy.kotlin.runtime.http.response.HttpCall
import aws.smithy.kotlin.runtime.http.response.HttpResponse
import aws.smithy.kotlin.runtime.http.sdkHttpClient
import aws.smithy.kotlin.runtime.time.Instant
import aws.smithy.kotlin.runtime.util.get
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFalse

@OptIn(ExperimentalCoroutinesApi::class)
class RecursionDetectionTest {
private class TraceHeaderSerializer(
private val traceHeader: String
) : HttpSerialize<Unit> {
override suspend fun serialize(context: ExecutionContext, input: Unit): HttpRequestBuilder {
val builder = HttpRequestBuilder()
builder.headers[HEADER_TRACE_ID] = traceHeader
return builder
}
}

private val mockEngine = object : HttpClientEngineBase("test") {
override suspend fun roundTrip(request: HttpRequest): HttpCall {
val resp = HttpResponse(HttpStatusCode.fromValue(200), Headers.Empty, HttpBody.Empty)
val now = Instant.now()
return HttpCall(request, resp, now, now)
}
}

private val client = sdkHttpClient(mockEngine)

private suspend fun test(
env: Map<String, String>,
existingTraceHeader: String?,
expectedTraceHeader: String?
) {
val op = SdkHttpOperation.build<Unit, HttpResponse> {
serializer = if (existingTraceHeader != null) TraceHeaderSerializer(existingTraceHeader) else UnitSerializer
deserializer = IdentityDeserializer
context {
service = "Test Service"
operationName = "testOperation"
}
}

val provider = TestPlatformProvider(env)
op.install(RecursionDetection(provider))
op.roundTrip(client, Unit)

val request = op.context[HttpOperationContext.HttpCallList].last().request
if (expectedTraceHeader != null) {
assertEquals(expectedTraceHeader, request.headers[HEADER_TRACE_ID])
} else {
assertFalse(request.headers.contains(HEADER_TRACE_ID))
}
}

@Test
fun `it noops if env unset`() = runTest {
test(
emptyMap(),
null,
null
)
}

@Test
fun `it sets header when both envs are present`() = runTest {
test(
mapOf(
ENV_FUNCTION_NAME to "some-function",
ENV_TRACE_ID to "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1;lineage=a87bd80c:0,68fd508a:5,c512fbe3:2"
),
null,
"Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1;lineage=a87bd80c:0,68fd508a:5,c512fbe3:2"
)
}

@Test
fun `it noops if trace env set but no lambda env`() = runTest {
test(
mapOf(
ENV_TRACE_ID to "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1;lineage=a87bd80c:0,68fd508a:5,c512fbe3:2"
),
null,
null
)
}

@Test
fun `it respects existing trace header`() = runTest {
test(
mapOf(
ENV_FUNCTION_NAME to "some-function",
ENV_TRACE_ID to "EnvValue"
),
"OriginalValue",
"OriginalValue"
)
}

@Test
fun `it url encodes new trace header`() = runTest {
test(
mapOf(
ENV_FUNCTION_NAME to "some-function",
ENV_TRACE_ID to "first\nsecond"
),
null,
"first%0Asecond"
)
}

@Test
fun `ignores other chars that are usually percent encoded`() = runTest {
test(
mapOf(
ENV_FUNCTION_NAME to "some-function",
ENV_TRACE_ID to "test123-=;:+&[]{}\"'"
),
null,
"test123-=;:+&[]{}\"'"
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import aws.sdk.kotlin.codegen.AwsKotlinDependency
import aws.sdk.kotlin.codegen.AwsRuntimeTypes
import aws.sdk.kotlin.codegen.protocols.eventstream.EventStreamParserGenerator
import aws.sdk.kotlin.codegen.protocols.eventstream.EventStreamSerializerGenerator
import aws.sdk.kotlin.codegen.protocols.middleware.RecursionDetectionMiddleware
import aws.sdk.kotlin.codegen.protocols.middleware.ResolveAwsEndpointMiddleware
import aws.sdk.kotlin.codegen.protocols.middleware.UserAgentMiddleware
import aws.sdk.kotlin.codegen.protocols.protocoltest.AwsHttpProtocolUnitTestErrorGenerator
Expand Down Expand Up @@ -48,6 +49,7 @@ abstract class AwsHttpBindingProtocolGenerator : HttpBindingProtocolGenerator()
}.toMutableList()

middleware.add(UserAgentMiddleware())
middleware.add(RecursionDetectionMiddleware())
return middleware
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package aws.sdk.kotlin.codegen.protocols.middleware

import aws.sdk.kotlin.codegen.AwsKotlinDependency
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
import software.amazon.smithy.kotlin.codegen.model.buildSymbol
import software.amazon.smithy.kotlin.codegen.model.namespace
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolMiddleware
import software.amazon.smithy.model.shapes.OperationShape

/**
* HTTP middleware to add the recursion detection header where required.
*/
class RecursionDetectionMiddleware : ProtocolMiddleware {
Comment on lines +16 to +19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: This middleware doesn't override isEnabledFor. Just confirming...we want this to run for every operation in every AWS service?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats a good question, it could be limited to only services that have an integration that causes recursion if we know that set. Otherwise it would have to be every service I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per post-standup discussion, we'll leave it on every request since

  1. defining the set of services that could be part of a recursive loop is non-trivial, and is subject to constant change
  2. the risk of causing a runaway workload for a user (large bill, eroded trust) is not worth selectively including the header

override val name: String = "RecursionDetection"
override val order: Byte = 30

private val middlewareSymbol = buildSymbol {
name = "RecursionDetection"
namespace(AwsKotlinDependency.AWS_HTTP, subpackage = "middleware")
}

override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) {
writer.write("op.install(#T())", middlewareSymbol)
}
}