Skip to content

Commit

Permalink
add on skipped callback
Browse files Browse the repository at this point in the history
  • Loading branch information
romsharon98 committed Dec 22, 2023
1 parent 33ee0b9 commit fc549c5
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 0 deletions.
1 change: 1 addition & 0 deletions airflow/example_dags/tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
# 'on_success_callback': some_other_function, # or list of functions
# 'on_retry_callback': another_function, # or list of functions
# 'sla_miss_callback': yet_another_function, # or list of functions
# 'on_skipped_callback': another_function, #or list of functions
# 'trigger_rule': 'all_success'
},
# [END default_args]
Expand Down
7 changes: 7 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def partial(
on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
run_as_user: str | None | ArgNotSet = NOTSET,
executor_config: dict | None | ArgNotSet = NOTSET,
inlets: Any | None | ArgNotSet = NOTSET,
Expand Down Expand Up @@ -310,6 +311,7 @@ def partial(
"on_failure_callback": on_failure_callback,
"on_retry_callback": on_retry_callback,
"on_success_callback": on_success_callback,
"on_skipped_callback": on_skipped_callback,
"run_as_user": run_as_user,
"executor_config": executor_config,
"inlets": inlets,
Expand Down Expand Up @@ -597,6 +599,8 @@ class derived from this one results in the creation of a task object,
that it is executed when retries occur.
:param on_success_callback: much like the ``on_failure_callback`` except
that it is executed when the task succeeds.
:param on_skipped_callback: much like the ``on_failure_callback`` except
that it is executed when skipped occur.
:param pre_execute: a function to be called immediately before task
execution, receiving a context dictionary; raising an exception will
prevent the task from being executed.
Expand Down Expand Up @@ -700,6 +704,7 @@ class derived from this one results in the creation of a task object,
"on_failure_callback",
"on_success_callback",
"on_retry_callback",
"on_skipped_callback",
"do_xcom_push",
}

Expand Down Expand Up @@ -759,6 +764,7 @@ def __init__(
on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
pre_execute: TaskPreExecuteHook | None = None,
post_execute: TaskPostExecuteHook | None = None,
trigger_rule: str = DEFAULT_TRIGGER_RULE,
Expand Down Expand Up @@ -825,6 +831,7 @@ def __init__(
self.on_failure_callback = on_failure_callback
self.on_success_callback = on_success_callback
self.on_retry_callback = on_retry_callback
self.on_skipped_callback = on_skipped_callback
self._pre_execute_hook = pre_execute
self._post_execute_hook = post_execute

Expand Down
8 changes: 8 additions & 0 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,14 @@ def on_success_callback(self) -> None | TaskStateChangeCallback | list[TaskState
def on_success_callback(self, value: TaskStateChangeCallback | None) -> None:
self.partial_kwargs["on_success_callback"] = value

@property
def on_skipped_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
return self.partial_kwargs.get("on_skipped_callback")

@on_skipped_callback.setter
def on_skipped_callback(self, value: TaskStateChangeCallback | None) -> None:
self.partial_kwargs["on_skipped_callback"] = value

@property
def run_as_user(self) -> str | None:
return self.partial_kwargs.get("run_as_user")
Expand Down
2 changes: 2 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2361,6 +2361,8 @@ def _run_raw_task(
self.log.info(e)
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
_run_finished_callback(callbacks=self.task.on_skipped_callback, context=context)
session.commit()
self.state = TaskInstanceState.SKIPPED
except AirflowRescheduleException as reschedule_exception:
self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session)
Expand Down
25 changes: 25 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3172,6 +3172,31 @@ def test_clear_db_references(self, session, create_task_instance):

assert session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None

def test_skipped_task_call_on_skipped_callback(self, dag_maker):
"""
test that running task which returns AirflowSkipOperator will end
up in a SKIPPED state.
"""

def raise_skip_exception():
raise AirflowSkipException

callback_function = mock.MagicMock()

with dag_maker(dag_id="test_skipped_task"):
task = PythonOperator(
task_id="test_skipped_task",
python_callable=raise_skip_exception,
on_skipped_callback=callback_function,
)

dr = dag_maker.create_dagrun(execution_date=timezone.utcnow())
ti = dr.task_instances[0]
ti.task = task
ti.run()
assert State.SKIPPED == ti.state
assert callback_function.called


@pytest.mark.parametrize("pool_override", [None, "test_pool2"])
def test_refresh_from_task(pool_override):
Expand Down

0 comments on commit fc549c5

Please sign in to comment.