Skip to content

Commit

Permalink
Include AWS Lambda execution logs to task logs (#34692)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Oct 3, 2023
1 parent 1928498 commit 3064812
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 14 deletions.
19 changes: 18 additions & 1 deletion airflow/providers/amazon/aws/hooks/lambda_function.py
Expand Up @@ -18,10 +18,12 @@
"""This module contains AWS Lambda hook."""
from __future__ import annotations

import base64
from typing import Any

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.amazon.aws.utils.suppress import return_on_error


class LambdaHook(AwsBaseHook):
Expand Down Expand Up @@ -59,7 +61,8 @@ def invoke_lambda(
:param function_name: AWS Lambda Function Name
:param invocation_type: AWS Lambda Invocation Type (RequestResponse, Event etc)
:param log_type: Tail Invocation Request
:param log_type: Set to Tail to include the execution log in the response.
Applies to synchronously invoked functions only.
:param client_context: Up to 3,583 bytes of base64-encoded data about the invoking client
to pass to the function in the context object.
:param payload: The JSON that you want to provide to your Lambda function as input.
Expand Down Expand Up @@ -179,3 +182,17 @@ def create_lambda(
"Architectures": architectures,
}
return self.conn.create_function(**trim_none_values(create_function_args))

@staticmethod
@return_on_error(None)
def encode_log_result(log_result: str, *, keep_empty_lines: bool = True) -> list[str] | None:
"""
Encode execution log from the response and return list of log records.
Returns ``None`` on error, e.g. invalid base64-encoded string
:param log_result: base64-encoded string which contain Lambda execution Log.
:param keep_empty_lines: Whether or not keep empty lines.
"""
encoded_log_result = base64.b64decode(log_result.encode("ascii")).decode()
return [log_row for log_row in encoded_log_result.splitlines() if keep_empty_lines or log_row]
21 changes: 20 additions & 1 deletion airflow/providers/amazon/aws/operators/lambda_function.py
Expand Up @@ -165,7 +165,10 @@ class LambdaInvokeFunctionOperator(BaseOperator):
:ref:`howto/operator:LambdaInvokeFunctionOperator`
:param function_name: The name of the AWS Lambda function, version, or alias.
:param log_type: Set to Tail to include the execution log in the response. Otherwise, set to "None".
:param log_type: Set to Tail to include the execution log in the response and task logs.
Otherwise, set to "None". Applies to synchronously invoked functions only,
and returns the last 4 KB of the execution log.
:param keep_empty_log_lines: Whether or not keep empty lines in the execution log.
:param qualifier: Specify a version or alias to invoke a published version of the function.
:param invocation_type: AWS Lambda invocation type (RequestResponse, Event, DryRun)
:param client_context: Data about the invoking client to pass to the function in the context object
Expand All @@ -181,6 +184,7 @@ def __init__(
*,
function_name: str,
log_type: str | None = None,
keep_empty_log_lines: bool = True,
qualifier: str | None = None,
invocation_type: str | None = None,
client_context: str | None = None,
Expand All @@ -192,6 +196,7 @@ def __init__(
self.function_name = function_name
self.payload = payload
self.log_type = log_type
self.keep_empty_log_lines = keep_empty_log_lines
self.qualifier = qualifier
self.invocation_type = invocation_type
self.client_context = client_context
Expand All @@ -218,6 +223,20 @@ def execute(self, context: Context):
qualifier=self.qualifier,
)
self.log.info("Lambda response metadata: %r", response.get("ResponseMetadata"))

if log_result := response.get("LogResult"):
log_records = self.hook.encode_log_result(
log_result,
keep_empty_lines=self.keep_empty_log_lines,
)
if log_records:
self.log.info(
"The last 4 KB of the Lambda execution log (keep_empty_log_lines=%s).",
self.keep_empty_log_lines,
)
for log_record in log_records:
self.log.info(log_record)

if response.get("StatusCode") not in success_status_codes:
raise ValueError("Lambda function did not execute", json.dumps(response.get("ResponseMetadata")))
payload_stream = response.get("Payload")
Expand Down
18 changes: 18 additions & 0 deletions tests/providers/amazon/aws/hooks/test_lambda_function.py
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import base64
from unittest import mock
from unittest.mock import MagicMock

Expand All @@ -31,6 +32,8 @@
ROLE = "role"
HANDLER = "handler"
CODE = {}
LOG_RESPONSE = base64.b64encode(b"FOO\n\nBAR\n\n").decode()
BAD_LOG_RESPONSE = LOG_RESPONSE[:-3]


class LambdaHookForTests(LambdaHook):
Expand Down Expand Up @@ -136,3 +139,18 @@ def test_create_lambda_with_zip_package_type_and_missing_args(self, params, hook
package_type="Zip",
**params,
)

def test_encode_log_result(self):
assert LambdaHook.encode_log_result(LOG_RESPONSE) == ["FOO", "", "BAR", ""]
assert LambdaHook.encode_log_result(LOG_RESPONSE, keep_empty_lines=False) == ["FOO", "BAR"]
assert LambdaHook.encode_log_result("") == []

@pytest.mark.parametrize(
"log_result",
[
pytest.param(BAD_LOG_RESPONSE, id="corrupted"),
pytest.param(None, id="none"),
],
)
def test_encode_corrupted_log_result(self, log_result):
assert LambdaHook.encode_log_result(log_result) is None
63 changes: 51 additions & 12 deletions tests/providers/amazon/aws/operators/test_lambda_function.py
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import base64
from unittest import mock
from unittest.mock import Mock, patch

Expand All @@ -30,10 +31,15 @@
)

FUNCTION_NAME = "function_name"
PAYLOAD = '{"hello": "airflow"}'
BYTES_PAYLOAD = b'{"hello": "airflow"}'
PAYLOADS = [
pytest.param('{"hello": "airflow"}', id="string-payload"),
pytest.param(b'{"hello": "airflow"}', id="bytes-payload"),
]
ROLE_ARN = "role_arn"
IMAGE_URI = "image_uri"
LOG_RESPONSE = base64.b64encode(b"FOO\n\nBAR\n\n").decode()
BAD_LOG_RESPONSE = LOG_RESPONSE[:-3]
NO_LOG_RESPONSE_SENTINEL = type("NoLogResponseSentinel", (), {})()


class TestLambdaCreateFunctionOperator:
Expand Down Expand Up @@ -86,10 +92,7 @@ def test_create_lambda_deferrable(self, _):


class TestLambdaInvokeFunctionOperator:
@pytest.mark.parametrize(
"payload",
[PAYLOAD, BYTES_PAYLOAD],
)
@pytest.mark.parametrize("payload", PAYLOADS)
def test_init(self, payload):
lambda_operator = LambdaInvokeFunctionOperator(
task_id="test",
Expand All @@ -104,33 +107,57 @@ def test_init(self, payload):
assert lambda_operator.log_type == "None"
assert lambda_operator.aws_conn_id == "aws_conn_test"

@patch.object(LambdaInvokeFunctionOperator, "hook", new_callable=mock.PropertyMock)
@mock.patch.object(LambdaHook, "invoke_lambda")
@mock.patch.object(LambdaHook, "conn")
@pytest.mark.parametrize(
"payload",
[PAYLOAD, BYTES_PAYLOAD],
"keep_empty_log_lines", [pytest.param(True, id="keep"), pytest.param(False, id="truncate")]
)
def test_invoke_lambda(self, hook_mock, payload):
@pytest.mark.parametrize(
"log_result, expected_execution_logs",
[
pytest.param(LOG_RESPONSE, True, id="log-result"),
pytest.param(BAD_LOG_RESPONSE, False, id="corrupted-log-result"),
pytest.param(None, False, id="none-log-result"),
pytest.param(NO_LOG_RESPONSE_SENTINEL, False, id="no-response"),
],
)
@pytest.mark.parametrize("payload", PAYLOADS)
def test_invoke_lambda(
self,
mock_conn,
mock_invoke,
payload,
keep_empty_log_lines,
log_result,
expected_execution_logs,
caplog,
):
operator = LambdaInvokeFunctionOperator(
task_id="task_test",
function_name="a",
invocation_type="b",
log_type="c",
keep_empty_log_lines=keep_empty_log_lines,
client_context="d",
payload=payload,
qualifier="f",
)
returned_payload = Mock()
returned_payload.read().decode.return_value = "data was read"
hook_mock().invoke_lambda.return_value = {
fake_response = {
"ResponseMetadata": "",
"StatusCode": 200,
"Payload": returned_payload,
}
if log_result is not NO_LOG_RESPONSE_SENTINEL:
fake_response["LogResult"] = log_result
mock_invoke.return_value = fake_response

caplog.set_level("INFO", "airflow.task.operators")
value = operator.execute(None)

assert value == "data was read"
hook_mock().invoke_lambda.assert_called_once_with(
mock_invoke.assert_called_once_with(
function_name="a",
invocation_type="b",
log_type="c",
Expand All @@ -139,6 +166,18 @@ def test_invoke_lambda(self, hook_mock, payload):
qualifier="f",
)

# Validate log messages in task logs
if expected_execution_logs:
assert "The last 4 KB of the Lambda execution log" in caplog.text
assert "FOO" in caplog.messages
assert "BAR" in caplog.messages
if keep_empty_log_lines:
assert "" in caplog.messages
else:
assert "" not in caplog.messages
else:
assert "The last 4 KB of the Lambda execution log" not in caplog.text

@patch.object(LambdaInvokeFunctionOperator, "hook", new_callable=mock.PropertyMock)
def test_invoke_lambda_bad_http_code(self, hook_mock):
operator = LambdaInvokeFunctionOperator(
Expand Down

0 comments on commit 3064812

Please sign in to comment.