Skip to content

Commit

Permalink
Fix a potential bug in DropCloserReadHalf::take() (#606)
Browse files Browse the repository at this point in the history
  • Loading branch information
steedmicro committed Jan 16, 2024
1 parent 3acefc7 commit 70e8525
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 14 deletions.
48 changes: 35 additions & 13 deletions nativelink-util/src/buf_channel.rs
Expand Up @@ -303,16 +303,26 @@ impl DropCloserReadHalf {
let (first_chunk, second_chunk) = {
// This is an optimization for a relatively common case when the first chunk in the
// stream satisfies all the requirements to fill this `take()`.
// This will us from needing to copy the data into a new buffer and instead we can
// This will prevent us from needing to copy the data into a new buffer and instead we can
// just forward on the original Bytes object. If we need more than the first chunk
// we will then go the slow path and actually copy our data.

// 1. Read some data from our stream (or self.partial).
let mut first_chunk = self.recv().await.err_tip(|| "During first buf_channel::take()")?;
assert!(
self.partial.is_none(),
"Partial should have been consumed during the recv()"
);
// 2. Split our data so `first_chunk` is <= `size` and puts any remaining
// in `self.partial` (or set it to None).
populate_partial_if_needed(0, size, &mut first_chunk, &mut self.partial);
if first_chunk.is_empty() || first_chunk.len() >= size {
assert!(
first_chunk.is_empty() || first_chunk.len() == size,
"Length should be exactly size here"
);
// 3a. If our `first_chunk` is EOF, we are done.
if first_chunk.is_empty() {
return Ok(first_chunk);
}
// 3b. If our first_chunk has data and it our self.partial was filled it means our stream has more data.
if self.partial.is_some() {
assert!(first_chunk.len() == size, "Length should be exactly size here");
return Ok(first_chunk);
}

Expand All @@ -332,21 +342,33 @@ impl DropCloserReadHalf {
output.put(second_chunk);

loop {
if self.partial.is_some() {
assert!(
output.len() == size,
"If partial is set expected output length to be {size}"
);
return Ok(output.freeze());
}
assert!(
output.len() <= size,
"Length should never be larger than size in take()"
);

let mut chunk = self.recv().await.err_tip(|| "During buf_channel::take()")?;
assert!(
self.partial.is_none(),
"Partial should have been consumed during the recv()"
);
if chunk.is_empty() {
break; // EOF.
// Forward EOF to next recv() and return our current buffer.
self.partial = Some(Ok(chunk));
return Ok(output.freeze());
}

populate_partial_if_needed(output.len(), size, &mut chunk, &mut self.partial);

output.put(chunk);

if output.len() >= size {
assert!(output.len() == size); // Length should never be larger than size here.
break;
}
}
Ok(output.freeze())
}
}

Expand Down
39 changes: 38 additions & 1 deletion nativelink-util/tests/buf_channel_test.rs
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use bytes::Bytes;
use bytes::{Bytes, BytesMut};
use nativelink_error::{make_err, Code, Error, ResultExt};
use nativelink_util::buf_channel::make_buf_channel_pair;
use tokio::try_join;
Expand All @@ -25,6 +25,7 @@ mod buf_channel_tests {

const DATA1: &str = "foo";
const DATA2: &str = "bar";
const DATA3: &str = "foobar1234";

#[tokio::test]
async fn smoke_test() -> Result<(), Error> {
Expand Down Expand Up @@ -222,6 +223,42 @@ mod buf_channel_tests {
Ok(())
}

#[tokio::test]
async fn send_and_take_fuzz_test() -> Result<(), Error> {
const DATA3_END_POS: usize = DATA3.len() + 1;
for data_size in 1..DATA3_END_POS {
let data: Vec<u8> = DATA3.as_bytes()[0..data_size].to_vec();

for write_size in 1..DATA3_END_POS {
for read_size in 1..DATA3_END_POS {
let tx_data = Bytes::from(data.clone());
let expected_data = Bytes::from(data.clone());

let (mut tx, mut rx) = make_buf_channel_pair();

let tx_fut = async move {
for i in (0..data_size).step_by(write_size) {
tx.send(tx_data.slice(i..std::cmp::min(data_size, i + write_size)))
.await?;
}
tx.send_eof().await?;
Result::<(), Error>::Ok(())
};
let rx_fut = async move {
let mut round_trip_data = BytesMut::new();
for _ in (0..data_size).step_by(read_size) {
round_trip_data.extend(rx.take(read_size).await?.iter());
}
assert_eq!(round_trip_data.freeze(), expected_data);
Result::<(), Error>::Ok(())
};
try_join!(tx_fut, rx_fut)?;
}
}
}
Ok(())
}

#[tokio::test]
async fn rx_gets_error_if_tx_drops_test() -> Result<(), Error> {
let (mut tx, mut rx) = make_buf_channel_pair();
Expand Down

0 comments on commit 70e8525

Please sign in to comment.