Skip to content

Commit

Permalink
DateTimeTrigger typing and tests (apache#37694)
Browse files Browse the repository at this point in the history
  • Loading branch information
drajguru authored and utkarsharma2 committed Apr 22, 2024
1 parent 779e767 commit 445d0c2
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 30 deletions.
14 changes: 8 additions & 6 deletions airflow/triggers/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import asyncio
import datetime
from typing import Any
from typing import Any, AsyncIterator

import pendulum

from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils import timezone
Expand All @@ -42,12 +44,12 @@ def __init__(self, moment: datetime.datetime):
elif moment.tzinfo is None:
raise ValueError("You cannot pass naive datetimes")
else:
self.moment = timezone.convert_to_utc(moment)
self.moment: pendulum.DateTime = timezone.convert_to_utc(moment)

def serialize(self) -> tuple[str, dict[str, Any]]:
return ("airflow.triggers.temporal.DateTimeTrigger", {"moment": self.moment})

async def run(self):
async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Loop until the relevant time is met.
Expand All @@ -59,13 +61,13 @@ async def run(self):
# Sleep in successively smaller increments starting from 1 hour down to 10 seconds at a time
self.log.info("trigger starting")
for step in 3600, 60, 10:
seconds_remaining = (self.moment - timezone.utcnow()).total_seconds()
seconds_remaining = (self.moment - pendulum.instance(timezone.utcnow())).total_seconds()
while seconds_remaining > 2 * step:
self.log.info(f"{int(seconds_remaining)} seconds remaining; sleeping {step} seconds")
await asyncio.sleep(step)
seconds_remaining = (self.moment - timezone.utcnow()).total_seconds()
seconds_remaining = (self.moment - pendulum.instance(timezone.utcnow())).total_seconds()
# Sleep a second at a time otherwise
while self.moment > timezone.utcnow():
while self.moment > pendulum.instance(timezone.utcnow()):
self.log.info("sleeping 1 second...")
await asyncio.sleep(1)
# Send our single event and then we're done
Expand Down
84 changes: 60 additions & 24 deletions scripts/cov/other_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,45 +25,81 @@

source_files = [
"airflow/dag_processing",
"airflow/triggers",
]
"""
Other potential source file packages to scan for coverage.
You can also compare the stats against those on
https://app.codecov.io/github/apache/airflow
(as it combines the coverage from all tests and so may be a bit higher).
"airflow/auth",
"airflow/callbacks",
"airflow/config_templates",
"airflow/dag_processing",
"airflow/datasets",
"airflow/decorators",
"airflow/hooks",
"airflow/io",
"airflow/lineage",
"airflow/listeners",
"airflow/macros",
"airflow/notifications",
"airflow/secrets",
"airflow/security",
"airflow/sensors",
"airflow/task",
"airflow/template",
"airflow/timetables",
"airflow/triggers",
"""

files_not_fully_covered = [
"airflow/dag_processing/manager.py",
"airflow/dag_processing/processor.py",
"airflow/triggers/base.py",
"airflow/triggers/external_task.py",
"airflow/triggers/file.py",
"airflow/triggers/testing.py",
]

other_tests = [
"tests/dag_processing",
"tests/jobs",
"tests/triggers",
]

"""
These 'other' packages can be added to the above lists
as necessary:
Other tests to potentially run against the source_file packages:
"tests/auth",
"tests/callbacks",
"tests/charts",
"tests/cluster_policies",
"tests/config_templates",
"tests/datasets",
"tests/decorators",
"tests/hooks",
"tests/io",
"tests/lineage",
"tests/listeners",
"tests/macros",
"tests/notifications",
"tests/plugins",
"tests/secrets",
"tests/security",
"tests/sensors",
"tests/task",
"tests/template",
"tests/testconfig",
"tests/timetables",
"tests/triggers",
"tests/api_internal",
"tests/auth",
"tests/callbacks",
"tests/charts",
"tests/cluster_policies",
"tests/config_templates",
"tests/dag_processing",
"tests/datasets",
"tests/decorators",
"tests/hooks",
"tests/io",
"tests/jobs",
"tests/lineage",
"tests/listeners",
"tests/macros",
"tests/notifications",
"tests/plugins",
"tests/secrets",
"tests/security",
"tests/sensors",
"tests/task",
"tests/template",
"tests/testconfig",
"tests/timetables",
"tests/triggers",
"""


if __name__ == "__main__":
args = ["-qq"] + other_tests
run_tests(args, source_files, files_not_fully_covered)
44 changes: 44 additions & 0 deletions tests/triggers/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@

import asyncio
import datetime
from unittest import mock

import pendulum
import pytest

from airflow.triggers.base import TriggerEvent
from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
from airflow.utils import timezone
from airflow.utils.timezone import utcnow


def test_input_validation():
Expand All @@ -35,6 +37,16 @@ def test_input_validation():
DateTimeTrigger("2012-01-01T03:03:03+00:00")


def test_input_validation_tz():
"""
Tests that the DateTimeTrigger validates input to moment arg, it shouldn't accept naive datetime.
"""

moment = datetime.datetime(2013, 3, 31, 0, 59, 59)
with pytest.raises(ValueError, match="You cannot pass naive datetimes"):
DateTimeTrigger(moment)


def test_datetime_trigger_serialization():
"""
Tests that the DateTimeTrigger correctly serializes its arguments
Expand Down Expand Up @@ -96,3 +108,35 @@ async def test_datetime_trigger_timing(tz):
result = trigger_task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == past_moment


@mock.patch("airflow.triggers.temporal.timezone.utcnow")
@mock.patch("airflow.triggers.temporal.asyncio.sleep")
@pytest.mark.asyncio
async def test_datetime_trigger_mocked(mock_sleep, mock_utcnow):
"""
Tests DateTimeTrigger with time and asyncio mocks
"""
start_moment = utcnow()
trigger_moment = start_moment + datetime.timedelta(seconds=30)

# returns the mock 'current time'. The first 3 calls report the initial time
mock_utcnow.side_effect = [
start_moment,
start_moment,
start_moment,
start_moment + datetime.timedelta(seconds=20),
start_moment + datetime.timedelta(seconds=25),
start_moment + datetime.timedelta(seconds=30),
]

trigger = DateTimeTrigger(trigger_moment)
gen = trigger.run()
trigger_task = asyncio.create_task(gen.__anext__())
await trigger_task
mock_sleep.assert_awaited()
assert mock_sleep.await_count == 2
assert trigger_task.done() is True
result = trigger_task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == trigger_moment

0 comments on commit 445d0c2

Please sign in to comment.