From 8cd21fc99c9c967b66d7ed3be2e77699084d7ecd Mon Sep 17 00:00:00 2001 From: threema-donat <129288638+threema-donat@users.noreply.github.com> Date: Sun, 5 May 2024 14:50:26 +0200 Subject: [PATCH] feat: Add option to set a request timeout (#81) --- Cargo.toml | 1 + examples/certificate_client.rs | 6 +- examples/token_client.rs | 9 +- src/client.rs | 197 ++++++++++++++++++++++++++------- src/error.rs | 4 + src/lib.rs | 10 +- src/request/payload.rs | 6 +- 7 files changed, 179 insertions(+), 54 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 92516031..d554f656 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ hyper-rustls = { version = "0.26.0", default-features = false, features = ["http rustls-pemfile = "2.1.1" rustls = "0.22.4" parking_lot = "0.12" +tokio = { version = "1", features = ["time"] } [dev-dependencies] argparse = "0.2" diff --git a/examples/certificate_client.rs b/examples/certificate_client.rs index 9f08d8bb..5d5acc49 100644 --- a/examples/certificate_client.rs +++ b/examples/certificate_client.rs @@ -43,7 +43,11 @@ async fn main() -> Result<(), Box> { }; let mut certificate = std::fs::File::open(certificate_file)?; - Ok(Client::certificate(&mut certificate, &password, endpoint)?) + + // Create config with the given endpoint and default timeouts + let client_config = a2::ClientConfig::new(endpoint); + + Ok(Client::certificate(&mut certificate, &password, client_config)?) } #[cfg(all(not(feature = "openssl"), feature = "ring"))] { diff --git a/examples/token_client.rs b/examples/token_client.rs index b60cff02..747d0590 100644 --- a/examples/token_client.rs +++ b/examples/token_client.rs @@ -1,7 +1,9 @@ use argparse::{ArgumentParser, Store, StoreOption, StoreTrue}; use std::fs::File; -use a2::{Client, DefaultNotificationBuilder, Endpoint, NotificationBuilder, NotificationOptions}; +use a2::{ + client::ClientConfig, Client, DefaultNotificationBuilder, Endpoint, NotificationBuilder, NotificationOptions, +}; // An example client connectiong to APNs with a JWT token #[tokio::main] @@ -46,8 +48,11 @@ async fn main() -> Result<(), Box> { Endpoint::Production }; + // Create config with the given endpoint and default timeouts + let client_config = ClientConfig::new(endpoint); + // Connecting to APNs - let client = Client::token(&mut private_key, key_id, team_id, endpoint).unwrap(); + let client = Client::token(&mut private_key, key_id, team_id, client_config).unwrap(); let options = NotificationOptions { apns_topic: topic.as_deref(), diff --git a/src/client.rs b/src/client.rs index e331f952..9da3c8d4 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,6 +3,7 @@ use crate::error::Error; use crate::error::Error::ResponseError; use crate::signer::Signer; +use tokio::time::timeout; use crate::request::payload::PayloadLike; use crate::response::Response; @@ -20,6 +21,8 @@ use std::io::Read; use std::time::Duration; use std::{fmt, io}; +const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 20; + type HyperConnector = HttpsConnector; /// The APNs service endpoint to connect. @@ -52,23 +55,121 @@ impl fmt::Display for Endpoint { /// holds the response for handling. #[derive(Debug, Clone)] pub struct Client { - endpoint: Endpoint, - signer: Option, + options: ConnectionOptions, http_client: HttpClient>, } -impl Client { - fn new(connector: HyperConnector, signer: Option, endpoint: Endpoint) -> Client { - let mut builder = HttpClient::builder(TokioExecutor::new()); - builder.pool_idle_timeout(Some(Duration::from_secs(600))); - builder.http2_only(true); +#[derive(Debug, Clone)] +/// The default implementation uses [`Endpoint::Production`] and can be created +/// trough calling [`ClientConfig::default`]. +pub struct ClientConfig { + /// The endpoint where the requests are sent to + pub endpoint: Endpoint, + /// The timeout of the HTTP requests + pub request_timeout_secs: Option, + /// The timeout for idle sockets being kept alive + pub pool_idle_timeout_secs: Option, +} - Client { - http_client: builder.build(connector), +impl Default for ClientConfig { + fn default() -> Self { + Self { + endpoint: Endpoint::Production, + request_timeout_secs: Some(DEFAULT_REQUEST_TIMEOUT_SECS), + pool_idle_timeout_secs: Some(600), + } + } +} + +impl ClientConfig { + pub fn new(endpoint: Endpoint) -> Self { + ClientConfig { + endpoint, + ..Default::default() + } + } +} + +#[derive(Debug, Clone)] +struct ClientBuilder { + config: ClientConfig, + signer: Option, + connector: Option, +} + +impl Default for ClientBuilder { + fn default() -> Self { + Self { + config: Default::default(), + signer: None, + connector: Some(default_connector()), + } + } +} + +impl ClientBuilder { + fn connector(mut self, connector: HyperConnector) -> Self { + self.connector = Some(connector); + self + } + + fn signer(mut self, signer: Signer) -> Self { + self.signer = Some(signer); + self + } + + fn config(mut self, config: ClientConfig) -> Self { + self.config = config; + self + } + + fn build(self) -> Client { + let ClientBuilder { + config: + ClientConfig { + endpoint, + request_timeout_secs, + pool_idle_timeout_secs, + }, signer, + connector, + } = self; + let http_client = HttpClient::builder(TokioExecutor::new()) + .pool_idle_timeout(pool_idle_timeout_secs.map(Duration::from_secs)) + .http2_only(true) + .build(connector.unwrap_or_else(default_connector)); + + Client { + http_client, + options: ConnectionOptions::new(endpoint, signer, request_timeout_secs), + } + } +} + +#[derive(Debug, Clone)] +struct ConnectionOptions { + endpoint: Endpoint, + request_timeout: Duration, + signer: Option, +} + +impl ConnectionOptions { + fn new(endpoint: Endpoint, signer: Option, request_timeout_secs: Option) -> Self { + let request_timeout = Duration::from_secs(request_timeout_secs.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS)); + Self { endpoint, + request_timeout, + signer, } } +} + +impl Client { + /// Creates a builder for the [`Client`] that uses the default connector and + /// [`Endpoint::Production`] + fn builder() -> ClientBuilder { + ClientBuilder::default() + } /// Create a connection to APNs using the provider client certificate which /// you obtain from your [Apple developer @@ -76,7 +177,7 @@ impl Client { /// /// Only works with the `openssl` feature. #[cfg(feature = "openssl")] - pub fn certificate(certificate: &mut R, password: &str, endpoint: Endpoint) -> Result + pub fn certificate(certificate: &mut R, password: &str, config: ClientConfig) -> Result where R: Read, { @@ -89,33 +190,32 @@ impl Client { }; let connector = client_cert_connector(&cert.to_pem()?, &pkey.private_key_to_pem_pkcs8()?)?; - Ok(Self::new(connector, None, endpoint)) + Ok(Self::builder().connector(connector).config(config).build()) } /// Create a connection to APNs using the raw PEM-formatted certificate and /// key, extracted from the provider client certificate you obtain from your /// [Apple developer account](https://developer.apple.com/account/) - pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], endpoint: Endpoint) -> Result { + pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], config: ClientConfig) -> Result { let connector = client_cert_connector(cert_pem, key_pem)?; - Ok(Self::new(connector, None, endpoint)) + Ok(Self::builder().config(config).connector(connector).build()) } /// Create a connection to APNs using system certificates, signing every /// request with a signature using a private key, key id and team id /// provisioned from your [Apple developer /// account](https://developer.apple.com/account/). - pub fn token(pkcs8_pem: R, key_id: S, team_id: T, endpoint: Endpoint) -> Result + pub fn token(pkcs8_pem: R, key_id: S, team_id: T, config: ClientConfig) -> Result where S: Into, T: Into, R: Read, { - let connector = default_connector(); let signature_ttl = Duration::from_secs(60 * 55); let signer = Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?; - Ok(Self::new(connector, Some(signer), endpoint)) + Ok(Self::builder().config(config).signer(signer).build()) } /// Send a notification payload. @@ -126,7 +226,11 @@ impl Client { let request = self.build_request(payload)?; let requesting = self.http_client.request(request); - let response = requesting.await?; + let Ok(response_result) = timeout(self.options.request_timeout, requesting).await else { + return Err(Error::RequestTimeout(self.options.request_timeout.as_secs())); + }; + + let response = response_result?; let apns_id = response .headers() @@ -153,7 +257,11 @@ impl Client { } fn build_request(&self, payload: T) -> Result>, Error> { - let path = format!("https://{}/3/device/{}", self.endpoint, payload.get_device_token()); + let path = format!( + "https://{}/3/device/{}", + self.options.endpoint, + payload.get_device_token() + ); let mut builder = hyper::Request::builder() .uri(&path) @@ -179,7 +287,7 @@ impl Client { if let Some(apns_topic) = options.apns_topic { builder = builder.header("apns-topic", apns_topic.as_bytes()); } - if let Some(ref signer) = self.signer { + if let Some(ref signer) = self.options.signer { let auth = signer.with_signature(|signature| format!("Bearer {}", signature))?; builder = builder.header(AUTHORIZATION, auth.as_bytes()); @@ -244,7 +352,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_production_request_uri() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let uri = format!("{}", request.uri()); @@ -255,7 +363,12 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_sandbox_request_uri() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Sandbox); + let client = Client::builder() + .config(ClientConfig { + endpoint: Endpoint::Sandbox, + ..Default::default() + }) + .build(); let request = client.build_request(payload).unwrap(); let uri = format!("{}", request.uri()); @@ -266,7 +379,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_method() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); assert_eq!(&Method::POST, request.method()); @@ -276,7 +389,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_invalid() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("\r\n", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload); assert!(matches!(request, Err(Error::BuildRequestError(_)))); @@ -286,7 +399,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_content_type() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap()); @@ -296,7 +409,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_content_length() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload.clone()).unwrap(); let payload_json = payload.to_json_string().unwrap(); let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap(); @@ -308,7 +421,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_authorization_with_no_signer() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); assert_eq!(None, request.headers().get(AUTHORIZATION)); @@ -326,7 +439,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), Some(signer), Endpoint::Production); + let client = Client::builder().signer(signer).build(); let request = client.build_request(payload).unwrap(); assert_ne!(None, request.headers().get(AUTHORIZATION)); @@ -340,7 +453,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ..Default::default() }; let payload = builder.build("a_test_id", options); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_push_type = request.headers().get("apns-push-type").unwrap(); @@ -351,7 +464,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_with_default_priority() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority"); @@ -370,7 +483,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority").unwrap(); @@ -389,7 +502,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority").unwrap(); @@ -402,7 +515,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_id = request.headers().get("apns-id"); @@ -421,7 +534,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_id = request.headers().get("apns-id").unwrap(); @@ -434,7 +547,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_expiration = request.headers().get("apns-expiration"); @@ -453,7 +566,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_expiration = request.headers().get("apns-expiration").unwrap(); @@ -466,7 +579,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_collapse_id = request.headers().get("apns-collapse-id"); @@ -485,7 +598,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap(); @@ -498,7 +611,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_topic = request.headers().get("apns-topic"); @@ -517,7 +630,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_topic = request.headers().get("apns-topic").unwrap(); @@ -528,7 +641,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ async fn test_request_body() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().build(); let request = client.build_request(payload.clone()).unwrap(); let body = request.into_body().collect().await.unwrap().to_bytes(); @@ -545,8 +658,8 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let key: Vec = include_str!("../test_cert/test.key").bytes().collect(); let cert: Vec = include_str!("../test_cert/test.crt").bytes().collect(); - let c = Client::certificate_parts(&cert, &key, Endpoint::Sandbox)?; - assert!(c.signer.is_none()); + let c = Client::certificate_parts(&cert, &key, ClientConfig::default())?; + assert!(c.options.signer.is_none()); Ok(()) } } diff --git a/src/error.rs b/src/error.rs index b818f745..2204421b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -48,6 +48,10 @@ pub enum Error { #[error("Failed to construct HTTP request: {0}")] BuildRequestError(#[source] http::Error), + /// No repsonse from APNs after the given amount of time + #[error("The request timed out after {0} s")] + RequestTimeout(u64), + /// Unexpected private key (only EC keys are supported). #[cfg(all(not(feature = "openssl"), feature = "ring"))] #[error("Unexpected private key: {0}")] diff --git a/src/lib.rs b/src/lib.rs index a59beb59..513fac04 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,7 +31,7 @@ //! ## Example sending a plain notification using token authentication: //! //! ```no_run -//! # use a2::{DefaultNotificationBuilder, NotificationBuilder, Client, Endpoint}; +//! # use a2::{DefaultNotificationBuilder, NotificationBuilder, Client, ClientConfig, Endpoint}; //! # use std::fs::File; //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { @@ -48,7 +48,7 @@ //! &mut file, //! "KEY_ID", //! "TEAM_ID", -//! Endpoint::Production).unwrap(); +//! ClientConfig::default()).unwrap(); //! //! let response = client.send(payload).await?; //! println!("Sent: {:?}", response); @@ -64,7 +64,7 @@ //! # { //! //! use a2::{ -//! Client, Endpoint, DefaultNotificationBuilder, NotificationBuilder, NotificationOptions, +//! Client, ClientConfig, Endpoint, DefaultNotificationBuilder, NotificationBuilder, NotificationOptions, //! Priority, //! }; //! use std::fs::File; @@ -97,7 +97,7 @@ //! let client = Client::certificate( //! &mut file, //! "Correct Horse Battery Stable", -//! Endpoint::Production)?; +//! ClientConfig::default())?; //! //! let response = client.send(payload).await?; //! println!("Sent: {:?}", response); @@ -131,6 +131,6 @@ pub use crate::request::notification::{ pub use crate::response::{ErrorBody, ErrorReason, Response}; -pub use crate::client::{Client, Endpoint}; +pub use crate::client::{Client, ClientConfig, Endpoint}; pub use crate::error::Error; diff --git a/src/request/payload.rs b/src/request/payload.rs index db509873..dcdcba16 100644 --- a/src/request/payload.rs +++ b/src/request/payload.rs @@ -27,11 +27,9 @@ pub struct Payload<'a> { /// /// # Example /// ```no_run -/// use a2::client::Endpoint; /// use a2::request::notification::{NotificationBuilder, NotificationOptions}; /// use a2::request::payload::{PayloadLike, APS}; -/// use a2::Client; -/// use a2::DefaultNotificationBuilder; +/// use a2::{Client, ClientConfig, DefaultNotificationBuilder, Endpoint}; /// use serde::Serialize; /// use std::fs::File; /// @@ -45,7 +43,7 @@ pub struct Payload<'a> { /// let payload = builder.build("device-token-from-the-user", Default::default()); /// let mut file = File::open("/path/to/private_key.p8")?; /// -/// let client = Client::token(&mut file, "KEY_ID", "TEAM_ID", Endpoint::Production).unwrap(); +/// let client = Client::token(&mut file, "KEY_ID", "TEAM_ID", ClientConfig::default()).unwrap(); /// /// let response = client.send(payload).await?; /// println!("Sent: {:?}", response);