diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index 819088df3..23f5eda00 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -22,6 +22,7 @@
+
diff --git a/storage/s3/src/main/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStream.java b/storage/s3/src/main/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStream.java
index d7d951469..e5172379c 100644
--- a/storage/s3/src/main/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStream.java
+++ b/storage/s3/src/main/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStream.java
@@ -18,12 +18,16 @@
import java.io.ByteArrayInputStream;
import java.io.IOException;
-import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
+import java.util.Map;
import java.util.Objects;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
+import java.util.stream.Collectors;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.AbortMultipartUploadRequest;
@@ -33,7 +37,6 @@
import com.amazonaws.services.s3.model.InitiateMultipartUploadResult;
import com.amazonaws.services.s3.model.PartETag;
import com.amazonaws.services.s3.model.UploadPartRequest;
-import com.amazonaws.services.s3.model.UploadPartResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -54,7 +57,9 @@ public class S3MultiPartOutputStream extends OutputStream {
final int partSize;
private final String uploadId;
- private final List partETags = new ArrayList<>();
+ // concurrent tags are required as multiple threads will write to it after upload result is returned
+ private final Map partETags = new ConcurrentHashMap<>();
+ private final List> partUploads = new ArrayList<>();
private boolean closed;
@@ -88,13 +93,17 @@ public void write(final byte[] b, final int off, final int len) throws IOExcepti
}
final ByteBuffer source = ByteBuffer.wrap(b, off, len);
while (source.hasRemaining()) {
- final int transferred = Math.min(partBuffer.remaining(), source.remaining());
- final int offset = source.arrayOffset() + source.position();
- // TODO: get rid of this array copying
- partBuffer.put(source.array(), offset, transferred);
- source.position(source.position() + transferred);
+ final int toCopy = Math.min(partBuffer.remaining(), source.remaining());
+ final int positionAfterCopying = source.position() + toCopy;
+ source.limit(positionAfterCopying);
+ partBuffer.put(source.slice());
+ source.clear(); // reset limit
+ source.position(positionAfterCopying);
if (!partBuffer.hasRemaining()) {
- flushBuffer(0, partSize);
+ partBuffer.position(0);
+ partBuffer.limit(partSize);
+ flushBuffer(partBuffer.slice(), partSize);
+ partBuffer.clear();
}
}
}
@@ -102,18 +111,31 @@ public void write(final byte[] b, final int off, final int len) throws IOExcepti
@Override
public void close() throws IOException {
if (partBuffer.position() > 0) {
- flushBuffer(partBuffer.arrayOffset(), partBuffer.position());
+ final int actualPartSize = partBuffer.position();
+ partBuffer.position(0);
+ partBuffer.limit(actualPartSize);
+ flushBuffer(partBuffer.slice(), actualPartSize);
}
if (Objects.nonNull(uploadId)) {
- if (!partETags.isEmpty()) {
+ if (!partUploads.isEmpty()) {
try {
- final CompleteMultipartUploadRequest request =
- new CompleteMultipartUploadRequest(bucketName, key, uploadId, partETags);
- final CompleteMultipartUploadResult result = client.completeMultipartUpload(request);
- log.debug("Completed multipart upload {} with result {}", uploadId, result);
- } catch (final Exception e) {
+ // wait for all uploads to complete successfully before committing
+ CompletableFuture.allOf(partUploads.toArray(new CompletableFuture[0]))
+ .thenAccept(unused -> {
+ final List tags = partETags.entrySet()
+ .stream()
+ .map(entry -> new PartETag(entry.getKey(), entry.getValue()))
+ .collect(Collectors.toList());
+ final CompleteMultipartUploadRequest request =
+ new CompleteMultipartUploadRequest(bucketName, key, uploadId, tags);
+ final CompleteMultipartUploadResult result = client.completeMultipartUpload(request);
+ log.debug("Completed multipart upload {} with result {}", uploadId, result);
+ })
+ .get(); // TODO: maybe set a timeout?
+ } catch (final InterruptedException | ExecutionException e) {
log.error("Failed to complete multipart upload {}, aborting transaction", uploadId, e);
client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId));
+ throw new IOException(e);
}
} else {
client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId));
@@ -122,12 +144,27 @@ public void close() throws IOException {
closed = true;
}
- private void flushBuffer(final int offset,
- final int actualPartSize) throws IOException {
+ private void flushBuffer(final ByteBuffer partBuffer, final int actualPartSize) throws IOException {
try {
- final ByteArrayInputStream in = new ByteArrayInputStream(partBuffer.array(), offset, actualPartSize);
- uploadPart(in, actualPartSize);
- partBuffer.clear();
+ final byte[] array = new byte[actualPartSize];
+ partBuffer.get(array, 0, actualPartSize);
+
+ final UploadPartRequest uploadPartRequest =
+ new UploadPartRequest()
+ .withBucketName(bucketName)
+ .withKey(key)
+ .withUploadId(uploadId)
+ .withPartSize(actualPartSize)
+ .withPartNumber(partUploads.size() + 1)
+ .withInputStream(new ByteArrayInputStream(array));
+
+ // Run request async
+ final CompletableFuture upload =
+ CompletableFuture.supplyAsync(() -> client.uploadPart(uploadPartRequest))
+ .thenAccept(result ->
+ partETags.put(result.getPartETag().getPartNumber(), result.getPartETag().getETag()));
+
+ partUploads.add(upload);
} catch (final Exception e) {
log.error("Failed to upload part in multipart upload {}, aborting transaction", uploadId, e);
client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId));
@@ -135,18 +172,4 @@ private void flushBuffer(final int offset,
throw new IOException(e);
}
}
-
- private void uploadPart(final InputStream in, final int actualPartSize) {
- final int partNumber = partETags.size() + 1;
- final UploadPartRequest uploadPartRequest =
- new UploadPartRequest()
- .withBucketName(bucketName)
- .withKey(key)
- .withUploadId(uploadId)
- .withPartSize(actualPartSize)
- .withPartNumber(partNumber)
- .withInputStream(in);
- final UploadPartResult uploadResult = client.uploadPart(uploadPartRequest);
- partETags.add(uploadResult.getPartETag());
- }
}
diff --git a/storage/s3/src/test/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStreamTest.java b/storage/s3/src/test/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStreamTest.java
index a1bc155b8..157674a3f 100644
--- a/storage/s3/src/test/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStreamTest.java
+++ b/storage/s3/src/test/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStreamTest.java
@@ -16,11 +16,12 @@
package io.aiven.kafka.tieredstorage.storage.s3;
-import java.io.ByteArrayInputStream;
import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
+import java.util.HashMap;
+import java.util.Map;
import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.stream.Collectors;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.AbortMultipartUploadRequest;
@@ -83,7 +84,7 @@ void sendAbortForAnyExceptionWhileWriting() {
new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, 100, mockedS3)) {
out.write(new byte[] {1, 2, 3});
}
- }).isInstanceOf(IOException.class).hasCause(testException);
+ }).isInstanceOf(IOException.class).hasRootCause(testException);
verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class));
verify(mockedS3).uploadPart(any(UploadPartRequest.class));
@@ -135,39 +136,48 @@ void writesOneByte() throws Exception {
verify(mockedS3).uploadPart(any(UploadPartRequest.class));
verify(mockedS3).completeMultipartUpload(any(CompleteMultipartUploadRequest.class));
+ final UploadPartRequest value = uploadPartRequestCaptor.getValue();
assertUploadPartRequest(
- uploadPartRequestCaptor.getValue(),
+ value,
+ value.getInputStream().readAllBytes(),
1,
1,
new byte[] {1});
assertCompleteMultipartUploadRequest(
completeMultipartUploadRequestCaptor.getValue(),
- List.of(new PartETag(1, "SOME_ETAG"))
+ Map.of(1, "SOME_ETAG")
);
}
@Test
void writesMultipleMessages() throws Exception {
final int bufferSize = 10;
- final byte[] message = new byte[bufferSize];
when(mockedS3.initiateMultipartUpload(any(InitiateMultipartUploadRequest.class)))
.thenReturn(newInitiateMultipartUploadResult());
+
+ final Map uploadPartRequests = new ConcurrentHashMap<>();
+ final Map uploadPartContents = new ConcurrentHashMap<>();
when(mockedS3.uploadPart(uploadPartRequestCaptor.capture()))
- .thenAnswer(a -> {
- final UploadPartRequest up = a.getArgument(0);
+ .thenAnswer(answer -> {
+ final UploadPartRequest up = answer.getArgument(0);
+ //emulate behave of S3 client otherwise we will get wrong array in the memory
+ uploadPartRequests.put(up.getPartNumber(), up);
+ uploadPartContents.put(up.getPartNumber(), up.getInputStream().readAllBytes());
+
return newUploadPartResult(up.getPartNumber(), "SOME_TAG#" + up.getPartNumber());
});
when(mockedS3.completeMultipartUpload(completeMultipartUploadRequestCaptor.capture()))
.thenReturn(new CompleteMultipartUploadResult());
- final List expectedMessagesList = new ArrayList<>();
+ final Map expectedMessageParts = new HashMap<>();
try (final S3MultiPartOutputStream out =
new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, bufferSize, mockedS3)) {
for (int i = 0; i < 3; i++) {
+ final byte[] message = new byte[bufferSize];
random.nextBytes(message);
out.write(message, 0, message.length);
- expectedMessagesList.add(message);
+ expectedMessageParts.put(i + 1, message);
}
}
@@ -175,21 +185,20 @@ void writesMultipleMessages() throws Exception {
verify(mockedS3, times(3)).uploadPart(any(UploadPartRequest.class));
verify(mockedS3).completeMultipartUpload(any(CompleteMultipartUploadRequest.class));
- final List uploadRequests = uploadPartRequestCaptor.getAllValues();
- int counter = 0;
- for (final byte[] expectedMessage : expectedMessagesList) {
+ for (final Integer part : expectedMessageParts.keySet()) {
assertUploadPartRequest(
- uploadRequests.get(counter),
+ uploadPartRequests.get(part),
+ uploadPartContents.get(part),
bufferSize,
- counter + 1,
- expectedMessage);
- counter++;
+ part,
+ expectedMessageParts.get(part)
+ );
}
assertCompleteMultipartUploadRequest(
completeMultipartUploadRequestCaptor.getValue(),
- List.of(new PartETag(1, "SOME_TAG#1"),
- new PartETag(2, "SOME_TAG#2"),
- new PartETag(3, "SOME_TAG#3"))
+ Map.of(1, "SOME_TAG#1",
+ 2, "SOME_TAG#2",
+ 3, "SOME_TAG#3")
);
}
@@ -197,40 +206,46 @@ void writesMultipleMessages() throws Exception {
void writesTailMessages() throws Exception {
final int messageSize = 20;
- final List uploadPartRequests = new ArrayList<>();
+ final Map uploadPartRequests = new ConcurrentHashMap<>();
+ final Map uploadPartContents = new ConcurrentHashMap<>();
when(mockedS3.initiateMultipartUpload(any(InitiateMultipartUploadRequest.class)))
.thenReturn(newInitiateMultipartUploadResult());
when(mockedS3.uploadPart(any(UploadPartRequest.class)))
- .thenAnswer(a -> {
- final UploadPartRequest up = a.getArgument(0);
+ .thenAnswer(answer -> {
+ final UploadPartRequest up = answer.getArgument(0);
//emulate behave of S3 client otherwise we will get wrong array in the memory
- up.setInputStream(new ByteArrayInputStream(up.getInputStream().readAllBytes()));
- uploadPartRequests.add(up);
+ uploadPartRequests.put(up.getPartNumber(), up);
+ uploadPartContents.put(up.getPartNumber(), up.getInputStream().readAllBytes());
return newUploadPartResult(up.getPartNumber(), "SOME_TAG#" + up.getPartNumber());
});
when(mockedS3.completeMultipartUpload(completeMultipartUploadRequestCaptor.capture()))
.thenReturn(new CompleteMultipartUploadResult());
- final byte[] message = new byte[messageSize];
final byte[] expectedFullMessage = new byte[messageSize + 10];
final byte[] expectedTailMessage = new byte[10];
- final S3MultiPartOutputStream
- out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, messageSize + 10, mockedS3);
- random.nextBytes(message);
- out.write(message);
- System.arraycopy(message, 0, expectedFullMessage, 0, message.length);
- random.nextBytes(message);
- out.write(message);
- System.arraycopy(message, 0, expectedFullMessage, 20, 10);
- System.arraycopy(message, 10, expectedTailMessage, 0, 10);
+ final var out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, messageSize + 10, mockedS3);
+ {
+ final byte[] message = new byte[messageSize];
+ random.nextBytes(message);
+ out.write(message);
+ System.arraycopy(message, 0, expectedFullMessage, 0, message.length);
+ }
+ {
+ final byte[] message = new byte[messageSize];
+ random.nextBytes(message);
+ out.write(message);
+ System.arraycopy(message, 0, expectedFullMessage, 20, 10);
+ System.arraycopy(message, 10, expectedTailMessage, 0, 10);
+ }
out.close();
- assertUploadPartRequest(uploadPartRequests.get(0), 30, 1, expectedFullMessage);
- assertUploadPartRequest(uploadPartRequests.get(1), 10, 2, expectedTailMessage);
+ assertThat(uploadPartRequests).hasSize(2);
+ assertUploadPartRequest(uploadPartRequests.get(1), uploadPartContents.get(1), 30, 1, expectedFullMessage);
+ assertUploadPartRequest(uploadPartRequests.get(2), uploadPartContents.get(2), 10, 2, expectedTailMessage);
verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class));
verify(mockedS3, times(2)).uploadPart(any(UploadPartRequest.class));
@@ -251,6 +266,7 @@ private static UploadPartResult newUploadPartResult(final int partNumber, final
}
private static void assertUploadPartRequest(final UploadPartRequest uploadPartRequest,
+ final byte[] bytes,
final int expectedPartSize,
final int expectedPartNumber,
final byte[] expectedBytes) {
@@ -259,23 +275,17 @@ private static void assertUploadPartRequest(final UploadPartRequest uploadPartRe
assertThat(uploadPartRequest.getPartNumber()).isEqualTo(expectedPartNumber);
assertThat(uploadPartRequest.getBucketName()).isEqualTo(BUCKET_NAME);
assertThat(uploadPartRequest.getKey()).isEqualTo(FILE_KEY);
- assertThat(uploadPartRequest.getInputStream()).hasBinaryContent(expectedBytes);
+ assertThat(bytes).isEqualTo(expectedBytes);
}
private static void assertCompleteMultipartUploadRequest(final CompleteMultipartUploadRequest request,
- final List expectedETags) {
+ final Map expectedETags) {
assertThat(request.getBucketName()).isEqualTo(BUCKET_NAME);
assertThat(request.getKey()).isEqualTo(FILE_KEY);
assertThat(request.getUploadId()).isEqualTo(UPLOAD_ID);
- assertThat(request.getPartETags()).hasSameSizeAs(expectedETags);
-
- for (int i = 0; i < expectedETags.size(); i++) {
- final PartETag expectedETag = expectedETags.get(i);
- final PartETag etag = request.getPartETags().get(i);
-
- assertThat(etag.getPartNumber()).isEqualTo(expectedETag.getPartNumber());
- assertThat(etag.getETag()).isEqualTo(expectedETag.getETag());
- }
+ final Map tags = request.getPartETags().stream()
+ .collect(Collectors.toMap(PartETag::getPartNumber, PartETag::getETag));
+ assertThat(tags).containsExactlyInAnyOrderEntriesOf(expectedETags);
}
private static void assertAbortMultipartUploadRequest(final AbortMultipartUploadRequest request) {