diff --git a/src/lib.rs b/src/lib.rs index dac48480..68fb0064 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -83,6 +83,7 @@ impl AdapterOptions { } } +#[derive(Clone)] pub struct Adapter { client: Arc>, healthcheck_url: Uri, @@ -101,13 +102,15 @@ impl Adapter { let client = Client::builder().pool_idle_timeout(Duration::from_secs(4)).build_http(); let healthcheck_url = format!( - "http://{}:{}{}", - options.host, options.readiness_check_port, options.readiness_check_path + "{}://{}:{}{}", + "http", options.host, options.readiness_check_port, options.readiness_check_path ) .parse() .unwrap(); - let domain = format!("http://{}:{}", options.host, options.port).parse().unwrap(); + let domain = format!("{}://{}:{}", "http", options.host, options.port) + .parse() + .unwrap(); Adapter { client: Arc::new(client), @@ -186,101 +189,73 @@ impl Adapter { pub async fn run(self) -> Result<(), Error> { lambda_http::run(self).await } -} -/// Implement a `Tower.Service` that sends the requests -/// to the web server. -impl Service for Adapter { - type Response = Response; - type Error = Error; - type Future = Pin> + Send>>; + async fn fetch_response(&self, event: Request) -> Result, Error> { + if self.async_init && !self.ready_at_init.load(Ordering::SeqCst) { + is_web_ready(&self.healthcheck_url, &self.healthcheck_protocol).await; + self.ready_at_init.store(true, Ordering::SeqCst); + } - fn poll_ready(&mut self, _cx: &mut core::task::Context<'_>) -> core::task::Poll> { - core::task::Poll::Ready(Ok(())) - } + let request_context = event.request_context(); + let path = event.raw_http_path(); + let mut path = path.as_str(); + let (parts, body) = event.into_parts(); - fn call(&mut self, event: Request) -> Self::Future { - let async_init = self.async_init; - let client = self.client.clone(); - let ready_at_init = self.ready_at_init.clone(); - let healthcheck_url = self.healthcheck_url.clone(); - let healthcheck_protocol = self.healthcheck_protocol; - let domain = self.domain.clone(); - let base_path = self.base_path.clone(); - - Box::pin(async move { - fetch_response( - async_init, - ready_at_init, - client, - base_path, - domain, - healthcheck_url, - healthcheck_protocol, - event, - ) - .await - }) - } -} + // strip away Base Path if environment variable REMOVE_BASE_PATH is set. + if let Some(base_path) = self.base_path.as_deref() { + path = path.trim_start_matches(base_path); + } -#[allow(clippy::too_many_arguments)] -async fn fetch_response( - async_init: bool, - ready_at_init: Arc, - client: Arc>, - base_path: Option, - domain: Uri, - healthcheck_url: Uri, - healthcheck_protocol: Protocol, - event: Request, -) -> Result, Error> { - if async_init && !ready_at_init.load(Ordering::SeqCst) { - is_web_ready(&healthcheck_url, &healthcheck_protocol).await; - ready_at_init.store(true, Ordering::SeqCst); - } + let mut req_headers = parts.headers; - let request_context = event.request_context(); - let path = event.raw_http_path(); - let mut path = path.as_str(); - let (parts, body) = event.into_parts(); + // include request context in http header "x-amzn-request-context" + req_headers.append( + HeaderName::from_static("x-amzn-request-context"), + HeaderValue::from_bytes(serde_json::to_string(&request_context)?.as_bytes())?, + ); - // strip away Base Path if environment variable REMOVE_BASE_PATH is set. - if let Some(base_path) = base_path.as_deref() { - path = path.trim_start_matches(base_path); - } + let mut pq = path.to_string(); + if let Some(q) = parts.uri.query() { + pq.push('?'); + pq.push_str(q); + } - let mut req_headers = parts.headers; + let mut app_parts = self.domain.clone().into_parts(); + app_parts.path_and_query = Some(pq.parse()?); + let app_url = Uri::from_parts(app_parts)?; - // include request context in http header "x-amzn-request-context" - req_headers.append( - HeaderName::from_static("x-amzn-request-context"), - HeaderValue::from_bytes(serde_json::to_string(&request_context)?.as_bytes())?, - ); + tracing::debug!(app_url = %app_url, req_headers = ?req_headers, "sending request to app server"); - let mut pq = path.to_string(); - if let Some(q) = parts.uri.query() { - pq.push('?'); - pq.push_str(q); - } + let mut builder = hyper::Request::builder().method(parts.method).uri(app_url); + if let Some(headers) = builder.headers_mut() { + headers.extend(req_headers); + } - let mut app_parts = domain.into_parts(); - app_parts.path_and_query = Some(pq.parse()?); - let app_url = Uri::from_parts(app_parts)?; + let request = builder.body(hyper::Body::from(body.to_vec()))?; - tracing::debug!(app_url = %app_url, req_headers = ?req_headers, "sending request to app server"); + let app_response = self.client.request(request).await?; + tracing::debug!(status = %app_response.status(), body_size = app_response.body().size_hint().lower(), + app_headers = ?app_response.headers().clone(), "responding to lambda event"); - let mut builder = hyper::Request::builder().method(parts.method).uri(app_url); - if let Some(headers) = builder.headers_mut() { - headers.extend(req_headers); + Ok(app_response) } +} + +/// Implement a `Tower.Service` that sends the requests +/// to the web server. +impl Service for Adapter { + type Response = Response; + type Error = Error; + type Future = Pin> + Send>>; - let request = builder.body(hyper::Body::from(body.to_vec()))?; + fn poll_ready(&mut self, _cx: &mut core::task::Context<'_>) -> core::task::Poll> { + core::task::Poll::Ready(Ok(())) + } - let app_response = client.request(request).await?; - tracing::debug!(status = %app_response.status(), body_size = app_response.body().size_hint().lower(), - app_headers = ?app_response.headers().clone(), "responding to lambda event"); - Ok(app_response) + fn call(&mut self, event: Request) -> Self::Future { + let adapter = self.clone(); + Box::pin(async move { adapter.fetch_response(event).await }) + } } async fn is_web_ready(url: &Uri, protocol: &Protocol) -> bool {