-
Notifications
You must be signed in to change notification settings - Fork 13.6k
/
base.py
664 lines (562 loc) · 26.5 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import inspect
import itertools
import textwrap
import warnings
from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Collection,
Generic,
Iterator,
Mapping,
Sequence,
TypeVar,
cast,
overload,
)
import attr
import re2
import typing_extensions
from airflow.datasets import Dataset
from airflow.exceptions import AirflowException
from airflow.models.abstractoperator import DEFAULT_RETRIES, DEFAULT_RETRY_DELAY
from airflow.models.baseoperator import (
BaseOperator,
coerce_resources,
coerce_timedelta,
get_merged_defaults,
parse_retries,
)
from airflow.models.dag import DagContext
from airflow.models.expandinput import (
EXPAND_INPUT_EMPTY,
DictOfListsExpandInput,
ListOfDictsExpandInput,
is_mappable,
)
from airflow.models.mappedoperator import MappedOperator, ensure_xcomarg_return_value
from airflow.models.pool import Pool
from airflow.models.xcom_arg import XComArg
from airflow.typing_compat import ParamSpec, Protocol
from airflow.utils import timezone
from airflow.utils.context import KNOWN_CONTEXT_KEYS
from airflow.utils.decorators import remove_task_decorator
from airflow.utils.helpers import prevent_duplicates
from airflow.utils.task_group import TaskGroupContext
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET
if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.models.dag import DAG
from airflow.models.expandinput import (
ExpandInput,
OperatorExpandArgument,
OperatorExpandKwargsArgument,
)
from airflow.models.mappedoperator import ValidationSource
from airflow.utils.context import Context
from airflow.utils.task_group import TaskGroup
class ExpandableFactory(Protocol):
"""Protocol providing inspection against wrapped function.
This is used in ``validate_expand_kwargs`` and implemented by function
decorators like ``@task`` and ``@task_group``.
:meta private:
"""
function: Callable
@cached_property
def function_signature(self) -> inspect.Signature:
return inspect.signature(self.function)
@cached_property
def _mappable_function_argument_names(self) -> set[str]:
"""Arguments that can be mapped against."""
return set(self.function_signature.parameters)
def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]) -> None:
"""Ensure that all arguments passed to operator-mapping functions are accounted for."""
parameters = self.function_signature.parameters
if any(v.kind == inspect.Parameter.VAR_KEYWORD for v in parameters.values()):
return
kwargs_left = kwargs.copy()
for arg_name in self._mappable_function_argument_names:
value = kwargs_left.pop(arg_name, NOTSET)
if func == "expand" and value is not NOTSET and not is_mappable(value):
tname = type(value).__name__
raise ValueError(
f"expand() got an unexpected type {tname!r} for keyword argument {arg_name!r}"
)
if len(kwargs_left) == 1:
raise TypeError(f"{func}() got an unexpected keyword argument {next(iter(kwargs_left))!r}")
elif kwargs_left:
names = ", ".join(repr(n) for n in kwargs_left)
raise TypeError(f"{func}() got unexpected keyword arguments {names}")
def get_unique_task_id(
task_id: str,
dag: DAG | None = None,
task_group: TaskGroup | None = None,
) -> str:
"""
Generate unique task id given a DAG (or if run in a DAG context).
IDs are generated by appending a unique number to the end of
the original task id.
Example:
task_id
task_id__1
task_id__2
...
task_id__20
"""
dag = dag or DagContext.get_current_dag()
if not dag:
return task_id
# We need to check if we are in the context of TaskGroup as the task_id may
# already be altered
task_group = task_group or TaskGroupContext.get_current_task_group(dag)
tg_task_id = task_group.child_id(task_id) if task_group else task_id
if tg_task_id not in dag.task_ids:
return task_id
def _find_id_suffixes(dag: DAG) -> Iterator[int]:
prefix = re2.split(r"__\d+$", tg_task_id)[0]
for task_id in dag.task_ids:
match = re2.match(rf"^{prefix}__(\d+)$", task_id)
if match:
yield int(match.group(1))
yield 0 # Default if there's no matching task ID.
core = re2.split(r"__\d+$", task_id)[0]
return f"{core}__{max(_find_id_suffixes(dag)) + 1}"
class DecoratedOperator(BaseOperator):
"""
Wraps a Python callable and captures args/kwargs when called for execution.
:param python_callable: A reference to an object that is callable
:param op_kwargs: a dictionary of keyword arguments that will get unpacked
in your function (templated)
:param op_args: a list of positional arguments that will get unpacked when
calling your callable (templated)
:param multiple_outputs: If set to True, the decorated function's return value will be unrolled to
multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
:param kwargs_to_upstream: For certain operators, we might need to upstream certain arguments
that would otherwise be absorbed by the DecoratedOperator (for example python_callable for the
PythonOperator). This gives a user the option to upstream kwargs as needed.
"""
template_fields: Sequence[str] = ("op_args", "op_kwargs")
template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}
# since we won't mutate the arguments, we should just do the shallow copy
# there are some cases we can't deepcopy the objects (e.g protobuf).
shallow_copy_attrs: Sequence[str] = ("python_callable",)
def __init__(
self,
*,
python_callable: Callable,
task_id: str,
op_args: Collection[Any] | None = None,
op_kwargs: Mapping[str, Any] | None = None,
multiple_outputs: bool = False,
kwargs_to_upstream: dict[str, Any] | None = None,
**kwargs,
) -> None:
task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
self.python_callable = python_callable
kwargs_to_upstream = kwargs_to_upstream or {}
op_args = op_args or []
op_kwargs = op_kwargs or {}
# Check the decorated function's signature. We go through the argument
# list and "fill in" defaults to arguments that are known context keys,
# since values for those will be provided when the task is run. Since
# we're not actually running the function, None is good enough here.
signature = inspect.signature(python_callable)
parameters = [
param.replace(default=None) if param.name in KNOWN_CONTEXT_KEYS else param
for param in signature.parameters.values()
]
signature = signature.replace(parameters=parameters)
# Check that arguments can be binded. There's a slight difference when
# we do validation for task-mapping: Since there's no guarantee we can
# receive enough arguments at parse time, we use bind_partial to simply
# check all the arguments we know are valid. Whether these are enough
# can only be known at execution time, when unmapping happens, and this
# is called without the _airflow_mapped_validation_only flag.
if kwargs.get("_airflow_mapped_validation_only"):
signature.bind_partial(*op_args, **op_kwargs)
else:
signature.bind(*op_args, **op_kwargs)
self.multiple_outputs = multiple_outputs
self.op_args = op_args
self.op_kwargs = op_kwargs
super().__init__(task_id=task_id, **kwargs_to_upstream, **kwargs)
def execute(self, context: Context):
# todo make this more generic (move to prepare_lineage) so it deals with non taskflow operators
# as well
for arg in itertools.chain(self.op_args, self.op_kwargs.values()):
if isinstance(arg, Dataset):
self.inlets.append(arg)
return_value = super().execute(context)
return self._handle_output(return_value=return_value, context=context, xcom_push=self.xcom_push)
def _handle_output(self, return_value: Any, context: Context, xcom_push: Callable):
"""
Handle logic for whether a decorator needs to push a single return value or multiple return values.
It sets outlets if any datasets are found in the returned value(s)
:param return_value:
:param context:
:param xcom_push:
"""
if isinstance(return_value, Dataset):
self.outlets.append(return_value)
if isinstance(return_value, list):
for item in return_value:
if isinstance(item, Dataset):
self.outlets.append(item)
if not self.multiple_outputs or return_value is None:
return return_value
if isinstance(return_value, dict):
for key in return_value.keys():
if not isinstance(key, str):
raise AirflowException(
"Returned dictionary keys must be strings when using "
f"multiple_outputs, found {key} ({type(key)}) instead"
)
for key, value in return_value.items():
if isinstance(value, Dataset):
self.outlets.append(value)
xcom_push(context, key, value)
else:
raise AirflowException(
f"Returned output was type {type(return_value)} expected dictionary for multiple_outputs"
)
return return_value
def _hook_apply_defaults(self, *args, **kwargs):
if "python_callable" not in kwargs:
return args, kwargs
python_callable = kwargs["python_callable"]
default_args = kwargs.get("default_args") or {}
op_kwargs = kwargs.get("op_kwargs") or {}
f_sig = inspect.signature(python_callable)
for arg in f_sig.parameters:
if arg not in op_kwargs and arg in default_args:
op_kwargs[arg] = default_args[arg]
kwargs["op_kwargs"] = op_kwargs
return args, kwargs
def get_python_source(self):
raw_source = inspect.getsource(self.python_callable)
res = textwrap.dedent(raw_source)
res = remove_task_decorator(res, self.custom_operator_name)
return res
FParams = ParamSpec("FParams")
FReturn = TypeVar("FReturn")
OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")
@attr.define(slots=False)
class _TaskDecorator(ExpandableFactory, Generic[FParams, FReturn, OperatorSubclass]):
"""
Helper class for providing dynamic task mapping to decorated functions.
``task_decorator_factory`` returns an instance of this, instead of just a plain wrapped function.
:meta private:
"""
function: Callable[FParams, FReturn] = attr.ib(validator=attr.validators.is_callable())
operator_class: type[OperatorSubclass]
multiple_outputs: bool = attr.ib()
kwargs: dict[str, Any] = attr.ib(factory=dict)
decorator_name: str = attr.ib(repr=False, default="task")
_airflow_is_task_decorator: ClassVar[bool] = True
is_setup: bool = False
is_teardown: bool = False
on_failure_fail_dagrun: bool = False
@multiple_outputs.default
def _infer_multiple_outputs(self):
if "return" not in self.function.__annotations__:
# No return type annotation, nothing to infer
return False
try:
# We only care about the return annotation, not anything about the parameters
def fake():
...
fake.__annotations__ = {"return": self.function.__annotations__["return"]}
return_type = typing_extensions.get_type_hints(fake, self.function.__globals__).get("return", Any)
except NameError as e:
warnings.warn(
f"Cannot infer multiple_outputs for TaskFlow function {self.function.__name__!r} with forward"
f" type references that are not imported. (Error was {e})",
stacklevel=4,
)
return False
except TypeError: # Can't evaluate return type.
return False
ttype = getattr(return_type, "__origin__", return_type)
return issubclass(ttype, Mapping)
def __attrs_post_init__(self):
if "self" in self.function_signature.parameters:
raise TypeError(f"@{self.decorator_name} does not support methods")
self.kwargs.setdefault("task_id", self.function.__name__)
def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> XComArg:
if self.is_teardown:
if "trigger_rule" in self.kwargs:
raise ValueError("Trigger rule not configurable for teardown tasks.")
self.kwargs.update(trigger_rule=TriggerRule.ALL_DONE_SETUP_SUCCESS)
on_failure_fail_dagrun = self.kwargs.pop("on_failure_fail_dagrun", self.on_failure_fail_dagrun)
op = self.operator_class(
python_callable=self.function,
op_args=args,
op_kwargs=kwargs,
multiple_outputs=self.multiple_outputs,
**self.kwargs,
)
op.is_setup = self.is_setup
op.is_teardown = self.is_teardown
op.on_failure_fail_dagrun = on_failure_fail_dagrun
op_doc_attrs = [op.doc, op.doc_json, op.doc_md, op.doc_rst, op.doc_yaml]
# Set the task's doc_md to the function's docstring if it exists and no other doc* args are set.
if self.function.__doc__ and not any(op_doc_attrs):
op.doc_md = self.function.__doc__
return XComArg(op)
@property
def __wrapped__(self) -> Callable[FParams, FReturn]:
return self.function
def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]):
# Ensure that context variables are not shadowed.
context_keys_being_mapped = KNOWN_CONTEXT_KEYS.intersection(kwargs)
if len(context_keys_being_mapped) == 1:
(name,) = context_keys_being_mapped
raise ValueError(f"cannot call {func}() on task context variable {name!r}")
elif context_keys_being_mapped:
names = ", ".join(repr(n) for n in context_keys_being_mapped)
raise ValueError(f"cannot call {func}() on task context variables {names}")
super()._validate_arg_names(func, kwargs)
def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg:
if not map_kwargs:
raise TypeError("no arguments to expand against")
self._validate_arg_names("expand", map_kwargs)
prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial")
# Since the input is already checked at parse time, we can set strict
# to False to skip the checks on execution.
if self.is_teardown:
if "trigger_rule" in self.kwargs:
raise ValueError("Trigger rule not configurable for teardown tasks.")
self.kwargs.update(trigger_rule=TriggerRule.ALL_DONE_SETUP_SUCCESS)
return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)
def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg:
if isinstance(kwargs, Sequence):
for item in kwargs:
if not isinstance(item, (XComArg, Mapping)):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
elif not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
ensure_xcomarg_return_value(expand_input.value)
task_kwargs = self.kwargs.copy()
dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current_task_group(dag)
partial_kwargs, partial_params = get_merged_defaults(
dag=dag,
task_group=task_group,
task_params=task_kwargs.pop("params", None),
task_default_args=task_kwargs.pop("default_args", None),
)
partial_kwargs.update(
task_kwargs,
is_setup=self.is_setup,
is_teardown=self.is_teardown,
on_failure_fail_dagrun=self.on_failure_fail_dagrun,
)
task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group)
if task_group:
task_id = task_group.child_id(task_id)
# Logic here should be kept in sync with BaseOperatorMeta.partial().
if "task_concurrency" in partial_kwargs:
raise TypeError("unexpected argument: task_concurrency")
if partial_kwargs.get("wait_for_downstream"):
partial_kwargs["depends_on_past"] = True
start_date = timezone.convert_to_utc(partial_kwargs.pop("start_date", None))
end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None))
if partial_kwargs.get("pool") is None:
partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME
partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES))
partial_kwargs["retry_delay"] = coerce_timedelta(
partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY),
key="retry_delay",
)
max_retry_delay = partial_kwargs.get("max_retry_delay")
partial_kwargs["max_retry_delay"] = (
max_retry_delay
if max_retry_delay is None
else coerce_timedelta(max_retry_delay, key="max_retry_delay")
)
partial_kwargs["resources"] = coerce_resources(partial_kwargs.get("resources"))
partial_kwargs.setdefault("executor_config", {})
partial_kwargs.setdefault("op_args", [])
partial_kwargs.setdefault("op_kwargs", {})
# Mypy does not work well with a subclassed attrs class :(
_MappedOperator = cast(Any, DecoratedMappedOperator)
try:
operator_name = self.operator_class.custom_operator_name # type: ignore
except AttributeError:
operator_name = self.operator_class.__name__
operator = _MappedOperator(
operator_class=self.operator_class,
expand_input=EXPAND_INPUT_EMPTY, # Don't use this; mapped values go to op_kwargs_expand_input.
partial_kwargs=partial_kwargs,
task_id=task_id,
params=partial_params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
template_ext=self.operator_class.template_ext,
template_fields=self.operator_class.template_fields,
template_fields_renderers=self.operator_class.template_fields_renderers,
ui_color=self.operator_class.ui_color,
ui_fgcolor=self.operator_class.ui_fgcolor,
is_empty=False,
task_module=self.operator_class.__module__,
task_type=self.operator_class.__name__,
operator_name=operator_name,
dag=dag,
task_group=task_group,
start_date=start_date,
end_date=end_date,
multiple_outputs=self.multiple_outputs,
python_callable=self.function,
op_kwargs_expand_input=expand_input,
disallow_kwargs_override=strict,
# Different from classic operators, kwargs passed to a taskflow
# task's expand() contribute to the op_kwargs operator argument, not
# the operator arguments themselves, and should expand against it.
expand_input_attr="op_kwargs_expand_input",
)
return XComArg(operator=operator)
def partial(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]:
self._validate_arg_names("partial", kwargs)
old_kwargs = self.kwargs.get("op_kwargs", {})
prevent_duplicates(old_kwargs, kwargs, fail_reason="duplicate partial")
kwargs.update(old_kwargs)
return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs})
def override(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]:
result = attr.evolve(self, kwargs={**self.kwargs, **kwargs})
setattr(result, "is_setup", self.is_setup)
setattr(result, "is_teardown", self.is_teardown)
setattr(result, "on_failure_fail_dagrun", self.on_failure_fail_dagrun)
return result
@attr.define(kw_only=True, repr=False)
class DecoratedMappedOperator(MappedOperator):
"""MappedOperator implementation for @task-decorated task function."""
multiple_outputs: bool
python_callable: Callable
# We can't save these in expand_input because op_kwargs need to be present
# in partial_kwargs, and MappedOperator prevents duplication.
op_kwargs_expand_input: ExpandInput
def __hash__(self):
return id(self)
def __attrs_post_init__(self):
# The magic super() doesn't work here, so we use the explicit form.
# Not using super(..., self) to work around pyupgrade bug.
super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self)
XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value)
def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
# We only use op_kwargs_expand_input so this must always be empty.
if self.expand_input is not EXPAND_INPUT_EMPTY:
raise AssertionError(f"unexpected expand_input: {self.expand_input}")
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session)
return {"op_kwargs": op_kwargs}, resolved_oids
def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
partial_op_kwargs = self.partial_kwargs["op_kwargs"]
mapped_op_kwargs = mapped_kwargs["op_kwargs"]
if strict:
prevent_duplicates(partial_op_kwargs, mapped_op_kwargs, fail_reason="mapping already partial")
kwargs = {
"multiple_outputs": self.multiple_outputs,
"python_callable": self.python_callable,
"op_kwargs": {**partial_op_kwargs, **mapped_op_kwargs},
}
return super()._get_unmap_kwargs(kwargs, strict=False)
class Task(Protocol, Generic[FParams, FReturn]):
"""Declaration of a @task-decorated callable for type-checking.
An instance of this type inherits the call signature of the decorated
function wrapped in it (not *exactly* since it actually returns an XComArg,
but there's no way to express that right now), and provides two additional
methods for task-mapping.
This type is implemented by ``_TaskDecorator`` at runtime.
"""
__call__: Callable[FParams, XComArg]
function: Callable[FParams, FReturn]
@property
def __wrapped__(self) -> Callable[FParams, FReturn]:
...
def partial(self, **kwargs: Any) -> Task[FParams, FReturn]:
...
def expand(self, **kwargs: OperatorExpandArgument) -> XComArg:
...
def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg:
...
def override(self, **kwargs: Any) -> Task[FParams, FReturn]:
...
class TaskDecorator(Protocol):
"""Type declaration for ``task_decorator_factory`` return type."""
@overload
def __call__( # type: ignore[misc]
self,
python_callable: Callable[FParams, FReturn],
) -> Task[FParams, FReturn]:
"""For the "bare decorator" ``@task`` case."""
@overload
def __call__(
self,
*,
multiple_outputs: bool | None = None,
**kwargs: Any,
) -> Callable[[Callable[FParams, FReturn]], Task[FParams, FReturn]]:
"""For the decorator factory ``@task()`` case."""
def override(self, **kwargs: Any) -> Task[FParams, FReturn]:
...
def task_decorator_factory(
python_callable: Callable | None = None,
*,
multiple_outputs: bool | None = None,
decorated_operator_class: type[BaseOperator],
**kwargs,
) -> TaskDecorator:
"""Generate a wrapper that wraps a function into an Airflow operator.
Can be reused in a single DAG.
:param python_callable: Function to decorate.
:param multiple_outputs: If set to True, the decorated function's return
value will be unrolled to multiple XCom values. Dict will unroll to XCom
values with its keys as XCom keys. If set to False (default), only at
most one XCom value is pushed.
:param decorated_operator_class: The operator that executes the logic needed
to run the python function in the correct environment.
Other kwargs are directly forwarded to the underlying operator class when
it's instantiated.
"""
if multiple_outputs is None:
multiple_outputs = cast(bool, attr.NOTHING)
if python_callable:
decorator = _TaskDecorator(
function=python_callable,
multiple_outputs=multiple_outputs,
operator_class=decorated_operator_class,
kwargs=kwargs,
)
return cast(TaskDecorator, decorator)
elif python_callable is not None:
raise TypeError("No args allowed while using @task, use kwargs instead")
def decorator_factory(python_callable):
return _TaskDecorator(
function=python_callable,
multiple_outputs=multiple_outputs,
operator_class=decorated_operator_class,
kwargs=kwargs,
)
return cast(TaskDecorator, decorator_factory)