Skip to content

Commit

Permalink
Extend test coverage for pandas frequencies (#3179)
Browse files Browse the repository at this point in the history
*Issue #, if available:* #3178

*Description of changes:*
- Add more tests verifying that pandas frequencies are handled correctly
(compatible with both pandas 2.1 and pandas 2.2)

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.


**Please tag this pr with at least one of these labels to make our
release process faster:** BREAKING, new feature, bug fix, other change,
dev setup
  • Loading branch information
shchur authored May 24, 2024
1 parent a132eab commit b1e054a
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 101 deletions.
26 changes: 13 additions & 13 deletions src/gluonts/time_feature/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from packaging.version import Version
from typing import Any, Callable, Dict, List

import numpy as np
Expand Down Expand Up @@ -196,7 +197,10 @@ def norm_freq_str(freq_str: str) -> str:
# Note: Secondly ("S") frequency exists, where we don't want to remove the
# "S"!
if len(base_freq) >= 2 and base_freq.endswith("S"):
return base_freq[:-1]
base_freq = base_freq[:-1]
# In pandas >= 2.2, period end frequencies have been renamed, e.g. "M" -> "ME"
if Version(pd.__version__) >= Version("2.2.0"):
base_freq += "E"

return base_freq

Expand Down Expand Up @@ -252,17 +256,13 @@ def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
Unsupported frequency {freq_str}
The following frequencies are supported:
Y - yearly
alias: A
Q - quarterly
M - monthly
W - weekly
D - daily
B - business days
H - hourly
T - minutely
alias: min
S - secondly
"""

for offset_cls in features_by_offsets:
offset = offset_cls()
supported_freq_msg += (
f"\t{offset.freqstr.split('-')[0]} - {offset_cls.__name__}"
)

raise RuntimeError(supported_freq_msg)
1 change: 1 addition & 0 deletions src/gluonts/time_feature/seasonality.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"ME": 12,
"B": 5,
"Q": 4,
"QE": 4,
}


Expand Down
12 changes: 12 additions & 0 deletions test/time_feature/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
28 changes: 28 additions & 0 deletions test/time_feature/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import pandas as pd
from packaging.version import Version

if Version(pd.__version__) <= Version("2.2.0"):
S = "S"
H = "H"
M = "M"
Q = "Q"
Y = "A"
else:
S = "s"
H = "h"
M = "ME"
Q = "QE"
Y = "YE"
1 change: 0 additions & 1 deletion test/time_feature/test_agg_lags.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import pytest

from gluonts.dataset.common import ListDataset

from gluonts.dataset.field_names import FieldName
from gluonts.transform import AddAggregateLags

Expand Down
26 changes: 14 additions & 12 deletions test/time_feature/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,23 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import pytest
from pandas.tseries.frequencies import to_offset

from gluonts.time_feature import norm_freq_str

from .common import M, Q, S, Y

def test_norm_freq_str():
assert norm_freq_str(to_offset("Y").name) in ["A", "YE"]
assert norm_freq_str(to_offset("YS").name) in ["A", "Y"]
assert norm_freq_str(to_offset("A").name) in ["A", "YE"]
assert norm_freq_str(to_offset("AS").name) in ["A", "Y"]

assert norm_freq_str(to_offset("Q").name) in ["Q", "QE"]
assert norm_freq_str(to_offset("QS").name) == "Q"

assert norm_freq_str(to_offset("M").name) in ["M", "ME"]
assert norm_freq_str(to_offset("MS").name) in ["M", "ME"]

assert norm_freq_str(to_offset("S").name) in ["S", "s"]
@pytest.mark.parametrize(
" aliases, normalized_freq_str",
[
(["Y", "YS", "A", "AS"], Y),
(["Q", "QS"], Q),
(["M", "MS"], M),
(["S"], S),
],
)
def test_norm_freq_str(aliases, normalized_freq_str):
for alias in aliases:
assert norm_freq_str(to_offset(alias).name) == normalized_freq_str
1 change: 0 additions & 1 deletion test/time_feature/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import pytest

from gluonts import zebras as zb

from gluonts.time_feature import (
Constant,
TimeFeature,
Expand Down
81 changes: 26 additions & 55 deletions test/time_feature/test_lag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
Test the lags computed for different frequencies.
"""

import pytest

import gluonts.time_feature.lag as date_feature_set

from .common import H, M, Q, Y

# These are the expected lags for common frequencies and corner cases.
# By default all frequencies have the following lags: [1, 2, 3, 4, 5, 6, 7].
# Remaining lags correspond to the same `season` (+/- `delta`) in previous `k` cycles.
expected_lags = {
EXPECTED_LAGS = {
# (apart from the default lags) centered around each of the last 3 hours (delta = 2)
"4S": [
1,
Expand Down Expand Up @@ -179,7 +183,7 @@
]
+ [329, 330, 331, 494, 495, 496, 659, 660, 661, 707, 708, 709],
# centered around each of the last 3 hours (delta = 2) + last 7 days (delta = 1) + last 6 weeks (delta = 1)
"H": [1, 2, 3, 4, 5, 6, 7]
H: [1, 2, 3, 4, 5, 6, 7]
+ [
23,
24,
Expand All @@ -206,7 +210,7 @@
+ [335, 336, 337, 503, 504, 505, 671, 672, 673, 719, 720, 721],
# centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) +
# last 8th and 12th weeks (delta = 0)
"6H": [
("6" + H): [
1,
2,
3,
Expand Down Expand Up @@ -237,21 +241,21 @@
+ [224, 336],
# centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) +
# last 8th and 12th weeks (delta = 0) + last year (delta = 1)
"12H": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
("12" + H): [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
+ [27, 28, 29, 41, 42, 43, 55, 56, 57]
+ [59, 60, 61]
+ [112, 168]
+ [727, 728, 729],
# centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) +
# last 8th and 12th weeks (delta = 0) + last 3 years (delta = 1)
"23H": [1, 2, 3, 4, 5, 6, 7, 8]
("23" + H): [1, 2, 3, 4, 5, 6, 7, 8]
+ [13, 14, 15, 20, 21, 22, 28, 29]
+ [30, 31, 32]
+ [58, 87]
+ [378, 379, 380, 758, 759, 760, 1138, 1139, 1140],
# centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) +
# last 8th and 12th weeks (delta = 0) + last 3 years (delta = 1)
"25H": [1, 2, 3, 4, 5, 6, 7]
("25" + H): [1, 2, 3, 4, 5, 6, 7]
+ [12, 13, 14, 19, 20, 21, 25, 26, 27]
+ [28, 29]
+ [53, 80]
Expand Down Expand Up @@ -285,64 +289,31 @@
# centered around each of the last 3 years (delta = 1)
"5W": [1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 19, 20, 21, 30, 31, 32],
# centered around each of the last 3 years (delta = 1)
"M": [1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 23, 24, 25, 35, 36, 37],
M: [1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 23, 24, 25, 35, 36, 37],
# default
"6M": [1, 2, 3, 4, 5, 6, 7],
"6" + M: [1, 2, 3, 4, 5, 6, 7],
# default
"12M": [1, 2, 3, 4, 5, 6, 7],
"12" + M: [1, 2, 3, 4, 5, 6, 7],
Q: [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13],
"QS": [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13],
Y: [1, 2, 3, 4, 5, 6, 7],
"YS": [1, 2, 3, 4, 5, 6, 7],
}

# For the default multiple (1)
for freq in ["min", "H", "D", "W", "M"]:
expected_lags["1" + freq] = expected_lags[freq]
for freq in ["min", H, "D", "W", M]:
EXPECTED_LAGS["1" + freq] = EXPECTED_LAGS[freq]

# For frequencies that do not have unique form
expected_lags["60min"] = expected_lags["1H"]
expected_lags["24H"] = expected_lags["1D"]
expected_lags["7D"] = expected_lags["1W"]


def test_lags():
freq_strs = [
"4S",
"min",
"1min",
"15min",
"30min",
"59min",
"60min",
"61min",
"H",
"1H",
"6H",
"12H",
"23H",
"24H",
"25H",
"D",
"1D",
"2D",
"6D",
"7D",
"8D",
"W",
"1W",
"3W",
"4W",
"5W",
"M",
"6M",
"12M",
]
EXPECTED_LAGS["60min"] = EXPECTED_LAGS["1" + H]
EXPECTED_LAGS["24" + H] = EXPECTED_LAGS["1D"]
EXPECTED_LAGS["7D"] = EXPECTED_LAGS["1W"]

for freq_str in freq_strs:
lags = date_feature_set.get_lags_for_frequency(freq_str)

assert (
lags == expected_lags[freq_str]
), "lags do not match for the frequency '{}':\nexpected: {},\nprovided: {}".format(
freq_str, expected_lags[freq_str], lags
)
@pytest.mark.parametrize("freq_str, expected_lags", EXPECTED_LAGS.items())
def test_lags(freq_str, expected_lags):
lags = date_feature_set.get_lags_for_frequency(freq_str)
assert lags == expected_lags


if __name__ == "__main__":
Expand Down
55 changes: 36 additions & 19 deletions test/time_feature/test_seasonality.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,42 @@

from gluonts.time_feature import get_seasonality

from .common import H, M, Q, Y

@pytest.mark.parametrize(
"freq, expected_seasonality",
[
("30min", 48),
("1H", 24),
("H", 24),
("2H", 12),
("3H", 8),
("4H", 6),
("15H", 1),
("5B", 1),
("1B", 5),
("2W", 1),
("3M", 4),
("1D", 1),
("7D", 1),
("8D", 1),
],
)
TEST_CASES = [
("30min", 48),
("5B", 1),
("1B", 5),
("2W", 1),
("1D", 1),
("7D", 1),
("8D", 1),
# Monthly
("MS", 12),
("3MS", 4),
(M, 12),
("3" + M, 4),
# Quarterly
("QS", 4),
("2QS", 2),
(Q, 4),
("2" + Q, 2),
("3" + Q, 1),
# Hourly
("1" + H, 24),
(H, 24),
("2" + H, 12),
("3" + H, 8),
("4" + H, 6),
("15" + H, 1),
# Yearly
(Y, 1),
("2" + Y, 1),
("YS", 1),
("2YS", 1),
]


@pytest.mark.parametrize("freq, expected_seasonality", TEST_CASES)
def test_get_seasonality(freq, expected_seasonality):
assert get_seasonality(freq) == expected_seasonality

0 comments on commit b1e054a

Please sign in to comment.