Skip to content

Commit 19e1ceb

Browse files
committed
Add Lambda Extension crate
This new crate encapsulates the logic to create Lambda Extensions. It includes reference examples. Signed-off-by: David Calavera <david.calavera@gmail.com>
1 parent 046d380 commit 19e1ceb

File tree

6 files changed

+324
-0
lines changed

6 files changed

+324
-0
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ members = [
33
"lambda-http",
44
"lambda-runtime-client",
55
"lambda-runtime",
6+
"lambda-extension"
67
]

lambda-extension/Cargo.toml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
[package]
2+
name = "lambda_extension"
3+
version = "0.1.0"
4+
edition = "2021"
5+
authors = ["David Calavera <david.calavera@gmail.com>"]
6+
description = "AWS Lambda Extension API"
7+
license = "Apache-2.0"
8+
repository = "https://github.com/awslabs/aws-lambda-rust-runtime"
9+
categories = ["web-programming::http-server"]
10+
keywords = ["AWS", "Lambda", "API"]
11+
readme = "../README.md"
12+
13+
[dependencies]
14+
anyhow = "1.0.48"
15+
async-trait = "0.1.51"
16+
tokio = { version = "1.0", features = ["macros", "io-util", "sync", "rt-multi-thread"] }
17+
hyper = { version = "0.14", features = ["http1", "client", "server", "stream", "runtime"] }
18+
serde = { version = "1", features = ["derive"] }
19+
serde_json = "^1"
20+
bytes = "1.0"
21+
http = "0.2"
22+
async-stream = "0.3"
23+
futures = "0.3"
24+
tracing-error = "0.2"
25+
tracing = { version = "0.1", features = ["log"] }
26+
tower-service = "0.3"
27+
tokio-stream = "0.1.2"
28+
lambda_runtime_client = { version = "*", path = "../lambda-runtime-client" }
29+
30+
[dev-dependencies]
31+
tracing-subscriber = "0.3"
32+
once_cell = "1.4.0"
33+
simple_logger = "1.6.0"
34+
log = "^0.4"
35+
simple-error = "0.2"

lambda-extension/examples/basic.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
use lambda_extension::{run, Error, Extension, InvokeEvent, ShutdownEvent};
2+
use log::LevelFilter;
3+
use simple_logger::SimpleLogger;
4+
5+
struct BasicExtension {}
6+
7+
#[async_trait::async_trait]
8+
impl Extension for BasicExtension {
9+
async fn on_invoke(&self, _extension_id: &str, _event: InvokeEvent) -> Result<(), Error> {
10+
Ok(())
11+
}
12+
13+
async fn on_shutdown(&self, _extension_id: &str, _event: ShutdownEvent) -> Result<(), Error> {
14+
Ok(())
15+
}
16+
}
17+
18+
#[tokio::main]
19+
async fn main() -> Result<(), Error> {
20+
// required to enable CloudWatch error logging by the runtime
21+
// can be replaced with any other method of initializing `log`
22+
SimpleLogger::new().with_level(LevelFilter::Info).init().unwrap();
23+
24+
run(BasicExtension {}).await?;
25+
Ok(())
26+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
use lambda_extension::{run, Error, Extension, InvokeEvent, ShutdownEvent, requests::exit_error, requests::init_error};
2+
use lambda_runtime_client::Client;
3+
use log::LevelFilter;
4+
use simple_logger::SimpleLogger;
5+
6+
#[derive(Debug)]
7+
enum ErrorExample {
8+
OnInvokeError,
9+
OnShutdownError,
10+
}
11+
12+
impl std::error::Error for ErrorExample {}
13+
14+
impl std::fmt::Display for ErrorExample {
15+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
16+
match self {
17+
ErrorExample::OnInvokeError => write!(f, "error processing invocation call"),
18+
ErrorExample::OnShutdownError => write!(f, "error processing shutdown call"),
19+
}
20+
}
21+
}
22+
23+
struct ErrorHandlingExtension {
24+
client: Client
25+
}
26+
27+
#[async_trait::async_trait]
28+
impl Extension for ErrorHandlingExtension {
29+
async fn on_invoke(&self, extension_id: &str, _event: InvokeEvent) -> Result<(), Error> {
30+
let err = ErrorExample::OnInvokeError;
31+
let req = init_error(extension_id, &format!("{}", err), None)?;
32+
self.client.call(req).await?;
33+
Err(Box::new(err))
34+
}
35+
36+
async fn on_shutdown(&self, extension_id: &str, _event: ShutdownEvent) -> Result<(), Error> {
37+
let err = ErrorExample::OnShutdownError;
38+
let req = exit_error(extension_id, &format!("{}", err), None)?;
39+
self.client.call(req).await?;
40+
Err(Box::new(err))
41+
}
42+
}
43+
44+
#[tokio::main]
45+
async fn main() -> Result<(), Error> {
46+
// required to enable CloudWatch error logging by the runtime
47+
// can be replaced with any other method of initializing `log`
48+
SimpleLogger::new().with_level(LevelFilter::Info).init().unwrap();
49+
50+
let client = Client::builder().build()?;
51+
let extension = ErrorHandlingExtension { client };
52+
53+
run(extension).await?;
54+
Ok(())
55+
}

