diff --git a/hook-worker/src/dns.rs b/hook-worker/src/dns.rs index 52c55bd..ca444ec 100644 --- a/hook-worker/src/dns.rs +++ b/hook-worker/src/dns.rs @@ -1,11 +1,25 @@ use std::error::Error as StdError; -use std::io; use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; +use std::{fmt, io}; use futures::FutureExt; use reqwest::dns::{Addrs, Name, Resolve, Resolving}; use tokio::task::spawn_blocking; +pub struct NoPublicIPError; + +impl std::error::Error for NoPublicIPError {} +impl fmt::Display for NoPublicIPError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "No public IPv4 found for specified host") + } +} +impl fmt::Debug for NoPublicIPError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "No public IPv4 found for specified host") + } +} + /// Internal reqwest type, copied here as part of Resolving pub(crate) type BoxError = Box; @@ -40,10 +54,18 @@ impl Resolve for PublicIPv4Resolver { // Execute the blocking call in a separate worker thread then process its result asynchronously. // spawn_blocking returns a JoinHandle that implements Future>. let future_result = spawn_blocking(resolve_host).map(|result| match result { - Ok(Ok(addr)) => { - // Resolution succeeded, pass the IPs in a Box after filtering - let addrs: Addrs = Box::new(addr.filter(is_global_ipv4)); - Ok(addrs) + Ok(Ok(all_addrs)) => { + // Resolution succeeded, filter the results + let filtered_addr: Vec = all_addrs.filter(is_global_ipv4).collect(); + if filtered_addr.is_empty() { + // No public IPs found, error out with PermissionDenied + let err: BoxError = Box::new(NoPublicIPError); + Err(err) + } else { + // Pass remaining IPs in a boxed iterator for request to use. + let addrs: Addrs = Box::new(filtered_addr.into_iter()); + Ok(addrs) + } } Ok(Err(err)) => { // Resolution failed, pass error through in a Box diff --git a/hook-worker/src/error.rs b/hook-worker/src/error.rs index 48468bc..51fe468 100644 --- a/hook-worker/src/error.rs +++ b/hook-worker/src/error.rs @@ -1,6 +1,8 @@ +use std::error::Error; use std::fmt; use std::time; +use crate::dns::NoPublicIPError; use hook_common::{pgqueue, webhook::WebhookJobError}; use thiserror::Error; @@ -65,6 +67,9 @@ impl fmt::Display for WebhookRequestError { None => "No response from the server".to_string(), }; writeln!(f, "{}", error)?; + if is_error_source::(error) { + writeln!(f, "{}", NoPublicIPError)?; + } write!(f, "{}", response_message)?; Ok(()) @@ -132,3 +137,14 @@ pub enum WorkerError { #[error("timed out while waiting for jobs to be available")] TimeoutError, } + +/// Check the error and it's sources (recursively) to return true if an error of the given type is found. +pub fn is_error_source(err: &(dyn std::error::Error + 'static)) -> bool { + if err.downcast_ref::().is_some() { + return true; + } + match err.source() { + None => false, + Some(source) => is_error_source::(source), + } +} diff --git a/hook-worker/src/worker.rs b/hook-worker/src/worker.rs index 9a42a90..7fc3d0e 100644 --- a/hook-worker/src/worker.rs +++ b/hook-worker/src/worker.rs @@ -14,7 +14,7 @@ use hook_common::{ webhook::{HttpMethod, WebhookJobError, WebhookJobMetadata, WebhookJobParameters}, }; use http::StatusCode; -use reqwest::header; +use reqwest::{header, Client}; use tokio::sync; use tracing::error; @@ -75,6 +75,25 @@ pub struct WebhookWorker<'p> { liveness: HealthHandle, } +pub fn build_http_client( + request_timeout: time::Duration, + allow_internal_ips: bool, +) -> reqwest::Result { + let mut headers = header::HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ); + let mut client_builder = reqwest::Client::builder() + .default_headers(headers) + .user_agent("PostHog Webhook Worker") + .timeout(request_timeout); + if !allow_internal_ips { + client_builder = client_builder.dns_resolver(Arc::new(PublicIPv4Resolver {})) + } + client_builder.build() +} + impl<'p> WebhookWorker<'p> { #[allow(clippy::too_many_arguments)] pub fn new( @@ -88,21 +107,7 @@ impl<'p> WebhookWorker<'p> { allow_internal_ips: bool, liveness: HealthHandle, ) -> Self { - let mut headers = header::HeaderMap::new(); - headers.insert( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ); - - let mut client_builder = reqwest::Client::builder() - .default_headers(headers) - .user_agent("PostHog Webhook Worker") - .timeout(request_timeout); - if !allow_internal_ips { - client_builder = client_builder.dns_resolver(Arc::new(PublicIPv4Resolver {})) - } - let client = client_builder - .build() + let client = build_http_client(request_timeout, allow_internal_ips) .expect("failed to construct reqwest client for webhook worker"); Self { @@ -475,6 +480,7 @@ fn parse_retry_after_header(header_map: &reqwest::header::HeaderMap) -> Option Client { + build_http_client(Duration::from_secs(1), true).expect("failed to create client") + } + #[allow(dead_code)] async fn enqueue_job( queue: &PgQueue, @@ -565,8 +577,8 @@ mod tests { webhook_job_parameters.clone(), webhook_job_metadata, ) - .await - .expect("failed to enqueue job"); + .await + .expect("failed to enqueue job"); let worker = WebhookWorker::new( &worker_id, &queue, @@ -601,15 +613,14 @@ mod tests { assert!(registry.get_status().healthy) } - #[sqlx::test(migrations = "../migrations")] - async fn test_send_webhook(_pg: PgPool) { + #[tokio::test] + async fn test_send_webhook() { let method = HttpMethod::POST; let url = "http://localhost:18081/echo"; let headers = collections::HashMap::new(); let body = "a very relevant request body"; - let client = reqwest::Client::new(); - let response = send_webhook(client, &method, url, &headers, body.to_owned()) + let response = send_webhook(localhost_client(), &method, url, &headers, body.to_owned()) .await .expect("send_webhook failed"); @@ -620,15 +631,14 @@ mod tests { ); } - #[sqlx::test(migrations = "../migrations")] - async fn test_error_message_contains_response_body(_pg: PgPool) { + #[tokio::test] + async fn test_error_message_contains_response_body() { let method = HttpMethod::POST; let url = "http://localhost:18081/fail"; let headers = collections::HashMap::new(); let body = "this is an error message"; - let client = reqwest::Client::new(); - let err = send_webhook(client, &method, url, &headers, body.to_owned()) + let err = send_webhook(localhost_client(), &method, url, &headers, body.to_owned()) .await .err() .expect("request didn't fail when it should have failed"); @@ -645,17 +655,16 @@ mod tests { } } - #[sqlx::test(migrations = "../migrations")] - async fn test_error_message_contains_up_to_n_bytes_of_response_body(_pg: PgPool) { + #[tokio::test] + async fn test_error_message_contains_up_to_n_bytes_of_response_body() { let method = HttpMethod::POST; let url = "http://localhost:18081/fail"; let headers = collections::HashMap::new(); // This is double the current hardcoded amount of bytes. // TODO: Make this configurable and change it here too. let body = (0..20 * 1024).map(|_| "a").collect::>().concat(); - let client = reqwest::Client::new(); - let err = send_webhook(client, &method, url, &headers, body.to_owned()) + let err = send_webhook(localhost_client(), &method, url, &headers, body.to_owned()) .await .err() .expect("request didn't fail when it should have failed"); @@ -673,4 +682,29 @@ mod tests { )); } } + + #[tokio::test] + async fn test_private_ips_denied() { + let method = HttpMethod::POST; + let url = "http://localhost:18081/echo"; + let headers = collections::HashMap::new(); + let body = "a very relevant request body"; + let filtering_client = + build_http_client(Duration::from_secs(1), false).expect("failed to create client"); + + let err = send_webhook(filtering_client, &method, url, &headers, body.to_owned()) + .await + .err() + .expect("request didn't fail when it should have failed"); + + assert!(matches!(err, WebhookError::Request(..))); + if let WebhookError::Request(request_error) = err { + assert_eq!(request_error.status(), None); + assert!(request_error + .to_string() + .contains("No public IPv4 found for specified host")); + } else { + panic!("unexpected error type {:?}", err) + } + } }