diff --git a/lambda-http/Cargo.toml b/lambda-http/Cargo.toml index cb1f7a97..8b1d9285 100644 --- a/lambda-http/Cargo.toml +++ b/lambda-http/Cargo.toml @@ -17,6 +17,7 @@ travis-ci = { repository = "awslabs/aws-lambda-rust-runtime" } maintenance = { status = "actively-developed" } [dependencies] +aws_lambda_events = { version = "0.6", default-features = false, features = ["alb", "apigw"]} base64 = "0.13.0" bytes = "1" http = "0.2" @@ -25,6 +26,7 @@ lambda_runtime = { path = "../lambda-runtime", version = "0.5" } serde = { version = "^1", features = ["derive"] } serde_json = "^1" serde_urlencoded = "0.7.0" +query_map = { version = "0.4", features = ["url-query"] } [dev-dependencies] log = "^0.4" diff --git a/lambda-http/examples/hello-cors.rs b/lambda-http/examples/hello-cors.rs index 79d2a986..275873a4 100644 --- a/lambda-http/examples/hello-cors.rs +++ b/lambda-http/examples/hello-cors.rs @@ -19,7 +19,7 @@ async fn main() -> Result<(), Error> { } async fn func(event: Request) -> Result, Error> { - Ok(match event.query_string_parameters().get("first_name") { + Ok(match event.query_string_parameters().first("first_name") { Some(first_name) => format!("Hello, {}!", first_name).into_response(), _ => Response::builder() .status(400) diff --git a/lambda-http/examples/hello-http.rs b/lambda-http/examples/hello-http.rs index 40352dab..5b679196 100644 --- a/lambda-http/examples/hello-http.rs +++ b/lambda-http/examples/hello-http.rs @@ -7,7 +7,7 @@ async fn main() -> Result<(), Error> { } async fn func(event: Request) -> Result { - Ok(match event.query_string_parameters().get("first_name") { + Ok(match event.query_string_parameters().first("first_name") { Some(first_name) => format!("Hello, {}!", first_name).into_response(), _ => Response::builder() .status(400) diff --git a/lambda-http/examples/shared-resources-example.rs b/lambda-http/examples/shared-resources-example.rs index 24e56f97..a90dd815 100644 --- a/lambda-http/examples/shared-resources-example.rs +++ b/lambda-http/examples/shared-resources-example.rs @@ -20,7 +20,7 @@ async fn main() -> Result<(), Error> { // Define a closure here that makes use of the shared client. let handler_func_closure = move |event: Request| async move { - Ok(match event.query_string_parameters().get("first_name") { + Ok(match event.query_string_parameters().first("first_name") { Some(first_name) => shared_client_ref .response(event.lambda_context().request_id, first_name) .into_response(), diff --git a/lambda-http/src/body.rs b/lambda-http/src/body.rs deleted file mode 100644 index 7ac31ce7..00000000 --- a/lambda-http/src/body.rs +++ /dev/null @@ -1,295 +0,0 @@ -//! Provides an ALB / API Gateway oriented request and response body entity interface - -use crate::Error; -use base64::display::Base64Display; -use bytes::Bytes; -use http_body::{Body as HttpBody, SizeHint}; -use serde::ser::{Error as SerError, Serialize, Serializer}; -use std::{borrow::Cow, mem::take, ops::Deref, pin::Pin, task::Poll}; - -/// Representation of http request and response bodies as supported -/// by API Gateway and ALBs. -/// -/// These come in three flavors -/// * `Empty` ( no body ) -/// * `Text` ( text data ) -/// * `Binary` ( binary data ) -/// -/// Body types can be `Deref` and `AsRef`'d into `[u8]` types much like the [hyper crate](https://crates.io/crates/hyper) -/// -/// # Examples -/// -/// Body types are inferred with `From` implementations. -/// -/// ## Text -/// -/// Types like `String`, `str` whose type reflects -/// text produce `Body::Text` variants -/// -/// ``` -/// assert!(match lambda_http::Body::from("text") { -/// lambda_http::Body::Text(_) => true, -/// _ => false -/// }) -/// ``` -/// -/// ## Binary -/// -/// Types like `Vec` and `&[u8]` whose types reflect raw bytes produce `Body::Binary` variants -/// -/// ``` -/// assert!(match lambda_http::Body::from("text".as_bytes()) { -/// lambda_http::Body::Binary(_) => true, -/// _ => false -/// }) -/// ``` -/// -/// `Binary` responses bodies will automatically get based64 encoded to meet API Gateway's response expectations. -/// -/// ## Empty -/// -/// The unit type ( `()` ) whose type represents an empty value produces `Body::Empty` variants -/// -/// ``` -/// assert!(match lambda_http::Body::from(()) { -/// lambda_http::Body::Empty => true, -/// _ => false -/// }) -/// ``` -/// -/// -/// For more information about API Gateway's body types, -/// refer to [this documentation](https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-payload-encodings.html). -#[derive(Debug, PartialEq)] -pub enum Body { - /// An empty body - Empty, - /// A body containing string data - Text(String), - /// A body containing binary data - Binary(Vec), -} - -impl Body { - /// Decodes body, if needed. - /// - /// # Panics - /// - /// Panics when aws communicates to handler that request is base64 encoded but - /// it can not be base64 decoded - pub(crate) fn from_maybe_encoded(is_base64_encoded: bool, body: Cow<'_, str>) -> Body { - if is_base64_encoded { - Body::from(::base64::decode(body.as_ref()).expect("failed to decode aws base64 encoded body")) - } else { - Body::from(body.as_ref()) - } - } -} - -impl Default for Body { - fn default() -> Self { - Body::Empty - } -} - -impl From<()> for Body { - fn from(_: ()) -> Self { - Body::Empty - } -} - -impl<'a> From<&'a str> for Body { - fn from(s: &'a str) -> Self { - Body::Text(s.into()) - } -} - -impl From for Body { - fn from(b: String) -> Self { - Body::Text(b) - } -} - -impl From> for Body { - #[inline] - fn from(cow: Cow<'static, str>) -> Body { - match cow { - Cow::Borrowed(b) => Body::from(b.to_owned()), - Cow::Owned(o) => Body::from(o), - } - } -} - -impl From> for Body { - #[inline] - fn from(cow: Cow<'static, [u8]>) -> Body { - match cow { - Cow::Borrowed(b) => Body::from(b), - Cow::Owned(o) => Body::from(o), - } - } -} - -impl From> for Body { - fn from(b: Vec) -> Self { - Body::Binary(b) - } -} - -impl<'a> From<&'a [u8]> for Body { - fn from(b: &'a [u8]) -> Self { - Body::Binary(b.to_vec()) - } -} - -impl Deref for Body { - type Target = [u8]; - - #[inline] - fn deref(&self) -> &Self::Target { - self.as_ref() - } -} - -impl AsRef<[u8]> for Body { - #[inline] - fn as_ref(&self) -> &[u8] { - match self { - Body::Empty => &[], - Body::Text(ref bytes) => bytes.as_ref(), - Body::Binary(ref bytes) => bytes.as_ref(), - } - } -} - -impl<'a> Serialize for Body { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - match self { - Body::Text(data) => { - serializer.serialize_str(::std::str::from_utf8(data.as_ref()).map_err(S::Error::custom)?) - } - Body::Binary(data) => serializer.collect_str(&Base64Display::with_config(data, base64::STANDARD)), - Body::Empty => serializer.serialize_unit(), - } - } -} - -impl HttpBody for Body { - type Data = Bytes; - type Error = Error; - - fn poll_data( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll>> { - let body = take(self.get_mut()); - Poll::Ready(match body { - Body::Empty => None, - Body::Text(s) => Some(Ok(s.into())), - Body::Binary(b) => Some(Ok(b.into())), - }) - } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } - - fn is_end_stream(&self) -> bool { - match self { - Body::Empty => true, - _ => false, - } - } - - fn size_hint(&self) -> SizeHint { - match self { - Body::Empty => SizeHint::default(), - Body::Text(ref s) => SizeHint::with_exact(s.len() as u64), - Body::Binary(ref b) => SizeHint::with_exact(b.len() as u64), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json; - use std::collections::HashMap; - - #[test] - fn body_has_default() { - assert_eq!(Body::default(), Body::Empty); - } - - #[test] - fn from_unit() { - assert_eq!(Body::from(()), Body::Empty); - } - - #[test] - fn from_str() { - match Body::from(String::from("foo").as_str()) { - Body::Text(_) => (), - not => assert!(false, "expected Body::Text(...) got {:?}", not), - } - } - - #[test] - fn from_string() { - match Body::from(String::from("foo")) { - Body::Text(_) => (), - not => assert!(false, "expected Body::Text(...) got {:?}", not), - } - } - - #[test] - fn from_cow_str() { - match Body::from(Cow::from("foo")) { - Body::Text(_) => (), - not => assert!(false, "expected Body::Text(...) got {:?}", not), - } - } - - #[test] - fn from_cow_bytes() { - match Body::from(Cow::from("foo".as_bytes())) { - Body::Binary(_) => (), - not => assert!(false, "expected Body::Binary(...) got {:?}", not), - } - } - - #[test] - fn from_bytes() { - match Body::from("foo".as_bytes()) { - Body::Binary(_) => (), - not => assert!(false, "expected Body::Binary(...) got {:?}", not), - } - } - - #[test] - fn serialize_text() { - let mut map = HashMap::new(); - map.insert("foo", Body::from("bar")); - assert_eq!(serde_json::to_string(&map).unwrap(), r#"{"foo":"bar"}"#); - } - - #[test] - fn serialize_binary() { - let mut map = HashMap::new(); - map.insert("foo", Body::from("bar".as_bytes())); - assert_eq!(serde_json::to_string(&map).unwrap(), r#"{"foo":"YmFy"}"#); - } - - #[test] - fn serialize_empty() { - let mut map = HashMap::new(); - map.insert("foo", Body::Empty); - assert_eq!(serde_json::to_string(&map).unwrap(), r#"{"foo":null}"#); - } -} diff --git a/lambda-http/src/ext.rs b/lambda-http/src/ext.rs index 2f56d78c..e33e639c 100644 --- a/lambda-http/src/ext.rs +++ b/lambda-http/src/ext.rs @@ -1,23 +1,24 @@ //! Extension methods for `http::Request` types -use crate::{request::RequestContext, strmap::StrMap, Body}; +use crate::{request::RequestContext, Body}; use lambda_runtime::Context; +use query_map::QueryMap; use serde::{de::value::Error as SerdeError, Deserialize}; use std::{error::Error, fmt}; /// ALB/API gateway pre-parsed http query string parameters -pub(crate) struct QueryStringParameters(pub(crate) StrMap); +pub(crate) struct QueryStringParameters(pub(crate) QueryMap); /// API gateway pre-extracted url path parameters /// /// These will always be empty for ALB requests -pub(crate) struct PathParameters(pub(crate) StrMap); +pub(crate) struct PathParameters(pub(crate) QueryMap); /// API gateway configured /// [stage variables](https://docs.aws.amazon.com/apigateway/latest/developerguide/stage-variables.html) /// /// These will always be empty for ALB requests -pub(crate) struct StageVariables(pub(crate) StrMap); +pub(crate) struct StageVariables(pub(crate) QueryMap); /// Request payload deserialization errors /// @@ -112,37 +113,37 @@ pub trait RequestExt { /// name are expected, `query_string_parameters().get_all("many")` to retrieve them all. /// /// No query parameters - /// will yield an empty `StrMap`. - fn query_string_parameters(&self) -> StrMap; + /// will yield an empty `QueryMap`. + fn query_string_parameters(&self) -> QueryMap; /// Configures instance with query string parameters under #[cfg(test)] configurations /// /// This is intended for use in mock testing contexts. fn with_query_string_parameters(self, parameters: Q) -> Self where - Q: Into; + Q: Into; /// Return pre-extracted path parameters, parameter provided in url placeholders /// `/foo/{bar}/baz/{boom}`, /// associated with the API gateway request. No path parameters - /// will yield an empty `StrMap` + /// will yield an empty `QueryMap` /// /// These will always be empty for ALB triggered requests - fn path_parameters(&self) -> StrMap; + fn path_parameters(&self) -> QueryMap; /// Configures instance with path parameters under #[cfg(test)] configurations /// /// This is intended for use in mock testing contexts. fn with_path_parameters

(self, parameters: P) -> Self where - P: Into; + P: Into; /// Return [stage variables](https://docs.aws.amazon.com/apigateway/latest/developerguide/stage-variables.html) /// associated with the API gateway request. No stage parameters - /// will yield an empty `StrMap` + /// will yield an empty `QueryMap` /// /// These will always be empty for ALB triggered requests - fn stage_variables(&self) -> StrMap; + fn stage_variables(&self) -> QueryMap; /// Configures instance with stage variables under #[cfg(test)] configurations /// @@ -150,7 +151,7 @@ pub trait RequestExt { #[cfg(test)] fn with_stage_variables(self, variables: V) -> Self where - V: Into; + V: Into; /// Return request context data assocaited with the ALB or API gateway request fn request_context(&self) -> RequestContext; @@ -176,7 +177,7 @@ pub trait RequestExt { } impl RequestExt for http::Request { - fn query_string_parameters(&self) -> StrMap { + fn query_string_parameters(&self) -> QueryMap { self.extensions() .get::() .map(|ext| ext.0.clone()) @@ -185,14 +186,14 @@ impl RequestExt for http::Request { fn with_query_string_parameters(self, parameters: Q) -> Self where - Q: Into, + Q: Into, { let mut s = self; s.extensions_mut().insert(QueryStringParameters(parameters.into())); s } - fn path_parameters(&self) -> StrMap { + fn path_parameters(&self) -> QueryMap { self.extensions() .get::() .map(|ext| ext.0.clone()) @@ -201,14 +202,14 @@ impl RequestExt for http::Request { fn with_path_parameters

(self, parameters: P) -> Self where - P: Into, + P: Into, { let mut s = self; s.extensions_mut().insert(PathParameters(parameters.into())); s } - fn stage_variables(&self) -> StrMap { + fn stage_variables(&self) -> QueryMap { self.extensions() .get::() .map(|ext| ext.0.clone()) @@ -218,7 +219,7 @@ impl RequestExt for http::Request { #[cfg(test)] fn with_stage_variables(self, variables: V) -> Self where - V: Into, + V: Into, { let mut s = self; s.extensions_mut().insert(StageVariables(variables.into())); diff --git a/lambda-http/src/lib.rs b/lambda-http/src/lib.rs index e1119dd0..95b26f25 100644 --- a/lambda-http/src/lib.rs +++ b/lambda-http/src/lib.rs @@ -51,7 +51,7 @@ //! "hello {}", //! request //! .query_string_parameters() -//! .get("name") +//! .first("name") //! .unwrap_or_else(|| "stranger") //! )) //! } @@ -66,16 +66,15 @@ pub use http::{self, Response}; use lambda_runtime::LambdaEvent; pub use lambda_runtime::{self, service_fn, tower, Context, Error, Service}; -mod body; pub mod ext; pub mod request; mod response; -mod strmap; -pub use crate::{body::Body, ext::RequestExt, response::IntoResponse, strmap::StrMap}; +pub use crate::{ext::RequestExt, response::IntoResponse}; use crate::{ request::{LambdaRequest, RequestOrigin}, response::LambdaResponse, }; +pub use aws_lambda_events::encodings::Body; use std::{ future::Future, marker::PhantomData, @@ -134,7 +133,7 @@ where } } -impl<'a, R, S> Service>> for Adapter<'a, R, S> +impl<'a, R, S> Service> for Adapter<'a, R, S> where S: Service + Send, S::Future: Send + 'a, @@ -148,7 +147,7 @@ where core::task::Poll::Ready(Ok(())) } - fn call(&mut self, req: LambdaEvent>) -> Self::Future { + fn call(&mut self, req: LambdaEvent) -> Self::Future { let request_origin = req.payload.request_origin(); let event: Request = req.payload.into(); let fut = Box::pin(self.service.call(event.with_lambda_context(req.context))); diff --git a/lambda-http/src/request.rs b/lambda-http/src/request.rs index fd509412..63c649c6 100644 --- a/lambda-http/src/request.rs +++ b/lambda-http/src/request.rs @@ -3,17 +3,18 @@ //! Typically these are exposed via the `request_context` //! request extension method provided by [lambda_http::RequestExt](../trait.RequestExt.html) //! -use crate::{ - body::Body, - ext::{PathParameters, QueryStringParameters, StageVariables}, - strmap::StrMap, +use crate::ext::{PathParameters, QueryStringParameters, StageVariables}; +use aws_lambda_events::alb::{AlbTargetGroupRequest, AlbTargetGroupRequestContext}; +use aws_lambda_events::apigw::{ + ApiGatewayProxyRequest, ApiGatewayProxyRequestContext, ApiGatewayV2httpRequest, ApiGatewayV2httpRequestContext, + ApiGatewayWebsocketProxyRequest, ApiGatewayWebsocketProxyRequestContext, }; -use serde::{ - de::{Deserializer, Error as DeError, MapAccess, Visitor}, - Deserialize, -}; -use serde_json::{error::Error as JsonError, Value}; -use std::{borrow::Cow, collections::HashMap, fmt, io::Read, mem}; +use aws_lambda_events::encodings::Body; +use http::header::HeaderName; +use query_map::QueryMap; +use serde::Deserialize; +use serde_json::error::Error as JsonError; +use std::{io::Read, mem}; /// Internal representation of an Lambda http event from /// ALB, API Gateway REST and HTTP API proxy event perspectives @@ -23,84 +24,23 @@ use std::{borrow::Cow, collections::HashMap, fmt, io::Read, mem}; #[doc(hidden)] #[derive(Deserialize, Debug)] #[serde(untagged)] -pub enum LambdaRequest<'a> { - #[serde(rename_all = "camelCase")] - ApiGatewayV2 { - version: Cow<'a, str>, - route_key: Cow<'a, str>, - raw_path: Cow<'a, str>, - raw_query_string: Cow<'a, str>, - cookies: Option>>, - #[serde(deserialize_with = "deserialize_headers")] - headers: http::HeaderMap, - #[serde(default, deserialize_with = "nullable_default")] - query_string_parameters: StrMap, - #[serde(default, deserialize_with = "nullable_default")] - path_parameters: StrMap, - #[serde(default, deserialize_with = "nullable_default")] - stage_variables: StrMap, - body: Option>, - #[serde(default)] - is_base64_encoded: bool, - request_context: ApiGatewayV2RequestContext, - }, - #[serde(rename_all = "camelCase")] - Alb { - path: Cow<'a, str>, - #[serde(deserialize_with = "deserialize_method")] - http_method: http::Method, - #[serde(deserialize_with = "deserialize_headers")] - headers: http::HeaderMap, - /// For alb events these are only present when - /// the `lambda.multi_value_headers.enabled` target group setting turned on - #[serde(default, deserialize_with = "deserialize_multi_value_headers")] - multi_value_headers: http::HeaderMap, - #[serde(default, deserialize_with = "nullable_default")] - query_string_parameters: StrMap, - /// For alb events these are only present when - /// the `lambda.multi_value_headers.enabled` target group setting turned on - #[serde(default, deserialize_with = "nullable_default")] - multi_value_query_string_parameters: StrMap, - body: Option>, - #[serde(default)] - is_base64_encoded: bool, - request_context: AlbRequestContext, - }, - #[serde(rename_all = "camelCase")] - ApiGateway { - path: Cow<'a, str>, - #[serde(deserialize_with = "deserialize_method")] - http_method: http::Method, - #[serde(deserialize_with = "deserialize_headers")] - headers: http::HeaderMap, - #[serde(default, deserialize_with = "deserialize_multi_value_headers")] - multi_value_headers: http::HeaderMap, - #[serde(default, deserialize_with = "nullable_default")] - query_string_parameters: StrMap, - #[serde(default, deserialize_with = "nullable_default")] - multi_value_query_string_parameters: StrMap, - #[serde(default, deserialize_with = "nullable_default")] - path_parameters: StrMap, - #[serde(default, deserialize_with = "nullable_default")] - stage_variables: StrMap, - body: Option>, - #[serde(default)] - is_base64_encoded: bool, - request_context: ApiGatewayRequestContext, - #[serde(default, deserialize_with = "nullable_default")] - resource: Option, - }, +pub enum LambdaRequest { + ApiGatewayV1(ApiGatewayProxyRequest), + ApiGatewayV2(ApiGatewayV2httpRequest), + Alb(AlbTargetGroupRequest), + WebSocket(ApiGatewayWebsocketProxyRequest), } -impl LambdaRequest<'_> { +impl LambdaRequest { /// Return the `RequestOrigin` of the request to determine where the `LambdaRequest` /// originated from, so that the appropriate response can be selected based on what /// type of response the request origin expects. pub fn request_origin(&self) -> RequestOrigin { match self { + LambdaRequest::ApiGatewayV1 { .. } => RequestOrigin::ApiGatewayV1, LambdaRequest::ApiGatewayV2 { .. } => RequestOrigin::ApiGatewayV2, LambdaRequest::Alb { .. } => RequestOrigin::Alb, - LambdaRequest::ApiGateway { .. } => RequestOrigin::ApiGateway, + LambdaRequest::WebSocket { .. } => RequestOrigin::WebSocket, } } } @@ -109,488 +49,287 @@ impl LambdaRequest<'_> { #[doc(hidden)] #[derive(Debug)] pub enum RequestOrigin { + /// API Gateway request origin + ApiGatewayV1, /// API Gateway v2 request origin ApiGatewayV2, - /// API Gateway request origin - ApiGateway, /// ALB request origin Alb, + /// API Gateway WebSocket + WebSocket, } -/// See [context-variable-reference](https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html) for more detail. -#[derive(Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct ApiGatewayV2RequestContext { - /// The API owner's AWS account ID. - pub account_id: String, - /// The identifier API Gateway assigns to your API. - pub api_id: String, - /// The stringified value of the specified key-value pair of the context map returned from an API Gateway Lambda authorizer function. - #[serde(default)] - pub authorizer: HashMap, - /// The full domain name used to invoke the API. This should be the same as the incoming Host header. - pub domain_name: String, - /// The first label of the $context.domainName. This is often used as a caller/customer identifier. - pub domain_prefix: String, - /// The HTTP method used. - pub http: Http, - /// The ID that API Gateway assigns to the API request. - pub request_id: String, - /// Undocumented, could be resourcePath - pub route_key: String, - /// The deployment stage of the API request (for example, Beta or Prod). - pub stage: String, - /// Undocumented, could be requestTime - pub time: String, - /// Undocumented, could be requestTimeEpoch - pub time_epoch: usize, -} +fn into_api_gateway_v2_request(ag: ApiGatewayV2httpRequest) -> http::Request { + let http_method = ag.request_context.http.method.clone(); + let builder = http::Request::builder() + .uri({ + let scheme = ag + .headers + .get(x_forwarded_proto()) + .and_then(|s| s.to_str().ok()) + .unwrap_or("https"); + let host = ag + .headers + .get(http::header::HOST) + .and_then(|s| s.to_str().ok()) + .or_else(|| ag.request_context.domain_name.as_deref()) + .unwrap_or_default(); + + let path = apigw_path_with_stage(&ag.request_context.stage, ag.raw_path.as_deref().unwrap_or_default()); + let mut url = format!("{}://{}{}", scheme, host, path); + + if let Some(query) = ag.raw_query_string { + url.push('?'); + url.push_str(&query); + } + url + }) + .extension(QueryStringParameters(ag.query_string_parameters)) + .extension(PathParameters(QueryMap::from(ag.path_parameters))) + .extension(StageVariables(QueryMap::from(ag.stage_variables))) + .extension(RequestContext::ApiGatewayV2(ag.request_context)); + + let mut headers = ag.headers; + if let Some(cookies) = ag.cookies { + if let Ok(header_value) = http::header::HeaderValue::from_str(&cookies.join(";")) { + headers.append(http::header::COOKIE, header_value); + } + } -/// See [context-variable-reference](https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html) for more detail. -#[derive(Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct ApiGatewayRequestContext { - /// The API owner's AWS account ID. - pub account_id: String, - /// The identifier that API Gateway assigns to your resource. - pub resource_id: String, - /// The deployment stage of the API request (for example, Beta or Prod). - pub stage: String, - /// The full domain name used to invoke the API. This should be the same as the incoming Host header. - pub domain_name: Option, - /// The first label of the $context.domainName. This is often used as a caller/customer identifier. - pub domain_prefix: Option, - /// The ID that API Gateway assigns to the API request. - pub request_id: String, - /// The path to your resource. For example, for the non-proxy request URI of `https://{rest-api-id.execute-api.{region}.amazonaws.com/{stage}/root/child`, The $context.resourcePath value is /root/child. - pub resource_path: String, - /// The request protocol, for example, HTTP/1.1. - pub protocol: Option, - /// The CLF-formatted request time (dd/MMM/yyyy:HH:mm:ss +-hhmm). - pub request_time: Option, - /// The Epoch-formatted request time, in milliseconds. - pub request_time_epoch: i64, - /// The identifier API Gateway assigns to your API. - pub apiid: Option, - /// The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT. - pub http_method: String, - /// The stringified value of the specified key-value pair of the context map returned from an API Gateway Lambda authorizer function. - #[serde(default)] - pub authorizer: HashMap, - /// The identifier API Gateway assigns to your API. - pub api_id: String, - /// Cofnito identity information - #[serde(default)] - pub identity: Identity, -} + let base64 = ag.is_base64_encoded; -/// Elastic load balancer context information -#[derive(Deserialize, Debug, Clone)] -#[serde(rename_all = "camelCase")] -pub struct AlbRequestContext { - /// Elastic load balancer context information - pub elb: Elb, -} + let mut req = builder + .body( + ag.body + .as_deref() + .map_or_else(Body::default, |b| Body::from_maybe_encoded(base64, b)), + ) + .expect("failed to build request"); -/// Event request context as an enumeration of request contexts -/// for both ALB and API Gateway and HTTP API events -#[derive(Deserialize, Debug, Clone)] -#[serde(untagged)] -pub enum RequestContext { - /// API Gateway v2 request context - ApiGatewayV2(ApiGatewayV2RequestContext), - /// API Gateway request context - ApiGateway(ApiGatewayRequestContext), - /// ALB request context - Alb(AlbRequestContext), -} + // no builder method that sets headers in batch + let _ = mem::replace(req.headers_mut(), headers); + let _ = mem::replace(req.method_mut(), http_method); -/// Elastic load balancer context information -#[derive(Deserialize, Debug, Default, Clone)] -#[serde(rename_all = "camelCase")] -pub struct Elb { - /// AWS ARN identifier for the ELB Target Group this lambda was triggered by - pub target_group_arn: String, + req } -/// Http information captured API Gateway v2 request context -#[derive(Deserialize, Debug, Default, Clone)] -#[serde(rename_all = "camelCase")] -pub struct Http { - #[serde(deserialize_with = "deserialize_method")] - /// The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT. - pub method: http::Method, - /// The request path. For example, for a non-proxy request URL of - /// `https://{rest-api-id.execute-api.{region}.amazonaws.com/{stage}/root/child`, - /// the $context.path value is `/{stage}/root/child`. - pub path: String, - /// The request protocol, for example, HTTP/1.1. - pub protocol: String, - /// The source IP address of the TCP connection making the request to API Gateway. - pub source_ip: String, - /// The User-Agent header of the API caller. - pub user_agent: String, -} +fn into_proxy_request(ag: ApiGatewayProxyRequest) -> http::Request { + let http_method = ag.http_method; + let builder = http::Request::builder() + .uri({ + let host = ag.headers.get(http::header::HOST).and_then(|s| s.to_str().ok()); + let path = apigw_path_with_stage(&ag.request_context.stage, &ag.path.unwrap_or_default()); + + let mut url = match host { + None => path, + Some(host) => { + let scheme = ag + .headers + .get(x_forwarded_proto()) + .and_then(|s| s.to_str().ok()) + .unwrap_or("https"); + format!("{}://{}{}", scheme, host, path) + } + }; + + if !ag.multi_value_query_string_parameters.is_empty() { + url.push('?'); + url.push_str(&ag.multi_value_query_string_parameters.to_query_string()); + } else if !ag.query_string_parameters.is_empty() { + url.push('?'); + url.push_str(&ag.query_string_parameters.to_query_string()); + } + url + }) + // multi-valued query string parameters are always a super + // set of singly valued query string parameters, + // when present, multi-valued query string parameters are preferred + .extension(QueryStringParameters( + if ag.multi_value_query_string_parameters.is_empty() { + ag.query_string_parameters + } else { + ag.multi_value_query_string_parameters + }, + )) + .extension(PathParameters(QueryMap::from(ag.path_parameters))) + .extension(StageVariables(QueryMap::from(ag.stage_variables))) + .extension(RequestContext::ApiGatewayV1(ag.request_context)); + + // merge headers into multi_value_headers and make + // multi-value_headers our cannoncial source of request headers + let mut headers = ag.multi_value_headers; + headers.extend(ag.headers); + + let base64 = ag.is_base64_encoded.unwrap_or_default(); + let mut req = builder + .body( + ag.body + .as_deref() + .map_or_else(Body::default, |b| Body::from_maybe_encoded(base64, b)), + ) + .expect("failed to build request"); -/// Identity assoicated with API Gateway request -#[derive(Deserialize, Debug, Default, Clone)] -#[serde(rename_all = "camelCase")] -pub struct Identity { - /// The source IP address of the TCP connection making the request to API Gateway. - pub source_ip: String, - /// The Amazon Cognito identity ID of the caller making the request. - /// Available only if the request was signed with Amazon Cognito credentials. - pub cognito_identity_id: Option, - /// The Amazon Cognito identity pool ID of the caller making the request. - /// Available only if the request was signed with Amazon Cognito credentials. - pub cognito_identity_pool_id: Option, - /// A comma-separated list of the Amazon Cognito authentication providers used by the caller making the request. - /// Available only if the request was signed with Amazon Cognito credentials. - pub cognito_authentication_provider: Option, - /// The Amazon Cognito authentication type of the caller making the request. - /// Available only if the request was signed with Amazon Cognito credentials. - pub cognito_authentication_type: Option, - /// The AWS account ID associated with the request. - pub account_id: Option, - /// The principal identifier of the caller making the request. - pub caller: Option, - /// For API methods that require an API key, this variable is the API key associated with the method request. - /// For methods that don't require an API key, this variable is null. - pub api_key: Option, - /// Undocumented. Can be the API key ID associated with an API request that requires an API key. - /// The description of `api_key` and `access_key` may actually be reversed. - pub access_key: Option, - /// The principal identifier of the user making the request. Used in Lambda authorizers. - pub user: Option, - /// The User-Agent header of the API caller. - pub user_agent: Option, - /// The Amazon Resource Name (ARN) of the effective user identified after authentication. - pub user_arn: Option, -} + // no builder method that sets headers in batch + let _ = mem::replace(req.headers_mut(), headers); + let _ = mem::replace(req.method_mut(), http_method); -/// Deserialize a str into an http::Method -fn deserialize_method<'de, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, -{ - struct MethodVisitor; + req +} - impl<'de> Visitor<'de> for MethodVisitor { - type Value = http::Method; +fn into_alb_request(alb: AlbTargetGroupRequest) -> http::Request { + let http_method = alb.http_method; + let builder = http::Request::builder() + .uri({ + let scheme = alb + .headers + .get(x_forwarded_proto()) + .and_then(|s| s.to_str().ok()) + .unwrap_or("https"); + let host = alb + .headers + .get(http::header::HOST) + .and_then(|s| s.to_str().ok()) + .unwrap_or_default(); + + let mut url = format!("{}://{}{}", scheme, host, alb.path.unwrap_or_default()); + if !alb.multi_value_query_string_parameters.is_empty() { + url.push('?'); + url.push_str(&alb.multi_value_query_string_parameters.to_query_string()); + } else if !alb.query_string_parameters.is_empty() { + url.push('?'); + url.push_str(&alb.query_string_parameters.to_query_string()); + } - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(formatter, "a Method") - } + url + }) + // multi valued query string parameters are always a super + // set of singly valued query string parameters, + // when present, multi-valued query string parameters are preferred + .extension(QueryStringParameters( + if alb.multi_value_query_string_parameters.is_empty() { + alb.query_string_parameters + } else { + alb.multi_value_query_string_parameters + }, + )) + .extension(RequestContext::Alb(alb.request_context)); - fn visit_str(self, v: &str) -> Result - where - E: DeError, - { - v.parse().map_err(E::custom) - } - } + // merge headers into multi_value_headers and make + // multi-value_headers our cannoncial source of request headers + let mut headers = alb.multi_value_headers; + headers.extend(alb.headers); - deserializer.deserialize_str(MethodVisitor) -} + let base64 = alb.is_base64_encoded; -/// Deserialize a map of Cow<'_, str> => Vec> into an http::HeaderMap -fn deserialize_multi_value_headers<'de, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, -{ - struct HeaderVisitor; + let mut req = builder + .body( + alb.body + .as_deref() + .map_or_else(Body::default, |b| Body::from_maybe_encoded(base64, b)), + ) + .expect("failed to build request"); - impl<'de> Visitor<'de> for HeaderVisitor { - type Value = http::HeaderMap; + // no builder method that sets headers in batch + let _ = mem::replace(req.headers_mut(), headers); + let _ = mem::replace(req.method_mut(), http_method); - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(formatter, "a multi valued HeaderMap") - } + req +} - fn visit_map(self, mut map: A) -> Result - where - A: MapAccess<'de>, - { - let mut headers = map - .size_hint() - .map(http::HeaderMap::with_capacity) - .unwrap_or_else(http::HeaderMap::new); - while let Some((key, values)) = map.next_entry::, Vec>>()? { - // note the aws docs for multi value headers include an empty key. I'm not sure if this is a doc bug - // or not by the http crate doesn't handle it - // https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html#api-gateway-simple-proxy-for-lambda-input-format - if !key.is_empty() { - for value in values { - let header_name = key.parse::().map_err(A::Error::custom)?; - let header_value = http::header::HeaderValue::from_maybe_shared(value.into_owned()) - .map_err(A::Error::custom)?; - headers.append(header_name, header_value); - } +fn into_websocket_request(ag: ApiGatewayWebsocketProxyRequest) -> http::Request { + let http_method = ag.http_method; + let builder = http::Request::builder() + .uri({ + let host = ag.headers.get(http::header::HOST).and_then(|s| s.to_str().ok()); + let path = apigw_path_with_stage(&ag.request_context.stage, &ag.path.unwrap_or_default()); + + let mut url = match host { + None => path, + Some(host) => { + let scheme = ag + .headers + .get(x_forwarded_proto()) + .and_then(|s| s.to_str().ok()) + .unwrap_or("https"); + format!("{}://{}{}", scheme, host, path) } + }; + + if !ag.multi_value_query_string_parameters.is_empty() { + url.push('?'); + url.push_str(&ag.multi_value_query_string_parameters.to_query_string()); + } else if !ag.query_string_parameters.is_empty() { + url.push('?'); + url.push_str(&ag.query_string_parameters.to_query_string()); } - Ok(headers) - } - } - - Ok(deserializer.deserialize_map(HeaderVisitor).unwrap_or_default()) -} - -/// Deserialize a map of Cow<'_, str> => Cow<'_, str> into an http::HeaderMap -fn deserialize_headers<'de, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, -{ - struct HeaderVisitor; + url + }) + // multi-valued query string parameters are always a super + // set of singly valued query string parameters, + // when present, multi-valued query string parameters are preferred + .extension(QueryStringParameters( + if ag.multi_value_query_string_parameters.is_empty() { + ag.query_string_parameters + } else { + ag.multi_value_query_string_parameters + }, + )) + .extension(PathParameters(QueryMap::from(ag.path_parameters))) + .extension(StageVariables(QueryMap::from(ag.stage_variables))) + .extension(RequestContext::WebSocket(ag.request_context)); + + // merge headers into multi_value_headers and make + // multi-value_headers our cannoncial source of request headers + let mut headers = ag.multi_value_headers; + headers.extend(ag.headers); + + let base64 = ag.is_base64_encoded.unwrap_or_default(); + let mut req = builder + .body( + ag.body + .as_deref() + .map_or_else(Body::default, |b| Body::from_maybe_encoded(base64, b)), + ) + .expect("failed to build request"); - impl<'de> Visitor<'de> for HeaderVisitor { - type Value = http::HeaderMap; + // no builder method that sets headers in batch + let _ = mem::replace(req.headers_mut(), headers); + let _ = mem::replace(req.method_mut(), http_method.unwrap_or(http::Method::GET)); - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(formatter, "a HeaderMap") - } + req +} - fn visit_map(self, mut map: A) -> Result - where - A: MapAccess<'de>, - { - let mut headers = map - .size_hint() - .map(http::HeaderMap::with_capacity) - .unwrap_or_else(http::HeaderMap::new); - while let Some((key, value)) = map.next_entry::, Cow<'_, str>>()? { - let header_name = key.parse::().map_err(A::Error::custom)?; - let header_value = - http::header::HeaderValue::from_maybe_shared(value.into_owned()).map_err(A::Error::custom)?; - headers.append(header_name, header_value); - } - Ok(headers) - } +fn apigw_path_with_stage(stage: &Option, path: &str) -> String { + match stage { + None => path.into(), + Some(stage) if stage == "$default" => path.into(), + Some(stage) => format!("/{}{}", stage, path), } - - Ok(deserializer.deserialize_map(HeaderVisitor).unwrap_or_default()) } -/// deserializes (json) null values to their default values -// https://github.com/serde-rs/serde/issues/1098 -fn nullable_default<'de, T, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, - T: Default + Deserialize<'de>, -{ - let opt = Option::deserialize(deserializer)?; - Ok(opt.unwrap_or_default()) +/// Event request context as an enumeration of request contexts +/// for both ALB and API Gateway and HTTP API events +#[derive(Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum RequestContext { + /// API Gateway proxy request context + ApiGatewayV1(ApiGatewayProxyRequestContext), + /// API Gateway v2 request context + ApiGatewayV2(ApiGatewayV2httpRequestContext), + /// ALB request context + Alb(AlbTargetGroupRequestContext), + /// WebSocket request context + WebSocket(ApiGatewayWebsocketProxyRequestContext), } /// Converts LambdaRequest types into `http::Request` types -impl<'a> From> for http::Request { - fn from(value: LambdaRequest<'_>) -> Self { +impl<'a> From for http::Request { + fn from(value: LambdaRequest) -> Self { match value { - LambdaRequest::ApiGatewayV2 { - raw_path, - raw_query_string, - mut headers, - query_string_parameters, - path_parameters, - stage_variables, - body, - is_base64_encoded, - request_context, - cookies, - .. - } => { - if let Some(cookies) = cookies { - if let Ok(header_value) = http::header::HeaderValue::from_str(&cookies.join(";")) { - headers.append(http::header::COOKIE, header_value); - } - } - - let builder = http::Request::builder() - .method(request_context.http.method.as_ref()) - .uri({ - let mut url = format!( - "{}://{}{}", - headers - .get("X-Forwarded-Proto") - .and_then(|val| val.to_str().ok()) - .unwrap_or("https"), - headers - .get(http::header::HOST) - .and_then(|val| val.to_str().ok()) - .unwrap_or_else(|| request_context.domain_name.as_ref()), - raw_path - ); - if !raw_query_string.is_empty() { - url.push('?'); - url.push_str(raw_query_string.as_ref()); - } - url - }) - .extension(QueryStringParameters(query_string_parameters)) - .extension(PathParameters(path_parameters)) - .extension(StageVariables(stage_variables)) - .extension(RequestContext::ApiGatewayV2(request_context)); - - let mut req = builder - .body(body.map_or_else(Body::default, |b| Body::from_maybe_encoded(is_base64_encoded, b))) - .expect("failed to build request"); - - // no builder method that sets headers in batch - let _ = mem::replace(req.headers_mut(), headers); - - req - } - LambdaRequest::ApiGateway { - path, - http_method, - headers, - mut multi_value_headers, - query_string_parameters, - multi_value_query_string_parameters, - path_parameters, - stage_variables, - body, - is_base64_encoded, - request_context, - resource: _, - } => { - let builder = http::Request::builder() - .method(http_method) - .uri({ - let host = headers.get(http::header::HOST).and_then(|val| val.to_str().ok()); - let mut uri = match host { - Some(host) => { - format!( - "{}://{}{}", - headers - .get("X-Forwarded-Proto") - .and_then(|val| val.to_str().ok()) - .unwrap_or("https"), - host, - path - ) - } - None => path.to_string(), - }; - - if !multi_value_query_string_parameters.is_empty() { - uri.push('?'); - uri.push_str(multi_value_query_string_parameters.to_query_string().as_str()); - } else if !query_string_parameters.is_empty() { - uri.push('?'); - uri.push_str(query_string_parameters.to_query_string().as_str()); - } - - uri - }) - // multi-valued query string parameters are always a super - // set of singly valued query string parameters, - // when present, multi-valued query string parameters are preferred - .extension(QueryStringParameters( - if multi_value_query_string_parameters.is_empty() { - query_string_parameters - } else { - multi_value_query_string_parameters - }, - )) - .extension(PathParameters(path_parameters)) - .extension(StageVariables(stage_variables)) - .extension(RequestContext::ApiGateway(request_context)); - - let mut req = builder - .body(body.map_or_else(Body::default, |b| Body::from_maybe_encoded(is_base64_encoded, b))) - .expect("failed to build request"); - - // merge headers into multi_value_headers and make - // multi-value_headers our cannoncial source of request headers - for (key, value) in headers { - // see HeaderMap#into_iter() docs for cases when key element may be None - if let Some(first_key) = key { - // if it contains the key, avoid appending a duplicate value - if !multi_value_headers.contains_key(&first_key) { - multi_value_headers.append(first_key, value); - } - } - } - - // no builder method that sets headers in batch - let _ = mem::replace(req.headers_mut(), multi_value_headers); - - req - } - LambdaRequest::Alb { - path, - http_method, - headers, - mut multi_value_headers, - query_string_parameters, - multi_value_query_string_parameters, - body, - is_base64_encoded, - request_context, - } => { - // build an http::Request from a lambda_http::LambdaRequest - let builder = http::Request::builder() - .method(http_method) - .uri({ - let host = headers.get(http::header::HOST).and_then(|val| val.to_str().ok()); - let mut uri = match host { - Some(host) => { - format!( - "{}://{}{}", - headers - .get("X-Forwarded-Proto") - .and_then(|val| val.to_str().ok()) - .unwrap_or("https"), - host, - path - ) - } - None => path.to_string(), - }; - - if !multi_value_query_string_parameters.is_empty() { - uri.push('?'); - uri.push_str(multi_value_query_string_parameters.to_query_string().as_str()); - } else if !query_string_parameters.is_empty() { - uri.push('?'); - uri.push_str(query_string_parameters.to_query_string().as_str()); - } - - uri - }) - // multi valued query string parameters are always a super - // set of singly valued query string parameters, - // when present, multi-valued query string parameters are preferred - .extension(QueryStringParameters( - if multi_value_query_string_parameters.is_empty() { - query_string_parameters - } else { - multi_value_query_string_parameters - }, - )) - .extension(RequestContext::Alb(request_context)); - - let mut req = builder - .body(body.map_or_else(Body::default, |b| Body::from_maybe_encoded(is_base64_encoded, b))) - .expect("failed to build request"); - - // merge headers into multi_value_headers and make - // multi-value_headers our cannoncial source of request headers - for (key, value) in headers { - // see HeaderMap#into_iter() docs for cases when key element may be None - if let Some(first_key) = key { - // if it contains the key, avoid appending a duplicate value - if !multi_value_headers.contains_key(&first_key) { - multi_value_headers.append(first_key, value); - } - } - } - - // no builder method that sets headers in batch - let _ = mem::replace(req.headers_mut(), multi_value_headers); - - req - } + LambdaRequest::ApiGatewayV2(ag) => into_api_gateway_v2_request(ag), + LambdaRequest::ApiGatewayV1(ag) => into_proxy_request(ag), + LambdaRequest::Alb(alb) => into_alb_request(alb), + LambdaRequest::WebSocket(ag) => into_websocket_request(ag), } } } @@ -638,12 +377,15 @@ pub fn from_str(s: &str) -> Result { serde_json::from_str(s).map(LambdaRequest::into) } +fn x_forwarded_proto() -> HeaderName { + HeaderName::from_static("x-forwarded-proto") +} + #[cfg(test)] mod tests { use super::*; use crate::RequestExt; - use serde_json; - use std::{collections::HashMap, fs::File}; + use std::fs::File; #[test] fn deserializes_apigw_request_events_from_readables() { @@ -734,14 +476,14 @@ mod tests { assert_eq!(req.method(), "GET"); assert_eq!( req.uri(), - "https://wt6mne2s9k.execute-api.us-west-2.amazonaws.com/test/hello?name=me" + "https://wt6mne2s9k.execute-api.us-west-2.amazonaws.com/test/test/hello?name=me" ); // Ensure this is an APIGW request let req_context = req.request_context(); assert!( match req_context { - RequestContext::ApiGateway(_) => true, + RequestContext::ApiGatewayV1(_) => true, _ => false, }, "expected ApiGateway context, got {:?}", @@ -798,7 +540,7 @@ mod tests { // test RequestExt#query_string_parameters does the right thing assert_eq!( - request.query_string_parameters().get_all("multivalueName"), + request.query_string_parameters().all("multivalueName"), Some(vec!["you", "me"]) ); } @@ -820,7 +562,7 @@ mod tests { // test RequestExt#query_string_parameters does the right thing assert_eq!( - request.query_string_parameters().get_all("myKey"), + request.query_string_parameters().all("myKey"), Some(vec!["val1", "val2"]) ); } @@ -861,66 +603,6 @@ mod tests { ); let req = result.expect("failed to parse request"); assert_eq!(req.method(), "GET"); - assert_eq!(req.uri(), "/test/hello?name=me"); - } - - #[test] - fn deserialize_with_null() { - #[derive(Debug, PartialEq, Deserialize)] - struct Test { - #[serde(deserialize_with = "nullable_default")] - foo: HashMap, - } - - assert_eq!( - serde_json::from_str::(r#"{"foo":null}"#).expect("failed to deserialize"), - Test { foo: HashMap::new() } - ) - } - - #[test] - fn deserialize_with_missing() { - #[derive(Debug, PartialEq, Deserialize)] - struct Test { - #[serde(default, deserialize_with = "nullable_default")] - foo: HashMap, - } - - assert_eq!( - serde_json::from_str::(r#"{}"#).expect("failed to deserialize"), - Test { foo: HashMap::new() } - ) - } - - #[test] - fn deserialize_null_headers() { - #[derive(Debug, PartialEq, Deserialize)] - struct Test { - #[serde(deserialize_with = "deserialize_headers")] - headers: http::HeaderMap, - } - - assert_eq!( - serde_json::from_str::(r#"{"headers":null}"#).expect("failed to deserialize"), - Test { - headers: http::HeaderMap::new() - } - ) - } - - #[test] - fn deserialize_null_multi_value_headers() { - #[derive(Debug, PartialEq, Deserialize)] - struct Test { - #[serde(deserialize_with = "deserialize_multi_value_headers")] - multi_value_headers: http::HeaderMap, - } - - assert_eq!( - serde_json::from_str::(r#"{"multi_value_headers":null}"#).expect("failed to deserialize"), - Test { - multi_value_headers: http::HeaderMap::new() - } - ) + assert_eq!(req.uri(), "/test/test/hello?name=me"); } } diff --git a/lambda-http/src/response.rs b/lambda-http/src/response.rs index 8c3e2a1d..4ea9c895 100644 --- a/lambda-http/src/response.rs +++ b/lambda-http/src/response.rs @@ -1,108 +1,23 @@ //! Response types -use crate::{body::Body, request::RequestOrigin}; +use crate::request::RequestOrigin; +use aws_lambda_events::encodings::Body; +use aws_lambda_events::event::alb::AlbTargetGroupResponse; +use aws_lambda_events::event::apigw::{ApiGatewayProxyResponse, ApiGatewayV2httpResponse}; use http::{ - header::{HeaderMap, HeaderValue, CONTENT_TYPE, SET_COOKIE}, + header::{CONTENT_TYPE, SET_COOKIE}, Response, }; -use serde::{ - ser::{Error as SerError, SerializeMap, SerializeSeq}, - Serialize, Serializer, -}; +use serde::Serialize; /// Representation of Lambda response #[doc(hidden)] #[derive(Serialize, Debug)] #[serde(untagged)] pub enum LambdaResponse { - ApiGatewayV2(ApiGatewayV2Response), - Alb(AlbResponse), - ApiGateway(ApiGatewayResponse), -} - -/// Representation of API Gateway v2 lambda response -#[doc(hidden)] -#[derive(Serialize, Debug)] -#[serde(rename_all = "camelCase")] -pub struct ApiGatewayV2Response { - status_code: u16, - #[serde(serialize_with = "serialize_headers")] - headers: HeaderMap, - #[serde(serialize_with = "serialize_headers_slice")] - cookies: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - body: Option, - is_base64_encoded: bool, -} - -/// Representation of ALB lambda response -#[doc(hidden)] -#[derive(Serialize, Debug)] -#[serde(rename_all = "camelCase")] -pub struct AlbResponse { - status_code: u16, - status_description: String, - #[serde(serialize_with = "serialize_headers")] - headers: HeaderMap, - #[serde(skip_serializing_if = "Option::is_none")] - body: Option, - is_base64_encoded: bool, -} - -/// Representation of API Gateway lambda response -#[doc(hidden)] -#[derive(Serialize, Debug)] -#[serde(rename_all = "camelCase")] -pub struct ApiGatewayResponse { - status_code: u16, - #[serde(serialize_with = "serialize_headers")] - headers: HeaderMap, - #[serde(serialize_with = "serialize_multi_value_headers")] - multi_value_headers: HeaderMap, - #[serde(skip_serializing_if = "Option::is_none")] - body: Option, - is_base64_encoded: bool, -} - -/// Serialize a http::HeaderMap into a serde str => str map -fn serialize_multi_value_headers(headers: &HeaderMap, serializer: S) -> Result -where - S: Serializer, -{ - let mut map = serializer.serialize_map(Some(headers.keys_len()))?; - for key in headers.keys() { - let mut map_values = Vec::new(); - for value in headers.get_all(key) { - map_values.push(value.to_str().map_err(S::Error::custom)?) - } - map.serialize_entry(key.as_str(), &map_values)?; - } - map.end() -} - -/// Serialize a http::HeaderMap into a serde str => Vec map -fn serialize_headers(headers: &HeaderMap, serializer: S) -> Result -where - S: Serializer, -{ - let mut map = serializer.serialize_map(Some(headers.keys_len()))?; - for key in headers.keys() { - let map_value = headers[key].to_str().map_err(S::Error::custom)?; - map.serialize_entry(key.as_str(), map_value)?; - } - map.end() -} - -/// Serialize a &[HeaderValue] into a Vec -fn serialize_headers_slice(headers: &[HeaderValue], serializer: S) -> Result -where - S: Serializer, -{ - let mut seq = serializer.serialize_seq(Some(headers.len()))?; - for header in headers { - seq.serialize_element(header.to_str().map_err(S::Error::custom)?)?; - } - seq.end() + ApiGatewayV2(ApiGatewayV2httpResponse), + ApiGatewayV1(ApiGatewayProxyResponse), + Alb(AlbTargetGroupResponse), } /// tranformation from http type to internal type @@ -125,34 +40,48 @@ impl LambdaResponse { RequestOrigin::ApiGatewayV2 => { // ApiGatewayV2 expects the set-cookies headers to be in the "cookies" attribute, // so remove them from the headers. - let cookies: Vec = headers.get_all(SET_COOKIE).iter().cloned().collect(); + let cookies = headers + .get_all(SET_COOKIE) + .iter() + .cloned() + .map(|v| v.to_str().ok().unwrap_or_default().to_string()) + .collect(); headers.remove(SET_COOKIE); - LambdaResponse::ApiGatewayV2(ApiGatewayV2Response { + LambdaResponse::ApiGatewayV2(ApiGatewayV2httpResponse { body, - status_code, - is_base64_encoded, + status_code: status_code as i64, + is_base64_encoded: Some(is_base64_encoded), cookies, - headers, + headers: headers.clone(), + multi_value_headers: headers, }) } - RequestOrigin::ApiGateway => LambdaResponse::ApiGateway(ApiGatewayResponse { + RequestOrigin::ApiGatewayV1 => LambdaResponse::ApiGatewayV1(ApiGatewayProxyResponse { body, - status_code, - is_base64_encoded, + status_code: status_code as i64, + is_base64_encoded: Some(is_base64_encoded), headers: headers.clone(), multi_value_headers: headers, }), - RequestOrigin::Alb => LambdaResponse::Alb(AlbResponse { + RequestOrigin::Alb => LambdaResponse::Alb(AlbTargetGroupResponse { body, - status_code, + status_code: status_code as i64, is_base64_encoded, - headers, - status_description: format!( + headers: headers.clone(), + multi_value_headers: headers, + status_description: Some(format!( "{} {}", status_code, parts.status.canonical_reason().unwrap_or_default() - ), + )), + }), + RequestOrigin::WebSocket => LambdaResponse::ApiGatewayV1(ApiGatewayProxyResponse { + body, + status_code: status_code as i64, + is_base64_encoded: Some(is_base64_encoded), + headers: headers.clone(), + multi_value_headers: headers, }), } } @@ -189,12 +118,15 @@ where } } -impl IntoResponse for B -where - B: Into, -{ +impl IntoResponse for String { + fn into_response(self) -> Response { + Response::new(Body::from(self)) + } +} + +impl IntoResponse for &str { fn into_response(self) -> Response { - Response::new(self.into()) + Response::new(Body::from(self)) } } @@ -213,42 +145,10 @@ impl IntoResponse for serde_json::Value { #[cfg(test)] mod tests { - use super::{ - AlbResponse, ApiGatewayResponse, ApiGatewayV2Response, Body, IntoResponse, LambdaResponse, RequestOrigin, - }; + use super::{Body, IntoResponse, LambdaResponse, RequestOrigin}; use http::{header::CONTENT_TYPE, Response}; use serde_json::{self, json}; - fn api_gateway_response() -> ApiGatewayResponse { - ApiGatewayResponse { - status_code: 200, - headers: Default::default(), - multi_value_headers: Default::default(), - body: Default::default(), - is_base64_encoded: Default::default(), - } - } - - fn alb_response() -> AlbResponse { - AlbResponse { - status_code: 200, - status_description: "200 OK".to_string(), - headers: Default::default(), - body: Default::default(), - is_base64_encoded: Default::default(), - } - } - - fn api_gateway_v2_response() -> ApiGatewayV2Response { - ApiGatewayV2Response { - status_code: 200, - headers: Default::default(), - body: Default::default(), - cookies: Default::default(), - is_base64_encoded: Default::default(), - } - } - #[test] fn json_into_response() { let response = json!({ "hello": "lambda"}).into_response(); @@ -274,40 +174,10 @@ mod tests { } } - #[test] - fn serialize_body_for_api_gateway() { - let mut resp = api_gateway_response(); - resp.body = Some("foo".into()); - assert_eq!( - serde_json::to_string(&resp).expect("failed to serialize response"), - r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"foo","isBase64Encoded":false}"# - ); - } - - #[test] - fn serialize_body_for_alb() { - let mut resp = alb_response(); - resp.body = Some("foo".into()); - assert_eq!( - serde_json::to_string(&resp).expect("failed to serialize response"), - r#"{"statusCode":200,"statusDescription":"200 OK","headers":{},"body":"foo","isBase64Encoded":false}"# - ); - } - - #[test] - fn serialize_body_for_api_gateway_v2() { - let mut resp = api_gateway_v2_response(); - resp.body = Some("foo".into()); - assert_eq!( - serde_json::to_string(&resp).expect("failed to serialize response"), - r#"{"statusCode":200,"headers":{},"cookies":[],"body":"foo","isBase64Encoded":false}"# - ); - } - #[test] fn serialize_multi_value_headers() { let res = LambdaResponse::from_response( - &RequestOrigin::ApiGateway, + &RequestOrigin::ApiGatewayV1, Response::builder() .header("multi", "a") .header("multi", "b") @@ -333,8 +203,8 @@ mod tests { ); let json = serde_json::to_string(&res).expect("failed to serialize to json"); assert_eq!( - json, - r#"{"statusCode":200,"headers":{},"cookies":["cookie1=a","cookie2=b"],"isBase64Encoded":false}"# + "{\"statusCode\":200,\"headers\":{},\"multiValueHeaders\":{},\"isBase64Encoded\":false,\"cookies\":[\"cookie1=a\",\"cookie2=b\"]}", + json ) } } diff --git a/lambda-http/src/strmap.rs b/lambda-http/src/strmap.rs deleted file mode 100644 index 066c575a..00000000 --- a/lambda-http/src/strmap.rs +++ /dev/null @@ -1,221 +0,0 @@ -use serde::{ - de::{MapAccess, Visitor}, - Deserialize, Deserializer, -}; -use std::{ - collections::{hash_map::Keys, HashMap}, - fmt, - sync::Arc, -}; - -/// A read-only view into a map of string data which may contain multiple values -/// -/// Internally data is always represented as many valued -#[derive(Default, Debug, PartialEq)] -pub struct StrMap(pub(crate) Arc>>); - -impl StrMap { - /// Return a named value where available. - /// If there is more than one value associated with this name, - /// the first one will be returned - pub fn get(&self, key: &str) -> Option<&str> { - self.0 - .get(key) - .and_then(|values| values.first().map(|owned| owned.as_str())) - } - - /// Return all values associated with name where available - pub fn get_all(&self, key: &str) -> Option> { - self.0 - .get(key) - .map(|values| values.iter().map(|owned| owned.as_str()).collect::>()) - } - - /// Return true if the underlying map is empty - pub fn is_empty(&self) -> bool { - self.0.is_empty() - } - - /// Return an iterator over keys and values - pub fn iter(&self) -> StrMapIter<'_> { - StrMapIter { - data: self, - keys: self.0.keys(), - current: None, - next_idx: 0, - } - } - - /// Return the URI query representation for this map - pub fn to_query_string(&self) -> String { - if self.is_empty() { - "".into() - } else { - self.iter() - .map(|(k, v)| format!("{}={}", k, v)) - .collect::>() - .join("&") - } - } -} - -impl Clone for StrMap { - fn clone(&self) -> Self { - // only clone the inner data - StrMap(self.0.clone()) - } -} - -impl From>> for StrMap { - fn from(inner: HashMap>) -> Self { - StrMap(Arc::new(inner)) - } -} - -/// A read only reference to `StrMap` key and value slice pairings -pub struct StrMapIter<'a> { - data: &'a StrMap, - keys: Keys<'a, String, Vec>, - current: Option<(&'a String, Vec<&'a str>)>, - next_idx: usize, -} - -impl<'a> Iterator for StrMapIter<'a> { - type Item = (&'a str, &'a str); - - #[inline] - fn next(&mut self) -> Option<(&'a str, &'a str)> { - if self.current.is_none() { - self.current = self.keys.next().map(|k| (k, self.data.get_all(k).unwrap_or_default())); - }; - - let mut reset = false; - let ret = if let Some((key, values)) = &self.current { - let value = values[self.next_idx]; - - if self.next_idx + 1 < values.len() { - self.next_idx += 1; - } else { - reset = true; - } - - Some((key.as_str(), value)) - } else { - None - }; - - if reset { - self.current = None; - self.next_idx = 0; - } - - ret - } -} - -/// internal type used when deserializing StrMaps from -/// potentially one or many valued maps -#[derive(Deserialize)] -#[serde(untagged)] -enum OneOrMany { - One(String), - Many(Vec), -} - -impl<'de> Deserialize<'de> for StrMap { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct StrMapVisitor; - - impl<'de> Visitor<'de> for StrMapVisitor { - type Value = StrMap; - - fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(formatter, "a StrMap") - } - - fn visit_map(self, mut map: A) -> Result - where - A: MapAccess<'de>, - { - let mut inner = map.size_hint().map(HashMap::with_capacity).unwrap_or_else(HashMap::new); - // values may either be String or Vec - // to handle both single and multi value data - while let Some((key, value)) = map.next_entry::<_, OneOrMany>()? { - inner.insert( - key, - match value { - OneOrMany::One(one) => vec![one], - OneOrMany::Many(many) => many, - }, - ); - } - Ok(StrMap(Arc::new(inner))) - } - } - - deserializer.deserialize_map(StrMapVisitor) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::collections::HashMap; - - #[test] - fn str_map_default_is_empty() { - assert!(StrMap::default().is_empty()) - } - - #[test] - fn str_map_get() { - let mut data = HashMap::new(); - data.insert("foo".into(), vec!["bar".into()]); - let strmap = StrMap(data.into()); - assert_eq!(strmap.get("foo"), Some("bar")); - assert_eq!(strmap.get("bar"), None); - } - - #[test] - fn str_map_get_all() { - let mut data = HashMap::new(); - data.insert("foo".into(), vec!["bar".into(), "baz".into()]); - let strmap = StrMap(data.into()); - assert_eq!(strmap.get_all("foo"), Some(vec!["bar", "baz"])); - assert_eq!(strmap.get_all("bar"), None); - } - - #[test] - fn str_map_iter() { - let mut data = HashMap::new(); - data.insert("foo".into(), vec!["bar".into()]); - data.insert("baz".into(), vec!["boom".into()]); - let strmap = StrMap(data.into()); - let mut values = strmap.iter().map(|(_, v)| v).collect::>(); - values.sort(); - assert_eq!(values, vec!["bar", "boom"]); - } - - #[test] - fn test_empty_str_map_to_query_string() { - let data = HashMap::new(); - let strmap = StrMap(data.into()); - let query = strmap.to_query_string(); - assert_eq!("", &query); - } - - #[test] - fn test_str_map_to_query_string() { - let mut data = HashMap::new(); - data.insert("foo".into(), vec!["bar".into(), "qux".into()]); - data.insert("baz".into(), vec!["quux".into()]); - - let strmap = StrMap(data.into()); - let query = strmap.to_query_string(); - assert!(query.contains("foo=bar&foo=qux")); - assert!(query.contains("baz=quux")); - } -} diff --git a/lambda-http/tests/data/apigw_multi_value_proxy_request.json b/lambda-http/tests/data/apigw_multi_value_proxy_request.json index 5b254c8b..8f84aeb9 100644 --- a/lambda-http/tests/data/apigw_multi_value_proxy_request.json +++ b/lambda-http/tests/data/apigw_multi_value_proxy_request.json @@ -51,9 +51,6 @@ "CloudFront-Viewer-Country":[ "US" ], - "":[ - "" - ], "Content-Type":[ "application/json" ], diff --git a/lambda-integration-tests/src/bin/logs-trait.rs b/lambda-integration-tests/src/bin/logs-trait.rs index a9bbe7d5..3f5a4909 100644 --- a/lambda-integration-tests/src/bin/logs-trait.rs +++ b/lambda-integration-tests/src/bin/logs-trait.rs @@ -28,14 +28,15 @@ impl MyLogsProcessor { } } +type MyLogsFuture = Pin> + Send>>; + /// Implementation of the actual log processor /// /// This receives a `Vec` whenever there are new log entries available. impl Service> for MyLogsProcessor { type Response = (); type Error = Error; - #[allow(clippy::type_complexity)] - type Future = Pin> + Send>>; + type Future = MyLogsFuture; fn poll_ready(&mut self, _cx: &mut core::task::Context<'_>) -> core::task::Poll> { Poll::Ready(Ok(()))