diff --git a/native-link-store/BUILD.bazel b/native-link-store/BUILD.bazel index c59875285..a25668f6c 100644 --- a/native-link-store/BUILD.bazel +++ b/native-link-store/BUILD.bazel @@ -97,6 +97,7 @@ rust_test_suite( "@crate_index//:filetime", "@crate_index//:futures", "@crate_index//:http", + "@crate_index//:hyper", "@crate_index//:memory-stats", "@crate_index//:once_cell", "@crate_index//:pretty_assertions", diff --git a/native-link-store/src/s3_store.rs b/native-link-store/src/s3_store.rs index af0815b22..2917cf802 100644 --- a/native-link-store/src/s3_store.rs +++ b/native-link-store/src/s3_store.rs @@ -12,15 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::cmp; +use std::borrow::Cow; use std::future::Future; use std::marker::Send; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; +use std::{cmp, env}; use async_trait::async_trait; +use aws_sdk_s3::config::Region; use aws_sdk_s3::operation::create_multipart_upload::CreateMultipartUploadOutput; use aws_sdk_s3::operation::get_object::GetObjectError; use aws_sdk_s3::operation::head_object::HeadObjectError; @@ -147,8 +149,16 @@ impl S3Store { }); let s3_client = { let http_client = HyperClientBuilder::new().build(TlsConnector::new(config, jitter_fn.clone())); - let shared_config = aws_config::from_env().http_client(http_client).load().await; - aws_sdk_s3::Client::new(&shared_config) + let mut config_builder = aws_config::from_env() + .region(Region::new(Cow::Owned(config.region.clone()))) + .http_client(http_client); + // TODO(allada) When aws-sdk supports this env variable we should be able + // to remove this. + // See: https://github.com/awslabs/aws-sdk-rust/issues/932 + if let Ok(endpoint_url) = env::var("AWS_ENDPOINT_URL") { + config_builder = config_builder.endpoint_url(endpoint_url); + } + aws_sdk_s3::Client::new(&config_builder.load().await) }; Self::new_with_client_and_jitter(config, s3_client, jitter_fn) } @@ -488,6 +498,11 @@ impl Store for S3Store { while let Some(maybe_bytes) = s3_in_stream.next().await { match maybe_bytes { Ok(bytes) => { + if bytes.is_empty() { + // Ignore possible EOF. Different implimentations of S3 may or may not + // send EOF this way. + continue; + } if let Err(e) = writer.send(bytes).await { return Some(( RetryResult::Err(make_input_err!("Error sending bytes to consumer in S3: {e}")), diff --git a/native-link-store/tests/s3_store_test.rs b/native-link-store/tests/s3_store_test.rs index 13ea8c117..3895e089c 100644 --- a/native-link-store/tests/s3_store_test.rs +++ b/native-link-store/tests/s3_store_test.rs @@ -19,9 +19,12 @@ use std::time::Duration; use aws_sdk_s3::config::{Builder, Region}; use aws_smithy_runtime::client::http::test_util::{ReplayEvent, StaticReplayClient}; use aws_smithy_types::body::SdkBody; -use error::Error; +use bytes::Bytes; +use error::{Error, ResultExt}; +use futures::join; use http::header; use http::status::StatusCode; +use hyper::Body; use native_link_store::s3_store::S3Store; use native_link_util::common::DigestInfo; use native_link_util::store_trait::Store; @@ -464,4 +467,59 @@ mod s3_store_tests { mock_client.assert_requests_match(&[]); Ok(()) } + + #[tokio::test] + async fn ensure_empty_string_in_stream_works_test() -> Result<(), Error> { + const CAS_ENTRY_SIZE: usize = 10; // Length of "helloworld". + let (mut tx, channel_body) = Body::channel(); + let mock_client = StaticReplayClient::new(vec![ReplayEvent::new( + http::Request::builder() + .uri(format!( + "https://{BUCKET_NAME}.s3.{REGION}.amazonaws.com/{VALID_HASH1}-{CAS_ENTRY_SIZE}?x-id=GetObject", + )) + .header("range", format!("bytes={}-{}", 0, CAS_ENTRY_SIZE)) + .body(SdkBody::empty()) + .unwrap(), + http::Response::builder() + .status(StatusCode::OK) + .body(SdkBody::from_body_0_4(channel_body)) + .unwrap(), + )]); + let test_config = Builder::new() + .region(Region::from_static(REGION)) + .http_client(mock_client.clone()) + .build(); + let s3_client = aws_sdk_s3::Client::from_conf(test_config); + let store = S3Store::new_with_client_and_jitter( + &native_link_config::stores::S3Store { + bucket: BUCKET_NAME.to_string(), + ..Default::default() + }, + s3_client, + Arc::new(move |_delay| Duration::from_secs(0)), + )?; + let store_pin = Pin::new(&store); + + let (_, get_part_result) = join!( + async move { + tx.send_data(Bytes::from_static(b"hello")).await?; + tx.send_data(Bytes::from_static(b"")).await?; + tx.send_data(Bytes::from_static(b"world")).await?; + Result::<(), hyper::Error>::Ok(()) + }, + store_pin.get_part_unchunked( + DigestInfo::try_new(VALID_HASH1, CAS_ENTRY_SIZE)?, + 0, + Some(CAS_ENTRY_SIZE), + None + ) + ); + assert_eq!( + get_part_result.err_tip(|| "Expected get_part_result to pass")?, + "helloworld".as_bytes() + ); + + mock_client.assert_requests_match(&[]); + Ok(()) + } }