From 70e852598580e48d54835b6ea7d2be6ec953b7b3 Mon Sep 17 00:00:00 2001 From: Steed <93139780+steed924@users.noreply.github.com> Date: Tue, 16 Jan 2024 16:46:08 +0100 Subject: [PATCH] Fix a potential bug in DropCloserReadHalf::take() (#606) --- nativelink-util/src/buf_channel.rs | 48 +++++++++++++++++------ nativelink-util/tests/buf_channel_test.rs | 39 +++++++++++++++++- 2 files changed, 73 insertions(+), 14 deletions(-) diff --git a/nativelink-util/src/buf_channel.rs b/nativelink-util/src/buf_channel.rs index d59ef5d0c..535a15e32 100644 --- a/nativelink-util/src/buf_channel.rs +++ b/nativelink-util/src/buf_channel.rs @@ -303,16 +303,26 @@ impl DropCloserReadHalf { let (first_chunk, second_chunk) = { // This is an optimization for a relatively common case when the first chunk in the // stream satisfies all the requirements to fill this `take()`. - // This will us from needing to copy the data into a new buffer and instead we can + // This will prevent us from needing to copy the data into a new buffer and instead we can // just forward on the original Bytes object. If we need more than the first chunk // we will then go the slow path and actually copy our data. + + // 1. Read some data from our stream (or self.partial). let mut first_chunk = self.recv().await.err_tip(|| "During first buf_channel::take()")?; + assert!( + self.partial.is_none(), + "Partial should have been consumed during the recv()" + ); + // 2. Split our data so `first_chunk` is <= `size` and puts any remaining + // in `self.partial` (or set it to None). populate_partial_if_needed(0, size, &mut first_chunk, &mut self.partial); - if first_chunk.is_empty() || first_chunk.len() >= size { - assert!( - first_chunk.is_empty() || first_chunk.len() == size, - "Length should be exactly size here" - ); + // 3a. If our `first_chunk` is EOF, we are done. + if first_chunk.is_empty() { + return Ok(first_chunk); + } + // 3b. If our first_chunk has data and it our self.partial was filled it means our stream has more data. + if self.partial.is_some() { + assert!(first_chunk.len() == size, "Length should be exactly size here"); return Ok(first_chunk); } @@ -332,21 +342,33 @@ impl DropCloserReadHalf { output.put(second_chunk); loop { + if self.partial.is_some() { + assert!( + output.len() == size, + "If partial is set expected output length to be {size}" + ); + return Ok(output.freeze()); + } + assert!( + output.len() <= size, + "Length should never be larger than size in take()" + ); + let mut chunk = self.recv().await.err_tip(|| "During buf_channel::take()")?; + assert!( + self.partial.is_none(), + "Partial should have been consumed during the recv()" + ); if chunk.is_empty() { - break; // EOF. + // Forward EOF to next recv() and return our current buffer. + self.partial = Some(Ok(chunk)); + return Ok(output.freeze()); } populate_partial_if_needed(output.len(), size, &mut chunk, &mut self.partial); output.put(chunk); - - if output.len() >= size { - assert!(output.len() == size); // Length should never be larger than size here. - break; - } } - Ok(output.freeze()) } } diff --git a/nativelink-util/tests/buf_channel_test.rs b/nativelink-util/tests/buf_channel_test.rs index c0d297c45..ce2e570cd 100644 --- a/nativelink-util/tests/buf_channel_test.rs +++ b/nativelink-util/tests/buf_channel_test.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_util::buf_channel::make_buf_channel_pair; use tokio::try_join; @@ -25,6 +25,7 @@ mod buf_channel_tests { const DATA1: &str = "foo"; const DATA2: &str = "bar"; + const DATA3: &str = "foobar1234"; #[tokio::test] async fn smoke_test() -> Result<(), Error> { @@ -222,6 +223,42 @@ mod buf_channel_tests { Ok(()) } + #[tokio::test] + async fn send_and_take_fuzz_test() -> Result<(), Error> { + const DATA3_END_POS: usize = DATA3.len() + 1; + for data_size in 1..DATA3_END_POS { + let data: Vec = DATA3.as_bytes()[0..data_size].to_vec(); + + for write_size in 1..DATA3_END_POS { + for read_size in 1..DATA3_END_POS { + let tx_data = Bytes::from(data.clone()); + let expected_data = Bytes::from(data.clone()); + + let (mut tx, mut rx) = make_buf_channel_pair(); + + let tx_fut = async move { + for i in (0..data_size).step_by(write_size) { + tx.send(tx_data.slice(i..std::cmp::min(data_size, i + write_size))) + .await?; + } + tx.send_eof().await?; + Result::<(), Error>::Ok(()) + }; + let rx_fut = async move { + let mut round_trip_data = BytesMut::new(); + for _ in (0..data_size).step_by(read_size) { + round_trip_data.extend(rx.take(read_size).await?.iter()); + } + assert_eq!(round_trip_data.freeze(), expected_data); + Result::<(), Error>::Ok(()) + }; + try_join!(tx_fut, rx_fut)?; + } + } + } + Ok(()) + } + #[tokio::test] async fn rx_gets_error_if_tx_drops_test() -> Result<(), Error> { let (mut tx, mut rx) = make_buf_channel_pair();