Skip to content

Commit

Permalink
Change AirflowTaskTimeout to inherit BaseException (#35653)
Browse files Browse the repository at this point in the history
Code that normally catches Exception should not implicitly ignore
interrupts from AirflowTaskTimout.

Fixes #35644 #35474
  • Loading branch information
hterik committed Feb 21, 2024
1 parent 69d48ed commit 581e2e4
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 15 deletions.
5 changes: 4 additions & 1 deletion airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ class InvalidStatsNameException(AirflowException):
"""Raise when name of the stats is invalid."""


class AirflowTaskTimeout(AirflowException):
# Important to inherit BaseException instead of AirflowException->Exception, since this Exception is used
# to explicitly interrupt ongoing task. Code that does normal error-handling should not treat
# such interrupt as an error that can be handled normally. (Compare with KeyboardInterrupt)
class AirflowTaskTimeout(BaseException):
"""Raise when the task execution times-out."""


Expand Down
16 changes: 8 additions & 8 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def _is_eligible_to_retry(*, task_instance: TaskInstance | TaskInstancePydantic)
def _handle_failure(
*,
task_instance: TaskInstance | TaskInstancePydantic,
error: None | str | Exception | KeyboardInterrupt,
error: None | str | BaseException,
session: Session,
test_mode: bool | None = None,
context: Context | None = None,
Expand Down Expand Up @@ -2411,7 +2411,7 @@ def _run_raw_task(
self.handle_failure(e, test_mode, context, force_fail=True, session=session)
session.commit()
raise
except AirflowException as e:
except (AirflowTaskTimeout, AirflowException) as e:
if not test_mode:
self.refresh_from_db(lock_for_update=True, session=session)
# for case when task is marked as success/failed externally
Expand All @@ -2426,17 +2426,17 @@ def _run_raw_task(
self.handle_failure(e, test_mode, context, session=session)
session.commit()
raise
except (Exception, KeyboardInterrupt) as e:
self.handle_failure(e, test_mode, context, session=session)
session.commit()
raise
except SystemExit as e:
# We have already handled SystemExit with success codes (0 and None) in the `_execute_task`.
# Therefore, here we must handle only error codes.
msg = f"Task failed due to SystemExit({e.code})"
self.handle_failure(msg, test_mode, context, session=session)
session.commit()
raise Exception(msg)
except BaseException as e:
self.handle_failure(e, test_mode, context, session=session)
session.commit()
raise
finally:
Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", tags=self.stats_tags)
# Same metric with tagging
Expand Down Expand Up @@ -2743,7 +2743,7 @@ def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -
def fetch_handle_failure_context(
cls,
ti: TaskInstance | TaskInstancePydantic,
error: None | str | Exception | KeyboardInterrupt,
error: None | str | BaseException,
test_mode: bool | None = None,
context: Context | None = None,
force_fail: bool = False,
Expand Down Expand Up @@ -2838,7 +2838,7 @@ def save_to_db(ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_S
@provide_session
def handle_failure(
self,
error: None | str | Exception | KeyboardInterrupt,
error: None | str | BaseException,
test_mode: bool | None = None,
context: Context | None = None,
force_fail: bool = False,
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/celery/executors/celery_executor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

import airflow.settings as settings
from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowTaskTimeout
from airflow.executors.base_executor import BaseExecutor
from airflow.stats import Stats
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
Expand Down Expand Up @@ -198,7 +198,7 @@ class ExceptionWithTraceback:
:param exception_traceback: The stacktrace to wrap
"""

def __init__(self, exception: Exception, exception_traceback: str):
def __init__(self, exception: BaseException, exception_traceback: str):
self.exception = exception
self.traceback = exception_traceback

Expand All @@ -211,7 +211,7 @@ def send_task_to_executor(
try:
with timeout(seconds=OPERATION_TIMEOUT):
result = task_to_run.apply_async(args=[command], queue=queue)
except Exception as e:
except (Exception, AirflowTaskTimeout) as e:
exception_traceback = f"Celery Task ID: {key}\n{traceback.format_exc()}"
result = ExceptionWithTraceback(e, exception_traceback)

Expand Down
2 changes: 1 addition & 1 deletion airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class Context(TypedDict, total=False):
data_interval_start: DateTime
ds: str
ds_nodash: str
exception: KeyboardInterrupt | Exception | str | None
exception: BaseException | str | None
execution_date: DateTime
expanded_ti_count: int | None
inlets: list
Expand Down
21 changes: 21 additions & 0 deletions newsfragments/35653.significant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
``AirflowTimeoutError`` is no longer ``except``ed by default through ``Exception``

The ``AirflowTimeoutError`` is now inheriting ``BaseException`` instead of
``AirflowException``->``Exception``.
See https://docs.python.org/3/library/exceptions.html#exception-hierarchy

This prevents code catching ``Exception`` from accidentally
catching ``AirflowTimeoutError`` and continuing to run.
``AirflowTimeoutError`` is an explicit intent to cancel the task, and should not
be caught in attempts to handle the error and return some default value.

Catching ``AirflowTimeoutError`` is still possible by explicitly ``except``ing
``AirflowTimeoutError`` or ``BaseException``.
This is discouraged, as it may allow the code to continue running even after
such cancellation requests.
Code that previously depended on performing strict cleanup in every situation
after catching ``Exception`` is advised to use ``finally`` blocks or
context managers. To perform only the cleanup and then automatically
re-raise the exception.
See similar considerations about catching ``KeyboardInterrupt`` in
https://docs.python.org/3/library/exceptions.html#KeyboardInterrupt
9 changes: 8 additions & 1 deletion tests/core/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,18 @@ class InvalidTemplateFieldOperator(BaseOperator):
op.dry_run()

def test_timeout(self, dag_maker):
def sleep_and_catch_other_exceptions():
try:
sleep(5)
# Catching Exception should NOT catch AirflowTaskTimeout
except Exception:
pass

with dag_maker():
op = PythonOperator(
task_id="test_timeout",
execution_timeout=timedelta(seconds=1),
python_callable=lambda: sleep(5),
python_callable=sleep_and_catch_other_exceptions,
)
dag_maker.create_dagrun()
with pytest.raises(AirflowTaskTimeout):
Expand Down
3 changes: 2 additions & 1 deletion tests/providers/microsoft/azure/hooks/test_synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest
from azure.synapse.spark import SparkClient

from airflow.exceptions import AirflowTaskTimeout
from airflow.models.connection import Connection
from airflow.providers.microsoft.azure.hooks.synapse import AzureSynapseHook, AzureSynapseSparkBatchRunStatus

Expand Down Expand Up @@ -172,7 +173,7 @@ def test_wait_for_job_run_status(hook, job_run_status, expected_status, expected
if expected_output != "timeout":
assert hook.wait_for_job_run_status(**config) == expected_output
else:
with pytest.raises(Exception):
with pytest.raises(AirflowTaskTimeout):
hook.wait_for_job_run_status(**config)


Expand Down

0 comments on commit 581e2e4

Please sign in to comment.