Skip to content

Commit

Permalink
feat(storage:s3): multi-part upload: upload parts concurrently
Browse files Browse the repository at this point in the history
  • Loading branch information
jeqo committed Jun 13, 2023
1 parent b4d97d8 commit a78f09a
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 84 deletions.
1 change: 1 addition & 0 deletions checkstyle/suppressions.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
<suppress checks="ClassDataAbstractionCoupling" files=".*Test\.java"/>
<suppress checks="ClassFanOutComplexity" files=".*Test\.java"/>
<suppress checks="ClassFanOutComplexity" files="RemoteStorageManager.java"/>
<suppress checks="ClassDataAbstractionCoupling" files="S3MultiPartOutputStream.java"/>
<suppress checks="ClassDataAbstractionCoupling" files="S3StorageConfig.java"/>
<suppress checks="ClassDataAbstractionCoupling" files="RemoteStorageManager.java"/>
</suppressions>
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.AbortMultipartUploadRequest;
Expand All @@ -33,7 +37,6 @@
import com.amazonaws.services.s3.model.InitiateMultipartUploadResult;
import com.amazonaws.services.s3.model.PartETag;
import com.amazonaws.services.s3.model.UploadPartRequest;
import com.amazonaws.services.s3.model.UploadPartResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -54,7 +57,9 @@ public class S3MultiPartOutputStream extends OutputStream {
final int partSize;

private final String uploadId;
private final List<PartETag> partETags = new ArrayList<>();
// concurrent tags are required as multiple threads will write to it after upload result is returned
private final Map<Integer, String> partETags = new ConcurrentHashMap<>();
private final List<CompletableFuture<Void>> partUploads = new ArrayList<>();

private boolean closed;

Expand Down Expand Up @@ -88,32 +93,49 @@ public void write(final byte[] b, final int off, final int len) throws IOExcepti
}
final ByteBuffer source = ByteBuffer.wrap(b, off, len);
while (source.hasRemaining()) {
final int transferred = Math.min(partBuffer.remaining(), source.remaining());
final int offset = source.arrayOffset() + source.position();
// TODO: get rid of this array copying
partBuffer.put(source.array(), offset, transferred);
source.position(source.position() + transferred);
final int toCopy = Math.min(partBuffer.remaining(), source.remaining());
final int positionAfterCopying = source.position() + toCopy;
source.limit(positionAfterCopying);
partBuffer.put(source.slice());
source.clear(); // reset limit
source.position(positionAfterCopying);
if (!partBuffer.hasRemaining()) {
flushBuffer(0, partSize);
partBuffer.position(0);
partBuffer.limit(partSize);
flushBuffer(partBuffer.slice(), partSize);
partBuffer.clear();
}
}
}

@Override
public void close() throws IOException {
if (partBuffer.position() > 0) {
flushBuffer(partBuffer.arrayOffset(), partBuffer.position());
final int actualPartSize = partBuffer.position();
partBuffer.position(0);
partBuffer.limit(actualPartSize);
flushBuffer(partBuffer.slice(), actualPartSize);
}
if (Objects.nonNull(uploadId)) {
if (!partETags.isEmpty()) {
if (!partUploads.isEmpty()) {
try {
final CompleteMultipartUploadRequest request =
new CompleteMultipartUploadRequest(bucketName, key, uploadId, partETags);
final CompleteMultipartUploadResult result = client.completeMultipartUpload(request);
log.debug("Completed multipart upload {} with result {}", uploadId, result);
} catch (final Exception e) {
// wait for all uploads to complete successfully before committing
CompletableFuture.allOf(partUploads.toArray(new CompletableFuture[0]))
.thenAccept(unused -> {
final List<PartETag> tags = partETags.entrySet()
.stream()
.map(entry -> new PartETag(entry.getKey(), entry.getValue()))
.collect(Collectors.toList());
final CompleteMultipartUploadRequest request =
new CompleteMultipartUploadRequest(bucketName, key, uploadId, tags);
final CompleteMultipartUploadResult result = client.completeMultipartUpload(request);
log.debug("Completed multipart upload {} with result {}", uploadId, result);
})
.get(); // TODO: maybe set a timeout?
} catch (final InterruptedException | ExecutionException e) {
log.error("Failed to complete multipart upload {}, aborting transaction", uploadId, e);
client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId));
throw new IOException(e);
}
} else {
client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId));
Expand All @@ -122,31 +144,32 @@ public void close() throws IOException {
closed = true;
}

private void flushBuffer(final int offset,
final int actualPartSize) throws IOException {
private void flushBuffer(final ByteBuffer partBuffer, final int actualPartSize) throws IOException {
try {
final ByteArrayInputStream in = new ByteArrayInputStream(partBuffer.array(), offset, actualPartSize);
uploadPart(in, actualPartSize);
partBuffer.clear();
final byte[] array = new byte[actualPartSize];
partBuffer.get(array, 0, actualPartSize);

final UploadPartRequest uploadPartRequest =
new UploadPartRequest()
.withBucketName(bucketName)
.withKey(key)
.withUploadId(uploadId)
.withPartSize(actualPartSize)
.withPartNumber(partUploads.size() + 1)
.withInputStream(new ByteArrayInputStream(array));

// Run request async
final CompletableFuture<Void> upload =
CompletableFuture.supplyAsync(() -> client.uploadPart(uploadPartRequest))
.thenAccept(result ->
partETags.put(result.getPartETag().getPartNumber(), result.getPartETag().getETag()));

partUploads.add(upload);
} catch (final Exception e) {
log.error("Failed to upload part in multipart upload {}, aborting transaction", uploadId, e);
client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId));
closed = true;
throw new IOException(e);
}
}

