Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed an issue in ChecksumCalculatingAsyncRequestBody where the posit… #4244

Merged
merged 5 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ private static final class SynchronousChunkBuffer {
}

private Iterable<ByteBuffer> buffer(ByteBuffer bytes) {
return chunkBuffer.bufferAndCreateChunks(bytes);
return chunkBuffer.split(bytes);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.utils.BinaryUtils;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Validate;
import software.amazon.awssdk.utils.builder.SdkBuilder;

Expand All @@ -31,70 +32,115 @@
*/
@SdkInternalApi
public final class ChunkBuffer {
private final AtomicLong remainingBytes;
private static final Logger log = Logger.loggerFor(ChunkBuffer.class);
private final AtomicLong transferredBytes;
private final ByteBuffer currentBuffer;
private final int bufferSize;
private final int chunkSize;
private final long totalBytes;

private ChunkBuffer(Long totalBytes, Integer bufferSize) {
Validate.notNull(totalBytes, "The totalBytes must not be null");

int chunkSize = bufferSize != null ? bufferSize : DEFAULT_ASYNC_CHUNK_SIZE;
this.bufferSize = chunkSize;
this.chunkSize = chunkSize;
this.currentBuffer = ByteBuffer.allocate(chunkSize);
this.remainingBytes = new AtomicLong(totalBytes);
this.totalBytes = totalBytes;
this.transferredBytes = new AtomicLong(0);
}

public static Builder builder() {
return new DefaultBuilder();
}


// currentBuffer and bufferedList can get over written if concurrent Threads calls this method at the same time.
public synchronized Iterable<ByteBuffer> bufferAndCreateChunks(ByteBuffer buffer) {
int startPosition = 0;
List<ByteBuffer> bufferedList = new ArrayList<>();
int currentBytesRead = buffer.remaining();
do {
int bufferedBytes = currentBuffer.position();
int availableToRead = bufferSize - bufferedBytes;
int bytesToMove = Math.min(availableToRead, currentBytesRead - startPosition);
/**
* Split the input {@link ByteBuffer} into multiple smaller {@link ByteBuffer}s, each of which contains {@link #chunkSize}
* worth of bytes. If the last chunk of the input ByteBuffer contains less than {@link #chunkSize} data, the last chunk will
* be buffered.
*/
public synchronized Iterable<ByteBuffer> split(ByteBuffer inputByteBuffer) {

byte[] bytes = BinaryUtils.copyAllBytesFrom(buffer);
if (bufferedBytes == 0) {
currentBuffer.put(bytes, startPosition, bytesToMove);
} else {
currentBuffer.put(bytes, 0, bytesToMove);
if (!inputByteBuffer.hasRemaining()) {
return Collections.singletonList(inputByteBuffer);
}

List<ByteBuffer> byteBuffers = new ArrayList<>();

// If current buffer is not empty, fill the buffer first.
if (currentBuffer.position() != 0) {
fillBuffer(inputByteBuffer);

if (isCurrentBufferFull()) {
addCurrentBufferToIterable(byteBuffers, chunkSize);
}
}

// If the input buffer is not empty, split the input buffer
if (inputByteBuffer.hasRemaining()) {
splitInputBuffer(inputByteBuffer, byteBuffers);
}

// If this is the last chunk, add data buffered to the iterable
if (isLastChunk()) {
int remainingBytesInBuffer = currentBuffer.position();
addCurrentBufferToIterable(byteBuffers, remainingBytesInBuffer);
}
return byteBuffers;
}

private boolean isCurrentBufferFull() {
return currentBuffer.position() == chunkSize;
}

private void splitInputBuffer(ByteBuffer buffer, List<ByteBuffer> byteBuffers) {
zoewangg marked this conversation as resolved.
Show resolved Hide resolved
while (buffer.hasRemaining()) {
ByteBuffer chunkByteBuffer = buffer.asReadOnlyBuffer();
zoewangg marked this conversation as resolved.
Show resolved Hide resolved
if (buffer.remaining() < chunkSize) {
currentBuffer.put(buffer);
break;
}

startPosition = startPosition + bytesToMove;

// Send the data once the buffer is full
if (currentBuffer.position() == bufferSize) {
currentBuffer.position(0);
ByteBuffer bufferToSend = ByteBuffer.allocate(bufferSize);
bufferToSend.put(currentBuffer.array(), 0, bufferSize);
bufferToSend.clear();
currentBuffer.clear();
bufferedList.add(bufferToSend);
remainingBytes.addAndGet(-bufferSize);
int newLimit = chunkByteBuffer.position() + chunkSize;
chunkByteBuffer.limit(newLimit);
buffer.position(newLimit);
byteBuffers.add(chunkByteBuffer);
transferredBytes.addAndGet(chunkSize);
}
}

private boolean isLastChunk() {
long remainingBytes = totalBytes - transferredBytes.get();
return remainingBytes != 0 && remainingBytes == currentBuffer.position();
}

private void addCurrentBufferToIterable(List<ByteBuffer> byteBuffers, int capacity) {
ByteBuffer bufferedChunk = ByteBuffer.allocate(capacity);
currentBuffer.flip();
bufferedChunk.put(currentBuffer);
bufferedChunk.flip();
byteBuffers.add(bufferedChunk);
transferredBytes.addAndGet(bufferedChunk.remaining());
currentBuffer.clear();
}

private void fillBuffer(ByteBuffer inputByteBuffer) {
zoewangg marked this conversation as resolved.
Show resolved Hide resolved
while (currentBuffer.position() < chunkSize) {
if (!inputByteBuffer.hasRemaining()) {
break;
}

int remainingCapacity = chunkSize - currentBuffer.position();

if (inputByteBuffer.remaining() < remainingCapacity) {
currentBuffer.put(inputByteBuffer);
} else {
ByteBuffer remainingChunk = inputByteBuffer.asReadOnlyBuffer();
int newLimit = inputByteBuffer.position() + remainingCapacity;
remainingChunk.limit(newLimit);
inputByteBuffer.position(newLimit);
currentBuffer.put(remainingChunk);
}
} while (startPosition < currentBytesRead);

int remainingBytesInBuffer = currentBuffer.position();

// Send the remaining buffer when
// 1. remainingBytes in buffer are same as the last few bytes to be read.
// 2. If it is a zero byte and the last byte to be read.
if (remainingBytes.get() == remainingBytesInBuffer &&
(buffer.remaining() == 0 || remainingBytesInBuffer > 0)) {
currentBuffer.clear();
ByteBuffer trimmedBuffer = ByteBuffer.allocate(remainingBytesInBuffer);
trimmedBuffer.put(currentBuffer.array(), 0, remainingBytesInBuffer);
trimmedBuffer.clear();
bufferedList.add(trimmedBuffer);
remainingBytes.addAndGet(-remainingBytesInBuffer);
}
return bufferedList;
}

public interface Builder extends SdkBuilder<Builder, ChunkBuffer> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
Expand All @@ -29,8 +31,12 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import software.amazon.awssdk.core.internal.async.ChunkBuffer;
import software.amazon.awssdk.utils.BinaryUtils;
import software.amazon.awssdk.utils.StringUtils;

class ChunkBufferTest {
Expand All @@ -40,42 +46,38 @@ void builderWithNoTotalSize() {
assertThatThrownBy(() -> ChunkBuffer.builder().build()).isInstanceOf(NullPointerException.class);
}

@Test
void numberOfChunkMultipleOfTotalBytes() {
String inputString = StringUtils.repeat("*", 25);

ChunkBuffer chunkBuffer =
ChunkBuffer.builder().bufferSize(5).totalBytes(inputString.getBytes(StandardCharsets.UTF_8).length).build();
Iterable<ByteBuffer> byteBuffers =
chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8)));

AtomicInteger iteratedCounts = new AtomicInteger();
byteBuffers.forEach(r -> {
iteratedCounts.getAndIncrement();
assertThat(r.array()).isEqualTo(StringUtils.repeat("*", 5).getBytes(StandardCharsets.UTF_8));
});
assertThat(iteratedCounts.get()).isEqualTo(5);
}

@Test
void numberOfChunk_Not_MultipleOfTotalBytes() {
int totalBytes = 23;
@ParameterizedTest
@ValueSource(ints = {1, 6, 10, 23, 25})
void numberOfChunk_Not_MultipleOfTotalBytes(int totalBytes) {
int bufferSize = 5;

String inputString = StringUtils.repeat("*", totalBytes);
String inputString = RandomStringUtils.randomAscii(totalBytes);
ChunkBuffer chunkBuffer =
ChunkBuffer.builder().bufferSize(bufferSize).totalBytes(inputString.getBytes(StandardCharsets.UTF_8).length).build();
Iterable<ByteBuffer> byteBuffers =
chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8)));
chunkBuffer.split(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8)));

AtomicInteger index = new AtomicInteger(0);
int count = (int) Math.ceil(totalBytes / (double) bufferSize);
int remainder = totalBytes % bufferSize;

AtomicInteger iteratedCounts = new AtomicInteger();
byteBuffers.forEach(r -> {
iteratedCounts.getAndIncrement();
if (iteratedCounts.get() * bufferSize < totalBytes) {
assertThat(r.array()).isEqualTo(StringUtils.repeat("*", bufferSize).getBytes(StandardCharsets.UTF_8));
} else {
assertThat(r.array()).isEqualTo(StringUtils.repeat("*", 3).getBytes(StandardCharsets.UTF_8));
int i = index.get();

try (ByteArrayInputStream inputStream = new ByteArrayInputStream(inputString.getBytes(StandardCharsets.UTF_8))) {
byte[] expected;
if (i == count - 1 && remainder != 0) {
expected = new byte[remainder];
} else {
expected = new byte[bufferSize];
}
inputStream.skip(i * bufferSize);
inputStream.read(expected);
byte[] actualBytes = BinaryUtils.copyBytesFrom(r);
assertThat(actualBytes).isEqualTo(expected);
index.incrementAndGet();
} catch (IOException e) {
throw new RuntimeException(e);
}
});
}
Expand All @@ -86,7 +88,7 @@ void zeroTotalBytesAsInput_returnsZeroByte() {
ChunkBuffer chunkBuffer =
ChunkBuffer.builder().bufferSize(5).totalBytes(zeroByte.length).build();
Iterable<ByteBuffer> byteBuffers =
chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(zeroByte));
chunkBuffer.split(ByteBuffer.wrap(zeroByte));

AtomicInteger iteratedCounts = new AtomicInteger();
byteBuffers.forEach(r -> {
Expand All @@ -104,16 +106,16 @@ void emptyAllocatedBytes_returnSameNumberOfEmptyBytes() {
ChunkBuffer chunkBuffer =
ChunkBuffer.builder().bufferSize(bufferSize).totalBytes(wrap.remaining()).build();
Iterable<ByteBuffer> byteBuffers =
chunkBuffer.bufferAndCreateChunks(wrap);
chunkBuffer.split(wrap);

AtomicInteger iteratedCounts = new AtomicInteger();
byteBuffers.forEach(r -> {
iteratedCounts.getAndIncrement();
if (iteratedCounts.get() * bufferSize < totalBytes) {
// array of empty bytes
assertThat(r.array()).isEqualTo(ByteBuffer.allocate(bufferSize).array());
assertThat(BinaryUtils.copyBytesFrom(r)).isEqualTo(ByteBuffer.allocate(bufferSize).array());
} else {
assertThat(r.array()).isEqualTo(ByteBuffer.allocate(totalBytes % bufferSize).array());
assertThat(BinaryUtils.copyBytesFrom(r)).isEqualTo(ByteBuffer.allocate(totalBytes % bufferSize).array());
}
});
assertThat(iteratedCounts.get()).isEqualTo(4);
Expand Down Expand Up @@ -167,7 +169,7 @@ void concurrentTreads_calling_bufferAndCreateChunks() throws ExecutionException,

futures = IntStream.range(0, threads).<Future<Iterable>>mapToObj(t -> service.submit(() -> {
String inputString = StringUtils.repeat(Integer.toString(counter.incrementAndGet()), totalBytes);
return chunkBuffer.bufferAndCreateChunks(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8)));
return chunkBuffer.split(ByteBuffer.wrap(inputString.getBytes(StandardCharsets.UTF_8)));
})).collect(Collectors.toCollection(() -> new ArrayList<>(threads)));

AtomicInteger filledBuffers = new AtomicInteger(0);
Expand Down
Loading
Loading