Skip to content

Commit

Permalink
Merge pull request #59965 from nickitat/express_support
Browse files Browse the repository at this point in the history
S3Express support
  • Loading branch information
alexey-milovidov committed Mar 23, 2024
2 parents 61e74cc + f9d1c57 commit 056c8ce
Show file tree
Hide file tree
Showing 13 changed files with 148 additions and 35 deletions.
1 change: 1 addition & 0 deletions src/Backups/BackupIO_S3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ namespace
.use_virtual_addressing = s3_uri.is_virtual_hosted_style,
.disable_checksum = local_settings.s3_disable_checksum,
.gcs_issue_compose_request = context->getConfigRef().getBool("s3.gcs_issue_compose_request", false),
.is_s3express_bucket = S3::isS3ExpressEndpoint(s3_uri.endpoint),
};

return S3::ClientFactory::instance().create(
Expand Down
1 change: 1 addition & 0 deletions src/Coordination/KeeperSnapshotManagerS3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ void KeeperSnapshotManagerS3::updateS3Configuration(const Poco::Util::AbstractCo
.use_virtual_addressing = new_uri.is_virtual_hosted_style,
.disable_checksum = false,
.gcs_issue_compose_request = false,
.is_s3express_bucket = S3::isS3ExpressEndpoint(new_uri.endpoint),
};

auto client = S3::ClientFactory::instance().create(
Expand Down
17 changes: 14 additions & 3 deletions src/Disks/ObjectStorages/S3/diskSettings.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <Disks/ObjectStorages/S3/diskSettings.h>
#include "IO/S3/Client.h"
#include <IO/S3/Client.h>
#include <Common/Exception.h>

#if USE_AWS_S3

Expand All @@ -10,7 +11,7 @@
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/Context.h>
#include "Disks/DiskFactory.h"
#include <Disks/DiskFactory.h>

#include <aws/core/client/DefaultRetryStrategy.h>
#include <base/getFQDNOrHostName.h>
Expand All @@ -25,6 +26,11 @@
namespace DB
{

namespace ErrorCodes
{
extern const int NO_ELEMENTS_IN_CONFIG;
}

std::unique_ptr<S3ObjectStorageSettings> getSettings(const Poco::Util::AbstractConfiguration & config, const String & config_prefix, ContextPtr context)
{
const Settings & settings = context->getSettingsRef();
Expand All @@ -47,11 +53,15 @@ std::unique_ptr<S3::Client> getClient(
const Settings & global_settings = context->getGlobalContext()->getSettingsRef();
const Settings & local_settings = context->getSettingsRef();

String endpoint = context->getMacros()->expand(config.getString(config_prefix + ".endpoint"));
const String endpoint = context->getMacros()->expand(config.getString(config_prefix + ".endpoint"));
S3::URI uri(endpoint);
if (!uri.key.ends_with('/'))
uri.key.push_back('/');

if (S3::isS3ExpressEndpoint(endpoint) && !config.has(config_prefix + ".region"))
throw Exception(
ErrorCodes::NO_ELEMENTS_IN_CONFIG, "Region should be explicitly specified for directory buckets ({})", config_prefix);

S3::PocoHTTPClientConfiguration client_configuration = S3::ClientFactory::instance().createClientConfiguration(
config.getString(config_prefix + ".region", ""),
context->getRemoteHostFilter(),
Expand Down Expand Up @@ -93,6 +103,7 @@ std::unique_ptr<S3::Client> getClient(
.use_virtual_addressing = uri.is_virtual_hosted_style,
.disable_checksum = local_settings.s3_disable_checksum,
.gcs_issue_compose_request = config.getBool("s3.gcs_issue_compose_request", false),
.is_s3express_bucket = S3::isS3ExpressEndpoint(endpoint),
};

return S3::ClientFactory::instance().create(
Expand Down
21 changes: 17 additions & 4 deletions src/IO/S3/Client.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <IO/S3/Client.h>
#include <Common/Exception.h>

#if USE_AWS_S3

Expand Down Expand Up @@ -304,6 +305,9 @@ Model::HeadObjectOutcome Client::HeadObject(HeadObjectRequest & request) const

request.setApiMode(api_mode);

if (isS3ExpressBucket())
request.setIsS3ExpressBucket();

addAdditionalAMZHeadersToCanonicalHeadersList(request, client_configuration.extra_headers);

if (auto region = getRegionForBucket(bucket); !region.empty())
Expand Down Expand Up @@ -530,7 +534,11 @@ Client::doRequest(RequestType & request, RequestFn request_fn) const
addAdditionalAMZHeadersToCanonicalHeadersList(request, client_configuration.extra_headers);
const auto & bucket = request.GetBucket();
request.setApiMode(api_mode);
if (client_settings.disable_checksum)

/// We have to use checksums for S3Express buckets, so the order of checks should be the following
if (client_settings.is_s3express_bucket)
request.setIsS3ExpressBucket();
else if (client_settings.disable_checksum)
request.disableChecksum();

if (auto region = getRegionForBucket(bucket); !region.empty())
Expand Down Expand Up @@ -915,9 +923,9 @@ std::unique_ptr<S3::Client> ClientFactory::create( // NOLINT
std::move(sse_kms_config),
credentials_provider,
client_configuration, // Client configuration.
Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never,
client_settings
);
client_settings.is_s3express_bucket ? Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::RequestDependent
: Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never,
client_settings);
}

PocoHTTPClientConfiguration ClientFactory::createClientConfiguration( // NOLINT
Expand Down Expand Up @@ -956,6 +964,11 @@ PocoHTTPClientConfiguration ClientFactory::createClientConfiguration( // NOLINT
return config;
}

bool isS3ExpressEndpoint(const std::string & endpoint)
{
/// On one hand this check isn't 100% reliable, on the other - all it will change is whether we attach checksums to the requests.
return endpoint.contains("s3express");
}
}

}
Expand Down
6 changes: 6 additions & 0 deletions src/IO/S3/Client.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class ClientCacheRegistry
std::unordered_map<ClientCache *, std::weak_ptr<ClientCache>> client_caches;
};

bool isS3ExpressEndpoint(const std::string & endpoint);

struct ClientSettings
{
bool use_virtual_addressing;
Expand All @@ -107,6 +109,7 @@ struct ClientSettings
/// Ability to enable it preserved since likely it is required for old
/// files.
bool gcs_issue_compose_request;
bool is_s3express_bucket;
};

/// Client that improves the client from the AWS SDK
Expand Down Expand Up @@ -208,6 +211,9 @@ class Client : private Aws::S3::S3Client
const std::shared_ptr<Aws::Http::HttpRequest>& httpRequest) const override;

bool supportsMultiPartCopy() const;

bool isS3ExpressBucket() const { return client_settings.is_s3express_bucket; }

private:
friend struct ::MockS3::Client;

Expand Down
44 changes: 41 additions & 3 deletions src/IO/S3/Requests.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,44 @@
#include <aws/s3/model/UploadPartCopyRequest.h>
#include <aws/s3/model/DeleteObjectRequest.h>
#include <aws/s3/model/DeleteObjectsRequest.h>
#include <aws/s3/model/ChecksumAlgorithm.h>
#include <aws/s3/model/CompletedPart.h>
#include <aws/core/utils/HashingUtils.h>

#include <base/defines.h>

namespace DB::S3
{

namespace Model = Aws::S3::Model;

/// Used only for S3Express
namespace RequestChecksum
{
inline void setPartChecksum(Model::CompletedPart & part, const std::string & checksum)
{
part.SetChecksumCRC32(checksum);
}

inline void setRequestChecksum(Model::UploadPartRequest & req, const std::string & checksum)
{
req.SetChecksumCRC32(checksum);
}

inline std::string calculateChecksum(Model::UploadPartRequest & req)
{
chassert(req.GetChecksumAlgorithm() == Aws::S3::Model::ChecksumAlgorithm::CRC32);
return Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateCRC32(*(req.GetBody())));
}

template <typename R>
inline void setChecksumAlgorithm(R & request)
{
if constexpr (requires { request.SetChecksumAlgorithm(Model::ChecksumAlgorithm::CRC32); })
request.SetChecksumAlgorithm(Model::ChecksumAlgorithm::CRC32);
}
};

template <typename BaseRequest>
class ExtendedRequest : public BaseRequest
{
Expand All @@ -49,11 +81,13 @@ class ExtendedRequest : public BaseRequest

Aws::String GetChecksumAlgorithmName() const override
{
chassert(!is_s3express_bucket || checksum);

/// Return empty string is enough to disable checksums (see
/// AWSClient::AddChecksumToRequest [1] for more details).
///
/// [1]: https://github.com/aws/aws-sdk-cpp/blob/b0ee1c0d336dbb371c34358b68fba6c56aae2c92/src/aws-cpp-sdk-core/source/client/AWSClient.cpp#L783-L839
if (!checksum)
if (!is_s3express_bucket && !checksum)
return "";
return BaseRequest::GetChecksumAlgorithmName();
}
Expand Down Expand Up @@ -84,16 +118,20 @@ class ExtendedRequest : public BaseRequest
}

/// Disable checksum to avoid extra read of the input stream
void disableChecksum() const
void disableChecksum() const { checksum = false; }

void setIsS3ExpressBucket()
{
checksum = false;
is_s3express_bucket = true;
RequestChecksum::setChecksumAlgorithm(*this);
}

protected:
mutable std::string region_override;
mutable std::optional<S3::URI> uri_override;
mutable ApiMode api_mode{ApiMode::AWS};
mutable bool checksum = true;
bool is_s3express_bucket = false;
};

class CopyObjectRequest : public ExtendedRequest<Model::CopyObjectRequest>
Expand Down
18 changes: 7 additions & 11 deletions src/IO/S3/URI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ URI::URI(const std::string & uri_)
/// Case when bucket name represented in domain name of S3 URL.
/// E.g. (https://bucket-name.s3.Region.amazonaws.com/key)
/// https://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html#virtual-hosted-style-access
static const RE2 virtual_hosted_style_pattern(R"((.+)\.(s3|cos|obs|oss|eos)([.\-][a-z0-9\-.:]+))");
static const RE2 virtual_hosted_style_pattern(R"((.+)\.(s3express[\-a-z0-9]+|s3|cos|obs|oss|eos)([.\-][a-z0-9\-.:]+))");

/// Case when bucket name and key represented in path of S3 URL.
/// E.g. (https://s3.Region.amazonaws.com/bucket-name/key)
/// https://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html#path-style-access
static const RE2 path_style_pattern("^/([^/]*)/(.*)");

static constexpr auto S3 = "S3";
static constexpr auto S3EXPRESS = "S3EXPRESS";
static constexpr auto COSN = "COSN";
static constexpr auto COS = "COS";
static constexpr auto OBS = "OBS";
Expand Down Expand Up @@ -115,21 +116,16 @@ URI::URI(const std::string & uri_)
}

