Skip to content

Commit

Permalink
Make dataset_events access safer
Browse files Browse the repository at this point in the history
The key is not always available (mostly in tests). Create a placeholder
if it's not.
  • Loading branch information
uranusjr committed Apr 3, 2024
1 parent b311549 commit a076ae5
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 13 deletions.
16 changes: 11 additions & 5 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
from airflow.utils.context import Context
from airflow.utils.context import Context, context_get_dataset_events
from airflow.utils.decorators import fixup_decorator_warning_stack
from airflow.utils.edgemodifier import EdgeModifier
from airflow.utils.helpers import validate_key
Expand Down Expand Up @@ -1272,8 +1272,11 @@ def pre_execute(self, context: Any):
"""Execute right before self.execute() is called."""
if self._pre_execute_hook is None:
return
runner = ExecutionCallableRunner(self._pre_execute_hook, context["dataset_events"], logger=self.log)
runner.run(context)
ExecutionCallableRunner(
self._pre_execute_hook,
context_get_dataset_events(context),
logger=self.log,
).run(context)

def execute(self, context: Context) -> Any:
"""
Expand All @@ -1294,8 +1297,11 @@ def post_execute(self, context: Any, result: Any = None):
"""
if self._post_execute_hook is None:
return
runner = ExecutionCallableRunner(self._post_execute_hook, context["dataset_events"], logger=self.log)
runner.run(context, result)
ExecutionCallableRunner(
self._post_execute_hook,
context_get_dataset_events(context),
logger=self.log,
).run(context, result)

def on_kill(self) -> None:
"""
Expand Down
8 changes: 6 additions & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
Context,
DatasetEventAccessors,
VariableAccessor,
context_get_dataset_events,
context_merge,
)
from airflow.utils.email import send_email
Expand Down Expand Up @@ -437,8 +438,11 @@ def _execute_callable(context: Context, **execute_callable_kwargs):
# Print a marker for log grouping of details before task execution
log.info("::endgroup::")

runner = ExecutionCallableRunner(execute_callable, context["dataset_events"], logger=log)
return runner.run(context=context, **execute_callable_kwargs)
return ExecutionCallableRunner(
execute_callable,
context_get_dataset_events(context),
logger=log,
).run(context=context, **execute_callable_kwargs)
except SystemExit as e:
# Handle only successful cases here. Failure cases will be handled upper
# in the exception chain.
Expand Down
7 changes: 2 additions & 5 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from airflow.models.variable import Variable
from airflow.operators.branch import BranchMixIn
from airflow.utils import hashlib_wrapper
from airflow.utils.context import DatasetEventAccessors, context_copy_partial, context_merge
from airflow.utils.context import context_copy_partial, context_get_dataset_events, context_merge
from airflow.utils.file import get_unique_dag_module_name
from airflow.utils.operator_helpers import ExecutionCallableRunner, KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
Expand Down Expand Up @@ -231,11 +231,8 @@ def __init__(
def execute(self, context: Context) -> Any:
context_merge(context, self.op_kwargs, templates_dict=self.templates_dict)
self.op_kwargs = self.determine_kwargs(context)
self._dataset_events = context_get_dataset_events(context)

try:
self._dataset_events = context["dataset_events"]
except KeyError:
self._dataset_events = DatasetEventAccessors()
return_value = self.execute_callable()
if self.show_return_value_in_logs:
self.log.info("Done. Returned value was: %s", return_value)
Expand Down
7 changes: 7 additions & 0 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,10 @@ def _create_value(k: str, v: Any) -> Any:
return lazy_object_proxy.Proxy(factory)

return {k: _create_value(k, v) for k, v in source._context.items()}


def context_get_dataset_events(context: Context) -> DatasetEventAccessors:
try:
return context["dataset_events"]
except KeyError:
return DatasetEventAccessors()
1 change: 1 addition & 0 deletions airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,4 @@ def context_merge(context: Context, **kwargs: Any) -> None: ...
def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: ...
def context_copy_partial(source: Context, keys: Container[str]) -> Context: ...
def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: ...
def context_get_dataset_events(context: Context) -> DatasetEventAccessors: ...
1 change: 0 additions & 1 deletion tests/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,6 @@ def f(
conf,
dag,
dag_run,
dataset_events,
task,
# other
**context,
Expand Down

0 comments on commit a076ae5

Please sign in to comment.