diff --git a/aws-runtime/aws-core/build.gradle.kts b/aws-runtime/aws-core/build.gradle.kts index b37439d67a3..5922eee90ae 100644 --- a/aws-runtime/aws-core/build.gradle.kts +++ b/aws-runtime/aws-core/build.gradle.kts @@ -15,6 +15,9 @@ kotlin { commonMain { dependencies { api("aws.smithy.kotlin:runtime-core:$smithyKotlinVersion") + + // FIXME - should we just move these into core and get rid of aws-types at this point? + api(project(":aws-runtime:aws-types")) implementation("aws.smithy.kotlin:logging:$smithyKotlinVersion") } } diff --git a/aws-runtime/aws-core/common/src/aws/sdk/kotlin/runtime/execution/AuthAttributes.kt b/aws-runtime/aws-core/common/src/aws/sdk/kotlin/runtime/execution/AuthAttributes.kt index 358705ac8ae..66e282e3ab2 100644 --- a/aws-runtime/aws-core/common/src/aws/sdk/kotlin/runtime/execution/AuthAttributes.kt +++ b/aws-runtime/aws-core/common/src/aws/sdk/kotlin/runtime/execution/AuthAttributes.kt @@ -5,8 +5,11 @@ package aws.sdk.kotlin.runtime.execution +import aws.sdk.kotlin.runtime.auth.credentials.CredentialsProvider import aws.smithy.kotlin.runtime.client.ClientOption import aws.smithy.kotlin.runtime.time.Instant +import aws.smithy.kotlin.runtime.util.AttributeKey +import aws.smithy.kotlin.runtime.util.InternalApi /** * Operation (execution) options related to authorization @@ -34,4 +37,17 @@ public object AuthAttributes { * NOTE: This is not a common option. */ public val SigningDate: ClientOption = ClientOption("SigningDate") + + /** + * The [CredentialsProvider] to complete the signing process with. Defaults to the provider configured + * on the service client. + * NOTE: This is not a common option. + */ + public val CredentialsProvider: ClientOption = ClientOption("CredentialsProvider") + + /** + * The signature of the HTTP request. This will only exist after the request has been signed! + */ + @InternalApi + public val RequestSignature: AttributeKey = AttributeKey("AWS_HTTP_SIGNATURE") } diff --git a/aws-runtime/aws-signing/common/src/aws/sdk/kotlin/runtime/auth/signing/AwsSigV4SigningMiddleware.kt b/aws-runtime/aws-signing/common/src/aws/sdk/kotlin/runtime/auth/signing/AwsSigV4SigningMiddleware.kt index 7cd236b8beb..fb92e0cf10a 100644 --- a/aws-runtime/aws-signing/common/src/aws/sdk/kotlin/runtime/auth/signing/AwsSigV4SigningMiddleware.kt +++ b/aws-runtime/aws-signing/common/src/aws/sdk/kotlin/runtime/auth/signing/AwsSigV4SigningMiddleware.kt @@ -144,7 +144,12 @@ public class AwsSigV4SigningMiddleware(private val config: Config) : ModifyReque } } - val signedRequest = AwsSigner.signRequest(signableRequest, opSigningConfig.toCrt()) + val signingResult = AwsSigner.sign(signableRequest, opSigningConfig.toCrt()) + val signedRequest = checkNotNull(signingResult.signedRequest) { "signing result must return a non-null HTTP request" } + + // Add the signature to the request context + req.context[AuthAttributes.RequestSignature] = signingResult.signature + req.subject.update(signedRequest) req.subject.body.resetStream() diff --git a/aws-runtime/aws-signing/common/src/aws/sdk/kotlin/runtime/auth/signing/AwsSigning.kt b/aws-runtime/aws-signing/common/src/aws/sdk/kotlin/runtime/auth/signing/AwsSigning.kt index 4ef5f2eec3d..cd508a92a63 100644 --- a/aws-runtime/aws-signing/common/src/aws/sdk/kotlin/runtime/auth/signing/AwsSigning.kt +++ b/aws-runtime/aws-signing/common/src/aws/sdk/kotlin/runtime/auth/signing/AwsSigning.kt @@ -6,6 +6,7 @@ package aws.sdk.kotlin.runtime.auth.signing import aws.sdk.kotlin.crt.auth.signing.AwsSigner +import aws.sdk.kotlin.runtime.InternalSdkApi import aws.sdk.kotlin.runtime.crt.toSignableCrtRequest import aws.sdk.kotlin.runtime.crt.update import aws.smithy.kotlin.runtime.http.request.HttpRequest @@ -61,3 +62,18 @@ public suspend fun sign(request: HttpRequest, config: AwsSigningConfig): Signing val output = builder.build() return SigningResult(output, crtResult.signature) } + +/** + * Sign a body [chunk] using the given signing [config] + * + * @param chunk the body chunk to sign + * @param prevSignature the signature of the previous component of the request (either the initial request signature + * itself for the first chunk or the previous chunk otherwise) + * @param config the signing configuration to use + * @return the signing result + */ +@InternalSdkApi +public suspend fun sign(chunk: ByteArray, prevSignature: ByteArray, config: AwsSigningConfig): SigningResult { + val crtResult = AwsSigner.signChunk(chunk, prevSignature, config.toCrt()) + return SigningResult(Unit, crtResult.signature) +} diff --git a/aws-runtime/aws-signing/common/src/aws/sdk/kotlin/runtime/auth/signing/AwsSigningConfig.kt b/aws-runtime/aws-signing/common/src/aws/sdk/kotlin/runtime/auth/signing/AwsSigningConfig.kt index 4b481346169..2acf1402246 100644 --- a/aws-runtime/aws-signing/common/src/aws/sdk/kotlin/runtime/auth/signing/AwsSigningConfig.kt +++ b/aws-runtime/aws-signing/common/src/aws/sdk/kotlin/runtime/auth/signing/AwsSigningConfig.kt @@ -5,6 +5,7 @@ package aws.sdk.kotlin.runtime.auth.signing +import aws.sdk.kotlin.runtime.InternalSdkApi import aws.sdk.kotlin.runtime.auth.credentials.Credentials import aws.sdk.kotlin.runtime.auth.credentials.CredentialsProvider import aws.smithy.kotlin.runtime.time.Instant @@ -21,7 +22,7 @@ public typealias ShouldSignHeaderFn = (String) -> Boolean */ public class AwsSigningConfig private constructor(builder: Builder) { public companion object { - public operator fun invoke(block: Builder.() -> Unit): AwsSigningConfig = Builder().apply(block).build() + public inline operator fun invoke(block: Builder.() -> Unit): AwsSigningConfig = Builder().apply(block).build() } /** * The region to sign against @@ -119,6 +120,27 @@ public class AwsSigningConfig private constructor(builder: Builder) { */ public val expiresAfter: Duration? = builder.expiresAfter + @InternalSdkApi + public fun toBuilder(): Builder { + val config = this + return Builder().apply { + region = config.region + service = config.service + date = config.date + algorithm = config.algorithm + shouldSignHeader = config.shouldSignHeader + signatureType = config.signatureType + useDoubleUriEncode = config.useDoubleUriEncode + normalizeUriPath = config.normalizeUriPath + omitSessionToken = config.omitSessionToken + signedBodyValue = config.signedBodyValue + signedBodyHeader = config.signedBodyHeaderType + credentials = config.credentials + credentialsProvider = config.credentialsProvider + expiresAfter = config.expiresAfter + } + } + public class Builder { public var region: String? = null public var service: String? = null @@ -135,7 +157,8 @@ public class AwsSigningConfig private constructor(builder: Builder) { public var credentialsProvider: CredentialsProvider? = null public var expiresAfter: Duration? = null - internal fun build(): AwsSigningConfig = AwsSigningConfig(this) + @InternalSdkApi + public fun build(): AwsSigningConfig = AwsSigningConfig(this) } } diff --git a/aws-runtime/aws-signing/common/test/aws/sdk/kotlin/runtime/auth/signing/AwsSigningTest.kt b/aws-runtime/aws-signing/common/test/aws/sdk/kotlin/runtime/auth/signing/AwsSigningTest.kt index 22e19f535a6..b5b11e3c091 100644 --- a/aws-runtime/aws-signing/common/test/aws/sdk/kotlin/runtime/auth/signing/AwsSigningTest.kt +++ b/aws-runtime/aws-signing/common/test/aws/sdk/kotlin/runtime/auth/signing/AwsSigningTest.kt @@ -5,9 +5,14 @@ package aws.sdk.kotlin.runtime.auth.signing +import aws.sdk.kotlin.runtime.auth.credentials.Credentials import aws.smithy.kotlin.runtime.http.HttpMethod +import aws.smithy.kotlin.runtime.http.Url import aws.smithy.kotlin.runtime.http.content.ByteArrayContent +import aws.smithy.kotlin.runtime.http.request.HttpRequest import aws.smithy.kotlin.runtime.http.request.HttpRequestBuilder +import aws.smithy.kotlin.runtime.http.request.headers +import aws.smithy.kotlin.runtime.http.request.url import aws.smithy.kotlin.runtime.time.Instant import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.runTest @@ -83,4 +88,102 @@ class AwsSigningTest { val authHeader = result.output.headers["Authorization"]!! assertTrue(authHeader.contains(expectedPrefix), "Sigv4A auth header: $authHeader") } + + // based on: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html#example-signature-calculations-streaming + private val CHUNKED_ACCESS_KEY_ID = "AKIAIOSFODNN7EXAMPLE" + private val CHUNKED_SECRET_ACCESS_KEY = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + private val CHUNKED_TEST_CREDENTIALS = Credentials(CHUNKED_ACCESS_KEY_ID, CHUNKED_SECRET_ACCESS_KEY) + private val CHUNKED_TEST_REGION = "us-east-1" + private val CHUNKED_TEST_SERVICE = "s3" + private val CHUNKED_TEST_SIGNING_TIME = "2013-05-24T00:00:00Z" + private val CHUNK1_SIZE = 65536 + private val CHUNK2_SIZE = 1024 + + private val EXPECTED_CHUNK_REQUEST_AUTHORIZATION_HEADER = + "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20130524/us-east-1/s3/aws4_request, " + + "SignedHeaders=content-encoding;content-length;host;x-amz-content-sha256;x-amz-date;x-amz-decoded-content-length;x-" + + "amz-storage-class, Signature=4f232c4386841ef735655705268965c44a0e4690baa4adea153f7db9fa80a0a9" + + private val EXPECTED_REQUEST_SIGNATURE = "4f232c4386841ef735655705268965c44a0e4690baa4adea153f7db9fa80a0a9" + private val EXPECTED_FIRST_CHUNK_SIGNATURE = "ad80c730a21e5b8d04586a2213dd63b9a0e99e0e2307b0ade35a65485a288648" + private val EXPECTED_SECOND_CHUNK_SIGNATURE = "0055627c9e194cb4542bae2aa5492e3c1575bbb81b612b7d234b86a503ef5497" + private val EXPECTED_FINAL_CHUNK_SIGNATURE = "b6c6ea8a5354eaf15b3cb7646744f4275b71ea724fed81ceb9323e279d449df9" + private val EXPECTED_TRAILING_HEADERS_SIGNATURE = "df5735bd9f3295cd9386572292562fefc93ba94e80a0a1ddcbd652c4e0a75e6c" + + private fun createChunkedRequestSigningConfig(): AwsSigningConfig = AwsSigningConfig { + algorithm = AwsSigningAlgorithm.SIGV4 + signatureType = AwsSignatureType.HTTP_REQUEST_VIA_HEADERS + region = CHUNKED_TEST_REGION + service = CHUNKED_TEST_SERVICE + date = Instant.fromIso8601(CHUNKED_TEST_SIGNING_TIME) + useDoubleUriEncode = false + normalizeUriPath = true + signedBodyHeader = AwsSignedBodyHeaderType.X_AMZ_CONTENT_SHA256 + signedBodyValue = AwsSignedBodyValue.STREAMING_AWS4_HMAC_SHA256_PAYLOAD + credentials = CHUNKED_TEST_CREDENTIALS + } + + private fun createChunkedSigningConfig(): AwsSigningConfig = AwsSigningConfig { + algorithm = AwsSigningAlgorithm.SIGV4 + signatureType = AwsSignatureType.HTTP_REQUEST_CHUNK + region = CHUNKED_TEST_REGION + service = CHUNKED_TEST_SERVICE + date = Instant.fromIso8601(CHUNKED_TEST_SIGNING_TIME) + useDoubleUriEncode = false + normalizeUriPath = true + signedBodyHeader = AwsSignedBodyHeaderType.NONE + credentials = CHUNKED_TEST_CREDENTIALS + } + + private fun createChunkedTestRequest() = HttpRequest { + method = HttpMethod.PUT + url(Url.parse("https://s3.amazonaws.com/examplebucket/chunkObject.txt")) + headers { + set("Host", url.host) + set("x-amz-storage-class", "REDUCED_REDUNDANCY") + set("Content-Encoding", "aws-chunked") + set("x-amz-decoded-content-length", "66560") + set("Content-Length", "66824") + } + } + + private fun chunk1(): ByteArray { + val chunk = ByteArray(CHUNK1_SIZE) + for (i in chunk.indices) { + chunk[i] = 'a'.code.toByte() + } + return chunk + } + + private fun chunk2(): ByteArray { + val chunk = ByteArray(CHUNK2_SIZE) + for (i in chunk.indices) { + chunk[i] = 'a'.code.toByte() + } + return chunk + } + + @Test + fun testSignChunks() = runTest { + val request = createChunkedTestRequest() + val chunkedRequestConfig = createChunkedRequestSigningConfig() + val requestResult = sign(request, chunkedRequestConfig) + assertEquals(EXPECTED_CHUNK_REQUEST_AUTHORIZATION_HEADER, requestResult.output.headers["Authorization"]) + assertEquals(EXPECTED_REQUEST_SIGNATURE, requestResult.signature.decodeToString()) + + var prevSignature = requestResult.signature + + val chunkedSigningConfig = createChunkedSigningConfig() + val chunk1Result = sign(chunk1(), prevSignature, chunkedSigningConfig) + assertEquals(EXPECTED_FIRST_CHUNK_SIGNATURE, chunk1Result.signature.decodeToString()) + prevSignature = chunk1Result.signature + + val chunk2Result = sign(chunk2(), prevSignature, chunkedSigningConfig) + assertEquals(EXPECTED_SECOND_CHUNK_SIGNATURE, chunk2Result.signature.decodeToString()) + prevSignature = chunk2Result.signature + + // TODO - do we want 0 byte data like this or just allow null? + val finalChunkResult = sign(ByteArray(0), prevSignature, chunkedSigningConfig) + assertEquals(EXPECTED_FINAL_CHUNK_SIGNATURE, finalChunkResult.signature.decodeToString()) + } } diff --git a/aws-runtime/protocols/aws-event-stream/build.gradle.kts b/aws-runtime/protocols/aws-event-stream/build.gradle.kts index c0ca4995a42..b1c605fefdc 100644 --- a/aws-runtime/protocols/aws-event-stream/build.gradle.kts +++ b/aws-runtime/protocols/aws-event-stream/build.gradle.kts @@ -7,17 +7,27 @@ description = "Support for the vnd.amazon.event-stream content type" extra["displayName"] = "AWS :: SDK :: Kotlin :: Protocols :: Event Stream" extra["moduleName"] = "aws.sdk.kotlin.runtime.protocol.eventstream" +val smithyKotlinVersion: String by project +val coroutinesVersion: String by project kotlin { sourceSets { commonMain { dependencies { api(project(":aws-runtime:aws-core")) + // exposes Buffer/MutableBuffer and SdkByteReadChannel + api("aws.smithy.kotlin:io:$smithyKotlinVersion") + // exposes Flow + api("org.jetbrains.kotlinx:kotlinx-coroutines-core:$coroutinesVersion") + + // exposes AwsSigningConfig + api(project(":aws-runtime:aws-signing")) } } commonTest { dependencies { implementation(project(":aws-runtime:testing")) + api("org.jetbrains.kotlinx:kotlinx-coroutines-test:$coroutinesVersion") } } diff --git a/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/EventStreamSigning.kt b/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/EventStreamSigning.kt new file mode 100644 index 00000000000..fc0053486fe --- /dev/null +++ b/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/EventStreamSigning.kt @@ -0,0 +1,99 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package aws.sdk.kotlin.runtime.protocol.eventstream + +import aws.sdk.kotlin.runtime.auth.signing.* +import aws.sdk.kotlin.runtime.execution.AuthAttributes +import aws.smithy.kotlin.runtime.client.ExecutionContext +import aws.smithy.kotlin.runtime.io.SdkByteBuffer +import aws.smithy.kotlin.runtime.io.bytes +import aws.smithy.kotlin.runtime.time.Clock +import aws.smithy.kotlin.runtime.time.Instant +import aws.smithy.kotlin.runtime.util.InternalApi +import aws.smithy.kotlin.runtime.util.get +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow + +/** + * Creates a flow that signs each event stream message with the given signing config. + * + * Each message's signature incorporates the signature of the previous message. + * The very first message incorporates the signature of the initial-request for + * both HTTP2 and WebSockets. The initial signature comes from the execution context. + */ +@InternalApi +public fun Flow.sign( + context: ExecutionContext, + config: AwsSigningConfig, +): Flow = flow { + val messages = this@sign + + // NOTE: We need the signature of the initial HTTP request to seed the event stream signatures + // This is a bit of a chicken and egg problem since the event stream is constructed before the request + // is signed. The body of the stream shouldn't start being consumed though until after the entire request + // is built. Thus, by the time we get here the signature will exist in the context. + var prevSignature = context.getOrNull(AuthAttributes.RequestSignature) ?: error("expected initial HTTP signature to be set before message signing commences") + + // signature date is updated per event message + val configBuilder = config.toBuilder() + + messages.collect { message -> + // FIXME - can we get an estimate here on size? + val buffer = SdkByteBuffer(0U) + message.encode(buffer) + + // the entire message is wrapped as the payload of the signed message + val result = signPayload(configBuilder, prevSignature, buffer.bytes()) + prevSignature = result.signature + emit(result.output) + } +} + +internal suspend fun signPayload( + configBuilder: AwsSigningConfig.Builder, + prevSignature: ByteArray, + messagePayload: ByteArray, + clock: Clock = Clock.System +): SigningResult { + val dt = clock.now().truncateSubsecs() + val config = configBuilder.apply { date = dt }.build() + + val result = sign(messagePayload, prevSignature, config) + val signature = result.signature + + val signedMessage = buildMessage { + addHeader(":date", HeaderValue.Timestamp(dt)) + addHeader(":chunk-signature", HeaderValue.ByteArray(signature)) + payload = messagePayload + } + + return SigningResult(signedMessage, signature) +} + +/** + * Truncate the sub-seconds from the current time + */ +private fun Instant.truncateSubsecs(): Instant = Instant.fromEpochSeconds(epochSeconds, 0) + +/** + * Create a new signing config for an event stream using the current context to set the operation/service specific + * configuration (e.g. region, signing service, credentials, etc) + */ +@InternalApi +public fun ExecutionContext.newEventStreamSigningConfig(): AwsSigningConfig = AwsSigningConfig { + algorithm = AwsSigningAlgorithm.SIGV4 + signatureType = AwsSignatureType.HTTP_REQUEST_CHUNK + region = this@newEventStreamSigningConfig[AuthAttributes.SigningRegion] + service = this@newEventStreamSigningConfig[AuthAttributes.SigningService] + credentialsProvider = this@newEventStreamSigningConfig[AuthAttributes.CredentialsProvider] + useDoubleUriEncode = false + normalizeUriPath = true + signedBodyHeader = AwsSignedBodyHeaderType.NONE + + // FIXME - needs to be set on the operation for initial request + // signedBodyHeader = AwsSignedBodyHeaderType.X_AMZ_CONTENT_SHA256 + // signedBodyValue = AwsSignedBodyValue.STREAMING_AWS4_HMAC_SHA256_PAYLOAD +} diff --git a/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/FrameDecoder.kt b/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/FrameDecoder.kt index 20bca8fa1d5..e131b5ae4a3 100644 --- a/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/FrameDecoder.kt +++ b/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/FrameDecoder.kt @@ -5,48 +5,50 @@ package aws.sdk.kotlin.runtime.protocol.eventstream +import aws.sdk.kotlin.runtime.ClientException import aws.sdk.kotlin.runtime.InternalSdkApi -import aws.smithy.kotlin.runtime.io.Buffer -import aws.smithy.kotlin.runtime.io.SdkByteBuffer -import aws.smithy.kotlin.runtime.io.readFully +import aws.smithy.kotlin.runtime.io.* +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +/** + * Exception thrown when deserializing raw event stream messages off the wire fails for some reason + */ +public class EventStreamFramingException(message: String, cause: Throwable? = null) : ClientException(message, cause) + +/** + * Convert the raw bytes coming off [chan] to a stream of messages + */ @InternalSdkApi -public class FrameDecoder { - private var prelude: Prelude? = null - - /** - * Reset the decoder discarding any intermediate state - */ - public fun reset() { prelude = null } - - private fun isFrameAvailable(buffer: Buffer): Boolean { - val totalLen = prelude?.totalLen ?: return false - val remaining = totalLen - PRELUDE_BYTE_LEN_WITH_CRC - return buffer.readRemaining >= remaining.toULong() - } +public suspend fun decodeFrames(chan: SdkByteReadChannel): Flow = flow { + while (!chan.isClosedForRead) { + // get the prelude to figure out how much is left to read of the message + val preludeBytes = ByteArray(PRELUDE_BYTE_LEN_WITH_CRC) - /** - * Attempt to decode a [Message] from the buffer. This function expects to be called over and over again - * with more data in the buffer each time its called. When there is not enough data to decode this function - * returns null. - * The decoder will consume the prelude when enough data is available. When it is invoked with enough - * data it will consume the remaining message bytes. - */ - public fun decodeFrame(buffer: Buffer): Message? { - if (prelude == null && buffer.readRemaining >= PRELUDE_BYTE_LEN_WITH_CRC.toULong()) { - prelude = Prelude.decode(buffer) + try { + chan.readFully(preludeBytes) + } catch (ex: Exception) { + throw EventStreamFramingException("failed to read message prelude from channel", ex) } - return when (isFrameAvailable(buffer)) { - true -> { - val currPrelude = checkNotNull(prelude) - val messageBuf = SdkByteBuffer(currPrelude.totalLen.toULong()) - currPrelude.encode(messageBuf) - buffer.readFully(messageBuf) - reset() - Message.decode(messageBuf) - } - else -> null + val preludeBuf = SdkByteBuffer.of(preludeBytes).apply { advance(preludeBytes.size.toULong()) } + val prelude = Prelude.decode(preludeBuf) + + // get a buffer with one complete message in it, prelude has already been read though, leave room for it + val messageBytes = ByteArray(prelude.totalLen) + + try { + chan.readFully(messageBytes, offset = PRELUDE_BYTE_LEN_WITH_CRC) + } catch (ex: Exception) { + throw EventStreamFramingException("failed to read message from channel", ex) } + + val messageBuf = SdkByteBuffer.of(messageBytes) + messageBuf.writeFully(preludeBytes) + val remaining = prelude.totalLen - PRELUDE_BYTE_LEN_WITH_CRC + messageBuf.advance(remaining.toULong()) + + val message = Message.decode(messageBuf) + emit(message) } } diff --git a/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/FrameEncoder.kt b/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/FrameEncoder.kt new file mode 100644 index 00000000000..ff8e7ac6bc7 --- /dev/null +++ b/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/FrameEncoder.kt @@ -0,0 +1,55 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package aws.sdk.kotlin.runtime.protocol.eventstream + +import aws.sdk.kotlin.runtime.InternalSdkApi +import aws.smithy.kotlin.runtime.http.HttpBody +import aws.smithy.kotlin.runtime.http.toHttpBody +import aws.smithy.kotlin.runtime.io.SdkByteBuffer +import aws.smithy.kotlin.runtime.io.SdkByteChannel +import aws.smithy.kotlin.runtime.io.bytes +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.launch +import kotlin.coroutines.coroutineContext + +/** + * Transform the stream of messages into a stream of raw bytes. Each + * element of the resulting flow is the encoded version of the corresponding message + */ +@InternalSdkApi +public fun Flow.encode(): Flow = map { + // TODO - can we figure out the encoded size and directly get a byte array + val buffer = SdkByteBuffer(1024U) + it.encode(buffer) + buffer.bytes() +} + +/** + * Transform a stream of encoded messages into an [HttpBody]. + */ +@InternalSdkApi +public suspend fun Flow.asEventStreamHttpBody(): HttpBody { + val encodedMessages = this + val ch = SdkByteChannel(true) + + // FIXME - we should probably tie this to our own scope (off ExecutionContext) but for now + // tie it to whatever arbitrary scope we are in + val scope = CoroutineScope(coroutineContext) + + val job = scope.launch { + encodedMessages.collect { + ch.writeFully(it) + } + } + + job.invokeOnCompletion { cause -> + ch.close(cause) + } + + return ch.toHttpBody() +} diff --git a/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/HeaderValue.kt b/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/HeaderValue.kt index 5c10c001d68..95f1880834c 100644 --- a/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/HeaderValue.kt +++ b/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/HeaderValue.kt @@ -151,3 +151,13 @@ public sealed class HeaderValue { } private fun MutableBuffer.writeHeader(headerType: HeaderType) = writeByte(headerType.value) + +public fun HeaderValue.expectBool(): Boolean = checkNotNull((this as? HeaderValue.Bool)?.value) { "expected HeaderValue.Bool, found: $this" } +public fun HeaderValue.expectByte(): Byte = checkNotNull((this as? HeaderValue.Byte)?.value?.toByte()) { "expected HeaderValue.Byte, found: $this" } +public fun HeaderValue.expectInt16(): Short = checkNotNull((this as? HeaderValue.Int16)?.value) { "expected HeaderValue.Int16, found: $this" } +public fun HeaderValue.expectInt32(): Int = checkNotNull((this as? HeaderValue.Int32)?.value) { "expected HeaderValue.Int32, found: $this" } +public fun HeaderValue.expectInt64(): Long = checkNotNull((this as? HeaderValue.Int64)?.value) { "expected HeaderValue.Int64, found: $this" } +public fun HeaderValue.expectString(): String = checkNotNull((this as? HeaderValue.String)?.value) { "expected HeaderValue.String, found: $this" } +public fun HeaderValue.expectByteArray(): ByteArray = checkNotNull((this as? HeaderValue.ByteArray)?.value) { "expected HeaderValue.ByteArray, found: $this" } +public fun HeaderValue.expectTimestamp(): Instant = checkNotNull((this as? HeaderValue.Timestamp)?.value) { "expected HeaderValue.Bool, found: $this" } +public fun HeaderValue.expectUuid(): Uuid = checkNotNull((this as? HeaderValue.Uuid)?.value) { "expected HeaderValue.Bool, found: $this" } diff --git a/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/ResponseHeaders.kt b/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/ResponseHeaders.kt new file mode 100644 index 00000000000..647d3fefe8a --- /dev/null +++ b/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/ResponseHeaders.kt @@ -0,0 +1,106 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package aws.sdk.kotlin.runtime.protocol.eventstream + +import aws.sdk.kotlin.runtime.InternalSdkApi + +/** + * Parse the protocol level headers into a concrete [MessageType] + */ +@InternalSdkApi +public fun Message.type(): MessageType { + val headersByName = headers.associateBy { it.name } + val messageType: String = checkNotNull(headersByName[":message-type"]) { "`:message-type` header is required to deserialize an event stream message" }.value.expectString() + val eventType = headersByName[":event-type"]?.value?.expectString() + val exceptionType = headersByName[":exception-type"]?.value?.expectString() + val contentType = headersByName[":content-type"]?.value?.expectString() + + return when (messageType) { + "event" -> MessageType.Event( + checkNotNull(eventType) { "Invalid `event` message: `:event-type` header is missing" }, + contentType + ) + "exception" -> MessageType.Exception( + checkNotNull(exceptionType) { "Invalid `exception` message: `:exception-type` header is missing" }, + contentType + ) + "error" -> { + val errorCode = headersByName[":error-code"]?.value?.expectString() ?: error("Invalid `error` message: `:error-code` header is missing") + val errorMessage = headersByName[":error-message"]?.value?.expectString() + MessageType.Error(errorCode, errorMessage) + } + + else -> MessageType.SdkUnknown(messageType) + } +} + +/** + * Common framework message information parsed from headers + */ +@InternalSdkApi +public sealed class MessageType { + /** + * Corresponds to the `event` message type. All events include the headers: + * + * * `:message-type`: Always set to `event` + * * `:event-type`: (Required) Identifies the event shape from the event stream union. This is the member name from the union. + * * `:content-type`: (Optional) The content type for the payload + * + * ### Example message + * + * ``` + * :message-type: event + * :event-type: MyStruct + * :content-type: application/json + * + * {...} + * ``` + * @param shapeType the event type as identified by the `:event-type` header. + * @param contentType the content type of the payload (if present) + */ + public data class Event(val shapeType: String, val contentType: String? = null) : MessageType() + + /** + * Corresponds to the `exception` message type. + * NOTE: Exceptions are mapped directly to the payload. There is no way to map event headers for exceptions. + * + * ### Example message + * + * ``` + * :message-type: exception + * :exception-type: FooException + * :content-type: application/json + * + * {...} + * ``` + * + * @param shapeType the exception type as identified by the `:exception-type` header. + * @param contentType the content type of the payload (if present) + */ + public data class Exception(val shapeType: String, val contentType: String? = null) : MessageType() + + /** + * Corresponds to the `error` message type. + * Errors are like exceptions, but they are un-modeled and have a fixed set of fields: + * * `:message-type`: Always set to `error` + * * `:error-code`: (Required) UTF-8 string containing name, type, or category of the error. + * * `:error-message`: (Optional) UTF-* string containing an error message + * + * ### Example message + * + * ``` + * :message-type: error + * :error-code: InternalServerError + * :error-message: An error occurred + * ``` + */ + public data class Error(val errorCode: String, val message: String? = null) : MessageType() + + /** + * Catch all for unknown message types outside of `event`, `exception`, or `error` + */ + public data class SdkUnknown(val messageType: String) : MessageType() +} diff --git a/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/EventStreamSigningTest.kt b/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/EventStreamSigningTest.kt new file mode 100644 index 00000000000..9fb499f0aa7 --- /dev/null +++ b/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/EventStreamSigningTest.kt @@ -0,0 +1,57 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package aws.sdk.kotlin.runtime.protocol.eventstream + +import aws.sdk.kotlin.runtime.auth.credentials.Credentials +import aws.sdk.kotlin.runtime.auth.signing.AwsSignatureType +import aws.sdk.kotlin.runtime.auth.signing.AwsSigningConfig +import aws.smithy.kotlin.runtime.io.SdkByteBuffer +import aws.smithy.kotlin.runtime.io.bytes +import aws.smithy.kotlin.runtime.time.Instant +import aws.smithy.kotlin.runtime.time.ManualClock +import aws.smithy.kotlin.runtime.util.encodeToHex +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals + +@OptIn(ExperimentalCoroutinesApi::class) +class EventStreamSigningTest { + + @Test + fun testSignPayload() = runTest { + val messageToSign = buildMessage { + addHeader("some-header", HeaderValue.String("value")) + payload = "test payload".encodeToByteArray() + } + + val epoch = Instant.fromEpochSeconds(123_456_789L, 1234) + val testClock = ManualClock(epoch) + val signingConfig = AwsSigningConfig.Builder().apply { + credentials = Credentials("fake access key", "fake secret key") + region = "us-east-1" + service = "testservice" + signatureType = AwsSignatureType.HTTP_REQUEST_CHUNK + } + + val prevSignature = "last message sts".encodeToByteArray() + + val buffer = SdkByteBuffer(0U) + messageToSign.encode(buffer) + val messagePayload = buffer.bytes() + val result = signPayload(signingConfig, prevSignature, messagePayload, testClock) + assertEquals(":date", result.output.headers[0].name) + + val dateHeader = result.output.headers[0].value.expectTimestamp() + assertEquals(epoch.epochSeconds, dateHeader.epochSeconds) + assertEquals(0, dateHeader.nanosecondsOfSecond) + + assertEquals(":chunk-signature", result.output.headers[1].name) + val expectedSignature = result.signature.encodeToHex() + val actualSignature = result.output.headers[1].value.expectByteArray().encodeToHex() + assertEquals(expectedSignature, actualSignature) + } +} diff --git a/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/FrameDecoderTest.kt b/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/FrameDecoderTest.kt index f44a058c8b8..0b429616b86 100644 --- a/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/FrameDecoderTest.kt +++ b/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/FrameDecoderTest.kt @@ -5,65 +5,56 @@ package aws.sdk.kotlin.runtime.protocol.eventstream -import aws.smithy.kotlin.runtime.io.SdkByteBuffer -import aws.smithy.kotlin.runtime.io.writeByte -import aws.smithy.kotlin.runtime.io.writeFully -import kotlin.math.min +import aws.smithy.kotlin.runtime.io.* +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.* +import kotlinx.coroutines.test.runTest +import kotlin.test.Ignore import kotlin.test.Test import kotlin.test.assertEquals -import kotlin.test.fail +@OptIn(ExperimentalCoroutinesApi::class) class FrameDecoderTest { + @Test - fun testSingleStreamingMessage() { + fun testFrameStreamSingleMessage() = runTest { val encoded = validMessageWithAllHeaders() + val expected = Message.decode(SdkByteBuffer.wrapAsReadBuffer(encoded)) + val chan = SdkByteReadChannel(encoded) - val decoder = FrameDecoder() - val buf = SdkByteBuffer(256u) - for (i in 0 until encoded.size - 1) { - buf.writeByte(encoded[i]) - assertEquals(null, decoder.decodeFrame(buf), "incomplete frame shouldn't result in a message") - } - - buf.writeByte(encoded.last()) + val frames = decodeFrames(chan) + val actual = frames.toList() - when (val frame = decoder.decodeFrame(buf)) { - null -> fail("frame should be complete now") - else -> { - val expected = Message.decode(SdkByteBuffer.wrapAsReadBuffer(encoded)) - assertEquals(expected, frame) - } - } + assertEquals(1, actual.size) + assertEquals(expected, actual.first()) } @Test - fun testMultipleStreamingMessagesChunked() { - val encoded = SdkByteBuffer(256u).apply { + fun testFrameStreamMultipleMessagesChunked() = runTest { + val encoded = SdkByteBuffer(0u).apply { writeFully(validMessageWithAllHeaders()) writeFully(validMessageEmptyPayload()) writeFully(validMessageNoHeaders()) - } - - val decoder = FrameDecoder() - val chunkSize = 8 - - val totalChunks = encoded.readRemaining / chunkSize.toULong() - val buffer = SdkByteBuffer(256u) - val decoded = mutableListOf() - for (i in 0..totalChunks.toInt()) { - buffer.writeFully(encoded, min(chunkSize.toULong(), encoded.readRemaining)) - when (val frame = decoder.decodeFrame(buffer)) { - null -> {} - else -> decoded.add(frame) - } - } + }.bytes() val expected1 = Message.decode(SdkByteBuffer.wrapAsReadBuffer(validMessageWithAllHeaders())) val expected2 = Message.decode(SdkByteBuffer.wrapAsReadBuffer(validMessageEmptyPayload())) val expected3 = Message.decode(SdkByteBuffer.wrapAsReadBuffer(validMessageNoHeaders())) - assertEquals(3, decoded.size) - assertEquals(expected1, decoded[0]) - assertEquals(expected2, decoded[1]) - assertEquals(expected3, decoded[2]) + + val chan = SdkByteReadChannel(encoded) + val frames = decodeFrames(chan) + + val actual = frames.toList() + + assertEquals(3, actual.size) + assertEquals(expected1, actual[0]) + assertEquals(expected2, actual[1]) + assertEquals(expected3, actual[2]) + } + + @Ignore + @Test + fun testChannelClosed() = runTest { + TODO("not implemented yet: need to add test for channel closed normally while waiting on prelude") } } diff --git a/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/FrameEncoderTest.kt b/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/FrameEncoderTest.kt new file mode 100644 index 00000000000..03e5d2de262 --- /dev/null +++ b/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/FrameEncoderTest.kt @@ -0,0 +1,60 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package aws.sdk.kotlin.runtime.protocol.eventstream + +import aws.smithy.kotlin.runtime.http.readAll +import aws.smithy.kotlin.runtime.io.SdkByteBuffer +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals + +@OptIn(ExperimentalCoroutinesApi::class) +class FrameEncoderTest { + @Test + fun testEncode() = runTest { + val expected = listOf( + validMessageWithAllHeaders(), + validMessageEmptyPayload(), + validMessageNoHeaders() + ) + + val message1 = Message.decode(SdkByteBuffer.wrapAsReadBuffer(validMessageWithAllHeaders())) + val message2 = Message.decode(SdkByteBuffer.wrapAsReadBuffer(validMessageEmptyPayload())) + val message3 = Message.decode(SdkByteBuffer.wrapAsReadBuffer(validMessageNoHeaders())) + + val messages = flowOf( + message1, + message2, + message3 + ) + + val actual = messages.encode().toList() + + assertEquals(3, actual.size) + assertContentEquals(expected[0], actual[0]) + assertContentEquals(expected[1], actual[1]) + assertContentEquals(expected[2], actual[2]) + } + + @Test + fun testAsEventStreamHttpBody() = runTest { + val messages = flowOf( + "foo", + "bar", + "baz" + ).map { it.encodeToByteArray() } + + val body = messages.asEventStreamHttpBody() + val actual = body.readAll() + val expected = "foobarbaz" + assertEquals(expected, actual?.decodeToString()) + } +} diff --git a/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/HeaderValueTest.kt b/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/HeaderValueTest.kt new file mode 100644 index 00000000000..7bbd3851af4 --- /dev/null +++ b/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/HeaderValueTest.kt @@ -0,0 +1,36 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package aws.sdk.kotlin.runtime.protocol.eventstream + +import aws.smithy.kotlin.runtime.time.Instant +import aws.smithy.kotlin.runtime.util.Uuid +import io.kotest.matchers.string.shouldContain +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertFails + +@OptIn(Uuid.WeakRng::class) +class HeaderValueTest { + @Test + fun testExpectAs() { + assertEquals(true, HeaderValue.Bool(true).expectBool()) + assertEquals(12.toByte(), HeaderValue.Byte(12u).expectByte()) + assertEquals(12.toShort(), HeaderValue.Int16(12).expectInt16()) + assertEquals(12, HeaderValue.Int32(12).expectInt32()) + assertEquals(12L, HeaderValue.Int64(12L).expectInt64()) + assertEquals("foo", HeaderValue.String("foo").expectString()) + assertContentEquals("foo".encodeToByteArray(), HeaderValue.ByteArray("foo".encodeToByteArray()).expectByteArray()) + val ts = Instant.now() + assertEquals(ts, HeaderValue.Timestamp(ts).expectTimestamp()) + val uuid = Uuid.random() + assertEquals(uuid, HeaderValue.Uuid(uuid).expectUuid()) + + assertFails { + HeaderValue.Int32(12).expectString() + }.message.shouldContain("expected HeaderValue.String, found: Int32(value=12)") + } +} diff --git a/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/ResponseHeadersTest.kt b/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/ResponseHeadersTest.kt new file mode 100644 index 00000000000..acf2fbe2aaa --- /dev/null +++ b/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/ResponseHeadersTest.kt @@ -0,0 +1,125 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package aws.sdk.kotlin.runtime.protocol.eventstream + +import kotlin.test.* + +class ResponseHeadersTest { + + @Test + fun testNormalMessage() { + val message = buildMessage { + payload = "test".encodeToByteArray() + addHeader(":message-type", HeaderValue.String("event")) + addHeader(":event-type", HeaderValue.String("Foo")) + addHeader(":content-type", HeaderValue.String("application/json")) + } + + val actual = message.type() + assertIs(actual) + assertEquals("Foo", actual.shapeType) + assertEquals("application/json", actual.contentType) + } + + @Test + fun testExceptionMessage() { + val message = buildMessage { + payload = "test".encodeToByteArray() + addHeader(":message-type", HeaderValue.String("exception")) + addHeader(":exception-type", HeaderValue.String("BadRequestException")) + addHeader(":content-type", HeaderValue.String("application/json")) + } + + val actual = message.type() + assertIs(actual) + assertEquals("BadRequestException", actual.shapeType) + assertEquals("application/json", actual.contentType) + } + + @Test + fun testMissingExceptionType() { + val message = buildMessage { + payload = "test".encodeToByteArray() + addHeader(":message-type", HeaderValue.String("exception")) + addHeader(":content-type", HeaderValue.String("application/json")) + } + + val ex = assertFailsWith { + message.type() + } + + assertEquals(ex.message, "Invalid `exception` message: `:exception-type` header is missing") + } + + @Test + fun testMissingEventType() { + val message = buildMessage { + payload = "test".encodeToByteArray() + addHeader(":message-type", HeaderValue.String("event")) + addHeader(":content-type", HeaderValue.String("application/json")) + } + + val ex = assertFailsWith { + message.type() + } + + assertEquals(ex.message, "Invalid `event` message: `:event-type` header is missing") + } + + @Test + fun testMissingMessageType() { + val message = buildMessage { + payload = "test".encodeToByteArray() + addHeader(":event-type", HeaderValue.String("Foo")) + addHeader(":content-type", HeaderValue.String("application/json")) + } + + val ex = assertFailsWith { + message.type() + } + + assertEquals(ex.message, "`:message-type` header is required to deserialize an event stream message") + } + + @Test + fun testMissingContentType() { + val message = buildMessage { + payload = "test".encodeToByteArray() + addHeader(":message-type", HeaderValue.String("event")) + addHeader(":event-type", HeaderValue.String("Foo")) + } + + val actual = message.type() + assertIs(actual) + assertEquals("Foo", actual.shapeType) + assertNull(actual.contentType) + } + + @Test + fun testErrorMessage() { + val message = buildMessage { + addHeader(":message-type", HeaderValue.String("error")) + addHeader(":error-code", HeaderValue.String("InternalError")) + addHeader(":error-message", HeaderValue.String("An internal server error occurred")) + } + + val actual = message.type() + assertIs(actual) + assertEquals("InternalError", actual.errorCode) + assertEquals("An internal server error occurred", actual.message) + } + + @Test + fun testUnknown() { + val message = buildMessage { + addHeader(":message-type", HeaderValue.String("foo")) + } + + val actual = message.type() + assertIs(actual) + assertEquals("foo", actual.messageType) + } +} diff --git a/codegen/sdk/build.gradle.kts b/codegen/sdk/build.gradle.kts index 31ba48e04af..feddcaa09aa 100644 --- a/codegen/sdk/build.gradle.kts +++ b/codegen/sdk/build.gradle.kts @@ -95,8 +95,6 @@ data class AwsService( val disabledServices = setOf( - // transcribe streaming contains exclusively EventStream operations which are not supported - "transcribestreaming", // timestream requires endpoint discovery // https://github.com/awslabs/smithy-kotlin/issues/146 "timestreamwrite", @@ -465,4 +463,4 @@ tasks.register("syncAwsModels") { logger.warn("${source.path} (sdkId=$sdkId) is new to aws-models since the last sync!") } } -} \ No newline at end of file +} diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsKotlinDependency.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsKotlinDependency.kt index 6201b20e590..c36322b28f5 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsKotlinDependency.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsKotlinDependency.kt @@ -43,6 +43,7 @@ object AwsKotlinDependency { val AWS_JSON_PROTOCOLS = KotlinDependency(GradleConfiguration.Implementation, "$AWS_CLIENT_RT_ROOT_NS.protocol.json", AWS_CLIENT_RT_GROUP, "aws-json-protocols", AWS_CLIENT_RT_VERSION) val AWS_XML_PROTOCOLS = KotlinDependency(GradleConfiguration.Implementation, "$AWS_CLIENT_RT_ROOT_NS.protocol.xml", AWS_CLIENT_RT_GROUP, "aws-xml-protocols", AWS_CLIENT_RT_VERSION) val AWS_CRT_HTTP_ENGINE = KotlinDependency(GradleConfiguration.Implementation, "$AWS_CLIENT_RT_ROOT_NS.http.engine.crt", AWS_CLIENT_RT_GROUP, "http-client-engine-crt", AWS_CLIENT_RT_VERSION) + val AWS_EVENT_STREAM = KotlinDependency(GradleConfiguration.Implementation, "$AWS_CLIENT_RT_ROOT_NS.protocol.eventstream", AWS_CLIENT_RT_GROUP, "aws-event-stream", AWS_CLIENT_RT_VERSION) } // remap aws-sdk-kotlin dependencies to project notation @@ -59,6 +60,7 @@ private val sameProjectDeps: Map by lazy { AwsKotlinDependency.AWS_JSON_PROTOCOLS to """project(":aws-runtime:protocols:aws-json-protocols")""", AwsKotlinDependency.AWS_XML_PROTOCOLS to """project(":aws-runtime:protocols:aws-xml-protocols")""", AwsKotlinDependency.AWS_CRT_HTTP_ENGINE to """project(":aws-runtime:http-client-engine-crt")""", + AwsKotlinDependency.AWS_EVENT_STREAM to """project(":aws-runtime:protocols:aws-event-stream")""", ) } diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsRuntimeTypes.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsRuntimeTypes.kt index 28991046502..9df16e6a4cb 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsRuntimeTypes.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/AwsRuntimeTypes.kt @@ -79,6 +79,30 @@ object AwsRuntimeTypes { val parseRestXmlErrorResponse = runtimeSymbol("parseRestXmlErrorResponse", AwsKotlinDependency.AWS_XML_PROTOCOLS) val parseEc2QueryErrorResponse = runtimeSymbol("parseEc2QueryErrorResponse", AwsKotlinDependency.AWS_XML_PROTOCOLS) } + + object AwsEventStream { + val HeaderValue = runtimeSymbol("HeaderValue", AwsKotlinDependency.AWS_EVENT_STREAM) + val Message = runtimeSymbol("Message", AwsKotlinDependency.AWS_EVENT_STREAM) + val MessageType = runtimeSymbol("MessageType", AwsKotlinDependency.AWS_EVENT_STREAM) + val MessageTypeExt = runtimeSymbol("type", AwsKotlinDependency.AWS_EVENT_STREAM) + + val asEventStreamHttpBody = runtimeSymbol("asEventStreamHttpBody", AwsKotlinDependency.AWS_EVENT_STREAM) + val buildMessage = runtimeSymbol("buildMessage", AwsKotlinDependency.AWS_EVENT_STREAM) + val decodeFrames = runtimeSymbol("decodeFrames", AwsKotlinDependency.AWS_EVENT_STREAM) + val encode = runtimeSymbol("encode", AwsKotlinDependency.AWS_EVENT_STREAM) + + val expectBool = runtimeSymbol("expectBool", AwsKotlinDependency.AWS_EVENT_STREAM) + val expectByte = runtimeSymbol("expectByte", AwsKotlinDependency.AWS_EVENT_STREAM) + val expectByteArray = runtimeSymbol("expectByteArray", AwsKotlinDependency.AWS_EVENT_STREAM) + val expectInt16 = runtimeSymbol("expectInt16", AwsKotlinDependency.AWS_EVENT_STREAM) + val expectInt32 = runtimeSymbol("expectInt32", AwsKotlinDependency.AWS_EVENT_STREAM) + val expectInt64 = runtimeSymbol("expectInt64", AwsKotlinDependency.AWS_EVENT_STREAM) + val expectTimestamp = runtimeSymbol("expectTimestamp", AwsKotlinDependency.AWS_EVENT_STREAM) + val expectString = runtimeSymbol("expectString", AwsKotlinDependency.AWS_EVENT_STREAM) + + val sign = runtimeSymbol("sign", AwsKotlinDependency.AWS_EVENT_STREAM) + val newEventStreamSigningConfig = runtimeSymbol("newEventStreamSigningConfig", AwsKotlinDependency.AWS_EVENT_STREAM) + } } private fun runtimeSymbol(name: String, dependency: KotlinDependency, subpackage: String = ""): Symbol = buildSymbol { diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/RemoveEventStreamOperations.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/RemoveEventStreamOperations.kt index c05f9d531bd..c029c0d750b 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/RemoveEventStreamOperations.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/RemoveEventStreamOperations.kt @@ -9,12 +9,10 @@ import software.amazon.smithy.kotlin.codegen.KotlinSettings import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration import software.amazon.smithy.kotlin.codegen.model.expectShape import software.amazon.smithy.kotlin.codegen.model.findStreamingMember -import software.amazon.smithy.kotlin.codegen.model.hasTrait import software.amazon.smithy.kotlin.codegen.utils.getOrNull import software.amazon.smithy.model.Model import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.StructureShape -import software.amazon.smithy.model.traits.StreamingTrait import software.amazon.smithy.model.transform.ModelTransformer import java.util.logging.Logger @@ -33,9 +31,9 @@ class RemoveEventStreamOperations : KotlinIntegration { } else { val ioShapes = listOfNotNull(parentShape.output.getOrNull(), parentShape.input.getOrNull()).map { model.expectShape(it) } val hasEventStream = ioShapes.any { ioShape -> - ioShape.allMembers.values.any { model.getShape(it.target).get().hasTrait() } val streamingMember = ioShape.findStreamingMember(model) - streamingMember?.isUnionShape ?: false + val target = streamingMember?.let { model.expectShape(it.target) } + target?.isUnionShape ?: false } // If a streaming member has a union trait, it is an event stream. Event Streams are not currently supported // by the SDK, so if we generate this API it won't work. diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/RestJson1.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/RestJson1.kt index 08fca8aa548..3ff83e18f02 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/RestJson1.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/RestJson1.kt @@ -30,7 +30,7 @@ class RestJson1 : JsonHttpBindingProtocolGenerator() { override val protocol: ShapeId = RestJson1Trait.ID override fun getProtocolHttpBindingResolver(model: Model, serviceShape: ServiceShape): HttpBindingResolver = - HttpTraitResolver(model, serviceShape, "application/json") + HttpTraitResolver(model, serviceShape, ProtocolContentTypes.consistent("application/json")) override fun renderSerializeHttpBody( ctx: ProtocolGenerator.GenerationContext, diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/RestXml.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/RestXml.kt index 7d38aeeeb36..74da45dee48 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/RestXml.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/RestXml.kt @@ -31,7 +31,7 @@ open class RestXml : AwsHttpBindingProtocolGenerator() { // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-restxml-protocol.html#content-type override fun getProtocolHttpBindingResolver(model: Model, serviceShape: ServiceShape): HttpBindingResolver = - HttpTraitResolver(model, serviceShape, "application/xml") + HttpTraitResolver(model, serviceShape, ProtocolContentTypes.consistent("application/xml")) override fun structuredDataParser(ctx: ProtocolGenerator.GenerationContext): StructuredDataParserGenerator = RestXmlParserGenerator(this, defaultTimestampFormat) diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/core/AwsHttpBindingProtocolGenerator.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/core/AwsHttpBindingProtocolGenerator.kt index 1c4a57d8b80..472a3361ae1 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/core/AwsHttpBindingProtocolGenerator.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/core/AwsHttpBindingProtocolGenerator.kt @@ -6,6 +6,8 @@ package aws.sdk.kotlin.codegen.protocols.core import aws.sdk.kotlin.codegen.AwsKotlinDependency import aws.sdk.kotlin.codegen.AwsRuntimeTypes +import aws.sdk.kotlin.codegen.protocols.eventstream.EventStreamParserGenerator +import aws.sdk.kotlin.codegen.protocols.eventstream.EventStreamSerializerGenerator import aws.sdk.kotlin.codegen.protocols.middleware.AwsSignatureVersion4 import aws.sdk.kotlin.codegen.protocols.middleware.ResolveAwsEndpointMiddleware import aws.sdk.kotlin.codegen.protocols.middleware.UserAgentMiddleware @@ -112,6 +114,18 @@ abstract class AwsHttpBindingProtocolGenerator : HttpBindingProtocolGenerator() */ abstract fun renderDeserializeErrorDetails(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) + override fun eventStreamRequestHandler(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Symbol { + val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service) + val contentType = resolver.determineRequestContentType(op) ?: error("event streams must set a content-type") + val eventStreamSerializerGenerator = EventStreamSerializerGenerator(structuredDataSerializer(ctx), contentType) + return eventStreamSerializerGenerator.requestHandler(ctx, op) + } + + override fun eventStreamResponseHandler(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Symbol { + val eventStreamParserGenerator = EventStreamParserGenerator(ctx, structuredDataParser(ctx)) + return eventStreamParserGenerator.responseHandler(ctx, op) + } + override fun operationErrorHandler(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Symbol = op.errorHandler(ctx.settings) { writer -> writer.withBlock( diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/eventstream/EventStreamParserGenerator.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/eventstream/EventStreamParserGenerator.kt new file mode 100644 index 00000000000..5f22b115a60 --- /dev/null +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/eventstream/EventStreamParserGenerator.kt @@ -0,0 +1,183 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package aws.sdk.kotlin.codegen.protocols.eventstream + +import aws.sdk.kotlin.codegen.AwsRuntimeTypes +import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.kotlin.codegen.core.* +import software.amazon.smithy.kotlin.codegen.model.* +import software.amazon.smithy.kotlin.codegen.rendering.ExceptionBaseClassGenerator +import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator +import software.amazon.smithy.kotlin.codegen.rendering.serde.StructuredDataParserGenerator +import software.amazon.smithy.kotlin.codegen.rendering.serde.bodyDeserializer +import software.amazon.smithy.kotlin.codegen.rendering.serde.bodyDeserializerName +import software.amazon.smithy.model.shapes.* +import software.amazon.smithy.model.traits.EventHeaderTrait +import software.amazon.smithy.model.traits.EventPayloadTrait + +/** + * Implements rendering deserialize implementation for event streams implemented using the + * `vnd.amazon.event-stream` content-type + * + * @param sdg the structured data parser generator + */ +class EventStreamParserGenerator( + private val ctx: ProtocolGenerator.GenerationContext, + private val sdg: StructuredDataParserGenerator +) { + /** + * Return the function responsible for deserializing an operation output that targets an event stream + * + * ``` + * private suspend fun deserializeFooOperationBody(builder: Foo.Builder, body: HttpBody) { ... } + * ``` + */ + fun responseHandler(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Symbol = + // FIXME - don't use the body deserializer name since we may need to re-use it (albeit with a different signature, we should still be more explicit than this) + op.bodyDeserializer(ctx.settings) { writer -> + val outputSymbol = ctx.symbolProvider.toSymbol(ctx.model.expectShape(op.output.get())) + // we have access to the builder for the output type and the full HttpBody + // members bound via HTTP bindings (e.g. httpHeader, statusCode, etc) are already deserialized via HttpDeserialize impl + // we just need to deserialize the event stream member (and/or the initial response) + writer.withBlock( + // FIXME - revert to private, exposed as internal temporarily while we figure out integration tests + "internal suspend fun #L(builder: #T.Builder, body: #T) {", + "}", + op.bodyDeserializerName(), + outputSymbol, + RuntimeTypes.Http.HttpBody + ) { + renderDeserializeEventStream(ctx, op, writer) + } + } + + private fun renderDeserializeEventStream(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) { + val output = ctx.model.expectShape(op.output.get()) + val streamingMember = output.findStreamingMember(ctx.model) ?: error("expected a streaming member for $output") + val streamShape = ctx.model.expectShape(streamingMember.target) + val streamSymbol = ctx.symbolProvider.toSymbol(streamShape) + + // TODO - handle RPC bound protocol bindings where the initial response is bound to an event stream document + // possibly by decoding the first Message + + val messageTypeSymbol = AwsRuntimeTypes.AwsEventStream.MessageType + val baseExceptionSymbol = ExceptionBaseClassGenerator.baseExceptionSymbol(ctx.settings) + + writer.write("val chan = body.#T() ?: return", RuntimeTypes.Http.toSdkByteReadChannel) + writer.write("val events = #T(chan)", AwsRuntimeTypes.AwsEventStream.decodeFrames) + .indent() + .withBlock(".#T { message ->", "}", RuntimeTypes.KotlinxCoroutines.Flow.map) { + withBlock("when(val mt = message.#T()) {", "}", AwsRuntimeTypes.AwsEventStream.MessageTypeExt) { + withBlock("is #T.Event -> when(mt.shapeType) {", "}", messageTypeSymbol) { + streamShape.filterEventStreamErrors(ctx.model).forEach { member -> + withBlock("#S -> {", "}", member.memberName) { + renderDeserializeEventVariant(ctx, streamSymbol, member, writer) + } + } + write("else -> #T.SdkUnknown", streamSymbol) + } + withBlock("is #T.Exception -> when(mt.shapeType){", "}", messageTypeSymbol) { + // errors are completely bound to payload (at least according to design docs) + val errorMembers = streamShape.members().filter { + val target = ctx.model.expectShape(it.target) + target.isError + } + errorMembers.forEach { member -> + withBlock("#S -> {", "}", member.memberName) { + val payloadDeserializeFn = sdg.payloadDeserializer(ctx, member) + write("val err = #T(message.payload)", payloadDeserializeFn) + write("throw err") + } + } + write("else -> throw #T(#S)", baseExceptionSymbol, "error processing event stream, unrecognized errorType: \${mt.shapeType}") + } + // this is a service exception still, just un-modeled + write("is #T.Error -> throw #T(\"error processing event stream: errorCode=\${mt.errorCode}; message=\${mt.message}\")", messageTypeSymbol, baseExceptionSymbol) + // this is a client exception because we failed to parse it + write("is #T.SdkUnknown -> throw #T(\"unrecognized event stream message `:message-type`: \${mt.messageType}\")", messageTypeSymbol, AwsRuntimeTypes.Core.ClientException) + } + } + .dedent() + .write("builder.#L = events", streamingMember.defaultName()) + } + + private fun renderDeserializeEventVariant(ctx: ProtocolGenerator.GenerationContext, unionSymbol: Symbol, member: MemberShape, writer: KotlinWriter) { + val variant = ctx.model.expectShape(member.target) + + val eventHeaderBindings = variant.members().filter { it.hasTrait() } + val eventPayloadBinding = variant.members().firstOrNull { it.hasTrait() } + + if (eventHeaderBindings.isEmpty() && eventPayloadBinding == null) { + // the entire variant can be deserialized from the payload + val payloadDeserializeFn = sdg.payloadDeserializer(ctx, member) + writer.write("val e = #T(message.payload)", payloadDeserializeFn) + } else { + val variantSymbol = ctx.symbolProvider.toSymbol(variant) + writer.write("val builder = #T.Builder()", variantSymbol) + + // render members bound to header + eventHeaderBindings.forEach { hdrBinding -> + val target = ctx.model.expectShape(hdrBinding.target) + val targetSymbol = ctx.symbolProvider.toSymbol(target) + + // :test(boolean, byte, short, integer, long, blob, string, timestamp)) + val conversionFn = when (target.type) { + ShapeType.BOOLEAN -> AwsRuntimeTypes.AwsEventStream.expectBool + ShapeType.BYTE -> AwsRuntimeTypes.AwsEventStream.expectByte + ShapeType.SHORT -> AwsRuntimeTypes.AwsEventStream.expectInt16 + ShapeType.INTEGER -> AwsRuntimeTypes.AwsEventStream.expectInt32 + ShapeType.LONG -> AwsRuntimeTypes.AwsEventStream.expectInt64 + ShapeType.BLOB -> AwsRuntimeTypes.AwsEventStream.expectByteArray + ShapeType.STRING -> AwsRuntimeTypes.AwsEventStream.expectString + ShapeType.TIMESTAMP -> AwsRuntimeTypes.AwsEventStream.expectTimestamp + else -> throw CodegenException("unsupported eventHeader shape: member=$hdrBinding; targetShape=$target") + } + + val defaultValuePostfix = if (targetSymbol.isNotBoxed && targetSymbol.defaultValue() != null) { + " ?: ${targetSymbol.defaultValue()}" + } else { + "" + } + writer.write("builder.#L = message.headers.find { it.name == #S }?.value?.#T()$defaultValuePostfix", hdrBinding.defaultName(), hdrBinding.memberName, conversionFn) + } + + if (eventPayloadBinding != null) { + renderDeserializeExplicitEventPayloadMember(ctx, eventPayloadBinding, writer) + } else { + val members = variant.members().filterNot { it.hasTrait() } + if (members.isNotEmpty()) { + // all remaining members are bound to payload (but not explicitly bound via @eventPayload) + // use the operation body deserializer + TODO("render unbound event stream payload members") + } + } + + writer.write("val e = builder.build()") + } + + writer.write("#T.#L(e)", unionSymbol, member.unionVariantName()) + } + + private fun renderDeserializeExplicitEventPayloadMember( + ctx: ProtocolGenerator.GenerationContext, + member: MemberShape, + writer: KotlinWriter + ) { + // FIXME - check content type for blob and string + // structure > :test(member > :test(blob, string, structure, union)) + val target = ctx.model.expectShape(member.target) + when (target.type) { + ShapeType.BLOB -> writer.write("builder.#L = message.payload", member.defaultName()) + ShapeType.STRING -> writer.write("builder.#L = message.payload?.decodeToString()", member.defaultName()) + ShapeType.STRUCTURE, ShapeType.UNION -> { + val payloadDeserializeFn = sdg.payloadDeserializer(ctx, member) + writer.write("builder.#L = #T(message.payload)", member.defaultName(), payloadDeserializeFn) + } + else -> throw CodegenException("unsupported shape type `${target.type}` for target: $target; expected blob, string, structure, or union for eventPayload member: $member") + } + } +} diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/eventstream/EventStreamSerializerGenerator.kt b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/eventstream/EventStreamSerializerGenerator.kt new file mode 100644 index 00000000000..147cfc952c2 --- /dev/null +++ b/codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/eventstream/EventStreamSerializerGenerator.kt @@ -0,0 +1,176 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package aws.sdk.kotlin.codegen.protocols.eventstream + +import aws.sdk.kotlin.codegen.AwsRuntimeTypes +import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.kotlin.codegen.core.* +import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes +import software.amazon.smithy.kotlin.codegen.model.* +import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator +import software.amazon.smithy.kotlin.codegen.rendering.serde.* +import software.amazon.smithy.model.shapes.* +import software.amazon.smithy.model.traits.EventHeaderTrait +import software.amazon.smithy.model.traits.EventPayloadTrait + +/** + * Implements rendering serialize implementation for event streams implemented using the + * `vnd.amazon.event-stream` content-type + * + * @param sdg the structured data serializer generator + * @param payloadContentType the content-type to use when sending structured data (e.g. `application/json`) + */ +class EventStreamSerializerGenerator( + private val sdg: StructuredDataSerializerGenerator, + private val payloadContentType: String, +) { + + /** + * Return the function responsible for serializing an operation output that targets an event stream + * + * ``` + * private suspend fun serializeFooOperationBody(input: FooInput): HttpBody { ... } + * ``` + */ + fun requestHandler(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Symbol = + // FIXME - don't use the body serializer name since we may need to re-use it (albeit with a different signature, we should still be more explicit than this) + op.bodySerializer(ctx.settings) { writer -> + val inputSymbol = ctx.symbolProvider.toSymbol(ctx.model.expectShape(op.input.get())) + writer.withBlock( + // FIXME - revert to private, exposed as internal temporarily while we figure out integration tests + "internal suspend fun #L(context: #T, input: #T): #T {", + "}", + op.bodySerializerName(), + RuntimeTypes.Core.ExecutionContext, + inputSymbol, + RuntimeTypes.Http.HttpBody + ) { + renderSerializeEventStream(ctx, op, writer) + } + } + + private fun renderSerializeEventStream( + ctx: ProtocolGenerator.GenerationContext, + op: OperationShape, + writer: KotlinWriter + ) { + val input = ctx.model.expectShape(op.input.get()) + val streamingMember = input.findStreamingMember(ctx.model) ?: error("expected a streaming member for $input") + val streamShape = ctx.model.expectShape(streamingMember.target) + + writer.write("val stream = input.#L ?: return #T.Empty", streamingMember.defaultName(), RuntimeTypes.Http.HttpBody) + writer.write("val signingConfig = context.#T()", AwsRuntimeTypes.AwsEventStream.newEventStreamSigningConfig) + + val encodeFn = encodeEventStreamMessage(ctx, op, streamShape) + writer.withBlock("val messages = stream", "") { + write(".#T(::#T)", RuntimeTypes.KotlinxCoroutines.Flow.map, encodeFn) + write(".#T(context, signingConfig)", AwsRuntimeTypes.AwsEventStream.sign) + write(".#T()", AwsRuntimeTypes.AwsEventStream.encode) + } + + writer.write("") + writer.write("return messages.#T()", AwsRuntimeTypes.AwsEventStream.asEventStreamHttpBody) + } + + private fun encodeEventStreamMessage( + ctx: ProtocolGenerator.GenerationContext, + op: OperationShape, + streamShape: UnionShape, + ): Symbol = buildSymbol { + val streamSymbol = ctx.symbolProvider.toSymbol(streamShape) + val fnName = "encode${op.capitalizedDefaultName()}${streamSymbol.name}EventMessage" + name = fnName + namespace = ctx.settings.pkg.subpackage("transform") + // place it in same file as the operation serializer + definitionFile = "${op.serializerName()}.kt" + + renderBy = { writer -> + // FIXME - make internal and share across operations? + writer.withBlock( + "private fun #L(input: #T): #T = #T {", "}", + fnName, + streamSymbol, + AwsRuntimeTypes.AwsEventStream.Message, + AwsRuntimeTypes.AwsEventStream.buildMessage + ) { + addStringHeader(":message-type", "event") + + withBlock("when(input) {", "}") { + streamShape.filterEventStreamErrors(ctx.model) + .forEach { member -> + withBlock( + "is #T.#L -> {", "}", + streamSymbol, + member.unionVariantName() + ) { + addStringHeader(":event-type", member.memberName) + val target = ctx.model.expectShape(member.target) + target.members().forEach { targetMember -> + when { + targetMember.hasTrait() -> renderSerializeEventHeader(ctx, targetMember, writer) + targetMember.hasTrait() -> renderSerializeEventPayload(ctx, targetMember, writer) + } + } + } + } + write("is #T.SdkUnknown -> error(#S)", streamSymbol, "cannot serialize the unknown event type!") + } + } + } + } + + private fun renderSerializeEventHeader(ctx: ProtocolGenerator.GenerationContext, member: MemberShape, writer: KotlinWriter) { + val target = ctx.model.expectShape(member.target) + val headerValue = when (target.type) { + ShapeType.BOOLEAN -> "Bool" + ShapeType.BYTE -> "Byte" + ShapeType.SHORT -> "Int16" + ShapeType.INTEGER -> "Int32" + ShapeType.LONG -> "Int64" + ShapeType.BLOB -> "ByteArray" + ShapeType.STRING -> "String" + ShapeType.TIMESTAMP -> "Timestamp" + else -> throw CodegenException("unsupported shape type `${target.type}` for eventHeader member `$member`; target: $target") + } + val conversion = if (target.type == ShapeType.BYTE) ".toUByte()" else "" + + writer.write( + "input.value.#L?.let { addHeader(#S, #T.#L(it$conversion)) }", + member.defaultName(), + member.memberName, + AwsRuntimeTypes.AwsEventStream.HeaderValue, + headerValue + ) + } + + private fun renderSerializeEventPayload(ctx: ProtocolGenerator.GenerationContext, member: MemberShape, writer: KotlinWriter) { + // structure > :test(member > :test(blob, string, structure, union)) + val target = ctx.model.expectShape(member.target) + when (target.type) { + // input is the sealed class, each variant is generated with `value` as the property name of the type being wrapped + ShapeType.BLOB -> { + writer.addStringHeader(":content-type", "application/octet-stream") + writer.write("payload = input.value.#L", member.defaultName()) + } + ShapeType.STRING -> { + writer.addStringHeader(":content-type", "text/plain") + writer.write("payload = input.value.#L?.#T()", member.defaultName(), KotlinTypes.Text.encodeToByteArray) + } + ShapeType.STRUCTURE, ShapeType.UNION -> { + writer.addStringHeader(":content-type", payloadContentType) + // re-use the payload serializer + val serializeFn = sdg.payloadSerializer(ctx, member) + writer.write("payload = input.value.#L?.let { #T(it) }", member.defaultName(), serializeFn) + } + else -> throw CodegenException("unsupported shape type `${target.type}` for target: $target; expected blob, string, structure, or union for eventPayload member: $member") + } + } + + private fun KotlinWriter.addStringHeader(name: String, value: String) { + write("addHeader(#S, #T.String(#S))", name, AwsRuntimeTypes.AwsEventStream.HeaderValue, value) + } +} diff --git a/codegen/smithy-aws-kotlin-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration b/codegen/smithy-aws-kotlin-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration index 1582994fe8f..72ebfcaa6b0 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration +++ b/codegen/smithy-aws-kotlin-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration @@ -3,7 +3,6 @@ aws.sdk.kotlin.codegen.AwsDefaultRetryIntegration aws.sdk.kotlin.codegen.customization.s3.S3GeneratorSupplier aws.sdk.kotlin.codegen.GradleGenerator aws.sdk.kotlin.codegen.AwsServiceConfigIntegration -aws.sdk.kotlin.codegen.customization.RemoveEventStreamOperations aws.sdk.kotlin.codegen.customization.s3.S3SigningConfig aws.sdk.kotlin.codegen.customization.s3.S3ErrorMetadataIntegration aws.sdk.kotlin.codegen.customization.s3.GetBucketLocationDeserializerIntegration diff --git a/codegen/smithy-aws-kotlin-codegen/src/test/kotlin/aws/sdk/kotlin/codegen/customization/RemoveEventStreamOperationsTest.kt b/codegen/smithy-aws-kotlin-codegen/src/test/kotlin/aws/sdk/kotlin/codegen/customization/RemoveEventStreamOperationsTest.kt index 0ee9b0c83c8..c87a41cc30e 100644 --- a/codegen/smithy-aws-kotlin-codegen/src/test/kotlin/aws/sdk/kotlin/codegen/customization/RemoveEventStreamOperationsTest.kt +++ b/codegen/smithy-aws-kotlin-codegen/src/test/kotlin/aws/sdk/kotlin/codegen/customization/RemoveEventStreamOperationsTest.kt @@ -40,6 +40,6 @@ class RemoveEventStreamOperationsTest { val ctx = model.newTestContext() val transformed = RemoveEventStreamOperations().preprocessModel(model, ctx.generationCtx.settings) transformed.expectShape(ShapeId.from("com.test#BlobStream")) - transformed.getShape(ShapeId.from("comm.test#EventStream")).shouldBe(java.util.Optional.empty()) + transformed.getShape(ShapeId.from("com.test#EventStream")).shouldBe(java.util.Optional.empty()) } } diff --git a/gradle.properties b/gradle.properties index 988143034d0..fbb6cddfd87 100644 --- a/gradle.properties +++ b/gradle.properties @@ -12,7 +12,7 @@ sdkVersion=0.13.1-SNAPSHOT smithyVersion=1.17.0 smithyGradleVersion=0.5.3 # smithy-kotlin codegen and runtime are versioned together -smithyKotlinVersion=0.7.8 +smithyKotlinVersion=0.7.9-SNAPSHOT # kotlin kotlinVersion=1.6.10 @@ -28,7 +28,7 @@ kotlinxSerializationVersion=1.3.1 ktorVersion=1.6.7 # crt -crtKotlinVersion=0.5.3 +crtKotlinVersion=0.5.4-SNAPSHOT # testing/utility junitVersion=5.6.2 diff --git a/settings.gradle.kts b/settings.gradle.kts index 4fb9ab34839..6aed9aed72a 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -42,6 +42,8 @@ include(":aws-runtime:protocols:aws-json-protocols") include(":aws-runtime:protocols:aws-xml-protocols") include(":aws-runtime:protocols:aws-event-stream") include(":aws-runtime:crt-util") +// include(":tests") +// include(":tests:codegen:event-stream") // generated services fun File.isServiceDir(): Boolean {