Skip to content

Commit

Permalink
Handle SystemExit raised in the task. (#36986)
Browse files Browse the repository at this point in the history
* Handle SystemExit raised in the task.

* Add handling of system exit in tasks:
* Exiting with a zero or None code signifies success, and the task does not return any value.
* Exiting with other codes signifies an error.

(cherry picked from commit 574d90f)
  • Loading branch information
avkirilishin authored and ephraimbuddy committed Feb 20, 2024
1 parent ff10401 commit 98cc571
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
22 changes: 20 additions & 2 deletions airflow/models/taskinstance.py
Expand Up @@ -408,6 +408,17 @@ def _execute_task(task_instance, context, task_orig):
execute_callable_kwargs["next_kwargs"] = task_instance.next_kwargs
else:
execute_callable = task_to_execute.execute

def _execute_callable(context, **execute_callable_kwargs):
try:
return execute_callable(context=context, **execute_callable_kwargs)
except SystemExit as e:
# Handle only successful cases here. Failure cases will be handled upper
# in the exception chain.
if e.code is not None and e.code != 0:
raise
return None

# If a timeout is specified for the task, make it fail
# if it goes beyond
if task_to_execute.execution_timeout:
Expand All @@ -425,12 +436,12 @@ def _execute_task(task_instance, context, task_orig):
raise AirflowTaskTimeout()
# Run task in timeout wrapper
with timeout(timeout_seconds):
result = execute_callable(context=context, **execute_callable_kwargs)
result = _execute_callable(context=context, **execute_callable_kwargs)
except AirflowTaskTimeout:
task_to_execute.on_kill()
raise
else:
result = execute_callable(context=context, **execute_callable_kwargs)
result = _execute_callable(context=context, **execute_callable_kwargs)
with create_session() as session:
if task_to_execute.do_xcom_push:
xcom_value = result
Expand Down Expand Up @@ -2402,6 +2413,13 @@ def _run_raw_task(
self.handle_failure(e, test_mode, context, session=session)
session.commit()
raise
except SystemExit as e:
# We have already handled SystemExit with success codes (0 and None) in the `_execute_task`.
# Therefore, here we must handle only error codes.
msg = f"Task failed due to SystemExit({e.code})"
self.handle_failure(msg, test_mode, context, session=session)
session.commit()
raise Exception(msg)
finally:
Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", tags=self.stats_tags)
# Same metric with tagging
Expand Down
31 changes: 31 additions & 0 deletions tests/models/test_taskinstance.py
Expand Up @@ -2924,6 +2924,37 @@ def test_echo_env_variables(self, dag_maker):
ti.refresh_from_db()
assert ti.state == State.SUCCESS

@pytest.mark.parametrize(
"code, expected_state",
[
(1, State.FAILED),
(-1, State.FAILED),
("error", State.FAILED),
(0, State.SUCCESS),
(None, State.SUCCESS),
],
)
def test_handle_system_exit(self, dag_maker, code, expected_state):
with dag_maker():

def f(*args, **kwargs):
exit(code)

task = PythonOperator(task_id="mytask", python_callable=f)

dr = dag_maker.create_dagrun()
ti = TI(task=task, run_id=dr.run_id)
ti.state = State.RUNNING
session = settings.Session()
session.merge(ti)
session.commit()
try:
ti._run_raw_task()
except Exception:
...
ti.refresh_from_db()
assert ti.state == expected_state

def test_get_current_context_works_in_template(self, dag_maker):
def user_defined_macro():
from airflow.operators.python import get_current_context
Expand Down

0 comments on commit 98cc571

Please sign in to comment.