diff --git a/sdk/core/src/bytes_stream.rs b/sdk/core/src/bytes_stream.rs index b7a898ca54..3f192d460b 100644 --- a/sdk/core/src/bytes_stream.rs +++ b/sdk/core/src/bytes_stream.rs @@ -82,10 +82,11 @@ impl AsyncRead for BytesStream { let remaining_bytes = self_mut.bytes.len() - bytes_read; let bytes_to_copy = std::cmp::min(remaining_bytes, buf.len()); + let bytes_to_read_end = self_mut.bytes_read + bytes_to_copy; for (buf_byte, bytes_byte) in buf .iter_mut() - .zip(self_mut.bytes.slice(self_mut.bytes_read..bytes_to_copy)) + .zip(self_mut.bytes.slice(self_mut.bytes_read..bytes_to_read_end)) { *buf_byte = bytes_byte; } @@ -98,3 +99,54 @@ impl AsyncRead for BytesStream { } } } + +// Unit tests +#[cfg(test)] +mod tests { + use super::*; + use futures::io::AsyncReadExt; + use futures::stream::StreamExt; + + // Test BytesStream Stream + #[test] + fn test_bytes_stream() { + let bytes = Bytes::from("hello world"); + let mut stream = BytesStream::new(bytes.clone()); + + let mut buf = Vec::new(); + let mut bytes_read = 0; + while let Some(Ok(bytes)) = futures::executor::block_on(stream.next()) { + buf.extend_from_slice(&bytes); + bytes_read += bytes.len(); + } + + assert_eq!(bytes_read, bytes.len()); + assert_eq!(buf, bytes); + } + + // Test BytesStream AsyncRead, all bytes at once + #[test] + fn test_async_read_all_bytes_at_once() { + let bytes = Bytes::from("hello world"); + let mut stream = BytesStream::new(bytes.clone()); + + let mut buf = [0; 11]; + let bytes_read = futures::executor::block_on(stream.read(&mut buf)).unwrap(); + assert_eq!(bytes_read, 11); + assert_eq!(&buf[..], &bytes); + } + + // Test BytesStream AsyncRead, one byte at a time + #[test] + fn test_async_read_one_byte_at_a_time() { + let bytes = Bytes::from("hello world"); + let mut stream = BytesStream::new(bytes.clone()); + + for i in 0..bytes.len() { + let mut buf = [0; 1]; + let bytes_read = futures::executor::block_on(stream.read(&mut buf)).unwrap(); + assert_eq!(bytes_read, 1); + assert_eq!(buf[0], bytes[i]); + } + } +}