Skip to content

Commit

Permalink
respect soft_fail argument when exception is raised for celery sensors (
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Sep 25, 2023
1 parent a1ef232 commit f19e055
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
14 changes: 11 additions & 3 deletions airflow/providers/celery/sensors/celery_queue.py
Expand Up @@ -21,6 +21,7 @@

from celery.app import control

from airflow.exceptions import AirflowSkipException
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
Expand All @@ -39,7 +40,6 @@ class CeleryQueueSensor(BaseSensorOperator):
"""

def __init__(self, *, celery_queue: str, target_task_id: str | None = None, **kwargs) -> None:

super().__init__(**kwargs)
self.celery_queue = celery_queue
self.target_task_id = target_task_id
Expand All @@ -56,7 +56,6 @@ def _check_task_id(self, context: Context) -> bool:
return celery_result.ready()

def poke(self, context: Context) -> bool:

if self.target_task_id:
return self._check_task_id(context)

Expand All @@ -74,4 +73,13 @@ def poke(self, context: Context) -> bool:

return reserved == 0 and scheduled == 0 and active == 0
except KeyError:
raise KeyError(f"Could not locate Celery queue {self.celery_queue}")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Could not locate Celery queue {self.celery_queue}"
if self.soft_fail:
raise AirflowSkipException(message)
raise KeyError(message)
except Exception as err:
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
if self.soft_fail:
raise AirflowSkipException from err
raise
17 changes: 17 additions & 0 deletions tests/providers/celery/sensors/test_celery_queue.py
Expand Up @@ -19,6 +19,9 @@

from unittest.mock import patch

import pytest

from airflow.exceptions import AirflowSkipException
from airflow.providers.celery.sensors.celery_queue import CeleryQueueSensor


Expand Down Expand Up @@ -54,6 +57,20 @@ def test_poke_fail(self, mock_inspect):
test_sensor = self.sensor(celery_queue="test_queue", task_id="test-task")
assert not test_sensor.poke(None)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, KeyError), (True, AirflowSkipException))
)
@patch("celery.app.control.Inspect")
def test_poke_fail_with_exception(self, mock_inspect, soft_fail, expected_exception):
mock_inspect_result = mock_inspect.return_value
mock_inspect_result.reserved.return_value = {}
mock_inspect_result.scheduled.return_value = {}
mock_inspect_result.active.return_value = {}

with pytest.raises(expected_exception):
test_sensor = self.sensor(celery_queue="test_queue", task_id="test-task", soft_fail=soft_fail)
test_sensor.poke(None)

@patch("celery.app.control.Inspect")
def test_poke_success_with_taskid(self, mock_inspect):
test_sensor = self.sensor(
Expand Down

0 comments on commit f19e055

Please sign in to comment.