From f78673647f9ac714f0fec3cd03cdb36e89585654 Mon Sep 17 00:00:00 2001 From: rustikk Date: Mon, 3 Jan 2022 12:02:52 -0700 Subject: [PATCH 01/18] changed macros to correct classes and modules --- docs/apache-airflow/templates-ref.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/apache-airflow/templates-ref.rst b/docs/apache-airflow/templates-ref.rst index 28aba68b04e9a..f3c4a46978a33 100644 --- a/docs/apache-airflow/templates-ref.rst +++ b/docs/apache-airflow/templates-ref.rst @@ -166,9 +166,9 @@ Variable Description ``macros.datetime`` The standard lib's :class:`datetime.datetime` ``macros.timedelta`` The standard lib's :class:`datetime.timedelta` ``macros.dateutil`` A reference to the ``dateutil`` package -``macros.time`` The standard lib's :class:`datetime.time` +``macros.time`` The standard lib's :mod:`time` ``macros.uuid`` The standard lib's :mod:`uuid` -``macros.random`` The standard lib's :mod:`random` +``macros.random`` The standard lib's :class:`random.random` ================================= ============================================== Some airflow specific macros are also defined: From 70d3f70dca59ba012c303c946702833833d772a5 Mon Sep 17 00:00:00 2001 From: rustikk Date: Wed, 23 Feb 2022 19:23:59 -0700 Subject: [PATCH 02/18] adding type check for default_args --- airflow/models/dag.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index f893096a516c8..6b9390f77866c 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -341,7 +341,10 @@ def __init__( self.user_defined_macros = user_defined_macros self.user_defined_filters = user_defined_filters - self.default_args = copy.deepcopy(default_args or {}) + if not isinstance(default_args, Dict) and default_args != None: + raise TypeError("default_args must be a dictionary") + else: + self.default_args = copy.deepcopy(default_args or {}) params = params or {} # merging potentially conflicting default_args['params'] into params From f92ffbf5276d3b9a737f3b3ce0c13a5bb31db41e Mon Sep 17 00:00:00 2001 From: rustikk Date: Wed, 23 Feb 2022 20:32:02 -0700 Subject: [PATCH 03/18] flake8 checks --- airflow/models/dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 6b9390f77866c..18d19e35817c1 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -341,7 +341,7 @@ def __init__( self.user_defined_macros = user_defined_macros self.user_defined_filters = user_defined_filters - if not isinstance(default_args, Dict) and default_args != None: + if not isinstance(default_args, Dict) and default_args is not None: raise TypeError("default_args must be a dictionary") else: self.default_args = copy.deepcopy(default_args or {}) From 5b5e35601596898b3eead8e9e5eb3a7ca532f2f1 Mon Sep 17 00:00:00 2001 From: Alexander Chashnikov <6350825+ne1r0n@users.noreply.github.com> Date: Thu, 24 Feb 2022 03:51:35 +0200 Subject: [PATCH 04/18] Add Paxful to INTHEWILD.md (#21766) --- INTHEWILD.md | 1 + 1 file changed, 1 insertion(+) diff --git a/INTHEWILD.md b/INTHEWILD.md index 5f6e0aa4a1f8c..ac753f8a4cdaa 100644 --- a/INTHEWILD.md +++ b/INTHEWILD.md @@ -328,6 +328,7 @@ Currently, **officially** using Airflow: 1. [Paraná Banco](https://paranabanco.com.br/) [[@lopesdiego12](https://github.com/lopesdiego12/)] 1. [Parimatch Tech](https://parimatch.tech/) [[@KulykDmytro](https://github.com/KulykDmytro), [@Tonkonozhenko](https://github.com/Tonkonozhenko)] 1. [Pathstream](https://pathstream.com) [[@pJackDanger](https://github.com/JackDanger)] +1. [Paxful](https://paxful.com) [[@ne1r0n](https://github.com/ne1r0n)] 1. [PayFit](https://payfit.com) [[@pcorbel](https://github.com/pcorbel)] 1. [PAYMILL](https://www.paymill.com/) [[@paymill](https://github.com/paymill) & [@matthiashuschle](https://github.com/matthiashuschle)] 1. [PayPal](https://www.paypal.com/) [[@r39132](https://github.com/r39132) & [@jhsenjaliya](https://github.com/jhsenjaliya)] From de323a355cbf517a3a2184b8209644c11953896b Mon Sep 17 00:00:00 2001 From: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> Date: Wed, 23 Feb 2022 21:14:45 -0700 Subject: [PATCH 05/18] Add `2.2.4` to db migrations map (#21777) --- airflow/utils/db.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 3bd98f1e5b633..e9b2ef2eb58f3 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -82,6 +82,7 @@ "2.2.1": "7b2661a43ba3", "2.2.2": "7b2661a43ba3", "2.2.3": "be2bfac3da23", + "2.2.4": "587bdf053233", } From bc347272157777fc1d38fbbd43132388f4edba85 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Thu, 24 Feb 2022 08:12:12 +0100 Subject: [PATCH 06/18] Fix max_active_runs=1 not scheduling runs when min_file_process_interval is high (#21413) The finished dagrun was still being seen as running when we call dag.get_num_active_runs because the session was not flushed. This PR fixes it --- airflow/models/dagrun.py | 2 ++ tests/jobs/test_scheduler_job.py | 35 ++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 5170ad376a592..36014fd7cf426 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -608,11 +608,13 @@ def update_state( self.data_interval_end, self.dag_hash, ) + session.flush() self._emit_true_scheduling_delay_stats_for_finished_state(finished_tis) self._emit_duration_stats_for_finished_state() session.merge(self) + # We do not flush here for performance reasons(It increases queries count by +20) return schedulable_tis, callback diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 72d60d2e22ee1..d959e3fba1500 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -1259,6 +1259,41 @@ def test_queued_dagruns_stops_creating_when_max_active_is_reached(self, dag_make assert session.query(DagRun.state).filter(DagRun.state == State.QUEUED).count() == 0 assert orm_dag.next_dagrun_create_after is None + def test_runs_are_created_after_max_active_runs_was_reached(self, dag_maker, session): + """ + Test that when creating runs once max_active_runs is reached the runs does not stick + """ + self.scheduler_job = SchedulerJob(subdir=os.devnull) + self.scheduler_job.executor = MockExecutor(do_update=True) + self.scheduler_job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) + + with dag_maker(max_active_runs=1, session=session) as dag: + # Need to use something that doesn't immediately get marked as success by the scheduler + BashOperator(task_id='task', bash_command='true') + + dag_run = dag_maker.create_dagrun( + state=State.RUNNING, + session=session, + ) + + # Reach max_active_runs + for _ in range(3): + self.scheduler_job._do_scheduling(session) + + # Complete dagrun + # Add dag_run back in to the session (_do_scheduling does an expunge_all) + dag_run = session.merge(dag_run) + session.refresh(dag_run) + dag_run.get_task_instance(task_id='task', session=session).state = State.SUCCESS + + # create new run + for _ in range(3): + self.scheduler_job._do_scheduling(session) + + # Assert that new runs has created + dag_runs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(dag_runs) == 2 + def test_dagrun_timeout_verify_max_active_runs(self, dag_maker): """ Test if a a dagrun will not be scheduled if max_dag_runs From 736394d9117ee4cd612af282a473d9870683f625 Mon Sep 17 00:00:00 2001 From: Davy Date: Thu, 24 Feb 2022 15:40:57 +0800 Subject: [PATCH 07/18] Correctly handle multiple '=' in LocalFileSystem secrets. (#21694) --- airflow/secrets/local_filesystem.py | 7 +++---- tests/secrets/test_local_filesystem.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/airflow/secrets/local_filesystem.py b/airflow/secrets/local_filesystem.py index 7c30aa241cc2f..2c399b30f5714 100644 --- a/airflow/secrets/local_filesystem.py +++ b/airflow/secrets/local_filesystem.py @@ -74,8 +74,8 @@ def _parse_env_file(file_path: str) -> Tuple[Dict[str, List[str]], List[FileSynt # Ignore comments continue - var_parts: List[str] = line.split("=", 2) - if len(var_parts) != 2: + key, sep, value = line.partition("=") + if not sep: errors.append( FileSyntaxError( line_no=line_no, @@ -84,8 +84,7 @@ def _parse_env_file(file_path: str) -> Tuple[Dict[str, List[str]], List[FileSynt ) continue - key, value = var_parts - if not key: + if not value: errors.append( FileSyntaxError( line_no=line_no, diff --git a/tests/secrets/test_local_filesystem.py b/tests/secrets/test_local_filesystem.py index 85f0aaa0b0c01..5993eb3af2d99 100644 --- a/tests/secrets/test_local_filesystem.py +++ b/tests/secrets/test_local_filesystem.py @@ -153,6 +153,23 @@ def test_env_file_should_load_connection(self, file_content, expected_connection assert expected_connection_uris == connection_uris_by_conn_id + @parameterized.expand( + ( + ( + "CONN_ID=mysql://host_1?param1=val1¶m2=val2", + {"CONN_ID": "mysql://host_1?param1=val1¶m2=val2"}, + ), + ) + ) + def test_parsing_with_params(self, content, expected_connection_uris): + with mock_local_file(content): + connections_by_conn_id = local_filesystem.load_connections_dict("a.env") + connection_uris_by_conn_id = { + conn_id: connection.get_uri() for conn_id, connection in connections_by_conn_id.items() + } + + assert expected_connection_uris == connection_uris_by_conn_id + @parameterized.expand( ( ("AA", 'Invalid line format. The line should contain at least one equal sign ("=")'), From bfb3991d2e1a6c0a025b6f3fab9580652dc64686 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 24 Feb 2022 16:38:59 +0800 Subject: [PATCH 08/18] Rewrite taskflow-mapping argument validation (#21759) --- airflow/decorators/base.py | 76 ++++++++++++++++++-------------- airflow/models/mappedoperator.py | 2 +- airflow/models/taskinstance.py | 3 ++ airflow/utils/context.py | 45 +++++++++++++++++++ airflow/utils/context.pyi | 5 ++- tests/decorators/test_python.py | 61 ++++++++++++++++++++----- 6 files changed, 144 insertions(+), 48 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 46e4c0a5c71ed..38354d9157d05 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -19,7 +19,6 @@ import functools import inspect import re -import sys from typing import ( TYPE_CHECKING, Any, @@ -52,12 +51,13 @@ ValidationSource, create_mocked_kwargs, get_mappable_types, + prevent_duplicates, ) from airflow.models.pool import Pool from airflow.models.xcom_arg import XComArg from airflow.typing_compat import Protocol from airflow.utils import timezone -from airflow.utils.context import Context +from airflow.utils.context import KNOWN_CONTEXT_KEYS, Context from airflow.utils.task_group import TaskGroup, TaskGroupContext from airflow.utils.types import NOTSET @@ -227,45 +227,25 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]): :meta private: """ - function: Function = attr.ib(validator=attr.validators.is_callable()) + function: Function = attr.ib() operator_class: Type[OperatorSubclass] multiple_outputs: bool = attr.ib() kwargs: Dict[str, Any] = attr.ib(factory=dict) decorator_name: str = attr.ib(repr=False, default="task") - @cached_property - def function_signature(self): - return inspect.signature(self.function) - - @cached_property - def function_arg_names(self) -> Set[str]: - return set(self.function_signature.parameters) - - @function.validator - def _validate_function(self, _, f): - if 'self' in self.function_arg_names: - raise TypeError(f'@{self.decorator_name} does not support methods') - @multiple_outputs.default def _infer_multiple_outputs(self): try: return_type = typing_extensions.get_type_hints(self.function).get("return", Any) except Exception: # Can't evaluate retrurn type. return False - - # Get the non-subscripted type. The ``__origin__`` attribute is not - # stable until 3.7, but we need to use ``__extra__`` instead. - # TODO: Remove the ``__extra__`` branch when support for Python 3.6 is - # dropped in Airflow 2.3. - if sys.version_info < (3, 7): - ttype = getattr(return_type, "__extra__", return_type) - else: - ttype = getattr(return_type, "__origin__", return_type) - + ttype = getattr(return_type, "__origin__", return_type) return ttype == dict or ttype == Dict def __attrs_post_init__(self): + if "self" in self.function_signature.parameters: + raise TypeError(f"@{self.decorator_name} does not support methods") self.kwargs.setdefault('task_id', self.function.__name__) def __call__(self, *args, **kwargs) -> XComArg: @@ -280,22 +260,50 @@ def __call__(self, *args, **kwargs) -> XComArg: op.doc_md = self.function.__doc__ return XComArg(op) + @cached_property + def function_signature(self): + return inspect.signature(self.function) + + @cached_property + def _function_is_vararg(self): + return any( + v.kind == inspect.Parameter.VAR_KEYWORD for v in self.function_signature.parameters.values() + ) + + @cached_property + def _mappable_function_argument_names(self) -> Set[str]: + """Arguments that can be mapped against.""" + return set(self.function_signature.parameters) + def _validate_arg_names(self, func: ValidationSource, kwargs: Dict[str, Any]): + # Ensure that context variables are not shadowed. + context_keys_being_mapped = KNOWN_CONTEXT_KEYS.intersection(kwargs) + if len(context_keys_being_mapped) == 1: + (name,) = context_keys_being_mapped + raise ValueError(f"cannot call {func}() on task context variable {name!r}") + elif context_keys_being_mapped: + names = ", ".join(repr(n) for n in context_keys_being_mapped) + raise ValueError(f"cannot call {func}() on task context variables {names}") + + # Ensure that all arguments passed in are accounted for. + if self._function_is_vararg: + return kwargs_left = kwargs.copy() - for arg_name in self.function_arg_names: + for arg_name in self._mappable_function_argument_names: value = kwargs_left.pop(arg_name, NOTSET) if func != "map" or value is NOTSET or isinstance(value, get_mappable_types()): continue - raise ValueError(f"{func} got unexpected value{type(value)!r} for keyword argument {arg_name!r}") - + type_name = type(value).__name__ + raise ValueError(f"map() got an unexpected type {type_name!r} for keyword argument {arg_name!r}") if len(kwargs_left) == 1: - raise TypeError(f"{func} got unexpected keyword argument {next(iter(kwargs_left))!r}") + raise TypeError(f"{func}() got an unexpected keyword argument {next(iter(kwargs_left))!r}") elif kwargs_left: names = ", ".join(repr(n) for n in kwargs_left) - raise TypeError(f"{func} got unexpected keyword arguments {names}") + raise TypeError(f"{func}() got unexpected keyword arguments {names}") - def map(self, **kwargs: "MapArgument") -> XComArg: - self._validate_arg_names("map", kwargs) + def map(self, **map_kwargs: "MapArgument") -> XComArg: + self._validate_arg_names("map", map_kwargs) + prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial") partial_kwargs = self.kwargs.copy() @@ -345,7 +353,7 @@ def map(self, **kwargs: "MapArgument") -> XComArg: end_date=end_date, multiple_outputs=self.multiple_outputs, python_callable=self.function, - mapped_op_kwargs=kwargs, + mapped_op_kwargs=map_kwargs, ) return XComArg(operator=operator) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 99b70fbbf0a92..a0c16203c1e0b 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -111,7 +111,7 @@ def validate_mapping_kwargs(op: Type["BaseOperator"], func: ValidationSource, va if isinstance(value, get_mappable_types()): continue type_name = type(value).__name__ - error = f"{op.__name__}.map() got unexpected type {type_name!r} for keyword argument {name}" + error = f"{op.__name__}.map() got an unexpected type {type_name!r} for keyword argument {name}" raise ValueError(error) if not unknown_args: return # If we have no args left ot check: stop looking at the MRO chian. diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 34f5a6e2cee82..954aaf68267ca 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1929,6 +1929,9 @@ def get_prev_ds_nodash() -> Optional[str]: return None return prev_ds.replace('-', '') + # NOTE: If you add anything to this dict, make sure to also update the + # definition in airflow/utils/context.pyi, and KNOWN_CONTEXT_KEYS in + # airflow/utils/context.py! context = { 'conf': conf, 'dag': dag, diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 2cd870852dabe..04dababa24602 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -41,6 +41,51 @@ from airflow.utils.types import NOTSET +# NOTE: Please keep this in sync with Context in airflow/utils/context.pyi. +KNOWN_CONTEXT_KEYS = { + "conf", + "conn", + "dag", + "dag_run", + "data_interval_end", + "data_interval_start", + "ds", + "ds_nodash", + "execution_date", + "exception", + "inlets", + "logical_date", + "macros", + "next_ds", + "next_ds_nodash", + "next_execution_date", + "outlets", + "params", + "prev_data_interval_start_success", + "prev_data_interval_end_success", + "prev_ds", + "prev_ds_nodash", + "prev_execution_date", + "prev_execution_date_success", + "prev_start_date_success", + "run_id", + "task", + "task_instance", + "task_instance_key_str", + "test_mode", + "templates_dict", + "ti", + "tomorrow_ds", + "tomorrow_ds_nodash", + "ts", + "ts_nodash", + "ts_nodash_with_tz", + "try_number", + "var", + "yesterday_ds", + "yesterday_ds_nodash", +} + class VariableAccessor: """Wrapper to access Variable values in template.""" diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index f614459a382fc..6003d1d1ec7d6 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -25,7 +25,7 @@ # undefined attribute errors from Mypy. Hopefully there will be a mechanism to # declare "these are defined, but don't error if others are accessed" someday. -from typing import Any, Container, Iterable, Mapping, Optional, Tuple, Union, overload +from typing import Any, Container, Iterable, Mapping, Optional, Set, Tuple, Union, overload from pendulum import DateTime @@ -37,6 +37,8 @@ from airflow.models.param import ParamsDict from airflow.models.taskinstance import TaskInstance from airflow.typing_compat import TypedDict +KNOWN_CONTEXT_KEYS: Set[str] + class _VariableAccessors(TypedDict): json: Any value: Any @@ -48,6 +50,7 @@ class VariableAccessor: class ConnectionAccessor: def get(self, key: str, default_conn: Any = None) -> Any: ... +# NOTE: Please keep this in sync with KNOWN_CONTEXT_KEYS in airflow/utils/context.py. class Context(TypedDict): conf: AirflowConfigParser conn: Any diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 4f6d0bced74c4..e127ab6daeb06 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -477,22 +477,59 @@ def add_2(number: int): assert ret.operator.doc_md.strip(), "Adds 2 to number." -def test_mapped_decorator() -> None: +def test_mapped_decorator_shadow_context() -> None: @task_decorator - def double(number: int): - return number * 2 + def print_info(message: str, run_id: str = "") -> None: + print(f"{run_id}: {message}") + + with pytest.raises(ValueError) as ctx: + print_info.partial(run_id="hi") + assert str(ctx.value) == "cannot call partial() on task context variable 'run_id'" + + with pytest.raises(ValueError) as ctx: + print_info.map(run_id=["hi", "there"]) + assert str(ctx.value) == "cannot call map() on task context variable 'run_id'" + + +def test_mapped_decorator_wrong_argument() -> None: + @task_decorator + def print_info(message: str, run_id: str = "") -> None: + print(f"{run_id}: {message}") + + with pytest.raises(TypeError) as ct: + print_info.partial(wrong_name="hi") + assert str(ct.value) == "partial() got an unexpected keyword argument 'wrong_name'" + + with pytest.raises(TypeError) as ct: + print_info.map(wrong_name=["hi", "there"]) + assert str(ct.value) == "map() got an unexpected keyword argument 'wrong_name'" + + with pytest.raises(ValueError) as cv: + print_info.map(message="hi") + assert str(cv.value) == "map() got an unexpected type 'str' for keyword argument 'message'" + + +def test_mapped_decorator(): + @task_decorator + def print_info(m1: str, m2: str, run_id: str = "") -> None: + print(f"{run_id}: {m1} {m2}") + + @task_decorator + def print_everything(**kwargs) -> None: + print(kwargs) - with DAG('test_dag', start_date=DEFAULT_DATE): - literal = [1, 2, 3] - doubled_0 = double.map(number=literal) - doubled_1 = double.map(number=literal) + with DAG("test_mapped_decorator", start_date=DEFAULT_DATE): + t0 = print_info.map(m1=["a", "b"], m2={"foo": "bar"}) + t1 = print_info.partial(m1="hi").map(m2=[1, 2, 3]) + t2 = print_everything.partial(whatever="123").map(any_key=[1, 2], works=t1) - assert isinstance(doubled_0, XComArg) - assert isinstance(doubled_0.operator, DecoratedMappedOperator) - assert doubled_0.operator.task_id == "double" - assert doubled_0.operator.mapped_op_kwargs == {"number": literal} + assert isinstance(t2, XComArg) + assert isinstance(t2.operator, DecoratedMappedOperator) + assert t2.operator.task_id == "print_everything" + assert t2.operator.mapped_op_kwargs == {"any_key": [1, 2], "works": t1} - assert doubled_1.operator.task_id == "double__1" + assert t0.operator.task_id == "print_info" + assert t1.operator.task_id == "print_info__1" def test_mapped_decorator_invalid_args() -> None: From adaec6774198e53155c178093fc4ecd97568641c Mon Sep 17 00:00:00 2001 From: Sitao Z Date: Thu, 24 Feb 2022 18:53:07 +0800 Subject: [PATCH 09/18] Fix bigquery_dts parameter docstring typo (#21786) --- airflow/providers/google/cloud/operators/bigquery_dts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/providers/google/cloud/operators/bigquery_dts.py b/airflow/providers/google/cloud/operators/bigquery_dts.py index 3d67879cad7fe..c3bb734ba5e6d 100644 --- a/airflow/providers/google/cloud/operators/bigquery_dts.py +++ b/airflow/providers/google/cloud/operators/bigquery_dts.py @@ -40,7 +40,7 @@ class BigQueryCreateDataTransferOperator(BaseOperator): :param project_id: The BigQuery project id where the transfer configuration should be created. If set to None or missing, the default project_id from the Google Cloud connection is used. - :param: location: BigQuery Transfer Service location for regional transfers. + :param location: BigQuery Transfer Service location for regional transfers. :param authorization_code: authorization code to use with this transfer configuration. This is required if new credentials are needed. :param retry: A retry object used to retry requests. If `None` is @@ -123,7 +123,7 @@ class BigQueryDeleteDataTransferConfigOperator(BaseOperator): :param transfer_config_id: Id of transfer config to be used. :param project_id: The BigQuery project id where the transfer configuration should be created. If set to None or missing, the default project_id from the Google Cloud connection is used. - :param: location: BigQuery Transfer Service location for regional transfers. + :param location: BigQuery Transfer Service location for regional transfers. :param retry: A retry object used to retry requests. If `None` is specified, requests will not be retried. :param timeout: The amount of time, in seconds, to wait for the request to @@ -205,7 +205,7 @@ class BigQueryDataTransferServiceStartTransferRunsOperator(BaseOperator): `~google.cloud.bigquery_datatransfer_v1.types.Timestamp` :param project_id: The BigQuery project id where the transfer configuration should be created. If set to None or missing, the default project_id from the Google Cloud connection is used. - :param: location: BigQuery Transfer Service location for regional transfers. + :param location: BigQuery Transfer Service location for regional transfers. :param retry: A retry object used to retry requests. If `None` is specified, requests will not be retried. :param timeout: The amount of time, in seconds, to wait for the request to From 32acd754e0d443e92beba41a93ce817698bd371d Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Thu, 24 Feb 2022 14:10:49 +0100 Subject: [PATCH 10/18] Add --platform as parameter of image building (#21695) Building image for now shoudl be forced to linux/amd64 as this is the only supported platform. When using BUILDKIT, the default platform depends on the OS/processor, but until we implement multi-platform images we force them to linux/amd64. Setting it as parameter of docker build instead of env variables is more explicit and allows to copy&paste the whole command to reproduce it outside of breeze when verbose is used independently if you are on Linux, MacOS Intel/ARM. --- BREEZE.rst | 56 +++++++++++++++++++++++++ breeze | 24 +++++++++++ breeze-complete | 4 +- scripts/ci/libraries/_build_images.sh | 2 + scripts/ci/libraries/_initialization.sh | 4 +- 5 files changed, 87 insertions(+), 3 deletions(-) diff --git a/BREEZE.rst b/BREEZE.rst index ffb72c74da24f..3fa459f218130 100644 --- a/BREEZE.rst +++ b/BREEZE.rst @@ -1272,6 +1272,13 @@ This is the current syntax for `./breeze <./breeze>`_: 3.7 3.8 3.9 + --platform PLATFORM + Builds image for the platform specified. + + One of: + + linux/amd64 + -a, --install-airflow-version INSTALL_AIRFLOW_VERSION Uses different version of Airflow when building PROD image. @@ -1472,6 +1479,13 @@ This is the current syntax for `./breeze <./breeze>`_: 3.7 3.8 3.9 + --platform PLATFORM + Builds image for the platform specified. + + One of: + + linux/amd64 + -a, --install-airflow-version INSTALL_AIRFLOW_VERSION Uses different version of Airflow when building PROD image. @@ -1532,6 +1546,13 @@ This is the current syntax for `./breeze <./breeze>`_: 3.7 3.8 3.9 + --platform PLATFORM + Builds image for the platform specified. + + One of: + + linux/amd64 + -I, --production-image Use production image for entering the environment and builds (not for tests). @@ -1599,6 +1620,13 @@ This is the current syntax for `./breeze <./breeze>`_: 3.7 3.8 3.9 + --platform PLATFORM + Builds image for the platform specified. + + One of: + + linux/amd64 + -v, --verbose Show verbose information about executed docker, kind, kubectl, helm commands. Useful for debugging - when you run breeze with --verbose flags you will be able to see the commands @@ -1635,6 +1663,13 @@ This is the current syntax for `./breeze <./breeze>`_: 3.7 3.8 3.9 + --platform PLATFORM + Builds image for the platform specified. + + One of: + + linux/amd64 + #################################################################################################### @@ -1830,6 +1865,13 @@ This is the current syntax for `./breeze <./breeze>`_: 3.7 3.8 3.9 + --platform PLATFORM + Builds image for the platform specified. + + One of: + + linux/amd64 + -b, --backend BACKEND Backend to use for tests - it determines which database is used. One of: @@ -1899,6 +1941,13 @@ This is the current syntax for `./breeze <./breeze>`_: 3.7 3.8 3.9 + --platform PLATFORM + Builds image for the platform specified. + + One of: + + linux/amd64 + -F, --force-build-images Forces building of the local docker images. The images are rebuilt automatically for the first time or when changes are detected in @@ -2298,6 +2347,13 @@ This is the current syntax for `./breeze <./breeze>`_: 3.7 3.8 3.9 + --platform PLATFORM + Builds image for the platform specified. + + One of: + + linux/amd64 + **************************************************************************************************** Choose backend to run for Airflow diff --git a/breeze b/breeze index a4d372eed4e96..64a2daa0b9b05 100755 --- a/breeze +++ b/breeze @@ -499,6 +499,7 @@ EOF Branch name: ${BRANCH_NAME} Docker image: ${AIRFLOW_PROD_IMAGE} + Platform: ${PLATFORM} Airflow source version: $(build_images::get_airflow_version_from_production_image) EOF else @@ -508,6 +509,7 @@ EOF Branch name: ${BRANCH_NAME} Docker image: ${AIRFLOW_CI_IMAGE_WITH_TAG} + Platform: ${PLATFORM} Airflow source version: ${AIRFLOW_VERSION} EOF fi @@ -530,6 +532,7 @@ EOF Branch name: ${BRANCH_NAME} Docker image: ${AIRFLOW_PROD_IMAGE} + Platform: ${PLATFORM} EOF else cat < Date: Thu, 24 Feb 2022 13:46:30 +0000 Subject: [PATCH 11/18] Upgrade and record elasticsearch log_id_template changes (#21734) --- UPDATING.md | 14 ++++++++++++++ airflow/configuration.py | 7 +++++++ airflow/utils/db.py | 41 ++++++++++++++++++++++++++++------------ 3 files changed, 50 insertions(+), 12 deletions(-) diff --git a/UPDATING.md b/UPDATING.md index a9ba80821b2d2..bd9e5e32a98b3 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -129,6 +129,20 @@ Previously, a task’s log is dynamically rendered from the `[core] log_filename A new `log_template` table is introduced to solve this problem. This table is synchronised with the aforementioned config values every time Airflow starts, and a new field `log_template_id` is added to every DAG run to point to the format used by tasks (`NULL` indicates the first ever entry for compatibility). +### Default templates for log filenames and elasticsearch log_id changed + +In order to support Dynamic Task Mapping the default templates for per-task instance logging has changed. If your config contains the old default values they will be upgraded-in-place. + +If you are happy with the new config values you should _remove_ the setting in `airflow.cfg` and let the default value be used. Old default values were: + + +- `[core] log_filename_template`: `{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log` +- `[elasticsearch] log_id_template`: `{dag_id}-{task_id}-{execution_date}-{try_number}` + +`[core] log_filename_template` now uses "hive partition style" of `dag_id=/run_id=` by default, which may cause problems on some older FAT filesystems. If this affects you then you will have to change the log template. + +If you have customized the templates you should ensure that they contain `{{ ti.map_index }}` if you want to use dynamically mapped tasks. + ### `airflow.models.base.Operator` is removed Previously, there was an empty class `airflow.models.base.Operator` for “type hinting”. This class was never really useful for anything (everything it did could be done better with `airflow.models.baseoperator.BaseOperator`), and has been removed. If you are relying on the class’s existence, use `BaseOperator` (for concrete operators), `airflow.models.abstractoperator.AbstractOperator` (the base class of both `BaseOperator` and the AIP-42 `MappedOperator`), or `airflow.models.operator.Operator` (a union type `BaseOperator | MappedOperator` for type annotation). diff --git a/airflow/configuration.py b/airflow/configuration.py index c48f851c8ad5d..03a1e7a1e8ac2 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -218,6 +218,13 @@ class AirflowConfigParser(ConfigParser): '3.0', ), }, + 'elasticsearch': { + 'log_id_template': ( + re.compile('^' + re.escape('{dag_id}-{task_id}-{run_id}-{try_number}') + '$'), + '{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}', + 3.0, + ) + }, } _available_logging_levels = ['CRITICAL', 'FATAL', 'ERROR', 'WARN', 'WARNING', 'INFO', 'DEBUG'] diff --git a/airflow/utils/db.py b/airflow/utils/db.py index e9b2ef2eb58f3..2038d92cbc5b8 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -748,24 +748,41 @@ def synchronize_log_template(*, session: Session = NEW_SESSION) -> None: This checks if the last row fully matches the current config values, and insert a new row if not. """ - - def check_templates(filename, elasticsearch_id): - stored = session.query(LogTemplate).order_by(LogTemplate.id.desc()).first() - - if not stored or stored.filename != filename or stored.elasticsearch_id != elasticsearch_id: - session.add(LogTemplate(filename=filename, elasticsearch_id=elasticsearch_id)) - filename = conf.get("logging", "log_filename_template") elasticsearch_id = conf.get("elasticsearch", "log_id_template") # Before checking if the _current_ value exists, we need to check if the old config value we upgraded in # place exists! - pre_upgrade_filename = conf.upgraded_values.get(('logging', 'log_filename_template'), None) - if pre_upgrade_filename is not None: - check_templates(pre_upgrade_filename, elasticsearch_id) - session.flush() + pre_upgrade_filename = conf.upgraded_values.get(("logging", "log_filename_template"), filename) + pre_upgrade_elasticsearch_id = conf.upgraded_values.get( + ("elasticsearch", "log_id_template"), elasticsearch_id + ) + + if pre_upgrade_filename != filename or pre_upgrade_elasticsearch_id != elasticsearch_id: + # The previous non-upgraded value likely won't be the _latest_ value (as after we've recorded the + # recorded the upgraded value it will be second-to-newest), so we'll have to just search which is okay + # as this is a table with a tiny number of rows + row = ( + session.query(LogTemplate.id) + .filter( + or_( + LogTemplate.filename == pre_upgrade_filename, + LogTemplate.elasticsearch_id == pre_upgrade_elasticsearch_id, + ) + ) + .order_by(LogTemplate.id.desc()) + .first() + ) + if not row: + session.add( + LogTemplate(filename=pre_upgrade_filename, elasticsearch_id=pre_upgrade_elasticsearch_id) + ) + session.flush() + + stored = session.query(LogTemplate).order_by(LogTemplate.id.desc()).first() - check_templates(filename, elasticsearch_id) + if not stored or stored.filename != filename or stored.elasticsearch_id != elasticsearch_id: + session.add(LogTemplate(filename=filename, elasticsearch_id=elasticsearch_id)) def check_conn_id_duplicates(session: Session) -> Iterable[str]: From 855f9e05f4ec130914f2901e37bb8cb3b589fa99 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 24 Feb 2022 22:07:02 +0800 Subject: [PATCH 12/18] Rename operator mapping map() to apply() (#21754) --- airflow/decorators/base.py | 18 +++++------ airflow/decorators/task_group.py | 17 ++++++---- airflow/models/baseoperator.py | 4 +-- airflow/models/mappedoperator.py | 30 ++++++++--------- airflow/serialization/serialized_objects.py | 24 ++++++++++---- airflow/utils/task_group.py | 2 +- tests/dags/test_mapped_classic.py | 2 +- tests/dags/test_mapped_taskflow.py | 2 +- tests/decorators/test_python.py | 32 +++++++++---------- tests/models/test_baseoperator.py | 14 ++++---- tests/models/test_dagrun.py | 4 +-- tests/models/test_taskinstance.py | 14 ++++---- tests/serialization/test_dag_serialization.py | 8 ++--- tests/utils/test_task_group.py | 8 ++--- 14 files changed, 97 insertions(+), 82 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 38354d9157d05..d661b5727e2e8 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -62,7 +62,7 @@ from airflow.utils.types import NOTSET if TYPE_CHECKING: - from airflow.models.mappedoperator import MapArgument + from airflow.models.mappedoperator import Mappable def validate_python_callable(python_callable: Any) -> None: @@ -291,18 +291,18 @@ def _validate_arg_names(self, func: ValidationSource, kwargs: Dict[str, Any]): kwargs_left = kwargs.copy() for arg_name in self._mappable_function_argument_names: value = kwargs_left.pop(arg_name, NOTSET) - if func != "map" or value is NOTSET or isinstance(value, get_mappable_types()): + if func != "apply" or value is NOTSET or isinstance(value, get_mappable_types()): continue - type_name = type(value).__name__ - raise ValueError(f"map() got an unexpected type {type_name!r} for keyword argument {arg_name!r}") + tname = type(value).__name__ + raise ValueError(f"apply() got an unexpected type {tname!r} for keyword argument {arg_name!r}") if len(kwargs_left) == 1: raise TypeError(f"{func}() got an unexpected keyword argument {next(iter(kwargs_left))!r}") elif kwargs_left: names = ", ".join(repr(n) for n in kwargs_left) raise TypeError(f"{func}() got unexpected keyword arguments {names}") - def map(self, **map_kwargs: "MapArgument") -> XComArg: - self._validate_arg_names("map", map_kwargs) + def apply(self, **map_kwargs: "Mappable") -> XComArg: + self._validate_arg_names("apply", map_kwargs) prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial") partial_kwargs = self.kwargs.copy() @@ -385,7 +385,7 @@ class DecoratedMappedOperator(MappedOperator): # We can't save these in mapped_kwargs because op_kwargs need to be present # in partial_kwargs, and MappedOperator prevents duplication. - mapped_op_kwargs: Dict[str, "MapArgument"] + mapped_op_kwargs: Dict[str, "Mappable"] @classmethod @cache @@ -401,7 +401,7 @@ def __attrs_post_init__(self): super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self) XComArg.apply_upstream_relationship(self, self.mapped_op_kwargs) - def _get_expansion_kwargs(self) -> Dict[str, "MapArgument"]: + def _get_expansion_kwargs(self) -> Dict[str, "Mappable"]: """The kwargs to calculate expansion length against. Different from classic operators, a decorated (taskflow) operator's @@ -458,7 +458,7 @@ class Task(Generic[Function]): function: Function - def map(self, **kwargs: "MapArgument") -> XComArg: + def apply(self, **kwargs: "Mappable") -> XComArg: ... def partial(self, **kwargs: Any) -> "Task[Function]": diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py index 56bceed0aa9d1..554ad4e4712c6 100644 --- a/airflow/decorators/task_group.py +++ b/airflow/decorators/task_group.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from airflow.models.dag import DAG + from airflow.models.mappedoperator import Mappable F = TypeVar("F", bound=Callable) R = TypeVar("R") @@ -83,8 +84,8 @@ def __call__(self, *args, **kwargs) -> Union[R, TaskGroup]: def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]": return MappedTaskGroupDecorator(function=self.function, kwargs=self.kwargs).partial(**kwargs) - def map(self, **kwargs) -> Union[R, TaskGroup]: - return MappedTaskGroupDecorator(function=self.function, kwargs=self.kwargs).map(**kwargs) + def apply(self, **kwargs) -> Union[R, TaskGroup]: + return MappedTaskGroupDecorator(function=self.function, kwargs=self.kwargs).apply(**kwargs) @attr.define @@ -106,12 +107,16 @@ def _make_task_group(self, **kwargs) -> MappedTaskGroup: return tg def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]": - if self.partial_kwargs: - raise RuntimeError("Already a partial task group") + duplicated_keys = [k for k in kwargs if k in self.partial_kwargs] + if len(duplicated_keys) == 1: + raise ValueError(f"Cannot overwrite partial argument: {duplicated_keys[0]!r}") + elif duplicated_keys: + joined = ", ".join(repr(k) for k in duplicated_keys) + raise ValueError(f"Cannot overwrite partial arguments: {joined}") self.partial_kwargs.update(kwargs) return self - def map(self, **kwargs) -> Union[R, TaskGroup]: + def apply(self, **kwargs) -> Union[R, TaskGroup]: if self.mapped_kwargs: raise RuntimeError("Already a mapped task group") self.mapped_kwargs = kwargs @@ -145,7 +150,7 @@ class Group(Generic[F]): function: F # Return value should match F's return type, but that's impossible to declare. - def map(self, **kwargs: Any) -> Any: + def apply(self, **kwargs: "Mappable") -> Any: ... def partial(self, **kwargs: Any) -> "Group[F]": diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 8cdd3d2d91d57..b14e885f6f2a4 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -261,7 +261,7 @@ def partial( partial_kwargs.setdefault("outlets", outlets) partial_kwargs.setdefault("resources", resources) - # Post-process arguments. Should be kept in sync with _TaskDecorator.map(). + # Post-process arguments. Should be kept in sync with _TaskDecorator.apply(). if "task_concurrency" in kwargs: # Reject deprecated option. raise TypeError("unexpected argument: task_concurrency") if partial_kwargs["wait_for_downstream"]: @@ -671,7 +671,7 @@ def __new__( task_group = task_group or TaskGroupContext.get_current_task_group(dag) if not _airflow_map_validation and isinstance(task_group, MappedTaskGroup): - return cls.partial(dag=dag, task_group=task_group, **kwargs).map() + return cls.partial(dag=dag, task_group=task_group, **kwargs).apply() return super().__new__(cls) def __init__( diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index a0c16203c1e0b..e9f614e4efab6 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -78,9 +78,9 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.xcom_arg import XComArg - # BaseOperator.map() can be called on an XComArg, sequence, or dict (not any - # mapping since we need the value to be ordered). - MapArgument = Union[XComArg, Sequence, dict] + # BaseOperator.apply() can be called on an XComArg, sequence, or dict (not + # any mapping since we need the value to be ordered). + Mappable = Union[XComArg, Sequence, dict] ValidationSource = Union[Literal["map"], Literal["partial"]] @@ -104,14 +104,14 @@ def validate_mapping_kwargs(op: Type["BaseOperator"], func: ValidationSource, va continue for name in param_names: value = unknown_args.pop(name, NOTSET) - if func != "map": + if func != "apply": continue if value is NOTSET: continue if isinstance(value, get_mappable_types()): continue type_name = type(value).__name__ - error = f"{op.__name__}.map() got an unexpected type {type_name!r} for keyword argument {name}" + error = f"{op.__name__}.apply() got an unexpected type {type_name!r} for keyword argument {name}" raise ValueError(error) if not unknown_args: return # If we have no args left ot check: stop looking at the MRO chian. @@ -134,7 +134,7 @@ def prevent_duplicates(kwargs1: Dict[str, Any], kwargs2: Dict[str, Any], *, fail raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}") -def create_mocked_kwargs(kwargs: Dict[str, "MapArgument"]) -> Dict[str, unittest.mock.MagicMock]: +def create_mocked_kwargs(kwargs: Dict[str, "Mappable"]) -> Dict[str, unittest.mock.MagicMock]: """Create a mapping of mocks for given map arguments. When a mapped operator is created, we want to perform basic validation on @@ -157,14 +157,14 @@ class OperatorPartial: """An "intermediate state" returned by ``BaseOperator.partial()``. This only exists at DAG-parsing time; the only intended usage is for the - user to call ``.map()`` on it at some point (usually in a method chain) to + user to call ``.apply()`` on it at some point (usually in a method chain) to create a ``MappedOperator`` to add into the DAG. """ operator_class: Type["BaseOperator"] kwargs: Dict[str, Any] - _map_called: bool = False # Set when map() is called to ease user debugging. + _apply_called: bool = False # Set when apply() is called to ease user debugging. def __attrs_post_init__(self): from airflow.operators.subdag import SubDagOperator @@ -178,13 +178,13 @@ def __repr__(self) -> str: return f"{self.operator_class.__name__}.partial({args})" def __del__(self): - if not self._map_called: + if not self._apply_called: warnings.warn(f"{self!r} was never mapped!") - def map(self, **mapped_kwargs: "MapArgument") -> "MappedOperator": + def apply(self, **mapped_kwargs: "Mappable") -> "MappedOperator": from airflow.operators.dummy import DummyOperator - validate_mapping_kwargs(self.operator_class, "map", mapped_kwargs) + validate_mapping_kwargs(self.operator_class, "apply", mapped_kwargs) partial_kwargs = self.kwargs.copy() task_id = partial_kwargs.pop("task_id") @@ -214,7 +214,7 @@ def map(self, **mapped_kwargs: "MapArgument") -> "MappedOperator": start_date=start_date, end_date=end_date, ) - self._map_called = True + self._apply_called = True return op @@ -223,7 +223,7 @@ class MappedOperator(AbstractOperator): """Object representing a mapped operator in a DAG.""" operator_class: Union[Type["BaseOperator"], str] - mapped_kwargs: Dict[str, "MapArgument"] + mapped_kwargs: Dict[str, "Mappable"] partial_kwargs: Dict[str, Any] # Needed for serialization. @@ -461,11 +461,11 @@ def unmap(self) -> "BaseOperator": dag._remove_task(self.task_id) return self._create_unmapped_operator(mapped_kwargs=self.mapped_kwargs, real=True) - def _get_expansion_kwargs(self) -> Dict[str, "MapArgument"]: + def _get_expansion_kwargs(self) -> Dict[str, "Mappable"]: """The kwargs to calculate expansion length against. This is ``self.mapped_kwargs`` for classic operators because kwargs to - ``BaseOperator.map()`` contribute to operator arguments. + ``BaseOperator.apply()`` contribute to operator arguments. """ return self.mapped_kwargs diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 1dd4b09f6dd98..5b5225f5ebe65 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -76,17 +76,28 @@ @cache -def get_operator_extra_links(): - """ - Returns operator extra links - both the ones that are built in and the ones that come from - the providers. +def get_operator_extra_links() -> Set[str]: + """Get the operator extra links. - :return: set of extra links + This includes both the built-in ones, and those come from the providers. """ _OPERATOR_EXTRA_LINKS.update(ProvidersManager().extra_links_class_names) return _OPERATOR_EXTRA_LINKS +@cache +def _get_default_mapped_partial() -> Dict[str, Any]: + """Get default partial kwargs in a mapped operator. + + This is used to simplify a serialized mapped operator by excluding default + values supplied in the implementation from the serialized dict. Since those + are defaults, they are automatically supplied on de-serialization, so we + don't need to store them. + """ + default_partial_kwargs = BaseOperator.partial(task_id="_").apply().partial_kwargs + return BaseSerialization._serialize(default_partial_kwargs)[Encoding.VAR] + + def encode_relativedelta(var: relativedelta.relativedelta) -> Dict[str, Any]: encoded = {k: v for k, v in var.__dict__.items() if not k.startswith("_") and v} if var.weekday and var.weekday.n: @@ -580,9 +591,8 @@ def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]: # Simplify partial_kwargs by comparing it to the most barebone object. # Remove all entries that are simply default values. - default_partial = cls._serialize(BaseOperator.partial(task_id="_").map().partial_kwargs)[Encoding.VAR] serialized_partial = serialized_op["partial_kwargs"] - for k, default in default_partial.items(): + for k, default in _get_default_mapped_partial().items(): try: v = serialized_partial[k] except KeyError: diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 88b956e53c3eb..744110f3b8941 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -380,7 +380,7 @@ def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]: return DagAttributeTypes.TASK_GROUP, SerializedTaskGroup.serialize_task_group(self) - def map(self, arg: Iterable) -> "MappedTaskGroup": + def apply(self, arg: Iterable) -> "MappedTaskGroup": if self.children: raise RuntimeError("Cannot map a TaskGroup that already has children") if not self.group_id: diff --git a/tests/dags/test_mapped_classic.py b/tests/dags/test_mapped_classic.py index ab69d2334a6c7..827efdc61addf 100644 --- a/tests/dags/test_mapped_classic.py +++ b/tests/dags/test_mapped_classic.py @@ -31,4 +31,4 @@ def consumer(value): with DAG(dag_id='test_mapped_classic', start_date=days_ago(2)) as dag: - PythonOperator.partial(task_id='consumer', python_callable=consumer).map(op_args=make_arg_lists()) + PythonOperator.partial(task_id='consumer', python_callable=consumer).apply(op_args=make_arg_lists()) diff --git a/tests/dags/test_mapped_taskflow.py b/tests/dags/test_mapped_taskflow.py index f21a9a5e8a42d..1f314e23e0605 100644 --- a/tests/dags/test_mapped_taskflow.py +++ b/tests/dags/test_mapped_taskflow.py @@ -28,4 +28,4 @@ def make_list(): def consumer(value): print(repr(value)) - consumer.map(value=make_list()) + consumer.apply(value=make_list()) diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index e127ab6daeb06..3e51f61a53f16 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -487,8 +487,8 @@ def print_info(message: str, run_id: str = "") -> None: assert str(ctx.value) == "cannot call partial() on task context variable 'run_id'" with pytest.raises(ValueError) as ctx: - print_info.map(run_id=["hi", "there"]) - assert str(ctx.value) == "cannot call map() on task context variable 'run_id'" + print_info.apply(run_id=["hi", "there"]) + assert str(ctx.value) == "cannot call apply() on task context variable 'run_id'" def test_mapped_decorator_wrong_argument() -> None: @@ -501,12 +501,12 @@ def print_info(message: str, run_id: str = "") -> None: assert str(ct.value) == "partial() got an unexpected keyword argument 'wrong_name'" with pytest.raises(TypeError) as ct: - print_info.map(wrong_name=["hi", "there"]) - assert str(ct.value) == "map() got an unexpected keyword argument 'wrong_name'" + print_info.apply(wrong_name=["hi", "there"]) + assert str(ct.value) == "apply() got an unexpected keyword argument 'wrong_name'" with pytest.raises(ValueError) as cv: - print_info.map(message="hi") - assert str(cv.value) == "map() got an unexpected type 'str' for keyword argument 'message'" + print_info.apply(message="hi") + assert str(cv.value) == "apply() got an unexpected type 'str' for keyword argument 'message'" def test_mapped_decorator(): @@ -519,9 +519,9 @@ def print_everything(**kwargs) -> None: print(kwargs) with DAG("test_mapped_decorator", start_date=DEFAULT_DATE): - t0 = print_info.map(m1=["a", "b"], m2={"foo": "bar"}) - t1 = print_info.partial(m1="hi").map(m2=[1, 2, 3]) - t2 = print_everything.partial(whatever="123").map(any_key=[1, 2], works=t1) + t0 = print_info.apply(m1=["a", "b"], m2={"foo": "bar"}) + t1 = print_info.partial(m1="hi").apply(m2=[1, 2, 3]) + t2 = print_everything.partial(whatever="123").apply(any_key=[1, 2], works=t1) assert isinstance(t2, XComArg) assert isinstance(t2.operator, DecoratedMappedOperator) @@ -542,9 +542,9 @@ def double(number: int): with pytest.raises(TypeError, match="arguments 'other', 'b'"): double.partial(other=[1], b=['a']) with pytest.raises(TypeError, match="argument 'other'"): - double.map(number=literal, other=[1]) + double.apply(number=literal, other=[1]) with pytest.raises(ValueError, match="argument 'number'"): - double.map(number=1) # type: ignore[arg-type] + double.apply(number=1) # type: ignore[arg-type] def test_partial_mapped_decorator() -> None: @@ -555,9 +555,9 @@ def product(number: int, multiple: int): literal = [1, 2, 3] with DAG('test_dag', start_date=DEFAULT_DATE) as dag: - quadrupled = product.partial(multiple=3).map(number=literal) - doubled = product.partial(multiple=2).map(number=literal) - trippled = product.partial(multiple=3).map(number=literal) + quadrupled = product.partial(multiple=3).apply(number=literal) + doubled = product.partial(multiple=2).apply(number=literal) + trippled = product.partial(multiple=3).apply(number=literal) product.partial(multiple=2) # No operator is actually created. @@ -589,7 +589,7 @@ def task1(): def task2(arg1, arg2): ... - task2.partial(arg1=1).map(arg2=task1()) + task2.partial(arg1=1).apply(arg2=task1()) unmapped = dag.get_task("task2").unmap() assert set(unmapped.op_kwargs) == {"arg1", "arg2"} @@ -606,7 +606,7 @@ def task1(arg): def task2(arg1, arg2): ... - task2.partial(arg1=1).map(arg2=task1.map(arg=[1, 2])) + task2.partial(arg1=1).apply(arg2=task1.apply(arg=[1, 2])) mapped_task2 = dag.get_task("task2") assert mapped_task2.partial_kwargs["retry_delay"] == timedelta(seconds=30) diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index b9d886399912c..e24861c0da608 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -704,7 +704,7 @@ def test_task_mapping_with_dag(): with DAG("test-dag", start_date=DEFAULT_DATE) as dag: task1 = BaseOperator(task_id="op1") literal = ['a', 'b', 'c'] - mapped = MockOperator.partial(task_id='task_2').map(arg2=literal) + mapped = MockOperator.partial(task_id='task_2').apply(arg2=literal) finish = MockOperator(task_id="finish") task1 >> mapped >> finish @@ -723,7 +723,7 @@ def test_task_mapping_without_dag_context(): with DAG("test-dag", start_date=DEFAULT_DATE) as dag: task1 = BaseOperator(task_id="op1") literal = ['a', 'b', 'c'] - mapped = MockOperator.partial(task_id='task_2').map(arg2=literal) + mapped = MockOperator.partial(task_id='task_2').apply(arg2=literal) task1 >> mapped @@ -740,7 +740,7 @@ def test_task_mapping_default_args(): with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args): task1 = BaseOperator(task_id="op1") literal = ['a', 'b', 'c'] - mapped = MockOperator.partial(task_id='task_2').map(arg2=literal) + mapped = MockOperator.partial(task_id='task_2').apply(arg2=literal) task1 >> mapped @@ -750,7 +750,7 @@ def test_task_mapping_default_args(): def test_map_unknown_arg_raises(): with pytest.raises(TypeError, match=r"argument 'file'"): - BaseOperator.partial(task_id='a').map(file=[1, 2, {'a': 'b'}]) + BaseOperator.partial(task_id='a').apply(file=[1, 2, {'a': 'b'}]) def test_map_xcom_arg(): @@ -758,7 +758,7 @@ def test_map_xcom_arg(): with DAG("test-dag", start_date=DEFAULT_DATE): task1 = BaseOperator(task_id="op1") xcomarg = XComArg(task1, "test_key") - mapped = MockOperator.partial(task_id='task_2').map(arg2=xcomarg) + mapped = MockOperator.partial(task_id='task_2').apply(arg2=xcomarg) finish = MockOperator(task_id="finish") mapped >> finish @@ -817,7 +817,7 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec with dag_maker(session=session): task1 = BaseOperator(task_id="op1") xcomarg = XComArg(task1, "test_key") - mapped = MockOperator.partial(task_id='task_2').map(arg2=xcomarg) + mapped = MockOperator.partial(task_id='task_2').apply(arg2=xcomarg) dr = dag_maker.create_dagrun() @@ -861,7 +861,7 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): with dag_maker(session=session): task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id='task_2').map(arg2=XComArg(task1, XCOM_RETURN_KEY)) + mapped = MockOperator.partial(task_id='task_2').apply(arg2=XComArg(task1, XCOM_RETURN_KEY)) dr = dag_maker.create_dagrun() diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 4340595d59f4b..9e2349ca713dd 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -910,7 +910,7 @@ def test_verify_integrity_task_start_date(Stats_incr, session, run_type, expecte def test_expand_mapped_task_instance(dag_maker, session): literal = [1, 2, {'a': 'b'}] with dag_maker(session=session): - mapped = MockOperator(task_id='task_2').map(arg2=literal) + mapped = MockOperator(task_id='task_2').apply(arg2=literal) dr = dag_maker.create_dagrun() indices = ( @@ -926,7 +926,7 @@ def test_expand_mapped_task_instance(dag_maker, session): def test_ti_scheduling_mapped_zero_length(dag_maker, session): with dag_maker(session=session): task = BaseOperator(task_id='task_1') - mapped = MockOperator.partial(task_id='task_2').map(arg2=XComArg(task)) + mapped = MockOperator.partial(task_id='task_2').apply(arg2=XComArg(task)) dr: DagRun = dag_maker.create_dagrun() ti1, _ = sorted(dr.task_instances, key=lambda ti: ti.task_id) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 64182dfb06682..f35e7feba8e8c 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2289,7 +2289,7 @@ def push_something(): def pull_something(value): print(value) - pull_something.map(value=push_something()) + pull_something.apply(value=push_something()) ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push_something") with pytest.raises(UnmappableXComTypePushed) as ctx: @@ -2312,7 +2312,7 @@ def push_something(): def pull_something(value): print(value) - pull_something.map(value=push_something()) + pull_something.apply(value=push_something()) ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push_something") with pytest.raises(UnmappableXComLengthPushed) as ctx: @@ -2341,7 +2341,7 @@ def push_something(): def pull_something(value): print(value) - pull_something.map(value=push_something()) + pull_something.apply(value=push_something()) dag_run = dag_maker.create_dagrun() ti = next(ti for ti in dag_run.task_instances if ti.task_id == "push_something") @@ -2373,7 +2373,7 @@ def test_map_literal(self, literal, expected_outputs, dag_maker, session): def show(value): outputs.append(value) - show.map(value=literal) + show.apply(value=literal) dag_run = dag_maker.create_dagrun() show_task = dag.get_task("show") @@ -2405,7 +2405,7 @@ def emit(): def show(value): outputs.append(value) - show.map(value=emit()) + show.apply(value=emit()) dag_run = dag_maker.create_dagrun() emit_ti = dag_run.get_task_instance("emit", session=session) @@ -2438,7 +2438,7 @@ def emit_letters(): def show(number, letter): outputs.append((number, letter)) - show.map(number=emit_numbers(), letter=emit_letters()) + show.apply(number=emit_numbers(), letter=emit_letters()) dag_run = dag_maker.create_dagrun() for task_id in ["emit_numbers", "emit_letters"]: @@ -2477,7 +2477,7 @@ def show(a, b): outputs.append((a, b)) emit_task = emit_numbers() - show.map(a=emit_task, b=emit_task) + show.apply(a=emit_task, b=emit_task) dag_run = dag_maker.create_dagrun() ti = dag_run.get_task_instance("emit_numbers", session=session) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index fd818f0985317..7bf27f19a90fc 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -1600,7 +1600,7 @@ def mock__import__(name, globals_=None, locals_=None, fromlist=(), level=0): def test_mapped_operator_serde(): literal = [1, 2, {'a': 'b'}] - real_op = BashOperator.partial(task_id='a', executor_config={'dict': {'sub': 'value'}}).map( + real_op = BashOperator.partial(task_id='a', executor_config={'dict': {'sub': 'value'}}).apply( bash_command=literal ) @@ -1650,7 +1650,7 @@ def test_mapped_operator_xcomarg_serde(): with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag: task1 = BaseOperator(task_id="op1") xcomarg = XComArg(task1, "test_key") - mapped = MockOperator.partial(task_id='task_2').map(arg2=xcomarg) + mapped = MockOperator.partial(task_id='task_2').apply(arg2=xcomarg) serialized = SerializedBaseOperator._serialize(mapped) assert serialized == { @@ -1716,7 +1716,7 @@ def test_mapped_decorator_serde(): def x(arg1, arg2, arg3): print(arg1, arg2, arg3) - x.partial(arg1=[1, 2, {"a": "b"}]).map(arg2={"a": 1, "b": 2}, arg3=xcomarg) + x.partial(arg1=[1, 2, {"a": "b"}]).apply(arg2={"a": 1, "b": 2}, arg3=xcomarg) original = dag.get_task("x") @@ -1767,7 +1767,7 @@ def test_mapped_task_group_serde(): literal = [1, 2, {'a': 'b'}] with DAG("test", start_date=execution_date) as dag: - with TaskGroup("process_one", dag=dag).map(literal) as process_one: + with TaskGroup("process_one", dag=dag).apply(literal) as process_one: BaseOperator(task_id='one') serialized = SerializedTaskGroup.serialize_task_group(process_one) diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 3238aec0763a7..39cf59934caaf 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -1044,7 +1044,7 @@ def test_map() -> None: start = MockOperator(task_id="start") end = MockOperator(task_id="end") literal = ['a', 'b', 'c'] - with TaskGroup("process_one").map(literal) as process_one: + with TaskGroup("process_one").apply(literal) as process_one: one = MockOperator(task_id='one') two = MockOperator(task_id='two') three = MockOperator(task_id='three') @@ -1073,10 +1073,10 @@ def test_nested_map() -> None: start = MockOperator(task_id="start") end = MockOperator(task_id="end") literal = ['a', 'b', 'c'] - with TaskGroup("process_one").map(literal) as process_one: + with TaskGroup("process_one").apply(literal) as process_one: one = MockOperator(task_id='one') - with TaskGroup("process_two").map(literal) as process_one_two: + with TaskGroup("process_two").apply(literal) as process_one_two: two = MockOperator(task_id='two') three = MockOperator(task_id='three') two >> three @@ -1147,7 +1147,7 @@ def my_task_group(my_arg_1: str, unmapped: bool): with DAG("test-dag", start_date=DEFAULT_DATE) as dag: lines = ["foo", "bar", "baz"] - (task_1, task_2, task_3) = my_task_group.partial(unmapped=True).map(my_arg_1=lines) + (task_1, task_2, task_3) = my_task_group.partial(unmapped=True).apply(my_arg_1=lines) assert task_1 in dag.tasks From 73ca73369aafa086cc62008223cdafba63bbcc41 Mon Sep 17 00:00:00 2001 From: Josh Fell <48934154+josh-fell@users.noreply.github.com> Date: Thu, 24 Feb 2022 09:38:05 -0500 Subject: [PATCH 13/18] Restore image rendering in AWS Secrets Manager Backend doc (#21772) --- .../secrets-backends/aws-secrets-manager.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/apache-airflow-providers-amazon/secrets-backends/aws-secrets-manager.rst b/docs/apache-airflow-providers-amazon/secrets-backends/aws-secrets-manager.rst index 5ceaa391669ca..61af5c9f726a6 100644 --- a/docs/apache-airflow-providers-amazon/secrets-backends/aws-secrets-manager.rst +++ b/docs/apache-airflow-providers-amazon/secrets-backends/aws-secrets-manager.rst @@ -37,8 +37,9 @@ environment variables like ``AWS_ACCESS_KEY_ID``, ``AWS_SECRET_ACCESS_KEY``. Storing and Retrieving Connections """""""""""""""""""""""""""""""""" You can store the different values for a secret in two forms: storing the conn URI in one field (default mode) or using different -fields in Amazon Secrets Manager (setting ``full_url_mode`` as ``false`` in the backend config), as follow: -.. image:: img/aws-secrets-manager.png +fields in Amazon Secrets Manager (setting ``full_url_mode`` as ``false`` in the backend config), as follows: + +.. image:: /img/aws-secrets-manager.png By default you must use some of the following words for each kind of field: From a6f2d7d4449f326522b1ee76f56611d793c5a107 Mon Sep 17 00:00:00 2001 From: Malthe Borch Date: Thu, 24 Feb 2022 15:49:18 +0000 Subject: [PATCH 14/18] Use Pendulum's built-in UTC object (#21732) Co-authored-by: Tzu-ping Chung --- airflow/example_dags/plugins/workday.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/airflow/example_dags/plugins/workday.py b/airflow/example_dags/plugins/workday.py index cea12c0c0e3f1..77111a79396de 100644 --- a/airflow/example_dags/plugins/workday.py +++ b/airflow/example_dags/plugins/workday.py @@ -22,14 +22,11 @@ from datetime import timedelta from typing import Optional -from pendulum import Date, DateTime, Time, timezone +from pendulum import UTC, Date, DateTime, Time from airflow.plugins_manager import AirflowPlugin from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable -# MyPy Does not recognize callable modules https://github.com/python/mypy/issues/9240 -UTC = timezone("UTC") # type: ignore - class AfterWorkdayTimetable(Timetable): From 254a56e3b61cce947c85d1595326746d096a256a Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Thu, 24 Feb 2022 20:36:28 +0100 Subject: [PATCH 15/18] Make sure emphasis in UPDATING in .md is consistent (#21804) There is a new rule in markdownlint which has been violated in main when new version of pre-commits is installed introduced in the #21734 --- UPDATING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/UPDATING.md b/UPDATING.md index bd9e5e32a98b3..5ff419edf0f29 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -133,7 +133,7 @@ A new `log_template` table is introduced to solve this problem. This table is sy In order to support Dynamic Task Mapping the default templates for per-task instance logging has changed. If your config contains the old default values they will be upgraded-in-place. -If you are happy with the new config values you should _remove_ the setting in `airflow.cfg` and let the default value be used. Old default values were: +If you are happy with the new config values you should *remove* the setting in `airflow.cfg` and let the default value be used. Old default values were: - `[core] log_filename_template`: `{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log` From 328aaf57de433be2d2911fa7050916cef5c31d9d Mon Sep 17 00:00:00 2001 From: Mocheng Guo Date: Thu, 24 Feb 2022 11:42:28 -0800 Subject: [PATCH 16/18] REST API: add rendered fields in task instance. (#21741) Make task instance rendered template fields available in the REST API. Co-authored-by: Mocheng Guo --- .../endpoints/task_instance_endpoint.py | 31 ++++++++++++++++--- airflow/api_connexion/openapi/v1.yaml | 6 ++++ .../api_connexion/schemas/common_schema.py | 15 +++++++++ .../schemas/task_instance_schema.py | 4 +++ .../endpoints/test_task_instance_endpoint.py | 16 ++++++++-- .../schemas/test_task_instance_schema.py | 15 ++++++--- 6 files changed, 75 insertions(+), 12 deletions(-) diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index cbcc189e9c22c..af1b4e348fa28 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -39,6 +39,7 @@ from airflow.api_connexion.types import APIResponse from airflow.models import SlaMiss from airflow.models.dagrun import DagRun as DR +from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances from airflow.security import permissions from airflow.utils.session import NEW_SESSION, provide_session @@ -77,6 +78,14 @@ def get_task_instance( ) .add_entity(SlaMiss) ) + query = query.outerjoin( + RTIF, + and_( + RTIF.dag_id == TI.dag_id, + RTIF.execution_date == DR.execution_date, + RTIF.task_id == TI.task_id, + ), + ).add_entity(RTIF) task_instance = query.one_or_none() if task_instance is None: raise NotFound("Task instance not found") @@ -178,8 +187,15 @@ def get_task_instances( SlaMiss.execution_date == DR.execution_date, ), isouter=True, - ) - ti_query = base_query.add_entity(SlaMiss) + ).add_entity(SlaMiss) + ti_query = base_query.outerjoin( + RTIF, + and_( + RTIF.dag_id == TI.dag_id, + RTIF.task_id == TI.task_id, + RTIF.execution_date == DR.execution_date, + ), + ).add_entity(RTIF) task_instances = ti_query.offset(offset).limit(limit).all() return task_instance_collection_schema.dump( @@ -237,8 +253,15 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: SlaMiss.execution_date == DR.execution_date, ), isouter=True, - ) - ti_query = base_query.add_entity(SlaMiss) + ).add_entity(SlaMiss) + ti_query = base_query.outerjoin( + RTIF, + and_( + RTIF.dag_id == TI.dag_id, + RTIF.task_id == TI.task_id, + RTIF.execution_date == DR.execution_date, + ), + ).add_entity(RTIF) task_instances = ti_query.all() return task_instance_collection_schema.dump( diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 8b2cabd4b1eb1..4472b86dc1fe2 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -2496,6 +2496,12 @@ components: sla_miss: $ref: '#/components/schemas/SLAMiss' nullable: true + rendered_fields: + description: | + JSON object describing rendered fields. + + *New in version 2.3.0* + type: object TaskInstanceCollection: type: object diff --git a/airflow/api_connexion/schemas/common_schema.py b/airflow/api_connexion/schemas/common_schema.py index 1c10421ea201c..502d5b60bdddd 100644 --- a/airflow/api_connexion/schemas/common_schema.py +++ b/airflow/api_connexion/schemas/common_schema.py @@ -17,6 +17,7 @@ import datetime import inspect +import json import typing import marshmallow @@ -165,3 +166,17 @@ def _get_class_name(self, obj): if isinstance(obj, type): return obj.__name__ return type(obj).__name__ + + +class JsonObjectField(fields.Field): + """JSON object field.""" + + def _serialize(self, value, attr, obj, **kwargs): + if not value: + return {} + return json.loads(value) if isinstance(value, str) else value + + def _deserialize(self, value, attr, data, **kwargs): + if isinstance(value, str): + return json.loads(value) + return value diff --git a/airflow/api_connexion/schemas/task_instance_schema.py b/airflow/api_connexion/schemas/task_instance_schema.py index 2d1b950489417..5e31b47d49496 100644 --- a/airflow/api_connexion/schemas/task_instance_schema.py +++ b/airflow/api_connexion/schemas/task_instance_schema.py @@ -22,6 +22,7 @@ from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field from airflow.api_connexion.parameters import validate_istimezone +from airflow.api_connexion.schemas.common_schema import JsonObjectField from airflow.api_connexion.schemas.enum_schemas import TaskInstanceStateField from airflow.api_connexion.schemas.sla_miss_schema import SlaMissSchema from airflow.models import SlaMiss, TaskInstance @@ -58,6 +59,7 @@ class Meta: pid = auto_field() executor_config = auto_field() sla_miss = fields.Nested(SlaMissSchema, dump_default=None) + rendered_fields = JsonObjectField() def get_attribute(self, obj, attr, default): if attr == "sla_miss": @@ -66,6 +68,8 @@ def get_attribute(self, obj, attr, default): # corresponding to the attr. slamiss_instance = {"sla_miss": obj[1]} return get_value(slamiss_instance, attr, default) + elif attr == "rendered_fields": + return get_value(obj[2], attr, None) return get_value(obj[0], attr, default) diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index ad2e2d5253598..8f5676ed3e2e6 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -22,6 +22,7 @@ from sqlalchemy.orm import contains_eager from airflow.models import DagRun, SlaMiss, TaskInstance +from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF from airflow.security import permissions from airflow.utils.platform import getuser from airflow.utils.session import provide_session @@ -29,7 +30,7 @@ from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user -from tests.test_utils.db import clear_db_runs, clear_db_sla_miss +from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields DEFAULT_DATETIME_1 = datetime(2020, 1, 1) DEFAULT_DATETIME_STR_1 = "2020-01-01T00:00:00+00:00" @@ -105,6 +106,7 @@ def setup_attrs(self, configured_app, dagbag) -> None: self.client = self.app.test_client() # type:ignore clear_db_runs() clear_db_sla_miss() + clear_rendered_ti_fields() self.dagbag = dagbag def create_task_instances( @@ -127,6 +129,7 @@ def create_task_instances( execution_date = self.ti_init.pop("execution_date", self.default_time) dr = None + tis = [] for i in range(counter): if task_instances is None: pass @@ -155,8 +158,10 @@ def create_task_instances( for key, value in self.ti_extras.items(): setattr(ti, key, value) session.add(ti) + tis.append(ti) session.commit() + return tis class TestGetTaskInstance(TestTaskInstanceEndpoint): @@ -198,6 +203,7 @@ def test_should_respond_200(self, username, session): "try_number": 0, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", + "rendered_fields": {}, } def test_should_respond_200_with_task_state_in_removed(self, session): @@ -229,10 +235,11 @@ def test_should_respond_200_with_task_state_in_removed(self, session): "try_number": 0, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", + "rendered_fields": {}, } - def test_should_respond_200_task_instance_with_sla(self, session): - self.create_task_instances(session) + def test_should_respond_200_task_instance_with_sla_and_rendered(self, session): + tis = self.create_task_instances(session) session.query() sla_miss = SlaMiss( task_id="print_the_context", @@ -241,6 +248,8 @@ def test_should_respond_200_task_instance_with_sla(self, session): timestamp=self.default_time, ) session.add(sla_miss) + rendered_fields = RTIF(tis[0], render_templates=False) + session.add(rendered_fields) session.commit() response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", @@ -278,6 +287,7 @@ def test_should_respond_200_task_instance_with_sla(self, session): "try_number": 0, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", + "rendered_fields": {'op_args': [], 'op_kwargs': {}}, } def test_should_raises_401_unauthenticated(self): diff --git a/tests/api_connexion/schemas/test_task_instance_schema.py b/tests/api_connexion/schemas/test_task_instance_schema.py index 24da6e92b82ff..69111ffc3722c 100644 --- a/tests/api_connexion/schemas/test_task_instance_schema.py +++ b/tests/api_connexion/schemas/test_task_instance_schema.py @@ -27,7 +27,7 @@ set_task_instance_state_form, task_instance_schema, ) -from airflow.models import SlaMiss, TaskInstance as TI +from airflow.models import RenderedTaskInstanceFields as RTIF, SlaMiss, TaskInstance as TI from airflow.operators.dummy import DummyOperator from airflow.utils.platform import getuser from airflow.utils.state import State @@ -62,11 +62,11 @@ def set_attrs(self, session, dag_maker): session.rollback() - def test_task_instance_schema_without_sla(self, session): + def test_task_instance_schema_without_sla_and_rendered(self, session): ti = TI(task=self.task, **self.default_ti_init) for key, value in self.default_ti_extras.items(): setattr(ti, key, value) - serialized_ti = task_instance_schema.dump((ti, None)) + serialized_ti = task_instance_schema.dump((ti, None, None)) expected_json = { "dag_id": "TEST_DAG_ID", "duration": 10000.0, @@ -89,10 +89,11 @@ def test_task_instance_schema_without_sla(self, session): "try_number": 0, "unixname": getuser(), "dag_run_id": None, + "rendered_fields": {}, } assert serialized_ti == expected_json - def test_task_instance_schema_with_sla(self, session): + def test_task_instance_schema_with_sla_and_rendered(self, session): sla_miss = SlaMiss( task_id="TEST_TASK_ID", dag_id="TEST_DAG_ID", @@ -103,7 +104,10 @@ def test_task_instance_schema_with_sla(self, session): ti = TI(task=self.task, **self.default_ti_init) for key, value in self.default_ti_extras.items(): setattr(ti, key, value) - serialized_ti = task_instance_schema.dump((ti, sla_miss)) + self.task.template_fields = ["partitions"] + setattr(self.task, "partitions", "data/ds=2022-02-17") + rendered_fields = RTIF(ti, render_templates=False) + serialized_ti = task_instance_schema.dump((ti, sla_miss, rendered_fields)) expected_json = { "dag_id": "TEST_DAG_ID", "duration": 10000.0, @@ -134,6 +138,7 @@ def test_task_instance_schema_with_sla(self, session): "try_number": 0, "unixname": getuser(), "dag_run_id": None, + "rendered_fields": {"partitions": "data/ds=2022-02-17"}, } assert serialized_ti == expected_json From 5dc0cd57d51c5875146cee47c593508078503485 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 24 Feb 2022 21:45:01 +0000 Subject: [PATCH 17/18] Use DB where possible for quicker ``airflow dag`` subcommands (#21793) Some of the subcommands here don't actually need a full dag, so there is no point paying the (possibly long) time to parse a dagfile if we could go directly to the DB instead. Closes #21450 --- airflow/cli/commands/dag_command.py | 66 +++++++++++++---------------- 1 file changed, 29 insertions(+), 37 deletions(-) diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py index 7889652e791ce..f47b4b1b5d5c5 100644 --- a/airflow/cli/commands/dag_command.py +++ b/airflow/cli/commands/dag_command.py @@ -39,15 +39,9 @@ from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel from airflow.utils import cli as cli_utils -from airflow.utils.cli import ( - get_dag, - get_dag_by_file_location, - process_subdir, - sigint_handler, - suppress_logs_and_warning, -) +from airflow.utils.cli import get_dag, process_subdir, sigint_handler, suppress_logs_and_warning from airflow.utils.dot_renderer import render_dag, render_dag_dependencies -from airflow.utils.session import create_session, provide_session +from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.state import DagRunState @@ -249,7 +243,8 @@ def _save_dot_to_file(dot: Dot, filename: str): @cli_utils.action_cli -def dag_state(args): +@provide_session +def dag_state(args, session=NEW_SESSION): """ Returns the state (and conf if exists) of a DagRun at the command line. >>> airflow dags state tutorial 2015-01-01T00:00:00.000000 @@ -257,15 +252,16 @@ def dag_state(args): >>> airflow dags state a_dag_with_conf_passed 2015-01-01T00:00:00.000000 failed, {"name": "bob", "age": "42"} """ - if args.subdir: - dag = get_dag(args.subdir, args.dag_id) - else: - dag = get_dag_by_file_location(args.dag_id) - dr = DagRun.find(dag.dag_id, execution_date=args.execution_date) - out = dr[0].state if dr else None + + dag = DagModel.get_dagmodel(args.dag_id, session=session) + + if not dag: + raise SystemExit(f"DAG: {args.dag_id} does not exist in 'dag' table") + dr = session.query(DagRun).filter_by(dag_id=args.dag_id, execution_date=args.execution_date).one_or_none() + out = dr.state if dr else None conf_out = '' - if out and dr[0].conf: - conf_out = ', ' + json.dumps(dr[0].conf) + if out and dr.conf: + conf_out = ', ' + json.dumps(dr.conf) print(str(out) + conf_out) @@ -351,32 +347,27 @@ def dag_report(args): @cli_utils.action_cli @suppress_logs_and_warning -def dag_list_jobs(args, dag=None): +@provide_session +def dag_list_jobs(args, dag=None, session=NEW_SESSION): """Lists latest n jobs""" queries = [] if dag: args.dag_id = dag.dag_id if args.dag_id: - dagbag = DagBag() + dag = DagModel.get_dagmodel(args.dag_id, session=session) - if args.dag_id not in dagbag.dags: - error_message = f"Dag id {args.dag_id} not found" - raise AirflowException(error_message) + if not dag: + raise SystemExit(f"DAG: {args.dag_id} does not exist in 'dag' table") queries.append(BaseJob.dag_id == args.dag_id) if args.state: queries.append(BaseJob.state == args.state) fields = ['dag_id', 'state', 'job_type', 'start_date', 'end_date'] - with create_session() as session: - all_jobs = ( - session.query(BaseJob) - .filter(*queries) - .order_by(BaseJob.start_date.desc()) - .limit(args.limit) - .all() - ) - all_jobs = [{f: str(job.__getattribute__(f)) for f in fields} for job in all_jobs] + all_jobs = ( + session.query(BaseJob).filter(*queries).order_by(BaseJob.start_date.desc()).limit(args.limit).all() + ) + all_jobs = [{f: str(job.__getattribute__(f)) for f in fields} for job in all_jobs] AirflowConsole().print_as( data=all_jobs, @@ -386,16 +377,16 @@ def dag_list_jobs(args, dag=None): @cli_utils.action_cli @suppress_logs_and_warning -def dag_list_dag_runs(args, dag=None): +@provide_session +def dag_list_dag_runs(args, dag=None, session=NEW_SESSION): """Lists dag runs for a given DAG""" if dag: args.dag_id = dag.dag_id + else: + dag = DagModel.get_dagmodel(args.dag_id, session=session) - dagbag = DagBag() - - if args.dag_id is not None and args.dag_id not in dagbag.dags: - error_message = f"Dag id {args.dag_id} not found" - raise AirflowException(error_message) + if not dag: + raise SystemExit(f"DAG: {args.dag_id} does not exist in 'dag' table") state = args.state.lower() if args.state else None dag_runs = DagRun.find( @@ -404,6 +395,7 @@ def dag_list_dag_runs(args, dag=None): no_backfills=args.no_backfill, execution_start_date=args.start_date, execution_end_date=args.end_date, + session=session, ) dag_runs.sort(key=lambda x: x.execution_date, reverse=True) From e08dd258edd5971d178991a706c880c6ba8bd4be Mon Sep 17 00:00:00 2001 From: rustikk Date: Thu, 24 Feb 2022 17:41:26 -0700 Subject: [PATCH 18/18] checking type for taskgroup default_args and task default_args --- airflow/models/baseoperator.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index b14e885f6f2a4..0b86bd61c1147 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -135,7 +135,10 @@ def _get_dag_defaults(dag: Optional["DAG"], task_group: Optional["TaskGroup"]) - dag_args = copy.copy(dag.default_args) dag_params = copy.deepcopy(dag.params) if task_group: - dag_args.update(task_group.default_args) + if not isinstance(task_group.default_args, Dict) and task_group.default_args is not None: + raise TypeError("default_args must be a dictionary") + else: + dag_args.update(task_group.default_args) return dag_args, dag_params @@ -147,9 +150,12 @@ def _merge_defaults( ) -> Tuple[dict, ParamsDict]: if task_params: dag_params.update(task_params) - with contextlib.suppress(KeyError): - dag_params.update(task_default_args.pop("params")) - dag_args.update(task_default_args) + if not isinstance(task_default_args, Dict) and task_default_args is not None: + raise TypeError("default_args must be a dictionary") + else: + dag_args.update(task_default_args) + with contextlib.suppress(KeyError): + dag_params.update(task_default_args.pop("params")) return dag_args, dag_params