diff --git a/cas/store/s3_store.rs b/cas/store/s3_store.rs index 517a2c11b..4c0d07dfc 100644 --- a/cas/store/s3_store.rs +++ b/cas/store/s3_store.rs @@ -3,7 +3,6 @@ use std::cmp; use std::marker::Send; use std::pin::Pin; -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::Duration; @@ -14,8 +13,9 @@ use http::status::StatusCode; use rand::{rngs::OsRng, Rng}; use rusoto_core::{region::Region, ByteStream, RusotoError}; use rusoto_s3::{ - AbortMultipartUploadRequest, CompleteMultipartUploadRequest, CreateMultipartUploadRequest, GetObjectRequest, - HeadObjectError, HeadObjectRequest, PutObjectRequest, S3Client, UploadPartRequest, S3, + AbortMultipartUploadRequest, CompleteMultipartUploadRequest, CompletedMultipartUpload, CompletedPart, + CreateMultipartUploadRequest, GetObjectError, GetObjectRequest, HeadObjectError, HeadObjectRequest, + PutObjectRequest, S3Client, UploadPartRequest, S3, }; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::time::sleep; @@ -199,28 +199,29 @@ impl StoreTrait for S3Store { .err_tip(|| "Expected upload_id to be set by s3 response")?; let complete_result = { - let mut part_number = 1; + let mut part_number: i64 = 1; let reader = Arc::new(Mutex::new(reader)); - let is_done = Arc::new(AtomicBool::new(false)); + // We might end up with +1 capacity units than needed, but that is the worst case. + let mut completed_parts = Vec::with_capacity((expected_size / bytes_per_upload_part) + 1); loop { - let is_done_clone = is_done.clone(); + let possible_last_chunk_size = expected_size - bytes_per_upload_part * ((part_number as usize) - 1); + let content_length = cmp::min(possible_last_chunk_size, bytes_per_upload_part); + let is_last_chunk = bytes_per_upload_part * (part_number as usize) >= expected_size; // Wrap `AsyncRead` so we can hold a copy of it in this scope between iterations. // This is quite difficult because we need to give full ownership of an AsyncRead // to `ByteStream` which has an unknown lifetime. // This wrapper will also ensure we only send `bytes_per_upload_part` then close the // stream. - let taker = AsyncReadTaker::new( - reader.clone(), - Some(move || is_done_clone.store(true, Ordering::Relaxed)), - bytes_per_upload_part, - ); + let taker = AsyncReadTaker::new(reader.clone(), content_length); { let body = Some(ByteStream::new(ReaderStream::new(taker))); - self.s3_client + let response = self + .s3_client .upload_part(UploadPartRequest { bucket: self.bucket.to_owned(), key: s3_path.to_owned(), + content_length: Some(content_length as i64), body, part_number, upload_id: upload_id.clone(), @@ -228,8 +229,12 @@ impl StoreTrait for S3Store { }) .await .map_err(|e| make_err!(Code::Unknown, "Failed to upload part: {:?}", e))?; + completed_parts.push(CompletedPart { + e_tag: response.e_tag, + part_number: Some(part_number), + }); } - if is_done.load(Ordering::Relaxed) { + if is_last_chunk { break; } part_number += 1; @@ -240,6 +245,9 @@ impl StoreTrait for S3Store { bucket: self.bucket.to_owned(), key: s3_path.to_owned(), upload_id: upload_id.clone(), + multipart_upload: Some(CompletedMultipartUpload { + parts: Some(completed_parts), + }), ..Default::default() }) .await @@ -288,11 +296,12 @@ impl StoreTrait for S3Store { ..Default::default() }; - let get_object_output = self - .s3_client - .get_object(get_req) - .await - .map_err(|e| make_err!(Code::Unavailable, "Error uploading to S3: {:?}", e))?; + let get_object_output = self.s3_client.get_object(get_req).await.map_err(|e| match e { + RusotoError::Service(GetObjectError::NoSuchKey(err)) => { + return make_err!(Code::NotFound, "Error reading from S3: {:?}", err) + } + _ => make_err!(Code::Unknown, "Error reading from S3: {:?}", e), + })?; let s3_in_stream = get_object_output .body .err_tip(|| "Expected body to be set in s3 get request")?; diff --git a/cas/store/tests/s3_store_test.rs b/cas/store/tests/s3_store_test.rs index fa0a33b78..09217d636 100644 --- a/cas/store/tests/s3_store_test.rs +++ b/cas/store/tests/s3_store_test.rs @@ -34,7 +34,7 @@ fn receive_request(sender: mpsc::Sender<(SignedRequest, Vec)>) -> impl Fn(Si let mut async_reader = stream.into_async_read(); assert!(block_on(async_reader.read_to_end(&mut raw_payload)).is_ok()); } - Some(SignedRequestPayload::Buffer(buffer)) => raw_payload.copy_from_slice(&buffer[..]), + Some(SignedRequestPayload::Buffer(buffer)) => raw_payload.extend_from_slice(&buffer[..]), None => {} } sender.try_send((request, raw_payload)).expect("Failed to send payload"); @@ -361,6 +361,10 @@ mod s3_store_tests { from_utf8(&request.headers["host"][0]).unwrap(), "s3.us-east-1.amazonaws.com" ); + assert_eq!( + from_utf8(&request.headers["content-length"][0]).unwrap(), + format!("{}", rt_data.len()) + ); assert_eq!(request.canonical_query_string, "uploads="); assert_eq!( request.canonical_uri, @@ -380,6 +384,10 @@ mod s3_store_tests { from_utf8(&request.headers["host"][0]).unwrap(), "s3.us-east-1.amazonaws.com" ); + assert_eq!( + from_utf8(&request.headers["content-length"][0]).unwrap(), + format!("{}", rt_data.len()) + ); assert_eq!(request.canonical_query_string, "partNumber=1&uploadId=Dummy-uploadid"); assert_eq!( request.canonical_uri, @@ -393,6 +401,10 @@ mod s3_store_tests { rt_data, "Expected data to match" ); + assert_eq!( + from_utf8(&request.headers["content-length"][0]).unwrap(), + format!("{}", rt_data.len()) + ); assert_eq!(request.canonical_query_string, "partNumber=2&uploadId=Dummy-uploadid"); assert_eq!( request.canonical_uri, @@ -406,6 +418,10 @@ mod s3_store_tests { rt_data, "Expected data to match" ); + assert_eq!( + from_utf8(&request.headers["content-length"][0]).unwrap(), + format!("{}", rt_data.len()) + ); assert_eq!(request.canonical_query_string, "partNumber=3&uploadId=Dummy-uploadid"); assert_eq!( request.canonical_uri, @@ -415,7 +431,28 @@ mod s3_store_tests { { // Final payload is the complete_multipart_upload request. let (request, rt_data) = receiver.next().await.err_tip(|| "Could not get next payload")?; - assert_eq!(&send_data[0..0], rt_data, "Expected data to match"); + const COMPLETE_MULTIPART_PAYLOAD_DATA: &str = concat!( + r#""#, + "", + "1", + "2", + "3", + "", + ); + assert_eq!( + from_utf8(&rt_data).unwrap(), + COMPLETE_MULTIPART_PAYLOAD_DATA, + "Expected last payload to be empty" + ); + assert_eq!(request.method, "POST"); + assert_eq!( + from_utf8(&request.headers["content-length"][0]).unwrap(), + format!("{}", COMPLETE_MULTIPART_PAYLOAD_DATA.len()) + ); + assert_eq!( + from_utf8(&request.headers["x-amz-content-sha256"][0]).unwrap(), + "730f96c9a87580c7930b5bd4fd0457fbe01b34f2261dcdde877d09b06d937b5e" + ); assert_eq!(request.canonical_query_string, "uploadId=Dummy-uploadid"); assert_eq!( request.canonical_uri, diff --git a/config/examples/basic_cas.json b/config/examples/basic_cas.json index 4c3162705..38ae76528 100644 --- a/config/examples/basic_cas.json +++ b/config/examples/basic_cas.json @@ -6,7 +6,7 @@ "s3_store": { "region": "us-west-1", "bucket": "blaisebruer-cas-store", - "key_prefix": "test-prefix-cas", + "key_prefix": "test-prefix-cas/", "retry": { "max_retries": 0, "delay": 0.1, @@ -28,7 +28,7 @@ "s3_store": { "region": "us-west-1", "bucket": "blaisebruer-cas-store", - "key_prefix": "test-prefix-ac", + "key_prefix": "test-prefix-ac/", "retry": { "max_retries": 0, "delay": 0.1, @@ -56,7 +56,7 @@ "cas_stores": { "main": "CAS_MAIN_STORE", }, - // This value was choosen only because it is a common mem page size. + // This value was chosen only because it is a common mem page size. "write_buffer_stream_size": 2e6, // 2mb. "read_buffer_stream_size": 2e6, // 2mb. // According to https://github.com/grpc/grpc.github.io/issues/371 16KiB - 64KiB is optimal. diff --git a/util/async_read_taker.rs b/util/async_read_taker.rs index 683fe6a4e..f7071285f 100644 --- a/util/async_read_taker.rs +++ b/util/async_read_taker.rs @@ -13,31 +13,26 @@ use tokio::io::{AsyncRead, ReadBuf}; pub type ArcMutexAsyncRead = Arc>>; +// TODO(blaise.bruer) It does not look like this class is needed any more. Should consider removing it. pin_project! { /// Useful object that can be used to chunk an AsyncReader by a specific size. /// This also requires the inner reader be sharable between threads. This allows /// the caller to still "own" the underlying reader in a way that once `limit` is /// reached the caller can keep using it, but still use this struct to read the data. - pub struct AsyncReadTaker { + pub struct AsyncReadTaker { inner: ArcMutexAsyncRead, - done_fn: Option, // Add '_' to avoid conflicts with `limit` method. limit_: usize, } } -impl AsyncReadTaker { - /// `done_fn` can be used to pass a functor that will fire when the stream has no more data. - pub fn new(inner: ArcMutexAsyncRead, done_fn: Option, limit: usize) -> Self { - AsyncReadTaker { - inner, - done_fn: done_fn, - limit_: limit, - } +impl AsyncReadTaker { + pub fn new(inner: ArcMutexAsyncRead, limit: usize) -> Self { + AsyncReadTaker { inner, limit_: limit } } } -impl AsyncRead for AsyncReadTaker { +impl AsyncRead for AsyncReadTaker { /// Note: This function is modeled after tokio::Take::poll_read. /// see: https://docs.rs/tokio/1.12.0/src/tokio/io/util/take.rs.html#77 fn poll_read( @@ -62,11 +57,6 @@ impl AsyncRead for AsyncReadTaker { unsafe { buf.assume_init(n); } - if n == 0 { - if let Some(done_fn) = self.done_fn.take() { - done_fn(); - } - } buf.advance(n); self.limit_ -= n; diff --git a/util/tests/async_read_taker_test.rs b/util/tests/async_read_taker_test.rs index 42a4a47e1..54aa39921 100644 --- a/util/tests/async_read_taker_test.rs +++ b/util/tests/async_read_taker_test.rs @@ -1,6 +1,5 @@ // Copyright 2021 Nathan (Blaise) Bruer. All rights reserved. -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use fast_async_mutex::mutex::Mutex; @@ -21,7 +20,7 @@ mod async_read_taker_tests { let raw_fixed_buffer = AsyncFixedBuf::new(vec![0u8; 100].into_boxed_slice()); let (rx, mut tx) = tokio::io::split(raw_fixed_buffer); - let mut taker = AsyncReadTaker::new(Arc::new(Mutex::new(Box::new(rx))), None::>, 1024); + let mut taker = AsyncReadTaker::new(Arc::new(Mutex::new(Box::new(rx))), 1024); let write_data = vec![97u8; 50]; { // Send our data. @@ -38,58 +37,6 @@ mod async_read_taker_tests { Ok(()) } - #[tokio::test] - async fn done_fn_with_split() -> Result<(), Error> { - let raw_fixed_buffer = AsyncFixedBuf::new(vec![0u8; 100].into_boxed_slice()); - let (rx, mut tx) = tokio::io::split(raw_fixed_buffer); - - const WRITE_DATA: &[u8] = &[97u8; 50]; - const READ_AMOUNT: usize = 40; - - let reader: ArcMutexAsyncRead = Arc::new(Mutex::new(Box::new(rx))); - let done = Arc::new(AtomicBool::new(false)); - { - // Send our data. - tx.write_all(&WRITE_DATA).await?; - tx.write(&vec![]).await?; // Write EOF. - } - { - // Receive first chunk and test our data. - let done_clone = done.clone(); - let mut taker = AsyncReadTaker::new( - reader.clone(), - Some(move || done_clone.store(true, Ordering::Relaxed)), - READ_AMOUNT, - ); - - let mut read_buffer = Vec::new(); - let read_sz = taker.read_to_end(&mut read_buffer).await?; - assert_eq!(read_sz, READ_AMOUNT); - assert_eq!(read_buffer.len(), READ_AMOUNT); - assert_eq!(done.load(Ordering::Relaxed), false, "Should not be done"); - assert_eq!(&read_buffer, &WRITE_DATA[0..READ_AMOUNT]); - } - { - // Receive last chunk and test our data. - let done_clone = done.clone(); - let mut taker = AsyncReadTaker::new( - reader.clone(), - Some(move || done_clone.store(true, Ordering::Relaxed)), - READ_AMOUNT, - ); - - let mut read_buffer = Vec::new(); - let read_sz = taker.read_to_end(&mut read_buffer).await?; - const REMAINING_AMT: usize = WRITE_DATA.len() - READ_AMOUNT; - assert_eq!(read_sz, REMAINING_AMT); - assert_eq!(read_buffer.len(), REMAINING_AMT); - assert_eq!(done.load(Ordering::Relaxed), true, "Should not be done"); - assert_eq!(&read_buffer, &WRITE_DATA[READ_AMOUNT..WRITE_DATA.len()]); - } - - Ok(()) - } - #[tokio::test] async fn shutdown_during_read() -> Result<(), Error> { let raw_fixed_buffer = AsyncFixedBuf::new(vec![0u8; 100].into_boxed_slice()); @@ -99,16 +46,10 @@ mod async_read_taker_tests { const READ_AMOUNT: usize = 50; let reader: ArcMutexAsyncRead = Arc::new(Mutex::new(Box::new(rx))); - let done = Arc::new(AtomicBool::new(false)); tx.write_all(&WRITE_DATA).await?; - let done_clone = done.clone(); - let mut taker = Box::pin(AsyncReadTaker::new( - reader.clone(), - Some(move || done_clone.store(true, Ordering::Relaxed)), - READ_AMOUNT, - )); + let mut taker = Box::pin(AsyncReadTaker::new(reader.clone(), READ_AMOUNT)); let mut read_buffer = Vec::new(); let mut read_fut = taker.read_to_end(&mut read_buffer).boxed(); @@ -129,7 +70,6 @@ mod async_read_taker_tests { &read_buffer, &WRITE_DATA, "Expected poll!() macro to have processed the data we wrote" ); - assert_eq!(done.load(Ordering::Relaxed), false, "Should not have called done_fn"); } Ok(())