From c169023933c5889e971f5734e5f4970c223744d6 Mon Sep 17 00:00:00 2001 From: threema-donat <129288638+threema-donat@users.noreply.github.com> Date: Wed, 24 Apr 2024 14:06:00 +0200 Subject: [PATCH 1/4] feat: Add option to set a request timeout --- Cargo.toml | 1 + src/client.rs | 168 +++++++++++++++++++++++++++++++++++++++----------- src/error.rs | 4 ++ 3 files changed, 136 insertions(+), 37 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9251603..d554f65 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ hyper-rustls = { version = "0.26.0", default-features = false, features = ["http rustls-pemfile = "2.1.1" rustls = "0.22.4" parking_lot = "0.12" +tokio = { version = "1", features = ["time"] } [dev-dependencies] argparse = "0.2" diff --git a/src/client.rs b/src/client.rs index e331f95..9b481ff 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,6 +3,7 @@ use crate::error::Error; use crate::error::Error::ResponseError; use crate::signer::Signer; +use tokio::time::timeout; use crate::request::payload::PayloadLike; use crate::response::Response; @@ -20,6 +21,8 @@ use std::io::Read; use std::time::Duration; use std::{fmt, io}; +const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 20; + type HyperConnector = HttpsConnector; /// The APNs service endpoint to connect. @@ -52,23 +55,96 @@ impl fmt::Display for Endpoint { /// holds the response for handling. #[derive(Debug, Clone)] pub struct Client { - endpoint: Endpoint, - signer: Option, + options: ConnectionOptions, http_client: HttpClient>, } -impl Client { - fn new(connector: HyperConnector, signer: Option, endpoint: Endpoint) -> Client { - let mut builder = HttpClient::builder(TokioExecutor::new()); - builder.pool_idle_timeout(Some(Duration::from_secs(600))); - builder.http2_only(true); +/// Uses [`Endpoint::Production`] by default. +#[derive(Debug, Clone)] +pub struct ClientOptions { + /// The timeout of the HTTP requests + pub request_timeout_secs: Option, + /// The timeout for idle sockets being kept alive + pub pool_idle_timeout_secs: Option, + /// The endpoint where the requests are sent to + pub endpoint: Endpoint, + /// See [`crate::signer::Signer`] + pub signer: Option, +} + +impl Default for ClientOptions { + fn default() -> Self { + Self { + pool_idle_timeout_secs: Some(600), + request_timeout_secs: Some(DEFAULT_REQUEST_TIMEOUT_SECS), + endpoint: Endpoint::Production, + signer: None, + } + } +} + +impl ClientOptions { + pub fn new(endpoint: Endpoint) -> Self { + Self { + endpoint, + ..Default::default() + } + } - Client { - http_client: builder.build(connector), + pub fn with_signer(mut self, signer: Signer) -> Self { + self.signer = Some(signer); + self + } + + pub fn with_request_timeout(mut self, seconds: u64) -> Self { + self.request_timeout_secs = Some(seconds); + self + } + + pub fn with_pool_idle_timeout(mut self, seconds: u64) -> Self { + self.pool_idle_timeout_secs = Some(seconds); + self + } +} + +#[derive(Debug, Clone)] +struct ConnectionOptions { + endpoint: Endpoint, + request_timeout: Duration, + signer: Option, +} + +impl From for ConnectionOptions { + fn from(value: ClientOptions) -> Self { + let ClientOptions { + endpoint, + pool_idle_timeout_secs: _, signer, + request_timeout_secs, + } = value; + let request_timeout = Duration::from_secs(request_timeout_secs.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS)); + Self { endpoint, + request_timeout, + signer, } } +} + +impl Client { + /// If `options` is not set, a default using [`Endpoint::Production`] will + /// be initialized. + fn new(connector: HyperConnector, options: Option) -> Client { + let options = options.unwrap_or_default(); + let http_client = HttpClient::builder(TokioExecutor::new()) + .pool_idle_timeout(options.pool_idle_timeout_secs.map(Duration::from_secs)) + .http2_only(true) + .build(connector); + + let options = options.into(); + + Client { http_client, options } + } /// Create a connection to APNs using the provider client certificate which /// you obtain from your [Apple developer @@ -89,7 +165,7 @@ impl Client { }; let connector = client_cert_connector(&cert.to_pem()?, &pkey.private_key_to_pem_pkcs8()?)?; - Ok(Self::new(connector, None, endpoint)) + Ok(Self::new(connector, Some(ClientOptions::new(endpoint)))) } /// Create a connection to APNs using the raw PEM-formatted certificate and @@ -98,7 +174,7 @@ impl Client { pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], endpoint: Endpoint) -> Result { let connector = client_cert_connector(cert_pem, key_pem)?; - Ok(Self::new(connector, None, endpoint)) + Ok(Self::new(connector, Some(ClientOptions::new(endpoint)))) } /// Create a connection to APNs using system certificates, signing every @@ -113,9 +189,16 @@ impl Client { { let connector = default_connector(); let signature_ttl = Duration::from_secs(60 * 55); - let signer = Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?; + let signer = Some(Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?); - Ok(Self::new(connector, Some(signer), endpoint)) + Ok(Self::new( + connector, + Some(ClientOptions { + endpoint, + signer, + ..Default::default() + }), + )) } /// Send a notification payload. @@ -126,7 +209,11 @@ impl Client { let request = self.build_request(payload)?; let requesting = self.http_client.request(request); - let response = requesting.await?; + let Ok(response_result) = timeout(self.options.request_timeout, requesting).await else { + return Err(Error::RequestTimeout(self.options.request_timeout.as_secs())); + }; + + let response = response_result?; let apns_id = response .headers() @@ -153,7 +240,11 @@ impl Client { } fn build_request(&self, payload: T) -> Result>, Error> { - let path = format!("https://{}/3/device/{}", self.endpoint, payload.get_device_token()); + let path = format!( + "https://{}/3/device/{}", + self.options.endpoint, + payload.get_device_token() + ); let mut builder = hyper::Request::builder() .uri(&path) @@ -179,7 +270,7 @@ impl Client { if let Some(apns_topic) = options.apns_topic { builder = builder.header("apns-topic", apns_topic.as_bytes()); } - if let Some(ref signer) = self.signer { + if let Some(ref signer) = self.options.signer { let auth = signer.with_signature(|signature| format!("Bearer {}", signature))?; builder = builder.header(AUTHORIZATION, auth.as_bytes()); @@ -244,7 +335,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_production_request_uri() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); let uri = format!("{}", request.uri()); @@ -255,7 +346,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_sandbox_request_uri() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Sandbox); + let client = Client::new(default_connector(), Some(ClientOptions::new(Endpoint::Sandbox))); let request = client.build_request(payload).unwrap(); let uri = format!("{}", request.uri()); @@ -266,7 +357,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_method() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); assert_eq!(&Method::POST, request.method()); @@ -286,7 +377,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_content_type() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap()); @@ -296,7 +387,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_content_length() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload.clone()).unwrap(); let payload_json = payload.to_json_string().unwrap(); let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap(); @@ -308,7 +399,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_authorization_with_no_signer() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); assert_eq!(None, request.headers().get(AUTHORIZATION)); @@ -326,7 +417,10 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), Some(signer), Endpoint::Production); + let client = Client::new( + default_connector(), + Some(ClientOptions::new(Endpoint::Production).with_signer(signer)), + ); let request = client.build_request(payload).unwrap(); assert_ne!(None, request.headers().get(AUTHORIZATION)); @@ -340,7 +434,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ..Default::default() }; let payload = builder.build("a_test_id", options); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); let apns_push_type = request.headers().get("apns-push-type").unwrap(); @@ -351,7 +445,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_with_default_priority() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority"); @@ -370,7 +464,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority").unwrap(); @@ -389,7 +483,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority").unwrap(); @@ -402,7 +496,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); let apns_id = request.headers().get("apns-id"); @@ -421,7 +515,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); let apns_id = request.headers().get("apns-id").unwrap(); @@ -434,7 +528,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); let apns_expiration = request.headers().get("apns-expiration"); @@ -453,7 +547,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); let apns_expiration = request.headers().get("apns-expiration").unwrap(); @@ -466,7 +560,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); let apns_collapse_id = request.headers().get("apns-collapse-id"); @@ -485,7 +579,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap(); @@ -498,7 +592,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); let apns_topic = request.headers().get("apns-topic"); @@ -517,7 +611,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload).unwrap(); let apns_topic = request.headers().get("apns-topic").unwrap(); @@ -528,7 +622,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ async fn test_request_body() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::new(default_connector(), None); let request = client.build_request(payload.clone()).unwrap(); let body = request.into_body().collect().await.unwrap().to_bytes(); @@ -546,7 +640,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let cert: Vec = include_str!("../test_cert/test.crt").bytes().collect(); let c = Client::certificate_parts(&cert, &key, Endpoint::Sandbox)?; - assert!(c.signer.is_none()); + assert!(c.options.signer.is_none()); Ok(()) } } diff --git a/src/error.rs b/src/error.rs index b818f74..2204421 100644 --- a/src/error.rs +++ b/src/error.rs @@ -48,6 +48,10 @@ pub enum Error { #[error("Failed to construct HTTP request: {0}")] BuildRequestError(#[source] http::Error), + /// No repsonse from APNs after the given amount of time + #[error("The request timed out after {0} s")] + RequestTimeout(u64), + /// Unexpected private key (only EC keys are supported). #[cfg(all(not(feature = "openssl"), feature = "ring"))] #[error("Unexpected private key: {0}")] From 8f47453013310fef5a487c79d708af211b4707ac Mon Sep 17 00:00:00 2001 From: threema-donat <129288638+threema-donat@users.noreply.github.com> Date: Fri, 26 Apr 2024 11:11:11 +0200 Subject: [PATCH 2/4] fixup! feat: Add option to set a request timeout --- src/client.rs | 109 +++++++++++++++++++++++++++----------------------- 1 file changed, 58 insertions(+), 51 deletions(-) diff --git a/src/client.rs b/src/client.rs index 9b481ff..e407cd6 100644 --- a/src/client.rs +++ b/src/client.rs @@ -70,6 +70,8 @@ pub struct ClientOptions { pub endpoint: Endpoint, /// See [`crate::signer::Signer`] pub signer: Option, + /// The HTTPS connector used to connect to APNs + pub connector: Option, } impl Default for ClientOptions { @@ -79,6 +81,7 @@ impl Default for ClientOptions { request_timeout_secs: Some(DEFAULT_REQUEST_TIMEOUT_SECS), endpoint: Endpoint::Production, signer: None, + connector: Some(default_connector()), } } } @@ -91,6 +94,11 @@ impl ClientOptions { } } + pub fn with_connector(mut self, connector: HyperConnector) -> Self { + self.connector = Some(connector); + self + } + pub fn with_signer(mut self, signer: Signer) -> Self { self.signer = Some(signer); self @@ -114,14 +122,8 @@ struct ConnectionOptions { signer: Option, } -impl From for ConnectionOptions { - fn from(value: ClientOptions) -> Self { - let ClientOptions { - endpoint, - pool_idle_timeout_secs: _, - signer, - request_timeout_secs, - } = value; +impl ConnectionOptions { + fn new(endpoint: Endpoint, signer: Option, request_timeout_secs: Option) -> Self { let request_timeout = Duration::from_secs(request_timeout_secs.unwrap_or(DEFAULT_REQUEST_TIMEOUT_SECS)); Self { endpoint, @@ -132,18 +134,23 @@ impl From for ConnectionOptions { } impl Client { - /// If `options` is not set, a default using [`Endpoint::Production`] will - /// be initialized. - fn new(connector: HyperConnector, options: Option) -> Client { - let options = options.unwrap_or_default(); + fn new(options: ClientOptions) -> Self { + let ClientOptions { + request_timeout_secs, + pool_idle_timeout_secs, + endpoint, + signer, + connector, + } = options; let http_client = HttpClient::builder(TokioExecutor::new()) - .pool_idle_timeout(options.pool_idle_timeout_secs.map(Duration::from_secs)) + .pool_idle_timeout(pool_idle_timeout_secs.map(Duration::from_secs)) .http2_only(true) - .build(connector); + .build(connector.unwrap_or_else(default_connector)); - let options = options.into(); - - Client { http_client, options } + Client { + http_client, + options: ConnectionOptions::new(endpoint, signer, request_timeout_secs), + } } /// Create a connection to APNs using the provider client certificate which @@ -165,7 +172,7 @@ impl Client { }; let connector = client_cert_connector(&cert.to_pem()?, &pkey.private_key_to_pem_pkcs8()?)?; - Ok(Self::new(connector, Some(ClientOptions::new(endpoint)))) + Ok(Self::new(ClientOptions::new(endpoint).with_connector(connector))) } /// Create a connection to APNs using the raw PEM-formatted certificate and @@ -174,7 +181,7 @@ impl Client { pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], endpoint: Endpoint) -> Result { let connector = client_cert_connector(cert_pem, key_pem)?; - Ok(Self::new(connector, Some(ClientOptions::new(endpoint)))) + Ok(Self::new(ClientOptions::new(endpoint).with_connector(connector))) } /// Create a connection to APNs using system certificates, signing every @@ -187,18 +194,14 @@ impl Client { T: Into, R: Read, { - let connector = default_connector(); let signature_ttl = Duration::from_secs(60 * 55); let signer = Some(Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?); - Ok(Self::new( - connector, - Some(ClientOptions { - endpoint, - signer, - ..Default::default() - }), - )) + Ok(Self::new(ClientOptions { + endpoint, + signer, + ..Default::default() + })) } /// Send a notification payload. @@ -331,11 +334,18 @@ lCEIvbDqlUhA5FOzcakkG90E8L+hRANCAATKS2ZExEybUvchRDuKBftotMwVEus3 jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ -----END PRIVATE KEY-----"; + impl Client { + fn new_with_defaults() -> Self { + let options = ClientOptions::default(); + Self::new(options) + } + } + #[test] fn test_production_request_uri() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); let uri = format!("{}", request.uri()); @@ -346,7 +356,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_sandbox_request_uri() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), Some(ClientOptions::new(Endpoint::Sandbox))); + let client = Client::new(ClientOptions::new(Endpoint::Sandbox)); let request = client.build_request(payload).unwrap(); let uri = format!("{}", request.uri()); @@ -357,7 +367,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_method() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); assert_eq!(&Method::POST, request.method()); @@ -377,7 +387,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_content_type() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap()); @@ -387,7 +397,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_content_length() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload.clone()).unwrap(); let payload_json = payload.to_json_string().unwrap(); let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap(); @@ -399,7 +409,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_authorization_with_no_signer() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); assert_eq!(None, request.headers().get(AUTHORIZATION)); @@ -417,10 +427,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new( - default_connector(), - Some(ClientOptions::new(Endpoint::Production).with_signer(signer)), - ); + let client = Client::new(ClientOptions::new(Endpoint::Production).with_signer(signer)); let request = client.build_request(payload).unwrap(); assert_ne!(None, request.headers().get(AUTHORIZATION)); @@ -434,7 +441,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ..Default::default() }; let payload = builder.build("a_test_id", options); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); let apns_push_type = request.headers().get("apns-push-type").unwrap(); @@ -445,7 +452,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_with_default_priority() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority"); @@ -464,7 +471,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority").unwrap(); @@ -483,7 +490,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority").unwrap(); @@ -496,7 +503,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); let apns_id = request.headers().get("apns-id"); @@ -515,7 +522,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); let apns_id = request.headers().get("apns-id").unwrap(); @@ -528,7 +535,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); let apns_expiration = request.headers().get("apns-expiration"); @@ -547,7 +554,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); let apns_expiration = request.headers().get("apns-expiration").unwrap(); @@ -560,7 +567,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); let apns_collapse_id = request.headers().get("apns-collapse-id"); @@ -579,7 +586,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap(); @@ -592,7 +599,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); let apns_topic = request.headers().get("apns-topic"); @@ -611,7 +618,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload).unwrap(); let apns_topic = request.headers().get("apns-topic").unwrap(); @@ -622,7 +629,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ async fn test_request_body() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(default_connector(), None); + let client = Client::new_with_defaults(); let request = client.build_request(payload.clone()).unwrap(); let body = request.into_body().collect().await.unwrap().to_bytes(); From 2a1fe70b3c12e5b025a209eba27063b40eb182d4 Mon Sep 17 00:00:00 2001 From: threema-donat <129288638+threema-donat@users.noreply.github.com> Date: Tue, 30 Apr 2024 08:45:28 +0200 Subject: [PATCH 3/4] fixup! feat: Add option to set a request timeout --- src/client.rs | 127 ++++++++++++++++++++++++-------------------------- 1 file changed, 60 insertions(+), 67 deletions(-) diff --git a/src/client.rs b/src/client.rs index e407cd6..f556468 100644 --- a/src/client.rs +++ b/src/client.rs @@ -61,7 +61,7 @@ pub struct Client { /// Uses [`Endpoint::Production`] by default. #[derive(Debug, Clone)] -pub struct ClientOptions { +pub struct ClientBuilder { /// The timeout of the HTTP requests pub request_timeout_secs: Option, /// The timeout for idle sockets being kept alive @@ -74,7 +74,7 @@ pub struct ClientOptions { pub connector: Option, } -impl Default for ClientOptions { +impl Default for ClientBuilder { fn default() -> Self { Self { pool_idle_timeout_secs: Some(600), @@ -86,33 +86,50 @@ impl Default for ClientOptions { } } -impl ClientOptions { - pub fn new(endpoint: Endpoint) -> Self { - Self { - endpoint, - ..Default::default() - } - } - - pub fn with_connector(mut self, connector: HyperConnector) -> Self { +impl ClientBuilder { + pub fn connector(mut self, connector: HyperConnector) -> Self { self.connector = Some(connector); self } - pub fn with_signer(mut self, signer: Signer) -> Self { + pub fn signer(mut self, signer: Signer) -> Self { self.signer = Some(signer); self } - pub fn with_request_timeout(mut self, seconds: u64) -> Self { + pub fn request_timeout(mut self, seconds: u64) -> Self { self.request_timeout_secs = Some(seconds); self } - pub fn with_pool_idle_timeout(mut self, seconds: u64) -> Self { + pub fn pool_idle_timeout(mut self, seconds: u64) -> Self { self.pool_idle_timeout_secs = Some(seconds); self } + + pub fn endpoint(mut self, endpoint: Endpoint) -> Self { + self.endpoint = endpoint; + self + } + + pub fn build(self) -> Client { + let ClientBuilder { + request_timeout_secs, + pool_idle_timeout_secs, + endpoint, + signer, + connector, + } = self; + let http_client = HttpClient::builder(TokioExecutor::new()) + .pool_idle_timeout(pool_idle_timeout_secs.map(Duration::from_secs)) + .http2_only(true) + .build(connector.unwrap_or_else(default_connector)); + + Client { + http_client, + options: ConnectionOptions::new(endpoint, signer, request_timeout_secs), + } + } } #[derive(Debug, Clone)] @@ -134,23 +151,10 @@ impl ConnectionOptions { } impl Client { - fn new(options: ClientOptions) -> Self { - let ClientOptions { - request_timeout_secs, - pool_idle_timeout_secs, - endpoint, - signer, - connector, - } = options; - let http_client = HttpClient::builder(TokioExecutor::new()) - .pool_idle_timeout(pool_idle_timeout_secs.map(Duration::from_secs)) - .http2_only(true) - .build(connector.unwrap_or_else(default_connector)); - - Client { - http_client, - options: ConnectionOptions::new(endpoint, signer, request_timeout_secs), - } + /// Creates a builder for the [`Client`] that uses the default connector and + /// [`Endpoint::Production`] + fn builder() -> ClientBuilder { + ClientBuilder::default() } /// Create a connection to APNs using the provider client certificate which @@ -172,7 +176,7 @@ impl Client { }; let connector = client_cert_connector(&cert.to_pem()?, &pkey.private_key_to_pem_pkcs8()?)?; - Ok(Self::new(ClientOptions::new(endpoint).with_connector(connector))) + Ok(Self::builder().connector(connector).endpoint(endpoint).build()) } /// Create a connection to APNs using the raw PEM-formatted certificate and @@ -181,7 +185,7 @@ impl Client { pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], endpoint: Endpoint) -> Result { let connector = client_cert_connector(cert_pem, key_pem)?; - Ok(Self::new(ClientOptions::new(endpoint).with_connector(connector))) + Ok(Self::builder().endpoint(endpoint).connector(connector).build()) } /// Create a connection to APNs using system certificates, signing every @@ -195,13 +199,9 @@ impl Client { R: Read, { let signature_ttl = Duration::from_secs(60 * 55); - let signer = Some(Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?); + let signer = Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?; - Ok(Self::new(ClientOptions { - endpoint, - signer, - ..Default::default() - })) + Ok(Self::builder().endpoint(endpoint).signer(signer).build()) } /// Send a notification payload. @@ -334,18 +334,11 @@ lCEIvbDqlUhA5FOzcakkG90E8L+hRANCAATKS2ZExEybUvchRDuKBftotMwVEus3 jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ -----END PRIVATE KEY-----"; - impl Client { - fn new_with_defaults() -> Self { - let options = ClientOptions::default(); - Self::new(options) - } - } - #[test] fn test_production_request_uri() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let uri = format!("{}", request.uri()); @@ -356,7 +349,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_sandbox_request_uri() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(ClientOptions::new(Endpoint::Sandbox)); + let client = Client::builder().endpoint(Endpoint::Sandbox).build(); let request = client.build_request(payload).unwrap(); let uri = format!("{}", request.uri()); @@ -367,7 +360,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_method() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); assert_eq!(&Method::POST, request.method()); @@ -377,7 +370,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_invalid() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("\r\n", Default::default()); - let client = Client::new(default_connector(), None, Endpoint::Production); + let client = Client::builder().endpoint(Endpoint::Production).build(); let request = client.build_request(payload); assert!(matches!(request, Err(Error::BuildRequestError(_)))); @@ -387,7 +380,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_content_type() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap()); @@ -397,7 +390,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_content_length() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload.clone()).unwrap(); let payload_json = payload.to_json_string().unwrap(); let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap(); @@ -409,7 +402,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_authorization_with_no_signer() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); assert_eq!(None, request.headers().get(AUTHORIZATION)); @@ -427,7 +420,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new(ClientOptions::new(Endpoint::Production).with_signer(signer)); + let client = Client::builder().endpoint(Endpoint::Production).signer(signer).build(); let request = client.build_request(payload).unwrap(); assert_ne!(None, request.headers().get(AUTHORIZATION)); @@ -441,7 +434,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ ..Default::default() }; let payload = builder.build("a_test_id", options); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_push_type = request.headers().get("apns-push-type").unwrap(); @@ -452,7 +445,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_with_default_priority() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority"); @@ -471,7 +464,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority").unwrap(); @@ -490,7 +483,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_priority = request.headers().get("apns-priority").unwrap(); @@ -503,7 +496,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_id = request.headers().get("apns-id"); @@ -522,7 +515,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_id = request.headers().get("apns-id").unwrap(); @@ -535,7 +528,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_expiration = request.headers().get("apns-expiration"); @@ -554,7 +547,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_expiration = request.headers().get("apns-expiration").unwrap(); @@ -567,7 +560,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_collapse_id = request.headers().get("apns-collapse-id"); @@ -586,7 +579,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap(); @@ -599,7 +592,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_topic = request.headers().get("apns-topic"); @@ -618,7 +611,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ }, ); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload).unwrap(); let apns_topic = request.headers().get("apns-topic").unwrap(); @@ -629,7 +622,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ async fn test_request_body() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::new_with_defaults(); + let client = Client::builder().build(); let request = client.build_request(payload.clone()).unwrap(); let body = request.into_body().collect().await.unwrap().to_bytes(); From b938e2dce24a88b66f3380fab53df2a0aace6452 Mon Sep 17 00:00:00 2001 From: threema-donat <129288638+threema-donat@users.noreply.github.com> Date: Tue, 30 Apr 2024 10:51:42 +0200 Subject: [PATCH 4/4] fixup! feat: Add option to set a request timeout --- examples/certificate_client.rs | 6 ++- examples/token_client.rs | 9 +++- src/client.rs | 97 ++++++++++++++++++++-------------- src/lib.rs | 10 ++-- src/request/payload.rs | 6 +-- 5 files changed, 77 insertions(+), 51 deletions(-) diff --git a/examples/certificate_client.rs b/examples/certificate_client.rs index 9f08d8b..5d5acc4 100644 --- a/examples/certificate_client.rs +++ b/examples/certificate_client.rs @@ -43,7 +43,11 @@ async fn main() -> Result<(), Box> { }; let mut certificate = std::fs::File::open(certificate_file)?; - Ok(Client::certificate(&mut certificate, &password, endpoint)?) + + // Create config with the given endpoint and default timeouts + let client_config = a2::ClientConfig::new(endpoint); + + Ok(Client::certificate(&mut certificate, &password, client_config)?) } #[cfg(all(not(feature = "openssl"), feature = "ring"))] { diff --git a/examples/token_client.rs b/examples/token_client.rs index b60cff0..747d059 100644 --- a/examples/token_client.rs +++ b/examples/token_client.rs @@ -1,7 +1,9 @@ use argparse::{ArgumentParser, Store, StoreOption, StoreTrue}; use std::fs::File; -use a2::{Client, DefaultNotificationBuilder, Endpoint, NotificationBuilder, NotificationOptions}; +use a2::{ + client::ClientConfig, Client, DefaultNotificationBuilder, Endpoint, NotificationBuilder, NotificationOptions, +}; // An example client connectiong to APNs with a JWT token #[tokio::main] @@ -46,8 +48,11 @@ async fn main() -> Result<(), Box> { Endpoint::Production }; + // Create config with the given endpoint and default timeouts + let client_config = ClientConfig::new(endpoint); + // Connecting to APNs - let client = Client::token(&mut private_key, key_id, team_id, endpoint).unwrap(); + let client = Client::token(&mut private_key, key_id, team_id, client_config).unwrap(); let options = NotificationOptions { apns_topic: topic.as_deref(), diff --git a/src/client.rs b/src/client.rs index f556468..9da3c8d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -59,27 +59,48 @@ pub struct Client { http_client: HttpClient>, } -/// Uses [`Endpoint::Production`] by default. #[derive(Debug, Clone)] -pub struct ClientBuilder { +/// The default implementation uses [`Endpoint::Production`] and can be created +/// trough calling [`ClientConfig::default`]. +pub struct ClientConfig { + /// The endpoint where the requests are sent to + pub endpoint: Endpoint, /// The timeout of the HTTP requests pub request_timeout_secs: Option, /// The timeout for idle sockets being kept alive pub pool_idle_timeout_secs: Option, - /// The endpoint where the requests are sent to - pub endpoint: Endpoint, - /// See [`crate::signer::Signer`] - pub signer: Option, - /// The HTTPS connector used to connect to APNs - pub connector: Option, } -impl Default for ClientBuilder { +impl Default for ClientConfig { fn default() -> Self { Self { - pool_idle_timeout_secs: Some(600), - request_timeout_secs: Some(DEFAULT_REQUEST_TIMEOUT_SECS), endpoint: Endpoint::Production, + request_timeout_secs: Some(DEFAULT_REQUEST_TIMEOUT_SECS), + pool_idle_timeout_secs: Some(600), + } + } +} + +impl ClientConfig { + pub fn new(endpoint: Endpoint) -> Self { + ClientConfig { + endpoint, + ..Default::default() + } + } +} + +#[derive(Debug, Clone)] +struct ClientBuilder { + config: ClientConfig, + signer: Option, + connector: Option, +} + +impl Default for ClientBuilder { + fn default() -> Self { + Self { + config: Default::default(), signer: None, connector: Some(default_connector()), } @@ -87,36 +108,29 @@ impl Default for ClientBuilder { } impl ClientBuilder { - pub fn connector(mut self, connector: HyperConnector) -> Self { + fn connector(mut self, connector: HyperConnector) -> Self { self.connector = Some(connector); self } - pub fn signer(mut self, signer: Signer) -> Self { + fn signer(mut self, signer: Signer) -> Self { self.signer = Some(signer); self } - pub fn request_timeout(mut self, seconds: u64) -> Self { - self.request_timeout_secs = Some(seconds); + fn config(mut self, config: ClientConfig) -> Self { + self.config = config; self } - pub fn pool_idle_timeout(mut self, seconds: u64) -> Self { - self.pool_idle_timeout_secs = Some(seconds); - self - } - - pub fn endpoint(mut self, endpoint: Endpoint) -> Self { - self.endpoint = endpoint; - self - } - - pub fn build(self) -> Client { + fn build(self) -> Client { let ClientBuilder { - request_timeout_secs, - pool_idle_timeout_secs, - endpoint, + config: + ClientConfig { + endpoint, + request_timeout_secs, + pool_idle_timeout_secs, + }, signer, connector, } = self; @@ -163,7 +177,7 @@ impl Client { /// /// Only works with the `openssl` feature. #[cfg(feature = "openssl")] - pub fn certificate(certificate: &mut R, password: &str, endpoint: Endpoint) -> Result + pub fn certificate(certificate: &mut R, password: &str, config: ClientConfig) -> Result where R: Read, { @@ -176,23 +190,23 @@ impl Client { }; let connector = client_cert_connector(&cert.to_pem()?, &pkey.private_key_to_pem_pkcs8()?)?; - Ok(Self::builder().connector(connector).endpoint(endpoint).build()) + Ok(Self::builder().connector(connector).config(config).build()) } /// Create a connection to APNs using the raw PEM-formatted certificate and /// key, extracted from the provider client certificate you obtain from your /// [Apple developer account](https://developer.apple.com/account/) - pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], endpoint: Endpoint) -> Result { + pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], config: ClientConfig) -> Result { let connector = client_cert_connector(cert_pem, key_pem)?; - Ok(Self::builder().endpoint(endpoint).connector(connector).build()) + Ok(Self::builder().config(config).connector(connector).build()) } /// Create a connection to APNs using system certificates, signing every /// request with a signature using a private key, key id and team id /// provisioned from your [Apple developer /// account](https://developer.apple.com/account/). - pub fn token(pkcs8_pem: R, key_id: S, team_id: T, endpoint: Endpoint) -> Result + pub fn token(pkcs8_pem: R, key_id: S, team_id: T, config: ClientConfig) -> Result where S: Into, T: Into, @@ -201,7 +215,7 @@ impl Client { let signature_ttl = Duration::from_secs(60 * 55); let signer = Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?; - Ok(Self::builder().endpoint(endpoint).signer(signer).build()) + Ok(Self::builder().config(config).signer(signer).build()) } /// Send a notification payload. @@ -349,7 +363,12 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_sandbox_request_uri() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::builder().endpoint(Endpoint::Sandbox).build(); + let client = Client::builder() + .config(ClientConfig { + endpoint: Endpoint::Sandbox, + ..Default::default() + }) + .build(); let request = client.build_request(payload).unwrap(); let uri = format!("{}", request.uri()); @@ -370,7 +389,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ fn test_request_invalid() { let builder = DefaultNotificationBuilder::new(); let payload = builder.build("\r\n", Default::default()); - let client = Client::builder().endpoint(Endpoint::Production).build(); + let client = Client::builder().build(); let request = client.build_request(payload); assert!(matches!(request, Err(Error::BuildRequestError(_)))); @@ -420,7 +439,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let builder = DefaultNotificationBuilder::new(); let payload = builder.build("a_test_id", Default::default()); - let client = Client::builder().endpoint(Endpoint::Production).signer(signer).build(); + let client = Client::builder().signer(signer).build(); let request = client.build_request(payload).unwrap(); assert_ne!(None, request.headers().get(AUTHORIZATION)); @@ -639,7 +658,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ let key: Vec = include_str!("../test_cert/test.key").bytes().collect(); let cert: Vec = include_str!("../test_cert/test.crt").bytes().collect(); - let c = Client::certificate_parts(&cert, &key, Endpoint::Sandbox)?; + let c = Client::certificate_parts(&cert, &key, ClientConfig::default())?; assert!(c.options.signer.is_none()); Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index a59beb5..513fac0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,7 +31,7 @@ //! ## Example sending a plain notification using token authentication: //! //! ```no_run -//! # use a2::{DefaultNotificationBuilder, NotificationBuilder, Client, Endpoint}; +//! # use a2::{DefaultNotificationBuilder, NotificationBuilder, Client, ClientConfig, Endpoint}; //! # use std::fs::File; //! # #[tokio::main] //! # async fn main() -> Result<(), Box> { @@ -48,7 +48,7 @@ //! &mut file, //! "KEY_ID", //! "TEAM_ID", -//! Endpoint::Production).unwrap(); +//! ClientConfig::default()).unwrap(); //! //! let response = client.send(payload).await?; //! println!("Sent: {:?}", response); @@ -64,7 +64,7 @@ //! # { //! //! use a2::{ -//! Client, Endpoint, DefaultNotificationBuilder, NotificationBuilder, NotificationOptions, +//! Client, ClientConfig, Endpoint, DefaultNotificationBuilder, NotificationBuilder, NotificationOptions, //! Priority, //! }; //! use std::fs::File; @@ -97,7 +97,7 @@ //! let client = Client::certificate( //! &mut file, //! "Correct Horse Battery Stable", -//! Endpoint::Production)?; +//! ClientConfig::default())?; //! //! let response = client.send(payload).await?; //! println!("Sent: {:?}", response); @@ -131,6 +131,6 @@ pub use crate::request::notification::{ pub use crate::response::{ErrorBody, ErrorReason, Response}; -pub use crate::client::{Client, Endpoint}; +pub use crate::client::{Client, ClientConfig, Endpoint}; pub use crate::error::Error; diff --git a/src/request/payload.rs b/src/request/payload.rs index db50987..dcdcba1 100644 --- a/src/request/payload.rs +++ b/src/request/payload.rs @@ -27,11 +27,9 @@ pub struct Payload<'a> { /// /// # Example /// ```no_run -/// use a2::client::Endpoint; /// use a2::request::notification::{NotificationBuilder, NotificationOptions}; /// use a2::request::payload::{PayloadLike, APS}; -/// use a2::Client; -/// use a2::DefaultNotificationBuilder; +/// use a2::{Client, ClientConfig, DefaultNotificationBuilder, Endpoint}; /// use serde::Serialize; /// use std::fs::File; /// @@ -45,7 +43,7 @@ pub struct Payload<'a> { /// let payload = builder.build("device-token-from-the-user", Default::default()); /// let mut file = File::open("/path/to/private_key.p8")?; /// -/// let client = Client::token(&mut file, "KEY_ID", "TEAM_ID", Endpoint::Production).unwrap(); +/// let client = Client::token(&mut file, "KEY_ID", "TEAM_ID", ClientConfig::default()).unwrap(); /// /// let response = client.send(payload).await?; /// println!("Sent: {:?}", response);