Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(Files): Add integration tests for Files.copy / read* #252

Merged
merged 7 commits into from
Nov 6, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ public static void putObject(String bucket, String key) {
.doesNotThrowAnyException();
}

public static void putObject(String bucket, String key, String content) {
assertThatCode(() -> {
Container.ExecResult execResultCreateFile = LOCAL_STACK_CONTAINER.execInContainer("sh", "-c", "echo -n '" + content + "' > " + key);
Container.ExecResult execResultPut = LOCAL_STACK_CONTAINER.execInContainer(("awslocal s3api put-object --bucket " + bucket + " --key " + key + " --body " + key).split(" "));

assertThat(execResultCreateFile.getExitCode()).isZero();
assertThat(execResultPut.getExitCode()).withFailMessage("Failed put: %s ", execResultPut.getStderr()).isZero();
}).as("Failed to put object '%s' in bucket '%s'", key, bucket)
.doesNotThrowAnyException();
}

public static String localStackConnectionEndpoint() {
return localStackConnectionEndpoint(null, null);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package software.amazon.nio.spi.s3;

import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;

import java.io.IOException;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;

import static org.assertj.core.api.Assertions.assertThat;
import static software.amazon.nio.spi.s3.Containers.localStackConnectionEndpoint;
import static software.amazon.nio.spi.s3.Containers.putObject;

@DisplayName("Files$copy should load file contents from localstack")
public class FilesCopyTest
{
@TempDir
Path tempDir;

@Test
@DisplayName("when doing copy of existing file")
public void fileCopyShouldCopyFileWhenFileFound() throws IOException {
Containers.createBucket("sink");
putObject("sink", "files-copy.txt", "some content");
final Path path = Paths.get(URI.create(localStackConnectionEndpoint() + "/sink/files-copy.txt"));
Path copiedFile = Files.copy(path, tempDir.resolve("sample-file-local.txt"));
assertThat(copiedFile).hasContent("some content");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package software.amazon.nio.spi.s3;

import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;

import static org.assertj.core.api.BDDAssertions.then;
import static software.amazon.nio.spi.s3.Containers.localStackConnectionEndpoint;
import static software.amazon.nio.spi.s3.Containers.putObject;

@DisplayName("Files$read* should load file contents from localstack")
public class FilesReadTest
{
private final Path path = Paths.get(URI.create(localStackConnectionEndpoint() + "/sink/files-read.txt"));

@BeforeAll
public static void createBucketAndFile(){
Containers.createBucket("sink");
putObject("sink", "files-read.txt", "some content");
}

@Test
@DisplayName("when doing readAllBytes from existing file in s3")
public void fileReadAllBytesShouldReturnFileContentsWhenFileFound() throws IOException {
then(Files.readAllBytes(path)).isEqualTo("some content".getBytes());
}

@Test
@DisplayName("when doing readAllLines from existing file in s3")
public void fileReadAllLinesShouldReturnFileContentWhenFileFound() throws IOException {
then(String.join("", Files.readAllLines(path))).isEqualTo("some content");
}

}
103 changes: 57 additions & 46 deletions src/main/java/software/amazon/nio/spi/s3/S3ClientProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3BaseClientBuilder;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;
import software.amazon.awssdk.services.s3.S3CrtAsyncClientBuilder;
import software.amazon.awssdk.services.s3.model.HeadBucketResponse;
import software.amazon.awssdk.services.s3.model.S3Exception;
Expand All @@ -55,15 +55,15 @@ public class S3ClientProvider {
final protected S3NioSpiConfiguration configuration;

/**
* Default client using the "https://s3.us-east-1.amazonaws.com" endpoint
* Default client using the "<a href="https://s3.us-east-1.amazonaws.com">...</a>" endpoint
*/
private static final S3Client DEFAULT_CLIENT = S3Client.builder()
.endpointOverride(URI.create("https://s3.us-east-1.amazonaws.com"))
.region(Region.US_EAST_1)
.build();

/**
* Default asynchronous client using the "https://s3.us-east-1.amazonaws.com" endpoint
* Default asynchronous client using the "<a href="https://s3.us-east-1.amazonaws.com">...</a>" endpoint
*/
private static final S3AsyncClient DEFAULT_ASYNC_CLIENT = S3AsyncClient.builder()
.endpointOverride(URI.create("https://s3.us-east-1.amazonaws.com"))
Expand Down Expand Up @@ -130,34 +130,32 @@ S3Client universalClient() {
* that can be used by certain S3 operations for discovery
*
* @param async true to return an asynchronous client, false otherwise
* @param <T> type of AwsClient
* @param <T> type of AwsClient
* @return a S3Client not bound to a region
*/
<T extends AwsClient> T universalClient(boolean async) {
return (T)((async) ? DEFAULT_ASYNC_CLIENT : DEFAULT_CLIENT);
return (T) ((async) ? DEFAULT_ASYNC_CLIENT : DEFAULT_CLIENT);
}

/**
* Generates a sync client for the named bucket using the provided location
* discovery client.
*
* @param bucket the named of the bucket to make the client for
*
* @param crt whether to return a CRT async client or not
* @return an S3 client appropriate for the region of the named bucket
*
*/
protected S3AsyncClient generateAsyncClient(String bucket) {
return generateAsyncClient(bucket, universalClient());
protected S3AsyncClient generateAsyncClient(String bucket, boolean crt) {
return generateAsyncClient(bucket, universalClient(), crt);
}

/**
* Generate a client for the named bucket using a provided client to
* determine the location of the named client
*
* @param bucketName the name of the bucket to make the client for
* @param bucketName the name of the bucket to make the client for
* @param locationClient the client used to determine the location of the
* named bucket, recommend using DEFAULT_CLIENT
*
* named bucket, recommend using DEFAULT_CLIENT
* @return an S3 client appropriate for the region of the named bucket
*/
S3Client generateSyncClient(String bucketName, S3Client locationClient) {
Expand All @@ -168,14 +166,14 @@ S3Client generateSyncClient(String bucketName, S3Client locationClient) {
* Generate an async client for the named bucket using a provided client to
* determine the location of the named client
*
* @param bucketName the name of the bucket to make the client for
* @param bucketName the name of the bucket to make the client for
* @param locationClient the client used to determine the location of the
* named bucket, recommend using DEFAULT_CLIENT
*
* named bucket, recommend using DEFAULT_CLIENT
* @param crt whether to return a CRT async client or not
* @return an S3 client appropriate for the region of the named bucket
*/
S3AsyncClient generateAsyncClient (String bucketName, S3Client locationClient) {
return getClientForBucket(bucketName, locationClient, this::asyncClientForRegion);
S3AsyncClient generateAsyncClient(String bucketName, S3Client locationClient, boolean crt) {
return getClientForBucket(bucketName, locationClient, (region) -> asyncClientForRegion(region, crt));
}

private <T extends AwsClient> T getClientForBucket(
Expand All @@ -186,11 +184,8 @@ private <T extends AwsClient> T getClientForBucket(
logger.debug("generating client for bucket: '{}'", bucketName);
T bucketSpecificClient = null;

if ((configuration.getEndpoint() == null) || configuration.getEndpoint().isBlank()) {
//
// we try to locate a bucket only if no endpoint is provided, which
// means we are dealing with AWS S3 buckets
//
if (configuration.endpointURI() == null) {
// we try to locate a bucket only if no endpoint is provided, which means we are dealing with AWS S3 buckets
String bucketLocation = determineBucketLocation(bucketName, locationClient);

if ( bucketLocation != null) {
Expand Down Expand Up @@ -248,48 +243,64 @@ private String getBucketRegionFromResponse(SdkHttpResponse response) {
}

private S3Client clientForRegion(String regionName) {
String endpoint = configuration.getEndpoint();
AwsCredentials credentials = configuration.getCredentials();
Region region = ((regionName == null) || (regionName.trim().isEmpty())) ? Region.US_EAST_1 : Region.of(regionName);
return configureClientForRegion(regionName, S3Client.builder());
}

private S3AsyncClient asyncClientForRegion(String regionName, boolean crt) {
if (!crt) {
return configureClientForRegion(regionName, S3AsyncClient.builder());
}
return configureCrtClientForRegion(regionName);
}

private <ActualClient extends AwsClient, ActualBuilder extends S3BaseClientBuilder<ActualBuilder, ActualClient>> ActualClient configureClientForRegion(
String regionName,
S3BaseClientBuilder<ActualBuilder, ActualClient> builder)
{
Region region = getRegionFromRegionName(regionName);
logger.debug("bucket region is: '{}'", region.id());

S3ClientBuilder clientBuilder = S3Client.builder()
.forcePathStyle(configuration.getForcePathStyle())
.region(region)
.overrideConfiguration(
conf -> conf.retryPolicy(
builder -> builder.retryCondition(retryCondition).backoffStrategy(backoffStrategy)
)
);

if (!endpoint.isBlank()) {
clientBuilder.endpointOverride(URI.create(configuration.getEndpointProtocol() + "://" + endpoint));
builder
.forcePathStyle(configuration.getForcePathStyle())
.region(region)
.overrideConfiguration(
conf -> conf.retryPolicy(
configBuilder -> configBuilder.retryCondition(retryCondition).backoffStrategy(backoffStrategy)
)
);

URI endpointUri = configuration.endpointURI();
if (endpointUri != null) {
builder.endpointOverride(endpointUri);
}

AwsCredentials credentials = configuration.getCredentials();
if (credentials != null) {
clientBuilder.credentialsProvider(() -> credentials);
builder.credentialsProvider(() -> credentials);
}

return clientBuilder.build();
return builder.build();
}

private S3AsyncClient asyncClientForRegion(String regionName) {
String endpoint = configuration.getEndpoint();
AwsCredentials credentials = configuration.getCredentials();

Region region = ((regionName == null) || (regionName.trim().isEmpty())) ? Region.US_EAST_1 : Region.of(regionName);

private S3AsyncClient configureCrtClientForRegion(String regionName) {
Region region = getRegionFromRegionName(regionName);
logger.debug("bucket region is: '{}'", region.id());

if (!endpoint.isBlank()) {
asyncClientBuilder.endpointOverride(URI.create(configuration.getEndpointProtocol() + "://" + endpoint));
URI endpointUri = configuration.endpointURI();
if (endpointUri != null) {
asyncClientBuilder.endpointOverride(endpointUri);
}

AwsCredentials credentials = configuration.getCredentials();
if (credentials != null) {
asyncClientBuilder.credentialsProvider(() -> credentials);
}

return asyncClientBuilder.forcePathStyle(configuration.getForcePathStyle()).region(region).build();
}

private static Region getRegionFromRegionName(String regionName) {
return (regionName == null || regionName.isBlank()) ? Region.US_EAST_1 : Region.of(regionName);
}

}
6 changes: 5 additions & 1 deletion src/main/java/software/amazon/nio/spi/s3/S3FileSystem.java
Original file line number Diff line number Diff line change
Expand Up @@ -404,12 +404,16 @@ public void clientProvider(S3ClientProvider clientProvider) {
*/
S3AsyncClient client() {
if (client == null) {
client = clientProvider.generateAsyncClient(bucketName);
client = clientProvider.generateAsyncClient(bucketName, true);
}

return client;
}

S3AsyncClient readClient() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why you wouldn't just overload client() with client(boolean crtClient)? Either way, it is probably worth documenting why a CRT client is not used in this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The client() method saves a reference of that client in the current class. Unlike client(), readClient() creates a new client every time a read needs to be done. This client is later closed when the readChannel is close(). I think it'd be clearer to have them in separate methods for now, one is a general purpose and the other one is only for reads.

return clientProvider.generateAsyncClient(bucketName, false);
}

/**
* Obtain the name of the bucket represented by this <code>FileSystem</code> instance
* @return the bucket name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ public void close() {
open = false;
readAheadBuffersCache.invalidateAll();
readAheadBuffersCache.cleanUp();
client.close();
}

private void clearPriorFragments(int currentFragIndx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ private S3SeekableByteChannel(S3Path s3Path, S3AsyncClient s3Client, long startA
position = 0L;
} else if (options.contains(StandardOpenOption.READ) || options.isEmpty()) {
LOGGER.debug("using S3ReadAheadByteChannel as read delegate for path '{}'", s3Path.toUri());
readDelegate = new S3ReadAheadByteChannel(s3Path, config.getMaxFragmentSize(), config.getMaxFragmentNumber(), s3Client, this, timeout, timeUnit);
S3AsyncClient readClient = s3Path.getFileSystem().readClient();
readDelegate = new S3ReadAheadByteChannel(s3Path, config.getMaxFragmentSize(), config.getMaxFragmentNumber(), readClient, this, timeout, timeUnit);
writeDelegate = null;
} else {
throw new IOException("Invalid channel mode");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package software.amazon.nio.spi.s3.config;


import java.net.URI;
import java.util.HashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -416,4 +417,12 @@ private int parseIntProperty(String propName, int defaultVal){
return defaultVal;
}
}

public URI endpointURI() {
String endpoint = getEndpoint();
if (endpoint.isBlank()) {
return null;
}
return URI.create(String.format("%s://%s", getEndpointProtocol(), getEndpoint()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ S3Client universalClient() {
}

@Override
protected S3AsyncClient generateAsyncClient(String bucketName) {
protected S3AsyncClient generateAsyncClient(String bucketName, boolean crt) {
return (S3AsyncClient)client;
}

Expand Down
Loading