From e11b91c8e01b38023f209983a81aee23439a34a3 Mon Sep 17 00:00:00 2001 From: Noam Cohen Date: Wed, 10 Jan 2024 23:21:21 +0200 Subject: [PATCH] Make sure `multiple_outputs` is inferred correctly even when using `TypedDict` (#36652) * Use `issubclass()` to check if return type is a dictionary * Compare type to `typing.Mapping` instead of `typing.Dict` * Add documentation --- airflow/decorators/base.py | 3 +-- docs/apache-airflow/tutorial/taskflow.rst | 4 ++-- tests/decorators/test_python.py | 13 +++++++++++++ 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index d3ec556f05339..119672dd427aa 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -27,7 +27,6 @@ Callable, ClassVar, Collection, - Dict, Generic, Iterator, Mapping, @@ -351,7 +350,7 @@ def fake(): except TypeError: # Can't evaluate return type. return False ttype = getattr(return_type, "__origin__", return_type) - return ttype is dict or ttype is Dict + return issubclass(ttype, Mapping) def __attrs_post_init__(self): if "self" in self.function_signature.parameters: diff --git a/docs/apache-airflow/tutorial/taskflow.rst b/docs/apache-airflow/tutorial/taskflow.rst index 5d71576b5982d..6d5b3bed25fad 100644 --- a/docs/apache-airflow/tutorial/taskflow.rst +++ b/docs/apache-airflow/tutorial/taskflow.rst @@ -428,8 +428,8 @@ Tasks can also infer multiple outputs by using dict Python typing. def identity_dict(x: int, y: int) -> dict[str, int]: return {"x": x, "y": y} -By using the typing ``Dict`` for the function return type, the ``multiple_outputs`` parameter -is automatically set to true. +By using the typing ``dict``, or any other class that conforms to the ``typing.Mapping`` protocol, +for the function return type, the ``multiple_outputs`` parameter is automatically set to true. Note, If you manually set the ``multiple_outputs`` parameter the inference is disabled and the parameter value is used. diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 78203ca6efeba..98aab562b8ce1 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -97,6 +97,19 @@ def identity_dict_with_decorator_call(x: int, y: int) -> resolve(annotation): assert identity_dict_with_decorator_call(5, 5).operator.multiple_outputs is True + @pytest.mark.skipif(sys.version_info < (3, 8), reason="PEP 589 is implemented in Python 3.8") + def test_infer_multiple_outputs_typed_dict(self): + from typing import TypedDict + + class TypeDictClass(TypedDict): + pass + + @task_decorator + def t1() -> TypeDictClass: + return {} + + assert t1().operator.multiple_outputs is True + def test_infer_multiple_outputs_forward_annotation(self): if TYPE_CHECKING: