/
main.rs
131 lines (109 loc) · 3.96 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use async_trait::async_trait;
use aws_sdk_s3::{output::ListObjectsV2Output, Client as S3Client};
use lambda_runtime::{service_fn, tracing, Error, LambdaEvent};
use serde::{Deserialize, Serialize};
/// The request defines what bucket to list
#[derive(Deserialize)]
struct Request {
bucket: String,
}
/// The response contains a Lambda-generated request ID and
/// the list of objects in the bucket.
#[derive(Serialize)]
struct Response {
req_id: String,
bucket: String,
objects: Vec<String>,
}
#[cfg_attr(test, mockall::automock)]
#[async_trait]
trait ListObjects {
async fn list_objects(&self, bucket: &str) -> Result<ListObjectsV2Output, Error>;
}
#[async_trait]
impl ListObjects for S3Client {
async fn list_objects(&self, bucket: &str) -> Result<ListObjectsV2Output, Error> {
self.list_objects_v2().bucket(bucket).send().await.map_err(|e| e.into())
}
}
#[tokio::main]
async fn main() -> Result<(), Error> {
// required to enable CloudWatch error logging by the runtime
tracing::init_default_subscriber();
let shared_config = aws_config::load_from_env().await;
let client = S3Client::new(&shared_config);
let client_ref = &client;
let func = service_fn(move |event| async move { my_handler(event, client_ref).await });
lambda_runtime::run(func).await?;
Ok(())
}
async fn my_handler<T: ListObjects>(event: LambdaEvent<Request>, client: &T) -> Result<Response, Error> {
let bucket = event.payload.bucket;
let objects_rsp = client.list_objects(&bucket).await?;
let objects: Vec<_> = objects_rsp
.contents()
.ok_or("missing objects in list-objects-v2 response")?
.into_iter()
.filter_map(|o| o.key().map(|k| k.to_string()))
.collect();
// prepare the response
let rsp = Response {
req_id: event.context.request_id,
bucket: bucket.clone(),
objects,
};
// return `Response` (it will be serialized to JSON automatically by the runtime)
Ok(rsp)
}
#[cfg(test)]
mod tests {
use super::*;
use aws_sdk_s3::model::Object;
use lambda_runtime::{Context, LambdaEvent};
use mockall::predicate::eq;
#[tokio::test]
async fn response_is_good_for_good_bucket() {
let mut context = Context::default();
context.request_id = "test-request-id".to_string();
let mut mock_client = MockListObjects::default();
mock_client
.expect_list_objects()
.with(eq("test-bucket"))
.returning(|_| {
Ok(ListObjectsV2Output::builder()
.contents(Object::builder().key("test-key-0").build())
.contents(Object::builder().key("test-key-1").build())
.contents(Object::builder().key("test-key-2").build())
.build())
});
let payload = Request {
bucket: "test-bucket".to_string(),
};
let event = LambdaEvent { payload, context };
let result = my_handler(event, &mock_client).await.unwrap();
let expected_keys = vec![
"test-key-0".to_string(),
"test-key-1".to_string(),
"test-key-2".to_string(),
];
assert_eq!(result.req_id, "test-request-id".to_string());
assert_eq!(result.bucket, "test-bucket".to_string());
assert_eq!(result.objects, expected_keys);
}
#[tokio::test]
async fn response_is_bad_for_bad_bucket() {
let mut context = Context::default();
context.request_id = "test-request-id".to_string();
let mut mock_client = MockListObjects::default();
mock_client
.expect_list_objects()
.with(eq("unknown-bucket"))
.returning(|_| Err(Error::from("test-sdk-error")));
let payload = Request {
bucket: "unknown-bucket".to_string(),
};
let event = LambdaEvent { payload, context };
let result = my_handler(event, &mock_client).await;
assert!(result.is_err());
}
}