diff --git a/airflow/providers/celery/sensors/celery_queue.py b/airflow/providers/celery/sensors/celery_queue.py index 4533217bff464..9800ccdb5b09b 100644 --- a/airflow/providers/celery/sensors/celery_queue.py +++ b/airflow/providers/celery/sensors/celery_queue.py @@ -21,6 +21,7 @@ from celery.app import control +from airflow.exceptions import AirflowSkipException from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -39,7 +40,6 @@ class CeleryQueueSensor(BaseSensorOperator): """ def __init__(self, *, celery_queue: str, target_task_id: str | None = None, **kwargs) -> None: - super().__init__(**kwargs) self.celery_queue = celery_queue self.target_task_id = target_task_id @@ -56,7 +56,6 @@ def _check_task_id(self, context: Context) -> bool: return celery_result.ready() def poke(self, context: Context) -> bool: - if self.target_task_id: return self._check_task_id(context) @@ -74,4 +73,13 @@ def poke(self, context: Context) -> bool: return reserved == 0 and scheduled == 0 and active == 0 except KeyError: - raise KeyError(f"Could not locate Celery queue {self.celery_queue}") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Could not locate Celery queue {self.celery_queue}" + if self.soft_fail: + raise AirflowSkipException(message) + raise KeyError(message) + except Exception as err: + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException from err + raise diff --git a/tests/providers/celery/sensors/test_celery_queue.py b/tests/providers/celery/sensors/test_celery_queue.py index 8d09085352adf..f2d619e50ca9d 100644 --- a/tests/providers/celery/sensors/test_celery_queue.py +++ b/tests/providers/celery/sensors/test_celery_queue.py @@ -19,6 +19,9 @@ from unittest.mock import patch +import pytest + +from airflow.exceptions import AirflowSkipException from airflow.providers.celery.sensors.celery_queue import CeleryQueueSensor @@ -54,6 +57,20 @@ def test_poke_fail(self, mock_inspect): test_sensor = self.sensor(celery_queue="test_queue", task_id="test-task") assert not test_sensor.poke(None) + @pytest.mark.parametrize( + "soft_fail, expected_exception", ((False, KeyError), (True, AirflowSkipException)) + ) + @patch("celery.app.control.Inspect") + def test_poke_fail_with_exception(self, mock_inspect, soft_fail, expected_exception): + mock_inspect_result = mock_inspect.return_value + mock_inspect_result.reserved.return_value = {} + mock_inspect_result.scheduled.return_value = {} + mock_inspect_result.active.return_value = {} + + with pytest.raises(expected_exception): + test_sensor = self.sensor(celery_queue="test_queue", task_id="test-task", soft_fail=soft_fail) + test_sensor.poke(None) + @patch("celery.app.control.Inspect") def test_poke_success_with_taskid(self, mock_inspect): test_sensor = self.sensor(