private void uploadPart(final InputStream in, final int actualPartSize) {
final int partNumber = partETags.size() + 1;
final UploadPartRequest uploadPartRequest =
new UploadPartRequest()
.withBucketName(bucketName)
.withKey(key)
.withUploadId(uploadId)
.withPartSize(actualPartSize)
.withPartNumber(partNumber)
.withInputStream(in);
final UploadPartResult uploadResult = client.uploadPart(uploadPartRequest);
partETags.add(uploadResult.getPartETag());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

package io.aiven.kafka.tieredstorage.storage.s3;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.AbortMultipartUploadRequest;
Expand Down Expand Up @@ -83,7 +84,7 @@ void sendAbortForAnyExceptionWhileWriting() {
new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, 100, mockedS3)) {
out.write(new byte[] {1, 2, 3});
}
}).isInstanceOf(IOException.class).hasCause(testException);
}).isInstanceOf(IOException.class).hasRootCause(testException);

verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class));
verify(mockedS3).uploadPart(any(UploadPartRequest.class));
Expand Down Expand Up @@ -135,102 +136,116 @@ void writesOneByte() throws Exception {
verify(mockedS3).uploadPart(any(UploadPartRequest.class));
verify(mockedS3).completeMultipartUpload(any(CompleteMultipartUploadRequest.class));

final UploadPartRequest value = uploadPartRequestCaptor.getValue();
assertUploadPartRequest(
uploadPartRequestCaptor.getValue(),
value,
value.getInputStream().readAllBytes(),
1,
1,
new byte[] {1});
assertCompleteMultipartUploadRequest(
completeMultipartUploadRequestCaptor.getValue(),
List.of(new PartETag(1, "SOME_ETAG"))
Map.of(1, "SOME_ETAG")
);
}

@Test
void writesMultipleMessages() throws Exception {
final int bufferSize = 10;
final byte[] message = new byte[bufferSize];

when(mockedS3.initiateMultipartUpload(any(InitiateMultipartUploadRequest.class)))
.thenReturn(newInitiateMultipartUploadResult());

final Map<Integer, UploadPartRequest> uploadPartRequests = new ConcurrentHashMap<>();
final Map<Integer, byte[]> uploadPartContents = new ConcurrentHashMap<>();
when(mockedS3.uploadPart(uploadPartRequestCaptor.capture()))
.thenAnswer(a -> {
final UploadPartRequest up = a.getArgument(0);
.thenAnswer(answer -> {
final UploadPartRequest up = answer.getArgument(0);
//emulate behave of S3 client otherwise we will get wrong array in the memory
uploadPartRequests.put(up.getPartNumber(), up);
uploadPartContents.put(up.getPartNumber(), up.getInputStream().readAllBytes());

return newUploadPartResult(up.getPartNumber(), "SOME_TAG#" + up.getPartNumber());
});
when(mockedS3.completeMultipartUpload(completeMultipartUploadRequestCaptor.capture()))
.thenReturn(new CompleteMultipartUploadResult());

final List<byte[]> expectedMessagesList = new ArrayList<>();
final Map<Integer, byte[]> expectedMessageParts = new HashMap<>();
try (final S3MultiPartOutputStream out =
new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, bufferSize, mockedS3)) {
for (int i = 0; i < 3; i++) {
final byte[] message = new byte[bufferSize];
random.nextBytes(message);
out.write(message, 0, message.length);
expectedMessagesList.add(message);
expectedMessageParts.put(i + 1, message);
}
}

verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class));
verify(mockedS3, times(3)).uploadPart(any(UploadPartRequest.class));
verify(mockedS3).completeMultipartUpload(any(CompleteMultipartUploadRequest.class));

final List<UploadPartRequest> uploadRequests = uploadPartRequestCaptor.getAllValues();
int counter = 0;
for (final byte[] expectedMessage : expectedMessagesList) {
for (final Integer part : expectedMessageParts.keySet()) {
assertUploadPartRequest(
uploadRequests.get(counter),
uploadPartRequests.get(part),
uploadPartContents.get(part),
bufferSize,
counter + 1,
expectedMessage);
counter++;
part,
expectedMessageParts.get(part)
);
}
assertCompleteMultipartUploadRequest(
completeMultipartUploadRequestCaptor.getValue(),
List.of(new PartETag(1, "SOME_TAG#1"),
new PartETag(2, "SOME_TAG#2"),
new PartETag(3, "SOME_TAG#3"))
Map.of(1, "SOME_TAG#1",
2, "SOME_TAG#2",
3, "SOME_TAG#3")
);
}

