Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GCS native copy #48981

Merged
merged 2 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/IO/S3/Client.cpp
Expand Up @@ -124,7 +124,11 @@ Client::Client(
auto * endpoint_provider = dynamic_cast<Aws::S3::Endpoint::S3DefaultEpProviderBase *>(accessEndpointProvider().get());
endpoint_provider->GetBuiltInParameters().GetParameter("Region").GetString(explicit_region);
endpoint_provider->GetBuiltInParameters().GetParameter("Endpoint").GetString(initial_endpoint);
detect_region = explicit_region == Aws::Region::AWS_GLOBAL && initial_endpoint.find(".amazonaws.com") != std::string::npos;

provider_type = getProviderTypeFromURL(initial_endpoint);
LOG_TRACE(log, "Provider type: {}", toString(provider_type));

detect_region = provider_type == ProviderType::AWS && explicit_region == Aws::Region::AWS_GLOBAL;

cache = std::make_shared<ClientCache>();
ClientCacheRegistry::instance().registerClient(cache);
Expand All @@ -135,6 +139,7 @@ Client::Client(const Client & other)
, initial_endpoint(other.initial_endpoint)
, explicit_region(other.explicit_region)
, detect_region(other.detect_region)
, provider_type(other.provider_type)
, max_redirects(other.max_redirects)
, log(&Poco::Logger::get("S3Client"))
{
Expand Down Expand Up @@ -177,6 +182,8 @@ Model::HeadObjectOutcome Client::HeadObject(const HeadObjectRequest & request) c
{
const auto & bucket = request.GetBucket();

request.setProviderType(provider_type);

if (auto region = getRegionForBucket(bucket); !region.empty())
{
if (!detect_region)
Expand Down Expand Up @@ -315,6 +322,7 @@ std::invoke_result_t<RequestFn, RequestType>
Client::doRequest(const RequestType & request, RequestFn request_fn) const
{
const auto & bucket = request.GetBucket();
request.setProviderType(provider_type);

if (auto region = getRegionForBucket(bucket); !region.empty())
{
Expand Down Expand Up @@ -387,6 +395,11 @@ Client::doRequest(const RequestType & request, RequestFn request_fn) const
throw Exception(ErrorCodes::TOO_MANY_REDIRECTS, "Too many redirects");
}

ProviderType Client::getProviderType() const
{
return provider_type;
}

std::string Client::getRegionForBucket(const std::string & bucket, bool force_detect) const
{
std::lock_guard lock(cache->region_cache_mutex);
Expand All @@ -396,7 +409,6 @@ std::string Client::getRegionForBucket(const std::string & bucket, bool force_de
if (!force_detect && !detect_region)
return "";


LOG_INFO(log, "Resolving region for bucket {}", bucket);
Aws::S3::Model::HeadBucketRequest req;
req.SetBucket(bucket);
Expand Down
5 changes: 5 additions & 0 deletions src/IO/S3/Client.h
Expand Up @@ -11,6 +11,7 @@
#include <IO/S3/Requests.h>
#include <IO/S3/PocoHTTPClient.h>
#include <IO/S3/Credentials.h>
#include <IO/S3/ProviderType.h>

#include <aws/core/Aws.h>
#include <aws/core/client/DefaultRetryStrategy.h>
Expand Down Expand Up @@ -160,6 +161,8 @@ class Client : private Aws::S3::S3Client

using Aws::S3::S3Client::EnableRequestProcessing;
using Aws::S3::S3Client::DisableRequestProcessing;

ProviderType getProviderType() const;
private:
Client(size_t max_redirects_,
const std::shared_ptr<Aws::Auth::AWSCredentialsProvider>& credentials_provider,
Expand Down Expand Up @@ -206,6 +209,8 @@ class Client : private Aws::S3::S3Client
std::string explicit_region;
mutable bool detect_region = true;

ProviderType provider_type{ProviderType::UNKNOWN};

mutable std::shared_ptr<ClientCache> cache;

const size_t max_redirects;
Expand Down
13 changes: 12 additions & 1 deletion src/IO/S3/PocoHTTPClient.cpp
Expand Up @@ -15,6 +15,7 @@
#include <IO/HTTPCommon.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
#include <IO/S3/ProviderType.h>

#include <aws/core/http/HttpRequest.h>
#include <aws/core/http/HttpResponse.h>
Expand Down Expand Up @@ -187,7 +188,7 @@ namespace
bool checkRequestCanReturn2xxAndErrorInBody(Aws::Http::HttpRequest & request)
{
auto query_params = request.GetQueryStringParameters();
if (request.HasHeader("x-amz-copy-source"))
if (request.HasHeader("x-amz-copy-source") || request.HasHeader("x-goog-copy-source"))
{
/// CopyObject https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html
if (query_params.empty())
Expand Down Expand Up @@ -259,6 +260,16 @@ void PocoHTTPClient::makeRequestInternal(
Poco::Logger * log = &Poco::Logger::get("AWSClient");

auto uri = request.GetUri().GetURIString();
auto provider_type = getProviderTypeFromURL(uri);

if (provider_type == ProviderType::GCS)
{
/// some GCS requests don't like S3 specific headers that the client sets
request.DeleteHeader("x-amz-api-version");
request.DeleteHeader("amz-sdk-invocation-id");
request.DeleteHeader("amz-sdk-request");
}

if (enable_s3_requests_logging)
LOG_TEST(log, "Make request to: {}", uri);

Expand Down
43 changes: 43 additions & 0 deletions src/IO/S3/ProviderType.cpp
@@ -0,0 +1,43 @@
#include <IO/S3/ProviderType.h>

#if USE_AWS_S3

#include <string>

namespace DB::S3
{

std::string_view toString(ProviderType provider_type)
{
using enum ProviderType;

switch (provider_type)
{
case AWS:
return "AWS";
case GCS:
return "GCS";
case UNKNOWN:
return "Unknown";
}
}

bool supportsMultiPartCopy(ProviderType provider_type)
{
return provider_type != ProviderType::GCS;
}

ProviderType getProviderTypeFromURL(const std::string & url)
{
if (url.find(".amazonaws.com") != std::string::npos)
return ProviderType::AWS;

if (url.find("storage.googleapis.com") != std::string::npos)
return ProviderType::GCS;

return ProviderType::UNKNOWN;
}

}

#endif
28 changes: 28 additions & 0 deletions src/IO/S3/ProviderType.h
@@ -0,0 +1,28 @@
#pragma once

#include "config.h"

#if USE_AWS_S3

#include <string_view>
#include <cstdint>

namespace DB::S3
{

enum class ProviderType : uint8_t
{
AWS,
GCS,
UNKNOWN
};

std::string_view toString(ProviderType provider_type);

bool supportsMultiPartCopy(ProviderType provider_type);

ProviderType getProviderTypeFromURL(const std::string & url);

}

#endif
55 changes: 55 additions & 0 deletions src/IO/S3/Requests.cpp
@@ -0,0 +1,55 @@
#include <IO/S3/Requests.h>

#if USE_AWS_S3

#include <Common/logger_useful.h>

namespace DB::S3
{

Aws::Http::HeaderValueCollection CopyObjectRequest::GetRequestSpecificHeaders() const
{
auto headers = Model::CopyObjectRequest::GetRequestSpecificHeaders();
if (provider_type != ProviderType::GCS)
return headers;

/// GCS supports same headers as S3 but with a prefix x-goog instead of x-amz
/// we have to replace all the prefixes client set internally
const auto replace_with_gcs_header = [&](const std::string & amz_header, const std::string & gcs_header)
{
if (const auto it = headers.find(amz_header); it != headers.end())
{
auto header_value = std::move(it->second);
headers.erase(it);
headers.emplace(gcs_header, std::move(header_value));
}
};

replace_with_gcs_header("x-amz-copy-source", "x-goog-copy-source");
replace_with_gcs_header("x-amz-metadata-directive", "x-goog-metadata-directive");
replace_with_gcs_header("x-amz-storage-class", "x-goog-storage-class");

/// replace all x-amz-meta- headers
std::vector<std::pair<std::string, std::string>> new_meta_headers;
for (auto it = headers.begin(); it != headers.end();)
{
if (it->first.starts_with("x-amz-meta-"))
{
auto value = std::move(it->second);
auto header = "x-goog" + it->first.substr(/* x-amz */ 5);
new_meta_headers.emplace_back(std::pair{std::move(header), std::move(value)});
it = headers.erase(it);
}
else
++it;
}

for (auto & [header, value] : new_meta_headers)
headers.emplace(std::move(header), std::move(value));

return headers;
}

}

#endif
14 changes: 13 additions & 1 deletion src/IO/S3/Requests.h
Expand Up @@ -5,6 +5,7 @@
#if USE_AWS_S3

#include <IO/S3/URI.h>
#include <IO/S3/ProviderType.h>

#include <aws/core/endpoint/EndpointParameter.h>
#include <aws/s3/model/HeadObjectRequest.h>
Expand Down Expand Up @@ -61,9 +62,21 @@ class ExtendedRequest : public BaseRequest
return uri_override;
}

void setProviderType(ProviderType provider_type_) const
{
provider_type = provider_type_;
}

protected:
mutable std::string region_override;
mutable std::optional<S3::URI> uri_override;
mutable ProviderType provider_type{ProviderType::UNKNOWN};
};

class CopyObjectRequest : public ExtendedRequest<Model::CopyObjectRequest>
{
public:
Aws::Http::HeaderValueCollection GetRequestSpecificHeaders() const override;
};

using HeadObjectRequest = ExtendedRequest<Model::HeadObjectRequest>;
Expand All @@ -78,7 +91,6 @@ using UploadPartRequest = ExtendedRequest<Model::UploadPartRequest>;
using UploadPartCopyRequest = ExtendedRequest<Model::UploadPartCopyRequest>;

using PutObjectRequest = ExtendedRequest<Model::PutObjectRequest>;
using CopyObjectRequest = ExtendedRequest<Model::CopyObjectRequest>;
using DeleteObjectRequest = ExtendedRequest<Model::DeleteObjectRequest>;
using DeleteObjectsRequest = ExtendedRequest<Model::DeleteObjectsRequest>;

Expand Down