Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
BaseOperatorLink,
BaseSensorOperator,
conf,
timezone,
)
from airflow.providers.standard.exceptions import (
DuplicateStateError,
Expand Down Expand Up @@ -156,6 +157,9 @@ class ExternalTaskSensor(BaseSensorOperator):
:param allowed_states: Iterable of allowed states, default is ``['success']``
:param skipped_states: Iterable of states to make this task mark as skipped, default is ``None``
:param failed_states: Iterable of failed or dis-allowed states, default is ``None``
:param execution_date supports templated values using either:
``{{ logical_date }}`` (preferred)
``{{ execution_date }}`` (legacy)
:param execution_delta: time difference with the previous execution to
look at, the default is the same logical date as the current task or DAG.
For yesterday, use [positive!] datetime.timedelta(days=1). Either
Expand All @@ -174,7 +178,13 @@ class ExternalTaskSensor(BaseSensorOperator):
:param deferrable: Run sensor in deferrable mode
"""

template_fields = ["external_dag_id", "external_task_id", "external_task_ids", "external_task_group_id"]
template_fields = [
"external_dag_id",
"external_task_id",
"external_task_ids",
"external_task_group_id",
"execution_date",
]
ui_color = "#4db7db"
operator_extra_links = [ExternalDagLink()]

Expand All @@ -188,6 +198,7 @@ def __init__(
allowed_states: Iterable[str] | None = None,
skipped_states: Iterable[str] | None = None,
failed_states: Iterable[str] | None = None,
execution_date: str | datetime.datetime | None = None,
execution_delta: datetime.timedelta | None = None,
execution_date_fn: Callable | None = None,
check_existence: bool = False,
Expand Down Expand Up @@ -248,12 +259,22 @@ def __init__(
f"when `external_task_id` and `external_task_group_id` is `None`: {State.dag_states}"
)

if execution_delta is not None and execution_date_fn is not None:
if sum(x is not None for x in (execution_delta, execution_date_fn, execution_date)) > 1:
raise ValueError(
"Only one of `execution_delta` or `execution_date_fn` may "
"be provided to ExternalTaskSensor; not both."
"Only one of `execution_delta`, `execution_date` or `execution_date_fn` may "
"be provided to ExternalTaskSensor."
)

if execution_delta is not None:
warnings.warn(
"`execution_delta` is deprecated. Use `execution_date`.", DeprecationWarning, stacklevel=1
)

if execution_date_fn is not None:
warnings.warn(
"`execution_date_fn` is deprecated. Use `execution_date`.", DeprecationWarning, stacklevel=1
)
self.execution_date = execution_date
self.execution_delta = execution_delta
self.execution_date_fn = execution_date_fn
self.external_dag_id = external_dag_id
Expand All @@ -267,13 +288,23 @@ def __init__(
self.external_dates_filter: str | None = None

def _get_dttm_filter(self, context: Context) -> Sequence[datetime.datetime]:
execution_date_value = self.execution_date

if execution_date_value is not None:
if isinstance(execution_date_value, datetime.datetime):
return [execution_date_value]

return [timezone.parse(execution_date_value)]

logical_date = self._get_logical_date(context)

if self.execution_delta:
if self.execution_delta is not None:
return [logical_date - self.execution_delta]

if self.execution_date_fn:
result = self._handle_execution_date_fn(context=context)
return result if isinstance(result, list) else [result]

return [logical_date]

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import logging
import re
from datetime import time, timedelta
from typing import Any
from unittest import mock

import pytest
Expand Down Expand Up @@ -671,7 +672,7 @@ def test_external_task_sensor_error_delta_and_fn(self):
# Test that providing execution_delta and a function raises an error
with pytest.raises(
ValueError,
match="Only one of `execution_delta` or `execution_date_fn` may be provided to ExternalTaskSensor; not both.",
match="Only one of `execution_delta`, `execution_date` or `execution_date_fn` may be provided to ExternalTaskSensor.",
):
ExternalTaskSensor(
task_id="test_external_task_sensor_check_delta",
Expand Down Expand Up @@ -1395,6 +1396,67 @@ def test_external_task_sensor_execution_delta(self, dag_maker):
)
assert op.external_dates_filter == expected_date.isoformat()

@pytest.mark.execution_timeout(10)
def test_handle_execution_date(self, dag_maker) -> None:
for param in ["logical_date", "execution_date"]:
with dag_maker("test_dag_child"):
op = ExternalTaskSensor(
task_id=f"test_external_task_sensor_check-{param}",
external_dag_id="test_dag_parent",
external_task_id="test_task",
execution_date=f"{{{{ {param} - macros.timedelta(hours=1) }}}}",
allowed_states=["success"],
)

import airflow.macros as macros

ctx: dict[str, Any] = self.context
ctx["macros"] = macros # ensure template rendering works

op.render_template_fields(ctx)
ti = self.context["ti"]
ti.get_ti_count.return_value = 1
op.execute(context=self.context)

expected_date = DEFAULT_DATE - timedelta(hours=1)
ti.get_ti_count.assert_has_calls(
[
mock.call(
dag_id="test_dag_parent",
logical_dates=[expected_date],
states=["success"],
task_ids=["test_task"],
)
]
)
assert op.external_dates_filter == expected_date.isoformat()

@pytest.mark.execution_timeout(3)
def test_external_task_sensor_error_delta_and_execution_date(self) -> None:
override_candidates = [
(
"execution_delta",
timedelta(seconds=123),
),
("execution_date", "{{ logical_date - macros.timedelta(hours=1) }}"),
("execution_date_fn", lambda dt: dt),
]

for r in range(2, 4):
for overrides in itertools.combinations(override_candidates, r):
with DAG("test_external_task_sensor_error_delta_and_execution_date"):
with pytest.raises(
ValueError,
match="Only one of `execution_delta`, `execution_date` or `execution_date_fn` may be provided to ExternalTaskSensor.",
):
ExternalTaskSensor(
task_id="test_external_task_sensor_error_delta_and_execution_date",
external_dag_id="test_dag_parent",
external_task_id="test_task",
allowed_states=["success"],
**dict(overrides),
)

@pytest.mark.execution_timeout(10)
def test_external_task_sensor_duplicate_task_ids(self, dag_maker):
with dag_maker("test_dag_child"):
Expand Down