Skip to content

Commit

Permalink
Fix LambdaInvokeFunctionOperator payload parameter type (#32259)
Browse files Browse the repository at this point in the history
* Fixing issue - Fix payload parameter of amazon LambdaCreateFunctionOperator

---------

Co-authored-by: Elad Galili <eladg@kahoona.io>
  • Loading branch information
eladi99 and elad-galili-ka committed Jul 3, 2023
1 parent 88da71e commit 5c72bef
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 12 deletions.
5 changes: 4 additions & 1 deletion airflow/providers/amazon/aws/hooks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def invoke_lambda(
invocation_type: str | None = None,
log_type: str | None = None,
client_context: str | None = None,
payload: str | None = None,
payload: bytes | str | None = None,
qualifier: str | None = None,
):
"""
Expand All @@ -65,6 +65,9 @@ def invoke_lambda(
:param payload: The JSON that you want to provide to your Lambda function as input.
:param qualifier: AWS Lambda Function Version or Alias Name
"""
if isinstance(payload, str):
payload = payload.encode()

invoke_args = {
"FunctionName": function_name,
"InvocationType": invocation_type,
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(
qualifier: str | None = None,
invocation_type: str | None = None,
client_context: str | None = None,
payload: str | None = None,
payload: bytes | str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
Expand Down
11 changes: 8 additions & 3 deletions tests/providers/amazon/aws/hooks/test_lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

FUNCTION_NAME = "test_function"
PAYLOAD = '{"hello": "airflow"}'
BYTES_PAYLOAD = b'{"hello": "airflow"}'
RUNTIME = "python3.9"
ROLE = "role"
HANDLER = "handler"
Expand All @@ -48,13 +49,17 @@ def test_get_conn_returns_a_boto3_connection(self, hook):
@mock.patch(
"airflow.providers.amazon.aws.hooks.lambda_function.LambdaHook.conn", new_callable=mock.PropertyMock
)
def test_invoke_lambda(self, mock_conn):
@pytest.mark.parametrize(
"payload, invoke_payload",
[(PAYLOAD, BYTES_PAYLOAD), (BYTES_PAYLOAD, BYTES_PAYLOAD)],
)
def test_invoke_lambda(self, mock_conn, payload, invoke_payload):
hook = LambdaHook()
hook.invoke_lambda(function_name=FUNCTION_NAME, payload=PAYLOAD)
hook.invoke_lambda(function_name=FUNCTION_NAME, payload=payload)

mock_conn().invoke.assert_called_once_with(
FunctionName=FUNCTION_NAME,
Payload=PAYLOAD,
Payload=invoke_payload,
)

@pytest.mark.parametrize(
Expand Down
23 changes: 16 additions & 7 deletions tests/providers/amazon/aws/operators/test_lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import json
from unittest import mock
from unittest.mock import Mock, patch

Expand All @@ -30,6 +29,8 @@
)

FUNCTION_NAME = "function_name"
PAYLOAD = '{"hello": "airflow"}'
BYTES_PAYLOAD = b'{"hello": "airflow"}'
ROLE_ARN = "role_arn"
IMAGE_URI = "image_uri"

Expand Down Expand Up @@ -70,29 +71,37 @@ def test_create_lambda_with_wait_for_completion(self, mock_hook_conn, mock_hook_


class TestLambdaInvokeFunctionOperator:
def test_init(self):
@pytest.mark.parametrize(
"payload",
[PAYLOAD, BYTES_PAYLOAD],
)
def test_init(self, payload):
lambda_operator = LambdaInvokeFunctionOperator(
task_id="test",
function_name="test",
payload=json.dumps({"TestInput": "Testdata"}),
payload=payload,
log_type="None",
aws_conn_id="aws_conn_test",
)
assert lambda_operator.task_id == "test"
assert lambda_operator.function_name == "test"
assert lambda_operator.payload == json.dumps({"TestInput": "Testdata"})
assert lambda_operator.payload == payload
assert lambda_operator.log_type == "None"
assert lambda_operator.aws_conn_id == "aws_conn_test"

@patch.object(LambdaInvokeFunctionOperator, "hook", new_callable=mock.PropertyMock)
def test_invoke_lambda(self, hook_mock):
@pytest.mark.parametrize(
"payload",
[PAYLOAD, BYTES_PAYLOAD],
)
def test_invoke_lambda(self, hook_mock, payload):
operator = LambdaInvokeFunctionOperator(
task_id="task_test",
function_name="a",
invocation_type="b",
log_type="c",
client_context="d",
payload="e",
payload=payload,
qualifier="f",
)
returned_payload = Mock()
Expand All @@ -111,7 +120,7 @@ def test_invoke_lambda(self, hook_mock):
invocation_type="b",
log_type="c",
client_context="d",
payload="e",
payload=payload,
qualifier="f",
)

Expand Down

0 comments on commit 5c72bef

Please sign in to comment.