Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Provide timeouts for reading from sockets #93

Merged
merged 3 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions src/ctx.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::time::Duration;

use crate::{
splat::SdkHeaders,
util::{ProgressTarget, Sha256},
Expand All @@ -23,33 +25,31 @@ pub struct Ctx {
}

impl Ctx {
fn http_client() -> Result<ureq::Agent, Error> {
fn http_client(read_timeout: Option<Duration>) -> Result<ureq::Agent, Error> {
let mut builder = ureq::builder();

#[cfg(feature = "native-tls")]
{
use std::sync::Arc;
let mut builder =
ureq::builder().tls_connector(Arc::new(native_tls_crate::TlsConnector::new()?));
if let Ok(proxy) = std::env::var("https_proxy") {
let proxy = ureq::Proxy::new(proxy)?;
builder = builder.proxy(proxy);
};
Ok(builder.build())
builder = builder.tls_connector(Arc::new(native_tls_crate::TlsConnector::new()?));
}

#[cfg(not(feature = "native-tls"))]
{
let mut builder = ureq::builder();
if let Ok(proxy) = std::env::var("https_proxy") {
let proxy = ureq::Proxy::new(proxy)?;
builder = builder.proxy(proxy);
};
Ok(builder.build())
// Allow user to specify timeout values in the case of bad/slow proxies
// or MS itself being terrible, but default to a minute, which is _far_
// more than it should take in normal situations, as by default ureq
// sets no timeout on the response
builder = builder.timeout_read(read_timeout.unwrap_or(Duration::from_secs(60)));

if let Ok(proxy) = std::env::var("https_proxy") {
let proxy = ureq::Proxy::new(proxy)?;
builder = builder.proxy(proxy);
}
Ok(builder.build())
}

pub fn with_temp(dt: ProgressTarget) -> Result<Self, Error> {
pub fn with_temp(dt: ProgressTarget, read_timeout: Option<Duration>) -> Result<Self, Error> {
let td = tempfile::TempDir::new()?;
let client = Self::http_client()?;
let client = Self::http_client(read_timeout)?;

Ok(Self {
work_dir: PathBuf::from_path_buf(td.path().to_owned()).map_err(|pb| {
Expand All @@ -61,8 +61,12 @@ impl Ctx {
})
}

pub fn with_dir(mut work_dir: PathBuf, dt: ProgressTarget) -> Result<Self, Error> {
let client = Self::http_client()?;
pub fn with_dir(
mut work_dir: PathBuf,
dt: ProgressTarget,
read_timeout: Option<Duration>,
) -> Result<Self, Error> {
let client = Self::http_client(read_timeout)?;

work_dir.push("dl");
std::fs::create_dir_all(&work_dir)?;
Expand Down
33 changes: 30 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use camino::Utf8PathBuf as PathBuf;
use clap::builder::{PossibleValuesParser, TypedValueParser as _};
use clap::{Parser, Subcommand};
use indicatif as ia;
use std::time::Duration;
use tracing_subscriber::filter::LevelFilter;

fn setup_logger(json: bool, log_level: LevelFilter) -> Result<(), Error> {
Expand Down Expand Up @@ -89,7 +90,30 @@ const LOG_LEVELS: &[&str] = &["off", "error", "warn", "info", "debug", "trace"];

fn parse_level(s: &str) -> Result<LevelFilter, Error> {
s.parse::<LevelFilter>()
.map_err(|_| anyhow::anyhow!("failed to parse level '{}'", s))
.map_err(|_| anyhow::anyhow!("failed to parse level '{s}'"))
}

#[allow(clippy::indexing_slicing)]
fn parse_duration(src: &str) -> anyhow::Result<Duration> {
let suffix_pos = src.find(char::is_alphabetic).unwrap_or(src.len());

let num: u64 = src[..suffix_pos].parse()?;
let suffix = if suffix_pos == src.len() {
"s"
} else {
&src[suffix_pos..]
};

let duration = match suffix {
"ms" => Duration::from_millis(num),
"s" | "S" => Duration::from_secs(num),
"m" | "M" => Duration::from_secs(num * 60),
"h" | "H" => Duration::from_secs(num * 60 * 60),
"d" | "D" => Duration::from_secs(num * 60 * 60 * 24),
s => anyhow::bail!("unknown duration suffix '{s}'"),
};

Ok(duration)
}

#[derive(Parser)]
Expand Down Expand Up @@ -131,6 +155,9 @@ pub struct Args {
/// Whether to include the Active Template Library (ATL) in the installation
#[arg(long)]
include_atl: bool,
/// Specifies a timeout for how long a single download is allowed to take. The default is 60s.
#[arg(short, long, value_parser = parse_duration)]
timeout: Option<Duration>,
/// The architectures to include
#[arg(
long,
Expand Down Expand Up @@ -176,13 +203,13 @@ fn main() -> Result<(), Error> {
let draw_target = xwin::util::ProgressTarget::Stdout;

let ctx = if args.temp {
xwin::Ctx::with_temp(draw_target)?
xwin::Ctx::with_temp(draw_target, args.timeout)?
} else {
let cache_dir = match &args.cache_dir {
Some(cd) => cd.clone(),
None => cwd.join(".xwin-cache"),
};
xwin::Ctx::with_dir(cache_dir, draw_target)?
xwin::Ctx::with_dir(cache_dir, draw_target, args.timeout)?
};

let ctx = std::sync::Arc::new(ctx);
Expand Down
1 change: 1 addition & 0 deletions tests/compiles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ fn verify_compiles() {
let ctx = xwin::Ctx::with_dir(
xwin::PathBuf::from(".xwin-cache/compile-test"),
xwin::util::ProgressTarget::Hidden,
None,
)
.unwrap();

Expand Down
1 change: 1 addition & 0 deletions tests/deterministic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ fn verify_deterministic() {
let ctx = xwin::Ctx::with_dir(
PathBuf::from(".xwin-cache/deterministic"),
xwin::util::ProgressTarget::Hidden,
None,
)
.unwrap();

Expand Down
4 changes: 4 additions & 0 deletions tests/snapshots/xwin.snap
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ Options:
Whether to include the Active Template Library (ATL) in the
installation

-t, --timeout <TIMEOUT>
Specifies a timeout for how long a single download is allowed to take.
The default is 60s

--arch <ARCH>
The architectures to include

Expand Down