Skip to content

Commit

Permalink
feat: Add option to set a request timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
threema-donat committed Apr 24, 2024
1 parent ce701eb commit a1949fe
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 37 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## Unreleased

- feat: Add support to set the request timeout

## v0.6.2

- Add support for Safari web push
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ base64 = "0.20"
tracing = { version = "0.1", optional = true }
pem = { version = "1.0", optional = true }
ring = { version = "0.16", features = ["std"], optional = true }
tokio = { version = "1", features = ["time"] }

[dev-dependencies]
argparse = "0.2"
Expand Down
166 changes: 129 additions & 37 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::error::Error;
use crate::error::Error::ResponseError;
use crate::signer::Signer;
use hyper_alpn::AlpnConnector;
use tokio::time::timeout;

use crate::request::payload::PayloadLike;
use crate::response::Response;
Expand Down Expand Up @@ -43,23 +44,96 @@ impl fmt::Display for Endpoint {
/// holds the response for handling.
#[derive(Debug, Clone)]
pub struct Client {
endpoint: Endpoint,
signer: Option<Signer>,
options: ConnectionOptions,
http_client: HttpClient<AlpnConnector>,
}

impl Client {
fn new(connector: AlpnConnector, signer: Option<Signer>, endpoint: Endpoint) -> Client {
let mut builder = HttpClient::builder();
builder.pool_idle_timeout(Some(Duration::from_secs(600)));
builder.http2_only(true);
/// Uses [`Endpoint::Production`] by default.
#[derive(Debug, Clone)]
pub struct ClientOptions {
/// The timeout of the HTTP requests
pub request_timeout_secs: Option<u64>,
/// The timeout for idle sockets being kept alive
pub pool_idle_timeout_secs: Option<u64>,
/// The endpoint where the requests are sent to
pub endpoint: Endpoint,
/// See [`crate::signer::Signer`]
pub signer: Option<Signer>,
}

impl Default for ClientOptions {
fn default() -> Self {
Self {
pool_idle_timeout_secs: Some(600),
request_timeout_secs: None,
endpoint: Endpoint::Production,
signer: None,
}
}
}

impl ClientOptions {
pub fn new(endpoint: Endpoint) -> Self {
Self {
endpoint,
..Default::default()
}
}

pub fn with_signer(mut self, signer: Signer) -> Self {
self.signer = Some(signer);
self
}

pub fn with_request_timeout(mut self, seconds: u64) -> Self {
self.request_timeout_secs = Some(seconds);
self
}

pub fn with_pool_idle_timeout(mut self, seconds: u64) -> Self {
self.pool_idle_timeout_secs = Some(seconds);
self
}
}

#[derive(Debug, Clone)]
struct ConnectionOptions {
endpoint: Endpoint,
request_timeout: Duration,
signer: Option<Signer>,
}

Client {
http_client: builder.build(connector),
impl From<ClientOptions> for ConnectionOptions {
fn from(value: ClientOptions) -> Self {
let ClientOptions {
endpoint,
pool_idle_timeout_secs: _,
signer,
request_timeout_secs,
} = value;
let request_timeout = Duration::from_secs(request_timeout_secs.unwrap_or(20));
Self {
endpoint,
request_timeout,
signer,
}
}
}

impl Client {
/// If `options` is not set, a default using [`Endpoint::Production`] will
/// be initialized.
fn new(connector: AlpnConnector, options: Option<ClientOptions>) -> Client {
let options = options.unwrap_or_default();
let http_client = HttpClient::builder()
.pool_idle_timeout(options.pool_idle_timeout_secs.map(Duration::from_secs))
.http2_only(true)
.build(connector);

let options = options.into();

Client { http_client, options }
}

/// Create a connection to APNs using the provider client certificate which
/// you obtain from your [Apple developer
Expand All @@ -77,7 +151,7 @@ impl Client {
let pkcs = openssl::pkcs12::Pkcs12::from_der(&cert_der)?.parse(password)?;
let connector = AlpnConnector::with_client_cert(&pkcs.cert.to_pem()?, &pkcs.pkey.private_key_to_pem_pkcs8()?)?;

Ok(Self::new(connector, None, endpoint))
Ok(Self::new(connector, Some(ClientOptions::new(endpoint))))
}

/// Create a connection to APNs using the raw PEM-formatted certificate and
Expand All @@ -86,7 +160,7 @@ impl Client {
pub fn certificate_parts(cert_pem: &[u8], key_pem: &[u8], endpoint: Endpoint) -> Result<Client, Error> {
let connector = AlpnConnector::with_client_cert(cert_pem, key_pem)?;

Ok(Self::new(connector, None, endpoint))
Ok(Self::new(connector, Some(ClientOptions::new(endpoint))))
}

/// Create a connection to APNs using system certificates, signing every
Expand All @@ -101,9 +175,16 @@ impl Client {
{
let connector = AlpnConnector::new();
let signature_ttl = Duration::from_secs(60 * 55);
let signer = Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?;
let signer = Some(Signer::new(pkcs8_pem, key_id, team_id, signature_ttl)?);

Ok(Self::new(connector, Some(signer), endpoint))
Ok(Self::new(
connector,
Some(ClientOptions {
endpoint,
signer,
..Default::default()
}),
))
}

/// Send a notification payload.
Expand All @@ -114,7 +195,11 @@ impl Client {
let request = self.build_request(payload);
let requesting = self.http_client.request(request);

let response = requesting.await?;
let Ok(response_result) = timeout(self.options.request_timeout, requesting).await else {
return Err(Error::RequestTimeout(self.options.request_timeout.as_secs()));
};

let response = response_result?;

let apns_id = response
.headers()
Expand All @@ -141,7 +226,11 @@ impl Client {
}

fn build_request<T: PayloadLike>(&self, payload: T) -> hyper::Request<Body> {
let path = format!("https://{}/3/device/{}", self.endpoint, payload.get_device_token());
let path = format!(
"https://{}/3/device/{}",
self.options.endpoint,
payload.get_device_token()
);

let mut builder = hyper::Request::builder()
.uri(&path)
Expand All @@ -167,7 +256,7 @@ impl Client {
if let Some(apns_topic) = options.apns_topic {
builder = builder.header("apns-topic", apns_topic.as_bytes());
}
if let Some(ref signer) = self.signer {
if let Some(ref signer) = self.options.signer {
let auth = signer
.with_signature(|signature| format!("Bearer {}", signature))
.unwrap();
Expand Down Expand Up @@ -205,7 +294,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_production_request_uri() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);
let uri = format!("{}", request.uri());

Expand All @@ -216,7 +305,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_sandbox_request_uri() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Sandbox);
let client = Client::new(AlpnConnector::new(), Some(ClientOptions::new(Endpoint::Sandbox)));
let request = client.build_request(payload);
let uri = format!("{}", request.uri());

Expand All @@ -227,7 +316,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_method() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);

assert_eq!(&Method::POST, request.method());
Expand All @@ -237,7 +326,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_content_type() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);

assert_eq!("application/json", request.headers().get(CONTENT_TYPE).unwrap());
Expand All @@ -247,7 +336,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_content_length() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload.clone());
let payload_json = payload.to_json_string().unwrap();
let content_length = request.headers().get(CONTENT_LENGTH).unwrap().to_str().unwrap();
Expand All @@ -259,7 +348,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_authorization_with_no_signer() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);

assert_eq!(None, request.headers().get(AUTHORIZATION));
Expand All @@ -277,7 +366,10 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), Some(signer), Endpoint::Production);
let client = Client::new(
AlpnConnector::new(),
Some(ClientOptions::new(Endpoint::Production).with_signer(signer)),
);
let request = client.build_request(payload);

assert_ne!(None, request.headers().get(AUTHORIZATION));
Expand All @@ -291,7 +383,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
..Default::default()
};
let payload = builder.build("a_test_id", options);
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);
let apns_push_type = request.headers().get("apns-push-type").unwrap();

Expand All @@ -302,7 +394,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
fn test_request_with_default_priority() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);
let apns_priority = request.headers().get("apns-priority");

Expand All @@ -321,7 +413,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);
let apns_priority = request.headers().get("apns-priority").unwrap();

