Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use async db calls in WorkflowTrigger #38689

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions airflow/triggers/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Any

from asgiref.sync import sync_to_async
from deprecated import deprecated
from sqlalchemy import func

from airflow.models import DagRun, TaskInstance
Expand Down Expand Up @@ -98,44 +99,43 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]:
"""Check periodically tasks, task group or dag status."""
while True:
if self.failed_states:
failed_count = _get_count(
self.execution_dates,
self.external_task_ids,
self.external_task_group_id,
self.external_dag_id,
self.failed_states,
)
failed_count = await self._get_count(self.failed_states)
if failed_count > 0:
yield TriggerEvent({"status": "failed"})
return
else:
yield TriggerEvent({"status": "success"})
return
if self.skipped_states:
skipped_count = _get_count(
self.execution_dates,
self.external_task_ids,
self.external_task_group_id,
self.external_dag_id,
self.skipped_states,
)
skipped_count = await self._get_count(self.skipped_states)
if skipped_count > 0:
yield TriggerEvent({"status": "skipped"})
return
allowed_count = _get_count(
self.execution_dates,
self.external_task_ids,
self.external_task_group_id,
self.external_dag_id,
self.allowed_states,
)
allowed_count = await self._get_count(self.allowed_states)
if allowed_count == len(self.execution_dates):
yield TriggerEvent({"status": "success"})
return
self.log.info("Sleeping for %s seconds", self.poke_interval)
await asyncio.sleep(self.poke_interval)

@sync_to_async
def _get_count(self, states: typing.Iterable[str] | None) -> int:
"""
Get the count of records against dttm filter and states. Async wrapper for _get_count.

:param states: task or dag states
:return The count of records.
"""
return _get_count(
dttm_filter=self.execution_dates,
external_task_ids=self.external_task_ids,
external_task_group_id=self.external_task_group_id,
external_dag_id=self.external_dag_id,
states=states,
)


@deprecated(reason=("TaskStateTrigger has been deprecated and will be removed in future."))
Taragolis marked this conversation as resolved.
Show resolved Hide resolved
class TaskStateTrigger(BaseTrigger):
"""
Waits asynchronously for a task in a different DAG to complete for a specific logical date.
Expand Down
67 changes: 51 additions & 16 deletions tests/triggers/test_external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import asyncio
import datetime
import time
from unittest import mock

import pytest
Expand All @@ -41,11 +42,10 @@ class TestWorkflowTrigger:
STATES = ["success", "fail"]

@mock.patch("airflow.triggers.external_task._get_count")
@mock.patch("asyncio.sleep")
@pytest.mark.asyncio
async def test_task_workflow_trigger_success(self, mock_sleep, mock_get_count):
async def test_task_workflow_trigger_success(self, mock_get_count):
"""check the db count get called correctly."""
mock_get_count.return_value = 1
mock_get_count.side_effect = mocked_get_count
trigger = WorkflowTrigger(
external_dag_id=self.DAG_ID,
execution_dates=[timezone.datetime(2022, 1, 1)],
Expand All @@ -54,19 +54,29 @@ async def test_task_workflow_trigger_success(self, mock_sleep, mock_get_count):
poke_interval=0.2,
)

generator = trigger.run()
await generator.asend(None)
gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
fake_task = asyncio.create_task(fake_async_fun())
await trigger_task
assert fake_task.done() # confirm that get_count is done in an async fashion
assert trigger_task.done()
result = trigger_task.result()
assert result.payload == {"status": "success"}
mock_get_count.assert_called_once_with(
[timezone.datetime(2022, 1, 1)], ["external_task_op"], None, "external_task", ["success", "fail"]
dttm_filter=[timezone.datetime(2022, 1, 1)],
external_task_ids=["external_task_op"],
external_task_group_id=None,
external_dag_id="external_task",
states=["success", "fail"],
)
# test that it returns after yielding
with pytest.raises(StopAsyncIteration):
await generator.__anext__()
await gen.__anext__()

@mock.patch("airflow.triggers.external_task._get_count")
@pytest.mark.asyncio
async def test_task_workflow_trigger_failed(self, mock_get_count):
mock_get_count.return_value = 1
mock_get_count.side_effect = mocked_get_count
trigger = WorkflowTrigger(
external_dag_id=self.DAG_ID,
execution_dates=[timezone.datetime(2022, 1, 1)],
Expand All @@ -77,13 +87,19 @@ async def test_task_workflow_trigger_failed(self, mock_get_count):

gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
fake_task = asyncio.create_task(fake_async_fun())
await trigger_task
assert trigger_task.done() is True
assert fake_task.done() # confirm that get_count is done in an async fashion
assert trigger_task.done()
result = trigger_task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "failed"}
mock_get_count.assert_called_once_with(
[timezone.datetime(2022, 1, 1)], ["external_task_op"], None, "external_task", ["success", "fail"]
dttm_filter=[timezone.datetime(2022, 1, 1)],
external_task_ids=["external_task_op"],
external_task_group_id=None,
external_dag_id="external_task",
states=["success", "fail"],
)
# test that it returns after yielding
with pytest.raises(StopAsyncIteration):
Expand All @@ -104,12 +120,16 @@ async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count):
gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
await trigger_task
assert trigger_task.done() is True
assert trigger_task.done()
result = trigger_task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "success"}
mock_get_count.assert_called_once_with(
[timezone.datetime(2022, 1, 1)], ["external_task_op"], None, "external_task", ["success", "fail"]
dttm_filter=[timezone.datetime(2022, 1, 1)],
external_task_ids=["external_task_op"],
external_task_group_id=None,
external_dag_id="external_task",
states=["success", "fail"],
)
# test that it returns after yielding
with pytest.raises(StopAsyncIteration):
Expand All @@ -118,7 +138,7 @@ async def test_task_workflow_trigger_fail_count_eq_0(self, mock_get_count):
@mock.patch("airflow.triggers.external_task._get_count")
@pytest.mark.asyncio
async def test_task_workflow_trigger_skipped(self, mock_get_count):
mock_get_count.return_value = 1
mock_get_count.side_effect = mocked_get_count
trigger = WorkflowTrigger(
external_dag_id=self.DAG_ID,
execution_dates=[timezone.datetime(2022, 1, 1)],
Expand All @@ -129,13 +149,19 @@ async def test_task_workflow_trigger_skipped(self, mock_get_count):

gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
fake_task = asyncio.create_task(fake_async_fun())
await trigger_task
assert trigger_task.done() is True
assert fake_task.done() # confirm that get_count is done in an async fashion
assert trigger_task.done()
result = trigger_task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "skipped"}
mock_get_count.assert_called_once_with(
[timezone.datetime(2022, 1, 1)], ["external_task_op"], None, "external_task", ["success", "fail"]
dttm_filter=[timezone.datetime(2022, 1, 1)],
external_task_ids=["external_task_op"],
external_task_group_id=None,
external_dag_id="external_task",
states=["success", "fail"],
)

@mock.patch("airflow.triggers.external_task._get_count")
Expand All @@ -153,7 +179,7 @@ async def test_task_workflow_trigger_sleep_success(self, mock_sleep, mock_get_co
gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
await trigger_task
assert trigger_task.done() is True
assert trigger_task.done()
result = trigger_task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "success"}
Expand Down Expand Up @@ -438,3 +464,12 @@ def test_serialization(self):
"execution_dates": [timezone.datetime(2022, 1, 1)],
"poll_interval": 5,
}


def mocked_get_count(*args, **kwargs):
time.sleep(0.0001)
return 1


async def fake_async_fun():
await asyncio.sleep(0.00005)