boost::to_upper(name);
if (name != S3 && name != COS && name != OBS && name != OSS && name != EOS)
/// For S3Express it will look like s3express-eun1-az1, i.e. contain region and AZ info
if (name != S3 && !name.starts_with(S3EXPRESS) && name != COS && name != OBS && name != OSS && name != EOS)
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Object storage system name is unrecognized in virtual hosted style S3 URI: {}",
quoteString(name));

if (name == S3)
storage_name = name;
else if (name == OBS)
storage_name = OBS;
else if (name == OSS)
storage_name = OSS;
else if (name == EOS)
storage_name = EOS;
else
if (name == COS)
storage_name = COSN;
else
storage_name = name;
}
else if (re2::RE2::PartialMatch(uri.getPath(), path_style_pattern, &bucket, &key))
{
Expand Down
25 changes: 24 additions & 1 deletion src/IO/S3/tests/gtest_aws_s3_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ void testServerSideEncryption(
bool disable_checksum,
String server_side_encryption_customer_key_base64,
DB::S3::ServerSideEncryptionKMSConfig sse_kms_config,
String expected_headers)
String expected_headers,
bool is_s3express_bucket = false)
{
TestPocoHTTPServer http;

Expand Down Expand Up @@ -144,6 +145,7 @@ void testServerSideEncryption(
.use_virtual_addressing = uri.is_virtual_hosted_style,
.disable_checksum = disable_checksum,
.gcs_issue_compose_request = false,
.is_s3express_bucket = is_s3express_bucket,
};

std::shared_ptr<DB::S3::Client> client = DB::S3::ClientFactory::instance().create(
Expand Down Expand Up @@ -295,4 +297,25 @@ TEST(IOTestAwsS3Client, AppendExtraSSEKMSHeadersWrite)
"x-amz-server-side-encryption-context: arn:aws:s3:::bucket_ARN\n");
}

