diff --git a/hadoop-tools/hadoop-aws/src/main/java/org/apache/hadoop/fs/s3a/S3AInputStream.java b/hadoop-tools/hadoop-aws/src/main/java/org/apache/hadoop/fs/s3a/S3AInputStream.java index 3a0b669543edf..a31dd2087c3bd 100644 --- a/hadoop-tools/hadoop-aws/src/main/java/org/apache/hadoop/fs/s3a/S3AInputStream.java +++ b/hadoop-tools/hadoop-aws/src/main/java/org/apache/hadoop/fs/s3a/S3AInputStream.java @@ -509,11 +509,10 @@ public synchronized int read(byte[] buf, int off, int len) int bytesRead = invoker.retry("read", pathStr, true, () -> { int bytes; - // When exception happens before re-setting wrappedStream in "reopen" called - // by onReadFailure, then wrappedStream will be null. But the **retry** may - // re-execute this block and cause NPE if we don't check wrappedStream + // When exception happens, onReadFailure closes the stream and + // sets wrappedStream to null. In this case, do not reopen the stream. if (wrappedStream == null) { - reopen("failure recovery", getPos(), 1, false); + return 0; } try { bytes = wrappedStream.read(buf, off, len); diff --git a/hadoop-tools/hadoop-aws/src/test/java/org/apache/hadoop/fs/s3a/TestS3AInputStreamRetry.java b/hadoop-tools/hadoop-aws/src/test/java/org/apache/hadoop/fs/s3a/TestS3AInputStreamRetry.java index db5b5b56851ea..32515ad534966 100644 --- a/hadoop-tools/hadoop-aws/src/test/java/org/apache/hadoop/fs/s3a/TestS3AInputStreamRetry.java +++ b/hadoop-tools/hadoop-aws/src/test/java/org/apache/hadoop/fs/s3a/TestS3AInputStreamRetry.java @@ -53,7 +53,7 @@ public class TestS3AInputStreamRetry extends AbstractS3AMockTest { @Test public void testInputStreamReadRetryForException() throws IOException { - S3AInputStream s3AInputStream = getMockedS3AInputStream(); + S3AInputStream s3AInputStream = getMockedS3AInputStream(true); assertEquals("'a' from the test input stream 'ab' should be the first " + "character being read", INPUT.charAt(0), s3AInputStream.read()); @@ -64,18 +64,21 @@ public void testInputStreamReadRetryForException() throws IOException { @Test public void testInputStreamReadLengthRetryForException() throws IOException { byte[] result = new byte[INPUT.length()]; - S3AInputStream s3AInputStream = getMockedS3AInputStream(); - s3AInputStream.read(result, 0, INPUT.length()); + S3AInputStream s3AInputStream = getMockedS3AInputStream(true); + int bytesRead = s3AInputStream.read(result, 0, INPUT.length()); + + assertEquals("Zero bytes should be read on failure", 0, bytesRead); assertArrayEquals( - "The read result should equals to the test input stream content", - INPUT.getBytes(), result); + "The read result should be empty", + new byte[INPUT.length()], result); } @Test public void testInputStreamReadFullyRetryForException() throws IOException { byte[] result = new byte[INPUT.length()]; - S3AInputStream s3AInputStream = getMockedS3AInputStream(); + + S3AInputStream s3AInputStream = getMockedS3AInputStream(false); s3AInputStream.readFully(0, result); assertArrayEquals( @@ -83,7 +86,7 @@ public void testInputStreamReadFullyRetryForException() throws IOException { INPUT.getBytes(), result); } - private S3AInputStream getMockedS3AInputStream() { + private S3AInputStream getMockedS3AInputStream(boolean triggerGetObjectFailure) { Path path = new Path("test-path"); String eTag = "test-etag"; String versionId = "test-version-id"; @@ -109,15 +112,17 @@ private S3AInputStream getMockedS3AInputStream() { return new S3AInputStream( s3AReadOpContext, s3ObjectAttributes, - getMockedInputStreamCallback()); + getMockedInputStreamCallback(triggerGetObjectFailure)); } /** * Get mocked InputStreamCallbacks where we return mocked S3Object. * + * @param triggerGetObjectFailure true when getObject failure is enabled. * @return mocked object. */ - private S3AInputStream.InputStreamCallbacks getMockedInputStreamCallback() { + private S3AInputStream.InputStreamCallbacks getMockedInputStreamCallback( + boolean triggerGetObjectFailure) { return new S3AInputStream.InputStreamCallbacks() { private final S3Object mockedS3Object = getMockedS3Object(); @@ -140,7 +145,7 @@ public S3Object getObject(GetObjectRequest request) { // -> retry(3) -> wrappedStream==null -> reopen -> getObject (mockedS3ObjectIndex=4) // -> getObjectContent(objectInputStreamGood)-> objectInputStreamGood // -> wrappedStream.read - if (mockedS3ObjectIndex == 3) { + if (mockedS3ObjectIndex == 3 && triggerGetObjectFailure) { throw new SdkClientException("Failed to get S3Object"); } return mockedS3Object; @@ -207,20 +212,21 @@ private S3ObjectInputStream getMockedInputStream(boolean triggerFailure) { @Override public int read() throws IOException { - int result = super.read(); if (triggerFailure) { throw exception; } - return result; + + return super.read(); } @Override public int read(byte[] b, int off, int len) throws IOException { - int result = super.read(b, off, len); + if (triggerFailure) { throw exception; } - return result; + + return super.read(b, off, len); } }; }