From a076ae57726434930cf1a206b7f9813c39d86728 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 3 Apr 2024 12:21:50 +0800 Subject: [PATCH] Make dataset_events access safer The key is not always available (mostly in tests). Create a placeholder if it's not. --- airflow/models/baseoperator.py | 16 +++++++++++----- airflow/models/taskinstance.py | 8 ++++++-- airflow/operators/python.py | 7 ++----- airflow/utils/context.py | 7 +++++++ airflow/utils/context.pyi | 1 + tests/operators/test_python.py | 1 - 6 files changed, 27 insertions(+), 13 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 60e61b4305093..5b6626af37bf3 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -91,7 +91,7 @@ from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils import timezone -from airflow.utils.context import Context +from airflow.utils.context import Context, context_get_dataset_events from airflow.utils.decorators import fixup_decorator_warning_stack from airflow.utils.edgemodifier import EdgeModifier from airflow.utils.helpers import validate_key @@ -1272,8 +1272,11 @@ def pre_execute(self, context: Any): """Execute right before self.execute() is called.""" if self._pre_execute_hook is None: return - runner = ExecutionCallableRunner(self._pre_execute_hook, context["dataset_events"], logger=self.log) - runner.run(context) + ExecutionCallableRunner( + self._pre_execute_hook, + context_get_dataset_events(context), + logger=self.log, + ).run(context) def execute(self, context: Context) -> Any: """ @@ -1294,8 +1297,11 @@ def post_execute(self, context: Any, result: Any = None): """ if self._post_execute_hook is None: return - runner = ExecutionCallableRunner(self._post_execute_hook, context["dataset_events"], logger=self.log) - runner.run(context, result) + ExecutionCallableRunner( + self._post_execute_hook, + context_get_dataset_events(context), + logger=self.log, + ).run(context, result) def on_kill(self) -> None: """ diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 7b3ef05d2f598..d52a71c5b2e16 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -111,6 +111,7 @@ Context, DatasetEventAccessors, VariableAccessor, + context_get_dataset_events, context_merge, ) from airflow.utils.email import send_email @@ -437,8 +438,11 @@ def _execute_callable(context: Context, **execute_callable_kwargs): # Print a marker for log grouping of details before task execution log.info("::endgroup::") - runner = ExecutionCallableRunner(execute_callable, context["dataset_events"], logger=log) - return runner.run(context=context, **execute_callable_kwargs) + return ExecutionCallableRunner( + execute_callable, + context_get_dataset_events(context), + logger=log, + ).run(context=context, **execute_callable_kwargs) except SystemExit as e: # Handle only successful cases here. Failure cases will be handled upper # in the exception chain. diff --git a/airflow/operators/python.py b/airflow/operators/python.py index dfafbffc0aa75..998e47c91b53f 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -52,7 +52,7 @@ from airflow.models.variable import Variable from airflow.operators.branch import BranchMixIn from airflow.utils import hashlib_wrapper -from airflow.utils.context import DatasetEventAccessors, context_copy_partial, context_merge +from airflow.utils.context import context_copy_partial, context_get_dataset_events, context_merge from airflow.utils.file import get_unique_dag_module_name from airflow.utils.operator_helpers import ExecutionCallableRunner, KeywordParameters from airflow.utils.process_utils import execute_in_subprocess @@ -231,11 +231,8 @@ def __init__( def execute(self, context: Context) -> Any: context_merge(context, self.op_kwargs, templates_dict=self.templates_dict) self.op_kwargs = self.determine_kwargs(context) + self._dataset_events = context_get_dataset_events(context) - try: - self._dataset_events = context["dataset_events"] - except KeyError: - self._dataset_events = DatasetEventAccessors() return_value = self.execute_callable() if self.show_return_value_in_logs: self.log.info("Done. Returned value was: %s", return_value) diff --git a/airflow/utils/context.py b/airflow/utils/context.py index a1f0004ac38ff..78536bc97222a 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -355,3 +355,10 @@ def _create_value(k: str, v: Any) -> Any: return lazy_object_proxy.Proxy(factory) return {k: _create_value(k, v) for k, v in source._context.items()} + + +def context_get_dataset_events(context: Context) -> DatasetEventAccessors: + try: + return context["dataset_events"] + except KeyError: + return DatasetEventAccessors() diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index 8b5deb4746918..eb2cf6dd3e46f 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -127,3 +127,4 @@ def context_merge(context: Context, **kwargs: Any) -> None: ... def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: ... def context_copy_partial(source: Context, keys: Container[str]) -> Context: ... def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: ... +def context_get_dataset_events(context: Context) -> DatasetEventAccessors: ... diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py index 0aadae6b7b9b1..5ad893b36d1b3 100644 --- a/tests/operators/test_python.py +++ b/tests/operators/test_python.py @@ -1121,7 +1121,6 @@ def f( conf, dag, dag_run, - dataset_events, task, # other **context,