TEST(IOTestAwsS3Client, ChecksumHeaderIsPresentForS3Express)
{
/// See https://github.com/ClickHouse/ClickHouse/pull/19748
testServerSideEncryption(
doWriteRequest,
/* disable_checksum= */ true,
"",
{},
"authorization: ... SignedHeaders="
"amz-sdk-invocation-id;"
"amz-sdk-request;"
"content-length;"
"content-type;"
"host;"
"x-amz-checksum-crc32;"
"x-amz-content-sha256;"
"x-amz-date;"
"x-amz-sdk-checksum-algorithm, ...\n",
/*is_s3express_bucket=*/true);
}

#endif
15 changes: 12 additions & 3 deletions src/IO/WriteBufferFromS3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
#include <IO/S3/getObjectInfo.h>
#include <IO/S3/BlobStorageLogWriter.h>

#include <aws/s3/model/StorageClass.h>

#include <utility>


Expand Down Expand Up @@ -456,6 +454,14 @@ S3::UploadPartRequest WriteBufferFromS3::getUploadRequest(size_t part_number, Pa
/// If we don't do it, AWS SDK can mistakenly set it to application/xml, see https://github.com/aws/aws-sdk-cpp/issues/1840
req.SetContentType("binary/octet-stream");

/// Checksums need to be provided on CompleteMultipartUpload requests, so we calculate then manually and store in multipart_checksums
if (client_ptr->isS3ExpressBucket())
{
auto checksum = S3::RequestChecksum::calculateChecksum(req);
S3::RequestChecksum::setRequestChecksum(req, checksum);
multipart_checksums.push_back(std::move(checksum));
}

return req;
}

