Skip to content

Commit

Permalink
[Feature] Support a user defined function name in the window transfor…
Browse files Browse the repository at this point in the history
…mation output (unit8co#1676)

* Support user-defined name in window transformation

* Update doc string

* Add tests

* Review: Test respects the built-in function name

---------

Co-authored-by: madtoinou <32447896+madtoinou@users.noreply.github.com>
Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
  • Loading branch information
3 people authored and alexcolpitts96 committed May 31, 2023
1 parent 69b48a9 commit f415e95
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 1 deletion.
8 changes: 8 additions & 0 deletions darts/dataprocessing/transformers/window_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def __init__(
transformation should be applied. If not specified, the transformation will be
applied on all components.
:``"function_name"``: Optional. A string specifying the function name referenced as part of
the transformation output name. For example, given a user-provided function
transformation on rolling window size of 5 on the component "comp", the
default transformation output name is "rolling_udf_5_comp" whereby "udf"
refers to "user defined function". If specified, the ``"function_name"`` will
replace the default name "udf". Similarly, the ``"function_name"`` will replace
the name of the pandas builtin transformation function name in the output name.
All other dictionary items provided will be treated as keyword arguments for the windowing mode
(i.e., ``rolling/ewm/expanding``) or for the specific function
in that mode (i.e., ``pandas.DataFrame.rolling.mean/std/max/min...`` or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,23 @@ def test_ts_windowtransf_output_series(self):
],
)

# test customized function name that overwrites the pandas builtin transformation
transforms = {
"function": "sum",
"mode": "rolling",
"window": 1,
"function_name": "customized_name",
}
transformed_ts = self.series_univ_det.window_transform(transforms=transforms)
self.assertEqual(
transformed_ts.components.to_list(),
[
f"{transforms['mode']}_{transforms['function_name']}_{str(transforms['window'])}_{comp}"
for comp in self.series_univ_det.components
],
)
del transforms["function_name"]

# multivariate deterministic input
# transform one component
transforms.update({"components": "0"})
Expand Down Expand Up @@ -242,6 +259,39 @@ def test_ts_windowtransf_output_series(self):
transformed_ts = self.series_multi_prob.window_transform(transforms=transforms)
self.assertEqual(transformed_ts.n_samples, 2)

def test_user_defined_function_behavior(self):
def count_above_mean(array):
mean = np.mean(array)
return np.where(array > mean)[0].size

transformation = {
"function": count_above_mean,
"mode": "rolling",
"window": 5,
}
transformed_ts = self.target.window_transform(
transformation,
)
expected_transformed_series = TimeSeries.from_times_and_values(
self.times,
[0, 1, 1, 2, 2, 2, 2, 2, 2, 2],
columns=["rolling_udf_5_0"],
)
self.assertEqual(transformed_ts, expected_transformed_series)

# test if a customized function name is provided
transformation.update({"function_name": "count_above_mean"})
transformed_ts = self.target.window_transform(
transformation,
)
self.assertEqual(
transformed_ts.components.to_list(),
[
f"{transformation['mode']}_{transformation['function_name']}_{str(transformation['window'])}_{comp}"
for comp in self.target.components
],
)

def test_ts_windowtransf_output_nabehavior(self):
window_transformations = {
"function": "sum",
Expand Down
16 changes: 15 additions & 1 deletion darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3255,6 +3255,14 @@ def window_transform(
transformation should be applied. If not specified, the transformation will be
applied on all components.
:``"function_name"``: Optional. A string specifying the function name referenced as part of
the transformation output name. For example, given a user-provided function
transformation on rolling window size of 5 on the component "comp", the
default transformation output name is "rolling_udf_5_comp" whereby "udf"
refers to "user defined function". If specified, the ``"function_name"`` will
replace the default name "udf". Similarly, the ``"function_name"`` will replace
the name of the pandas builtin transformation function name in the output name.
All other dictionary items provided will be treated as keyword arguments for the windowing mode
(i.e., ``rolling/ewm/expanding``) or for the specific function
in that mode (i.e., ``pandas.DataFrame.rolling.mean/std/max/min...`` or
Expand Down Expand Up @@ -3409,6 +3417,7 @@ def _get_kwargs(transformation, forecasting_safe):
"function",
"group",
"components",
"function_name",
}

window_mode_expected_args = set(window_mode.__code__.co_varnames)
Expand Down Expand Up @@ -3536,8 +3545,13 @@ def _get_kwargs(transformation, forecasting_safe):
)
min_periods = transformation["min_periods"]
# set new columns names
fn_name = transformation.get("function_name")
if fn_name:
function_name = fn_name
else:
function_name = fn if fn != "apply" else "udf"
name_prefix = (
f"{window_mode}_{fn if fn != 'apply' else 'udf'}"
f"{window_mode}_{function_name}"
f"{'_'+str(transformation['window']) if 'window' in transformation else ''}"
f"{'_'+str(min_periods) if min_periods>1 else ''}"
)
Expand Down

0 comments on commit f415e95

Please sign in to comment.