lambda-extension/src/lib.rs

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// #![deny(clippy::all, clippy::cargo)]
2+
// #![warn(missing_docs,? nonstandard_style, rust_2018_idioms)]
3+
4+
use async_trait::async_trait;
5+
use hyper::client::{connect::Connection, HttpConnector};
6+
use lambda_runtime_client::Client;
7+
use serde::Deserialize;
8+
use tokio::io::{AsyncRead, AsyncWrite};
9+
use tokio_stream::{StreamExt};
10+
use tower_service::Service;
11+
use tracing::trace;
12+
13+
pub mod requests;
14+
15+
pub type Error = lambda_runtime_client::Error;
16+
17+
#[derive(Debug, Deserialize)]
18+
#[serde(rename_all = "camelCase")]
19+
pub struct Tracing {
20+
pub r#type: String,
21+
pub value: String,
22+
}
23+
24+
#[derive(Debug, Deserialize)]
25+
#[serde(rename_all = "camelCase")]
26+
pub struct InvokeEvent {
27+
deadline_ms: u64,
28+
request_id: String,
29+
invoked_function_arn: String,
30+
tracing: Tracing,
31+
}
32+
33+
#[derive(Debug, Deserialize)]
34+
#[serde(rename_all = "camelCase")]
35+
pub struct ShutdownEvent {
36+
shutdown_reason: String,
37+
deadline_ms: u64,
38+
}
39+
40+
#[derive(Debug, Deserialize)]
41+
#[serde(rename_all = "UPPERCASE", tag = "eventType")]
42+
pub enum NextEvent {
43+
Invoke(InvokeEvent),
44+
Shutdown(ShutdownEvent),
45+
}
46+
47+
/// A trait describing an asynchronous extension.
48+
#[async_trait]
49+
pub trait Extension {
50+
async fn on_invoke(&self, extension_id: &str, event: InvokeEvent) -> Result<(), Error>;
51+
async fn on_shutdown(&self, extension_id: &str, event: ShutdownEvent) -> Result<(), Error>;
52+
}
53+
54+
struct Runtime<'a, C: Service<http::Uri> = HttpConnector> {
55+
extension_id: &'a str,
56+
client: Client<C>,
57+
}
58+
59+
impl<'a, C> Runtime<'a, C>
60+
where
61+
C: Service<http::Uri> + Clone + Send + Sync + Unpin + 'static,
62+
<C as Service<http::Uri>>::Future: Unpin + Send,
63+
<C as Service<http::Uri>>::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
64+
<C as Service<http::Uri>>::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
65+
{
66+
pub async fn run(&self, extension: impl Extension) -> Result<(), Error> {
67+
let client = &self.client;
68+
let extension_id = self.extension_id;
69+
70+
let incoming = async_stream::stream! {
71+
loop {
72+
trace!("Waiting for next event (incoming loop)");
73+
let req = requests::next_event_request(extension_id)?;
74+
let res = client.call(req).await;
75+
yield res;
76+
}
77+
};
78+
79+
tokio::pin!(incoming);
80+
while let Some(event) = incoming.next().await {
81+
trace!("New event arrived (run loop)");
82+
let event = event?;
83+
let (_parts, body) = event.into_parts();
84+
85+
let body = hyper::body::to_bytes(body).await?;
86+
trace!("{}", std::str::from_utf8(&body)?); // this may be very verbose
87+
let event: NextEvent = serde_json::from_slice(&body)?;
88+
89+
match event {
90+
NextEvent::Invoke(event) => {
91+
extension.on_invoke(extension_id, event).await?;
92+
}
93+
NextEvent::Shutdown(event) => {
94+
extension.on_shutdown(extension_id, event).await?;
95+
}
96+
};
97+
}
98+
99+
Ok(())
100+
}
101+
}
102+
103+
async fn register<C>(client: &Client<C>, extension_name: &str) -> Result<String, Error>
104+
where
105+
C: Service<http::Uri> + Clone + Send + Sync + Unpin + 'static,
106+
<C as Service<http::Uri>>::Future: Unpin + Send,
107+
<C as Service<http::Uri>>::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
108+
<C as Service<http::Uri>>::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
109+
{
110+
let req = requests::register_request(extension_name)?;
111+
let res = client.call(req).await?;
112+
// ensure!(res.status() == http::StatusCode::OK, "Unable to register extension",);
113+
114+
let ext_id = res.headers().get(requests::EXTENSION_ID_HEADER).unwrap().to_str()?;
115+
Ok(ext_id.into())
116+
}
117+
118+
pub async fn run(extension: impl Extension) -> Result<(), Error> {
119+
let args: Vec<String> = std::env::args().collect();
120+
121+
let client = Client::builder().build().expect("Unable to create a runtime client");
122+
let extension_id = register(&client, &args[0]).await?;
123+
let runtime = Runtime {
124+
extension_id: &extension_id,
125+
client,
126+
};
127+
128+
runtime.run(extension).await
129+
}

