From 19fd8858991bcf8e654c221e6956ce6a8b5a86e1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 26 Jul 2022 10:20:27 -0400 Subject: [PATCH] Port Add stream upload (multi-part upload) (#2147) * feat: Add stream upload (multi-part upload) (#20) * feat: Implement multi-part upload Co-authored-by: Raphael Taylor-Davies * chore: simplify local file implementation * chore: Remove pin-project * feat: make cleanup_upload() top-level * docs: Add some docs for upload * chore: fix linting issue * fix: rename to put_multipart * feat: Implement multi-part upload for GCP * fix: Get GCS test to pass * chore: remove more upload language * fix: Add guard to test so we don't run with fake gcs server * chore: small tweaks * fix: apply suggestions from code review Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> * feat: switch to quick-xml * feat: remove throttle implementation of multipart * fix: rename from cleanup to abort * feat: enforce upload not readable until shutdown * fix: ensure we close files before moving them * chore: fix lint issue Co-authored-by: Raphael Taylor-Davies Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> * fmt * RAT multipart * Fix build * fix: merge issue Co-authored-by: Will Jones Co-authored-by: Raphael Taylor-Davies Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> --- object_store/Cargo.toml | 5 +- object_store/src/aws.rs | 231 +++++++++++++++++++++- object_store/src/azure.rs | 125 +++++++++++- object_store/src/gcp.rs | 340 +++++++++++++++++++++++++++++--- object_store/src/lib.rs | 105 +++++++++- object_store/src/local.rs | 361 +++++++++++++++++++++++++++++++--- object_store/src/memory.rs | 69 ++++++- object_store/src/multipart.rs | 195 ++++++++++++++++++ object_store/src/throttle.rs | 17 ++ 9 files changed, 1392 insertions(+), 56 deletions(-) create mode 100644 object_store/src/multipart.rs diff --git a/object_store/Cargo.toml b/object_store/Cargo.toml index 613b6ab2ecf..74153989159 100644 --- a/object_store/Cargo.toml +++ b/object_store/Cargo.toml @@ -44,6 +44,7 @@ chrono = { version = "0.4", default-features = false, features = ["clock"] } futures = "0.3" serde = { version = "1.0", default-features = false, features = ["derive"], optional = true } serde_json = { version = "1.0", default-features = false, optional = true } +quick-xml = { version = "0.23.0", features = ["serialize"], optional = true } rustls-pemfile = { version = "1.0", default-features = false, optional = true } ring = { version = "0.16", default-features = false, features = ["std"] } base64 = { version = "0.13", default-features = false, optional = true } @@ -59,7 +60,7 @@ rusoto_credential = { version = "0.48.0", optional = true, default-features = fa rusoto_s3 = { version = "0.48.0", optional = true, default-features = false, features = ["rustls"] } rusoto_sts = { version = "0.48.0", optional = true, default-features = false, features = ["rustls"] } snafu = "0.7" -tokio = { version = "1.18", features = ["sync", "macros", "parking_lot", "rt-multi-thread", "time"] } +tokio = { version = "1.18", features = ["sync", "macros", "parking_lot", "rt-multi-thread", "time", "io-util"] } tracing = { version = "0.1" } reqwest = { version = "0.11", optional = true, default-features = false, features = ["rustls-tls"] } parking_lot = { version = "0.12" } @@ -70,7 +71,7 @@ walkdir = "2" [features] azure = ["azure_core", "azure_storage_blobs", "azure_storage", "reqwest"] azure_test = ["azure", "azure_core/azurite_workaround", "azure_storage/azurite_workaround", "azure_storage_blobs/azurite_workaround"] -gcp = ["serde", "serde_json", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "rustls-pemfile", "base64"] +gcp = ["serde", "serde_json", "quick-xml", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "rustls-pemfile", "base64"] aws = ["rusoto_core", "rusoto_credential", "rusoto_s3", "rusoto_sts", "hyper", "hyper-rustls"] [dev-dependencies] # In alphabetical order diff --git a/object_store/src/aws.rs b/object_store/src/aws.rs index 7ebcc2a8841..3606a3806f9 100644 --- a/object_store/src/aws.rs +++ b/object_store/src/aws.rs @@ -16,7 +16,23 @@ // under the License. //! An object store implementation for S3 +//! +//! ## Multi-part uploads +//! +//! Multi-part uploads can be initiated with the [ObjectStore::put_multipart] method. +//! Data passed to the writer is automatically buffered to meet the minimum size +//! requirements for a part. Multiple parts are uploaded concurrently. +//! +//! If the writer fails for any reason, you may have parts uploaded to AWS but not +//! used that you may be charged for. Use the [ObjectStore::abort_multipart] method +//! to abort the upload and drop those unneeded parts. In addition, you may wish to +//! consider implementing [automatic cleanup] of unused parts that are older than one +//! week. +//! +//! [automatic cleanup]: https://aws.amazon.com/blogs/aws/s3-lifecycle-management-update-support-for-multipart-uploads-and-delete-markers/ +use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}; use crate::util::format_http_range; +use crate::MultipartId; use crate::{ collect_bytes, path::{Path, DELIMITER}, @@ -26,6 +42,7 @@ use crate::{ use async_trait::async_trait; use bytes::Bytes; use chrono::{DateTime, Utc}; +use futures::future::BoxFuture; use futures::{ stream::{self, BoxStream}, Future, Stream, StreamExt, TryStreamExt, @@ -36,10 +53,12 @@ use rusoto_credential::{InstanceMetadataProvider, StaticProvider}; use rusoto_s3::S3; use rusoto_sts::WebIdentityProvider; use snafu::{OptionExt, ResultExt, Snafu}; +use std::io; use std::ops::Range; use std::{ convert::TryFrom, fmt, num::NonZeroUsize, ops::Deref, sync::Arc, time::Duration, }; +use tokio::io::AsyncWrite; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tracing::{debug, warn}; @@ -129,6 +148,32 @@ enum Error { path: String, }, + #[snafu(display( + "Unable to upload data. Bucket: {}, Location: {}, Error: {} ({:?})", + bucket, + path, + source, + source, + ))] + UnableToUploadData { + source: rusoto_core::RusotoError, + bucket: String, + path: String, + }, + + #[snafu(display( + "Unable to cleanup multipart data. Bucket: {}, Location: {}, Error: {} ({:?})", + bucket, + path, + source, + source, + ))] + UnableToCleanupMultipartData { + source: rusoto_core::RusotoError, + bucket: String, + path: String, + }, + #[snafu(display( "Unable to list data. Bucket: {}, Error: {} ({:?})", bucket, @@ -272,6 +317,71 @@ impl ObjectStore for AmazonS3 { Ok(()) } + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)> { + let bucket_name = self.bucket_name.clone(); + + let request_factory = move || rusoto_s3::CreateMultipartUploadRequest { + bucket: bucket_name.clone(), + key: location.to_string(), + ..Default::default() + }; + + let s3 = self.client().await; + + let data = s3_request(move || { + let (s3, request_factory) = (s3.clone(), request_factory.clone()); + + async move { s3.create_multipart_upload(request_factory()).await } + }) + .await + .context(UnableToUploadDataSnafu { + bucket: &self.bucket_name, + path: location.as_ref(), + })?; + + let upload_id = data.upload_id.unwrap(); + + let inner = S3MultiPartUpload { + upload_id: upload_id.clone(), + bucket: self.bucket_name.clone(), + key: location.to_string(), + client_unrestricted: self.client_unrestricted.clone(), + connection_semaphore: Arc::clone(&self.connection_semaphore), + }; + + Ok((upload_id, Box::new(CloudMultiPartUpload::new(inner, 8)))) + } + + async fn abort_multipart( + &self, + location: &Path, + multipart_id: &MultipartId, + ) -> Result<()> { + let request_factory = move || rusoto_s3::AbortMultipartUploadRequest { + bucket: self.bucket_name.clone(), + key: location.to_string(), + upload_id: multipart_id.to_string(), + ..Default::default() + }; + + let s3 = self.client().await; + s3_request(move || { + let (s3, request_factory) = (s3.clone(), request_factory); + + async move { s3.abort_multipart_upload(request_factory()).await } + }) + .await + .context(UnableToCleanupMultipartDataSnafu { + bucket: &self.bucket_name, + path: location.as_ref(), + })?; + + Ok(()) + } + async fn get(&self, location: &Path) -> Result { Ok(GetResult::Stream( self.get_object(location, None).await?.boxed(), @@ -821,13 +931,131 @@ impl Error { } } +struct S3MultiPartUpload { + bucket: String, + key: String, + upload_id: String, + client_unrestricted: rusoto_s3::S3Client, + connection_semaphore: Arc, +} + +impl CloudMultiPartUploadImpl for S3MultiPartUpload { + fn put_multipart_part( + &self, + buf: Vec, + part_idx: usize, + ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>> { + // Get values to move into future; we don't want a reference to Self + let bucket = self.bucket.clone(); + let key = self.key.clone(); + let upload_id = self.upload_id.clone(); + let content_length = buf.len(); + + let request_factory = move || rusoto_s3::UploadPartRequest { + bucket, + key, + upload_id, + // AWS part number is 1-indexed + part_number: (part_idx + 1).try_into().unwrap(), + content_length: Some(content_length.try_into().unwrap()), + body: Some(buf.into()), + ..Default::default() + }; + + let s3 = self.client_unrestricted.clone(); + let connection_semaphore = Arc::clone(&self.connection_semaphore); + + Box::pin(async move { + let _permit = connection_semaphore + .acquire_owned() + .await + .expect("semaphore shouldn't be closed yet"); + + let response = s3_request(move || { + let (s3, request_factory) = (s3.clone(), request_factory.clone()); + async move { s3.upload_part(request_factory()).await } + }) + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + + Ok(( + part_idx, + UploadPart { + content_id: response.e_tag.unwrap(), + }, + )) + }) + } + + fn complete( + &self, + completed_parts: Vec>, + ) -> BoxFuture<'static, Result<(), io::Error>> { + let parts = + completed_parts + .into_iter() + .enumerate() + .map(|(part_number, maybe_part)| match maybe_part { + Some(part) => { + Ok(rusoto_s3::CompletedPart { + e_tag: Some(part.content_id), + part_number: Some((part_number + 1).try_into().map_err( + |err| io::Error::new(io::ErrorKind::Other, err), + )?), + }) + } + None => Err(io::Error::new( + io::ErrorKind::Other, + format!("Missing information for upload part {:?}", part_number), + )), + }); + + // Get values to move into future; we don't want a reference to Self + let bucket = self.bucket.clone(); + let key = self.key.clone(); + let upload_id = self.upload_id.clone(); + + let request_factory = move || -> Result<_, io::Error> { + Ok(rusoto_s3::CompleteMultipartUploadRequest { + bucket, + key, + upload_id, + multipart_upload: Some(rusoto_s3::CompletedMultipartUpload { + parts: Some(parts.collect::>()?), + }), + ..Default::default() + }) + }; + + let s3 = self.client_unrestricted.clone(); + let connection_semaphore = Arc::clone(&self.connection_semaphore); + + Box::pin(async move { + let _permit = connection_semaphore + .acquire_owned() + .await + .expect("semaphore shouldn't be closed yet"); + + s3_request(move || { + let (s3, request_factory) = (s3.clone(), request_factory.clone()); + + async move { s3.complete_multipart_upload(request_factory()?).await } + }) + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + + Ok(()) + }) + } +} + #[cfg(test)] mod tests { use super::*; use crate::{ tests::{ get_nonexistent_object, list_uses_directories_correctly, list_with_delimiter, - put_get_delete_list, rename_and_copy, + put_get_delete_list, rename_and_copy, stream_get, }, Error as ObjectStoreError, ObjectStore, }; @@ -943,6 +1171,7 @@ mod tests { check_credentials(list_uses_directories_correctly(&integration).await).unwrap(); check_credentials(list_with_delimiter(&integration).await).unwrap(); check_credentials(rename_and_copy(&integration).await).unwrap(); + check_credentials(stream_get(&integration).await).unwrap(); } #[tokio::test] diff --git a/object_store/src/azure.rs b/object_store/src/azure.rs index 75dafef8694..25f311a9a39 100644 --- a/object_store/src/azure.rs +++ b/object_store/src/azure.rs @@ -16,10 +16,21 @@ // under the License. //! An object store implementation for Azure blob storage +//! +//! ## Streaming uploads +//! +//! [ObjectStore::put_multipart] will upload data in blocks and write a blob from those +//! blocks. Data is buffered internally to make blocks of at least 5MB and blocks +//! are uploaded concurrently. +//! +//! [ObjectStore::abort_multipart] is a no-op, since Azure Blob Store doesn't provide +//! a way to drop old blocks. Instead unused blocks are automatically cleaned up +//! after 7 days. use crate::{ + multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}, path::{Path, DELIMITER}, util::format_prefix, - GetResult, ListResult, ObjectMeta, ObjectStore, Result, + GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, }; use async_trait::async_trait; use azure_core::{prelude::*, HttpClient}; @@ -32,12 +43,15 @@ use azure_storage_blobs::{ }; use bytes::Bytes; use futures::{ + future::BoxFuture, stream::{self, BoxStream}, StreamExt, TryStreamExt, }; use snafu::{ResultExt, Snafu}; use std::collections::BTreeSet; +use std::io; use std::{convert::TryInto, sync::Arc}; +use tokio::io::AsyncWrite; use url::Url; /// A specialized `Error` for Azure object store-related errors @@ -232,6 +246,27 @@ impl ObjectStore for MicrosoftAzure { Ok(()) } + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)> { + let inner = AzureMultiPartUpload { + container_client: Arc::clone(&self.container_client), + location: location.to_owned(), + }; + Ok((String::new(), Box::new(CloudMultiPartUpload::new(inner, 8)))) + } + + async fn abort_multipart( + &self, + _location: &Path, + _multipart_id: &MultipartId, + ) -> Result<()> { + // There is no way to drop blocks that have been uploaded. Instead, they simply + // expire in 7 days. + Ok(()) + } + async fn get(&self, location: &Path) -> Result { let blob = self .container_client @@ -604,6 +639,94 @@ pub fn new_azure( }) } +// Relevant docs: https://azure.github.io/Storage/docs/application-and-user-data/basics/azure-blob-storage-upload-apis/ +// In Azure Blob Store, parts are "blocks" +// put_multipart_part -> PUT block +// complete -> PUT block list +// abort -> No equivalent; blocks are simply dropped after 7 days +#[derive(Debug, Clone)] +struct AzureMultiPartUpload { + container_client: Arc, + location: Path, +} + +impl AzureMultiPartUpload { + /// Gets the block id corresponding to the part index. + /// + /// In Azure, the user determines what id each block has. They must be + /// unique within an upload and of consistent length. + fn get_block_id(&self, part_idx: usize) -> String { + format!("{:20}", part_idx) + } +} + +impl CloudMultiPartUploadImpl for AzureMultiPartUpload { + fn put_multipart_part( + &self, + buf: Vec, + part_idx: usize, + ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>> { + let client = Arc::clone(&self.container_client); + let location = self.location.clone(); + let block_id = self.get_block_id(part_idx); + + Box::pin(async move { + client + .as_blob_client(location.as_ref()) + .put_block(block_id.clone(), buf) + .execute() + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + + Ok(( + part_idx, + UploadPart { + content_id: block_id, + }, + )) + }) + } + + fn complete( + &self, + completed_parts: Vec>, + ) -> BoxFuture<'static, Result<(), io::Error>> { + let parts = + completed_parts + .into_iter() + .enumerate() + .map(|(part_number, maybe_part)| match maybe_part { + Some(part) => { + Ok(azure_storage_blobs::blob::BlobBlockType::Uncommitted( + azure_storage_blobs::BlockId::new(part.content_id), + )) + } + None => Err(io::Error::new( + io::ErrorKind::Other, + format!("Missing information for upload part {:?}", part_number), + )), + }); + + let client = Arc::clone(&self.container_client); + let location = self.location.clone(); + + Box::pin(async move { + let block_list = azure_storage_blobs::blob::BlockList { + blocks: parts.collect::>()?, + }; + + client + .as_blob_client(location.as_ref()) + .put_block_list(&block_list) + .execute() + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + + Ok(()) + }) + } +} + #[cfg(test)] mod tests { use crate::azure::new_azure; diff --git a/object_store/src/gcp.rs b/object_store/src/gcp.rs index e836caba7b4..d740625bd92 100644 --- a/object_store/src/gcp.rs +++ b/object_store/src/gcp.rs @@ -16,27 +16,44 @@ // under the License. //! An object store implementation for Google Cloud Storage +//! +//! ## Multi-part uploads +//! +//! [Multi-part uploads](https://cloud.google.com/storage/docs/multipart-uploads) +//! can be initiated with the [ObjectStore::put_multipart] method. +//! Data passed to the writer is automatically buffered to meet the minimum size +//! requirements for a part. Multiple parts are uploaded concurrently. +//! +//! If the writer fails for any reason, you may have parts uploaded to GCS but not +//! used that you may be charged for. Use the [ObjectStore::abort_multipart] method +//! to abort the upload and drop those unneeded parts. In addition, you may wish to +//! consider implementing automatic clean up of unused parts that are older than one +//! week. use std::collections::BTreeSet; use std::fs::File; -use std::io::BufReader; +use std::io::{self, BufReader}; use std::ops::Range; +use std::sync::Arc; use async_trait::async_trait; -use bytes::Bytes; +use bytes::{Buf, Bytes}; use chrono::{DateTime, Utc}; +use futures::future::BoxFuture; use futures::{stream::BoxStream, StreamExt, TryStreamExt}; use percent_encoding::{percent_encode, NON_ALPHANUMERIC}; use reqwest::header::RANGE; use reqwest::{header, Client, Method, Response, StatusCode}; use snafu::{ResultExt, Snafu}; +use tokio::io::AsyncWrite; +use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}; use crate::util::format_http_range; use crate::{ oauth::OAuthProvider, path::{Path, DELIMITER}, token::TokenCache, util::format_prefix, - GetResult, ListResult, ObjectMeta, ObjectStore, Result, + GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, }; #[derive(Debug, Snafu)] @@ -47,6 +64,14 @@ enum Error { #[snafu(display("Unable to decode service account file: {}", source))] DecodeCredentials { source: serde_json::Error }, + #[snafu(display("Got invalid XML response for {} {}: {}", method, url, source))] + InvalidXMLResponse { + source: quick_xml::de::DeError, + method: String, + url: String, + data: Bytes, + }, + #[snafu(display("Error performing list request: {}", source))] ListRequest { source: reqwest::Error }, @@ -139,9 +164,42 @@ struct Object { updated: DateTime, } +#[derive(serde::Deserialize, Debug)] +#[serde(rename_all = "PascalCase")] +struct InitiateMultipartUploadResult { + upload_id: String, +} + +#[derive(serde::Serialize, Debug)] +#[serde(rename_all = "PascalCase", rename(serialize = "Part"))] +struct MultipartPart { + #[serde(rename = "$unflatten=PartNumber")] + part_number: usize, + #[serde(rename = "$unflatten=ETag")] + e_tag: String, +} + +#[derive(serde::Serialize, Debug)] +#[serde(rename_all = "PascalCase")] +struct CompleteMultipartUpload { + #[serde(rename = "Part", default)] + parts: Vec, +} + /// Configuration for connecting to [Google Cloud Storage](https://cloud.google.com/storage/). #[derive(Debug)] pub struct GoogleCloudStorage { + client: Arc, +} + +impl std::fmt::Display for GoogleCloudStorage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "GoogleCloudStorage({})", self.client.bucket_name) + } +} + +#[derive(Debug)] +struct GoogleCloudStorageClient { client: Client, base_url: String, @@ -155,13 +213,7 @@ pub struct GoogleCloudStorage { max_list_results: Option, } -impl std::fmt::Display for GoogleCloudStorage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "GoogleCloudStorage({})", self.bucket_name) - } -} - -impl GoogleCloudStorage { +impl GoogleCloudStorageClient { async fn get_token(&self) -> Result { if let Some(oauth_provider) = &self.oauth_provider { Ok(self @@ -243,6 +295,61 @@ impl GoogleCloudStorage { Ok(()) } + /// Initiate a multi-part upload + async fn multipart_initiate(&self, path: &Path) -> Result { + let token = self.get_token().await?; + let url = format!("{}/{}/{}", self.base_url, self.bucket_name_encoded, path); + + let response = self + .client + .request(Method::POST, &url) + .bearer_auth(token) + .header(header::CONTENT_TYPE, "application/octet-stream") + .header(header::CONTENT_LENGTH, "0") + .query(&[("uploads", "")]) + .send() + .await + .context(PutRequestSnafu)? + .error_for_status() + .context(PutRequestSnafu)?; + + let data = response.bytes().await.context(PutRequestSnafu)?; + let result: InitiateMultipartUploadResult = quick_xml::de::from_reader( + data.as_ref().reader(), + ) + .context(InvalidXMLResponseSnafu { + method: "POST".to_string(), + url, + data, + })?; + + Ok(result.upload_id) + } + + /// Cleanup unused parts + async fn multipart_cleanup( + &self, + path: &str, + multipart_id: &MultipartId, + ) -> Result<()> { + let token = self.get_token().await?; + let url = format!("{}/{}/{}", self.base_url, self.bucket_name_encoded, path); + + self.client + .request(Method::DELETE, &url) + .bearer_auth(token) + .header(header::CONTENT_TYPE, "application/octet-stream") + .header(header::CONTENT_LENGTH, "0") + .query(&[("uploadId", multipart_id)]) + .send() + .await + .context(PutRequestSnafu)? + .error_for_status() + .context(PutRequestSnafu)?; + + Ok(()) + } + /// Perform a delete request async fn delete_request(&self, path: &Path) -> Result<()> { let token = self.get_token().await?; @@ -401,14 +508,184 @@ impl GoogleCloudStorage { } } +fn reqwest_error_as_io(err: reqwest::Error) -> io::Error { + if err.is_builder() || err.is_request() { + io::Error::new(io::ErrorKind::InvalidInput, err) + } else if err.is_status() { + match err.status() { + Some(StatusCode::NOT_FOUND) => io::Error::new(io::ErrorKind::NotFound, err), + Some(StatusCode::BAD_REQUEST) => { + io::Error::new(io::ErrorKind::InvalidInput, err) + } + Some(_) => io::Error::new(io::ErrorKind::Other, err), + None => io::Error::new(io::ErrorKind::Other, err), + } + } else if err.is_timeout() { + io::Error::new(io::ErrorKind::TimedOut, err) + } else if err.is_connect() { + io::Error::new(io::ErrorKind::NotConnected, err) + } else { + io::Error::new(io::ErrorKind::Other, err) + } +} + +struct GCSMultipartUpload { + client: Arc, + encoded_path: String, + multipart_id: MultipartId, +} + +impl CloudMultiPartUploadImpl for GCSMultipartUpload { + /// Upload an object part + fn put_multipart_part( + &self, + buf: Vec, + part_idx: usize, + ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>> { + let upload_id = self.multipart_id.clone(); + let url = format!( + "{}/{}/{}", + self.client.base_url, self.client.bucket_name_encoded, self.encoded_path + ); + let client = Arc::clone(&self.client); + + Box::pin(async move { + let token = client + .get_token() + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + + let response = client + .client + .request(Method::PUT, &url) + .bearer_auth(token) + .query(&[ + ("partNumber", format!("{}", part_idx + 1)), + ("uploadId", upload_id), + ]) + .header(header::CONTENT_TYPE, "application/octet-stream") + .header(header::CONTENT_LENGTH, format!("{}", buf.len())) + .body(buf) + .send() + .await + .map_err(reqwest_error_as_io)? + .error_for_status() + .map_err(reqwest_error_as_io)?; + + let content_id = response + .headers() + .get("ETag") + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "response headers missing ETag", + ) + })? + .to_str() + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))? + .to_string(); + + Ok((part_idx, UploadPart { content_id })) + }) + } + + /// Complete a multipart upload + fn complete( + &self, + completed_parts: Vec>, + ) -> BoxFuture<'static, Result<(), io::Error>> { + let client = Arc::clone(&self.client); + let upload_id = self.multipart_id.clone(); + let url = format!( + "{}/{}/{}", + self.client.base_url, self.client.bucket_name_encoded, self.encoded_path + ); + + Box::pin(async move { + let parts: Vec = completed_parts + .into_iter() + .enumerate() + .map(|(part_number, maybe_part)| match maybe_part { + Some(part) => Ok(MultipartPart { + e_tag: part.content_id, + part_number: part_number + 1, + }), + None => Err(io::Error::new( + io::ErrorKind::Other, + format!("Missing information for upload part {:?}", part_number), + )), + }) + .collect::, io::Error>>()?; + + let token = client + .get_token() + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + + let upload_info = CompleteMultipartUpload { parts }; + + let data = quick_xml::se::to_string(&upload_info) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))? + // We cannot disable the escaping that transforms "/" to ""e;" :( + // https://github.com/tafia/quick-xml/issues/362 + // https://github.com/tafia/quick-xml/issues/350 + .replace(""", "\""); + + client + .client + .request(Method::POST, &url) + .bearer_auth(token) + .query(&[("uploadId", upload_id)]) + .body(data) + .send() + .await + .map_err(reqwest_error_as_io)? + .error_for_status() + .map_err(reqwest_error_as_io)?; + + Ok(()) + }) + } +} + #[async_trait] impl ObjectStore for GoogleCloudStorage { async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { - self.put_request(location, bytes).await + self.client.put_request(location, bytes).await + } + + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)> { + let upload_id = self.client.multipart_initiate(location).await?; + + let encoded_path = + percent_encode(location.to_string().as_bytes(), NON_ALPHANUMERIC).to_string(); + + let inner = GCSMultipartUpload { + client: Arc::clone(&self.client), + encoded_path, + multipart_id: upload_id.clone(), + }; + + Ok((upload_id, Box::new(CloudMultiPartUpload::new(inner, 8)))) + } + + async fn abort_multipart( + &self, + location: &Path, + multipart_id: &MultipartId, + ) -> Result<()> { + self.client + .multipart_cleanup(location.as_ref(), multipart_id) + .await?; + + Ok(()) } async fn get(&self, location: &Path) -> Result { - let response = self.get_request(location, None, false).await?; + let response = self.client.get_request(location, None, false).await?; let stream = response .bytes_stream() .map_err(|source| crate::Error::Generic { @@ -421,14 +698,17 @@ impl ObjectStore for GoogleCloudStorage { } async fn get_range(&self, location: &Path, range: Range) -> Result { - let response = self.get_request(location, Some(range), false).await?; + let response = self + .client + .get_request(location, Some(range), false) + .await?; Ok(response.bytes().await.context(GetRequestSnafu { path: location.as_ref(), })?) } async fn head(&self, location: &Path) -> Result { - let response = self.get_request(location, None, true).await?; + let response = self.client.get_request(location, None, true).await?; let object = response.json().await.context(GetRequestSnafu { path: location.as_ref(), })?; @@ -436,7 +716,7 @@ impl ObjectStore for GoogleCloudStorage { } async fn delete(&self, location: &Path) -> Result<()> { - self.delete_request(location).await + self.client.delete_request(location).await } async fn list( @@ -444,6 +724,7 @@ impl ObjectStore for GoogleCloudStorage { prefix: Option<&Path>, ) -> Result>> { let stream = self + .client .list_paginated(prefix, false)? .map_ok(|r| { futures::stream::iter( @@ -457,7 +738,7 @@ impl ObjectStore for GoogleCloudStorage { } async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { - let mut stream = self.list_paginated(prefix, true)?; + let mut stream = self.client.list_paginated(prefix, true)?; let mut common_prefixes = BTreeSet::new(); let mut objects = Vec::new(); @@ -482,11 +763,11 @@ impl ObjectStore for GoogleCloudStorage { } async fn copy(&self, from: &Path, to: &Path) -> Result<()> { - self.copy_request(from, to, false).await + self.client.copy_request(from, to, false).await } async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { - self.copy_request(from, to, true).await + self.client.copy_request(from, to, true).await } } @@ -537,13 +818,15 @@ pub fn new_gcs_with_client( // environment variables. Set the environment variable explicitly so // that we can optionally accept command line arguments instead. Ok(GoogleCloudStorage { - client, - base_url: credentials.gcs_base_url, - oauth_provider, - token_cache: Default::default(), - bucket_name, - bucket_name_encoded: encoded_bucket_name, - max_list_results: None, + client: Arc::new(GoogleCloudStorageClient { + client, + base_url: credentials.gcs_base_url, + oauth_provider, + token_cache: Default::default(), + bucket_name, + bucket_name_encoded: encoded_bucket_name, + max_list_results: None, + }), }) } @@ -568,7 +851,7 @@ mod test { use crate::{ tests::{ get_nonexistent_object, list_uses_directories_correctly, list_with_delimiter, - put_get_delete_list, rename_and_copy, + put_get_delete_list, rename_and_copy, stream_get, }, Error as ObjectStoreError, ObjectStore, }; @@ -648,6 +931,11 @@ mod test { list_uses_directories_correctly(&integration).await.unwrap(); list_with_delimiter(&integration).await.unwrap(); rename_and_copy(&integration).await.unwrap(); + if integration.client.base_url == default_gcs_base_url() { + // Fake GCS server does not yet implement XML Multipart uploads + // https://github.com/fsouza/fake-gcs-server/issues/852 + stream_get(&integration).await.unwrap(); + } } #[tokio::test] diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index 2dc65069a99..54d28273fa9 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -30,7 +30,7 @@ //! //! This crate provides APIs for interacting with object storage services. //! -//! It currently supports PUT, GET, DELETE, HEAD and list for: +//! It currently supports PUT (single or chunked/concurrent), GET, DELETE, HEAD and list for: //! //! * [Google Cloud Storage](https://cloud.google.com/storage/) //! * [Amazon S3](https://aws.amazon.com/s3/) @@ -56,6 +56,8 @@ mod oauth; #[cfg(feature = "gcp")] mod token; +#[cfg(any(feature = "azure", feature = "aws", feature = "gcp"))] +mod multipart; mod util; use crate::path::Path; @@ -68,16 +70,45 @@ use snafu::Snafu; use std::fmt::{Debug, Formatter}; use std::io::{Read, Seek, SeekFrom}; use std::ops::Range; +use tokio::io::AsyncWrite; /// An alias for a dynamically dispatched object store implementation. pub type DynObjectStore = dyn ObjectStore; +/// Id type for multi-part uploads. +pub type MultipartId = String; + /// Universal API to multiple object store services. #[async_trait] pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { /// Save the provided bytes to the specified location. async fn put(&self, location: &Path, bytes: Bytes) -> Result<()>; + /// Get a multi-part upload that allows writing data in chunks + /// + /// Most cloud-based uploads will buffer and upload parts in parallel. + /// + /// To complete the upload, [AsyncWrite::poll_shutdown] must be called + /// to completion. + /// + /// For some object stores (S3, GCS, and local in particular), if the + /// writer fails or panics, you must call [ObjectStore::abort_multipart] + /// to clean up partially written data. + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)>; + + /// Cleanup an aborted upload. + /// + /// See documentation for individual stores for exact behavior, as capabilities + /// vary by object store. + async fn abort_multipart( + &self, + location: &Path, + multipart_id: &MultipartId, + ) -> Result<()>; + /// Return the bytes that are stored at the specified location. async fn get(&self, location: &Path) -> Result; @@ -330,6 +361,7 @@ mod test_util { mod tests { use super::*; use crate::test_util::flatten_list_stream; + use tokio::io::AsyncWriteExt; type Error = Box; type Result = std::result::Result; @@ -497,6 +529,77 @@ mod tests { Ok(()) } + fn get_vec_of_bytes(chunk_length: usize, num_chunks: usize) -> Vec { + std::iter::repeat(Bytes::from_iter(std::iter::repeat(b'x').take(chunk_length))) + .take(num_chunks) + .collect() + } + + pub(crate) async fn stream_get(storage: &DynObjectStore) -> Result<()> { + let location = Path::from("test_dir/test_upload_file.txt"); + + // Can write to storage + let data = get_vec_of_bytes(5_000_000, 10); + let bytes_expected = data.concat(); + let (_, mut writer) = storage.put_multipart(&location).await?; + for chunk in &data { + writer.write_all(chunk).await?; + } + + // Object should not yet exist in store + let meta_res = storage.head(&location).await; + assert!(meta_res.is_err()); + assert!(matches!( + meta_res.unwrap_err(), + crate::Error::NotFound { .. } + )); + + writer.shutdown().await?; + let bytes_written = storage.get(&location).await?.bytes().await?; + assert_eq!(bytes_expected, bytes_written); + + // Can overwrite some storage + let data = get_vec_of_bytes(5_000, 5); + let bytes_expected = data.concat(); + let (_, mut writer) = storage.put_multipart(&location).await?; + for chunk in &data { + writer.write_all(chunk).await?; + } + writer.shutdown().await?; + let bytes_written = storage.get(&location).await?.bytes().await?; + assert_eq!(bytes_expected, bytes_written); + + // We can abort an empty write + let location = Path::from("test_dir/test_abort_upload.txt"); + let (upload_id, writer) = storage.put_multipart(&location).await?; + drop(writer); + storage.abort_multipart(&location, &upload_id).await?; + let get_res = storage.get(&location).await; + assert!(get_res.is_err()); + assert!(matches!( + get_res.unwrap_err(), + crate::Error::NotFound { .. } + )); + + // We can abort an in-progress write + let (upload_id, mut writer) = storage.put_multipart(&location).await?; + if let Some(chunk) = data.get(0) { + writer.write_all(chunk).await?; + let _ = writer.write(chunk).await?; + } + drop(writer); + + storage.abort_multipart(&location, &upload_id).await?; + let get_res = storage.get(&location).await; + assert!(get_res.is_err()); + assert!(matches!( + get_res.unwrap_err(), + crate::Error::NotFound { .. } + )); + + Ok(()) + } + pub(crate) async fn list_uses_directories_correctly( storage: &DynObjectStore, ) -> Result<()> { diff --git a/object_store/src/local.rs b/object_store/src/local.rs index 8a9462eba9b..798edef6f37 100644 --- a/object_store/src/local.rs +++ b/object_store/src/local.rs @@ -19,18 +19,23 @@ use crate::{ maybe_spawn_blocking, path::{filesystem_path_to_url, Path}, - GetResult, ListResult, ObjectMeta, ObjectStore, Result, + GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, }; use async_trait::async_trait; use bytes::Bytes; +use futures::future::BoxFuture; +use futures::FutureExt; use futures::{stream::BoxStream, StreamExt}; use snafu::{ensure, OptionExt, ResultExt, Snafu}; -use std::collections::VecDeque; use std::fs::File; use std::io::{Read, Seek, SeekFrom, Write}; use std::ops::Range; +use std::pin::Pin; use std::sync::Arc; +use std::task::Poll; use std::{collections::BTreeSet, convert::TryFrom, io}; +use std::{collections::VecDeque, path::PathBuf}; +use tokio::io::AsyncWrite; use url::Url; use walkdir::{DirEntry, WalkDir}; @@ -233,24 +238,7 @@ impl ObjectStore for LocalFileSystem { let path = self.config.path_to_filesystem(location)?; maybe_spawn_blocking(move || { - let mut file = match File::create(&path) { - Ok(f) => f, - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - let parent = path - .parent() - .context(UnableToCreateFileSnafu { path: &path, err })?; - std::fs::create_dir_all(&parent) - .context(UnableToCreateDirSnafu { path: parent })?; - - match File::create(&path) { - Ok(f) => f, - Err(err) => { - return Err(Error::UnableToCreateFile { path, err }.into()) - } - } - } - Err(err) => return Err(Error::UnableToCreateFile { path, err }.into()), - }; + let mut file = open_writable_file(&path)?; file.write_all(&bytes) .context(UnableToCopyDataToFileSnafu)?; @@ -260,6 +248,53 @@ impl ObjectStore for LocalFileSystem { .await } + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)> { + let dest = self.config.path_to_filesystem(location)?; + + // Generate an id in case of concurrent writes + let mut multipart_id = 1; + + // Will write to a temporary path + let staging_path = loop { + let staging_path = get_upload_stage_path(&dest, &multipart_id.to_string()); + + match std::fs::metadata(&staging_path) { + Err(err) if err.kind() == io::ErrorKind::NotFound => break staging_path, + Err(err) => { + return Err(Error::UnableToCopyDataToFile { source: err }.into()) + } + Ok(_) => multipart_id += 1, + } + }; + let multipart_id = multipart_id.to_string(); + + let file = open_writable_file(&staging_path)?; + + Ok(( + multipart_id.clone(), + Box::new(LocalUpload::new(dest, multipart_id, Arc::new(file))), + )) + } + + async fn abort_multipart( + &self, + location: &Path, + multipart_id: &MultipartId, + ) -> Result<()> { + let dest = self.config.path_to_filesystem(location)?; + let staging_path: PathBuf = get_upload_stage_path(&dest, multipart_id); + + maybe_spawn_blocking(move || { + std::fs::remove_file(&staging_path) + .context(UnableToDeleteFileSnafu { path: staging_path })?; + Ok(()) + }) + .await + } + async fn get(&self, location: &Path) -> Result { let path = self.config.path_to_filesystem(location)?; maybe_spawn_blocking(move || { @@ -343,7 +378,12 @@ impl ObjectStore for LocalFileSystem { Err(e) => Some(Err(e)), Ok(None) => None, Ok(entry @ Some(_)) => entry - .filter(|dir_entry| dir_entry.file_type().is_file()) + .filter(|dir_entry| { + dir_entry.file_type().is_file() + // Ignore file names with # in them, since they might be in-progress uploads. + // They would be rejected anyways by filesystem_to_path below. + && !dir_entry.file_name().to_string_lossy().contains('#') + }) .map(|entry| { let location = config.filesystem_to_path(entry.path())?; convert_entry(entry, location) @@ -400,6 +440,13 @@ impl ObjectStore for LocalFileSystem { for entry_res in walkdir.into_iter().map(convert_walkdir_result) { if let Some(entry) = entry_res? { + if entry.file_type().is_file() + // Ignore file names with # in them, since they might be in-progress uploads. + // They would be rejected anyways by filesystem_to_path below. + && entry.file_name().to_string_lossy().contains('#') + { + continue; + } let is_directory = entry.file_type().is_dir(); let entry_location = config.filesystem_to_path(entry.path())?; @@ -475,6 +522,216 @@ impl ObjectStore for LocalFileSystem { } } +fn get_upload_stage_path(dest: &std::path::Path, multipart_id: &MultipartId) -> PathBuf { + let mut staging_path = dest.as_os_str().to_owned(); + staging_path.push(format!("#{}", multipart_id)); + staging_path.into() +} + +enum LocalUploadState { + /// Upload is ready to send new data + Idle(Arc), + /// In the middle of a write + Writing( + Arc, + BoxFuture<'static, Result>, + ), + /// In the middle of syncing data and closing file. + /// + /// Future will contain last reference to file, so it will call drop on completion. + ShuttingDown(BoxFuture<'static, Result<(), io::Error>>), + /// File is being moved from it's temporary location to the final location + Committing(BoxFuture<'static, Result<(), io::Error>>), + /// Upload is complete + Complete, +} + +struct LocalUpload { + inner_state: LocalUploadState, + dest: PathBuf, + multipart_id: MultipartId, +} + +impl LocalUpload { + pub fn new( + dest: PathBuf, + multipart_id: MultipartId, + file: Arc, + ) -> Self { + Self { + inner_state: LocalUploadState::Idle(file), + dest, + multipart_id, + } + } +} + +impl AsyncWrite for LocalUpload { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + let invalid_state = + |condition: &str| -> std::task::Poll> { + Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("Tried to write to file {}.", condition), + ))) + }; + + if let Ok(runtime) = tokio::runtime::Handle::try_current() { + let mut data: Vec = buf.to_vec(); + let data_len = data.len(); + + loop { + match &mut self.inner_state { + LocalUploadState::Idle(file) => { + let file = Arc::clone(file); + let file2 = Arc::clone(&file); + let data: Vec = std::mem::take(&mut data); + self.inner_state = LocalUploadState::Writing( + file, + Box::pin( + runtime + .spawn_blocking(move || (&*file2).write_all(&data)) + .map(move |res| match res { + Err(err) => { + Err(io::Error::new(io::ErrorKind::Other, err)) + } + Ok(res) => res.map(move |_| data_len), + }), + ), + ); + } + LocalUploadState::Writing(file, inner_write) => { + match inner_write.poll_unpin(cx) { + Poll::Ready(res) => { + self.inner_state = + LocalUploadState::Idle(Arc::clone(file)); + return Poll::Ready(res); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + LocalUploadState::ShuttingDown(_) => { + return invalid_state("when writer is shutting down"); + } + LocalUploadState::Committing(_) => { + return invalid_state("when writer is committing data"); + } + LocalUploadState::Complete => { + return invalid_state("when writer is complete"); + } + } + } + } else if let LocalUploadState::Idle(file) = &self.inner_state { + let file = Arc::clone(file); + (&*file).write_all(buf)?; + Poll::Ready(Ok(buf.len())) + } else { + // If we are running on this thread, then only possible states are Idle and Complete. + invalid_state("when writer is already complete.") + } + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + if let Ok(runtime) = tokio::runtime::Handle::try_current() { + loop { + match &mut self.inner_state { + LocalUploadState::Idle(file) => { + // We are moving file into the future, and it will be dropped on it's completion, closing the file. + let file = Arc::clone(file); + self.inner_state = LocalUploadState::ShuttingDown(Box::pin( + runtime.spawn_blocking(move || (*file).sync_all()).map( + move |res| match res { + Err(err) => { + Err(io::Error::new(io::ErrorKind::Other, err)) + } + Ok(res) => res, + }, + ), + )); + } + LocalUploadState::ShuttingDown(fut) => match fut.poll_unpin(cx) { + Poll::Ready(res) => { + res?; + let staging_path = + get_upload_stage_path(&self.dest, &self.multipart_id); + let dest = self.dest.clone(); + self.inner_state = LocalUploadState::Committing(Box::pin( + runtime + .spawn_blocking(move || { + std::fs::rename(&staging_path, &dest) + }) + .map(move |res| match res { + Err(err) => { + Err(io::Error::new(io::ErrorKind::Other, err)) + } + Ok(res) => res, + }), + )); + } + Poll::Pending => { + return Poll::Pending; + } + }, + LocalUploadState::Writing(_, _) => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Tried to commit a file where a write is in progress.", + ))); + } + LocalUploadState::Committing(fut) => match fut.poll_unpin(cx) { + Poll::Ready(res) => { + self.inner_state = LocalUploadState::Complete; + return Poll::Ready(res); + } + Poll::Pending => return Poll::Pending, + }, + LocalUploadState::Complete => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "Already complete", + ))) + } + } + } + } else { + let staging_path = get_upload_stage_path(&self.dest, &self.multipart_id); + match &mut self.inner_state { + LocalUploadState::Idle(file) => { + let file = Arc::clone(file); + self.inner_state = LocalUploadState::Complete; + file.sync_all()?; + std::mem::drop(file); + std::fs::rename(&staging_path, &self.dest)?; + Poll::Ready(Ok(())) + } + _ => { + // If we are running on this thread, then only possible states are Idle and Complete. + Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "Already complete", + ))) + } + } + } + } +} + fn open_file(path: &std::path::PathBuf) -> Result { let file = File::open(path).map_err(|e| { if e.kind() == std::io::ErrorKind::NotFound { @@ -492,6 +749,33 @@ fn open_file(path: &std::path::PathBuf) -> Result { Ok(file) } +fn open_writable_file(path: &std::path::PathBuf) -> Result { + match File::create(&path) { + Ok(f) => Ok(f), + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + let parent = path + .parent() + .context(UnableToCreateFileSnafu { path: &path, err })?; + std::fs::create_dir_all(&parent) + .context(UnableToCreateDirSnafu { path: parent })?; + + match File::create(&path) { + Ok(f) => Ok(f), + Err(err) => Err(Error::UnableToCreateFile { + path: path.to_path_buf(), + err, + } + .into()), + } + } + Err(err) => Err(Error::UnableToCreateFile { + path: path.to_path_buf(), + err, + } + .into()), + } +} + fn convert_entry(entry: DirEntry, location: Path) -> Result { let metadata = entry .metadata() @@ -548,11 +832,12 @@ mod tests { use crate::{ tests::{ copy_if_not_exists, get_nonexistent_object, list_uses_directories_correctly, - list_with_delimiter, put_get_delete_list, rename_and_copy, + list_with_delimiter, put_get_delete_list, rename_and_copy, stream_get, }, Error as ObjectStoreError, ObjectStore, }; use tempfile::TempDir; + use tokio::io::AsyncWriteExt; #[tokio::test] async fn file_test() { @@ -564,6 +849,7 @@ mod tests { list_with_delimiter(&integration).await.unwrap(); rename_and_copy(&integration).await.unwrap(); copy_if_not_exists(&integration).await.unwrap(); + stream_get(&integration).await.unwrap(); } #[test] @@ -574,6 +860,7 @@ mod tests { put_get_delete_list(&integration).await.unwrap(); list_uses_directories_correctly(&integration).await.unwrap(); list_with_delimiter(&integration).await.unwrap(); + stream_get(&integration).await.unwrap(); }); } @@ -770,4 +1057,34 @@ mod tests { err ); } + + #[tokio::test] + async fn list_hides_incomplete_uploads() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + let location = Path::from("some_file"); + + let data = Bytes::from("arbitrary data"); + let (multipart_id, mut writer) = + integration.put_multipart(&location).await.unwrap(); + writer.write_all(&data).await.unwrap(); + + let (multipart_id_2, mut writer_2) = + integration.put_multipart(&location).await.unwrap(); + assert_ne!(multipart_id, multipart_id_2); + writer_2.write_all(&data).await.unwrap(); + + let list = flatten_list_stream(&integration, None).await.unwrap(); + assert_eq!(list.len(), 0); + + assert_eq!( + integration + .list_with_delimiter(None) + .await + .unwrap() + .objects + .len(), + 0 + ); + } } diff --git a/object_store/src/memory.rs b/object_store/src/memory.rs index ffd8e3a5207..dc3967d9915 100644 --- a/object_store/src/memory.rs +++ b/object_store/src/memory.rs @@ -16,6 +16,7 @@ // under the License. //! An in-memory object store implementation +use crate::MultipartId; use crate::{path::Path, GetResult, ListResult, ObjectMeta, ObjectStore, Result}; use async_trait::async_trait; use bytes::Bytes; @@ -25,7 +26,12 @@ use parking_lot::RwLock; use snafu::{ensure, OptionExt, Snafu}; use std::collections::BTreeMap; use std::collections::BTreeSet; +use std::io; use std::ops::Range; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Poll; +use tokio::io::AsyncWrite; /// A specialized `Error` for in-memory object store-related errors #[derive(Debug, Snafu)] @@ -67,7 +73,7 @@ impl From for super::Error { /// storage provider. #[derive(Debug, Default)] pub struct InMemory { - storage: RwLock>, + storage: Arc>>, } impl std::fmt::Display for InMemory { @@ -83,6 +89,29 @@ impl ObjectStore for InMemory { Ok(()) } + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)> { + Ok(( + String::new(), + Box::new(InMemoryUpload { + location: location.clone(), + data: Vec::new(), + storage: Arc::clone(&self.storage), + }), + )) + } + + async fn abort_multipart( + &self, + _location: &Path, + _multipart_id: &MultipartId, + ) -> Result<()> { + // Nothing to clean up + Ok(()) + } + async fn get(&self, location: &Path) -> Result { let data = self.get_bytes(location).await?; @@ -211,7 +240,7 @@ impl InMemory { let storage = storage.clone(); Self { - storage: RwLock::new(storage), + storage: Arc::new(RwLock::new(storage)), } } @@ -227,6 +256,39 @@ impl InMemory { } } +struct InMemoryUpload { + location: Path, + data: Vec, + storage: Arc>>, +} + +impl AsyncWrite for InMemoryUpload { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + self.data.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let data = Bytes::from(std::mem::take(&mut self.data)); + self.storage.write().insert(self.location.clone(), data); + Poll::Ready(Ok(())) + } +} + #[cfg(test)] mod tests { use super::*; @@ -234,7 +296,7 @@ mod tests { use crate::{ tests::{ copy_if_not_exists, get_nonexistent_object, list_uses_directories_correctly, - list_with_delimiter, put_get_delete_list, rename_and_copy, + list_with_delimiter, put_get_delete_list, rename_and_copy, stream_get, }, Error as ObjectStoreError, ObjectStore, }; @@ -248,6 +310,7 @@ mod tests { list_with_delimiter(&integration).await.unwrap(); rename_and_copy(&integration).await.unwrap(); copy_if_not_exists(&integration).await.unwrap(); + stream_get(&integration).await.unwrap(); } #[tokio::test] diff --git a/object_store/src/multipart.rs b/object_store/src/multipart.rs new file mode 100644 index 00000000000..c16022d3735 --- /dev/null +++ b/object_store/src/multipart.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use futures::{future::BoxFuture, stream::FuturesUnordered, Future, StreamExt}; +use std::{io, pin::Pin, sync::Arc, task::Poll}; +use tokio::io::AsyncWrite; + +use crate::Result; + +type BoxedTryFuture = Pin> + Send>>; + +/// A trait that can be implemented by cloud-based object stores +/// and used in combination with [`CloudMultiPartUpload`] to provide +/// multipart upload support +/// +/// Note: this does not use AsyncTrait as the lifetimes are difficult to manage +pub(crate) trait CloudMultiPartUploadImpl { + /// Upload a single part + fn put_multipart_part( + &self, + buf: Vec, + part_idx: usize, + ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>>; + + /// Complete the upload with the provided parts + /// + /// `completed_parts` is in order of part number + fn complete( + &self, + completed_parts: Vec>, + ) -> BoxFuture<'static, Result<(), io::Error>>; +} + +#[derive(Debug, Clone)] +pub(crate) struct UploadPart { + pub content_id: String, +} + +pub(crate) struct CloudMultiPartUpload +where + T: CloudMultiPartUploadImpl, +{ + inner: Arc, + /// A list of completed parts, in sequential order. + completed_parts: Vec>, + /// Part upload tasks currently running + tasks: FuturesUnordered>, + /// Maximum number of upload tasks to run concurrently + max_concurrency: usize, + /// Buffer that will be sent in next upload. + current_buffer: Vec, + /// Minimum size of a part in bytes + min_part_size: usize, + /// Index of current part + current_part_idx: usize, + /// The completion task + completion_task: Option>, +} + +impl CloudMultiPartUpload +where + T: CloudMultiPartUploadImpl, +{ + pub fn new(inner: T, max_concurrency: usize) -> Self { + Self { + inner: Arc::new(inner), + completed_parts: Vec::new(), + tasks: FuturesUnordered::new(), + max_concurrency, + current_buffer: Vec::new(), + // TODO: Should self vary by provider? + // TODO: Should we automatically increase then when part index gets large? + min_part_size: 5_000_000, + current_part_idx: 0, + completion_task: None, + } + } + + pub fn poll_tasks( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Result<(), io::Error> { + if self.tasks.is_empty() { + return Ok(()); + } + let total_parts = self.completed_parts.len(); + while let Poll::Ready(Some(res)) = self.tasks.poll_next_unpin(cx) { + let (part_idx, part) = res?; + self.completed_parts + .resize(std::cmp::max(part_idx + 1, total_parts), None); + self.completed_parts[part_idx] = Some(part); + } + Ok(()) + } +} + +impl AsyncWrite for CloudMultiPartUpload +where + T: CloudMultiPartUploadImpl + Send + Sync, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + // Poll current tasks + self.as_mut().poll_tasks(cx)?; + + // If adding buf to pending buffer would trigger send, check + // whether we have capacity for another task. + let enough_to_send = (buf.len() + self.current_buffer.len()) > self.min_part_size; + if enough_to_send && self.tasks.len() < self.max_concurrency { + // If we do, copy into the buffer and submit the task, and return ready. + self.current_buffer.extend_from_slice(buf); + + let out_buffer = std::mem::take(&mut self.current_buffer); + let task = self + .inner + .put_multipart_part(out_buffer, self.current_part_idx); + self.tasks.push(task); + self.current_part_idx += 1; + + // We need to poll immediately after adding to setup waker + self.as_mut().poll_tasks(cx)?; + + Poll::Ready(Ok(buf.len())) + } else if !enough_to_send { + self.current_buffer.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } else { + // Waker registered by call to poll_tasks at beginning + Poll::Pending + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // Poll current tasks + self.as_mut().poll_tasks(cx)?; + + // If current_buffer is not empty, see if it can be submitted + if !self.current_buffer.is_empty() && self.tasks.len() < self.max_concurrency { + let out_buffer: Vec = std::mem::take(&mut self.current_buffer); + let task = self + .inner + .put_multipart_part(out_buffer, self.current_part_idx); + self.tasks.push(task); + } + + self.as_mut().poll_tasks(cx)?; + + // If tasks and current_buffer are empty, return Ready + if self.tasks.is_empty() && self.current_buffer.is_empty() { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // First, poll flush + match self.as_mut().poll_flush(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(res) => res?, + }; + + // If shutdown task is not set, set it + let parts = std::mem::take(&mut self.completed_parts); + let inner = Arc::clone(&self.inner); + let completion_task = self + .completion_task + .get_or_insert_with(|| inner.complete(parts)); + + Pin::new(completion_task).poll(cx) + } +} diff --git a/object_store/src/throttle.rs b/object_store/src/throttle.rs index 6560296516d..6789f0e68df 100644 --- a/object_store/src/throttle.rs +++ b/object_store/src/throttle.rs @@ -20,11 +20,13 @@ use parking_lot::Mutex; use std::ops::Range; use std::{convert::TryInto, sync::Arc}; +use crate::MultipartId; use crate::{path::Path, GetResult, ListResult, ObjectMeta, ObjectStore, Result}; use async_trait::async_trait; use bytes::Bytes; use futures::{stream::BoxStream, StreamExt}; use std::time::Duration; +use tokio::io::AsyncWrite; /// Configuration settings for throttled store #[derive(Debug, Default, Clone, Copy)] @@ -149,6 +151,21 @@ impl ObjectStore for ThrottledStore { self.inner.put(location, bytes).await } + async fn put_multipart( + &self, + _location: &Path, + ) -> Result<(MultipartId, Box)> { + Err(super::Error::NotImplemented) + } + + async fn abort_multipart( + &self, + _location: &Path, + _multipart_id: &MultipartId, + ) -> Result<()> { + Err(super::Error::NotImplemented) + } + async fn get(&self, location: &Path) -> Result { sleep(self.config().wait_get_per_call).await;