Skip to content

Commit

Permalink
Fix s3_store
Browse files Browse the repository at this point in the history
  • Loading branch information
allada committed Nov 3, 2021
1 parent 3a4d743 commit efcb653
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 101 deletions.
45 changes: 27 additions & 18 deletions cas/store/s3_store.rs
Expand Up @@ -3,7 +3,6 @@
use std::cmp;
use std::marker::Send;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;

Expand All @@ -14,8 +13,9 @@ use http::status::StatusCode;
use rand::{rngs::OsRng, Rng};
use rusoto_core::{region::Region, ByteStream, RusotoError};
use rusoto_s3::{
AbortMultipartUploadRequest, CompleteMultipartUploadRequest, CreateMultipartUploadRequest, GetObjectRequest,
HeadObjectError, HeadObjectRequest, PutObjectRequest, S3Client, UploadPartRequest, S3,
AbortMultipartUploadRequest, CompleteMultipartUploadRequest, CompletedMultipartUpload, CompletedPart,
CreateMultipartUploadRequest, GetObjectError, GetObjectRequest, HeadObjectError, HeadObjectRequest,
PutObjectRequest, S3Client, UploadPartRequest, S3,
};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::time::sleep;
Expand Down Expand Up @@ -199,37 +199,42 @@ impl StoreTrait for S3Store {
.err_tip(|| "Expected upload_id to be set by s3 response")?;

let complete_result = {
let mut part_number = 1;
let mut part_number: i64 = 1;

let reader = Arc::new(Mutex::new(reader));
let is_done = Arc::new(AtomicBool::new(false));
// We might end up with +1 capacity units than needed, but that is the worst case.
let mut completed_parts = Vec::with_capacity((expected_size / bytes_per_upload_part) + 1);
loop {
let is_done_clone = is_done.clone();
let possible_last_chunk_size = expected_size - bytes_per_upload_part * ((part_number as usize) - 1);
let content_length = cmp::min(possible_last_chunk_size, bytes_per_upload_part);
let is_last_chunk = bytes_per_upload_part * (part_number as usize) >= expected_size;
// Wrap `AsyncRead` so we can hold a copy of it in this scope between iterations.
// This is quite difficult because we need to give full ownership of an AsyncRead
// to `ByteStream` which has an unknown lifetime.
// This wrapper will also ensure we only send `bytes_per_upload_part` then close the
// stream.
let taker = AsyncReadTaker::new(
reader.clone(),
Some(move || is_done_clone.store(true, Ordering::Relaxed)),
bytes_per_upload_part,
);
let taker = AsyncReadTaker::new(reader.clone(), content_length);
{
let body = Some(ByteStream::new(ReaderStream::new(taker)));
self.s3_client
let response = self
.s3_client
.upload_part(UploadPartRequest {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
content_length: Some(content_length as i64),
body,
part_number,
upload_id: upload_id.clone(),
..Default::default()
})
.await
.map_err(|e| make_err!(Code::Unknown, "Failed to upload part: {:?}", e))?;
completed_parts.push(CompletedPart {
e_tag: response.e_tag,
part_number: Some(part_number),
});
}
if is_done.load(Ordering::Relaxed) {
if is_last_chunk {
break;
}
part_number += 1;
Expand All @@ -240,6 +245,9 @@ impl StoreTrait for S3Store {
bucket: self.bucket.to_owned(),
key: s3_path.to_owned(),
upload_id: upload_id.clone(),
multipart_upload: Some(CompletedMultipartUpload {
parts: Some(completed_parts),
}),
..Default::default()
})
.await
Expand Down Expand Up @@ -288,11 +296,12 @@ impl StoreTrait for S3Store {
..Default::default()
};

let get_object_output = self
.s3_client
.get_object(get_req)
.await
.map_err(|e| make_err!(Code::Unavailable, "Error uploading to S3: {:?}", e))?;
let get_object_output = self.s3_client.get_object(get_req).await.map_err(|e| match e {
RusotoError::Service(GetObjectError::NoSuchKey(err)) => {
return make_err!(Code::NotFound, "Error reading from S3: {:?}", err)
}
_ => make_err!(Code::Unknown, "Error reading from S3: {:?}", e),
})?;
let s3_in_stream = get_object_output
.body
.err_tip(|| "Expected body to be set in s3 get request")?;
Expand Down
41 changes: 39 additions & 2 deletions cas/store/tests/s3_store_test.rs
Expand Up @@ -34,7 +34,7 @@ fn receive_request(sender: mpsc::Sender<(SignedRequest, Vec<u8>)>) -> impl Fn(Si
let mut async_reader = stream.into_async_read();
assert!(block_on(async_reader.read_to_end(&mut raw_payload)).is_ok());
}
Some(SignedRequestPayload::Buffer(buffer)) => raw_payload.copy_from_slice(&buffer[..]),
Some(SignedRequestPayload::Buffer(buffer)) => raw_payload.extend_from_slice(&buffer[..]),
None => {}
}
sender.try_send((request, raw_payload)).expect("Failed to send payload");
Expand Down Expand Up @@ -361,6 +361,10 @@ mod s3_store_tests {
from_utf8(&request.headers["host"][0]).unwrap(),
"s3.us-east-1.amazonaws.com"
);
assert_eq!(
from_utf8(&request.headers["content-length"][0]).unwrap(),
format!("{}", rt_data.len())
);
assert_eq!(request.canonical_query_string, "uploads=");
assert_eq!(
request.canonical_uri,
Expand All @@ -380,6 +384,10 @@ mod s3_store_tests {
from_utf8(&request.headers["host"][0]).unwrap(),
"s3.us-east-1.amazonaws.com"
);
assert_eq!(
from_utf8(&request.headers["content-length"][0]).unwrap(),
format!("{}", rt_data.len())
);
assert_eq!(request.canonical_query_string, "partNumber=1&uploadId=Dummy-uploadid");
assert_eq!(
request.canonical_uri,
Expand All @@ -393,6 +401,10 @@ mod s3_store_tests {
rt_data,
"Expected data to match"
);
assert_eq!(
from_utf8(&request.headers["content-length"][0]).unwrap(),
format!("{}", rt_data.len())
);
assert_eq!(request.canonical_query_string, "partNumber=2&uploadId=Dummy-uploadid");
assert_eq!(
request.canonical_uri,
Expand All @@ -406,6 +418,10 @@ mod s3_store_tests {
rt_data,
"Expected data to match"
);
assert_eq!(
from_utf8(&request.headers["content-length"][0]).unwrap(),
format!("{}", rt_data.len())
);
assert_eq!(request.canonical_query_string, "partNumber=3&uploadId=Dummy-uploadid");
assert_eq!(
request.canonical_uri,
Expand All @@ -415,7 +431,28 @@ mod s3_store_tests {
{
// Final payload is the complete_multipart_upload request.
let (request, rt_data) = receiver.next().await.err_tip(|| "Could not get next payload")?;
assert_eq!(&send_data[0..0], rt_data, "Expected data to match");
const COMPLETE_MULTIPART_PAYLOAD_DATA: &str = concat!(
r#"<?xml version="1.0" encoding="utf-8"?>"#,
"<CompleteMultipartUpload>",
"<Part><PartNumber>1</PartNumber></Part>",
"<Part><PartNumber>2</PartNumber></Part>",
"<Part><PartNumber>3</PartNumber></Part>",
"</CompleteMultipartUpload>",
);
assert_eq!(
from_utf8(&rt_data).unwrap(),
COMPLETE_MULTIPART_PAYLOAD_DATA,
"Expected last payload to be empty"
);
assert_eq!(request.method, "POST");
assert_eq!(
from_utf8(&request.headers["content-length"][0]).unwrap(),
format!("{}", COMPLETE_MULTIPART_PAYLOAD_DATA.len())
);
assert_eq!(
from_utf8(&request.headers["x-amz-content-sha256"][0]).unwrap(),
"730f96c9a87580c7930b5bd4fd0457fbe01b34f2261dcdde877d09b06d937b5e"
);
assert_eq!(request.canonical_query_string, "uploadId=Dummy-uploadid");
assert_eq!(
request.canonical_uri,
Expand Down
6 changes: 3 additions & 3 deletions config/examples/basic_cas.json
Expand Up @@ -6,7 +6,7 @@
"s3_store": {
"region": "us-west-1",
"bucket": "blaisebruer-cas-store",
"key_prefix": "test-prefix-cas",
"key_prefix": "test-prefix-cas/",
"retry": {
"max_retries": 0,
"delay": 0.1,
Expand All @@ -28,7 +28,7 @@
"s3_store": {
"region": "us-west-1",
"bucket": "blaisebruer-cas-store",
"key_prefix": "test-prefix-ac",
"key_prefix": "test-prefix-ac/",
"retry": {
"max_retries": 0,
"delay": 0.1,
Expand Down Expand Up @@ -56,7 +56,7 @@
"cas_stores": {
"main": "CAS_MAIN_STORE",
},
// This value was choosen only because it is a common mem page size.
// This value was chosen only because it is a common mem page size.
"write_buffer_stream_size": 2e6, // 2mb.
"read_buffer_stream_size": 2e6, // 2mb.
// According to https://github.com/grpc/grpc.github.io/issues/371 16KiB - 64KiB is optimal.
Expand Down
22 changes: 6 additions & 16 deletions util/async_read_taker.rs
Expand Up @@ -13,31 +13,26 @@ use tokio::io::{AsyncRead, ReadBuf};

pub type ArcMutexAsyncRead = Arc<Mutex<Box<dyn AsyncRead + Send + Unpin + Sync + 'static>>>;

// TODO(blaise.bruer) It does not look like this class is needed any more. Should consider removing it.
pin_project! {
/// Useful object that can be used to chunk an AsyncReader by a specific size.
/// This also requires the inner reader be sharable between threads. This allows
/// the caller to still "own" the underlying reader in a way that once `limit` is
/// reached the caller can keep using it, but still use this struct to read the data.
pub struct AsyncReadTaker<F: FnOnce()> {
pub struct AsyncReadTaker {
inner: ArcMutexAsyncRead,
done_fn: Option<F>,
// Add '_' to avoid conflicts with `limit` method.
limit_: usize,
}
}

impl<F: FnOnce()> AsyncReadTaker<F> {
/// `done_fn` can be used to pass a functor that will fire when the stream has no more data.
pub fn new(inner: ArcMutexAsyncRead, done_fn: Option<F>, limit: usize) -> Self {
AsyncReadTaker {
inner,
done_fn: done_fn,
limit_: limit,
}
impl AsyncReadTaker {
pub fn new(inner: ArcMutexAsyncRead, limit: usize) -> Self {
AsyncReadTaker { inner, limit_: limit }
}
}

impl<F: FnOnce()> AsyncRead for AsyncReadTaker<F> {
impl AsyncRead for AsyncReadTaker {
/// Note: This function is modeled after tokio::Take::poll_read.
/// see: https://docs.rs/tokio/1.12.0/src/tokio/io/util/take.rs.html#77
fn poll_read(
Expand All @@ -62,11 +57,6 @@ impl<F: FnOnce()> AsyncRead for AsyncReadTaker<F> {
unsafe {
buf.assume_init(n);
}
if n == 0 {
if let Some(done_fn) = self.done_fn.take() {
done_fn();
}
}
buf.advance(n);
self.limit_ -= n;

Expand Down
64 changes: 2 additions & 62 deletions util/tests/async_read_taker_test.rs
@@ -1,6 +1,5 @@
// Copyright 2021 Nathan (Blaise) Bruer. All rights reserved.

use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use fast_async_mutex::mutex::Mutex;
Expand All @@ -21,7 +20,7 @@ mod async_read_taker_tests {
let raw_fixed_buffer = AsyncFixedBuf::new(vec![0u8; 100].into_boxed_slice());
let (rx, mut tx) = tokio::io::split(raw_fixed_buffer);

let mut taker = AsyncReadTaker::new(Arc::new(Mutex::new(Box::new(rx))), None::<Box<fn()>>, 1024);
let mut taker = AsyncReadTaker::new(Arc::new(Mutex::new(Box::new(rx))), 1024);
let write_data = vec![97u8; 50];
{
// Send our data.
Expand All @@ -38,58 +37,6 @@ mod async_read_taker_tests {
Ok(())
}

#[tokio::test]
async fn done_fn_with_split() -> Result<(), Error> {
let raw_fixed_buffer = AsyncFixedBuf::new(vec![0u8; 100].into_boxed_slice());
let (rx, mut tx) = tokio::io::split(raw_fixed_buffer);

const WRITE_DATA: &[u8] = &[97u8; 50];
const READ_AMOUNT: usize = 40;

let reader: ArcMutexAsyncRead = Arc::new(Mutex::new(Box::new(rx)));
let done = Arc::new(AtomicBool::new(false));
{
// Send our data.
tx.write_all(&WRITE_DATA).await?;
tx.write(&vec![]).await?; // Write EOF.
}
{
// Receive first chunk and test our data.
let done_clone = done.clone();
let mut taker = AsyncReadTaker::new(
reader.clone(),
Some(move || done_clone.store(true, Ordering::Relaxed)),
READ_AMOUNT,
);

let mut read_buffer = Vec::new();
let read_sz = taker.read_to_end(&mut read_buffer).await?;
assert_eq!(read_sz, READ_AMOUNT);
assert_eq!(read_buffer.len(), READ_AMOUNT);
assert_eq!(done.load(Ordering::Relaxed), false, "Should not be done");
assert_eq!(&read_buffer, &WRITE_DATA[0..READ_AMOUNT]);
}
{
// Receive last chunk and test our data.
let done_clone = done.clone();
let mut taker = AsyncReadTaker::new(
reader.clone(),
Some(move || done_clone.store(true, Ordering::Relaxed)),
READ_AMOUNT,
);

let mut read_buffer = Vec::new();
let read_sz = taker.read_to_end(&mut read_buffer).await?;
const REMAINING_AMT: usize = WRITE_DATA.len() - READ_AMOUNT;
assert_eq!(read_sz, REMAINING_AMT);
assert_eq!(read_buffer.len(), REMAINING_AMT);
assert_eq!(done.load(Ordering::Relaxed), true, "Should not be done");
assert_eq!(&read_buffer, &WRITE_DATA[READ_AMOUNT..WRITE_DATA.len()]);
}

Ok(())
}

#[tokio::test]
async fn shutdown_during_read() -> Result<(), Error> {
let raw_fixed_buffer = AsyncFixedBuf::new(vec![0u8; 100].into_boxed_slice());
Expand All @@ -99,16 +46,10 @@ mod async_read_taker_tests {
const READ_AMOUNT: usize = 50;

let reader: ArcMutexAsyncRead = Arc::new(Mutex::new(Box::new(rx)));
let done = Arc::new(AtomicBool::new(false));

tx.write_all(&WRITE_DATA).await?;

let done_clone = done.clone();
let mut taker = Box::pin(AsyncReadTaker::new(
reader.clone(),
Some(move || done_clone.store(true, Ordering::Relaxed)),
READ_AMOUNT,
));
let mut taker = Box::pin(AsyncReadTaker::new(reader.clone(), READ_AMOUNT));

let mut read_buffer = Vec::new();
let mut read_fut = taker.read_to_end(&mut read_buffer).boxed();
Expand All @@ -129,7 +70,6 @@ mod async_read_taker_tests {
&read_buffer, &WRITE_DATA,
"Expected poll!() macro to have processed the data we wrote"
);
assert_eq!(done.load(Ordering::Relaxed), false, "Should not have called done_fn");
}

Ok(())
Expand Down

0 comments on commit efcb653

Please sign in to comment.