lambda-extension/src/requests.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
use crate::Error;
2+
use http::{Method, Request};
3+
use hyper::Body;
4+
use lambda_runtime_client::build_request;
5+
use serde::Serialize;
6+
7+
const EXTENSION_NAME_HEADER: &str = "Lambda-Extension-Name";
8+
pub(crate) const EXTENSION_ID_HEADER: &str = "Lambda-Extension-Identifier";
9+
const EXTENSION_ERROR_TYPE_HEADER: &str = "Lambda-Extension-Function-Error-Type";
10+
11+
pub(crate) fn next_event_request(extension_id: &str) -> Result<Request<Body>, Error> {
12+
let req = build_request()
13+
.method(Method::GET)
14+
.header(EXTENSION_ID_HEADER, extension_id)
15+
.uri("/2020-01-01/extension/event/next")
16+
.body(Body::empty())?;
17+
Ok(req)
18+
}
19+
20+
pub(crate) fn register_request(extension_name: &str) -> Result<Request<Body>, Error> {
21+
let events = serde_json::json!({
22+
"events": ["INVOKE", "SHUTDOWN"]
23+
});
24+
25+
let req = build_request()
26+
.method(Method::POST)
27+
.uri("/2020-01-01/extension/register")
28+
.header(EXTENSION_NAME_HEADER, extension_name)
29+
.body(Body::from(serde_json::to_string(&events)?))?;
30+
31+
Ok(req)
32+
}
33+
34+
#[derive(Debug, Serialize)]
35+
#[serde(rename_all = "camelCase")]
36+
pub struct ErrorRequest<'a> {
37+
error_message: &'a str,
38+
error_type: &'a str,
39+
stack_trace: Vec<&'a str>,
40+
}
41+
42+
pub fn init_error<'a>(
43+
extension_id: &str,
44+
error_type: &str,
45+
request: Option<ErrorRequest<'a>>,
46+
) -> Result<Request<Body>, Error> {
47+
error_request("init", extension_id, error_type, request)
48+
}
49+
50+
pub fn exit_error<'a>(
51+
extension_id: &str,
52+
error_type: &str,
53+
request: Option<ErrorRequest<'a>>,
54+
) -> Result<Request<Body>, Error> {
55+
error_request("exit", extension_id, error_type, request)
56+
}
57+
58+
fn error_request<'a>(
59+
error_type: &str,
60+
extension_id: &str,
61+
error_str: &str,
62+
request: Option<ErrorRequest<'a>>,
63+
) -> Result<Request<Body>, Error> {
64+
let uri = format!("/2020-01-01/extension/{}/error", error_type);
65+
66+
let body = match request {
67+
None => Body::empty(),
68+
Some(err) => Body::from(serde_json::to_string(&err)?),
69+
};
70+
71+
let req = build_request()
72+
.method(Method::POST)
73+
.uri(uri)
74+
.header(EXTENSION_ID_HEADER, extension_id)
75+
.header(EXTENSION_ERROR_TYPE_HEADER, error_str)
76+
.body(body)?;
77+
Ok(req)
78+
}

0 commit comments

Comments
 (0)