diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 981f908619811..112b239653958 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -880,6 +880,7 @@ def _handle_failure( test_mode: bool | None = None, context: Context | None = None, force_fail: bool = False, + fail_stop: bool = False, ) -> None: """ Handle Failure for a task instance. @@ -903,6 +904,7 @@ def _handle_failure( context=context, force_fail=force_fail, session=session, + fail_stop=fail_stop, ) _log_state(task_instance=task_instance, lead_msg="Immediate failure requested. " if force_fail else "") @@ -2966,8 +2968,13 @@ def fetch_handle_failure_context( context: Context | None = None, force_fail: bool = False, session: Session = NEW_SESSION, + fail_stop: bool = False, ): - """Handle Failure for the TaskInstance.""" + """ + Handle Failure for the TaskInstance. + + :param fail_stop: if true, stop remaining tasks in dag + """ get_listener_manager().hook.on_task_instance_failed( previous_state=TaskInstanceState.RUNNING, task_instance=ti, error=error, session=session ) @@ -3030,7 +3037,7 @@ def fetch_handle_failure_context( email_for_state = operator.attrgetter("email_on_failure") callbacks = task.on_failure_callback if task else None - if task and task.dag and task.dag.fail_stop: + if task and fail_stop: _stop_remaining_tasks(task_instance=ti, session=session) else: if ti.state == TaskInstanceState.QUEUED: @@ -3079,6 +3086,13 @@ def handle_failure( :param context: Jinja2 context :param force_fail: if True, task does not retry """ + if TYPE_CHECKING: + assert self.task + assert self.task.dag + try: + fail_stop = self.task.dag.fail_stop + except Exception: + fail_stop = False _handle_failure( task_instance=self, error=error, @@ -3086,6 +3100,7 @@ def handle_failure( test_mode=test_mode, context=context, force_fail=force_fail, + fail_stop=fail_stop, ) def is_eligible_to_retry(self): diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index cc52aa9989f3d..2e01bf415a141 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -276,6 +276,13 @@ def handle_failure( """ from airflow.models.taskinstance import _handle_failure + if TYPE_CHECKING: + assert self.task + assert self.task.dag + try: + fail_stop = self.task.dag.fail_stop + except Exception: + fail_stop = False _handle_failure( task_instance=self, error=error, @@ -283,6 +290,7 @@ def handle_failure( test_mode=test_mode, context=context, force_fail=force_fail, + fail_stop=fail_stop, ) def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None: