diff --git a/examples/series/rolling/series_rolling_aggregate.py b/examples/series/rolling/series_rolling_aggregate.py new file mode 100644 index 000000000..8af431947 --- /dev/null +++ b/examples/series/rolling/series_rolling_aggregate.py @@ -0,0 +1,57 @@ +# ***************************************************************************** +# Copyright (c) 2019, Intel Corporation All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +import numpy as np +import pandas as pd +from numba import njit + + +@njit +def series_rolling_aggregate(): + """ + Expected output: + get_mean get_median + 0 NaN NaN + 1 NaN NaN + 2 4.000000 4.0 + 3 3.333333 3.0 + 4 4.333333 5.0 + """ + series = pd.Series([4, 3, 5, 2, 6]) # Series of 4, 3, 5, 2, 6 + + def get_mean(x): + return x.mean() + + def get_median(x): + return np.median(x) + + # list of inhomogeneous type as argument isn't supported, so tuple passed + out_data_frame = series.rolling(3).aggregate((get_mean, get_median)) + + return out_data_frame + + +print(series_rolling_aggregate()) diff --git a/sdc/datatypes/hpat_pandas_series_rolling_functions.py b/sdc/datatypes/hpat_pandas_series_rolling_functions.py index 918c8254f..bf6f3f824 100644 --- a/sdc/datatypes/hpat_pandas_series_rolling_functions.py +++ b/sdc/datatypes/hpat_pandas_series_rolling_functions.py @@ -29,7 +29,9 @@ from numba import prange from numba.extending import register_jitable -from numba.types import float64, Boolean, Integer, NoneType, Omitted +from numba.types import (float64, Boolean, Integer, + NoneType, Omitted, Tuple, UniTuple) +from numba.types.functions import Callable from sdc.datatypes.common_functions import TypeChecker from sdc.datatypes.hpat_pandas_series_rolling_types import SeriesRollingType @@ -84,6 +86,12 @@ """ +@register_jitable +def arr_apply(arr, func): + """Apply function for values""" + return func(arr) + + @register_jitable def arr_nonnan_count(arr): """Count non-NaN values""" @@ -219,6 +227,37 @@ def impl(self): return impl +@register_jitable +def hpat_pandas_series_rolling_apply_impl(self, arg): + win = self._window + minp = self._min_periods + + input_series = self._data + input_arr = input_series._data + length = len(input_arr) + output_arr = numpy.empty(length, dtype=float64) + + def culc_apply(arr, func, minp): + finite_arr = arr.copy() + finite_arr[numpy.isinf(arr)] = numpy.nan + if len(finite_arr) < minp: + return numpy.nan + else: + return arr_apply(finite_arr, func) + + boundary = min(win, length) + for i in prange(boundary): + arr_range = input_arr[:i + 1] + output_arr[i] = culc_apply(arr_range, arg, minp) + + for i in prange(boundary, length): + arr_range = input_arr[i + 1 - win:i + 1] + output_arr[i] = culc_apply(arr_range, arg, minp) + + return pandas.Series(output_arr, input_series._index, + name=input_series._name) + + hpat_pandas_rolling_series_count_impl = register_jitable( gen_hpat_pandas_series_rolling_zerominp_impl(arr_nonnan_count, float64)) hpat_pandas_rolling_series_max_impl = register_jitable( @@ -233,6 +272,39 @@ def impl(self): gen_hpat_pandas_series_rolling_impl(arr_sum, float64)) +@sdc_overload_method(SeriesRollingType, 'aggregate') +def hpat_pandas_series_rolling_aggregate(self, arg): + + ty_checker = TypeChecker('Method rolling.aggregate().') + ty_checker.check(self, SeriesRollingType) + + if isinstance(arg, (Tuple, UniTuple)): + def hpat_pandas_rolling_series_aggregate_impl(self, arg): + win = self._window + minp = self._min_periods + + input_series = self._data + input_arr = input_series._data + length = len(input_arr) + output_arr = numpy.empty(length, dtype=float64) + + for func in arg: + # TODO: fix issue when iterating over tuple of functions + # most likely the issue happens due to heterogeneous tuple + pass + + # TODO: return Dataframe + return pandas.Series(output_arr, input_series._index, + name=input_series._name) + + return hpat_pandas_rolling_series_aggregate_impl + + elif isinstance(arg, Callable): + return hpat_pandas_series_rolling_apply_impl + + ty_checker.raise_exc(arg, 'callable, tuple of callables', 'arg') + + @sdc_overload_method(SeriesRollingType, 'corr') def hpat_pandas_series_rolling_corr(self, other=None, pairwise=None): diff --git a/sdc/tests/test_rolling.py b/sdc/tests/test_rolling.py index bbdaec123..4b8e7ae9c 100644 --- a/sdc/tests/test_rolling.py +++ b/sdc/tests/test_rolling.py @@ -501,6 +501,27 @@ def test_impl(series, window, min_periods, center, msg = msg_tmpl.format('closed', 'int64', 'str') self.assertIn(msg, str(raises.exception)) + @skip_sdc_jit('Series.rolling.aggregate() unsupported Series index') + def test_series_rolling_aggregate_mean_only(self): + def test_impl(series, window, min_periods): + def func(x): + if len(x) == 0: + return np.nan + return x.mean() + return series.rolling(window, min_periods).aggregate(func) + + hpat_func = self.jit(test_impl) + + data = [1., -1., 0., 0.1, -0.1] + index = list(range(len(data)))[::-1] + series = pd.Series(data, index, name='A') + for window in range(0, len(series) + 3, 2): + for min_periods in range(0, window + 1, 2): + with self.subTest(swindow=window, min_periods=min_periods): + jit_result = hpat_func(series, window, min_periods) + ref_result = test_impl(series, window, min_periods) + pd.testing.assert_series_equal(jit_result, ref_result) + @skip_sdc_jit('Series.rolling.corr() unsupported Series index') def test_series_rolling_corr(self): def test_impl(series, window, min_periods, other):