diff --git a/.changes/next-release/bugfix-AmazonS3-421839e.json b/.changes/next-release/bugfix-AmazonS3-421839e.json
new file mode 100644
index 000000000000..117be49f54bb
--- /dev/null
+++ b/.changes/next-release/bugfix-AmazonS3-421839e.json
@@ -0,0 +1,6 @@
+{
+ "category": "Amazon S3",
+ "contributor": "",
+ "type": "bugfix",
+ "description": "Fixed an issue where checksum validation only considered the first 4 bytes of the 16 byte checksum, creating the potential for corrupted downloads to go undetected."
+}
diff --git a/pom.xml b/pom.xml
index 760e167ec03c..415f8741b8e1 100644
--- a/pom.xml
+++ b/pom.xml
@@ -511,9 +511,8 @@
*.internal.*
-
- software.amazon.awssdk.core.util.json.JacksonUtils
- software.amazon.awssdk.protocols.json.*
+
+ software.amazon.awssdk.services.s3.checksums.ChecksumValidatingInputStream
diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStream.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStream.java
index ab089377e8cc..36fff298baf0 100644
--- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStream.java
+++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStream.java
@@ -17,11 +17,12 @@
import java.io.IOException;
import java.io.InputStream;
-import java.nio.ByteBuffer;
+import java.util.Arrays;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.checksums.SdkChecksum;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.http.Abortable;
+import software.amazon.awssdk.utils.BinaryUtils;
@SdkInternalApi
public class ChecksumValidatingInputStream extends InputStream implements Abortable {
@@ -34,7 +35,7 @@ public class ChecksumValidatingInputStream extends InputStream implements Aborta
private long lengthRead = 0;
// Preserve the computed checksum because some InputStream readers (e.g., java.util.Properties) read more than once at the
// end of the stream.
- private Integer computedChecksum;
+ private byte[] computedChecksum;
/**
* Creates an input stream using the specified Checksum, input stream, and length.
@@ -162,26 +163,15 @@ public void close() throws IOException {
inputStream.close();
}
- /**
- * Gets the stream's checksum as an integer.
- *
- * @return checksum.
- */
- public int getStreamChecksum() {
- ByteBuffer bb = ByteBuffer.wrap(streamChecksum);
- return bb.getInt();
- }
-
private void validateAndThrow() {
- int streamChecksumInt = getStreamChecksum();
if (computedChecksum == null) {
- computedChecksum = ByteBuffer.wrap(checkSum.getChecksumBytes()).getInt();
+ computedChecksum = checkSum.getChecksumBytes();
}
- if (streamChecksumInt != computedChecksum) {
+ if (!Arrays.equals(computedChecksum, streamChecksum)) {
throw SdkClientException.builder().message(
- String.format("Data read has a different checksum than expected. Was %d, but expected %d",
- computedChecksum, streamChecksumInt)).build();
+ String.format("Data read has a different checksum than expected. Was 0x%s, but expected 0x%s",
+ BinaryUtils.toHex(computedChecksum), BinaryUtils.toHex(streamChecksum))).build();
}
}
diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisher.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisher.java
index 2c871470d84e..a3310331dd23 100644
--- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisher.java
+++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisher.java
@@ -132,12 +132,11 @@ public void onError(Throwable t) {
@Override
public void onComplete() {
if (strippedLength > 0) {
- int streamChecksumInt = ByteBuffer.wrap(streamChecksum).getInt();
- int computedChecksumInt = ByteBuffer.wrap(sdkChecksum.getChecksumBytes()).getInt();
- if (streamChecksumInt != computedChecksumInt) {
+ byte[] computedChecksum = sdkChecksum.getChecksumBytes();
+ if (!Arrays.equals(computedChecksum, streamChecksum)) {
onError(SdkClientException.create(
- String.format("Data read has a different checksum than expected. Was %d, but expected %d",
- computedChecksumInt, streamChecksumInt)));
+ String.format("Data read has a different checksum than expected. Was 0x%s, but expected 0x%s",
+ BinaryUtils.toHex(computedChecksum), BinaryUtils.toHex(streamChecksum))));
return; // Return after onError and not call onComplete below
}
}
diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStreamTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStreamTest.java
new file mode 100644
index 000000000000..a83bf45b4155
--- /dev/null
+++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingInputStreamTest.java
@@ -0,0 +1,87 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package software.amazon.awssdk.services.s3.checksums;
+
+import static org.junit.Assert.assertArrayEquals;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.Arrays;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import software.amazon.awssdk.core.checksums.Md5Checksum;
+import software.amazon.awssdk.core.exception.SdkClientException;
+import software.amazon.awssdk.utils.IoUtils;
+
+public class ChecksumValidatingInputStreamTest {
+ private static final int TEST_DATA_SIZE = 32;
+ private static final int CHECKSUM_SIZE = 16;
+
+ private static byte[] testData;
+ private static byte[] testDataWithoutChecksum;
+
+ @BeforeClass
+ public static void populateData() {
+ testData = new byte[TEST_DATA_SIZE + CHECKSUM_SIZE];
+ for (int i = 0; i < TEST_DATA_SIZE; i++) {
+ testData[i] = (byte)(i & 0x7f);
+ }
+
+ Md5Checksum checksum = new Md5Checksum();
+ checksum.update(testData, 0, TEST_DATA_SIZE);
+ byte[] checksumBytes = checksum.getChecksumBytes();
+
+ for (int i = 0; i < CHECKSUM_SIZE; i++) {
+ testData[TEST_DATA_SIZE + i] = checksumBytes[i];
+ }
+
+ testDataWithoutChecksum = Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE);
+ }
+
+ @Test
+ public void validChecksumSucceeds() throws IOException {
+ InputStream validatingInputStream = newValidatingStream(testData);
+ byte[] dataFromValidatingStream = IoUtils.toByteArray(validatingInputStream);
+
+ assertArrayEquals(testDataWithoutChecksum, dataFromValidatingStream);
+ }
+
+ @Test
+ public void invalidChecksumFails() throws IOException {
+ for (int i = 0; i < testData.length; i++) {
+ // Make sure that corruption of any byte in the test data causes a checksum validation failure.
+ byte[] corruptedChecksumData = Arrays.copyOf(testData, testData.length);
+ corruptedChecksumData[i] = (byte) ~corruptedChecksumData[i];
+
+ InputStream validatingInputStream = newValidatingStream(corruptedChecksumData);
+
+ try {
+ IoUtils.toByteArray(validatingInputStream);
+ Assert.fail("Corruption at byte " + i + " was not detected.");
+ } catch (SdkClientException e) {
+ // Expected
+ }
+ }
+ }
+
+ private InputStream newValidatingStream(byte[] dataFromS3) {
+ return new ChecksumValidatingInputStream(new ByteArrayInputStream(dataFromS3),
+ new Md5Checksum(),
+ TEST_DATA_SIZE + CHECKSUM_SIZE);
+ }
+}
\ No newline at end of file
diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisherTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisherTest.java
index 935b656d8539..23027a2317fc 100644
--- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisherTest.java
+++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumValidatingPublisherTest.java
@@ -16,11 +16,13 @@
package software.amazon.awssdk.services.s3.checksums;
import static org.junit.Assert.assertArrayEquals;
-import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
@@ -31,6 +33,7 @@
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.core.checksums.Md5Checksum;
+import software.amazon.awssdk.utils.BinaryUtils;
/**
* Unit test for ChecksumValidatingPublisher
@@ -39,6 +42,7 @@ public class ChecksumValidatingPublisherTest {
private static int TEST_DATA_SIZE = 32; // size of the test data, in bytes
private static final int CHECKSUM_SIZE = 16;
private static byte[] testData;
+ private static byte[] testDataWithoutChecksum;
@BeforeClass
public static void populateData() {
@@ -52,27 +56,47 @@ public static void populateData() {
for (int i = 0; i < CHECKSUM_SIZE; i++) {
testData[TEST_DATA_SIZE + i] = checksumBytes[i];
}
+
+ testDataWithoutChecksum = Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE);
}
@Test
public void testSinglePacket() {
final TestPublisher driver = new TestPublisher();
- final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE));
+ final TestSubscriber s = new TestSubscriber();
final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE);
p.subscribe(s);
driver.doOnNext(ByteBuffer.wrap(testData));
driver.doOnComplete();
+ assertArrayEquals(testDataWithoutChecksum, s.receivedData());
assertTrue(s.hasCompleted());
assertFalse(s.isOnErrorCalled());
}
+ @Test
+ public void testLastChecksumByteCorrupted() {
+ TestPublisher driver = new TestPublisher();
+
+ TestSubscriber s = new TestSubscriber();
+ ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE);
+ p.subscribe(s);
+
+ byte[] incorrectChecksumData = Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE);
+ incorrectChecksumData[TEST_DATA_SIZE - 1] = (byte) ~incorrectChecksumData[TEST_DATA_SIZE - 1];
+ driver.doOnNext(ByteBuffer.wrap(incorrectChecksumData));
+ driver.doOnComplete();
+
+ assertFalse(s.hasCompleted());
+ assertTrue(s.isOnErrorCalled());
+ }
+
@Test
public void testTwoPackets() {
for (int i = 1; i < TEST_DATA_SIZE + CHECKSUM_SIZE - 1; i++) {
final TestPublisher driver = new TestPublisher();
- final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE));
+ final TestSubscriber s = new TestSubscriber();
final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE);
p.subscribe(s);
@@ -80,6 +104,7 @@ public void testTwoPackets() {
driver.doOnNext(ByteBuffer.wrap(testData, i, TEST_DATA_SIZE + CHECKSUM_SIZE - i));
driver.doOnComplete();
+ assertArrayEquals(testDataWithoutChecksum, s.receivedData());
assertTrue(s.hasCompleted());
assertFalse(s.isOnErrorCalled());
}
@@ -89,7 +114,7 @@ public void testTwoPackets() {
public void testTinyPackets() {
for (int packetSize = 1; packetSize < CHECKSUM_SIZE; packetSize++) {
final TestPublisher driver = new TestPublisher();
- final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE));
+ final TestSubscriber s = new TestSubscriber();
final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE);
p.subscribe(s);
int currOffset = 0;
@@ -100,6 +125,7 @@ public void testTinyPackets() {
}
driver.doOnComplete();
+ assertArrayEquals(testDataWithoutChecksum, s.receivedData());
assertTrue(s.hasCompleted());
assertFalse(s.isOnErrorCalled());
}
@@ -109,7 +135,7 @@ public void testTinyPackets() {
public void testUnknownLength() {
// When the length is unknown, the last 16 bytes are treated as a checksum, but are later ignored when completing
final TestPublisher driver = new TestPublisher();
- final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(testData, 0, TEST_DATA_SIZE));
+ final TestSubscriber s = new TestSubscriber();
final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), 0);
p.subscribe(s);
@@ -122,6 +148,7 @@ public void testUnknownLength() {
driver.doOnNext(ByteBuffer.wrap(randomChecksumData));
driver.doOnComplete();
+ assertArrayEquals(testDataWithoutChecksum, s.receivedData());
assertTrue(s.hasCompleted());
assertFalse(s.isOnErrorCalled());
}
@@ -130,7 +157,7 @@ public void testUnknownLength() {
public void checksumValidationFailure_throwsSdkClientException_NotNPE() {
final byte[] incorrectData = new byte[0];
final TestPublisher driver = new TestPublisher();
- final TestSubscriber s = new TestSubscriber(Arrays.copyOfRange(incorrectData, 0, TEST_DATA_SIZE));
+ final TestSubscriber s = new TestSubscriber();
final ChecksumValidatingPublisher p = new ChecksumValidatingPublisher(driver, new Md5Checksum(), TEST_DATA_SIZE + CHECKSUM_SIZE);
p.subscribe(s);
@@ -142,13 +169,11 @@ public void checksumValidationFailure_throwsSdkClientException_NotNPE() {
}
private class TestSubscriber implements Subscriber {
- final byte[] expected;
final List received;
boolean completed;
boolean onErrorCalled;
- TestSubscriber(byte[] expected) {
- this.expected = expected;
+ TestSubscriber() {
this.received = new ArrayList<>();
this.completed = false;
}
@@ -172,17 +197,21 @@ public void onError(Throwable t) {
@Override
public void onComplete() {
- int matchPos = 0;
- for (ByteBuffer buffer : received) {
- byte[] bufferData = new byte[buffer.limit() - buffer.position()];
- buffer.get(bufferData);
- assertArrayEquals(Arrays.copyOfRange(expected, matchPos, matchPos + bufferData.length), bufferData);
- matchPos += bufferData.length;
- }
- assertEquals(expected.length, matchPos);
completed = true;
}
+ public byte[] receivedData() {
+ try {
+ ByteArrayOutputStream os = new ByteArrayOutputStream();
+ for (ByteBuffer buffer : received) {
+ os.write(BinaryUtils.copyBytesFrom(buffer));
+ }
+ return os.toByteArray();
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
public boolean hasCompleted() {
return completed;
}