Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -509,11 +509,10 @@ public synchronized int read(byte[] buf, int off, int len)
int bytesRead = invoker.retry("read", pathStr, true,
() -> {
int bytes;
// When exception happens before re-setting wrappedStream in "reopen" called
// by onReadFailure, then wrappedStream will be null. But the **retry** may
// re-execute this block and cause NPE if we don't check wrappedStream
// When exception happens, onReadFailure closes the stream and
// sets wrappedStream to null. In this case, do not reopen the stream.
if (wrappedStream == null) {
reopen("failure recovery", getPos(), 1, false);
return 0;
}
try {
bytes = wrappedStream.read(buf, off, len);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public class TestS3AInputStreamRetry extends AbstractS3AMockTest {

@Test
public void testInputStreamReadRetryForException() throws IOException {
S3AInputStream s3AInputStream = getMockedS3AInputStream();
S3AInputStream s3AInputStream = getMockedS3AInputStream(true);

assertEquals("'a' from the test input stream 'ab' should be the first " +
"character being read", INPUT.charAt(0), s3AInputStream.read());
Expand All @@ -64,26 +64,29 @@ public void testInputStreamReadRetryForException() throws IOException {
@Test
public void testInputStreamReadLengthRetryForException() throws IOException {
byte[] result = new byte[INPUT.length()];
S3AInputStream s3AInputStream = getMockedS3AInputStream();
s3AInputStream.read(result, 0, INPUT.length());
S3AInputStream s3AInputStream = getMockedS3AInputStream(true);
int bytesRead = s3AInputStream.read(result, 0, INPUT.length());

assertEquals("Zero bytes should be read on failure", 0, bytesRead);

assertArrayEquals(
"The read result should equals to the test input stream content",
INPUT.getBytes(), result);
"The read result should be empty",
new byte[INPUT.length()], result);
}

@Test
public void testInputStreamReadFullyRetryForException() throws IOException {
byte[] result = new byte[INPUT.length()];
S3AInputStream s3AInputStream = getMockedS3AInputStream();

S3AInputStream s3AInputStream = getMockedS3AInputStream(false);
s3AInputStream.readFully(0, result);

assertArrayEquals(
"The read result should equals to the test input stream content",
INPUT.getBytes(), result);
}

private S3AInputStream getMockedS3AInputStream() {
private S3AInputStream getMockedS3AInputStream(boolean triggerGetObjectFailure) {
Path path = new Path("test-path");
String eTag = "test-etag";
String versionId = "test-version-id";
Expand All @@ -109,15 +112,17 @@ private S3AInputStream getMockedS3AInputStream() {
return new S3AInputStream(
s3AReadOpContext,
s3ObjectAttributes,
getMockedInputStreamCallback());
getMockedInputStreamCallback(triggerGetObjectFailure));
}

/**
* Get mocked InputStreamCallbacks where we return mocked S3Object.
*
* @param triggerGetObjectFailure true when getObject failure is enabled.
* @return mocked object.
*/
private S3AInputStream.InputStreamCallbacks getMockedInputStreamCallback() {
private S3AInputStream.InputStreamCallbacks getMockedInputStreamCallback(
boolean triggerGetObjectFailure) {
return new S3AInputStream.InputStreamCallbacks() {

private final S3Object mockedS3Object = getMockedS3Object();
Expand All @@ -140,7 +145,7 @@ public S3Object getObject(GetObjectRequest request) {
// -> retry(3) -> wrappedStream==null -> reopen -> getObject (mockedS3ObjectIndex=4)
// -> getObjectContent(objectInputStreamGood)-> objectInputStreamGood
// -> wrappedStream.read
if (mockedS3ObjectIndex == 3) {
if (mockedS3ObjectIndex == 3 && triggerGetObjectFailure) {
throw new SdkClientException("Failed to get S3Object");
}
return mockedS3Object;
Expand Down Expand Up @@ -207,20 +212,21 @@ private S3ObjectInputStream getMockedInputStream(boolean triggerFailure) {

@Override
public int read() throws IOException {
int result = super.read();
if (triggerFailure) {
throw exception;
}
return result;

return super.read();
}

@Override
public int read(byte[] b, int off, int len) throws IOException {
int result = super.read(b, off, len);

if (triggerFailure) {
throw exception;
}
return result;

return super.read(b, off, len);
}
};
}
Expand Down