Skip to content

Commit

Permalink
Fix bug in BytestreamServer where it would ignore finish_write
Browse files Browse the repository at this point in the history
resolves #245
  • Loading branch information
allada committed Sep 8, 2023
1 parent f42f150 commit f645d69
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 12 deletions.
9 changes: 4 additions & 5 deletions cas/grpc_service/bytestream_server.rs
Expand Up @@ -388,14 +388,13 @@ impl ByteStreamServer {
err.code = Code::Internal;
return Err(err);
}
outer_bytes_received.store(tx.get_bytes_written(), Ordering::Release);
}
let bytes_written = tx.get_bytes_written();
outer_bytes_received.store(bytes_written, Ordering::Relaxed);

if expected_size < bytes_written {
if expected_size < tx.get_bytes_written() {
return Err(make_input_err!("Received more bytes than expected"));
}
if expected_size == bytes_written {
if write_request.finish_write {
// Gracefully close our stream.
tx.send_eof()
.await
Expand Down Expand Up @@ -454,7 +453,7 @@ impl ByteStreamServer {
let active_uploads = self.active_uploads.lock();
if let Some((received_bytes, _maybe_idle_stream)) = active_uploads.get(uuid) {
return Ok(Response::new(QueryWriteStatusResponse {
committed_size: received_bytes.load(Ordering::Relaxed) as i64,
committed_size: received_bytes.load(Ordering::Acquire) as i64,
// If we are in the active_uploads map, but the value is None,
// it means the stream is not complete.
complete: false,
Expand Down
88 changes: 88 additions & 0 deletions cas/grpc_service/tests/bytestream_server_test.rs
Expand Up @@ -17,6 +17,8 @@ use std::pin::Pin;
use std::sync::Arc;

use bytestream_server::ByteStreamServer;
use futures::poll;
use futures::task::Poll;
use hyper::body::Sender;
use maplit::hashmap;
use prometheus_client::registry::Registry;
Expand Down Expand Up @@ -228,6 +230,7 @@ pub mod write_tests {
{
// Write the remainder of our data.
write_request.write_offset = BYTE_SPLIT_OFFSET as i64;
write_request.finish_write = true;
write_request.data = WRITE_DATA[BYTE_SPLIT_OFFSET..].into();
tx.send_data(encode_stream_proto(&write_request)?).await?;
}
Expand All @@ -249,6 +252,91 @@ pub mod write_tests {
Ok(())
}

#[tokio::test]
pub async fn ensure_write_is_not_done_until_write_request_is_set() -> Result<(), Box<dyn std::error::Error>> {
let store_manager = make_store_manager().await?;
let bs_server = make_bytestream_server(store_manager.as_ref())?;
let store_owned = store_manager.get_store("main_cas").unwrap();

let store = Pin::new(store_owned.as_ref());

// Setup stream.
let (mut tx, mut write_fut) = {
let (tx, body) = Body::channel();
let mut codec = ProstCodec::<WriteRequest, WriteRequest>::default();
// Note: This is an undocumented function.
let stream = Streaming::new_request(codec.decoder(), body, Some(CompressionEncoding::Gzip), None);

(tx, bs_server.write(Request::new(stream)))
};
const WRITE_DATA: &str = "12456789abcdefghijk";
let resource_name = format!(
"{}/uploads/{}/blobs/{}/{}",
INSTANCE_NAME,
"4dcec57e-1389-4ab5-b188-4a59f22ceb4b", // Randomly generated.
HASH1,
WRITE_DATA.len()
);
let mut write_request = WriteRequest {
resource_name,
write_offset: 0,
finish_write: false,
data: vec![].into(),
};
{
// Write our data.
write_request.write_offset = 0;
write_request.data = WRITE_DATA[..].into();
tx.send_data(encode_stream_proto(&write_request)?).await?;
}
// Note: We have to pull multiple times because there are multiple futures
// joined onto this one future and we need to ensure we run the state machine as
// far as possible.
for _ in 0..100 {
assert!(
poll!(&mut write_fut).is_pending(),
"Expected the future to not be completed yet"
);
}
{
// Write our EOF.
write_request.write_offset = WRITE_DATA.len() as i64;
write_request.finish_write = true;
write_request.data.clear();
tx.send_data(encode_stream_proto(&write_request)?).await?;
}
let mut result = None;
for _ in 0..100 {
if let Poll::Ready(r) = poll!(&mut write_fut) {
result = Some(r);
break;
}
}
{
// Check our results.
assert_eq!(
result
.err_tip(|| "bs_server.write never returned a value")?
.err_tip(|| "bs_server.write returned an error")?
.into_inner(),
WriteResponse {
committed_size: WRITE_DATA.len() as i64
},
"Expected Responses to match"
);
}
{
// Check to make sure our store recorded the data properly.
let digest = DigestInfo::try_new(HASH1, WRITE_DATA.len())?;
assert_eq!(
store.get_part_unchunked(digest, 0, None, None).await?,
WRITE_DATA,
"Data written to store did not match expected data",
);
}
Ok(())
}

#[tokio::test]
pub async fn out_of_order_data_fails() -> Result<(), Box<dyn std::error::Error>> {
let store_manager = make_store_manager().await?;
Expand Down
14 changes: 9 additions & 5 deletions cas/store/ac_utils.rs
Expand Up @@ -17,7 +17,7 @@ use std::io::Cursor;
use std::pin::Pin;

use bytes::BytesMut;
use futures::{future::try_join, Future, FutureExt, TryFutureExt};
use futures::{future::join, Future, FutureExt};
use prost::Message;
use sha2::{Digest, Sha256};
use tokio::io::{AsyncRead, AsyncReadExt};
Expand Down Expand Up @@ -108,10 +108,14 @@ fn inner_upload_file_to_store<'a, Fut: Future<Output = Result<(), Error>> + 'a>(
read_data_fn: impl FnOnce(DropCloserWriteHalf) -> Fut,
) -> impl Future<Output = Result<(), Error>> + 'a {
let (tx, rx) = make_buf_channel_pair();
let upload_file_to_store_fut = cas_store
.update(digest, rx, UploadSizeInfo::ExactSize(digest.size_bytes as usize))
.map(|r| r.err_tip(|| "Could not upload data to store in upload_file_to_store"));
try_join(read_data_fn(tx), upload_file_to_store_fut).map_ok(|(_, _)| ())
join(
cas_store
.update(digest, rx, UploadSizeInfo::ExactSize(digest.size_bytes as usize))
.map(|r| r.err_tip(|| "Could not upload data to store in upload_file_to_store")),
read_data_fn(tx),
)
// Ensure we get errors reported from both sides
.map(|(upload_result, read_result)| upload_result.merge(read_result))
}

/// Uploads data to our store for given digest.
Expand Down
4 changes: 2 additions & 2 deletions cas/store/fast_slow_store.rs
Expand Up @@ -158,8 +158,8 @@ impl StoreTrait for FastSlowStore {
}
};

let fast_store_fut = self.pin_slow_store().update(digest, fast_rx, size_info);
let slow_store_fut = self.pin_fast_store().update(digest, slow_rx, size_info);
let fast_store_fut = self.pin_fast_store().update(digest, fast_rx, size_info);
let slow_store_fut = self.pin_slow_store().update(digest, slow_rx, size_info);

let (data_stream_res, fast_res, slow_res) = join!(data_stream_fut, fast_store_fut, slow_store_fut);
data_stream_res.merge(fast_res).merge(slow_res)?;
Expand Down

0 comments on commit f645d69

Please sign in to comment.