From a78f09a5a686fffe2dc48c4838d842658573a321 Mon Sep 17 00:00:00 2001 From: Jorge Esteban Quilcate Otoya Date: Tue, 13 Jun 2023 18:24:46 +0300 Subject: [PATCH] feat(storage:s3): multi-part upload: upload parts concurrently --- checkstyle/suppressions.xml | 1 + .../storage/s3/S3MultiPartOutputStream.java | 93 +++++++++------ .../s3/S3MultiPartOutputStreamTest.java | 108 ++++++++++-------- 3 files changed, 118 insertions(+), 84 deletions(-) diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml index 819088df3..23f5eda00 100644 --- a/checkstyle/suppressions.xml +++ b/checkstyle/suppressions.xml @@ -22,6 +22,7 @@ + diff --git a/storage/s3/src/main/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStream.java b/storage/s3/src/main/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStream.java index d7d951469..e5172379c 100644 --- a/storage/s3/src/main/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStream.java +++ b/storage/s3/src/main/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStream.java @@ -18,12 +18,16 @@ import java.io.ByteArrayInputStream; import java.io.IOException; -import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.services.s3.model.AbortMultipartUploadRequest; @@ -33,7 +37,6 @@ import com.amazonaws.services.s3.model.InitiateMultipartUploadResult; import com.amazonaws.services.s3.model.PartETag; import com.amazonaws.services.s3.model.UploadPartRequest; -import com.amazonaws.services.s3.model.UploadPartResult; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,7 +57,9 @@ public class S3MultiPartOutputStream extends OutputStream { final int partSize; private final String uploadId; - private final List partETags = new ArrayList<>(); + // concurrent tags are required as multiple threads will write to it after upload result is returned + private final Map partETags = new ConcurrentHashMap<>(); + private final List> partUploads = new ArrayList<>(); private boolean closed; @@ -88,13 +93,17 @@ public void write(final byte[] b, final int off, final int len) throws IOExcepti } final ByteBuffer source = ByteBuffer.wrap(b, off, len); while (source.hasRemaining()) { - final int transferred = Math.min(partBuffer.remaining(), source.remaining()); - final int offset = source.arrayOffset() + source.position(); - // TODO: get rid of this array copying - partBuffer.put(source.array(), offset, transferred); - source.position(source.position() + transferred); + final int toCopy = Math.min(partBuffer.remaining(), source.remaining()); + final int positionAfterCopying = source.position() + toCopy; + source.limit(positionAfterCopying); + partBuffer.put(source.slice()); + source.clear(); // reset limit + source.position(positionAfterCopying); if (!partBuffer.hasRemaining()) { - flushBuffer(0, partSize); + partBuffer.position(0); + partBuffer.limit(partSize); + flushBuffer(partBuffer.slice(), partSize); + partBuffer.clear(); } } } @@ -102,18 +111,31 @@ public void write(final byte[] b, final int off, final int len) throws IOExcepti @Override public void close() throws IOException { if (partBuffer.position() > 0) { - flushBuffer(partBuffer.arrayOffset(), partBuffer.position()); + final int actualPartSize = partBuffer.position(); + partBuffer.position(0); + partBuffer.limit(actualPartSize); + flushBuffer(partBuffer.slice(), actualPartSize); } if (Objects.nonNull(uploadId)) { - if (!partETags.isEmpty()) { + if (!partUploads.isEmpty()) { try { - final CompleteMultipartUploadRequest request = - new CompleteMultipartUploadRequest(bucketName, key, uploadId, partETags); - final CompleteMultipartUploadResult result = client.completeMultipartUpload(request); - log.debug("Completed multipart upload {} with result {}", uploadId, result); - } catch (final Exception e) { + // wait for all uploads to complete successfully before committing + CompletableFuture.allOf(partUploads.toArray(new CompletableFuture[0])) + .thenAccept(unused -> { + final List tags = partETags.entrySet() + .stream() + .map(entry -> new PartETag(entry.getKey(), entry.getValue())) + .collect(Collectors.toList()); + final CompleteMultipartUploadRequest request = + new CompleteMultipartUploadRequest(bucketName, key, uploadId, tags); + final CompleteMultipartUploadResult result = client.completeMultipartUpload(request); + log.debug("Completed multipart upload {} with result {}", uploadId, result); + }) + .get(); // TODO: maybe set a timeout? + } catch (final InterruptedException | ExecutionException e) { log.error("Failed to complete multipart upload {}, aborting transaction", uploadId, e); client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId)); + throw new IOException(e); } } else { client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId)); @@ -122,12 +144,27 @@ public void close() throws IOException { closed = true; } - private void flushBuffer(final int offset, - final int actualPartSize) throws IOException { + private void flushBuffer(final ByteBuffer partBuffer, final int actualPartSize) throws IOException { try { - final ByteArrayInputStream in = new ByteArrayInputStream(partBuffer.array(), offset, actualPartSize); - uploadPart(in, actualPartSize); - partBuffer.clear(); + final byte[] array = new byte[actualPartSize]; + partBuffer.get(array, 0, actualPartSize); + + final UploadPartRequest uploadPartRequest = + new UploadPartRequest() + .withBucketName(bucketName) + .withKey(key) + .withUploadId(uploadId) + .withPartSize(actualPartSize) + .withPartNumber(partUploads.size() + 1) + .withInputStream(new ByteArrayInputStream(array)); + + // Run request async + final CompletableFuture upload = + CompletableFuture.supplyAsync(() -> client.uploadPart(uploadPartRequest)) + .thenAccept(result -> + partETags.put(result.getPartETag().getPartNumber(), result.getPartETag().getETag())); + + partUploads.add(upload); } catch (final Exception e) { log.error("Failed to upload part in multipart upload {}, aborting transaction", uploadId, e); client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId)); @@ -135,18 +172,4 @@ private void flushBuffer(final int offset, throw new IOException(e); } } - - private void uploadPart(final InputStream in, final int actualPartSize) { - final int partNumber = partETags.size() + 1; - final UploadPartRequest uploadPartRequest = - new UploadPartRequest() - .withBucketName(bucketName) - .withKey(key) - .withUploadId(uploadId) - .withPartSize(actualPartSize) - .withPartNumber(partNumber) - .withInputStream(in); - final UploadPartResult uploadResult = client.uploadPart(uploadPartRequest); - partETags.add(uploadResult.getPartETag()); - } } diff --git a/storage/s3/src/test/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStreamTest.java b/storage/s3/src/test/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStreamTest.java index a1bc155b8..157674a3f 100644 --- a/storage/s3/src/test/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStreamTest.java +++ b/storage/s3/src/test/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStreamTest.java @@ -16,11 +16,12 @@ package io.aiven.kafka.tieredstorage.storage.s3; -import java.io.ByteArrayInputStream; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import java.util.HashMap; +import java.util.Map; import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.services.s3.model.AbortMultipartUploadRequest; @@ -83,7 +84,7 @@ void sendAbortForAnyExceptionWhileWriting() { new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, 100, mockedS3)) { out.write(new byte[] {1, 2, 3}); } - }).isInstanceOf(IOException.class).hasCause(testException); + }).isInstanceOf(IOException.class).hasRootCause(testException); verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class)); verify(mockedS3).uploadPart(any(UploadPartRequest.class)); @@ -135,39 +136,48 @@ void writesOneByte() throws Exception { verify(mockedS3).uploadPart(any(UploadPartRequest.class)); verify(mockedS3).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + final UploadPartRequest value = uploadPartRequestCaptor.getValue(); assertUploadPartRequest( - uploadPartRequestCaptor.getValue(), + value, + value.getInputStream().readAllBytes(), 1, 1, new byte[] {1}); assertCompleteMultipartUploadRequest( completeMultipartUploadRequestCaptor.getValue(), - List.of(new PartETag(1, "SOME_ETAG")) + Map.of(1, "SOME_ETAG") ); } @Test void writesMultipleMessages() throws Exception { final int bufferSize = 10; - final byte[] message = new byte[bufferSize]; when(mockedS3.initiateMultipartUpload(any(InitiateMultipartUploadRequest.class))) .thenReturn(newInitiateMultipartUploadResult()); + + final Map uploadPartRequests = new ConcurrentHashMap<>(); + final Map uploadPartContents = new ConcurrentHashMap<>(); when(mockedS3.uploadPart(uploadPartRequestCaptor.capture())) - .thenAnswer(a -> { - final UploadPartRequest up = a.getArgument(0); + .thenAnswer(answer -> { + final UploadPartRequest up = answer.getArgument(0); + //emulate behave of S3 client otherwise we will get wrong array in the memory + uploadPartRequests.put(up.getPartNumber(), up); + uploadPartContents.put(up.getPartNumber(), up.getInputStream().readAllBytes()); + return newUploadPartResult(up.getPartNumber(), "SOME_TAG#" + up.getPartNumber()); }); when(mockedS3.completeMultipartUpload(completeMultipartUploadRequestCaptor.capture())) .thenReturn(new CompleteMultipartUploadResult()); - final List expectedMessagesList = new ArrayList<>(); + final Map expectedMessageParts = new HashMap<>(); try (final S3MultiPartOutputStream out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, bufferSize, mockedS3)) { for (int i = 0; i < 3; i++) { + final byte[] message = new byte[bufferSize]; random.nextBytes(message); out.write(message, 0, message.length); - expectedMessagesList.add(message); + expectedMessageParts.put(i + 1, message); } } @@ -175,21 +185,20 @@ void writesMultipleMessages() throws Exception { verify(mockedS3, times(3)).uploadPart(any(UploadPartRequest.class)); verify(mockedS3).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); - final List uploadRequests = uploadPartRequestCaptor.getAllValues(); - int counter = 0; - for (final byte[] expectedMessage : expectedMessagesList) { + for (final Integer part : expectedMessageParts.keySet()) { assertUploadPartRequest( - uploadRequests.get(counter), + uploadPartRequests.get(part), + uploadPartContents.get(part), bufferSize, - counter + 1, - expectedMessage); - counter++; + part, + expectedMessageParts.get(part) + ); } assertCompleteMultipartUploadRequest( completeMultipartUploadRequestCaptor.getValue(), - List.of(new PartETag(1, "SOME_TAG#1"), - new PartETag(2, "SOME_TAG#2"), - new PartETag(3, "SOME_TAG#3")) + Map.of(1, "SOME_TAG#1", + 2, "SOME_TAG#2", + 3, "SOME_TAG#3") ); } @@ -197,40 +206,46 @@ void writesMultipleMessages() throws Exception { void writesTailMessages() throws Exception { final int messageSize = 20; - final List uploadPartRequests = new ArrayList<>(); + final Map uploadPartRequests = new ConcurrentHashMap<>(); + final Map uploadPartContents = new ConcurrentHashMap<>(); when(mockedS3.initiateMultipartUpload(any(InitiateMultipartUploadRequest.class))) .thenReturn(newInitiateMultipartUploadResult()); when(mockedS3.uploadPart(any(UploadPartRequest.class))) - .thenAnswer(a -> { - final UploadPartRequest up = a.getArgument(0); + .thenAnswer(answer -> { + final UploadPartRequest up = answer.getArgument(0); //emulate behave of S3 client otherwise we will get wrong array in the memory - up.setInputStream(new ByteArrayInputStream(up.getInputStream().readAllBytes())); - uploadPartRequests.add(up); + uploadPartRequests.put(up.getPartNumber(), up); + uploadPartContents.put(up.getPartNumber(), up.getInputStream().readAllBytes()); return newUploadPartResult(up.getPartNumber(), "SOME_TAG#" + up.getPartNumber()); }); when(mockedS3.completeMultipartUpload(completeMultipartUploadRequestCaptor.capture())) .thenReturn(new CompleteMultipartUploadResult()); - final byte[] message = new byte[messageSize]; final byte[] expectedFullMessage = new byte[messageSize + 10]; final byte[] expectedTailMessage = new byte[10]; - final S3MultiPartOutputStream - out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, messageSize + 10, mockedS3); - random.nextBytes(message); - out.write(message); - System.arraycopy(message, 0, expectedFullMessage, 0, message.length); - random.nextBytes(message); - out.write(message); - System.arraycopy(message, 0, expectedFullMessage, 20, 10); - System.arraycopy(message, 10, expectedTailMessage, 0, 10); + final var out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, messageSize + 10, mockedS3); + { + final byte[] message = new byte[messageSize]; + random.nextBytes(message); + out.write(message); + System.arraycopy(message, 0, expectedFullMessage, 0, message.length); + } + { + final byte[] message = new byte[messageSize]; + random.nextBytes(message); + out.write(message); + System.arraycopy(message, 0, expectedFullMessage, 20, 10); + System.arraycopy(message, 10, expectedTailMessage, 0, 10); + } out.close(); - assertUploadPartRequest(uploadPartRequests.get(0), 30, 1, expectedFullMessage); - assertUploadPartRequest(uploadPartRequests.get(1), 10, 2, expectedTailMessage); + assertThat(uploadPartRequests).hasSize(2); + assertUploadPartRequest(uploadPartRequests.get(1), uploadPartContents.get(1), 30, 1, expectedFullMessage); + assertUploadPartRequest(uploadPartRequests.get(2), uploadPartContents.get(2), 10, 2, expectedTailMessage); verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class)); verify(mockedS3, times(2)).uploadPart(any(UploadPartRequest.class)); @@ -251,6 +266,7 @@ private static UploadPartResult newUploadPartResult(final int partNumber, final } private static void assertUploadPartRequest(final UploadPartRequest uploadPartRequest, + final byte[] bytes, final int expectedPartSize, final int expectedPartNumber, final byte[] expectedBytes) { @@ -259,23 +275,17 @@ private static void assertUploadPartRequest(final UploadPartRequest uploadPartRe assertThat(uploadPartRequest.getPartNumber()).isEqualTo(expectedPartNumber); assertThat(uploadPartRequest.getBucketName()).isEqualTo(BUCKET_NAME); assertThat(uploadPartRequest.getKey()).isEqualTo(FILE_KEY); - assertThat(uploadPartRequest.getInputStream()).hasBinaryContent(expectedBytes); + assertThat(bytes).isEqualTo(expectedBytes); } private static void assertCompleteMultipartUploadRequest(final CompleteMultipartUploadRequest request, - final List expectedETags) { + final Map expectedETags) { assertThat(request.getBucketName()).isEqualTo(BUCKET_NAME); assertThat(request.getKey()).isEqualTo(FILE_KEY); assertThat(request.getUploadId()).isEqualTo(UPLOAD_ID); - assertThat(request.getPartETags()).hasSameSizeAs(expectedETags); - - for (int i = 0; i < expectedETags.size(); i++) { - final PartETag expectedETag = expectedETags.get(i); - final PartETag etag = request.getPartETags().get(i); - - assertThat(etag.getPartNumber()).isEqualTo(expectedETag.getPartNumber()); - assertThat(etag.getETag()).isEqualTo(expectedETag.getETag()); - } + final Map tags = request.getPartETags().stream() + .collect(Collectors.toMap(PartETag::getPartNumber, PartETag::getETag)); + assertThat(tags).containsExactlyInAnyOrderEntriesOf(expectedETags); } private static void assertAbortMultipartUploadRequest(final AbortMultipartUploadRequest request) {