From 218e571c04e258db6fdd75e309bd45b1120a9797 Mon Sep 17 00:00:00 2001 From: XD-DENG Date: Tue, 3 Dec 2019 23:48:46 +0800 Subject: [PATCH] [AIRFLOW-6165] Housekeep utils.dates.date_range & add tests --- airflow/utils/dates.py | 9 +++++++- tests/utils/test_dates.py | 43 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/airflow/utils/dates.py b/airflow/utils/dates.py index 6ebbeb6a1ece5..9c5c5565550f9 100644 --- a/airflow/utils/dates.py +++ b/airflow/utils/dates.py @@ -59,6 +59,8 @@ def date_range(start_date, end_date=None, num=None, delta=None): number of entries you want in the range. This number can be negative, output will always be sorted regardless :type num: int + :param delta: step length. It can be datetime.timedelta or cron expression as string + :type delta: datetime.timedelta or str """ if not delta: return [] @@ -71,12 +73,17 @@ def date_range(start_date, end_date=None, num=None, delta=None): delta_iscron = False tz = start_date.tzinfo + if isinstance(delta, str): delta_iscron = True - start_date = timezone.make_naive(start_date, tz) + if timezone.is_localized(start_date): + start_date = timezone.make_naive(start_date, tz) cron = croniter(delta, start_date) elif isinstance(delta, timedelta): delta = abs(delta) + else: + raise Exception("Wait. delta must be either datetime.timedelta or cron expression as str") + dates = [] if end_date: if timezone.is_naive(start_date) and not timezone.is_naive(end_date): diff --git a/tests/utils/test_dates.py b/tests/utils/test_dates.py index a4e0af917501d..910a458f6a9a3 100644 --- a/tests/utils/test_dates.py +++ b/tests/utils/test_dates.py @@ -52,3 +52,46 @@ def test_parse_execution_date(self): timezone.datetime(2017, 11, 5, 16, 18, 30, 989729), dates.parse_execution_date(execution_date_str_w_ms)) self.assertRaises(ValueError, dates.parse_execution_date, bad_execution_date_str) + + +class TestUtilsDatesDateRange(unittest.TestCase): + + def test_no_delta(self): + self.assertEqual(dates.date_range(datetime(2016, 1, 1), datetime(2016, 1, 3)), + []) + + def test_end_date_before_start_date(self): + with self.assertRaisesRegex(Exception, "Wait. start_date needs to be before end_date"): + dates.date_range(datetime(2016, 2, 1), + datetime(2016, 1, 1), + delta=timedelta(seconds=1)) + + def test_both_end_date_and_num_given(self): + with self.assertRaisesRegex(Exception, "Wait. Either specify end_date OR num"): + dates.date_range(datetime(2016, 1, 1), + datetime(2016, 1, 3), + num=2, + delta=timedelta(seconds=1)) + + def test_invalid_delta(self): + exception_msg = "Wait. delta must be either datetime.timedelta or cron expression as str" + with self.assertRaisesRegex(Exception, exception_msg): + dates.date_range(datetime(2016, 1, 1), + datetime(2016, 1, 3), + delta=1) + + def test_positive_num_given(self): + for num in range(1, 10): + result = dates.date_range(datetime(2016, 1, 1), num=num, delta=timedelta(1)) + self.assertEqual(len(result), num) + + for i in range(num): + self.assertTrue(timezone.is_localized(result[i])) + + def test_negative_num_given(self): + for num in range(-1, -5, -10): + result = dates.date_range(datetime(2016, 1, 1), num=num, delta=timedelta(1)) + self.assertEqual(len(result), -num) + + for i in range(num): + self.assertTrue(timezone.is_localized(result[i]))