From 3f0e12d8d362752181c75836d25d424862acc424 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Mon, 15 Aug 2022 12:30:30 +0100 Subject: [PATCH] Replace rusoto with custom implementation for AWS (#2176) (#2352) * Replace rusoto (#2176) * Add integration test for metadata endpoint * Fix WebIdentity * Fix doc * Fix handling of multipart errors * Use separate client for credentials * Include port in Host header canonical request * Fix doc link * Review feedback --- .github/workflows/object_store.yml | 12 +- object_store/Cargo.toml | 11 +- object_store/src/aws.rs | 1343 ------------------------- object_store/src/aws/client.rs | 483 +++++++++ object_store/src/aws/credential.rs | 590 +++++++++++ object_store/src/aws/mod.rs | 646 ++++++++++++ object_store/src/azure.rs | 85 +- object_store/src/client/mod.rs | 2 + object_store/src/client/pagination.rs | 70 ++ object_store/src/client/token.rs | 10 +- object_store/src/gcp.rs | 219 ++-- object_store/src/lib.rs | 14 +- object_store/src/multipart.rs | 59 +- 13 files changed, 1982 insertions(+), 1562 deletions(-) delete mode 100644 object_store/src/aws.rs create mode 100644 object_store/src/aws/client.rs create mode 100644 object_store/src/aws/credential.rs create mode 100644 object_store/src/aws/mod.rs create mode 100644 object_store/src/client/pagination.rs diff --git a/.github/workflows/object_store.yml b/.github/workflows/object_store.yml index 6c81604a96a..5da2cb4e6cb 100644 --- a/.github/workflows/object_store.yml +++ b/.github/workflows/object_store.yml @@ -59,6 +59,13 @@ jobs: image: localstack/localstack:0.14.4 ports: - 4566:4566 + ec2-metadata: + image: amazon/amazon-ec2-metadata-mock:v1.9.2 + ports: + - 1338:1338 + env: + # Only allow IMDSv2 + AEMM_IMDSV2: "1" azurite: image: mcr.microsoft.com/azure-storage/azurite ports: @@ -78,6 +85,7 @@ jobs: AWS_ACCESS_KEY_ID: test AWS_SECRET_ACCESS_KEY: test AWS_ENDPOINT: http://localstack:4566 + EC2_METADATA_ENDPOINT: http://ec2-metadata:1338 AZURE_USE_EMULATOR: "1" AZURITE_BLOB_STORAGE_URL: "http://azurite:10000" AZURITE_QUEUE_STORAGE_URL: "http://azurite:10001" @@ -101,8 +109,8 @@ jobs: aws --endpoint-url=http://localstack:4566 s3 mb s3://test-bucket - name: Configure Azurite (Azure emulation) - # the magical connection string is from - # https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azurite?tabs=visual-studio#http-connection-strings + # the magical connection string is from + # https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azurite?tabs=visual-studio#http-connection-strings run: | curl -sL https://aka.ms/InstallAzureCLIDeb | bash az storage container create -n test-bucket --connection-string 'DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://azurite:10000/devstoreaccount1;QueueEndpoint=http://azurite:10001/devstoreaccount1;' diff --git a/object_store/Cargo.toml b/object_store/Cargo.toml index bb371988aa4..8c713d80b88 100644 --- a/object_store/Cargo.toml +++ b/object_store/Cargo.toml @@ -46,17 +46,8 @@ rustls-pemfile = { version = "1.0", default-features = false, optional = true } ring = { version = "0.16", default-features = false, features = ["std"], optional = true } base64 = { version = "0.13", default-features = false, optional = true } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true } -# for rusoto -hyper = { version = "0.14", optional = true, default-features = false } -# for rusoto -hyper-rustls = { version = "0.23.0", optional = true, default-features = false, features = ["webpki-tokio", "http1", "http2", "tls12"] } itertools = "0.10.1" percent-encoding = "2.1" -# rusoto crates are for Amazon S3 integration -rusoto_core = { version = "0.48.0", optional = true, default-features = false, features = ["rustls"] } -rusoto_credential = { version = "0.48.0", optional = true, default-features = false } -rusoto_s3 = { version = "0.48.0", optional = true, default-features = false, features = ["rustls"] } -rusoto_sts = { version = "0.48.0", optional = true, default-features = false, features = ["rustls"] } snafu = "0.7" tokio = { version = "1.18", features = ["sync", "macros", "parking_lot", "rt-multi-thread", "time", "io-util"] } tracing = { version = "0.1" } @@ -70,7 +61,7 @@ walkdir = "2" azure = ["azure_core", "azure_storage_blobs", "azure_storage", "reqwest", "azure_identity"] azure_test = ["azure", "azure_core/azurite_workaround", "azure_storage/azurite_workaround", "azure_storage_blobs/azurite_workaround"] gcp = ["serde", "serde_json", "quick-xml", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "rustls-pemfile", "base64", "rand", "ring"] -aws = ["rusoto_core", "rusoto_credential", "rusoto_s3", "rusoto_sts", "hyper", "hyper-rustls"] +aws = ["serde", "serde_json", "quick-xml", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "rustls-pemfile", "base64", "rand", "ring"] [dev-dependencies] # In alphabetical order dotenv = "0.15.0" diff --git a/object_store/src/aws.rs b/object_store/src/aws.rs deleted file mode 100644 index bcb294c0037..00000000000 --- a/object_store/src/aws.rs +++ /dev/null @@ -1,1343 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! An object store implementation for S3 -//! -//! ## Multi-part uploads -//! -//! Multi-part uploads can be initiated with the [ObjectStore::put_multipart] method. -//! Data passed to the writer is automatically buffered to meet the minimum size -//! requirements for a part. Multiple parts are uploaded concurrently. -//! -//! If the writer fails for any reason, you may have parts uploaded to AWS but not -//! used that you may be charged for. Use the [ObjectStore::abort_multipart] method -//! to abort the upload and drop those unneeded parts. In addition, you may wish to -//! consider implementing [automatic cleanup] of unused parts that are older than one -//! week. -//! -//! [automatic cleanup]: https://aws.amazon.com/blogs/aws/s3-lifecycle-management-update-support-for-multipart-uploads-and-delete-markers/ -use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}; -use crate::util::format_http_range; -use crate::MultipartId; -use crate::{ - collect_bytes, - path::{Path, DELIMITER}, - util::format_prefix, - GetResult, ListResult, ObjectMeta, ObjectStore, Result, -}; -use async_trait::async_trait; -use bytes::Bytes; -use chrono::{DateTime, Utc}; -use futures::future::BoxFuture; -use futures::{ - stream::{self, BoxStream}, - Future, Stream, StreamExt, TryStreamExt, -}; -use hyper::client::Builder as HyperBuilder; -use percent_encoding::{percent_encode, AsciiSet, NON_ALPHANUMERIC}; -use rusoto_core::ByteStream; -use rusoto_credential::{InstanceMetadataProvider, StaticProvider}; -use rusoto_s3::S3; -use rusoto_sts::WebIdentityProvider; -use snafu::{OptionExt, ResultExt, Snafu}; -use std::io; -use std::ops::Range; -use std::{ - convert::TryFrom, fmt, num::NonZeroUsize, ops::Deref, sync::Arc, time::Duration, -}; -use tokio::io::AsyncWrite; -use tokio::sync::{OwnedSemaphorePermit, Semaphore}; -use tracing::{debug, warn}; - -// Do not URI-encode any of the unreserved characters that RFC 3986 defines: -// A-Z, a-z, 0-9, hyphen ( - ), underscore ( _ ), period ( . ), and tilde ( ~ ). -const STRICT_ENCODE_SET: AsciiSet = NON_ALPHANUMERIC - .remove(b'-') - .remove(b'.') - .remove(b'_') - .remove(b'~'); - -/// This struct is used to maintain the URI path encoding -const STRICT_PATH_ENCODE_SET: AsciiSet = STRICT_ENCODE_SET.remove(b'/'); - -/// The maximum number of times a request will be retried in the case of an AWS server error -pub const MAX_NUM_RETRIES: u32 = 3; - -/// A specialized `Error` for object store-related errors -#[derive(Debug, Snafu)] -#[allow(missing_docs)] -enum Error { - #[snafu(display( - "Expected streamed data to have length {}, got {}", - expected, - actual - ))] - DataDoesNotMatchLength { expected: usize, actual: usize }, - - #[snafu(display( - "Did not receive any data. Bucket: {}, Location: {}", - bucket, - path - ))] - NoData { bucket: String, path: String }, - - #[snafu(display( - "Unable to DELETE data. Bucket: {}, Location: {}, Error: {} ({:?})", - bucket, - path, - source, - source, - ))] - UnableToDeleteData { - source: rusoto_core::RusotoError, - bucket: String, - path: String, - }, - - #[snafu(display( - "Unable to GET data. Bucket: {}, Location: {}, Error: {} ({:?})", - bucket, - path, - source, - source, - ))] - UnableToGetData { - source: rusoto_core::RusotoError, - bucket: String, - path: String, - }, - - #[snafu(display( - "Unable to HEAD data. Bucket: {}, Location: {}, Error: {} ({:?})", - bucket, - path, - source, - source, - ))] - UnableToHeadData { - source: rusoto_core::RusotoError, - bucket: String, - path: String, - }, - - #[snafu(display( - "Unable to GET part of the data. Bucket: {}, Location: {}, Error: {} ({:?})", - bucket, - path, - source, - source, - ))] - UnableToGetPieceOfData { - source: std::io::Error, - bucket: String, - path: String, - }, - - #[snafu(display( - "Unable to PUT data. Bucket: {}, Location: {}, Error: {} ({:?})", - bucket, - path, - source, - source, - ))] - UnableToPutData { - source: rusoto_core::RusotoError, - bucket: String, - path: String, - }, - - #[snafu(display( - "Unable to upload data. Bucket: {}, Location: {}, Error: {} ({:?})", - bucket, - path, - source, - source, - ))] - UnableToUploadData { - source: rusoto_core::RusotoError, - bucket: String, - path: String, - }, - - #[snafu(display( - "Unable to cleanup multipart data. Bucket: {}, Location: {}, Error: {} ({:?})", - bucket, - path, - source, - source, - ))] - UnableToCleanupMultipartData { - source: rusoto_core::RusotoError, - bucket: String, - path: String, - }, - - #[snafu(display( - "Unable to list data. Bucket: {}, Error: {} ({:?})", - bucket, - source, - source, - ))] - UnableToListData { - source: rusoto_core::RusotoError, - bucket: String, - }, - - #[snafu(display( - "Unable to copy object. Bucket: {}, From: {}, To: {}, Error: {}", - bucket, - from, - to, - source, - ))] - UnableToCopyObject { - source: rusoto_core::RusotoError, - bucket: String, - from: String, - to: String, - }, - - #[snafu(display( - "Unable to parse last modified date. Bucket: {}, Error: {} ({:?})", - bucket, - source, - source, - ))] - UnableToParseLastModified { - source: chrono::ParseError, - bucket: String, - }, - - #[snafu(display( - "Unable to buffer data into temporary file, Error: {} ({:?})", - source, - source, - ))] - UnableToBufferStream { source: std::io::Error }, - - #[snafu(display( - "Could not parse `{}` as an AWS region. Regions should look like `us-east-2`. {} ({:?})", - region, - source, - source, - ))] - InvalidRegion { - region: String, - source: rusoto_core::region::ParseRegionError, - }, - - #[snafu(display( - "Region must be specified for AWS S3. Regions should look like `us-east-2`" - ))] - MissingRegion {}, - - #[snafu(display("Missing bucket name"))] - MissingBucketName {}, - - #[snafu(display("Missing aws-access-key"))] - MissingAccessKey, - - #[snafu(display("Missing aws-secret-access-key"))] - MissingSecretAccessKey, - - NotFound { - path: String, - source: Box, - }, -} - -impl From for super::Error { - fn from(source: Error) -> Self { - match source { - Error::NotFound { path, source } => Self::NotFound { path, source }, - _ => Self::Generic { - store: "S3", - source: Box::new(source), - }, - } - } -} - -/// Interface for [Amazon S3](https://aws.amazon.com/s3/). -pub struct AmazonS3 { - /// S3 client w/o any connection limit. - /// - /// You should normally use [`Self::client`] instead. - client_unrestricted: rusoto_s3::S3Client, - - /// Semaphore that limits the usage of [`client_unrestricted`](Self::client_unrestricted). - connection_semaphore: Arc, - - /// Bucket name used by this object store client. - bucket_name: String, -} - -impl fmt::Debug for AmazonS3 { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("AmazonS3") - .field("client", &"rusoto_s3::S3Client") - .field("bucket_name", &self.bucket_name) - .finish() - } -} - -impl fmt::Display for AmazonS3 { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "AmazonS3({})", self.bucket_name) - } -} - -#[async_trait] -impl ObjectStore for AmazonS3 { - async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { - let bucket_name = self.bucket_name.clone(); - let request_factory = move || { - let bytes = bytes.clone(); - - let length = bytes.len(); - let stream_data = Ok(bytes); - let stream = futures::stream::once(async move { stream_data }); - let byte_stream = ByteStream::new_with_size(stream, length); - - rusoto_s3::PutObjectRequest { - bucket: bucket_name.clone(), - key: location.to_string(), - body: Some(byte_stream), - ..Default::default() - } - }; - - let s3 = self.client().await; - - s3_request(move || { - let (s3, request_factory) = (s3.clone(), request_factory.clone()); - - async move { s3.put_object(request_factory()).await } - }) - .await - .context(UnableToPutDataSnafu { - bucket: &self.bucket_name, - path: location.as_ref(), - })?; - - Ok(()) - } - - async fn put_multipart( - &self, - location: &Path, - ) -> Result<(MultipartId, Box)> { - let bucket_name = self.bucket_name.clone(); - - let request_factory = move || rusoto_s3::CreateMultipartUploadRequest { - bucket: bucket_name.clone(), - key: location.to_string(), - ..Default::default() - }; - - let s3 = self.client().await; - - let data = s3_request(move || { - let (s3, request_factory) = (s3.clone(), request_factory.clone()); - - async move { s3.create_multipart_upload(request_factory()).await } - }) - .await - .context(UnableToUploadDataSnafu { - bucket: &self.bucket_name, - path: location.as_ref(), - })?; - - let upload_id = data.upload_id.unwrap(); - - let inner = S3MultiPartUpload { - upload_id: upload_id.clone(), - bucket: self.bucket_name.clone(), - key: location.to_string(), - client_unrestricted: self.client_unrestricted.clone(), - connection_semaphore: Arc::clone(&self.connection_semaphore), - }; - - Ok((upload_id, Box::new(CloudMultiPartUpload::new(inner, 8)))) - } - - async fn abort_multipart( - &self, - location: &Path, - multipart_id: &MultipartId, - ) -> Result<()> { - let request_factory = move || rusoto_s3::AbortMultipartUploadRequest { - bucket: self.bucket_name.clone(), - key: location.to_string(), - upload_id: multipart_id.to_string(), - ..Default::default() - }; - - let s3 = self.client().await; - s3_request(move || { - let (s3, request_factory) = (s3.clone(), request_factory); - - async move { s3.abort_multipart_upload(request_factory()).await } - }) - .await - .context(UnableToCleanupMultipartDataSnafu { - bucket: &self.bucket_name, - path: location.as_ref(), - })?; - - Ok(()) - } - - async fn get(&self, location: &Path) -> Result { - Ok(GetResult::Stream( - self.get_object(location, None).await?.boxed(), - )) - } - - async fn get_range(&self, location: &Path, range: Range) -> Result { - let size_hint = range.end - range.start; - let stream = self.get_object(location, Some(range)).await?; - collect_bytes(stream, Some(size_hint)).await - } - - async fn head(&self, location: &Path) -> Result { - let key = location.to_string(); - let head_request = rusoto_s3::HeadObjectRequest { - bucket: self.bucket_name.clone(), - key: key.clone(), - ..Default::default() - }; - let s = self - .client() - .await - .head_object(head_request) - .await - .map_err(|e| match e { - rusoto_core::RusotoError::Service( - rusoto_s3::HeadObjectError::NoSuchKey(_), - ) => Error::NotFound { - path: key.clone(), - source: e.into(), - }, - rusoto_core::RusotoError::Unknown(h) if h.status.as_u16() == 404 => { - Error::NotFound { - path: key.clone(), - source: "resource not found".into(), - } - } - _ => Error::UnableToHeadData { - bucket: self.bucket_name.to_owned(), - path: key.clone(), - source: e, - }, - })?; - - // Note: GetObject and HeadObject return a different date format from ListObjects - // - // S3 List returns timestamps in the form - // 2013-09-17T18:07:53.000Z - // S3 GetObject returns timestamps in the form - // Last-Modified: Sun, 1 Jan 2006 12:00:00 GMT - let last_modified = match s.last_modified { - Some(lm) => DateTime::parse_from_rfc2822(&lm) - .context(UnableToParseLastModifiedSnafu { - bucket: &self.bucket_name, - })? - .with_timezone(&Utc), - None => Utc::now(), - }; - - Ok(ObjectMeta { - last_modified, - location: location.clone(), - size: usize::try_from(s.content_length.unwrap_or(0)) - .expect("unsupported size on this platform"), - }) - } - - async fn delete(&self, location: &Path) -> Result<()> { - let bucket_name = self.bucket_name.clone(); - - let request_factory = move || rusoto_s3::DeleteObjectRequest { - bucket: bucket_name.clone(), - key: location.to_string(), - ..Default::default() - }; - - let s3 = self.client().await; - - s3_request(move || { - let (s3, request_factory) = (s3.clone(), request_factory.clone()); - - async move { s3.delete_object(request_factory()).await } - }) - .await - .context(UnableToDeleteDataSnafu { - bucket: &self.bucket_name, - path: location.as_ref(), - })?; - - Ok(()) - } - - async fn list( - &self, - prefix: Option<&Path>, - ) -> Result>> { - Ok(self - .list_objects_v2(prefix, None) - .await? - .map_ok(move |list_objects_v2_result| { - let contents = list_objects_v2_result.contents.unwrap_or_default(); - let iter = contents - .into_iter() - .map(|object| convert_object_meta(object, &self.bucket_name)); - - futures::stream::iter(iter) - }) - .try_flatten() - .boxed()) - } - - async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { - Ok(self - .list_objects_v2(prefix, Some(DELIMITER.to_string())) - .await? - .try_fold( - ListResult { - common_prefixes: vec![], - objects: vec![], - }, - |acc, list_objects_v2_result| async move { - let mut res = acc; - let contents = list_objects_v2_result.contents.unwrap_or_default(); - let mut objects = contents - .into_iter() - .map(|object| convert_object_meta(object, &self.bucket_name)) - .collect::>>()?; - - res.objects.append(&mut objects); - - let prefixes = - list_objects_v2_result.common_prefixes.unwrap_or_default(); - res.common_prefixes.reserve(prefixes.len()); - - for p in prefixes { - let prefix = - p.prefix.expect("can't have a prefix without a value"); - res.common_prefixes.push(Path::parse(prefix)?); - } - - Ok(res) - }, - ) - .await?) - } - - async fn copy(&self, from: &Path, to: &Path) -> Result<()> { - let from = from.as_ref(); - let to = to.as_ref(); - let bucket_name = self.bucket_name.clone(); - - let copy_source = format!( - "{}/{}", - &bucket_name, - percent_encode(from.as_ref(), &STRICT_PATH_ENCODE_SET) - ); - - let request_factory = move || rusoto_s3::CopyObjectRequest { - bucket: bucket_name.clone(), - copy_source, - key: to.to_string(), - ..Default::default() - }; - - let s3 = self.client().await; - - s3_request(move || { - let (s3, request_factory) = (s3.clone(), request_factory.clone()); - - async move { s3.copy_object(request_factory()).await } - }) - .await - .context(UnableToCopyObjectSnafu { - bucket: &self.bucket_name, - from, - to, - })?; - - Ok(()) - } - - async fn copy_if_not_exists(&self, _source: &Path, _dest: &Path) -> Result<()> { - // Will need dynamodb_lock - Err(crate::Error::NotImplemented) - } -} - -fn convert_object_meta(object: rusoto_s3::Object, bucket: &str) -> Result { - let key = object.key.expect("object doesn't exist without a key"); - let location = Path::parse(key)?; - let last_modified = match object.last_modified { - Some(lm) => DateTime::parse_from_rfc3339(&lm) - .context(UnableToParseLastModifiedSnafu { bucket })? - .with_timezone(&Utc), - None => Utc::now(), - }; - let size = usize::try_from(object.size.unwrap_or(0)) - .expect("unsupported size on this platform"); - - Ok(ObjectMeta { - location, - last_modified, - size, - }) -} - -/// Configure a connection to Amazon S3 using the specified credentials in -/// the specified Amazon region and bucket. -/// -/// # Example -/// ``` -/// # let REGION = "foo"; -/// # let BUCKET_NAME = "foo"; -/// # let ACCESS_KEY_ID = "foo"; -/// # let SECRET_KEY = "foo"; -/// # use object_store::aws::AmazonS3Builder; -/// let s3 = AmazonS3Builder::new() -/// .with_region(REGION) -/// .with_bucket_name(BUCKET_NAME) -/// .with_access_key_id(ACCESS_KEY_ID) -/// .with_secret_access_key(SECRET_KEY) -/// .build(); -/// ``` -#[derive(Debug)] -pub struct AmazonS3Builder { - access_key_id: Option, - secret_access_key: Option, - region: Option, - bucket_name: Option, - endpoint: Option, - token: Option, - max_connections: NonZeroUsize, - allow_http: bool, -} - -impl Default for AmazonS3Builder { - fn default() -> Self { - Self { - access_key_id: None, - secret_access_key: None, - region: None, - bucket_name: None, - endpoint: None, - token: None, - max_connections: NonZeroUsize::new(16).unwrap(), - allow_http: false, - } - } -} - -impl AmazonS3Builder { - /// Create a new [`AmazonS3Builder`] with default values. - pub fn new() -> Self { - Default::default() - } - - /// Set the AWS Access Key (required) - pub fn with_access_key_id(mut self, access_key_id: impl Into) -> Self { - self.access_key_id = Some(access_key_id.into()); - self - } - - /// Set the AWS Secret Access Key (required) - pub fn with_secret_access_key( - mut self, - secret_access_key: impl Into, - ) -> Self { - self.secret_access_key = Some(secret_access_key.into()); - self - } - - /// Set the region (e.g. `us-east-1`) (required) - pub fn with_region(mut self, region: impl Into) -> Self { - self.region = Some(region.into()); - self - } - - /// Set the bucket_name (required) - pub fn with_bucket_name(mut self, bucket_name: impl Into) -> Self { - self.bucket_name = Some(bucket_name.into()); - self - } - - /// Sets the endpoint for communicating with AWS S3. Default value - /// is based on region. - /// - /// For example, this might be set to `"http://localhost:4566:` - /// for testing against a localstack instance. - pub fn with_endpoint(mut self, endpoint: impl Into) -> Self { - self.endpoint = Some(endpoint.into()); - self - } - - /// Set the token to use for requests (passed to underlying provider) - pub fn with_token(mut self, token: impl Into) -> Self { - self.token = Some(token.into()); - self - } - - /// Sets the maximum number of concurrent outstanding - /// connectons. Default is `16`. - #[deprecated(note = "use LimitStore instead")] - pub fn with_max_connections(mut self, max_connections: NonZeroUsize) -> Self { - self.max_connections = max_connections; - self - } - - /// Sets what protocol is allowed. If `allow_http` is : - /// * false (default): Only HTTPS are allowed - /// * true: HTTP and HTTPS are allowed - pub fn with_allow_http(mut self, allow_http: bool) -> Self { - self.allow_http = allow_http; - self - } - - /// Create a [`AmazonS3`] instance from the provided values, - /// consuming `self`. - pub fn build(self) -> Result { - let Self { - access_key_id, - secret_access_key, - region, - bucket_name, - endpoint, - token, - max_connections, - allow_http, - } = self; - - let region = region.ok_or(Error::MissingRegion {})?; - let bucket_name = bucket_name.ok_or(Error::MissingBucketName {})?; - - let region: rusoto_core::Region = match endpoint { - None => region.parse().context(InvalidRegionSnafu { region })?, - Some(endpoint) => rusoto_core::Region::Custom { - name: region, - endpoint, - }, - }; - - let mut builder = HyperBuilder::default(); - builder.pool_max_idle_per_host(max_connections.get()); - - let connector = if allow_http { - hyper_rustls::HttpsConnectorBuilder::new() - .with_webpki_roots() - .https_or_http() - .enable_http1() - .enable_http2() - .build() - } else { - hyper_rustls::HttpsConnectorBuilder::new() - .with_webpki_roots() - .https_only() - .enable_http1() - .enable_http2() - .build() - }; - - let http_client = - rusoto_core::request::HttpClient::from_builder(builder, connector); - - let client = match (access_key_id, secret_access_key, token) { - (Some(access_key_id), Some(secret_access_key), Some(token)) => { - let credentials_provider = StaticProvider::new( - access_key_id, - secret_access_key, - Some(token), - None, - ); - rusoto_s3::S3Client::new_with(http_client, credentials_provider, region) - } - (Some(access_key_id), Some(secret_access_key), None) => { - let credentials_provider = - StaticProvider::new_minimal(access_key_id, secret_access_key); - rusoto_s3::S3Client::new_with(http_client, credentials_provider, region) - } - (None, Some(_), _) => return Err(Error::MissingAccessKey.into()), - (Some(_), None, _) => return Err(Error::MissingSecretAccessKey.into()), - _ if std::env::var_os("AWS_WEB_IDENTITY_TOKEN_FILE").is_some() => { - rusoto_s3::S3Client::new_with( - http_client, - WebIdentityProvider::from_k8s_env(), - region, - ) - } - _ => rusoto_s3::S3Client::new_with( - http_client, - InstanceMetadataProvider::new(), - region, - ), - }; - - Ok(AmazonS3 { - client_unrestricted: client, - connection_semaphore: Arc::new(Semaphore::new(max_connections.get())), - bucket_name, - }) - } -} - -/// S3 client bundled w/ a semaphore permit. -#[derive(Clone)] -struct SemaphoreClient { - /// Permit for this specific use of the client. - /// - /// Note that this field is never read and therefore considered "dead code" by rustc. - #[allow(dead_code)] - permit: Arc, - - inner: rusoto_s3::S3Client, -} - -impl Deref for SemaphoreClient { - type Target = rusoto_s3::S3Client; - - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl AmazonS3 { - /// Get a client according to the current connection limit. - async fn client(&self) -> SemaphoreClient { - let permit = Arc::clone(&self.connection_semaphore) - .acquire_owned() - .await - .expect("semaphore shouldn't be closed yet"); - SemaphoreClient { - permit: Arc::new(permit), - inner: self.client_unrestricted.clone(), - } - } - - async fn get_object( - &self, - location: &Path, - range: Option>, - ) -> Result>> { - let key = location.to_string(); - let get_request = rusoto_s3::GetObjectRequest { - bucket: self.bucket_name.clone(), - key: key.clone(), - range: range.map(format_http_range), - ..Default::default() - }; - let bucket_name = self.bucket_name.clone(); - let stream = self - .client() - .await - .get_object(get_request) - .await - .map_err(|e| match e { - rusoto_core::RusotoError::Service( - rusoto_s3::GetObjectError::NoSuchKey(_), - ) => Error::NotFound { - path: key.clone(), - source: e.into(), - }, - _ => Error::UnableToGetData { - bucket: self.bucket_name.to_owned(), - path: key.clone(), - source: e, - }, - })? - .body - .context(NoDataSnafu { - bucket: self.bucket_name.to_owned(), - path: key.clone(), - })? - .map_err(move |source| Error::UnableToGetPieceOfData { - source, - bucket: bucket_name.clone(), - path: key.clone(), - }) - .err_into(); - - Ok(stream) - } - - async fn list_objects_v2( - &self, - prefix: Option<&Path>, - delimiter: Option, - ) -> Result>> { - enum ListState { - Start, - HasMore(String), - Done, - } - - let prefix = format_prefix(prefix); - let bucket = self.bucket_name.clone(); - - let request_factory = move || rusoto_s3::ListObjectsV2Request { - bucket, - prefix, - delimiter, - ..Default::default() - }; - let s3 = self.client().await; - - Ok(stream::unfold(ListState::Start, move |state| { - let request_factory = request_factory.clone(); - let s3 = s3.clone(); - - async move { - let continuation_token = match &state { - ListState::HasMore(continuation_token) => Some(continuation_token), - ListState::Done => { - return None; - } - // If this is the first request we've made, we don't need to make any - // modifications to the request - ListState::Start => None, - }; - - let resp = s3_request(move || { - let (s3, request_factory, continuation_token) = ( - s3.clone(), - request_factory.clone(), - continuation_token.cloned(), - ); - - async move { - s3.list_objects_v2(rusoto_s3::ListObjectsV2Request { - continuation_token, - ..request_factory() - }) - .await - } - }) - .await; - - let resp = match resp { - Ok(resp) => resp, - Err(e) => return Some((Err(e), state)), - }; - - // The AWS response contains a field named `is_truncated` as well as - // `next_continuation_token`, and we're assuming that `next_continuation_token` - // is only set when `is_truncated` is true (and therefore not - // checking `is_truncated`). - let next_state = if let Some(next_continuation_token) = - &resp.next_continuation_token - { - ListState::HasMore(next_continuation_token.to_string()) - } else { - ListState::Done - }; - - Some((Ok(resp), next_state)) - } - }) - .map_err(move |e| { - Error::UnableToListData { - source: e, - bucket: self.bucket_name.clone(), - } - .into() - }) - .boxed()) - } -} - -/// Handles retrying a request to S3 up to `MAX_NUM_RETRIES` times if S3 returns 5xx server errors. -/// -/// The `future_factory` argument is a function `F` that takes no arguments and, when called, will -/// return a `Future` (type `G`) that, when `await`ed, will perform a request to S3 through -/// `rusoto` and return a `Result` that returns some type `R` on success and some -/// `rusoto_core::RusotoError` on error. -/// -/// If the executed `Future` returns success, this function will return that success. -/// If the executed `Future` returns a 5xx server error, this function will wait an amount of -/// time that increases exponentially with the number of times it has retried, get a new `Future` by -/// calling `future_factory` again, and retry the request by `await`ing the `Future` again. -/// The retries will continue until the maximum number of retries has been attempted. In that case, -/// this function will return the last encountered error. -/// -/// Client errors (4xx) will never be retried by this function. -async fn s3_request( - future_factory: F, -) -> Result> -where - E: std::error::Error + Send, - F: Fn() -> G + Send, - G: Future>> + Send, - R: Send, -{ - let mut attempts = 0; - - loop { - let request = future_factory(); - - let result = request.await; - - match result { - Ok(r) => return Ok(r), - Err(error) => { - attempts += 1; - - let should_retry = matches!( - error, - rusoto_core::RusotoError::Unknown(ref response) - if response.status.is_server_error() - ); - - if attempts > MAX_NUM_RETRIES { - warn!( - ?error, - attempts, "maximum number of retries exceeded for AWS S3 request" - ); - return Err(error); - } else if !should_retry { - return Err(error); - } else { - debug!(?error, attempts, "retrying AWS S3 request"); - let wait_time = Duration::from_millis(2u64.pow(attempts) * 50); - tokio::time::sleep(wait_time).await; - } - } - } - } -} - -struct S3MultiPartUpload { - bucket: String, - key: String, - upload_id: String, - client_unrestricted: rusoto_s3::S3Client, - connection_semaphore: Arc, -} - -impl CloudMultiPartUploadImpl for S3MultiPartUpload { - fn put_multipart_part( - &self, - buf: Vec, - part_idx: usize, - ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>> { - // Get values to move into future; we don't want a reference to Self - let bucket = self.bucket.clone(); - let key = self.key.clone(); - let upload_id = self.upload_id.clone(); - let content_length = buf.len(); - - let request_factory = move || rusoto_s3::UploadPartRequest { - bucket, - key, - upload_id, - // AWS part number is 1-indexed - part_number: (part_idx + 1).try_into().unwrap(), - content_length: Some(content_length.try_into().unwrap()), - body: Some(buf.into()), - ..Default::default() - }; - - let s3 = self.client_unrestricted.clone(); - let connection_semaphore = Arc::clone(&self.connection_semaphore); - - Box::pin(async move { - let _permit = connection_semaphore - .acquire_owned() - .await - .expect("semaphore shouldn't be closed yet"); - - let response = s3_request(move || { - let (s3, request_factory) = (s3.clone(), request_factory.clone()); - async move { s3.upload_part(request_factory()).await } - }) - .await - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; - - Ok(( - part_idx, - UploadPart { - content_id: response.e_tag.unwrap(), - }, - )) - }) - } - - fn complete( - &self, - completed_parts: Vec>, - ) -> BoxFuture<'static, Result<(), io::Error>> { - let parts = - completed_parts - .into_iter() - .enumerate() - .map(|(part_number, maybe_part)| match maybe_part { - Some(part) => { - Ok(rusoto_s3::CompletedPart { - e_tag: Some(part.content_id), - part_number: Some((part_number + 1).try_into().map_err( - |err| io::Error::new(io::ErrorKind::Other, err), - )?), - }) - } - None => Err(io::Error::new( - io::ErrorKind::Other, - format!("Missing information for upload part {:?}", part_number), - )), - }); - - // Get values to move into future; we don't want a reference to Self - let bucket = self.bucket.clone(); - let key = self.key.clone(); - let upload_id = self.upload_id.clone(); - - let request_factory = move || -> Result<_, io::Error> { - Ok(rusoto_s3::CompleteMultipartUploadRequest { - bucket, - key, - upload_id, - multipart_upload: Some(rusoto_s3::CompletedMultipartUpload { - parts: Some(parts.collect::>()?), - }), - ..Default::default() - }) - }; - - let s3 = self.client_unrestricted.clone(); - let connection_semaphore = Arc::clone(&self.connection_semaphore); - - Box::pin(async move { - let _permit = connection_semaphore - .acquire_owned() - .await - .expect("semaphore shouldn't be closed yet"); - - s3_request(move || { - let (s3, request_factory) = (s3.clone(), request_factory.clone()); - - async move { s3.complete_multipart_upload(request_factory()?).await } - }) - .await - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; - - Ok(()) - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - tests::{ - get_nonexistent_object, list_uses_directories_correctly, list_with_delimiter, - put_get_delete_list, rename_and_copy, stream_get, - }, - Error as ObjectStoreError, - }; - use bytes::Bytes; - use std::env; - - const NON_EXISTENT_NAME: &str = "nonexistentname"; - - // Helper macro to skip tests if TEST_INTEGRATION and the AWS - // environment variables are not set. Returns a configured - // AmazonS3Builder - macro_rules! maybe_skip_integration { - () => {{ - dotenv::dotenv().ok(); - - let required_vars = [ - "AWS_DEFAULT_REGION", - "OBJECT_STORE_BUCKET", - "AWS_ACCESS_KEY_ID", - "AWS_SECRET_ACCESS_KEY", - ]; - let unset_vars: Vec<_> = required_vars - .iter() - .filter_map(|&name| match env::var(name) { - Ok(_) => None, - Err(_) => Some(name), - }) - .collect(); - let unset_var_names = unset_vars.join(", "); - - let force = env::var("TEST_INTEGRATION"); - - if force.is_ok() && !unset_var_names.is_empty() { - panic!( - "TEST_INTEGRATION is set, \ - but variable(s) {} need to be set", - unset_var_names - ); - } else if force.is_err() { - eprintln!( - "skipping AWS integration test - set {}TEST_INTEGRATION to run", - if unset_var_names.is_empty() { - String::new() - } else { - format!("{} and ", unset_var_names) - } - ); - return; - } else { - let config = AmazonS3Builder::new() - .with_access_key_id( - env::var("AWS_ACCESS_KEY_ID") - .expect("already checked AWS_ACCESS_KEY_ID"), - ) - .with_secret_access_key( - env::var("AWS_SECRET_ACCESS_KEY") - .expect("already checked AWS_SECRET_ACCESS_KEY"), - ) - .with_region( - env::var("AWS_DEFAULT_REGION") - .expect("already checked AWS_DEFAULT_REGION"), - ) - .with_bucket_name( - env::var("OBJECT_STORE_BUCKET") - .expect("already checked OBJECT_STORE_BUCKET"), - ) - .with_allow_http(true); - - let config = if let Some(endpoint) = env::var("AWS_ENDPOINT").ok() { - config.with_endpoint(endpoint) - } else { - config - }; - - let config = if let Some(token) = env::var("AWS_SESSION_TOKEN").ok() { - config.with_token(token) - } else { - config - }; - - config - } - }}; - } - - #[tokio::test] - async fn s3_test() { - let config = maybe_skip_integration!(); - let integration = config.build().unwrap(); - - put_get_delete_list(&integration).await; - list_uses_directories_correctly(&integration).await; - list_with_delimiter(&integration).await; - rename_and_copy(&integration).await; - stream_get(&integration).await; - } - - #[tokio::test] - async fn s3_test_get_nonexistent_location() { - let config = maybe_skip_integration!(); - let integration = config.build().unwrap(); - - let location = Path::from_iter([NON_EXISTENT_NAME]); - - let err = get_nonexistent_object(&integration, Some(location)) - .await - .unwrap_err(); - if let ObjectStoreError::NotFound { path, source } = err { - let source_variant = source.downcast_ref::>(); - assert!( - matches!( - source_variant, - Some(rusoto_core::RusotoError::Service( - rusoto_s3::GetObjectError::NoSuchKey(_) - )), - ), - "got: {:?}", - source_variant - ); - assert_eq!(path, NON_EXISTENT_NAME); - } else { - panic!("unexpected error type: {:?}", err); - } - } - - #[tokio::test] - async fn s3_test_get_nonexistent_bucket() { - let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME); - let integration = config.build().unwrap(); - - let location = Path::from_iter([NON_EXISTENT_NAME]); - - let err = integration.get(&location).await.unwrap_err().to_string(); - assert!( - err.contains("The specified bucket does not exist"), - "{}", - err - ) - } - - #[tokio::test] - async fn s3_test_put_nonexistent_bucket() { - let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME); - - let integration = config.build().unwrap(); - - let location = Path::from_iter([NON_EXISTENT_NAME]); - let data = Bytes::from("arbitrary data"); - - let err = integration - .put(&location, data) - .await - .unwrap_err() - .to_string(); - - assert!( - err.contains("The specified bucket does not exist") - && err.contains("Unable to PUT data"), - "{}", - err - ) - } - - #[tokio::test] - async fn s3_test_delete_nonexistent_location() { - let config = maybe_skip_integration!(); - let integration = config.build().unwrap(); - - let location = Path::from_iter([NON_EXISTENT_NAME]); - - integration.delete(&location).await.unwrap(); - } - - #[tokio::test] - async fn s3_test_delete_nonexistent_bucket() { - let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME); - let integration = config.build().unwrap(); - - let location = Path::from_iter([NON_EXISTENT_NAME]); - - let err = integration.delete(&location).await.unwrap_err().to_string(); - assert!( - err.contains("The specified bucket does not exist") - && err.contains("Unable to DELETE data"), - "{}", - err - ) - } -} diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs new file mode 100644 index 00000000000..36ba9ad126b --- /dev/null +++ b/object_store/src/aws/client.rs @@ -0,0 +1,483 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aws::credential::{AwsCredential, CredentialExt, CredentialProvider}; +use crate::client::pagination::stream_paginated; +use crate::client::retry::RetryExt; +use crate::multipart::UploadPart; +use crate::path::DELIMITER; +use crate::util::{format_http_range, format_prefix}; +use crate::{ + BoxStream, ListResult, MultipartId, ObjectMeta, Path, Result, RetryConfig, StreamExt, +}; +use bytes::{Buf, Bytes}; +use chrono::{DateTime, Utc}; +use percent_encoding::{utf8_percent_encode, AsciiSet, PercentEncode, NON_ALPHANUMERIC}; +use reqwest::{Client as ReqwestClient, Method, Response, StatusCode}; +use serde::{Deserialize, Serialize}; +use snafu::{ResultExt, Snafu}; +use std::ops::Range; +use std::sync::Arc; + +// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html +// +// Do not URI-encode any of the unreserved characters that RFC 3986 defines: +// A-Z, a-z, 0-9, hyphen ( - ), underscore ( _ ), period ( . ), and tilde ( ~ ). +const STRICT_ENCODE_SET: AsciiSet = NON_ALPHANUMERIC + .remove(b'-') + .remove(b'.') + .remove(b'_') + .remove(b'~'); + +/// This struct is used to maintain the URI path encoding +const STRICT_PATH_ENCODE_SET: AsciiSet = STRICT_ENCODE_SET.remove(b'/'); + +/// A specialized `Error` for object store-related errors +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +pub(crate) enum Error { + #[snafu(display("Error performing get request {}: {}", path, source))] + GetRequest { + source: reqwest::Error, + path: String, + }, + + #[snafu(display("Error performing put request {}: {}", path, source))] + PutRequest { + source: reqwest::Error, + path: String, + }, + + #[snafu(display("Error performing delete request {}: {}", path, source))] + DeleteRequest { + source: reqwest::Error, + path: String, + }, + + #[snafu(display("Error performing copy request {}: {}", path, source))] + CopyRequest { + source: reqwest::Error, + path: String, + }, + + #[snafu(display("Error performing list request: {}", source))] + ListRequest { source: reqwest::Error }, + + #[snafu(display("Error performing create multipart request: {}", source))] + CreateMultipartRequest { source: reqwest::Error }, + + #[snafu(display("Error performing complete multipart request: {}", source))] + CompleteMultipartRequest { source: reqwest::Error }, + + #[snafu(display("Got invalid list response: {}", source))] + InvalidListResponse { source: quick_xml::de::DeError }, + + #[snafu(display("Got invalid multipart response: {}", source))] + InvalidMultipartResponse { source: quick_xml::de::DeError }, +} + +impl From for crate::Error { + fn from(err: Error) -> Self { + match err { + Error::GetRequest { source, path } + | Error::DeleteRequest { source, path } + | Error::CopyRequest { source, path } + | Error::PutRequest { source, path } + if matches!(source.status(), Some(StatusCode::NOT_FOUND)) => + { + Self::NotFound { + path, + source: Box::new(source), + } + } + _ => Self::Generic { + store: "S3", + source: Box::new(err), + }, + } + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub struct ListResponse { + #[serde(default)] + pub contents: Vec, + #[serde(default)] + pub common_prefixes: Vec, + #[serde(default)] + pub next_continuation_token: Option, +} + +impl TryFrom for ListResult { + type Error = crate::Error; + + fn try_from(value: ListResponse) -> Result { + let common_prefixes = value + .common_prefixes + .into_iter() + .map(|x| Ok(Path::parse(&x.prefix)?)) + .collect::>()?; + + let objects = value + .contents + .into_iter() + .map(TryFrom::try_from) + .collect::>()?; + + Ok(Self { + common_prefixes, + objects, + }) + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub struct ListPrefix { + pub prefix: String, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub struct ListContents { + pub key: String, + pub size: usize, + pub last_modified: DateTime, +} + +impl TryFrom for ObjectMeta { + type Error = crate::Error; + + fn try_from(value: ListContents) -> Result { + Ok(Self { + location: Path::parse(value.key)?, + last_modified: value.last_modified, + size: value.size, + }) + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct InitiateMultipart { + upload_id: String, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "PascalCase", rename = "CompleteMultipartUpload")] +struct CompleteMultipart { + part: Vec, +} + +#[derive(Debug, Serialize)] +struct MultipartPart { + #[serde(rename = "$unflatten=ETag")] + e_tag: String, + #[serde(rename = "$unflatten=PartNumber")] + part_number: usize, +} + +#[derive(Debug)] +pub struct S3Config { + pub region: String, + pub endpoint: String, + pub bucket: String, + pub credentials: CredentialProvider, + pub retry_config: RetryConfig, + pub allow_http: bool, +} + +impl S3Config { + fn path_url(&self, path: &Path) -> String { + format!("{}/{}/{}", self.endpoint, self.bucket, encode_path(path)) + } +} + +#[derive(Debug)] +pub(crate) struct S3Client { + config: S3Config, + client: ReqwestClient, +} + +impl S3Client { + pub fn new(config: S3Config) -> Self { + let client = reqwest::ClientBuilder::new() + .https_only(!config.allow_http) + .build() + .unwrap(); + + Self { config, client } + } + + /// Returns the config + pub fn config(&self) -> &S3Config { + &self.config + } + + async fn get_credential(&self) -> Result> { + self.config.credentials.get_credential().await + } + + /// Make an S3 GET request + pub async fn get_request( + &self, + path: &Path, + range: Option>, + head: bool, + ) -> Result { + use reqwest::header::RANGE; + + let credential = self.get_credential().await?; + let url = self.config.path_url(path); + let method = match head { + true => Method::HEAD, + false => Method::GET, + }; + + let mut builder = self.client.request(method, url); + + if let Some(range) = range { + builder = builder.header(RANGE, format_http_range(range)); + } + + let response = builder + .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3") + .send_retry(&self.config.retry_config) + .await + .context(GetRequestSnafu { + path: path.as_ref(), + })? + .error_for_status() + .context(GetRequestSnafu { + path: path.as_ref(), + })?; + + Ok(response) + } + + /// Make an S3 PUT request + pub async fn put_request( + &self, + path: &Path, + bytes: Option, + query: &T, + ) -> Result { + let credential = self.get_credential().await?; + let url = self.config.path_url(path); + + let mut builder = self.client.request(Method::PUT, url); + if let Some(bytes) = bytes { + builder = builder.body(bytes) + } + + let response = builder + .query(query) + .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3") + .send_retry(&self.config.retry_config) + .await + .context(PutRequestSnafu { + path: path.as_ref(), + })? + .error_for_status() + .context(PutRequestSnafu { + path: path.as_ref(), + })?; + + Ok(response) + } + + /// Make an S3 Delete request + pub async fn delete_request( + &self, + path: &Path, + query: &T, + ) -> Result<()> { + let credential = self.get_credential().await?; + let url = self.config.path_url(path); + + self.client + .request(Method::DELETE, url) + .query(query) + .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3") + .send_retry(&self.config.retry_config) + .await + .context(DeleteRequestSnafu { + path: path.as_ref(), + })? + .error_for_status() + .context(DeleteRequestSnafu { + path: path.as_ref(), + })?; + + Ok(()) + } + + /// Make an S3 Copy request + pub async fn copy_request(&self, from: &Path, to: &Path) -> Result<()> { + let credential = self.get_credential().await?; + let url = self.config.path_url(to); + let source = format!("{}/{}", self.config.bucket, encode_path(from)); + + self.client + .request(Method::PUT, url) + .header("x-amz-copy-source", source) + .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3") + .send_retry(&self.config.retry_config) + .await + .context(CopyRequestSnafu { + path: from.as_ref(), + })? + .error_for_status() + .context(CopyRequestSnafu { + path: from.as_ref(), + })?; + + Ok(()) + } + + /// Make an S3 List request + async fn list_request( + &self, + prefix: Option<&str>, + delimiter: bool, + token: Option<&str>, + ) -> Result<(ListResult, Option)> { + let credential = self.get_credential().await?; + let url = format!("{}/{}", self.config.endpoint, self.config.bucket); + + let mut query = Vec::with_capacity(4); + + // Note: the order of these matters to ensure the generated URL is canonical + if let Some(token) = token { + query.push(("continuation-token", token)) + } + + if delimiter { + query.push(("delimiter", DELIMITER)) + } + + query.push(("list-type", "2")); + + if let Some(prefix) = prefix { + query.push(("prefix", prefix)) + } + + let response = self + .client + .request(Method::GET, &url) + .query(&query) + .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3") + .send_retry(&self.config.retry_config) + .await + .context(ListRequestSnafu)? + .error_for_status() + .context(ListRequestSnafu)? + .bytes() + .await + .context(ListRequestSnafu)?; + + let mut response: ListResponse = quick_xml::de::from_reader(response.reader()) + .context(InvalidListResponseSnafu)?; + let token = response.next_continuation_token.take(); + + Ok((response.try_into()?, token)) + } + + /// Perform a list operation automatically handling pagination + pub fn list_paginated( + &self, + prefix: Option<&Path>, + delimiter: bool, + ) -> BoxStream<'_, Result> { + let prefix = format_prefix(prefix); + stream_paginated(prefix, move |prefix, token| async move { + let (r, next_token) = self + .list_request(prefix.as_deref(), delimiter, token.as_deref()) + .await?; + Ok((r, prefix, next_token)) + }) + .boxed() + } + + pub async fn create_multipart(&self, location: &Path) -> Result { + let credential = self.get_credential().await?; + let url = format!( + "{}/{}/{}?uploads", + self.config.endpoint, + self.config.bucket, + encode_path(location) + ); + + let response = self + .client + .request(Method::POST, url) + .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3") + .send_retry(&self.config.retry_config) + .await + .context(CreateMultipartRequestSnafu)? + .error_for_status() + .context(CreateMultipartRequestSnafu)? + .bytes() + .await + .context(CreateMultipartRequestSnafu)?; + + let response: InitiateMultipart = quick_xml::de::from_reader(response.reader()) + .context(InvalidMultipartResponseSnafu)?; + + Ok(response.upload_id) + } + + pub async fn complete_multipart( + &self, + location: &Path, + upload_id: &str, + parts: Vec, + ) -> Result<()> { + let parts = parts + .into_iter() + .enumerate() + .map(|(part_idx, part)| MultipartPart { + e_tag: part.content_id, + part_number: part_idx + 1, + }) + .collect(); + + let request = CompleteMultipart { part: parts }; + let body = quick_xml::se::to_string(&request).unwrap(); + + let credential = self.get_credential().await?; + let url = self.config.path_url(location); + + self.client + .request(Method::POST, url) + .query(&[("uploadId", upload_id)]) + .body(body) + .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3") + .send_retry(&self.config.retry_config) + .await + .context(CompleteMultipartRequestSnafu)? + .error_for_status() + .context(CompleteMultipartRequestSnafu)?; + + Ok(()) + } +} + +fn encode_path(path: &Path) -> PercentEncode<'_> { + utf8_percent_encode(path.as_ref(), &STRICT_PATH_ENCODE_SET) +} diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs new file mode 100644 index 00000000000..b75005975bd --- /dev/null +++ b/object_store/src/aws/credential.rs @@ -0,0 +1,590 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::client::retry::RetryExt; +use crate::client::token::{TemporaryToken, TokenCache}; +use crate::{Result, RetryConfig}; +use bytes::Buf; +use chrono::{DateTime, Utc}; +use futures::TryFutureExt; +use reqwest::header::{HeaderMap, HeaderValue}; +use reqwest::{Client, Method, Request, RequestBuilder}; +use serde::Deserialize; +use std::collections::BTreeMap; +use std::sync::Arc; +use std::time::Instant; + +type StdError = Box; + +/// SHA256 hash of empty string +static EMPTY_SHA256_HASH: &str = + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; + +#[derive(Debug)] +pub struct AwsCredential { + pub key_id: String, + pub secret_key: String, + pub token: Option, +} + +impl AwsCredential { + /// Signs a string + /// + /// + fn sign( + &self, + to_sign: &str, + date: DateTime, + region: &str, + service: &str, + ) -> String { + let date_string = date.format("%Y%m%d").to_string(); + let date_hmac = hmac_sha256(format!("AWS4{}", self.secret_key), date_string); + let region_hmac = hmac_sha256(date_hmac, region); + let service_hmac = hmac_sha256(region_hmac, service); + let signing_hmac = hmac_sha256(service_hmac, b"aws4_request"); + hex_encode(hmac_sha256(signing_hmac, to_sign).as_ref()) + } +} + +struct RequestSigner<'a> { + date: DateTime, + credential: &'a AwsCredential, + service: &'a str, + region: &'a str, +} + +const DATE_HEADER: &str = "x-amz-date"; +const HASH_HEADER: &str = "x-amz-content-sha256"; +const TOKEN_HEADER: &str = "x-amz-security-token"; +const AUTH_HEADER: &str = "authorization"; + +const ALL_HEADERS: &[&str; 4] = &[DATE_HEADER, HASH_HEADER, TOKEN_HEADER, AUTH_HEADER]; + +impl<'a> RequestSigner<'a> { + fn sign(&self, request: &mut Request) { + if let Some(ref token) = self.credential.token { + let token_val = HeaderValue::from_str(token).unwrap(); + request.headers_mut().insert(TOKEN_HEADER, token_val); + } + + let host_val = HeaderValue::from_str( + &request.url()[url::Position::BeforeHost..url::Position::AfterPort], + ) + .unwrap(); + request.headers_mut().insert("host", host_val); + + let date_str = self.date.format("%Y%m%dT%H%M%SZ").to_string(); + let date_val = HeaderValue::from_str(&date_str).unwrap(); + request.headers_mut().insert(DATE_HEADER, date_val); + + let digest = match request.body() { + None => EMPTY_SHA256_HASH.to_string(), + Some(body) => hex_digest(body.as_bytes().unwrap()), + }; + + let header_digest = HeaderValue::from_str(&digest).unwrap(); + request.headers_mut().insert(HASH_HEADER, header_digest); + + let (signed_headers, canonical_headers) = canonicalize_headers(request.headers()); + + // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + let canonical_request = format!( + "{}\n{}\n{}\n{}\n{}\n{}", + request.method().as_str(), + request.url().path(), // S3 doesn't percent encode this like other services + request.url().query().unwrap_or(""), // This assumes the query pairs are in order + canonical_headers, + signed_headers, + digest + ); + + let hashed_canonical_request = hex_digest(canonical_request.as_bytes()); + let scope = format!( + "{}/{}/{}/aws4_request", + self.date.format("%Y%m%d"), + self.region, + self.service + ); + + let string_to_sign = format!( + "AWS4-HMAC-SHA256\n{}\n{}\n{}", + self.date.format("%Y%m%dT%H%M%SZ"), + scope, + hashed_canonical_request + ); + + // sign the string + let signature = + self.credential + .sign(&string_to_sign, self.date, self.region, self.service); + + // build the actual auth header + let authorisation = format!( + "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", + self.credential.key_id, scope, signed_headers, signature + ); + + let authorization_val = HeaderValue::from_str(&authorisation).unwrap(); + request.headers_mut().insert(AUTH_HEADER, authorization_val); + } +} + +pub trait CredentialExt { + /// Sign a request + fn with_aws_sigv4( + self, + credential: &AwsCredential, + region: &str, + service: &str, + ) -> Self; +} + +impl CredentialExt for RequestBuilder { + fn with_aws_sigv4( + mut self, + credential: &AwsCredential, + region: &str, + service: &str, + ) -> Self { + // Hack around lack of access to underlying request + // https://github.com/seanmonstar/reqwest/issues/1212 + let mut request = self + .try_clone() + .expect("not stream") + .build() + .expect("request valid"); + + let date = Utc::now(); + let signer = RequestSigner { + date, + credential, + service, + region, + }; + + signer.sign(&mut request); + + for header in ALL_HEADERS { + if let Some(val) = request.headers_mut().remove(*header) { + self = self.header(*header, val) + } + } + self + } +} + +fn hmac_sha256(secret: impl AsRef<[u8]>, bytes: impl AsRef<[u8]>) -> ring::hmac::Tag { + let key = ring::hmac::Key::new(ring::hmac::HMAC_SHA256, secret.as_ref()); + ring::hmac::sign(&key, bytes.as_ref()) +} + +/// Computes the SHA256 digest of `body` returned as a hex encoded string +fn hex_digest(bytes: &[u8]) -> String { + let digest = ring::digest::digest(&ring::digest::SHA256, bytes); + hex_encode(digest.as_ref()) +} + +/// Returns `bytes` as a lower-case hex encoded string +fn hex_encode(bytes: &[u8]) -> String { + use std::fmt::Write; + let mut out = String::with_capacity(bytes.len() * 2); + for byte in bytes { + // String writing is infallible + let _ = write!(out, "{:02x}", byte); + } + out +} + +/// Canonicalizes headers into the AWS Canonical Form. +/// +/// +fn canonicalize_headers(header_map: &HeaderMap) -> (String, String) { + let mut headers = BTreeMap::<&str, Vec<&str>>::new(); + let mut value_count = 0; + let mut value_bytes = 0; + let mut key_bytes = 0; + + for (key, value) in header_map { + let key = key.as_str(); + if ["authorization", "content-length", "user-agent"].contains(&key) { + continue; + } + + let value = std::str::from_utf8(value.as_bytes()).unwrap(); + key_bytes += key.len(); + value_bytes += value.len(); + value_count += 1; + headers.entry(key).or_default().push(value); + } + + let mut signed_headers = String::with_capacity(key_bytes + headers.len()); + let mut canonical_headers = + String::with_capacity(key_bytes + value_bytes + headers.len() + value_count); + + for (header_idx, (name, values)) in headers.into_iter().enumerate() { + if header_idx != 0 { + signed_headers.push(';'); + } + + signed_headers.push_str(name); + canonical_headers.push_str(name); + canonical_headers.push(':'); + for (value_idx, value) in values.into_iter().enumerate() { + if value_idx != 0 { + canonical_headers.push(','); + } + canonical_headers.push_str(value.trim()); + } + canonical_headers.push('\n'); + } + + (signed_headers, canonical_headers) +} + +/// Provides credentials for use when signing requests +#[derive(Debug)] +pub enum CredentialProvider { + Static(StaticCredentialProvider), + Instance(InstanceCredentialProvider), + WebIdentity(WebIdentityProvider), +} + +impl CredentialProvider { + pub async fn get_credential(&self) -> Result> { + match self { + Self::Static(s) => Ok(Arc::clone(&s.credential)), + Self::Instance(c) => c.get_credential().await, + Self::WebIdentity(c) => c.get_credential().await, + } + } +} + +/// A static set of credentials +#[derive(Debug)] +pub struct StaticCredentialProvider { + pub credential: Arc, +} + +/// Credentials sourced from the instance metadata service +/// +/// +#[derive(Debug)] +pub struct InstanceCredentialProvider { + pub cache: TokenCache>, + pub client: Client, + pub retry_config: RetryConfig, +} + +impl InstanceCredentialProvider { + async fn get_credential(&self) -> Result> { + self.cache + .get_or_insert_with(|| { + const METADATA_ENDPOINT: &str = "http://169.254.169.254"; + instance_creds(&self.client, &self.retry_config, METADATA_ENDPOINT) + .map_err(|source| crate::Error::Generic { + store: "S3", + source, + }) + }) + .await + } +} + +/// Credentials sourced using AssumeRoleWithWebIdentity +/// +/// +#[derive(Debug)] +pub struct WebIdentityProvider { + pub cache: TokenCache>, + pub token: String, + pub role_arn: String, + pub session_name: String, + pub endpoint: String, + pub client: Client, + pub retry_config: RetryConfig, +} + +impl WebIdentityProvider { + async fn get_credential(&self) -> Result> { + self.cache + .get_or_insert_with(|| { + web_identity( + &self.client, + &self.retry_config, + &self.token, + &self.role_arn, + &self.session_name, + &self.endpoint, + ) + .map_err(|source| crate::Error::Generic { + store: "S3", + source, + }) + }) + .await + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct InstanceCredentials { + access_key_id: String, + secret_access_key: String, + token: String, + expiration: DateTime, +} + +impl From for AwsCredential { + fn from(s: InstanceCredentials) -> Self { + Self { + key_id: s.access_key_id, + secret_key: s.secret_access_key, + token: Some(s.token), + } + } +} + +/// +async fn instance_creds( + client: &Client, + retry_config: &RetryConfig, + endpoint: &str, +) -> Result>, StdError> { + const CREDENTIALS_PATH: &str = "latest/meta-data/iam/security-credentials"; + const AWS_EC2_METADATA_TOKEN_HEADER: &str = "X-aws-ec2-metadata-token"; + + let token_url = format!("{}/latest/api/token", endpoint); + let token = client + .request(Method::PUT, token_url) + .header("X-aws-ec2-metadata-token-ttl-seconds", "600") // 10 minute TTL + .send_retry(retry_config) + .await? + .text() + .await?; + + let role_url = format!("{}/{}/", endpoint, CREDENTIALS_PATH); + let role = client + .request(Method::GET, role_url) + .header(AWS_EC2_METADATA_TOKEN_HEADER, &token) + .send_retry(retry_config) + .await? + .text() + .await?; + + let creds_url = format!("{}/{}/{}", endpoint, CREDENTIALS_PATH, role); + let creds: InstanceCredentials = client + .request(Method::GET, creds_url) + .header(AWS_EC2_METADATA_TOKEN_HEADER, &token) + .send_retry(retry_config) + .await? + .json() + .await?; + + let now = Utc::now(); + let ttl = (creds.expiration - now).to_std().unwrap_or_default(); + Ok(TemporaryToken { + token: Arc::new(creds.into()), + expiry: Instant::now() + ttl, + }) +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct AssumeRoleResponse { + assume_role_with_web_identity_result: AssumeRoleResult, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct AssumeRoleResult { + credentials: AssumeRoleCredentials, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct AssumeRoleCredentials { + session_token: String, + secret_access_key: String, + access_key_id: String, + expiration: DateTime, +} + +impl From for AwsCredential { + fn from(s: AssumeRoleCredentials) -> Self { + Self { + key_id: s.access_key_id, + secret_key: s.secret_access_key, + token: Some(s.session_token), + } + } +} + +/// +async fn web_identity( + client: &Client, + retry_config: &RetryConfig, + token: &str, + role_arn: &str, + session_name: &str, + endpoint: &str, +) -> Result>, StdError> { + let bytes = client + .request(Method::POST, endpoint) + .query(&[ + ("Action", "AssumeRoleWithWebIdentity"), + ("DurationSeconds", "3600"), + ("RoleArn", role_arn), + ("RoleSessionName", session_name), + ("Version", "2011-06-15"), + ("WebIdentityToken", token), + ]) + .send_retry(retry_config) + .await? + .bytes() + .await?; + + let resp: AssumeRoleResponse = quick_xml::de::from_reader(bytes.reader()) + .map_err(|e| format!("Invalid AssumeRoleWithWebIdentity response: {}", e))?; + + let creds = resp.assume_role_with_web_identity_result.credentials; + let now = Utc::now(); + let ttl = (creds.expiration - now).to_std().unwrap_or_default(); + + Ok(TemporaryToken { + token: Arc::new(creds.into()), + expiry: Instant::now() + ttl, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use reqwest::{Client, Method}; + use std::env; + + // Test generated using https://docs.aws.amazon.com/general/latest/gr/sigv4-signed-request-examples.html + #[test] + fn test_sign() { + let client = Client::new(); + + // Test credentials from https://docs.aws.amazon.com/AmazonS3/latest/userguide/RESTAuthentication.html + let credential = AwsCredential { + key_id: "AKIAIOSFODNN7EXAMPLE".to_string(), + secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(), + token: None, + }; + + // method = 'GET' + // service = 'ec2' + // host = 'ec2.amazonaws.com' + // region = 'us-east-1' + // endpoint = 'https://ec2.amazonaws.com' + // request_parameters = '' + let date = DateTime::parse_from_rfc3339("2022-08-06T18:01:34Z") + .unwrap() + .with_timezone(&Utc); + + let mut request = client + .request(Method::GET, "https://ec2.amazon.com/") + .build() + .unwrap(); + + let signer = RequestSigner { + date, + credential: &credential, + service: "ec2", + region: "us-east-1", + }; + + signer.sign(&mut request); + assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=a3c787a7ed37f7fdfbfd2d7056a3d7c9d85e6d52a2bfbec73793c0be6e7862d4") + } + + #[test] + fn test_sign_port() { + let client = Client::new(); + + let credential = AwsCredential { + key_id: "H20ABqCkLZID4rLe".to_string(), + secret_key: "jMqRDgxSsBqqznfmddGdu1TmmZOJQxdM".to_string(), + token: None, + }; + + let date = DateTime::parse_from_rfc3339("2022-08-09T13:05:25Z") + .unwrap() + .with_timezone(&Utc); + + let mut request = client + .request(Method::GET, "http://localhost:9000/tsm-schemas") + .query(&[ + ("delimiter", "/"), + ("encoding-type", "url"), + ("list-type", "2"), + ("prefix", ""), + ]) + .build() + .unwrap(); + + let signer = RequestSigner { + date, + credential: &credential, + service: "s3", + region: "us-east-1", + }; + + signer.sign(&mut request); + assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=H20ABqCkLZID4rLe/20220809/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=9ebf2f92872066c99ac94e573b4e1b80f4dbb8a32b1e8e23178318746e7d1b4d") + } + + #[tokio::test] + async fn test_instance_metadata() { + if env::var("TEST_INTEGRATION").is_err() { + eprintln!("skipping AWS integration test"); + } + + // For example https://github.com/aws/amazon-ec2-metadata-mock + let endpoint = env::var("EC2_METADATA_ENDPOINT").unwrap(); + let client = Client::new(); + let retry_config = RetryConfig::default(); + + // Verify only allows IMDSv2 + let resp = client + .request(Method::GET, format!("{}/latest/meta-data/ami-id", endpoint)) + .send() + .await + .unwrap(); + + assert_eq!( + resp.status(), + reqwest::StatusCode::UNAUTHORIZED, + "Ensure metadata endpoint is set to only allow IMDSv2" + ); + + let creds = instance_creds(&client, &retry_config, &endpoint) + .await + .unwrap(); + + let id = &creds.token.key_id; + let secret = &creds.token.secret_key; + let token = creds.token.token.as_ref().unwrap(); + + assert!(!id.is_empty()); + assert!(!secret.is_empty()); + assert!(!token.is_empty()) + } +} diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs new file mode 100644 index 00000000000..06d20ccc9e7 --- /dev/null +++ b/object_store/src/aws/mod.rs @@ -0,0 +1,646 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An object store implementation for S3 +//! +//! ## Multi-part uploads +//! +//! Multi-part uploads can be initiated with the [ObjectStore::put_multipart] method. +//! Data passed to the writer is automatically buffered to meet the minimum size +//! requirements for a part. Multiple parts are uploaded concurrently. +//! +//! If the writer fails for any reason, you may have parts uploaded to AWS but not +//! used that you may be charged for. Use the [ObjectStore::abort_multipart] method +//! to abort the upload and drop those unneeded parts. In addition, you may wish to +//! consider implementing [automatic cleanup] of unused parts that are older than one +//! week. +//! +//! [automatic cleanup]: https://aws.amazon.com/blogs/aws/s3-lifecycle-management-update-support-for-multipart-uploads-and-delete-markers/ + +use async_trait::async_trait; +use bytes::Bytes; +use chrono::{DateTime, Utc}; +use futures::stream::BoxStream; +use futures::TryStreamExt; +use reqwest::Client; +use snafu::{OptionExt, ResultExt, Snafu}; +use std::collections::BTreeSet; +use std::ops::Range; +use std::sync::Arc; +use tokio::io::AsyncWrite; +use tracing::info; + +use crate::aws::client::{S3Client, S3Config}; +use crate::aws::credential::{ + AwsCredential, CredentialProvider, InstanceCredentialProvider, + StaticCredentialProvider, WebIdentityProvider, +}; +use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}; +use crate::{ + GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Path, Result, + RetryConfig, StreamExt, +}; + +mod client; +mod credential; + +/// A specialized `Error` for object store-related errors +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +enum Error { + #[snafu(display("Last-Modified Header missing from response"))] + MissingLastModified, + + #[snafu(display("Content-Length Header missing from response"))] + MissingContentLength, + + #[snafu(display("Invalid last modified '{}': {}", last_modified, source))] + InvalidLastModified { + last_modified: String, + source: chrono::ParseError, + }, + + #[snafu(display("Invalid content length '{}': {}", content_length, source))] + InvalidContentLength { + content_length: String, + source: std::num::ParseIntError, + }, + + #[snafu(display("Missing region"))] + MissingRegion, + + #[snafu(display("Missing bucket name"))] + MissingBucketName, + + #[snafu(display("Missing AccessKeyId"))] + MissingAccessKeyId, + + #[snafu(display("Missing SecretAccessKey"))] + MissingSecretAccessKey, + + #[snafu(display("ETag Header missing from response"))] + MissingEtag, + + #[snafu(display("Received header containing non-ASCII data"))] + BadHeader { source: reqwest::header::ToStrError }, + + #[snafu(display("Error reading token file: {}", source))] + ReadTokenFile { source: std::io::Error }, +} + +impl From for super::Error { + fn from(err: Error) -> Self { + Self::Generic { + store: "S3", + source: Box::new(err), + } + } +} + +/// Interface for [Amazon S3](https://aws.amazon.com/s3/). +#[derive(Debug)] +pub struct AmazonS3 { + client: Arc, +} + +impl std::fmt::Display for AmazonS3 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "AmazonS3({})", self.client.config().bucket) + } +} + +#[async_trait] +impl ObjectStore for AmazonS3 { + async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + self.client.put_request(location, Some(bytes), &()).await?; + Ok(()) + } + + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)> { + let id = self.client.create_multipart(location).await?; + + let upload = S3MultiPartUpload { + location: location.clone(), + upload_id: id.clone(), + client: Arc::clone(&self.client), + }; + + Ok((id, Box::new(CloudMultiPartUpload::new(upload, 8)))) + } + + async fn abort_multipart( + &self, + location: &Path, + multipart_id: &MultipartId, + ) -> Result<()> { + self.client + .delete_request(location, &[("uploadId", multipart_id)]) + .await + } + + async fn get(&self, location: &Path) -> Result { + let response = self.client.get_request(location, None, false).await?; + let stream = response + .bytes_stream() + .map_err(|source| crate::Error::Generic { + store: "S3", + source: Box::new(source), + }) + .boxed(); + + Ok(GetResult::Stream(stream)) + } + + async fn get_range(&self, location: &Path, range: Range) -> Result { + let bytes = self + .client + .get_request(location, Some(range), false) + .await? + .bytes() + .await + .map_err(|source| client::Error::GetRequest { + source, + path: location.to_string(), + })?; + Ok(bytes) + } + + async fn head(&self, location: &Path) -> Result { + use reqwest::header::{CONTENT_LENGTH, LAST_MODIFIED}; + + // Extract meta from headers + // https://docs.aws.amazon.com/AmazonS3/latest/API/API_HeadObject.html#API_HeadObject_ResponseSyntax + let response = self.client.get_request(location, None, true).await?; + let headers = response.headers(); + + let last_modified = headers + .get(LAST_MODIFIED) + .context(MissingLastModifiedSnafu)?; + + let content_length = headers + .get(CONTENT_LENGTH) + .context(MissingContentLengthSnafu)?; + + let last_modified = last_modified.to_str().context(BadHeaderSnafu)?; + let last_modified = DateTime::parse_from_rfc2822(last_modified) + .context(InvalidLastModifiedSnafu { last_modified })? + .with_timezone(&Utc); + + let content_length = content_length.to_str().context(BadHeaderSnafu)?; + let content_length = content_length + .parse() + .context(InvalidContentLengthSnafu { content_length })?; + Ok(ObjectMeta { + location: location.clone(), + last_modified, + size: content_length, + }) + } + + async fn delete(&self, location: &Path) -> Result<()> { + self.client.delete_request(location, &()).await + } + + async fn list( + &self, + prefix: Option<&Path>, + ) -> Result>> { + let stream = self + .client + .list_paginated(prefix, false) + .map_ok(|r| futures::stream::iter(r.objects.into_iter().map(Ok))) + .try_flatten() + .boxed(); + + Ok(stream) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + let mut stream = self.client.list_paginated(prefix, true); + + let mut common_prefixes = BTreeSet::new(); + let mut objects = Vec::new(); + + while let Some(result) = stream.next().await { + let response = result?; + common_prefixes.extend(response.common_prefixes.into_iter()); + objects.extend(response.objects.into_iter()); + } + + Ok(ListResult { + common_prefixes: common_prefixes.into_iter().collect(), + objects, + }) + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + self.client.copy_request(from, to).await + } + + async fn copy_if_not_exists(&self, _source: &Path, _dest: &Path) -> Result<()> { + // Will need dynamodb_lock + Err(crate::Error::NotImplemented) + } +} + +struct S3MultiPartUpload { + location: Path, + upload_id: String, + client: Arc, +} + +#[async_trait] +impl CloudMultiPartUploadImpl for S3MultiPartUpload { + async fn put_multipart_part( + &self, + buf: Vec, + part_idx: usize, + ) -> Result { + use reqwest::header::ETAG; + let part = (part_idx + 1).to_string(); + + let response = self + .client + .put_request( + &self.location, + Some(buf.into()), + &[("partNumber", &part), ("uploadId", &self.upload_id)], + ) + .await?; + + let etag = response + .headers() + .get(ETAG) + .context(MissingEtagSnafu) + .map_err(crate::Error::from)?; + + let etag = etag + .to_str() + .context(BadHeaderSnafu) + .map_err(crate::Error::from)?; + + Ok(UploadPart { + content_id: etag.to_string(), + }) + } + + async fn complete( + &self, + completed_parts: Vec, + ) -> Result<(), std::io::Error> { + self.client + .complete_multipart(&self.location, &self.upload_id, completed_parts) + .await?; + Ok(()) + } +} + +/// Configure a connection to Amazon S3 using the specified credentials in +/// the specified Amazon region and bucket. +/// +/// # Example +/// ``` +/// # let REGION = "foo"; +/// # let BUCKET_NAME = "foo"; +/// # let ACCESS_KEY_ID = "foo"; +/// # let SECRET_KEY = "foo"; +/// # use object_store::aws::AmazonS3Builder; +/// let s3 = AmazonS3Builder::new() +/// .with_region(REGION) +/// .with_bucket_name(BUCKET_NAME) +/// .with_access_key_id(ACCESS_KEY_ID) +/// .with_secret_access_key(SECRET_KEY) +/// .build(); +/// ``` +#[derive(Debug, Default)] +pub struct AmazonS3Builder { + access_key_id: Option, + secret_access_key: Option, + region: Option, + bucket_name: Option, + endpoint: Option, + token: Option, + retry_config: RetryConfig, + allow_http: bool, +} + +impl AmazonS3Builder { + /// Create a new [`AmazonS3Builder`] with default values. + pub fn new() -> Self { + Default::default() + } + + /// Set the AWS Access Key (required) + pub fn with_access_key_id(mut self, access_key_id: impl Into) -> Self { + self.access_key_id = Some(access_key_id.into()); + self + } + + /// Set the AWS Secret Access Key (required) + pub fn with_secret_access_key( + mut self, + secret_access_key: impl Into, + ) -> Self { + self.secret_access_key = Some(secret_access_key.into()); + self + } + + /// Set the region (e.g. `us-east-1`) (required) + pub fn with_region(mut self, region: impl Into) -> Self { + self.region = Some(region.into()); + self + } + + /// Set the bucket_name (required) + pub fn with_bucket_name(mut self, bucket_name: impl Into) -> Self { + self.bucket_name = Some(bucket_name.into()); + self + } + + /// Sets the endpoint for communicating with AWS S3. Default value + /// is based on region. + /// + /// For example, this might be set to `"http://localhost:4566:` + /// for testing against a localstack instance. + pub fn with_endpoint(mut self, endpoint: impl Into) -> Self { + self.endpoint = Some(endpoint.into()); + self + } + + /// Set the token to use for requests (passed to underlying provider) + pub fn with_token(mut self, token: impl Into) -> Self { + self.token = Some(token.into()); + self + } + + /// Sets what protocol is allowed. If `allow_http` is : + /// * false (default): Only HTTPS are allowed + /// * true: HTTP and HTTPS are allowed + pub fn with_allow_http(mut self, allow_http: bool) -> Self { + self.allow_http = allow_http; + self + } + + /// Set the retry configuration + pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { + self.retry_config = retry_config; + self + } + + /// Create a [`AmazonS3`] instance from the provided values, + /// consuming `self`. + pub fn build(self) -> Result { + let bucket = self.bucket_name.context(MissingBucketNameSnafu)?; + let region = self.region.context(MissingRegionSnafu)?; + + let credentials = match (self.access_key_id, self.secret_access_key, self.token) { + (Some(key_id), Some(secret_key), token) => { + info!("Using Static credential provider"); + CredentialProvider::Static(StaticCredentialProvider { + credential: Arc::new(AwsCredential { + key_id, + secret_key, + token, + }), + }) + } + (None, Some(_), _) => return Err(Error::MissingAccessKeyId.into()), + (Some(_), None, _) => return Err(Error::MissingSecretAccessKey.into()), + // TODO: Replace with `AmazonS3Builder::credentials_from_env` + _ => match ( + std::env::var_os("AWS_WEB_IDENTITY_TOKEN_FILE"), + std::env::var("AWS_ROLE_ARN"), + ) { + (Some(token_file), Ok(role_arn)) => { + info!("Using WebIdentity credential provider"); + let token = std::fs::read_to_string(token_file) + .context(ReadTokenFileSnafu)?; + + let session_name = std::env::var("AWS_ROLE_SESSION_NAME") + .unwrap_or_else(|_| "WebIdentitySession".to_string()); + + let endpoint = format!("https://sts.{}.amazonaws.com", region); + + // Disallow non-HTTPs requests + let client = Client::builder().https_only(true).build().unwrap(); + + CredentialProvider::WebIdentity(WebIdentityProvider { + cache: Default::default(), + token, + session_name, + role_arn, + endpoint, + client, + retry_config: self.retry_config.clone(), + }) + } + _ => { + info!("Using Instance credential provider"); + + // The instance metadata endpoint is access over HTTP + let client = Client::builder().https_only(false).build().unwrap(); + + CredentialProvider::Instance(InstanceCredentialProvider { + cache: Default::default(), + client, + retry_config: self.retry_config.clone(), + }) + } + }, + }; + + let endpoint = self + .endpoint + .unwrap_or_else(|| format!("https://s3.{}.amazonaws.com", region)); + + let config = S3Config { + region, + endpoint, + bucket, + credentials, + retry_config: self.retry_config, + allow_http: self.allow_http, + }; + + let client = Arc::new(S3Client::new(config)); + + Ok(AmazonS3 { client }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::{ + get_nonexistent_object, list_uses_directories_correctly, list_with_delimiter, + put_get_delete_list, rename_and_copy, stream_get, + }; + use bytes::Bytes; + use std::env; + + const NON_EXISTENT_NAME: &str = "nonexistentname"; + + // Helper macro to skip tests if TEST_INTEGRATION and the AWS + // environment variables are not set. Returns a configured + // AmazonS3Builder + macro_rules! maybe_skip_integration { + () => {{ + dotenv::dotenv().ok(); + + let required_vars = [ + "AWS_DEFAULT_REGION", + "OBJECT_STORE_BUCKET", + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + ]; + let unset_vars: Vec<_> = required_vars + .iter() + .filter_map(|&name| match env::var(name) { + Ok(_) => None, + Err(_) => Some(name), + }) + .collect(); + let unset_var_names = unset_vars.join(", "); + + let force = env::var("TEST_INTEGRATION"); + + if force.is_ok() && !unset_var_names.is_empty() { + panic!( + "TEST_INTEGRATION is set, \ + but variable(s) {} need to be set", + unset_var_names + ); + } else if force.is_err() { + eprintln!( + "skipping AWS integration test - set {}TEST_INTEGRATION to run", + if unset_var_names.is_empty() { + String::new() + } else { + format!("{} and ", unset_var_names) + } + ); + return; + } else { + let config = AmazonS3Builder::new() + .with_access_key_id( + env::var("AWS_ACCESS_KEY_ID") + .expect("already checked AWS_ACCESS_KEY_ID"), + ) + .with_secret_access_key( + env::var("AWS_SECRET_ACCESS_KEY") + .expect("already checked AWS_SECRET_ACCESS_KEY"), + ) + .with_region( + env::var("AWS_DEFAULT_REGION") + .expect("already checked AWS_DEFAULT_REGION"), + ) + .with_bucket_name( + env::var("OBJECT_STORE_BUCKET") + .expect("already checked OBJECT_STORE_BUCKET"), + ) + .with_allow_http(true); + + let config = if let Some(endpoint) = env::var("AWS_ENDPOINT").ok() { + config.with_endpoint(endpoint) + } else { + config + }; + + let config = if let Some(token) = env::var("AWS_SESSION_TOKEN").ok() { + config.with_token(token) + } else { + config + }; + + config + } + }}; + } + + #[tokio::test] + async fn s3_test() { + let config = maybe_skip_integration!(); + let integration = config.build().unwrap(); + + put_get_delete_list(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + stream_get(&integration).await; + } + + #[tokio::test] + async fn s3_test_get_nonexistent_location() { + let config = maybe_skip_integration!(); + let integration = config.build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = get_nonexistent_object(&integration, Some(location)) + .await + .unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + } + + #[tokio::test] + async fn s3_test_get_nonexistent_bucket() { + let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME); + let integration = config.build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = integration.get(&location).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + } + + #[tokio::test] + async fn s3_test_put_nonexistent_bucket() { + let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME); + + let integration = config.build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + let data = Bytes::from("arbitrary data"); + + let err = integration.put(&location, data).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + } + + #[tokio::test] + async fn s3_test_delete_nonexistent_location() { + let config = maybe_skip_integration!(); + let integration = config.build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + integration.delete(&location).await.unwrap(); + } + + #[tokio::test] + async fn s3_test_delete_nonexistent_bucket() { + let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME); + let integration = config.build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = integration.delete(&location).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + } +} diff --git a/object_store/src/azure.rs b/object_store/src/azure.rs index 9987c0370df..a9dbc53e22a 100644 --- a/object_store/src/azure.rs +++ b/object_store/src/azure.rs @@ -49,7 +49,7 @@ use azure_storage_blobs::prelude::{ }; use bytes::Bytes; use chrono::{TimeZone, Utc}; -use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt}; +use futures::{stream::BoxStream, StreamExt, TryStreamExt}; use snafu::{ResultExt, Snafu}; use std::collections::BTreeSet; use std::fmt::{Debug, Formatter}; @@ -765,70 +765,47 @@ impl AzureMultiPartUpload { } } +#[async_trait] impl CloudMultiPartUploadImpl for AzureMultiPartUpload { - fn put_multipart_part( + async fn put_multipart_part( &self, buf: Vec, part_idx: usize, - ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>> { - let client = Arc::clone(&self.container_client); - let location = self.location.clone(); + ) -> Result { let block_id = self.get_block_id(part_idx); - Box::pin(async move { - client - .blob_client(location.as_ref()) - .put_block(block_id.clone(), buf) - .into_future() - .await - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + self.container_client + .blob_client(self.location.as_ref()) + .put_block(block_id.clone(), buf) + .into_future() + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; - Ok(( - part_idx, - UploadPart { - content_id: block_id, - }, - )) + Ok(UploadPart { + content_id: block_id, }) } - fn complete( - &self, - completed_parts: Vec>, - ) -> BoxFuture<'static, Result<(), io::Error>> { - let parts = - completed_parts - .into_iter() - .enumerate() - .map(|(part_number, maybe_part)| match maybe_part { - Some(part) => { - Ok(azure_storage_blobs::blob::BlobBlockType::Uncommitted( - azure_storage_blobs::prelude::BlockId::new(part.content_id), - )) - } - None => Err(io::Error::new( - io::ErrorKind::Other, - format!("Missing information for upload part {:?}", part_number), - )), - }); - - let client = Arc::clone(&self.container_client); - let location = self.location.clone(); - - Box::pin(async move { - let block_list = azure_storage_blobs::blob::BlockList { - blocks: parts.collect::>()?, - }; - - client - .blob_client(location.as_ref()) - .put_block_list(block_list) - .into_future() - .await - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + async fn complete(&self, completed_parts: Vec) -> Result<(), io::Error> { + let blocks = completed_parts + .into_iter() + .map(|part| { + azure_storage_blobs::blob::BlobBlockType::Uncommitted( + azure_storage_blobs::prelude::BlockId::new(part.content_id), + ) + }) + .collect(); - Ok(()) - }) + let block_list = azure_storage_blobs::blob::BlockList { blocks }; + + self.container_client + .blob_client(self.location.as_ref()) + .put_block_list(block_list) + .into_future() + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + + Ok(()) } } diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs index 1166ebe7a52..7241002a0bd 100644 --- a/object_store/src/client/mod.rs +++ b/object_store/src/client/mod.rs @@ -18,6 +18,8 @@ //! Generic utilities reqwest based ObjectStore implementations pub mod backoff; +#[cfg(feature = "gcp")] pub mod oauth; +pub mod pagination; pub mod retry; pub mod token; diff --git a/object_store/src/client/pagination.rs b/object_store/src/client/pagination.rs new file mode 100644 index 00000000000..3ab17fe8b5a --- /dev/null +++ b/object_store/src/client/pagination.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Result; +use futures::Stream; +use std::future::Future; + +/// Takes a paginated operation `op` that when called with: +/// +/// - A state `S` +/// - An optional next token `Option` +/// +/// Returns +/// +/// - A response value `T` +/// - The next state `S` +/// - The next continuation token `Option` +/// +/// And converts it into a `Stream>` which will first call `op(state, None)`, and yield +/// the returned response `T`. If the returned continuation token was `None` the stream will then +/// finish, otherwise it will continue to call `op(state, token)` with the values returned by the +/// previous call to `op`, until a continuation token of `None` is returned +/// +pub fn stream_paginated(state: S, op: F) -> impl Stream> +where + F: Fn(S, Option) -> Fut + Copy, + Fut: Future)>>, +{ + enum PaginationState { + Start(T), + HasMore(T, String), + Done, + } + + futures::stream::unfold(PaginationState::Start(state), move |state| async move { + let (s, page_token) = match state { + PaginationState::Start(s) => (s, None), + PaginationState::HasMore(s, page_token) => (s, Some(page_token)), + PaginationState::Done => { + return None; + } + }; + + let (resp, s, continuation) = match op(s, page_token).await { + Ok(resp) => resp, + Err(e) => return Some((Err(e), PaginationState::Done)), + }; + + let next_state = match continuation { + Some(token) => PaginationState::HasMore(s, token), + None => PaginationState::Done, + }; + + Some((Ok(resp), next_state)) + }) +} diff --git a/object_store/src/client/token.rs b/object_store/src/client/token.rs index a56a29462b1..2ff28616e60 100644 --- a/object_store/src/client/token.rs +++ b/object_store/src/client/token.rs @@ -30,11 +30,19 @@ pub struct TemporaryToken { /// Provides [`TokenCache::get_or_insert_with`] which can be used to cache a /// [`TemporaryToken`] based on its expiry -#[derive(Debug, Default)] +#[derive(Debug)] pub struct TokenCache { cache: Mutex>>, } +impl Default for TokenCache { + fn default() -> Self { + Self { + cache: Default::default(), + } + } +} + impl TokenCache { pub async fn get_or_insert_with(&self, f: F) -> Result where diff --git a/object_store/src/gcp.rs b/object_store/src/gcp.rs index 0dc5a956ac0..c9bb6335973 100644 --- a/object_store/src/gcp.rs +++ b/object_store/src/gcp.rs @@ -38,7 +38,6 @@ use std::sync::Arc; use async_trait::async_trait; use bytes::{Buf, Bytes}; use chrono::{DateTime, Utc}; -use futures::future::BoxFuture; use futures::{stream::BoxStream, StreamExt, TryStreamExt}; use percent_encoding::{percent_encode, NON_ALPHANUMERIC}; use reqwest::header::RANGE; @@ -46,6 +45,7 @@ use reqwest::{header, Client, Method, Response, StatusCode}; use snafu::{ResultExt, Snafu}; use tokio::io::AsyncWrite; +use crate::client::pagination::stream_paginated; use crate::client::retry::RetryExt; use crate::{ client::{oauth::OAuthProvider, token::TokenCache}, @@ -476,44 +476,16 @@ impl GoogleCloudStorageClient { &self, prefix: Option<&Path>, delimiter: bool, - ) -> Result>> { + ) -> BoxStream<'_, Result> { let prefix = format_prefix(prefix); - - enum ListState { - Start, - HasMore(String), - Done, - } - - Ok(futures::stream::unfold(ListState::Start, move |state| { - let prefix = prefix.clone(); - - async move { - let page_token = match &state { - ListState::Start => None, - ListState::HasMore(page_token) => Some(page_token.as_str()), - ListState::Done => { - return None; - } - }; - - let resp = match self - .list_request(prefix.as_deref(), delimiter, page_token) - .await - { - Ok(resp) => resp, - Err(e) => return Some((Err(e), state)), - }; - - let next_state = match &resp.next_page_token { - Some(token) => ListState::HasMore(token.clone()), - None => ListState::Done, - }; - - Some((Ok(resp), next_state)) - } + stream_paginated(prefix, move |prefix, token| async move { + let mut r = self + .list_request(prefix.as_deref(), delimiter, token.as_deref()) + .await?; + let next_token = r.next_page_token.take(); + Ok((r, prefix, next_token)) }) - .boxed()) + .boxed() } } @@ -544,116 +516,105 @@ struct GCSMultipartUpload { multipart_id: MultipartId, } +#[async_trait] impl CloudMultiPartUploadImpl for GCSMultipartUpload { /// Upload an object part - fn put_multipart_part( + async fn put_multipart_part( &self, buf: Vec, part_idx: usize, - ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>> { + ) -> Result { let upload_id = self.multipart_id.clone(); let url = format!( "{}/{}/{}", self.client.base_url, self.client.bucket_name_encoded, self.encoded_path ); - let client = Arc::clone(&self.client); - - Box::pin(async move { - let token = client - .get_token() - .await - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; - - let response = client - .client - .request(Method::PUT, &url) - .bearer_auth(token) - .query(&[ - ("partNumber", format!("{}", part_idx + 1)), - ("uploadId", upload_id), - ]) - .header(header::CONTENT_TYPE, "application/octet-stream") - .header(header::CONTENT_LENGTH, format!("{}", buf.len())) - .body(buf) - .send_retry(&client.retry_config) - .await - .map_err(reqwest_error_as_io)? - .error_for_status() - .map_err(reqwest_error_as_io)?; - - let content_id = response - .headers() - .get("ETag") - .ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidData, - "response headers missing ETag", - ) - })? - .to_str() - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))? - .to_string(); - Ok((part_idx, UploadPart { content_id })) - }) + let token = self + .client + .get_token() + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + + let response = self + .client + .client + .request(Method::PUT, &url) + .bearer_auth(token) + .query(&[ + ("partNumber", format!("{}", part_idx + 1)), + ("uploadId", upload_id), + ]) + .header(header::CONTENT_TYPE, "application/octet-stream") + .header(header::CONTENT_LENGTH, format!("{}", buf.len())) + .body(buf) + .send_retry(&self.client.retry_config) + .await + .map_err(reqwest_error_as_io)? + .error_for_status() + .map_err(reqwest_error_as_io)?; + + let content_id = response + .headers() + .get("ETag") + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "response headers missing ETag", + ) + })? + .to_str() + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))? + .to_string(); + + Ok(UploadPart { content_id }) } /// Complete a multipart upload - fn complete( - &self, - completed_parts: Vec>, - ) -> BoxFuture<'static, Result<(), io::Error>> { - let client = Arc::clone(&self.client); + async fn complete(&self, completed_parts: Vec) -> Result<(), io::Error> { let upload_id = self.multipart_id.clone(); let url = format!( "{}/{}/{}", self.client.base_url, self.client.bucket_name_encoded, self.encoded_path ); - Box::pin(async move { - let parts: Vec = completed_parts - .into_iter() - .enumerate() - .map(|(part_number, maybe_part)| match maybe_part { - Some(part) => Ok(MultipartPart { - e_tag: part.content_id, - part_number: part_number + 1, - }), - None => Err(io::Error::new( - io::ErrorKind::Other, - format!("Missing information for upload part {:?}", part_number), - )), - }) - .collect::, io::Error>>()?; - - let token = client - .get_token() - .await - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; - - let upload_info = CompleteMultipartUpload { parts }; - - let data = quick_xml::se::to_string(&upload_info) - .map_err(|err| io::Error::new(io::ErrorKind::Other, err))? - // We cannot disable the escaping that transforms "/" to ""e;" :( - // https://github.com/tafia/quick-xml/issues/362 - // https://github.com/tafia/quick-xml/issues/350 - .replace(""", "\""); - - client - .client - .request(Method::POST, &url) - .bearer_auth(token) - .query(&[("uploadId", upload_id)]) - .body(data) - .send_retry(&client.retry_config) - .await - .map_err(reqwest_error_as_io)? - .error_for_status() - .map_err(reqwest_error_as_io)?; - - Ok(()) - }) + let parts = completed_parts + .into_iter() + .enumerate() + .map(|(part_number, part)| MultipartPart { + e_tag: part.content_id, + part_number: part_number + 1, + }) + .collect(); + + let token = self + .client + .get_token() + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + + let upload_info = CompleteMultipartUpload { parts }; + + let data = quick_xml::se::to_string(&upload_info) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))? + // We cannot disable the escaping that transforms "/" to ""e;" :( + // https://github.com/tafia/quick-xml/issues/362 + // https://github.com/tafia/quick-xml/issues/350 + .replace(""", "\""); + + self.client + .client + .request(Method::POST, &url) + .bearer_auth(token) + .query(&[("uploadId", upload_id)]) + .body(data) + .send_retry(&self.client.retry_config) + .await + .map_err(reqwest_error_as_io)? + .error_for_status() + .map_err(reqwest_error_as_io)?; + + Ok(()) } } @@ -734,7 +695,7 @@ impl ObjectStore for GoogleCloudStorage { ) -> Result>> { let stream = self .client - .list_paginated(prefix, false)? + .list_paginated(prefix, false) .map_ok(|r| { futures::stream::iter( r.items.into_iter().map(|x| convert_object_meta(&x)), @@ -747,7 +708,7 @@ impl ObjectStore for GoogleCloudStorage { } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { - let mut stream = self.client.list_paginated(prefix, true)?; + let mut stream = self.client.list_paginated(prefix, true); let mut common_prefixes = BTreeSet::new(); let mut objects = Vec::new(); diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index f7adedb2682..374f5592e84 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -165,10 +165,10 @@ pub mod memory; pub mod path; pub mod throttle; -#[cfg(feature = "gcp")] +#[cfg(any(feature = "gcp", feature = "aws"))] mod client; -#[cfg(feature = "gcp")] +#[cfg(any(feature = "gcp", feature = "aws"))] pub use client::{backoff::BackoffConfig, retry::RetryConfig}; #[cfg(any(feature = "azure", feature = "aws", feature = "gcp"))] @@ -471,6 +471,16 @@ pub enum Error { OAuth { source: client::oauth::Error }, } +impl From for std::io::Error { + fn from(e: Error) -> Self { + let kind = match &e { + Error::NotFound { .. } => std::io::ErrorKind::NotFound, + _ => std::io::ErrorKind::Other, + }; + Self::new(kind, e) + } +} + #[cfg(test)] mod test_util { use super::*; diff --git a/object_store/src/multipart.rs b/object_store/src/multipart.rs index c16022d3735..1985d8694e5 100644 --- a/object_store/src/multipart.rs +++ b/object_store/src/multipart.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use futures::{future::BoxFuture, stream::FuturesUnordered, Future, StreamExt}; +use async_trait::async_trait; +use futures::{stream::FuturesUnordered, Future, StreamExt}; use std::{io, pin::Pin, sync::Arc, task::Poll}; use tokio::io::AsyncWrite; @@ -26,23 +27,19 @@ type BoxedTryFuture = Pin> + Sen /// A trait that can be implemented by cloud-based object stores /// and used in combination with [`CloudMultiPartUpload`] to provide /// multipart upload support -/// -/// Note: this does not use AsyncTrait as the lifetimes are difficult to manage -pub(crate) trait CloudMultiPartUploadImpl { +#[async_trait] +pub(crate) trait CloudMultiPartUploadImpl: 'static { /// Upload a single part - fn put_multipart_part( + async fn put_multipart_part( &self, buf: Vec, part_idx: usize, - ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>>; + ) -> Result; /// Complete the upload with the provided parts /// /// `completed_parts` is in order of part number - fn complete( - &self, - completed_parts: Vec>, - ) -> BoxFuture<'static, Result<(), io::Error>>; + async fn complete(&self, completed_parts: Vec) -> Result<(), io::Error>; } #[derive(Debug, Clone)] @@ -128,10 +125,12 @@ where self.current_buffer.extend_from_slice(buf); let out_buffer = std::mem::take(&mut self.current_buffer); - let task = self - .inner - .put_multipart_part(out_buffer, self.current_part_idx); - self.tasks.push(task); + let inner = Arc::clone(&self.inner); + let part_idx = self.current_part_idx; + self.tasks.push(Box::pin(async move { + let upload_part = inner.put_multipart_part(out_buffer, part_idx).await?; + Ok((part_idx, upload_part)) + })); self.current_part_idx += 1; // We need to poll immediately after adding to setup waker @@ -157,10 +156,12 @@ where // If current_buffer is not empty, see if it can be submitted if !self.current_buffer.is_empty() && self.tasks.len() < self.max_concurrency { let out_buffer: Vec = std::mem::take(&mut self.current_buffer); - let task = self - .inner - .put_multipart_part(out_buffer, self.current_part_idx); - self.tasks.push(task); + let inner = Arc::clone(&self.inner); + let part_idx = self.current_part_idx; + self.tasks.push(Box::pin(async move { + let upload_part = inner.put_multipart_part(out_buffer, part_idx).await?; + Ok((part_idx, upload_part)) + })); } self.as_mut().poll_tasks(cx)?; @@ -185,10 +186,26 @@ where // If shutdown task is not set, set it let parts = std::mem::take(&mut self.completed_parts); + let parts = parts + .into_iter() + .enumerate() + .map(|(idx, part)| { + part.ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + format!("Missing information for upload part {}", idx), + ) + }) + }) + .collect::>()?; + let inner = Arc::clone(&self.inner); - let completion_task = self - .completion_task - .get_or_insert_with(|| inner.complete(parts)); + let completion_task = self.completion_task.get_or_insert_with(|| { + Box::pin(async move { + inner.complete(parts).await?; + Ok(()) + }) + }); Pin::new(completion_task).poll(cx) }