diff --git a/Cargo.lock b/Cargo.lock index c2ffecfaf..647654196 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2042,6 +2042,7 @@ dependencies = [ "memchr", "mio", "num_cpus", + "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", diff --git a/crates/binstalk/Cargo.toml b/crates/binstalk/Cargo.toml index f1054b420..24d097973 100644 --- a/crates/binstalk/Cargo.toml +++ b/crates/binstalk/Cargo.toml @@ -46,7 +46,8 @@ tar = { package = "binstall-tar", version = "0.4.39" } tempfile = "3.3.0" thiserror = "1.0.37" tinytemplate = "1.2.1" -tokio = { version = "1.21.2", features = ["macros", "rt", "process", "sync", "signal", "time"], default-features = false } +# parking_lot - for OnceCell::const_new +tokio = { version = "1.21.2", features = ["macros", "rt", "process", "sync", "signal", "time", "parking_lot"], default-features = false } toml_edit = { version = "0.14.4", features = ["easy"] } tower = { version = "0.4.13", features = ["limit", "util"] } trust-dns-resolver = { version = "0.21.2", optional = true, default-features = false, features = ["dnssec-ring"] } diff --git a/crates/binstalk/src/errors.rs b/crates/binstalk/src/errors.rs index 19b8554b7..0fc369702 100644 --- a/crates/binstalk/src/errors.rs +++ b/crates/binstalk/src/errors.rs @@ -1,4 +1,5 @@ use std::{ + io, path::PathBuf, process::{ExitCode, ExitStatus, Termination}, }; @@ -99,7 +100,7 @@ pub enum BinstallError { /// - Exit: 74 #[error(transparent)] #[diagnostic(severity(error), code(binstall::io))] - Io(std::io::Error), + Io(io::Error), /// An error interacting with the crates.io API. /// @@ -392,8 +393,8 @@ impl Termination for BinstallError { } } -impl From for BinstallError { - fn from(err: std::io::Error) -> Self { +impl From for BinstallError { + fn from(err: io::Error) -> Self { if err.get_ref().is_some() { let kind = err.kind(); @@ -404,9 +405,18 @@ impl From for BinstallError { inner .downcast() .map(|b| *b) - .unwrap_or_else(|err| BinstallError::Io(std::io::Error::new(kind, err))) + .unwrap_or_else(|err| BinstallError::Io(io::Error::new(kind, err))) } else { BinstallError::Io(err) } } } + +impl From for io::Error { + fn from(e: BinstallError) -> io::Error { + match e { + BinstallError::Io(io_error) => io_error, + e => io::Error::new(io::ErrorKind::Other, e), + } + } +} diff --git a/crates/binstalk/src/helpers/download/async_extracter.rs b/crates/binstalk/src/helpers/download/async_extracter.rs index 840cbf210..21311b3da 100644 --- a/crates/binstalk/src/helpers/download/async_extracter.rs +++ b/crates/binstalk/src/helpers/download/async_extracter.rs @@ -1,7 +1,7 @@ use std::{ fmt::Debug, fs, - io::{copy, Read, Seek}, + io::{Read, Seek}, path::Path, }; @@ -33,7 +33,7 @@ where fs::remove_file(path).ok(); }); - copy(&mut reader, &mut file)?; + reader.copy(&mut file)?; // Operation isn't aborted and all writes succeed, // disarm the remove_guard. @@ -54,7 +54,7 @@ where let mut file = tempfile()?; - copy(&mut reader, &mut file)?; + reader.copy(&mut file)?; // rewind it so that we can pass it to unzip file.rewind()?; diff --git a/crates/binstalk/src/helpers/download/stream_readable.rs b/crates/binstalk/src/helpers/download/stream_readable.rs index bc450fb5d..6685c6bf6 100644 --- a/crates/binstalk/src/helpers/download/stream_readable.rs +++ b/crates/binstalk/src/helpers/download/stream_readable.rs @@ -1,24 +1,26 @@ use std::{ cmp::min, - io::{self, BufRead, Read}, + future::Future, + io::{self, BufRead, Read, Write}, + pin::Pin, }; use bytes::{Buf, Bytes}; use futures_util::stream::{Stream, StreamExt}; use tokio::runtime::Handle; -use crate::errors::BinstallError; +use crate::{errors::BinstallError, helpers::signal::wait_on_cancellation_signal}; /// This wraps an AsyncIterator as a `Read`able. /// It must be used in non-async context only, /// meaning you have to use it with /// `tokio::task::{block_in_place, spawn_blocking}` or /// `std::thread::spawn`. -#[derive(Debug)] pub struct StreamReadable { stream: S, handle: Handle, bytes: Bytes, + cancellation_future: Pin> + Send>>, } impl StreamReadable { @@ -27,6 +29,39 @@ impl StreamReadable { stream, handle: Handle::current(), bytes: Bytes::new(), + cancellation_future: Box::pin(wait_on_cancellation_signal()), + } + } +} + +impl StreamReadable +where + S: Stream> + Unpin, + BinstallError: From, +{ + /// Copies from `self` to `writer`. + /// + /// Same as `io::copy` but does not allocate any internal buffer + /// since `self` is buffered. + pub(super) fn copy(&mut self, mut writer: W) -> io::Result<()> + where + W: Write, + { + self.copy_inner(&mut writer) + } + + fn copy_inner(&mut self, writer: &mut dyn Write) -> io::Result<()> { + loop { + let buf = self.fill_buf()?; + if buf.is_empty() { + // Eof + break Ok(()); + } + + writer.write_all(buf)?; + + let n = buf.len(); + self.consume(n); } } } @@ -56,6 +91,27 @@ where Ok(n) } } + +/// If `Ok(Some(bytes))` if returned, then `bytes.is_empty() == false`. +async fn next_stream(stream: &mut S) -> io::Result> +where + S: Stream> + Unpin, + BinstallError: From, +{ + loop { + let option = stream + .next() + .await + .transpose() + .map_err(BinstallError::from)?; + + match option { + Some(bytes) if bytes.is_empty() => continue, + option => break Ok(option), + } + } +} + impl BufRead for StreamReadable where S: Stream> + Unpin, @@ -65,13 +121,18 @@ where let bytes = &mut self.bytes; if !bytes.has_remaining() { - match self.handle.block_on(async { self.stream.next().await }) { - Some(Ok(new_bytes)) => *bytes = new_bytes, - Some(Err(e)) => { - let e: BinstallError = e.into(); - return Err(io::Error::new(io::ErrorKind::Other, e)); + let option = self.handle.block_on(async { + tokio::select! { + res = next_stream(&mut self.stream) => res, + res = self.cancellation_future.as_mut() => { + Err(res.err().unwrap_or_else(|| io::Error::from(BinstallError::UserAbort))) + }, } - None => (), + })?; + + if let Some(new_bytes) = option { + // new_bytes are guaranteed to be non-empty. + *bytes = new_bytes; } } Ok(&*bytes) diff --git a/crates/binstalk/src/helpers/signal.rs b/crates/binstalk/src/helpers/signal.rs index e15ed8e18..d01041df8 100644 --- a/crates/binstalk/src/helpers/signal.rs +++ b/crates/binstalk/src/helpers/signal.rs @@ -1,7 +1,7 @@ use std::io; use futures_util::future::pending; -use tokio::signal; +use tokio::{signal, sync::OnceCell}; use super::tasks::AutoAbortJoinHandle; use crate::errors::BinstallError; @@ -24,12 +24,25 @@ pub async fn cancel_on_user_sig_term( tokio::select! { res = handle => res, res = wait_on_cancellation_signal() => { - res.map_err(BinstallError::Io).and(Err(BinstallError::UserAbort)) + res + .map_err(BinstallError::Io) + .and(Err(BinstallError::UserAbort)) } } } -async fn wait_on_cancellation_signal() -> Result<(), io::Error> { +/// If call to it returns `Ok(())`, then all calls to this function after +/// that also returns `Ok(())`. +pub async fn wait_on_cancellation_signal() -> Result<(), io::Error> { + static CANCELLED: OnceCell<()> = OnceCell::const_new(); + + CANCELLED + .get_or_try_init(wait_on_cancellation_signal_inner) + .await + .copied() +} + +async fn wait_on_cancellation_signal_inner() -> Result<(), io::Error> { #[cfg(unix)] async fn inner() -> Result<(), io::Error> { unix::wait_on_cancellation_signal_unix().await diff --git a/crates/binstalk/src/ops/resolve.rs b/crates/binstalk/src/ops/resolve.rs index 59fa9ad82..3335a130f 100644 --- a/crates/binstalk/src/ops/resolve.rs +++ b/crates/binstalk/src/ops/resolve.rs @@ -260,6 +260,9 @@ async fn resolve_inner( } } Err(err) => { + if let BinstallError::UserAbort = err { + return Err(err); + } warn!( "Error while downloading and extracting from fetcher {}: {}", fetcher.source_name(),