Skip to content

Commit

Permalink
Wasi input-stream: use same errors as output-stream (bytecodealliance…
Browse files Browse the repository at this point in the history
…#7090)

* streams.wit: delete stream-status, rename write-error to stream-error, transform all input-stream methods

* preview2: use StreamError throughout input-stream

* preview2: passes cargo test

* preview1: fixes for input-stream stream-error.

* wasmtime-wasi-http: fixes for HostInputStream trait changes

* component adapter: fixes for input-stream changes

* test programs: fixes for input-stream

* component adapter: handle StreamError::Closed in fd_read

* sync wit definitions to wasi-http

* fix!!

* preview1: handle eof and intr properly

prtest:full

* Fix preview1 stdin reading

* Touch up stream documentation

---------

Co-authored-by: Alex Crichton <alex@alexcrichton.com>
  • Loading branch information
Pat Hickey and alexcrichton committed Oct 4, 2023
1 parent 3c3ea44 commit a5ccfe6
Show file tree
Hide file tree
Showing 16 changed files with 430 additions and 631 deletions.
13 changes: 6 additions & 7 deletions crates/test-programs/wasi-http-tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,14 @@ pub fn request(
let input_stream_pollable = input_stream.subscribe();

let mut body = Vec::new();
let mut eof = streams::StreamStatus::Open;
while eof != streams::StreamStatus::Ended {
loop {
poll::poll_list(&[&input_stream_pollable]);

let (mut body_chunk, stream_status) = input_stream
.read(1024 * 1024)
.map_err(|_| anyhow!("input_stream read failed"))?;

eof = stream_status;
let mut body_chunk = match input_stream.read(1024 * 1024) {
Ok(c) => c,
Err(streams::StreamError::Closed) => break,
Err(e) => Err(anyhow!("input_stream read failed: {e:?}"))?,
};

if !body_chunk.is_empty() {
body.append(&mut body_chunk);
Expand Down
45 changes: 11 additions & 34 deletions crates/test-programs/wasi-sockets-tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,24 @@ use wasi::io::poll;
use wasi::io::streams;
use wasi::sockets::{network, tcp, tcp_create_socket};

pub fn write(output: &streams::OutputStream, mut bytes: &[u8]) -> (usize, streams::StreamStatus) {
let total = bytes.len();
let mut written = 0;

pub fn write(output: &streams::OutputStream, mut bytes: &[u8]) -> Result<(), streams::StreamError> {
let pollable = output.subscribe();

while !bytes.is_empty() {
poll::poll_list(&[&pollable]);

let permit = match output.check_write() {
Ok(n) => n,
Err(_) => return (written, streams::StreamStatus::Ended),
};
let permit = output.check_write()?;

let len = bytes.len().min(permit as usize);
let (chunk, rest) = bytes.split_at(len);

match output.write(chunk) {
Ok(()) => {}
Err(_) => return (written, streams::StreamStatus::Ended),
}
output.write(chunk)?;

match output.blocking_flush() {
Ok(()) => {}
Err(_) => return (written, streams::StreamStatus::Ended),
}
output.blocking_flush()?;

bytes = rest;
written += len;
}

(total, streams::StreamStatus::Open)
Ok(())
}

pub fn example_body(net: tcp::Network, sock: tcp::TcpSocket, family: network::IpAddressFamily) {
Expand All @@ -59,13 +45,9 @@ pub fn example_body(net: tcp::Network, sock: tcp::TcpSocket, family: network::Ip
poll::poll_one(&client_sub);
let (client_input, client_output) = client.finish_connect().unwrap();

let (n, status) = write(&client_output, &[]);
assert_eq!(n, 0);
assert_eq!(status, streams::StreamStatus::Open);
write(&client_output, &[]).unwrap();

let (n, status) = write(&client_output, first_message);
assert_eq!(n, first_message.len());
assert_eq!(status, streams::StreamStatus::Open);
write(&client_output, first_message).unwrap();

drop(client_input);
drop(client_output);
Expand All @@ -75,12 +57,10 @@ pub fn example_body(net: tcp::Network, sock: tcp::TcpSocket, family: network::Ip
poll::poll_one(&sub);
let (accepted, input, output) = sock.accept().unwrap();

let (empty_data, status) = input.read(0).unwrap();
let empty_data = input.read(0).unwrap();
assert!(empty_data.is_empty());
assert_eq!(status, streams::StreamStatus::Open);

let (data, status) = input.blocking_read(first_message.len() as u64).unwrap();
assert_eq!(status, streams::StreamStatus::Open);
let data = input.blocking_read(first_message.len() as u64).unwrap();

drop(input);
drop(output);
Expand All @@ -97,9 +77,7 @@ pub fn example_body(net: tcp::Network, sock: tcp::TcpSocket, family: network::Ip
poll::poll_one(&client_sub);
let (client_input, client_output) = client.finish_connect().unwrap();

let (n, status) = write(&client_output, second_message);
assert_eq!(n, second_message.len());
assert_eq!(status, streams::StreamStatus::Open);
write(&client_output, second_message).unwrap();

drop(client_input);
drop(client_output);
Expand All @@ -108,8 +86,7 @@ pub fn example_body(net: tcp::Network, sock: tcp::TcpSocket, family: network::Ip

poll::poll_one(&sub);
let (accepted, input, output) = sock.accept().unwrap();
let (data, status) = input.blocking_read(second_message.len() as u64).unwrap();
assert_eq!(status, streams::StreamStatus::Open);
let data = input.blocking_read(second_message.len() as u64).unwrap();

drop(input);
drop(output);
Expand Down
37 changes: 18 additions & 19 deletions crates/wasi-http/src/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ use std::{
};
use tokio::sync::{mpsc, oneshot};
use wasmtime_wasi::preview2::{
self, AbortOnDropJoinHandle, HostInputStream, HostOutputStream, OutputStreamError,
StreamRuntimeError, StreamState, Subscribe,
self, AbortOnDropJoinHandle, HostInputStream, HostOutputStream, StreamError, Subscribe,
};

pub type HyperIncomingBody = BoxBody<Bytes, anyhow::Error>;
Expand Down Expand Up @@ -146,21 +145,21 @@ impl HostIncomingBodyStream {

#[async_trait::async_trait]
impl HostInputStream for HostIncomingBodyStream {
fn read(&mut self, size: usize) -> anyhow::Result<(Bytes, StreamState)> {
fn read(&mut self, size: usize) -> Result<Bytes, StreamError> {
use mpsc::error::TryRecvError;

if !self.buffer.is_empty() {
let len = size.min(self.buffer.len());
let chunk = self.buffer.split_to(len);
return Ok((chunk, StreamState::Open));
return Ok(chunk);
}

if let Some(e) = self.error.take() {
return Err(StreamRuntimeError::from(e).into());
return Err(StreamError::LastOperationFailed(e));
}

if !self.open {
return Ok((Bytes::new(), StreamState::Closed));
return Err(StreamError::Closed);
}

match self.receiver.try_recv() {
Expand All @@ -171,21 +170,21 @@ impl HostInputStream for HostIncomingBodyStream {
self.buffer = bytes;
}

return Ok((chunk, StreamState::Open));
return Ok(chunk);
}

Ok(Err(e)) => {
self.open = false;
return Err(StreamRuntimeError::from(e).into());
return Err(StreamError::LastOperationFailed(e));
}

Err(TryRecvError::Empty) => {
return Ok((Bytes::new(), StreamState::Open));
return Ok(Bytes::new());
}

Err(TryRecvError::Disconnected) => {
self.open = false;
return Ok((Bytes::new(), StreamState::Closed));
return Err(StreamError::Closed);
}
}
}
Expand Down Expand Up @@ -332,12 +331,12 @@ struct WorkerState {
}

impl WorkerState {
fn check_error(&mut self) -> Result<(), OutputStreamError> {
fn check_error(&mut self) -> Result<(), StreamError> {
if let Some(e) = self.error.take() {
return Err(OutputStreamError::LastOperationFailed(e));
return Err(StreamError::LastOperationFailed(e));
}
if !self.alive {
return Err(OutputStreamError::Closed);
return Err(StreamError::Closed);
}
Ok(())
}
Expand Down Expand Up @@ -382,7 +381,7 @@ impl Worker {
self.write_ready_changed.notified().await;
}
}
fn check_write(&self) -> Result<usize, OutputStreamError> {
fn check_write(&self) -> Result<usize, StreamError> {
let mut state = self.state();
if let Err(e) = state.check_error() {
return Err(e);
Expand Down Expand Up @@ -476,11 +475,11 @@ impl BodyWriteStream {

#[async_trait::async_trait]
impl HostOutputStream for BodyWriteStream {
fn write(&mut self, bytes: Bytes) -> Result<(), OutputStreamError> {
fn write(&mut self, bytes: Bytes) -> Result<(), StreamError> {
let mut state = self.worker.state();
state.check_error()?;
if state.flush_pending {
return Err(OutputStreamError::Trap(anyhow!(
return Err(StreamError::Trap(anyhow!(
"write not permitted while flush pending"
)));
}
Expand All @@ -489,13 +488,13 @@ impl HostOutputStream for BodyWriteStream {
state.write_budget = remaining_budget;
state.items.push_back(bytes);
}
None => return Err(OutputStreamError::Trap(anyhow!("write exceeded budget"))),
None => return Err(StreamError::Trap(anyhow!("write exceeded budget"))),
}
drop(state);
self.worker.new_work.notify_one();
Ok(())
}
fn flush(&mut self) -> Result<(), OutputStreamError> {
fn flush(&mut self) -> Result<(), StreamError> {
let mut state = self.worker.state();
state.check_error()?;

Expand All @@ -505,7 +504,7 @@ impl HostOutputStream for BodyWriteStream {
Ok(())
}

fn check_write(&mut self) -> Result<usize, OutputStreamError> {
fn check_write(&mut self) -> Result<usize, StreamError> {
self.worker.check_write()
}
}
Expand Down
Loading

0 comments on commit a5ccfe6

Please sign in to comment.