@Test
void writesTailMessages() throws Exception {
final int messageSize = 20;

final List<UploadPartRequest> uploadPartRequests = new ArrayList<>();
final Map<Integer, UploadPartRequest> uploadPartRequests = new ConcurrentHashMap<>();
final Map<Integer, byte[]> uploadPartContents = new ConcurrentHashMap<>();

when(mockedS3.initiateMultipartUpload(any(InitiateMultipartUploadRequest.class)))
.thenReturn(newInitiateMultipartUploadResult());
when(mockedS3.uploadPart(any(UploadPartRequest.class)))
.thenAnswer(a -> {
final UploadPartRequest up = a.getArgument(0);
.thenAnswer(answer -> {
final UploadPartRequest up = answer.getArgument(0);
//emulate behave of S3 client otherwise we will get wrong array in the memory
up.setInputStream(new ByteArrayInputStream(up.getInputStream().readAllBytes()));
uploadPartRequests.add(up);
uploadPartRequests.put(up.getPartNumber(), up);
uploadPartContents.put(up.getPartNumber(), up.getInputStream().readAllBytes());

return newUploadPartResult(up.getPartNumber(), "SOME_TAG#" + up.getPartNumber());
});
when(mockedS3.completeMultipartUpload(completeMultipartUploadRequestCaptor.capture()))
.thenReturn(new CompleteMultipartUploadResult());

final byte[] message = new byte[messageSize];

final byte[] expectedFullMessage = new byte[messageSize + 10];
final byte[] expectedTailMessage = new byte[10];

final S3MultiPartOutputStream
out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, messageSize + 10, mockedS3);
random.nextBytes(message);
out.write(message);
System.arraycopy(message, 0, expectedFullMessage, 0, message.length);
random.nextBytes(message);
out.write(message);
System.arraycopy(message, 0, expectedFullMessage, 20, 10);
System.arraycopy(message, 10, expectedTailMessage, 0, 10);
final var out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, messageSize + 10, mockedS3);
{
final byte[] message = new byte[messageSize];
random.nextBytes(message);
out.write(message);
System.arraycopy(message, 0, expectedFullMessage, 0, message.length);
}
{
final byte[] message = new byte[messageSize];
random.nextBytes(message);
out.write(message);
System.arraycopy(message, 0, expectedFullMessage, 20, 10);
System.arraycopy(message, 10, expectedTailMessage, 0, 10);
}
out.close();

assertUploadPartRequest(uploadPartRequests.get(0), 30, 1, expectedFullMessage);
assertUploadPartRequest(uploadPartRequests.get(1), 10, 2, expectedTailMessage);
assertThat(uploadPartRequests).hasSize(2);
assertUploadPartRequest(uploadPartRequests.get(1), uploadPartContents.get(1), 30, 1, expectedFullMessage);
assertUploadPartRequest(uploadPartRequests.get(2), uploadPartContents.get(2), 10, 2, expectedTailMessage);

verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class));
verify(mockedS3, times(2)).uploadPart(any(UploadPartRequest.class));
Expand All @@ -251,6 +266,7 @@ private static UploadPartResult newUploadPartResult(final int partNumber, final
}

private static void assertUploadPartRequest(final UploadPartRequest uploadPartRequest,
final byte[] bytes,
final int expectedPartSize,
final int expectedPartNumber,
final byte[] expectedBytes) {
Expand All @@ -259,23 +275,17 @@ private static void assertUploadPartRequest(final UploadPartRequest uploadPartRe
assertThat(uploadPartRequest.getPartNumber()).isEqualTo(expectedPartNumber);
assertThat(uploadPartRequest.getBucketName()).isEqualTo(BUCKET_NAME);
assertThat(uploadPartRequest.getKey()).isEqualTo(FILE_KEY);
assertThat(uploadPartRequest.getInputStream()).hasBinaryContent(expectedBytes);
assertThat(bytes).isEqualTo(expectedBytes);
}

private static void assertCompleteMultipartUploadRequest(final CompleteMultipartUploadRequest request,
final List<PartETag> expectedETags) {
final Map<Integer, String> expectedETags) {
assertThat(request.getBucketName()).isEqualTo(BUCKET_NAME);
assertThat(request.getKey()).isEqualTo(FILE_KEY);
assertThat(request.getUploadId()).isEqualTo(UPLOAD_ID);
assertThat(request.getPartETags()).hasSameSizeAs(expectedETags);

for (int i = 0; i < expectedETags.size(); i++) {
final PartETag expectedETag = expectedETags.get(i);
final PartETag etag = request.getPartETags().get(i);

assertThat(etag.getPartNumber()).isEqualTo(expectedETag.getPartNumber());
assertThat(etag.getETag()).isEqualTo(expectedETag.getETag());
}
final Map<Integer, String> tags = request.getPartETags().stream()
.collect(Collectors.toMap(PartETag::getPartNumber, PartETag::getETag));
assertThat(tags).containsExactlyInAnyOrderEntriesOf(expectedETags);
}

private static void assertAbortMultipartUploadRequest(final AbortMultipartUploadRequest request) {
Expand Down

0 comments on commit a78f09a

Please sign in to comment.