diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChecksumCalculatingAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChecksumCalculatingAsyncRequestBody.java index 101aed7b7441..146007927c63 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChecksumCalculatingAsyncRequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChecksumCalculatingAsyncRequestBody.java @@ -239,7 +239,7 @@ private static final class SynchronousChunkBuffer { } private Iterable buffer(ByteBuffer bytes) { - return chunkBuffer.bufferAndCreateChunks(bytes); + return chunkBuffer.split(bytes); } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChunkBuffer.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChunkBuffer.java index 93d6d09578a6..c171b0787678 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChunkBuffer.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ChunkBuffer.java @@ -19,10 +19,11 @@ import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicLong; import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.utils.BinaryUtils; +import software.amazon.awssdk.utils.Logger; import software.amazon.awssdk.utils.Validate; import software.amazon.awssdk.utils.builder.SdkBuilder; @@ -31,17 +32,20 @@ */ @SdkInternalApi public final class ChunkBuffer { - private final AtomicLong remainingBytes; + private static final Logger log = Logger.loggerFor(ChunkBuffer.class); + private final AtomicLong transferredBytes; private final ByteBuffer currentBuffer; - private final int bufferSize; + private final int chunkSize; + private final long totalBytes; private ChunkBuffer(Long totalBytes, Integer bufferSize) { Validate.notNull(totalBytes, "The totalBytes must not be null"); int chunkSize = bufferSize != null ? bufferSize : DEFAULT_ASYNC_CHUNK_SIZE; - this.bufferSize = chunkSize; + this.chunkSize = chunkSize; this.currentBuffer = ByteBuffer.allocate(chunkSize); - this.remainingBytes = new AtomicLong(totalBytes); + this.totalBytes = totalBytes; + this.transferredBytes = new AtomicLong(0); } public static Builder builder() { @@ -49,52 +53,97 @@ public static Builder builder() { } - // currentBuffer and bufferedList can get over written if concurrent Threads calls this method at the same time. - public synchronized Iterable bufferAndCreateChunks(ByteBuffer buffer) { - int startPosition = 0; - List bufferedList = new ArrayList<>(); - int currentBytesRead = buffer.remaining(); - do { - int bufferedBytes = currentBuffer.position(); - int availableToRead = bufferSize - bufferedBytes; - int bytesToMove = Math.min(availableToRead, currentBytesRead - startPosition); + /** + * Split the input {@link ByteBuffer} into multiple smaller {@link ByteBuffer}s, each of which contains {@link #chunkSize} + * worth of bytes. If the last chunk of the input ByteBuffer contains less than {@link #chunkSize} data, the last chunk will + * be buffered. + */ + public synchronized Iterable split(ByteBuffer inputByteBuffer) { - byte[] bytes = BinaryUtils.copyAllBytesFrom(buffer); - if (bufferedBytes == 0) { - currentBuffer.put(bytes, startPosition, bytesToMove); - } else { - currentBuffer.put(bytes, 0, bytesToMove); + if (!inputByteBuffer.hasRemaining()) { + return Collections.singletonList(inputByteBuffer); + } + + List byteBuffers = new ArrayList<>(); + + // If current buffer is not empty, fill the buffer first. + if (currentBuffer.position() != 0) { + fillCurrentBuffer(inputByteBuffer); + + if (isCurrentBufferFull()) { + addCurrentBufferToIterable(byteBuffers, chunkSize); + } + } + + // If the input buffer is not empty, split the input buffer + if (inputByteBuffer.hasRemaining()) { + splitRemainingInputByteBuffer(inputByteBuffer, byteBuffers); + } + + // If this is the last chunk, add data buffered to the iterable + if (isLastChunk()) { + int remainingBytesInBuffer = currentBuffer.position(); + addCurrentBufferToIterable(byteBuffers, remainingBytesInBuffer); + } + return byteBuffers; + } + + private boolean isCurrentBufferFull() { + return currentBuffer.position() == chunkSize; + } + + /** + * Splits the input ByteBuffer to multiple chunks and add them to the iterable. + */ + private void splitRemainingInputByteBuffer(ByteBuffer inputByteBuffer, List byteBuffers) { + while (inputByteBuffer.hasRemaining()) { + ByteBuffer inputByteBufferCopy = inputByteBuffer.asReadOnlyBuffer(); + if (inputByteBuffer.remaining() < chunkSize) { + currentBuffer.put(inputByteBuffer); + break; } - startPosition = startPosition + bytesToMove; - - // Send the data once the buffer is full - if (currentBuffer.position() == bufferSize) { - currentBuffer.position(0); - ByteBuffer bufferToSend = ByteBuffer.allocate(bufferSize); - bufferToSend.put(currentBuffer.array(), 0, bufferSize); - bufferToSend.clear(); - currentBuffer.clear(); - bufferedList.add(bufferToSend); - remainingBytes.addAndGet(-bufferSize); + int newLimit = inputByteBufferCopy.position() + chunkSize; + inputByteBufferCopy.limit(newLimit); + inputByteBuffer.position(newLimit); + byteBuffers.add(inputByteBufferCopy); + transferredBytes.addAndGet(chunkSize); + } + } + + private boolean isLastChunk() { + long remainingBytes = totalBytes - transferredBytes.get(); + return remainingBytes != 0 && remainingBytes == currentBuffer.position(); + } + + private void addCurrentBufferToIterable(List byteBuffers, int capacity) { + ByteBuffer bufferedChunk = ByteBuffer.allocate(capacity); + currentBuffer.flip(); + bufferedChunk.put(currentBuffer); + bufferedChunk.flip(); + byteBuffers.add(bufferedChunk); + transferredBytes.addAndGet(bufferedChunk.remaining()); + currentBuffer.clear(); + } + + private void fillCurrentBuffer(ByteBuffer inputByteBuffer) { + while (currentBuffer.position() < chunkSize) { + if (!inputByteBuffer.hasRemaining()) { + break; + } + + int remainingCapacity = chunkSize - currentBuffer.position(); + + if (inputByteBuffer.remaining() < remainingCapacity) { + currentBuffer.put(inputByteBuffer); + } else { + ByteBuffer remainingChunk = inputByteBuffer.asReadOnlyBuffer(); + int newLimit = inputByteBuffer.position() + remainingCapacity; + remainingChunk.limit(newLimit); + inputByteBuffer.position(newLimit); + currentBuffer.put(remainingChunk); } - } while (startPosition < currentBytesRead); - - int remainingBytesInBuffer = currentBuffer.position(); - - // Send the remaining buffer when - // 1. remainingBytes in buffer are same as the last few bytes to be read. - // 2. If it is a zero byte and the last byte to be read. - if (remainingBytes.get() == remainingBytesInBuffer && - (buffer.remaining() == 0 || remainingBytesInBuffer > 0)) { - currentBuffer.clear(); - ByteBuffer trimmedBuffer = ByteBuffer.allocate(remainingBytesInBuffer); - trimmedBuffer.put(currentBuffer.array(), 0, remainingBytesInBuffer); - trimmedBuffer.clear(); - bufferedList.add(trimmedBuffer); - remainingBytes.addAndGet(-remainingBytesInBuffer); } - return bufferedList; } public interface Builder extends SdkBuilder { diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/ChunkBufferTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/ChunkBufferTest.java index 8b73402dc468..a553a55a4536 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/ChunkBufferTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/ChunkBufferTest.java @@ -18,6 +18,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import java.io.ByteArrayInputStream; +import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -29,8 +31,12 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import software.amazon.awssdk.core.internal.async.ChunkBuffer; +import software.amazon.awssdk.utils.BinaryUtils; import software.amazon.awssdk.utils.StringUtils; class ChunkBufferTest { @@ -40,42 +46,38 @@ void builderWithNoTotalSize() { assertThatThrownBy(() -> ChunkBuffer.builder().build()).isInstanceOf(NullPointerException.class); } - @Test - void numberOfChunkMultipleOfTotalBytes() { - String inputString = StringUtils.repeat("*", 25); - - ChunkBuffer chunkBuffer = - ChunkBuffer.builder().bufferSize(5).totalBytes(inputString.getBytes(StandardCharsets.UTF_8).length).build(); - Iterable byteBuffers = - chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8))); - - AtomicInteger iteratedCounts = new AtomicInteger(); - byteBuffers.forEach(r -> { - iteratedCounts.getAndIncrement(); - assertThat(r.array()).isEqualTo(StringUtils.repeat("*", 5).getBytes(StandardCharsets.UTF_8)); - }); - assertThat(iteratedCounts.get()).isEqualTo(5); - } - - @Test - void numberOfChunk_Not_MultipleOfTotalBytes() { - int totalBytes = 23; + @ParameterizedTest + @ValueSource(ints = {1, 6, 10, 23, 25}) + void numberOfChunk_Not_MultipleOfTotalBytes(int totalBytes) { int bufferSize = 5; - String inputString = StringUtils.repeat("*", totalBytes); + String inputString = RandomStringUtils.randomAscii(totalBytes); ChunkBuffer chunkBuffer = ChunkBuffer.builder().bufferSize(bufferSize).totalBytes(inputString.getBytes(StandardCharsets.UTF_8).length).build(); Iterable byteBuffers = - chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8))); + chunkBuffer.split(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8))); + + AtomicInteger index = new AtomicInteger(0); + int count = (int) Math.ceil(totalBytes / (double) bufferSize); + int remainder = totalBytes % bufferSize; - AtomicInteger iteratedCounts = new AtomicInteger(); byteBuffers.forEach(r -> { - iteratedCounts.getAndIncrement(); - if (iteratedCounts.get() * bufferSize < totalBytes) { - assertThat(r.array()).isEqualTo(StringUtils.repeat("*", bufferSize).getBytes(StandardCharsets.UTF_8)); - } else { - assertThat(r.array()).isEqualTo(StringUtils.repeat("*", 3).getBytes(StandardCharsets.UTF_8)); + int i = index.get(); + try (ByteArrayInputStream inputStream = new ByteArrayInputStream(inputString.getBytes(StandardCharsets.UTF_8))) { + byte[] expected; + if (i == count - 1 && remainder != 0) { + expected = new byte[remainder]; + } else { + expected = new byte[bufferSize]; + } + inputStream.skip(i * bufferSize); + inputStream.read(expected); + byte[] actualBytes = BinaryUtils.copyBytesFrom(r); + assertThat(actualBytes).isEqualTo(expected); + index.incrementAndGet(); + } catch (IOException e) { + throw new RuntimeException(e); } }); } @@ -86,7 +88,7 @@ void zeroTotalBytesAsInput_returnsZeroByte() { ChunkBuffer chunkBuffer = ChunkBuffer.builder().bufferSize(5).totalBytes(zeroByte.length).build(); Iterable byteBuffers = - chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(zeroByte)); + chunkBuffer.split(ByteBuffer.wrap(zeroByte)); AtomicInteger iteratedCounts = new AtomicInteger(); byteBuffers.forEach(r -> { @@ -104,16 +106,16 @@ void emptyAllocatedBytes_returnSameNumberOfEmptyBytes() { ChunkBuffer chunkBuffer = ChunkBuffer.builder().bufferSize(bufferSize).totalBytes(wrap.remaining()).build(); Iterable byteBuffers = - chunkBuffer.bufferAndCreateChunks(wrap); + chunkBuffer.split(wrap); AtomicInteger iteratedCounts = new AtomicInteger(); byteBuffers.forEach(r -> { iteratedCounts.getAndIncrement(); if (iteratedCounts.get() * bufferSize < totalBytes) { // array of empty bytes - assertThat(r.array()).isEqualTo(ByteBuffer.allocate(bufferSize).array()); + assertThat(BinaryUtils.copyBytesFrom(r)).isEqualTo(ByteBuffer.allocate(bufferSize).array()); } else { - assertThat(r.array()).isEqualTo(ByteBuffer.allocate(totalBytes % bufferSize).array()); + assertThat(BinaryUtils.copyBytesFrom(r)).isEqualTo(ByteBuffer.allocate(totalBytes % bufferSize).array()); } }); assertThat(iteratedCounts.get()).isEqualTo(4); @@ -167,7 +169,7 @@ void concurrentTreads_calling_bufferAndCreateChunks() throws ExecutionException, futures = IntStream.range(0, threads).>mapToObj(t -> service.submit(() -> { String inputString = StringUtils.repeat(Integer.toString(counter.incrementAndGet()), totalBytes); - return chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8))); + return chunkBuffer.split(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8))); })).collect(Collectors.toCollection(() -> new ArrayList<>(threads))); AtomicInteger filledBuffers = new AtomicInteger(0); diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ChecksumCalculatingAsyncRequestBodyTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ChecksumCalculatingAsyncRequestBodyTest.java index 90294bd2767b..39abaffd8f71 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ChecksumCalculatingAsyncRequestBodyTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ChecksumCalculatingAsyncRequestBodyTest.java @@ -28,6 +28,7 @@ import com.google.common.jimfs.Configuration; import com.google.common.jimfs.Jimfs; import io.reactivex.Flowable; +import org.apache.commons.lang3.RandomStringUtils; import org.assertj.core.util.Lists; import org.junit.Test; import org.junit.runner.RunWith; @@ -52,11 +53,26 @@ public class ChecksumCalculatingAsyncRequestBodyTest { "x-amz-checksum-crc32:i9aeUg==\r\n\r\n"; private final static Path path; + private final static ByteBuffer positionNonZeroBytebuffer; + + private final static ByteBuffer positionZeroBytebuffer; + static { + byte[] content = testString.getBytes(); + byte[] randomContent = RandomStringUtils.randomAscii(1024).getBytes(StandardCharsets.UTF_8); + positionNonZeroBytebuffer = ByteBuffer.allocate(content.length + randomContent.length); + positionNonZeroBytebuffer.put(randomContent) + .put(content); + positionNonZeroBytebuffer.position(randomContent.length); + + positionZeroBytebuffer = ByteBuffer.allocate(content.length); + positionZeroBytebuffer.put(content); + positionZeroBytebuffer.flip(); + FileSystem fs = Jimfs.newFileSystem(Configuration.unix()); path = fs.getPath("./test"); try { - Files.write(path, testString.getBytes()); + Files.write(path, content); } catch (IOException e) { e.printStackTrace(); } @@ -71,16 +87,25 @@ public ChecksumCalculatingAsyncRequestBodyTest(AsyncRequestBody provider) { @Parameterized.Parameters public static AsyncRequestBody[] data() { AsyncRequestBody[] asyncRequestBodies = { - ChecksumCalculatingAsyncRequestBody.builder() - .asyncRequestBody(AsyncRequestBody.fromString(testString)) - .algorithm(Algorithm.CRC32) - .trailerHeader("x-amz-checksum-crc32").build(), - - ChecksumCalculatingAsyncRequestBody.builder() - .asyncRequestBody(AsyncRequestBody.fromFile(path)) - .algorithm(Algorithm.CRC32) - .trailerHeader("x-amz-checksum-crc32").build(), - }; + ChecksumCalculatingAsyncRequestBody.builder() + .asyncRequestBody(AsyncRequestBody.fromString(testString)) + .algorithm(Algorithm.CRC32) + .trailerHeader("x-amz-checksum-crc32").build(), + + ChecksumCalculatingAsyncRequestBody.builder() + .asyncRequestBody(AsyncRequestBody.fromFile(path)) + .algorithm(Algorithm.CRC32) + .trailerHeader("x-amz-checksum-crc32").build(), + + ChecksumCalculatingAsyncRequestBody.builder() + .asyncRequestBody(AsyncRequestBody.fromRemainingByteBuffer(positionZeroBytebuffer)) + .algorithm(Algorithm.CRC32) + .trailerHeader("x-amz-checksum-crc32").build(), + ChecksumCalculatingAsyncRequestBody.builder() + .asyncRequestBody(AsyncRequestBody.fromRemainingByteBuffersUnsafe(positionNonZeroBytebuffer)) + .algorithm(Algorithm.CRC32) + .trailerHeader("x-amz-checksum-crc32").build(), + }; return asyncRequestBodies; } @@ -120,30 +145,30 @@ public void onComplete() { @Test public void stringConstructorHasCorrectContentType() { AsyncRequestBody requestBody = ChecksumCalculatingAsyncRequestBody.builder() - .asyncRequestBody(AsyncRequestBody.fromString("Hello world")) - .algorithm(Algorithm.CRC32) - .trailerHeader("x-amz-checksum-crc32") - .build(); + .asyncRequestBody(AsyncRequestBody.fromString("Hello world")) + .algorithm(Algorithm.CRC32) + .trailerHeader("x-amz-checksum-crc32") + .build(); assertThat(requestBody.contentType()).startsWith(Mimetype.MIMETYPE_TEXT_PLAIN); } @Test public void fileConstructorHasCorrectContentType() { AsyncRequestBody requestBody = ChecksumCalculatingAsyncRequestBody.builder() - .asyncRequestBody(AsyncRequestBody.fromFile(path)) - .algorithm(Algorithm.CRC32) - .trailerHeader("x-amz-checksum-crc32") - .build(); + .asyncRequestBody(AsyncRequestBody.fromFile(path)) + .algorithm(Algorithm.CRC32) + .trailerHeader("x-amz-checksum-crc32") + .build(); assertThat(requestBody.contentType()).isEqualTo(Mimetype.MIMETYPE_OCTET_STREAM); } @Test public void bytesArrayConstructorHasCorrectContentType() { AsyncRequestBody requestBody = ChecksumCalculatingAsyncRequestBody.builder() - .asyncRequestBody(AsyncRequestBody.fromBytes("hello world".getBytes())) - .algorithm(Algorithm.CRC32) - .trailerHeader("x-amz-checksum-crc32") - .build(); + .asyncRequestBody(AsyncRequestBody.fromBytes("hello world".getBytes())) + .algorithm(Algorithm.CRC32) + .trailerHeader("x-amz-checksum-crc32") + .build(); assertThat(requestBody.contentType()).isEqualTo(Mimetype.MIMETYPE_OCTET_STREAM); } @@ -151,20 +176,20 @@ public void bytesArrayConstructorHasCorrectContentType() { public void bytesBufferConstructorHasCorrectContentType() { ByteBuffer byteBuffer = ByteBuffer.wrap("hello world".getBytes()); AsyncRequestBody requestBody = ChecksumCalculatingAsyncRequestBody.builder() - .asyncRequestBody(AsyncRequestBody.fromByteBuffer(byteBuffer)) - .algorithm(Algorithm.CRC32) - .trailerHeader("x-amz-checksum-crc32") - .build(); + .asyncRequestBody(AsyncRequestBody.fromByteBuffer(byteBuffer)) + .algorithm(Algorithm.CRC32) + .trailerHeader("x-amz-checksum-crc32") + .build(); assertThat(requestBody.contentType()).isEqualTo(Mimetype.MIMETYPE_OCTET_STREAM); } @Test public void emptyBytesConstructorHasCorrectContentType() { AsyncRequestBody requestBody = ChecksumCalculatingAsyncRequestBody.builder() - .asyncRequestBody(AsyncRequestBody.empty()) - .algorithm(Algorithm.CRC32) - .trailerHeader("x-amz-checksum-crc32") - .build(); + .asyncRequestBody(AsyncRequestBody.empty()) + .algorithm(Algorithm.CRC32) + .trailerHeader("x-amz-checksum-crc32") + .build(); assertThat(requestBody.contentType()).isEqualTo(Mimetype.MIMETYPE_OCTET_STREAM); } @@ -172,8 +197,8 @@ public void emptyBytesConstructorHasCorrectContentType() { public void publisherConstructorThrowsExceptionIfNoContentLength() { List requestBodyStrings = Lists.newArrayList("A", "B", "C"); List bodyBytes = requestBodyStrings.stream() - .map(s -> ByteBuffer.wrap(s.getBytes(StandardCharsets.UTF_8))) - .collect(Collectors.toList()); + .map(s -> ByteBuffer.wrap(s.getBytes(StandardCharsets.UTF_8))) + .collect(Collectors.toList()); Publisher bodyPublisher = Flowable.fromIterable(bodyBytes); ChecksumCalculatingAsyncRequestBody.Builder builder = ChecksumCalculatingAsyncRequestBody.builder() @@ -208,16 +233,16 @@ public void fromBytes_byteArrayNotNullChecksumSupplied() { byte[] original = {1, 2, 3, 4}; // Checksum data in byte format. byte[] expected = {52, 13, 10, - 1, 2, 3, 4, 13, 10, - 48, 13, 10, 120, 45, 97, 109, 122, 110, 45, 99, 104, 101, 99, 107, 115, - 117, 109, 45, 99, 114, 99, 51, 50, 58, 116, 106, 122, 55, 122, 81, 61, 61, 13, 10, 13, 10}; + 1, 2, 3, 4, 13, 10, + 48, 13, 10, 120, 45, 97, 109, 122, 110, 45, 99, 104, 101, 99, 107, 115, + 117, 109, 45, 99, 114, 99, 51, 50, 58, 116, 106, 122, 55, 122, 81, 61, 61, 13, 10, 13, 10}; byte[] toModify = new byte[original.length]; System.arraycopy(original, 0, toModify, 0, original.length); AsyncRequestBody body = ChecksumCalculatingAsyncRequestBody.builder() - .asyncRequestBody(AsyncRequestBody.fromBytes(toModify)) - .algorithm(Algorithm.CRC32) - .trailerHeader("x-amzn-checksum-crc32") - .build(); + .asyncRequestBody(AsyncRequestBody.fromBytes(toModify)) + .algorithm(Algorithm.CRC32) + .trailerHeader("x-amzn-checksum-crc32") + .build(); for (int i = 0; i < toModify.length; ++i) { toModify[i]++; }