From 98343ee79a380b866c7c089e2a77557d3c386767 Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 25 Apr 2024 07:06:41 -0700 Subject: [PATCH] Determine fail_stop on client side when db isolated This is needed because we do not ser the dag on Operator objects. (cherry picked from commit 00ff95c27f68e1e1564b01dbd3fbc22207976ab7) --- airflow/models/taskinstance.py | 15 +++++++++++++-- airflow/serialization/pydantic/taskinstance.py | 4 ++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 2c6bc2d100a5a0..b0868c69a7079a 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -878,6 +878,7 @@ def _handle_failure( test_mode: bool | None = None, context: Context | None = None, force_fail: bool = False, + fail_stop: bool = False, ) -> None: """ Handle Failure for a task instance. @@ -901,6 +902,7 @@ def _handle_failure( context=context, force_fail=force_fail, session=session, + fail_stop=fail_stop, ) _log_state(task_instance=task_instance, lead_msg="Immediate failure requested. " if force_fail else "") @@ -2961,9 +2963,14 @@ def fetch_handle_failure_context( test_mode: bool | None = None, context: Context | None = None, force_fail: bool = False, + fail_stop: bool = False, session: Session = NEW_SESSION, ): - """Handle Failure for the TaskInstance.""" + """ + Handle Failure for the TaskInstance. + + :param fail_stop: if true, stop remaining tasks in dag + """ get_listener_manager().hook.on_task_instance_failed( previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error, session=session ) @@ -3026,7 +3033,7 @@ def fetch_handle_failure_context( email_for_state = operator.attrgetter("email_on_failure") callbacks = task.on_failure_callback if task else None - if task and task.dag and task.dag.fail_stop: + if task and fail_stop: _stop_remaining_tasks(task_instance=ti, session=session) else: if ti.state == TaskInstanceState.QUEUED: @@ -3075,6 +3082,9 @@ def handle_failure( :param context: Jinja2 context :param force_fail: if True, task does not retry """ + if TYPE_CHECKING: + assert self.task + assert self.task.dag _handle_failure( task_instance=self, error=error, @@ -3082,6 +3092,7 @@ def handle_failure( test_mode=test_mode, context=context, force_fail=force_fail, + fail_stop=self.task.dag.fail_stop, ) def is_eligible_to_retry(self): diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index cc52aa9989f3d6..1d09b37a8dadcd 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -276,6 +276,9 @@ def handle_failure( """ from airflow.models.taskinstance import _handle_failure + if TYPE_CHECKING: + assert self.task + assert self.task.dag _handle_failure( task_instance=self, error=error, @@ -283,6 +286,7 @@ def handle_failure( test_mode=test_mode, context=context, force_fail=force_fail, + fail_stop=self.task.dag.fail_stop, ) def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None: