From 17fe483a64d58a8e7bea7dfdf47235477ba059b6 Mon Sep 17 00:00:00 2001 From: "Nathan (Blaise) Bruer" Date: Mon, 27 Nov 2023 10:05:10 -0600 Subject: [PATCH] Fix empty bytes error in s3 store and support AWS_ENDPOINT_URL Some http servers can send empty strings in the stream, but we do not allow this in our code since this is the EOF signal. Also adds ability to use custom AWS_ENDPOINT_URL env to change destination of s3 endpoint. --- native-link-store/BUILD.bazel | 1 + native-link-store/src/s3_store.rs | 21 +++++++-- native-link-store/tests/s3_store_test.rs | 60 +++++++++++++++++++++++- 3 files changed, 78 insertions(+), 4 deletions(-) diff --git a/native-link-store/BUILD.bazel b/native-link-store/BUILD.bazel index c59875285..a25668f6c 100644 --- a/native-link-store/BUILD.bazel +++ b/native-link-store/BUILD.bazel @@ -97,6 +97,7 @@ rust_test_suite( "@crate_index//:filetime", "@crate_index//:futures", "@crate_index//:http", + "@crate_index//:hyper", "@crate_index//:memory-stats", "@crate_index//:once_cell", "@crate_index//:pretty_assertions", diff --git a/native-link-store/src/s3_store.rs b/native-link-store/src/s3_store.rs index af0815b22..2917cf802 100644 --- a/native-link-store/src/s3_store.rs +++ b/native-link-store/src/s3_store.rs @@ -12,15 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::cmp; +use std::borrow::Cow; use std::future::Future; use std::marker::Send; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; +use std::{cmp, env}; use async_trait::async_trait; +use aws_sdk_s3::config::Region; use aws_sdk_s3::operation::create_multipart_upload::CreateMultipartUploadOutput; use aws_sdk_s3::operation::get_object::GetObjectError; use aws_sdk_s3::operation::head_object::HeadObjectError; @@ -147,8 +149,16 @@ impl S3Store { }); let s3_client = { let http_client = HyperClientBuilder::new().build(TlsConnector::new(config, jitter_fn.clone())); - let shared_config = aws_config::from_env().http_client(http_client).load().await; - aws_sdk_s3::Client::new(&shared_config) + let mut config_builder = aws_config::from_env() + .region(Region::new(Cow::Owned(config.region.clone()))) + .http_client(http_client); + // TODO(allada) When aws-sdk supports this env variable we should be able + // to remove this. + // See: https://github.com/awslabs/aws-sdk-rust/issues/932 + if let Ok(endpoint_url) = env::var("AWS_ENDPOINT_URL") { + config_builder = config_builder.endpoint_url(endpoint_url); + } + aws_sdk_s3::Client::new(&config_builder.load().await) }; Self::new_with_client_and_jitter(config, s3_client, jitter_fn) } @@ -488,6 +498,11 @@ impl Store for S3Store { while let Some(maybe_bytes) = s3_in_stream.next().await { match maybe_bytes { Ok(bytes) => { + if bytes.is_empty() { + // Ignore possible EOF. Different implimentations of S3 may or may not + // send EOF this way. + continue; + } if let Err(e) = writer.send(bytes).await { return Some(( RetryResult::Err(make_input_err!("Error sending bytes to consumer in S3: {e}")), diff --git a/native-link-store/tests/s3_store_test.rs b/native-link-store/tests/s3_store_test.rs index 13ea8c117..3895e089c 100644 --- a/native-link-store/tests/s3_store_test.rs +++ b/native-link-store/tests/s3_store_test.rs @@ -19,9 +19,12 @@ use std::time::Duration; use aws_sdk_s3::config::{Builder, Region}; use aws_smithy_runtime::client::http::test_util::{ReplayEvent, StaticReplayClient}; use aws_smithy_types::body::SdkBody; -use error::Error; +use bytes::Bytes; +use error::{Error, ResultExt}; +use futures::join; use http::header; use http::status::StatusCode; +use hyper::Body; use native_link_store::s3_store::S3Store; use native_link_util::common::DigestInfo; use native_link_util::store_trait::Store; @@ -464,4 +467,59 @@ mod s3_store_tests { mock_client.assert_requests_match(&[]); Ok(()) } + + #[tokio::test] + async fn ensure_empty_string_in_stream_works_test() -> Result<(), Error> { + const CAS_ENTRY_SIZE: usize = 10; // Length of "helloworld". + let (mut tx, channel_body) = Body::channel(); + let mock_client = StaticReplayClient::new(vec![ReplayEvent::new( + http::Request::builder() + .uri(format!( + "https://{BUCKET_NAME}.s3.{REGION}.amazonaws.com/{VALID_HASH1}-{CAS_ENTRY_SIZE}?x-id=GetObject", + )) + .header("range", format!("bytes={}-{}", 0, CAS_ENTRY_SIZE)) + .body(SdkBody::empty()) + .unwrap(), + http::Response::builder() + .status(StatusCode::OK) + .body(SdkBody::from_body_0_4(channel_body)) + .unwrap(), + )]); + let test_config = Builder::new() + .region(Region::from_static(REGION)) + .http_client(mock_client.clone()) + .build(); + let s3_client = aws_sdk_s3::Client::from_conf(test_config); + let store = S3Store::new_with_client_and_jitter( + &native_link_config::stores::S3Store { + bucket: BUCKET_NAME.to_string(), + ..Default::default() + }, + s3_client, + Arc::new(move |_delay| Duration::from_secs(0)), + )?; + let store_pin = Pin::new(&store); + + let (_, get_part_result) = join!( + async move { + tx.send_data(Bytes::from_static(b"hello")).await?; + tx.send_data(Bytes::from_static(b"")).await?; + tx.send_data(Bytes::from_static(b"world")).await?; + Result::<(), hyper::Error>::Ok(()) + }, + store_pin.get_part_unchunked( + DigestInfo::try_new(VALID_HASH1, CAS_ENTRY_SIZE)?, + 0, + Some(CAS_ENTRY_SIZE), + None + ) + ); + assert_eq!( + get_part_result.err_tip(|| "Expected get_part_result to pass")?, + "helloworld".as_bytes() + ); + + mock_client.assert_requests_match(&[]); + Ok(()) + } }