diff --git a/airflow/providers/celery/sensors/celery_queue.py b/airflow/providers/celery/sensors/celery_queue.py index d3ca42794c300..f9b09e57d33b6 100644 --- a/airflow/providers/celery/sensors/celery_queue.py +++ b/airflow/providers/celery/sensors/celery_queue.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. +from typing import Any, Dict, Optional + from celery.app import control from airflow.sensors.base_sensor_operator import BaseSensorOperator @@ -36,16 +38,16 @@ class CeleryQueueSensor(BaseSensorOperator): @apply_defaults def __init__( self, - celery_queue, - target_task_id=None, + celery_queue: str, + target_task_id: Optional[str] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.celery_queue = celery_queue self.target_task_id = target_task_id - def _check_task_id(self, context): + def _check_task_id(self, context: Dict[str, Any]) -> bool: """ Gets the returned Celery result from the Airflow task ID provided to the sensor, and returns True if the @@ -60,7 +62,7 @@ def _check_task_id(self, context): celery_result = ti.xcom_pull(task_ids=self.target_task_id) return celery_result.ready() - def poke(self, context): + def poke(self, context: Dict[str, Any]) -> bool: if self.target_task_id: return self._check_task_id(context)