Skip to content

Commit

Permalink
Avoid writer EOF until fast store complete (#480)
Browse files Browse the repository at this point in the history
Currently we send the EOF to the fast store and the requester at the same time.
The bytestream_server drops the get_part future as soon as it receives the EOF
which means that the fast store doesn't get time to sync and populate.

This is resolved by only sending the EOF to the requestor in the fast_slow_store
once all of the futures have completed.  The alternative is to spawn for the fast
store or to complete the get_part future in bytestream_server, but this seems like
the least heavy weight solution which doesn't require thought by the user.
  • Loading branch information
chrisstaite-menlo committed Dec 14, 2023
1 parent f2bd770 commit 2de8867
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 8 deletions.
24 changes: 16 additions & 8 deletions nativelink-store/src/fast_slow_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,10 @@ impl Store for FastSlowStore {
.err_tip(|| "Failed to read data data buffer from slow store")?;
if output_buf.is_empty() {
// Write out our EOF.
// It is possible for the client to disconnect the stream because they got
// all the data they wanted, which could lead to an error when writing this
// EOF. If that was to happen, we could end up terminating this early and
// the resulting upload to the fast store might fail.
let (fast_res, slow_res) = join!(fast_tx.send_eof(), writer_pin.send_eof());
return fast_res.merge(slow_res);
// We are dropped as soon as we send_eof to writer_pin, so
// we wait until we've finished all of our joins to do that.
let fast_res = fast_tx.send_eof().await;
return Ok::<_, Error>((fast_res, writer_pin));
}

let writer_fut = if let Some(range) =
Expand All @@ -256,8 +254,18 @@ impl Store for FastSlowStore {
let fast_store_fut = fast_store.update(digest, fast_rx, UploadSizeInfo::ExactSize(sz));

let (data_stream_res, slow_res, fast_res) = join!(data_stream_fut, slow_store_fut, fast_store_fut);
data_stream_res.merge(fast_res).merge(slow_res)?;
Ok(())
match data_stream_res {
Ok((fast_eof_res, mut writer_pin)) =>
// Sending the EOF will drop us almost immediately in bytestream_server
// so we perform it as the very last action in this method.
{
fast_eof_res
.merge(fast_res)
.merge(slow_res)
.merge(writer_pin.send_eof().await)
}
Err(err) => fast_res.merge(slow_res).merge(Err(err)),
}
}

fn inner_store(self: Arc<Self>, _digest: Option<DigestInfo>) -> Arc<dyn Store> {
Expand Down
140 changes: 140 additions & 0 deletions nativelink-store/tests/fast_slow_store_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ async fn check_data<S: Store>(

#[cfg(test)]
mod fast_slow_store_tests {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;

use async_trait::async_trait;
use bytes::Bytes;
use error::{make_err, Code, ResultExt};
use nativelink_util::buf_channel::make_buf_channel_pair;
use pretty_assertions::assert_eq;

use super::*; // Must be declared in every module.
Expand Down Expand Up @@ -215,4 +222,137 @@ mod fast_slow_store_tests {
assert_eq!(test(received_range, send_range), expected_results);
}
}

#[tokio::test]
async fn drop_on_eof_completes_store_futures() -> Result<(), Error> {
struct DropCheckStore {
drop_flag: Arc<AtomicBool>,
read_rx: Mutex<Option<tokio::sync::oneshot::Receiver<()>>>,
eof_tx: Mutex<Option<tokio::sync::oneshot::Sender<()>>>,
digest: Option<DigestInfo>,
}

#[async_trait]
impl Store for DropCheckStore {
async fn has_with_results(
self: Pin<&Self>,
digests: &[DigestInfo],
results: &mut [Option<usize>],
) -> Result<(), Error> {
if let Some(has_digest) = self.digest {
for (digest, result) in digests.iter().zip(results.iter_mut()) {
if digest.hash_str() == has_digest.hash_str() {
*result = Some(has_digest.size_bytes as usize);
}
}
}
Ok(())
}

async fn update(
self: Pin<&Self>,
_digest: DigestInfo,
mut reader: nativelink_util::buf_channel::DropCloserReadHalf,
_size_info: nativelink_util::store_trait::UploadSizeInfo,
) -> Result<(), Error> {
// Gets called in the fast store and we don't need to do
// anything. Should only complete when drain has finished.
reader.drain().await?;
let eof_tx = self.eof_tx.lock().unwrap().take();
if let Some(tx) = eof_tx {
tx.send(()).map_err(|e| make_err!(Code::Internal, "{:?}", e))?;
}
let read_rx = self.read_rx.lock().unwrap().take();
if let Some(rx) = read_rx {
rx.await.map_err(|e| make_err!(Code::Internal, "{:?}", e))?;
}
Ok(())
}

async fn get_part_ref(
self: Pin<&Self>,
digest: DigestInfo,
writer: &mut nativelink_util::buf_channel::DropCloserWriteHalf,
offset: usize,
length: Option<usize>,
) -> Result<(), Error> {
// Gets called in the slow store and we provide the data that's
// sent to the upstream and the fast store.
let bytes = length.unwrap_or(digest.size_bytes as usize) - offset;
let data = vec![0_u8; bytes];
writer.send(Bytes::copy_from_slice(&data)).await?;
writer.send_eof().await
}

fn inner_store(self: Arc<Self>, _digest: Option<DigestInfo>) -> Arc<dyn Store> {
self
}

fn as_any(self: Arc<Self>) -> Box<dyn std::any::Any + Send> {
Box::new(self)
}

fn register_metrics(self: Arc<Self>, _registry: &mut nativelink_util::metrics_utils::Registry) {}
}

impl Drop for DropCheckStore {
fn drop(&mut self) {
self.drop_flag.store(true, Ordering::Release);
}
}

let digest = DigestInfo::try_new(VALID_HASH, 100).unwrap();
let (fast_store_read_tx, fast_store_read_rx) = tokio::sync::oneshot::channel();
let (fast_store_eof_tx, fast_store_eof_rx) = tokio::sync::oneshot::channel();
let fast_store_dropped = Arc::new(AtomicBool::new(false));
let fast_store: Arc<DropCheckStore> = Arc::new(DropCheckStore {
drop_flag: fast_store_dropped.clone(),
eof_tx: Mutex::new(Some(fast_store_eof_tx)),
read_rx: Mutex::new(Some(fast_store_read_rx)),
digest: None,
});
let slow_store_dropped = Arc::new(AtomicBool::new(false));
let slow_store: Arc<DropCheckStore> = Arc::new(DropCheckStore {
drop_flag: slow_store_dropped,
eof_tx: Mutex::new(None),
read_rx: Mutex::new(None),
digest: Some(digest),
});

let fast_slow_store = Arc::new(FastSlowStore::new(
&nativelink_config::stores::FastSlowStore {
fast: nativelink_config::stores::StoreConfig::memory(nativelink_config::stores::MemoryStore::default()),
slow: nativelink_config::stores::StoreConfig::memory(nativelink_config::stores::MemoryStore::default()),
},
fast_store,
slow_store,
));

let (tx, mut rx) = make_buf_channel_pair();
let (get_res, read_res) = tokio::join!(
async move {
// Drop get_part_arc as soon as rx.drain() completes
tokio::select!(
res = rx.drain() => res,
res = fast_slow_store.get_part_arc(digest, tx, 0, Some(digest.size_bytes as usize)) => res,
)
},
async move {
fast_store_eof_rx
.await
.map_err(|e| make_err!(Code::Internal, "{:?}", e))?;
// Give a couple of cycles for dropping to occur if it's going to.
tokio::task::yield_now().await;
tokio::task::yield_now().await;
if fast_store_dropped.load(Ordering::Acquire) {
return Err(make_err!(Code::Internal, "Fast store was dropped!"));
}
fast_store_read_tx
.send(())
.map_err(|e| make_err!(Code::Internal, "{:?}", e))?;
Ok::<_, Error>(())
}
);
get_res.merge(read_res)
}
}

0 comments on commit 2de8867

Please sign in to comment.