Skip to content

Commit

Permalink
improve type hinting for celery provider (#9762)
Browse files Browse the repository at this point in the history
  • Loading branch information
morrme committed Jul 11, 2020
1 parent a6b04d7 commit 5bb228d
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions airflow/providers/celery/sensors/celery_queue.py
Expand Up @@ -16,6 +16,8 @@
# specific language governing permissions and limitations
# under the License.

from typing import Any, Dict, Optional

from celery.app import control

from airflow.sensors.base_sensor_operator import BaseSensorOperator
Expand All @@ -36,16 +38,16 @@ class CeleryQueueSensor(BaseSensorOperator):
@apply_defaults
def __init__(
self,
celery_queue,
target_task_id=None,
celery_queue: str,
target_task_id: Optional[str] = None,
*args,
**kwargs):
**kwargs) -> None:

super().__init__(*args, **kwargs)
self.celery_queue = celery_queue
self.target_task_id = target_task_id

def _check_task_id(self, context):
def _check_task_id(self, context: Dict[str, Any]) -> bool:
"""
Gets the returned Celery result from the Airflow task
ID provided to the sensor, and returns True if the
Expand All @@ -60,7 +62,7 @@ def _check_task_id(self, context):
celery_result = ti.xcom_pull(task_ids=self.target_task_id)
return celery_result.ready()

def poke(self, context):
def poke(self, context: Dict[str, Any]) -> bool:

if self.target_task_id:
return self._check_task_id(context)
Expand Down

0 comments on commit 5bb228d

Please sign in to comment.