Skip to content

Commit

Permalink
Make sure multiple_outputs is inferred correctly even when using `T…
Browse files Browse the repository at this point in the history
…ypedDict` (#36652)

* Use `issubclass()` to check if return type is a dictionary

* Compare type to `typing.Mapping` instead of `typing.Dict`

* Add documentation
  • Loading branch information
noamcohen97 committed Jan 10, 2024
1 parent c439ab8 commit e11b91c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
3 changes: 1 addition & 2 deletions airflow/decorators/base.py
Expand Up @@ -27,7 +27,6 @@
Callable,
ClassVar,
Collection,
Dict,
Generic,
Iterator,
Mapping,
Expand Down Expand Up @@ -351,7 +350,7 @@ def fake():
except TypeError: # Can't evaluate return type.
return False
ttype = getattr(return_type, "__origin__", return_type)
return ttype is dict or ttype is Dict
return issubclass(ttype, Mapping)

def __attrs_post_init__(self):
if "self" in self.function_signature.parameters:
Expand Down
4 changes: 2 additions & 2 deletions docs/apache-airflow/tutorial/taskflow.rst
Expand Up @@ -428,8 +428,8 @@ Tasks can also infer multiple outputs by using dict Python typing.
def identity_dict(x: int, y: int) -> dict[str, int]:
return {"x": x, "y": y}
By using the typing ``Dict`` for the function return type, the ``multiple_outputs`` parameter
is automatically set to true.
By using the typing ``dict``, or any other class that conforms to the ``typing.Mapping`` protocol,
for the function return type, the ``multiple_outputs`` parameter is automatically set to true.

Note, If you manually set the ``multiple_outputs`` parameter the inference is disabled and
the parameter value is used.
Expand Down
13 changes: 13 additions & 0 deletions tests/decorators/test_python.py
Expand Up @@ -97,6 +97,19 @@ def identity_dict_with_decorator_call(x: int, y: int) -> resolve(annotation):

assert identity_dict_with_decorator_call(5, 5).operator.multiple_outputs is True

@pytest.mark.skipif(sys.version_info < (3, 8), reason="PEP 589 is implemented in Python 3.8")
def test_infer_multiple_outputs_typed_dict(self):
from typing import TypedDict

class TypeDictClass(TypedDict):
pass

@task_decorator
def t1() -> TypeDictClass:
return {}

assert t1().operator.multiple_outputs is True

def test_infer_multiple_outputs_forward_annotation(self):
if TYPE_CHECKING:

Expand Down

0 comments on commit e11b91c

Please sign in to comment.