Expand All @@ -340,7 +432,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);
let apns_priority = request.headers().get("apns-priority").unwrap();

Expand All @@ -353,7 +445,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);
let apns_id = request.headers().get("apns-id");

Expand All @@ -372,7 +464,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);
let apns_id = request.headers().get("apns-id").unwrap();

Expand All @@ -385,7 +477,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);
let apns_expiration = request.headers().get("apns-expiration");

Expand All @@ -404,7 +496,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);
let apns_expiration = request.headers().get("apns-expiration").unwrap();

Expand All @@ -417,7 +509,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);
let apns_collapse_id = request.headers().get("apns-collapse-id");

Expand All @@ -436,7 +528,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);
let apns_collapse_id = request.headers().get("apns-collapse-id").unwrap();

Expand All @@ -449,7 +541,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ

let payload = builder.build("a_test_id", Default::default());

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);
let apns_topic = request.headers().get("apns-topic");

Expand All @@ -468,7 +560,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
},
);

let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload);
let apns_topic = request.headers().get("apns-topic").unwrap();

Expand All @@ -479,7 +571,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
async fn test_request_body() {
let builder = DefaultNotificationBuilder::new();
let payload = builder.build("a_test_id", Default::default());
let client = Client::new(AlpnConnector::new(), None, Endpoint::Production);
let client = Client::new(AlpnConnector::new(), None);
let request = client.build_request(payload.clone());

let body = hyper::body::to_bytes(request).await.unwrap();
Expand All @@ -497,7 +589,7 @@ jDwmlD1Gg0yJt1e38djFwsxsfr5q2hv0Rj9fTEqAPr8H7mGm0wKxZ7iQ
let cert: Vec<u8> = include_str!("../test_cert/test.crt").bytes().collect();

let c = Client::certificate_parts(&cert, &key, Endpoint::Sandbox)?;
assert!(c.signer.is_none());
assert!(c.options.signer.is_none());
Ok(())
}
}
Loading

0 comments on commit a1949fe

Please sign in to comment.