From a247817af73c278ce3d0ff105a410227f829c8cf Mon Sep 17 00:00:00 2001 From: Chi-Sheng Liu Date: Thu, 20 Jun 2024 16:15:59 +0800 Subject: [PATCH] test: Add tests for positional arguments Resolves: flyteorg/flyte#5320 Signed-off-by: Chi-Sheng Liu --- .../flytekit/unit/core/test_serialization.py | 99 +++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 2fcf8bbd94..e06dcb6bf0 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -943,3 +943,102 @@ def wf_with_input() -> typing.Optional[typing.List[int]]: ) assert wf_with_input() == input_val + +def test_positional_args_task(): + arg1 = 5 + arg2 = 6 + ret = 11 + + @task + def t1(x: int, y: int) -> int: + return x + y + + @workflow + def wf_pure_positional_args() -> int: + return t1(arg1, arg2) + + @workflow + def wf_mixed_positional_and_keyword_args() -> int: + return t1(arg1, y=arg2) + + wf_pure_positional_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_pure_positional_args) + wf_mixed_positional_and_keyword_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_keyword_args) + + arg1_binding = Scalar(primitive=Primitive(integer=arg1)) + arg2_binding = Scalar(primitive=Primitive(integer=arg2)) + output_type = LiteralType(simple=SimpleType.INTEGER) + + assert wf_pure_positional_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding + assert wf_pure_positional_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding + assert wf_pure_positional_args_spec.template.interface.outputs["o0"].type == output_type + + + assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding + assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding + assert wf_mixed_positional_and_keyword_args_spec.template.interface.outputs["o0"].type == output_type + + assert wf_pure_positional_args() == ret + assert wf_mixed_positional_and_keyword_args() == ret + +def test_positional_args_workflow(): + arg1 = 5 + arg2 = 6 + ret = 11 + + @task + def t1(x: int, y: int) -> int: + return x + y + + @workflow + def sub_wf(x: int, y: int) -> int: + return t1(x=x, y=y) + + @workflow + def wf_pure_positional_args() -> int: + return sub_wf(arg1, arg2) + + @workflow + def wf_mixed_positional_and_keyword_args() -> int: + return sub_wf(arg1, y=arg2) + + wf_pure_positional_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_pure_positional_args) + wf_mixed_positional_and_keyword_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_keyword_args) + + arg1_binding = Scalar(primitive=Primitive(integer=arg1)) + arg2_binding = Scalar(primitive=Primitive(integer=arg2)) + output_type = LiteralType(simple=SimpleType.INTEGER) + + assert wf_pure_positional_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding + assert wf_pure_positional_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding + assert wf_pure_positional_args_spec.template.interface.outputs["o0"].type == output_type + + assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding + assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding + assert wf_mixed_positional_and_keyword_args_spec.template.interface.outputs["o0"].type == output_type + + assert wf_pure_positional_args() == ret + assert wf_mixed_positional_and_keyword_args() == ret + +def test_unexpected_kwargs_task(): + @task + def t1(a: int) -> int: + return a + + with pytest.raises(AssertionError, match="Received unexpected keyword argument"): + t1(b=6) + +def test_too_many_positional_args_task(): + @task + def t1(a: int) -> int: + return a + + with pytest.raises(AssertionError, match="Received more arguments than expected"): + t1(1, 2) + +def test_both_positional_and_keyword_args_task(): + @task + def t1(a: int) -> int: + return a + + with pytest.raises(AssertionError, match="Got multiple values for argument"): + t1(1, a=2)