diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 1b76145a469e2..e39f54a8bb0f3 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1360,18 +1360,27 @@ def _run_finished_callback(self, error: Optional[Union[str, Exception]] = None) if task.on_failure_callback is not None: context = self.get_template_context() context["exception"] = error - task.on_failure_callback(context) + try: + task.on_failure_callback(context) + except Exception: + self.log.exception("Error when executing on_failure_callback") elif self.state == State.SUCCESS: task = self.task if task.on_success_callback is not None: context = self.get_template_context() - task.on_success_callback(context) + try: + task.on_success_callback(context) + except Exception: + self.log.exception("Error when executing on_success_callback") elif self.state == State.UP_FOR_RETRY: task = self.task if task.on_retry_callback is not None: context = self.get_template_context() context["exception"] = error - task.on_retry_callback(context) + try: + task.on_retry_callback(context) + except Exception: + self.log.exception("Error when executing on_retry_callback") @provide_session def run( diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 7d64c76c5a92f..2e5eb8a88d43c 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1743,6 +1743,45 @@ def on_execute_callable(context): ti.refresh_from_db() assert ti.state == State.SUCCESS + @parameterized.expand( + [ + (State.SUCCESS, "Error when executing on_success_callback"), + (State.UP_FOR_RETRY, "Error when executing on_retry_callback"), + (State.FAILED, "Error when executing on_failure_callback"), + ] + ) + def test_finished_callbacks_handle_and_log_exception(self, finished_state, expected_message): + called = completed = False + + def on_finish_callable(context): + nonlocal called, completed + called = True + raise KeyError + completed = True + + dag = DAG( + 'test_success_callback_handles_exception', + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10), + ) + task = DummyOperator( + task_id='op', + email='test@test.test', + on_success_callback=on_finish_callable, + on_retry_callback=on_finish_callable, + on_failure_callback=on_finish_callable, + dag=dag, + ) + + ti = TI(task=task, execution_date=datetime.datetime.now()) + ti._log = mock.Mock() + ti.state = finished_state + ti._run_finished_callback() + + assert called + assert not completed + ti.log.exception.assert_called_once_with(expected_message) + def test_handle_failure(self): start_date = timezone.datetime(2016, 6, 1) dag = models.DAG(dag_id="test_handle_failure", schedule_interval=None, start_date=start_date)