Skip to content

Commit

Permalink
Add log for running callback (#38892)
Browse files Browse the repository at this point in the history
* add log for running callback

* get callback name before try statement

Co-authored-by: Andrey Anshin <Andrey.Anshin@taragol.is>

* add tests

* fix test

* change logging

* fix tests

---------

Co-authored-by: Andrey Anshin <Andrey.Anshin@taragol.is>
  • Loading branch information
romsharon98 and Taragolis committed Apr 17, 2024
1 parent 0fbe711 commit ebd65ce
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 19 deletions.
5 changes: 2 additions & 3 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@
from airflow.utils.email import send_email
from airflow.utils.helpers import prune_dict, render_template_to_string
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import qualname
from airflow.utils.net import get_hostname
from airflow.utils.operator_helpers import ExecutionCallableRunner, context_to_airflow_vars
from airflow.utils.platform import getuser
Expand Down Expand Up @@ -1230,11 +1229,11 @@ def _run_finished_callback(
if callbacks:
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
for callback in callbacks:
log.info("Executing %s callback", callback.__name__)
try:
callback(context)
except Exception:
callback_name = qualname(callback).split(".")[-1]
log.exception("Error when executing %s callback", callback_name) # type: ignore[attr-defined]
log.exception("Error when executing %s callback", callback.__name__) # type: ignore[attr-defined]


def _log_state(*, task_instance: TaskInstance | TaskInstancePydantic, lead_msg: str = "") -> None:
Expand Down
30 changes: 14 additions & 16 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2849,35 +2849,24 @@ def on_execute_callable(context):
ti.refresh_from_db()
assert ti.state == State.SUCCESS

@pytest.mark.parametrize(
"finished_state",
[
State.SUCCESS,
State.UP_FOR_RETRY,
State.FAILED,
],
)
@patch("logging.Logger.exception")
def test_finished_callbacks_handle_and_log_exception(
self, mock_log, finished_state, create_task_instance
):
called = completed = False

def test_finished_callbacks_handle_and_log_exception(self, caplog):
def on_finish_callable(context):
nonlocal called, completed
called = True
raise KeyError
completed = True

for callback_input in [[on_finish_callable], on_finish_callable]:
called = completed = False
caplog.clear()
_run_finished_callback(callbacks=callback_input, context={})

assert called
assert not completed
callback_name = callback_input[0] if isinstance(callback_input, list) else callback_input
callback_name = qualname(callback_name).split(".")[-1]
expected_message = "Error when executing %s callback"
mock_log.assert_called_with(expected_message, callback_name)
assert "Executing on_finish_callable callback" in caplog.text
assert "Error when executing on_finish_callable callback" in caplog.text

@provide_session
def test_handle_failure(self, create_dummy_dag, session=None):
Expand All @@ -2890,7 +2879,9 @@ def test_handle_failure(self, create_dummy_dag, session=None):
get_listener_manager().pm.hook.on_task_instance_failed = listener_callback_on_error

mock_on_failure_1 = mock.MagicMock()
mock_on_failure_1.__name__ = "mock_on_failure_1"
mock_on_retry_1 = mock.MagicMock()
mock_on_retry_1.__name__ = "mock_on_retry_1"
dag, task1 = create_dummy_dag(
dag_id="test_handle_failure",
schedule=None,
Expand Down Expand Up @@ -2927,7 +2918,9 @@ def test_handle_failure(self, create_dummy_dag, session=None):
mock_on_retry_1.assert_not_called()

mock_on_failure_2 = mock.MagicMock()
mock_on_failure_2.__name__ = "mock_on_failure_2"
mock_on_retry_2 = mock.MagicMock()
mock_on_retry_2.__name__ = "mock_on_retry_2"
task2 = EmptyOperator(
task_id="test_handle_failure_on_retry",
on_failure_callback=mock_on_failure_2,
Expand All @@ -2949,7 +2942,9 @@ def test_handle_failure(self, create_dummy_dag, session=None):

# test the scenario where normally we would retry but have been asked to fail
mock_on_failure_3 = mock.MagicMock()
mock_on_failure_3.__name__ = "mock_on_failure_3"
mock_on_retry_3 = mock.MagicMock()
mock_on_retry_3.__name__ = "mock_on_retry_3"
task3 = EmptyOperator(
task_id="test_handle_failure_on_force_fail",
on_failure_callback=mock_on_failure_3,
Expand Down Expand Up @@ -3465,6 +3460,7 @@ def raise_skip_exception():
raise AirflowSkipException

callback_function = mock.MagicMock()
callback_function.__name__ = "callback_function"

with dag_maker(dag_id="test_skipped_task"):
task = PythonOperator(
Expand Down Expand Up @@ -3560,6 +3556,7 @@ def timeout():
raise AirflowSensorTimeout

mock_on_failure = mock.MagicMock()
mock_on_failure.__name__ = "mock_on_failure"
with dag_maker(dag_id=f"test_sensor_timeout_{mode}_{retries}"):
PythonSensor(
task_id="test_raise_sensor_timeout",
Expand Down Expand Up @@ -3588,6 +3585,7 @@ def timeout():
raise AirflowSensorTimeout

mock_on_failure = mock.MagicMock()
mock_on_failure.__name__ = "mock_on_failure"
with dag_maker(dag_id=f"test_sensor_timeout_{mode}_{retries}"):
PythonSensor.partial(
task_id="test_raise_sensor_timeout",
Expand Down

0 comments on commit ebd65ce

Please sign in to comment.