diff --git a/src/aws-cpp-sdk-core/include/aws/core/utils/stream/AwsChunkedStream.h b/src/aws-cpp-sdk-core/include/aws/core/utils/stream/AwsChunkedStream.h index 1eb6d44c14a..fbbf85248a9 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/utils/stream/AwsChunkedStream.h +++ b/src/aws-cpp-sdk-core/include/aws/core/utils/stream/AwsChunkedStream.h @@ -30,30 +30,25 @@ class AwsChunkedStream { size_t BufferedRead(char *dst, size_t amountToRead) { assert(dst != nullptr); - if (dst == nullptr) { - AWS_LOGSTREAM_ERROR("AwsChunkedStream", "dst is null"); - } - // the chunk has ended and cannot be read from - if (m_chunkEnd) { - return 0; - } - - // If we've read all of the underlying stream write the checksum trailing header - // the set that the chunked stream is over. - if (m_stream->eof() && !m_stream->bad() && (m_chunkingStream->eof() || m_chunkingStream->peek() == EOF)) { - return writeTrailer(dst, amountToRead); - } + // only read and write to chunked stream if the underlying stream + // is still in a valid state + if (m_stream->good()) { + // Try to read in a 64K chunk, if we cant we know the stream is over + m_stream->read(m_data.GetUnderlyingData(), DataBufferSize); + size_t bytesRead = static_cast(m_stream->gcount()); + writeChunk(bytesRead); - // Try to read in a 64K chunk, if we cant we know the stream is over - size_t bytesRead = 0; - while (m_stream->good() && bytesRead < DataBufferSize) { - m_stream->read(&m_data[bytesRead], DataBufferSize - bytesRead); - bytesRead += static_cast(m_stream->gcount()); + // if we've read everything from the stream, we want to add the trailer + // to the underlying stream + if ((m_stream->peek() == EOF || m_stream->eof()) && !m_stream->bad()) { + writeTrailerToUnderlyingStream(); + } } - if (bytesRead > 0) { - writeChunk(bytesRead); + // if the underlying stream is empty there is nothing to read + if ((m_chunkingStream->peek() == EOF || m_chunkingStream->eof()) && !m_chunkingStream->bad()) { + return 0; } // Read to destination buffer, return how much was read @@ -62,7 +57,7 @@ class AwsChunkedStream { } private: - size_t writeTrailer(char *dst, size_t amountToRead) { + void writeTrailerToUnderlyingStream() { Aws::StringStream chunkedTrailerStream; chunkedTrailerStream << "0\r\n"; if (m_request->GetRequestHash().second != nullptr) { @@ -71,13 +66,10 @@ class AwsChunkedStream { } chunkedTrailerStream << "\r\n"; const auto chunkedTrailer = chunkedTrailerStream.str(); - auto trailerSize = chunkedTrailer.size(); - // unreferenced param for assert - AWS_UNREFERENCED_PARAM(amountToRead); - assert(amountToRead >= trailerSize); - memcpy(dst, chunkedTrailer.c_str(), trailerSize); - m_chunkEnd = true; - return trailerSize; + if (m_chunkingStream->eof()) { + m_chunkingStream->clear(); + } + *m_chunkingStream << chunkedTrailer; } void writeChunk(size_t bytesRead) { @@ -86,6 +78,9 @@ class AwsChunkedStream { } if (m_chunkingStream != nullptr && !m_chunkingStream->bad()) { + if (m_chunkingStream->eof()) { + m_chunkingStream->clear(); + } *m_chunkingStream << Aws::Utils::StringUtils::ToHexString(bytesRead) << "\r\n"; m_chunkingStream->write(m_data.GetUnderlyingData(), bytesRead); *m_chunkingStream << "\r\n"; @@ -94,7 +89,6 @@ class AwsChunkedStream { Aws::Utils::Array m_data{DataBufferSize}; std::shared_ptr m_chunkingStream; - bool m_chunkEnd{false}; Http::HttpRequest *m_request{nullptr}; std::shared_ptr m_stream; }; diff --git a/tests/aws-cpp-sdk-core-tests/utils/stream/AwsChunkedStreamTest.cpp b/tests/aws-cpp-sdk-core-tests/utils/stream/AwsChunkedStreamTest.cpp index 075b46e4558..ee43ab90734 100644 --- a/tests/aws-cpp-sdk-core-tests/utils/stream/AwsChunkedStreamTest.cpp +++ b/tests/aws-cpp-sdk-core-tests/utils/stream/AwsChunkedStreamTest.cpp @@ -39,3 +39,46 @@ TEST_F(AwsChunkedStreamTest, ChunkedStreamShouldWork) { auto expectedStreamWithChecksum = "A\r\n1234567890\r\nA\r\n1234567890\r\n5\r\n12345\r\n0\r\nx-amz-checksum-crc32:78DeVw==\r\n\r\n"; EXPECT_EQ(expectedStreamWithChecksum, encodedStr); } + +TEST_F(AwsChunkedStreamTest, ShouldNotRequireTwoReadsOnSmallChunk) { + StandardHttpRequest request{"www.clemar.com/strohl", Http::HttpMethod::HTTP_GET}; + auto requestHash = Aws::MakeShared(TEST_LOG_TAG); + request.SetRequestHash("crc32", requestHash); + std::shared_ptr inputStream = Aws::MakeShared(TEST_LOG_TAG, "12345"); + AwsChunkedStream<100> chunkedStream{&request, inputStream}; + Aws::Utils::Array outputBuffer{100}; + Aws::StringStream output; + const auto bufferOffset = chunkedStream.BufferedRead(outputBuffer.GetUnderlyingData(), 100); + std::copy(outputBuffer.GetUnderlyingData(), outputBuffer.GetUnderlyingData() + bufferOffset, std::ostream_iterator(output)); + EXPECT_EQ(46ul, bufferOffset); + const auto encodedStr = output.str(); + auto expectedStreamWithChecksum = "5\r\n12345\r\n0\r\nx-amz-checksum-crc32:y/U6HA==\r\n\r\n"; + EXPECT_EQ(expectedStreamWithChecksum, encodedStr); +} + +TEST_F(AwsChunkedStreamTest, ShouldWorkOnSmallBuffer) { + StandardHttpRequest request{"www.eugief.com/hesimay", Http::HttpMethod::HTTP_GET}; + auto requestHash = Aws::MakeShared(TEST_LOG_TAG); + request.SetRequestHash("crc32", requestHash); + std::shared_ptr inputStream = Aws::MakeShared(TEST_LOG_TAG, "1234567890"); + AwsChunkedStream<5> chunkedStream{&request, inputStream}; + Aws::Utils::Array outputBuffer{100}; + // Read first 5 bytes, we get back ten bytes chunk encoded since it is "5\r\n12345\r\n" + Aws::StringStream firstRead; + auto amountRead = chunkedStream.BufferedRead(outputBuffer.GetUnderlyingData(), 100); + std::copy(outputBuffer.GetUnderlyingData(), outputBuffer.GetUnderlyingData() + amountRead, std::ostream_iterator(firstRead)); + EXPECT_EQ(10ul, amountRead); + auto encodedStr = firstRead.str(); + EXPECT_EQ("5\r\n12345\r\n", encodedStr); + // Read second 5 bytes, we get back 46 bytes because we exhaust the underlying buffer + // abd write the trailer "5\r\n67890\r\n0\r\nx-amz-checksum-crc32:Jh2u5Q==\r\n\r\n" + Aws::StringStream secondRead; + amountRead = chunkedStream.BufferedRead(outputBuffer.GetUnderlyingData(), 100); + std::copy(outputBuffer.GetUnderlyingData(), outputBuffer.GetUnderlyingData() + amountRead, std::ostream_iterator(secondRead)); + EXPECT_EQ(46ul, amountRead); + encodedStr = secondRead.str(); + EXPECT_EQ("5\r\n67890\r\n0\r\nx-amz-checksum-crc32:Jh2u5Q==\r\n\r\n", encodedStr); + // Any subsequent reads will return 0 because all streams are exhausted + amountRead = chunkedStream.BufferedRead(outputBuffer.GetUnderlyingData(), 100); + EXPECT_EQ(0ul, amountRead); +} \ No newline at end of file