diff --git a/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/CrcUtil.kt b/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/CrcUtil.kt deleted file mode 100644 index 69a92169c20..00000000000 --- a/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/CrcUtil.kt +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package aws.sdk.kotlin.runtime.protocol.eventstream - -import aws.smithy.kotlin.runtime.hashing.Crc32 -import aws.smithy.kotlin.runtime.io.SdkSink -import aws.smithy.kotlin.runtime.io.SdkSource -import aws.smithy.kotlin.runtime.io.internal.SdkSinkObserver -import aws.smithy.kotlin.runtime.io.internal.SdkSourceObserver - -internal class CrcSource(source: SdkSource) : SdkSourceObserver(source) { - private val _crc = Crc32() - - val crc: UInt - get() = _crc.digestValue() - - override fun observe(data: ByteArray, offset: Int, length: Int) { - _crc.update(data, offset, length) - } -} - -internal class CrcSink(sink: SdkSink) : SdkSinkObserver(sink) { - private val _crc = Crc32() - - val crc: UInt - get() = _crc.digestValue() - - override fun observe(data: ByteArray, offset: Int, length: Int) { - _crc.update(data, offset, length) - } -} diff --git a/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/Message.kt b/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/Message.kt index 647ce69ca27..0a47e3fa234 100644 --- a/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/Message.kt +++ b/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/Message.kt @@ -6,7 +6,9 @@ package aws.sdk.kotlin.runtime.protocol.eventstream import aws.sdk.kotlin.runtime.InternalSdkApi +import aws.smithy.kotlin.runtime.hashing.Crc32 import aws.smithy.kotlin.runtime.io.* +import aws.smithy.kotlin.runtime.util.encodeToHex internal const val MESSAGE_CRC_BYTE_LEN = 4 @@ -50,14 +52,14 @@ public data class Message(val headers: List
, val payload: ByteArray) { check(totalLen <= MAX_MESSAGE_SIZE.toUInt()) { "Invalid Message size: $totalLen" } // Limiting the amount of data read by SdkBufferedSource is tricky and cause incorrect CRC - // if not careful (e.g. creating a buffered source of CrcSource will usually lead to incorrect results + // if not careful (e.g. creating a buffered source of a HashingSource will usually lead to incorrect results // because the entire point SdkBufferedSource (okio.BufferedSource) is to buffer larger chunks internally // to optimize short reads) val messageBuffer = SdkBuffer() val computedCrc = run { - val crcSource = CrcSource(source) + val crcSource = HashingSource(Crc32(), source) crcSource.read(messageBuffer, totalLen.toLong() - MESSAGE_CRC_BYTE_LEN.toLong()) - crcSource.crc + crcSource.digest() } val prelude = Prelude.decode(messageBuffer) @@ -79,9 +81,9 @@ public data class Message(val headers: List
, val payload: ByteArray) { message.payload = messageBuffer.readByteArray(prelude.payloadLen.toLong()) - val expectedCrc = source.readInt().toUInt() - check(computedCrc == expectedCrc) { - "Message checksum mismatch; expected=0x${expectedCrc.toString(16)}; calculated=0x${computedCrc.toString(16)}" + val expectedCrc = source.readByteArray(4) + check(computedCrc.contentEquals(expectedCrc)) { + "Message checksum mismatch; expected=0x${expectedCrc.encodeToHex()}; calculated=0x${computedCrc.encodeToHex()}" } return message.build() } @@ -119,7 +121,7 @@ public data class Message(val headers: List
, val payload: ByteArray) { val prelude = Prelude(messageLen.toInt(), headersLen.toInt()) - val sink = CrcSink(dest) + val sink = HashingSink(Crc32(), dest) val buffer = sink.buffer() prelude.encode(buffer) @@ -127,7 +129,7 @@ public data class Message(val headers: List
, val payload: ByteArray) { buffer.write(payload) buffer.emit() - dest.writeInt(sink.crc.toInt()) + dest.write(sink.digest()) } } diff --git a/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/Prelude.kt b/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/Prelude.kt index 4913401fb22..0f355f571e0 100644 --- a/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/Prelude.kt +++ b/aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/Prelude.kt @@ -6,7 +6,9 @@ package aws.sdk.kotlin.runtime.protocol.eventstream import aws.sdk.kotlin.runtime.InternalSdkApi +import aws.smithy.kotlin.runtime.hashing.Crc32 import aws.smithy.kotlin.runtime.io.* +import aws.smithy.kotlin.runtime.util.encodeToHex internal const val PRELUDE_BYTE_LEN = 8 internal const val PRELUDE_BYTE_LEN_WITH_CRC = PRELUDE_BYTE_LEN + 4 @@ -28,13 +30,13 @@ public data class Prelude(val totalLen: Int, val headersLength: Int) { * Encode the prelude + CRC to [dest] buffer */ public fun encode(dest: SdkBufferedSink) { - val sink = CrcSink(dest) + val sink = HashingSink(Crc32(), dest) val buffer = sink.buffer() buffer.writeInt(totalLen) buffer.writeInt(headersLength) buffer.emit() - dest.writeInt(sink.crc.toInt()) + dest.write(sink.digest()) } public companion object { @@ -43,19 +45,20 @@ public data class Prelude(val totalLen: Int, val headersLength: Int) { */ public fun decode(source: SdkBufferedSource): Prelude { check(source.request(PRELUDE_BYTE_LEN_WITH_CRC.toLong())) { "Invalid message prelude" } - val crcSource = CrcSource(source) + val crcSource = HashingSource(Crc32(), source) val buffer = SdkBuffer() crcSource.read(buffer, PRELUDE_BYTE_LEN.toLong()) - val expectedCrc = source.readInt().toUInt() - val computedCrc = crcSource.crc + + val expectedCrc = source.readByteArray(4) + val computedCrc = crcSource.digest() val totalLen = buffer.readInt() val headerLen = buffer.readInt() check(totalLen <= MAX_MESSAGE_SIZE) { "Invalid Message size: $totalLen" } check(headerLen <= MAX_HEADER_SIZE) { "Invalid Header size: $headerLen" } - check(expectedCrc == computedCrc) { - "Prelude checksum mismatch; expected=0x${expectedCrc.toString(16)}; calculated=0x${computedCrc.toString(16)}" + check(expectedCrc.contentEquals(computedCrc)) { + "Prelude checksum mismatch; expected=0x${expectedCrc.encodeToHex()}; calculated=0x${computedCrc.encodeToHex()}" } return Prelude(totalLen, headerLen) } diff --git a/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/MessageTest.kt b/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/MessageTest.kt index 363ef19b253..a9778adaac6 100644 --- a/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/MessageTest.kt +++ b/aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/MessageTest.kt @@ -248,7 +248,7 @@ class MessageTest { val buffer = sdkBufferOf(encoded) assertFailsWith { Message.decode(buffer) - }.message.shouldContain("Message checksum mismatch; expected=0xdeadbeef; calculated=0x1a05860") + }.message.shouldContain("Message checksum mismatch; expected=0xdeadbeef; calculated=0x01a05860") } @Test