From 73dcc7af13cbc3d32c9801a94a5905b6969ffa8a Mon Sep 17 00:00:00 2001 From: David Calavera Date: Fri, 24 Nov 2023 15:11:52 -0800 Subject: [PATCH] Extract the request ID without allocating extra memory. Changes the way that the Context is initialized to receive the request ID as an argument. This way we also avoid allocating additional memory for it. Signed-off-by: David Calavera --- lambda-runtime/src/lib.rs | 19 ++------ lambda-runtime/src/types.rs | 93 ++++++++++++++++++++++++------------- 2 files changed, 67 insertions(+), 45 deletions(-) diff --git a/lambda-runtime/src/lib.rs b/lambda-runtime/src/lib.rs index 5404fb96..ccd35ab0 100644 --- a/lambda-runtime/src/lib.rs +++ b/lambda-runtime/src/lib.rs @@ -17,7 +17,6 @@ use hyper::{ use lambda_runtime_api_client::Client; use serde::{Deserialize, Serialize}; use std::{ - convert::TryFrom, env, fmt::{self, Debug, Display}, future::Future, @@ -41,6 +40,8 @@ mod types; use requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest}; pub use types::{Context, FunctionResponse, IntoFunctionResponse, LambdaEvent, MetadataPrelude, StreamResponse}; +use types::invoke_request_id; + /// Error type that lambdas may result in pub type Error = lambda_runtime_api_client::Error; @@ -121,6 +122,7 @@ where trace!("New event arrived (run loop)"); let event = next_event_response?; let (parts, body) = event.into_parts(); + let request_id = invoke_request_id(&parts.headers)?; #[cfg(debug_assertions)] if parts.status == http::StatusCode::NO_CONTENT { @@ -130,19 +132,8 @@ where continue; } - let ctx: Context = Context::try_from((self.config.clone(), parts.headers))?; - let request_id = &ctx.request_id.clone(); - - let request_span = match &ctx.xray_trace_id { - Some(trace_id) => { - env::set_var("_X_AMZN_TRACE_ID", trace_id); - tracing::info_span!("Lambda runtime invoke", requestId = request_id, xrayTraceId = trace_id) - } - None => { - env::remove_var("_X_AMZN_TRACE_ID"); - tracing::info_span!("Lambda runtime invoke", requestId = request_id) - } - }; + let ctx: Context = Context::new(request_id, self.config.clone(), &parts.headers)?; + let request_span = ctx.request_span(); // Group the handling in one future and instrument it with the span async { diff --git a/lambda-runtime/src/types.rs b/lambda-runtime/src/types.rs index a252475b..82d9b21f 100644 --- a/lambda-runtime/src/types.rs +++ b/lambda-runtime/src/types.rs @@ -1,15 +1,16 @@ use crate::{Error, RefConfig}; use base64::prelude::*; use bytes::Bytes; -use http::{HeaderMap, HeaderValue, StatusCode}; +use http::{header::ToStrError, HeaderMap, HeaderValue, StatusCode}; use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, - convert::TryFrom, + env, fmt::Debug, time::{Duration, SystemTime}, }; use tokio_stream::Stream; +use tracing::Span; #[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -120,11 +121,10 @@ pub struct Context { pub env_config: RefConfig, } -impl TryFrom<(RefConfig, HeaderMap)> for Context { - type Error = Error; - fn try_from(data: (RefConfig, HeaderMap)) -> Result { - let env_config = data.0; - let headers = data.1; +impl Context { + /// Create a new [Context] struct based on the fuction configuration + /// and the incoming request data. + pub fn new(request_id: &str, env_config: RefConfig, headers: &HeaderMap) -> Result { let client_context: Option = if let Some(value) = headers.get("lambda-runtime-client-context") { serde_json::from_str(value.to_str()?)? } else { @@ -138,11 +138,7 @@ impl TryFrom<(RefConfig, HeaderMap)> for Context { }; let ctx = Context { - request_id: headers - .get("lambda-runtime-aws-request-id") - .expect("missing lambda-runtime-aws-request-id header") - .to_str()? - .to_owned(), + request_id: request_id.to_owned(), deadline: headers .get("lambda-runtime-deadline-ms") .expect("missing lambda-runtime-deadline-ms header") @@ -165,13 +161,37 @@ impl TryFrom<(RefConfig, HeaderMap)> for Context { Ok(ctx) } -} -impl Context { /// The execution deadline for the current invocation. pub fn deadline(&self) -> SystemTime { SystemTime::UNIX_EPOCH + Duration::from_millis(self.deadline) } + + /// Create a new [`tracing::Span`] for an incoming invocation. + pub(crate) fn request_span(&self) -> Span { + match &self.xray_trace_id { + Some(trace_id) => { + env::set_var("_X_AMZN_TRACE_ID", trace_id); + tracing::info_span!( + "Lambda runtime invoke", + requestId = &self.request_id, + xrayTraceId = trace_id + ) + } + None => { + env::remove_var("_X_AMZN_TRACE_ID"); + tracing::info_span!("Lambda runtime invoke", requestId = &self.request_id) + } + } + } +} + +/// Extract the invocation request id from the incoming request. +pub(crate) fn invoke_request_id(headers: &HeaderMap) -> Result<&str, ToStrError> { + headers + .get("lambda-runtime-aws-request-id") + .expect("missing lambda-runtime-aws-request-id header") + .to_str() } /// Incoming Lambda request containing the event payload and context. @@ -313,7 +333,7 @@ mod test { HeaderValue::from_static("arn::myarn"), ); headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn")); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_ok()); } @@ -324,7 +344,7 @@ mod test { let mut headers = HeaderMap::new(); headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_ok()); } @@ -355,7 +375,7 @@ mod test { ); let config = Arc::new(Config::default()); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_ok()); let tried = tried.unwrap(); assert!(tried.client_context.is_some()); @@ -369,7 +389,7 @@ mod test { headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); headers.insert("lambda-runtime-client-context", HeaderValue::from_static("{}")); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_ok()); assert!(tried.unwrap().client_context.is_some()); } @@ -390,7 +410,7 @@ mod test { "lambda-runtime-cognito-identity", HeaderValue::from_str(&cognito_identity_str).unwrap(), ); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_ok()); let tried = tried.unwrap(); assert!(tried.identity.is_some()); @@ -412,7 +432,7 @@ mod test { HeaderValue::from_static("arn::myarn"), ); headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn")); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_err()); } @@ -427,7 +447,7 @@ mod test { "lambda-runtime-client-context", HeaderValue::from_static("BAD-Type,not JSON"), ); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_err()); } @@ -439,7 +459,7 @@ mod test { headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); headers.insert("lambda-runtime-cognito-identity", HeaderValue::from_static("{}")); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_err()); } @@ -454,14 +474,13 @@ mod test { "lambda-runtime-cognito-identity", HeaderValue::from_static("BAD-Type,not JSON"), ); - let tried = Context::try_from((config, headers)); + let tried = Context::new("id", config, &headers); assert!(tried.is_err()); } #[test] #[should_panic] - #[allow(unused_must_use)] - fn context_with_missing_request_id_should_panic() { + fn context_with_missing_deadline_should_panic() { let config = Arc::new(Config::default()); let mut headers = HeaderMap::new(); @@ -471,15 +490,26 @@ mod test { HeaderValue::from_static("arn::myarn"), ); headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn")); - Context::try_from((config, headers)); + let _ = Context::new("id", config, &headers); } #[test] - #[should_panic] - #[allow(unused_must_use)] - fn context_with_missing_deadline_should_panic() { - let config = Arc::new(Config::default()); + fn invoke_request_id_should_not_panic() { + let mut headers = HeaderMap::new(); + headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id")); + headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); + headers.insert( + "lambda-runtime-invoked-function-arn", + HeaderValue::from_static("arn::myarn"), + ); + headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn")); + + let _ = invoke_request_id(&headers); + } + #[test] + #[should_panic] + fn invoke_request_id_should_panic() { let mut headers = HeaderMap::new(); headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123")); headers.insert( @@ -487,6 +517,7 @@ mod test { HeaderValue::from_static("arn::myarn"), ); headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn")); - Context::try_from((config, headers)); + + let _ = invoke_request_id(&headers); } }