Skip to content

Commit

Permalink
fix(236): Use non-crt async client for S3ReadAheadByteChannel
Browse files Browse the repository at this point in the history
  • Loading branch information
guicamest committed Nov 4, 2023
1 parent 0c0deab commit e4a9d68
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 14 deletions.
15 changes: 10 additions & 5 deletions src/main/java/software/amazon/nio/spi/s3/S3ClientProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,11 @@ <T extends AwsClient> T universalClient(boolean async) {
* discovery client.
*
* @param bucket the named of the bucket to make the client for
* @param crt whether to return a CRT async client or not
* @return an S3 client appropriate for the region of the named bucket
*/
protected S3AsyncClient generateAsyncClient(String bucket) {
return generateAsyncClient(bucket, universalClient());
protected S3AsyncClient generateAsyncClient(String bucket, boolean crt) {
return generateAsyncClient(bucket, universalClient(), crt);
}

/**
Expand All @@ -165,10 +166,11 @@ S3Client generateSyncClient(String bucketName, S3Client locationClient) {
* @param bucketName the name of the bucket to make the client for
* @param locationClient the client used to determine the location of the
* named bucket, recommend using DEFAULT_CLIENT
* @param crt whether to return a CRT async client or not
* @return an S3 client appropriate for the region of the named bucket
*/
S3AsyncClient generateAsyncClient(String bucketName, S3Client locationClient) {
return getClientForBucket(bucketName, locationClient, this::asyncClientForRegion);
S3AsyncClient generateAsyncClient(String bucketName, S3Client locationClient, boolean crt) {
return getClientForBucket(bucketName, locationClient, (region) -> asyncClientForRegion(region, crt));
}

private <T extends AwsClient> T getClientForBucket(
Expand Down Expand Up @@ -244,7 +246,10 @@ private S3Client clientForRegion(String regionName) {
return configureClientForRegion(regionName, S3Client.builder());
}

private S3AsyncClient asyncClientForRegion(String regionName) {
private S3AsyncClient asyncClientForRegion(String regionName, boolean crt) {
if (!crt) {
return configureClientForRegion(regionName, S3AsyncClient.builder());
}
Region region = getRegionFromRegionName(regionName);
logger.debug("bucket region is: '{}'", region.id());

Expand Down
6 changes: 5 additions & 1 deletion src/main/java/software/amazon/nio/spi/s3/S3FileSystem.java
Original file line number Diff line number Diff line change
Expand Up @@ -404,12 +404,16 @@ public void clientProvider(S3ClientProvider clientProvider) {
*/
S3AsyncClient client() {
if (client == null) {
client = clientProvider.generateAsyncClient(bucketName);
client = clientProvider.generateAsyncClient(bucketName, true);
}

return client;
}

S3AsyncClient readClient() {
return clientProvider.generateAsyncClient(bucketName, false);
}

/**
* Obtain the name of the bucket represented by this <code>FileSystem</code> instance
* @return the bucket name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ public void close() {
open = false;
readAheadBuffersCache.invalidateAll();
readAheadBuffersCache.cleanUp();
client.close();
}

private void clearPriorFragments(int currentFragIndx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ private S3SeekableByteChannel(S3Path s3Path, S3AsyncClient s3Client, long startA
position = 0L;
} else if (options.contains(StandardOpenOption.READ) || options.isEmpty()) {
LOGGER.debug("using S3ReadAheadByteChannel as read delegate for path '{}'", s3Path.toUri());
readDelegate = new S3ReadAheadByteChannel(s3Path, config.getMaxFragmentSize(), config.getMaxFragmentNumber(), s3Client, this, timeout, timeUnit);
S3AsyncClient readClient = s3Path.getFileSystem().readClient();
readDelegate = new S3ReadAheadByteChannel(s3Path, config.getMaxFragmentSize(), config.getMaxFragmentNumber(), readClient, this, timeout, timeUnit);
writeDelegate = null;
} else {
throw new IOException("Invalid channel mode");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ S3Client universalClient() {
}

@Override
protected S3AsyncClient generateAsyncClient(String bucketName) {
protected S3AsyncClient generateAsyncClient(String bucketName, boolean crt) {
return (S3AsyncClient)client;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public void initialization() {
public void testGenerateAsyncClientWithNoErrors() {
when(mockClient.getBucketLocation(anyConsumer()))
.thenReturn(GetBucketLocationResponse.builder().locationConstraint("us-west-2").build());
final S3AsyncClient s3Client = provider.generateAsyncClient("test-bucket", mockClient);
final S3AsyncClient s3Client = provider.generateAsyncClient("test-bucket", mockClient, true);
assertNotNull(s3Client);
}

Expand Down Expand Up @@ -107,7 +107,7 @@ public void testGenerateAsyncClientWith403Response() {
.build());

// which should get you a client
final S3AsyncClient s3Client = provider.generateAsyncClient("test-bucket", mockClient);
final S3AsyncClient s3Client = provider.generateAsyncClient("test-bucket", mockClient, true);
assertNotNull(s3Client);

final InOrder inOrder = inOrder(mockClient);
Expand Down Expand Up @@ -135,7 +135,7 @@ public void testGenerateAsyncClientWith403Then301Responses(){
);

// then you should be able to get a client as long as the error response header contains the region
final S3AsyncClient s3Client = provider.generateAsyncClient("test-bucket", mockClient);
final S3AsyncClient s3Client = provider.generateAsyncClient("test-bucket", mockClient, true);
assertNotNull(s3Client);

final InOrder inOrder = inOrder(mockClient);
Expand Down Expand Up @@ -189,7 +189,7 @@ public void testGenerateAsyncClientWith403Then301ResponsesNoHeader(){
);

// then you should get a NoSuchElement exception when you try to get the header
assertThrows(NoSuchElementException.class, () -> provider.generateAsyncClient("test-bucket", mockClient));
assertThrows(NoSuchElementException.class, () -> provider.generateAsyncClient("test-bucket", mockClient, true));

final InOrder inOrder = inOrder(mockClient);
inOrder.verify(mockClient).getBucketLocation(anyConsumer());
Expand All @@ -203,12 +203,12 @@ public void generateAsyncClientByEndpointBucketCredentials() {
provider.asyncClientBuilder = BUILDER;

provider.configuration.withEndpoint("endpoint1:1010");
provider.generateAsyncClient("bucket1");
provider.generateAsyncClient("bucket1", true);
then(BUILDER.endpointOverride.toString()).isEqualTo("https://endpoint1:1010");
then(BUILDER.region).isEqualTo(Region.US_EAST_1); // just a default in the case not provide

provider.configuration.withEndpoint("endpoint2:2020");
provider.generateAsyncClient("bucket2");
provider.generateAsyncClient("bucket2", true);
then(BUILDER.endpointOverride.toString()).isEqualTo("https://endpoint2:2020");
then(BUILDER.region).isEqualTo(Region.US_EAST_1); // just a default in the case not provide
}
Expand Down

0 comments on commit e4a9d68

Please sign in to comment.