Expand Down Expand Up @@ -575,7 +581,10 @@ void WriteBufferFromS3::completeMultipartUpload()
for (size_t i = 0; i < multipart_tags.size(); ++i)
{
Aws::S3::Model::CompletedPart part;
multipart_upload.AddParts(part.WithETag(multipart_tags[i]).WithPartNumber(static_cast<int>(i + 1)));
part.WithETag(multipart_tags[i]).WithPartNumber(static_cast<int>(i + 1));
if (!multipart_checksums.empty())
S3::RequestChecksum::setPartChecksum(part, multipart_checksums.at(i));
multipart_upload.AddParts(part);
}

req.SetMultipartUpload(multipart_upload);
Expand Down
1 change: 1 addition & 0 deletions src/IO/WriteBufferFromS3.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class WriteBufferFromS3 final : public WriteBufferFromFileBase
/// We initiate upload, then upload each part and get ETag as a response, and then finalizeImpl() upload with listing all our parts.
String multipart_upload_id;
std::deque<String> multipart_tags;
std::deque<String> multipart_checksums; // if enabled
bool multipart_upload_finished = false;

/// Track that prefinalize() is called only once
Expand Down
8 changes: 8 additions & 0 deletions src/IO/tests/gtest_s3_uri.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,14 @@ TEST(S3UriTest, validPatterns)
ASSERT_EQ("", uri.version_id);
ASSERT_EQ(false, uri.is_virtual_hosted_style);
}
{
S3::URI uri("https://test-perf-bucket--eun1-az1--x-s3.s3express-eun1-az1.eu-north-1.amazonaws.com/test.csv");
ASSERT_EQ("https://s3express-eun1-az1.eu-north-1.amazonaws.com", uri.endpoint);
ASSERT_EQ("test-perf-bucket--eun1-az1--x-s3", uri.bucket);
ASSERT_EQ("test.csv", uri.key);
ASSERT_EQ("", uri.version_id);
ASSERT_EQ(true, uri.is_virtual_hosted_style);
}
}

TEST_P(S3UriTest, invalidPatterns)
Expand Down

0 comments on commit 056c8ce

Please sign in to comment.