diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 91169e5..5f25414 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ repos: hooks: - id: isort - repo: https://github.com/psf/black - rev: 21.9b0 + rev: 22.8.0 hooks: - id: black - repo: https://github.com/PyCQA/pydocstyle diff --git a/README.md b/README.md index fbde80f..52fcb53 100644 --- a/README.md +++ b/README.md @@ -159,7 +159,7 @@ week_range = WeekRange( TimeRange(time(0), time(2)), TimeRange(time(4), time(8)), ] - ) + ), } ) diff --git a/requirements-dev.txt b/requirements-dev.txt index 0395994..d34cd21 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ -r docs/requirements.txt -black==21.9b0 +black==22.8.0 build==0.7.0 coverage==6.0.1 flake8==3.9.2 diff --git a/src/timeranges/__init__.py b/src/timeranges/__init__.py index 3ce723e..2b3a737 100644 --- a/src/timeranges/__init__.py +++ b/src/timeranges/__init__.py @@ -2,7 +2,8 @@ __version__ = "0.7.1" +from ._datetimeranges import DatetimeRange, DatetimeRanges from ._timeranges import TimeRange, TimeRanges, WeekRange # TODO Maybe generate it programmatically? -__all__ = ["TimeRange", "TimeRanges", "WeekRange"] +__all__ = ["TimeRange", "TimeRanges", "WeekRange", "DatetimeRange", "DatetimeRanges"] diff --git a/src/timeranges/_base.py b/src/timeranges/_base.py new file mode 100644 index 0000000..a06c8a4 --- /dev/null +++ b/src/timeranges/_base.py @@ -0,0 +1,5 @@ +from abc import ABC + + +class BaseRange(ABC): + pass diff --git a/src/timeranges/_datetimeranges.py b/src/timeranges/_datetimeranges.py new file mode 100644 index 0000000..fc97740 --- /dev/null +++ b/src/timeranges/_datetimeranges.py @@ -0,0 +1,186 @@ +from copy import deepcopy +from datetime import datetime, time, timedelta, timezone, tzinfo +from functools import reduce +from typing import List, Optional, TypeVar, Union + +import attr +from timematic.enums import Weekday + +from ._base import BaseRange +from ._timeranges import TimeRange, TimeRanges, WeekRange + +_T_DatetimeRange = TypeVar("_T_DatetimeRange", bound="DatetimeRange") + + +@attr.define(order=True, on_setattr=attr.setters.validate) +class DatetimeRange(BaseRange): + def _validate_start( + instance: _T_DatetimeRange, attribute: attr.Attribute, start: datetime + ) -> None: + instance._validate_datetime(start) + instance._validate_range(start, instance.end) + + def _validate_end( + instance: _T_DatetimeRange, attribute: attr.Attribute, end: datetime + ) -> None: + instance._validate_datetime(end) + instance._validate_range(instance.start, end) + + start: datetime = attr.ib( + default=datetime.min.replace(tzinfo=timezone.utc), validator=_validate_start + ) + end: datetime = attr.ib( + default=datetime.max.replace(tzinfo=timezone.utc), + order=False, + validator=_validate_end, + ) + + @staticmethod + def _validate_datetime(dt: datetime) -> None: + if dt.tzinfo is None: + raise ValueError(f"Datetime {dt} has no timezone information") + + @staticmethod + def _validate_range(start: datetime, end: datetime) -> None: + if start > end: # This automatically ensures they're both offset-naive or aware + raise ValueError(f"Start datetime {start} is after end datetime {end}") + + def validate(self) -> None: + for dt in (self.start, self.end): + self._validate_datetime(dt) + + self._validate_range(self.start, self.end) + + def __attrs_post_init__(self) -> None: + self.validate() + + def _contains_datetime(self, other: datetime, /) -> bool: + return self.start <= other <= self.end + + def _contains_datetime_range(self, other: "DatetimeRange", /) -> bool: + scdt = self._contains_datetime + return scdt(other.start) and scdt(other.end) + + _contains_types = Union[datetime, "DatetimeRange"] + + def contains(self, other: _contains_types, /) -> bool: + if isinstance(other, datetime): + return self._contains_datetime(other) + elif isinstance(other, DatetimeRange): + return self._contains_datetime_range(other) + else: + raise TypeError + + def __contains__(self, other: _contains_types) -> bool: + return self.contains(other) + + def to_week_range(self, replace_timezone: Optional[tzinfo] = None) -> WeekRange: + start = self.start + if replace_timezone is not None: + start = start.astimezone(replace_timezone) + + tz = start.tzinfo + end = self.end.astimezone(tz) + + date_start = start.date() + date_end = end.date() + + week_range = WeekRange(timezone=tz) + + d = date_start + while d <= date_end: + # TODO Skip unnecessary iterations if week is already full + tr_start: Optional[time] = None + tr_end: Optional[time] = None + if d == date_start: + tr_start = start.time() + if d == date_end: + tr_end = end.time() + + time_range = TimeRange() + if tr_start is not None: + time_range.start = tr_start + if tr_end is not None: + time_range.end = tr_end + + week_range.day_ranges[Weekday(d.weekday())] = TimeRanges([time_range]) + + d += timedelta(days=1) + + return week_range + + +@attr.define +class DatetimeRanges(BaseRange): + datetime_ranges: List[DatetimeRange] = attr.Factory(list) + + def validate(self) -> None: + for datetime_range in self.datetime_ranges: + datetime_range.validate() + + def sort(self) -> None: + self.validate() + self.datetime_ranges.sort() + + def merge(self, interpolate: timedelta = timedelta(0)) -> None: + assert interpolate >= timedelta(0), "Interpolation must be positive" + self.sort() + datetime_ranges = deepcopy(self.datetime_ranges) + aux: List[DatetimeRange] = [] + + # Merge overlapping time ranges + for datetime_range in datetime_ranges: + if not aux: + aux.append(datetime_range) + continue + aux_last = aux[-1] + if (datetime_range.start - aux_last.end) <= interpolate: + if datetime_range.end > aux_last.end: + aux_last.end = datetime_range.end + else: + aux.append(datetime_range) + + # TODO Interpolate to `time.max` + + self.datetime_ranges = aux + self.sort() + + def __attrs_post_init__(self) -> None: + self.validate() + + def __bool__(self) -> bool: + return bool(self.datetime_ranges) + + def _contains_datetime(self, other: datetime, /) -> bool: + return any(other in datetime_range for datetime_range in self.datetime_ranges) + + def _contains_datetime_range(self, other: DatetimeRange, /) -> bool: + return any(other in datetime_range for datetime_range in self.datetime_ranges) + + def _contains_datetime_ranges(self, other: "DatetimeRanges", /) -> bool: + return all( + self._contains_datetime_range(datetime_range) + for datetime_range in other.datetime_ranges + ) + + _contains_types = Union[time, DatetimeRange, "DatetimeRanges"] + + def contains(self, other: _contains_types, /) -> bool: + if isinstance(other, datetime): + return self._contains_datetime(other) + elif isinstance(other, DatetimeRange): + return self._contains_datetime_range(other) + elif isinstance(other, DatetimeRanges): + return self._contains_datetime_ranges(other) + else: + raise TypeError + + def __contains__(self, other: _contains_types) -> bool: + return self.contains(other) + + def to_week_range(self, replace_timezone: Optional[tzinfo] = None) -> WeekRange: + week_ranges: list[WeekRange] = [] + for datetime_range in self.datetime_ranges: + week_ranges.append(datetime_range.to_week_range(replace_timezone)) + + return reduce(lambda a, b: a | b, week_ranges) diff --git a/src/timeranges/_timeranges.py b/src/timeranges/_timeranges.py index f1bf112..d8d12bc 100644 --- a/src/timeranges/_timeranges.py +++ b/src/timeranges/_timeranges.py @@ -1,16 +1,20 @@ -from copy import deepcopy +from collections import defaultdict +from copy import copy, deepcopy from datetime import datetime, time, timedelta, tzinfo -from typing import Dict, List, Optional, TypeVar +from itertools import product +from typing import DefaultDict, Dict, List, Optional, TypeVar, Union import attr from timematic.enums import Weekday from timematic.utils import subtract_times +from ._base import BaseRange + _T_TimeRange = TypeVar("_T_TimeRange", bound="TimeRange") @attr.define(order=True, on_setattr=attr.setters.validate) -class TimeRange: +class TimeRange(BaseRange): def _validate_start( instance: _T_TimeRange, attribute: attr.Attribute, start: time ) -> None: @@ -28,7 +32,7 @@ def _validate_end( end: time = attr.ib(default=time.max, order=False, validator=_validate_end) @staticmethod - def _validate_time(time: time) -> None: + def _validate_time(time: time, /) -> None: if time.tzinfo is not None: raise ValueError(f"Time {time} has timezone info") @@ -46,15 +50,37 @@ def validate(self) -> None: def __attrs_post_init__(self) -> None: self.validate() - def contains(self, t: time) -> bool: - return self.start <= t <= self.end + def _contains_time(self, other: time, /) -> bool: + return self.start <= other <= self.end + + def _contains_time_range(self, other: "TimeRange", /) -> bool: + sct = self._contains_time + return sct(other.start) and sct(other.end) + + _contains_types = Union[time, "TimeRange"] + + def contains(self, other: _contains_types, /) -> bool: + if isinstance(other, time): + return self._contains_time(other) + elif isinstance(other, TimeRange): + return self._contains_time_range(other) + else: + raise TypeError + + def __contains__(self, other: _contains_types) -> bool: + return self.contains(other) - def __contains__(self, t: time) -> bool: - return self.contains(t) + def intersection(self, other: "TimeRange", /) -> Optional["TimeRange"]: + start = max(self.start, other.start) + end = min(self.end, other.end) + return TimeRange(start, end) if start <= end else None + + def __and__(self, other: "TimeRange") -> Optional["TimeRange"]: + return self.intersection(other) @attr.define -class TimeRanges: +class TimeRanges(BaseRange): time_ranges: List[TimeRange] = attr.Factory(list) def validate(self) -> None: @@ -93,16 +119,65 @@ def merge(self, interpolate: timedelta = timedelta(0)) -> None: def __attrs_post_init__(self) -> None: self.validate() - def contains(self, t: time) -> bool: - return any(t in time_range for time_range in self.time_ranges) + def __bool__(self) -> bool: + return bool(self.time_ranges) - def __contains__(self, t: time) -> bool: - return self.contains(t) + def _contains_time(self, other: time, /) -> bool: + return any(other in time_range for time_range in self.time_ranges) + def _contains_time_range(self, other: TimeRange, /) -> bool: + return any(other in time_range for time_range in self.time_ranges) -@attr.define -class WeekRange: - day_ranges: Dict[Weekday, TimeRanges] = attr.Factory(dict) + def _contains_time_ranges(self, other: "TimeRanges", /) -> bool: + return all( + self._contains_time_range(time_range) for time_range in other.time_ranges + ) + + _contains_types = Union[time, TimeRange, "TimeRanges"] + + def contains(self, other: _contains_types, /) -> bool: + if isinstance(other, time): + return self._contains_time(other) + elif isinstance(other, TimeRange): + return self._contains_time_range(other) + elif isinstance(other, TimeRanges): + return self._contains_time_ranges(other) + else: + raise TypeError + + def __contains__(self, other: _contains_types) -> bool: + return self.contains(other) + + def union(self, other: "TimeRanges", /) -> "TimeRanges": + time_ranges_list = self.time_ranges + other.time_ranges + time_ranges = TimeRanges(time_ranges_list) + time_ranges.merge() + return time_ranges + + def __or__(self, other: "TimeRanges") -> "TimeRanges": + return self.union(other) if isinstance(other, TimeRanges) else NotImplemented + + def intersection(self, other: "TimeRanges", /) -> "TimeRanges": + time_ranges_list: List[Optional[TimeRange]] = [ + a & b for a, b in product(self.time_ranges, other.time_ranges) + ] + + time_ranges = TimeRanges([tr for tr in time_ranges_list if tr is not None]) + time_ranges.merge() + return time_ranges + + def __and__(self, other: "TimeRanges") -> "TimeRanges": + return self.intersection(other) + + +@attr.define(on_setattr=attr.setters.convert) +class WeekRange(BaseRange): + def _convert_day_ranges(day_ranges: Dict[Weekday, TimeRanges]) -> DefaultDict[Weekday, TimeRanges]: # type: ignore + return defaultdict(TimeRanges, day_ranges) + + day_ranges: DefaultDict[Weekday, TimeRanges] = attr.ib( + factory=dict, converter=_convert_day_ranges + ) timezone: Optional[tzinfo] = None def validate(self) -> None: @@ -116,15 +191,64 @@ def merge(self, interpolate: timedelta = timedelta(0)) -> None: def __attrs_post_init__(self) -> None: self.validate() - def contains(self, dt: datetime) -> bool: + def __bool__(self) -> bool: + return any(self.day_ranges.values()) + + def _assert_timezone(self, other: "WeekRange", /) -> None: + if (tz := self.timezone) != (otz := other.timezone): + raise ValueError(f"Different timezones ({tz} and {otz})") + + def _contains_datetime(self, other: datetime, /) -> bool: tz = self.timezone if tz is not None: - dt = dt.astimezone(tz) - weekday = Weekday.from_datetime(dt) - day_range = self.day_ranges.get(weekday) - if day_range is not None: - return dt.time() in day_range - return False - - def __contains__(self, dt: datetime) -> bool: - return self.contains(dt) + other = other.astimezone(tz) + weekday = Weekday.from_datetime(other) + return other.time() in self.day_ranges[weekday] + + def _contains_week_range(self, other: "WeekRange", /) -> bool: + self._assert_timezone(other) + + return all( + day_range in self.day_ranges[weekday] + for weekday, day_range in other.day_ranges.items() + ) + + _contains_types = Union[datetime, "WeekRange"] + + def contains(self, other: _contains_types, /) -> bool: + if isinstance(other, datetime): + return self._contains_datetime(other) + elif isinstance(other, WeekRange): + return self._contains_week_range(other) + else: + raise TypeError + + def __contains__(self, other: _contains_types) -> bool: + return self.contains(other) + + def union(self, other: "WeekRange", /) -> "WeekRange": + self._assert_timezone(other) + + day_ranges = copy(self.day_ranges) + for weekday, day_range in other.day_ranges.items(): + day_ranges[weekday] |= day_range + + return WeekRange(day_ranges, timezone=self.timezone) + + def __or__(self, other: "WeekRange") -> "WeekRange": + return self.union(other) if isinstance(other, WeekRange) else NotImplemented + + def intersection(self, other: "WeekRange", /) -> "WeekRange": + self._assert_timezone(other) + + week_range = WeekRange(timezone=self.timezone) + for weekday, day_range in self.day_ranges.items(): + week_range.day_ranges[weekday] = day_range & other.day_ranges[weekday] + + return week_range + + def __and__(self, other: "WeekRange") -> "WeekRange": + return self.intersection(other) + + def has_transition(self, other: "WeekRange") -> bool: + return bool(other not in self and other & self) diff --git a/tests/test_datetimeranges.py b/tests/test_datetimeranges.py new file mode 100644 index 0000000..079e335 --- /dev/null +++ b/tests/test_datetimeranges.py @@ -0,0 +1,59 @@ +from datetime import datetime, timezone + +from pytest import raises + +from timeranges import DatetimeRange, DatetimeRanges + + +def utc(*args, **kwargs) -> datetime: + kwargs["tzinfo"] = timezone.utc + return datetime(*args, **kwargs) + + +def test_datetime_range_invalid(): + with raises(ValueError): + DatetimeRange(datetime(2022, 1, 1), datetime(2022, 2, 2)) + with raises(ValueError): + DatetimeRange(utc(2022, 2, 2), utc(2022, 1, 1)) + + +def test_datetime_range_contains_invalid(): + datetime_range = DatetimeRange(utc(2022, 1, 1), utc(2022, 2, 2)) + + with raises(TypeError): + 1 in datetime_range + + +def test_datetime_range_contains_datetime(): + datetime_range = DatetimeRange(utc(2022, 1, 1), utc(2022, 2, 2)) + yes = [utc(2022, 1, 1), utc(2022, 1, 2), utc(2022, 2, 2)] + no = [utc(2021, 1, 1), utc(2022, 2, 3)] + + for dt in yes: + assert dt in datetime_range + assert datetime_range.contains(dt) + + for dt in no: + assert dt not in datetime_range + assert not datetime_range.contains(dt) + + +def test_datetime_range_contains_datetime_range(): + datetime_range = DatetimeRange(utc(2022, 1, 1), utc(2022, 2, 2)) + yes = [ + DatetimeRange(utc(2022, 1, 1), utc(2022, 1, 2)), + DatetimeRange(utc(2022, 2, 1), utc(2022, 2, 2)), + ] + no = [ + DatetimeRange(utc(2021, 1, 1), utc(2023, 1, 1)), + DatetimeRange(utc(2022, 3, 3), utc(2022, 4, 4)), + DatetimeRange(utc(2022, 2, 1), utc(2022, 2, 3)), + ] + + for dt in yes: + assert dt in datetime_range + assert datetime_range.contains(dt) + + for dt in no: + assert dt not in datetime_range + assert not datetime_range.contains(dt) diff --git a/tests/test_timeranges.py b/tests/test_timeranges.py index 57fd65c..97f4824 100644 --- a/tests/test_timeranges.py +++ b/tests/test_timeranges.py @@ -1,4 +1,5 @@ from datetime import datetime, time, timedelta, timezone +from typing import Tuple from pytest import raises from timematic.enums import Weekday @@ -10,7 +11,12 @@ def test_timerange_invalid(): with raises(ValueError): TimeRange(time(2), time(1)) with raises(ValueError): - TimeRange(time(1), time(2, tzinfo=timezone(timedelta(1)))) + TimeRange( + time(1, tzinfo=timezone(timedelta(hours=2))), + time(2, tzinfo=timezone(timedelta(hours=1))), + ) + with raises(TypeError): + TimeRange(time(1), time(2, tzinfo=timezone(timedelta(hours=1)))) with raises(ValueError): tz = timezone(timedelta(1)) TimeRange(time(1, tzinfo=tz), time(2, tzinfo=tz)) @@ -23,7 +29,14 @@ def test_timerange_invalid(): assert tr == TimeRange(time(1), time(2)) -def test_timerange_contains(): +def test_timerange_contains_invalid(): + timerange = TimeRange(time(2), time(4)) + + with raises(TypeError): + 1 in timerange + + +def test_timerange_contains_time(): timerange = TimeRange(time(2), time(4)) yes = [time(2), time(3), time(4)] no = [time(1), time(5)] @@ -37,7 +50,7 @@ def test_timerange_contains(): assert not timerange.contains(t) -def test_timeranges_contains(): +def test_timeranges_contains_time(): timeranges = TimeRanges( [ TimeRange(time(2), time(4)), @@ -54,3 +67,50 @@ def test_timeranges_contains(): for t in no: assert t not in timeranges assert not timeranges.contains(t) + + +def test_timerange_contains_timerange(): + timerange = TimeRange(time(2), time(8)) + yes = [ + TimeRange(time(3), time(4)), + TimeRange(time(5), time(8)), + TimeRange(time(2), time(8)), + ] + no = [ + TimeRange(time(1), time(3)), + TimeRange(time(3), time(9)), + TimeRange(time(1), time(9)), + ] + + for tr in yes: + assert tr in timerange + assert timerange.contains(tr) + assert timerange._contains_time_range(tr) + + for tr in no: + assert tr not in timerange + assert not timerange.contains(tr) + assert not timerange._contains_time_range(tr) + + +def test_has_transition(): + def _make_week_range(interval: Tuple[int, int], /) -> WeekRange: + s, e = interval + return WeekRange({Weekday.MONDAY: TimeRanges([TimeRange(time(s), time(e))])}) + + yes = [_make_week_range(interval) for interval in [(3, 5), (6, 8), (3, 8)]] + no = [_make_week_range(interval) for interval in [(5, 6), (3, 4)]] + + target = WeekRange( + { + Weekday.MONDAY: TimeRanges( + [TimeRange(time(2), time(4)), TimeRange(time(7), time(9))] + ) + } + ) + + for wr in yes: + assert target.has_transition(wr) + + for wr in no: + assert not target.has_transition(wr)