diff --git a/sdc/datatypes/common_functions.py b/sdc/datatypes/common_functions.py index 7cb6e023c..0d1a8757b 100644 --- a/sdc/datatypes/common_functions.py +++ b/sdc/datatypes/common_functions.py @@ -41,12 +41,14 @@ from numba import numpy_support import sdc -from sdc.str_arr_type import string_array_type -from sdc.str_arr_ext import (num_total_chars, append_string_array_to, - str_arr_is_na, pre_alloc_string_array, str_arr_set_na, - cp_str_list_to_array) +from sdc.hiframes.pd_series_type import SeriesType +from sdc.str_arr_ext import ( + append_string_array_to, cp_str_list_to_array, num_total_chars, + pre_alloc_string_array, str_arr_is_na, str_arr_set_na, string_array_type +) from sdc.utilities.utils import sdc_overload, sdc_register_jitable -from sdc.utilities.sdc_typing_utils import find_common_dtype_from_numpy_dtypes +from sdc.utilities.sdc_typing_utils import (find_common_dtype_from_numpy_dtypes, + TypeChecker) def hpat_arrays_append(A, B): @@ -537,3 +539,55 @@ def _sdc_pandas_series_check_axis_impl(axis): return _sdc_pandas_series_check_axis_impl return None + + +def _sdc_pandas_series_align(series, other, size='max', finiteness=False): + """ + Align series and other series by + size where size of output series is max/min size of input series + finiteness where all the infinite and matched finite values are replaced with nans, e.g. + series: [1., inf, inf, -1., 0.] -> [1., nan, nan, -1., 0.] + other: [1., -1., 0., 0.1, -0.1] -> [1., nan, nan, 0.1, -0.1] + """ + pass + + +@sdc_overload(_sdc_pandas_series_align, jit_options={'parallel': False}) +def _sdc_pandas_series_align_overload(series, other, size='max', finiteness=False): + ty_checker = TypeChecker('Function sdc.common_functions._sdc_pandas_series_align().') + ty_checker.check(series, SeriesType) + ty_checker.check(other, SeriesType) + + str_types = (str, types.StringLiteral, types.UnicodeType, types.Omitted) + if not isinstance(size, str_types): + ty_checker.raise_exc(size, 'str', 'size') + + if not isinstance(finiteness, (bool, types.Boolean, types.Omitted)): + ty_checker.raise_exc(finiteness, 'bool', 'finiteness') + + def _sdc_pandas_series_align_impl(series, other, size='max', finiteness=False): + if size != 'max' and size != 'min': + raise ValueError("Function sdc.common_functions._sdc_pandas_series_align(). " + "The object size\n expected: 'max' or 'min'") + + arr, other_arr = series._data, other._data + arr_len, other_arr_len = len(arr), len(other_arr) + min_length = min(arr_len, other_arr_len) + length = max(arr_len, other_arr_len) if size == 'max' else min_length + + aligned_arr = numpy.repeat([numpy.nan], length) + aligned_other_arr = numpy.repeat([numpy.nan], length) + + for i in numba.prange(min_length): + if not finiteness or (numpy.isfinite(arr[i]) and numpy.isfinite(other_arr[i])): + aligned_arr[i] = arr[i] + aligned_other_arr[i] = other_arr[i] + else: + aligned_arr[i] = aligned_other_arr[i] = numpy.nan + + aligned = pandas.Series(aligned_arr, name=series._name) + aligned_other = pandas.Series(aligned_other_arr, name=other._name) + + return aligned, aligned_other + + return _sdc_pandas_series_align_impl diff --git a/sdc/datatypes/hpat_pandas_dataframe_rolling_functions.py b/sdc/datatypes/hpat_pandas_dataframe_rolling_functions.py index 29dcc5153..85d78514b 100644 --- a/sdc/datatypes/hpat_pandas_dataframe_rolling_functions.py +++ b/sdc/datatypes/hpat_pandas_dataframe_rolling_functions.py @@ -120,7 +120,7 @@ def df_rolling_method_other_df_codegen(method_name, self, other, args=None, kws= ' else:', ' _pairwise = pairwise', ' if _pairwise:', - ' raise ValueError("Method rolling.corr(). The object pairwise\\n expected: False, None")' + f' raise ValueError("Method rolling.{method_name}(). The object pairwise\\n expected: False, None")' ] data_length = 'len(get_dataframe_data(self._data, 0))' if data_columns else '0' @@ -139,7 +139,7 @@ def df_rolling_method_other_df_codegen(method_name, self, other, args=None, kws= f' series_{col} = pandas.Series(data_{col})', f' {other_series} = pandas.Series(other_data_{col})', f' rolling_{col} = series_{col}.rolling({rolling_params})', - f' result_{col} = rolling_{col}.corr({method_params})', + f' result_{col} = rolling_{col}.{method_name}({method_params})', f' {res_data} = result_{col}._data[:length]' ] else: @@ -182,32 +182,41 @@ def df_rolling_method_main_codegen(method_params, df_columns, method_name): return func_lines -def df_rolling_method_other_none_codegen(method_name, self, args=None, kws=None): - args = args or [] - kwargs = kws or {} +def gen_df_rolling_method_other_none_codegen(rewrite_name=None): + """Generate df.rolling method code generator based on name of the method""" + def df_rolling_method_other_none_codegen(method_name, self, args=None, kws=None): + _method_name = rewrite_name or method_name + args = args or [] + kwargs = kws or {} - impl_params = ['self'] + args + params2list(kwargs) - impl_params_as_str = ', '.join(impl_params) + impl_params = ['self'] + args + params2list(kwargs) + impl_params_as_str = ', '.join(impl_params) - impl_name = f'_df_rolling_{method_name}_other_none_impl' - func_lines = [f'def {impl_name}({impl_params_as_str}):'] + impl_name = f'_df_rolling_{_method_name}_other_none_impl' + func_lines = [f'def {impl_name}({impl_params_as_str}):'] - if 'pairwise' in kwargs: - func_lines += [ - ' if pairwise is None:', - ' _pairwise = True', - ' else:', - ' _pairwise = pairwise', - ' if _pairwise:', - ' raise ValueError("Method rolling.corr(). The object pairwise\\n expected: False")' - ] - method_params = args + ['{}={}'.format(k, k) for k in kwargs if k != 'other'] - func_lines += df_rolling_method_main_codegen(method_params, self.data.columns, method_name) - func_text = '\n'.join(func_lines) + if 'pairwise' in kwargs: + func_lines += [ + ' if pairwise is None:', + ' _pairwise = True', + ' else:', + ' _pairwise = pairwise', + ' if _pairwise:', + f' raise ValueError("Method rolling.{_method_name}(). The object pairwise\\n expected: False")' + ] + method_params = args + ['{}={}'.format(k, k) for k in kwargs if k != 'other'] + func_lines += df_rolling_method_main_codegen(method_params, self.data.columns, method_name) + func_text = '\n'.join(func_lines) - global_vars = {'pandas': pandas, 'get_dataframe_data': get_dataframe_data} + global_vars = {'pandas': pandas, 'get_dataframe_data': get_dataframe_data} - return func_text, global_vars + return func_text, global_vars + + return df_rolling_method_other_none_codegen + + +df_rolling_method_other_none_codegen = gen_df_rolling_method_other_none_codegen() +df_rolling_cov_other_none_codegen = gen_df_rolling_method_other_none_codegen('cov') def df_rolling_method_codegen(method_name, self, args=None, kws=None): @@ -249,6 +258,16 @@ def gen_df_rolling_method_other_none_impl(method_name, self, args=None, kws=None return _impl +def gen_df_rolling_cov_other_none_impl(method_name, self, args=None, kws=None): + func_text, global_vars = df_rolling_cov_other_none_codegen(method_name, self, + args=args, kws=kws) + loc_vars = {} + exec(func_text, global_vars, loc_vars) + _impl = loc_vars[f'_df_rolling_cov_other_none_impl'] + + return _impl + + def gen_df_rolling_method_impl(method_name, self, args=None, kws=None): func_text, global_vars = df_rolling_method_codegen(method_name, self, args=args, kws=kws) @@ -308,6 +327,37 @@ def sdc_pandas_dataframe_rolling_count(self): return gen_df_rolling_method_impl('count', self) +@sdc_overload_method(DataFrameRollingType, 'cov') +def sdc_pandas_dataframe_rolling_cov(self, other=None, pairwise=None, ddof=1): + + ty_checker = TypeChecker('Method rolling.cov().') + ty_checker.check(self, DataFrameRollingType) + + accepted_other = (Omitted, NoneType, DataFrameType, SeriesType) + if not isinstance(other, accepted_other) and other is not None: + ty_checker.raise_exc(other, 'DataFrame, Series', 'other') + + accepted_pairwise = (bool, Boolean, Omitted, NoneType) + if not isinstance(pairwise, accepted_pairwise) and pairwise is not None: + ty_checker.raise_exc(pairwise, 'bool', 'pairwise') + + if not isinstance(ddof, (int, Integer, Omitted)): + ty_checker.raise_exc(ddof, 'int', 'ddof') + + none_other = isinstance(other, (Omitted, NoneType)) or other is None + kws = {'other': 'None', 'pairwise': 'None', 'ddof': '1'} + + if none_other: + # method _df_cov in comparison to method cov doesn't align input data + # by replacing infinite and matched finite values with nans + return gen_df_rolling_cov_other_none_impl('_df_cov', self, kws=kws) + + if isinstance(other, DataFrameType): + return gen_df_rolling_method_other_df_impl('cov', self, other, kws=kws) + + return gen_df_rolling_method_impl('cov', self, kws=kws) + + @sdc_overload_method(DataFrameRollingType, 'kurt') def sdc_pandas_dataframe_rolling_kurt(self): @@ -457,6 +507,28 @@ def sdc_pandas_dataframe_rolling_var(self, ddof=1): 'extra_params': '' }) +sdc_pandas_dataframe_rolling_cov.__doc__ = sdc_pandas_dataframe_rolling_docstring_tmpl.format(**{ + 'method_name': 'cov', + 'example_caption': 'Calculate rolling covariance.', + 'limitations_block': + """ + Limitations + ----------- + DataFrame elements cannot be max/min float/integer. Otherwise SDC and Pandas results are different. + Different size of `self` and `other` can produce result different from the result of Pandas + due to different float rounding in Python and SDC. + """, + 'extra_params': + """ + other: :obj:`Series` or :obj:`DataFrame` + Other Series or DataFrame. + pairwise: :obj:`bool` + Calculate pairwise combinations of columns within a DataFrame. + ddof: :obj:`int` + Delta Degrees of Freedom. + """ +}) + sdc_pandas_dataframe_rolling_kurt.__doc__ = sdc_pandas_dataframe_rolling_docstring_tmpl.format(**{ 'method_name': 'kurt', 'example_caption': 'Calculate unbiased rolling kurtosis.', diff --git a/sdc/datatypes/hpat_pandas_series_rolling_functions.py b/sdc/datatypes/hpat_pandas_series_rolling_functions.py index 3bb25d18b..57d9a549e 100644 --- a/sdc/datatypes/hpat_pandas_series_rolling_functions.py +++ b/sdc/datatypes/hpat_pandas_series_rolling_functions.py @@ -34,8 +34,10 @@ from numba.types import (float64, Boolean, Integer, NoneType, Number, Omitted, StringLiteral, UnicodeType) -from sdc.utilities.sdc_typing_utils import TypeChecker +from sdc.datatypes.common_functions import _sdc_pandas_series_align from sdc.datatypes.hpat_pandas_series_rolling_types import SeriesRollingType +from sdc.hiframes.pd_series_type import SeriesType +from sdc.utilities.sdc_typing_utils import TypeChecker from sdc.utilities.utils import sdc_overload_method, sdc_register_jitable @@ -111,15 +113,6 @@ def arr_nonnan_count(arr): return len(arr) - numpy.isnan(arr).sum() -@sdc_register_jitable -def arr_cov(x, y, ddof): - """Calculate covariance of values 1D arrays x and y of the same size""" - if len(x) == 0: - return numpy.nan - - return numpy.cov(x, y, ddof=ddof)[0, 1] - - @sdc_register_jitable def _moment(arr, moment): mn = numpy.mean(arr) @@ -451,16 +444,15 @@ def hpat_pandas_rolling_series_count_impl(self): return hpat_pandas_rolling_series_count_impl -@sdc_rolling_overload(SeriesRollingType, 'cov') -def hpat_pandas_series_rolling_cov(self, other=None, pairwise=None, ddof=1): - +def _hpat_pandas_series_rolling_cov_check_types(self, other=None, + pairwise=None, ddof=1): + """Check types of parameters of series.rolling.cov()""" ty_checker = TypeChecker('Method rolling.cov().') ty_checker.check(self, SeriesRollingType) - # TODO: check `other` is Series after a circular import of SeriesType fixed - # accepted_other = (bool, Omitted, NoneType, SeriesType) - # if not isinstance(other, accepted_other) and other is not None: - # ty_checker.raise_exc(other, 'Series', 'other') + accepted_other = (bool, Omitted, NoneType, SeriesType) + if not isinstance(other, accepted_other) and other is not None: + ty_checker.raise_exc(other, 'Series', 'other') accepted_pairwise = (bool, Boolean, Omitted, NoneType) if not isinstance(pairwise, accepted_pairwise) and pairwise is not None: @@ -469,50 +461,48 @@ def hpat_pandas_series_rolling_cov(self, other=None, pairwise=None, ddof=1): if not isinstance(ddof, (int, Integer, Omitted)): ty_checker.raise_exc(ddof, 'int', 'ddof') + +def _gen_hpat_pandas_rolling_series_cov_impl(other, align_finiteness=False): + """Generate series.rolling.cov() implementation based on series alignment""" nan_other = isinstance(other, (Omitted, NoneType)) or other is None - def hpat_pandas_rolling_series_cov_impl(self, other=None, pairwise=None, ddof=1): + def _impl(self, other=None, pairwise=None, ddof=1): win = self._window minp = self._min_periods main_series = self._data - main_arr = main_series._data - main_arr_length = len(main_arr) - if nan_other == True: # noqa - other_arr = main_arr + other_series = main_series else: - other_arr = other._data + other_series = other - other_arr_length = len(other_arr) - length = max(main_arr_length, other_arr_length) - output_arr = numpy.empty(length, dtype=float64) + main_aligned, other_aligned = _sdc_pandas_series_align(main_series, other_series, + finiteness=align_finiteness) + count = (main_aligned + other_aligned).rolling(win).count() + bias_adj = count / (count - ddof) - def calc_cov(main, other, ddof, minp): - # align arrays `main` and `other` by size and finiteness - min_length = min(len(main), len(other)) - main_valid_indices = numpy.isfinite(main[:min_length]) - other_valid_indices = numpy.isfinite(other[:min_length]) - valid = main_valid_indices & other_valid_indices + def mean(series): + return series.rolling(win, min_periods=minp).mean() - if len(main[valid]) < minp: - return numpy.nan - else: - return arr_cov(main[valid], other[valid], ddof) + return (mean(main_aligned * other_aligned) - mean(main_aligned) * mean(other_aligned)) * bias_adj - for i in prange(min(win, length)): - main_arr_range = main_arr[:i + 1] - other_arr_range = other_arr[:i + 1] - output_arr[i] = calc_cov(main_arr_range, other_arr_range, ddof, minp) + return _impl - for i in prange(win, length): - main_arr_range = main_arr[i + 1 - win:i + 1] - other_arr_range = other_arr[i + 1 - win:i + 1] - output_arr[i] = calc_cov(main_arr_range, other_arr_range, ddof, minp) - return pandas.Series(output_arr) +@sdc_rolling_overload(SeriesRollingType, 'cov') +def hpat_pandas_series_rolling_cov(self, other=None, pairwise=None, ddof=1): + _hpat_pandas_series_rolling_cov_check_types(self, other=other, + pairwise=pairwise, ddof=ddof) + + return _gen_hpat_pandas_rolling_series_cov_impl(other, align_finiteness=True) + + +@sdc_rolling_overload(SeriesRollingType, '_df_cov') +def hpat_pandas_series_rolling_cov(self, other=None, pairwise=None, ddof=1): + _hpat_pandas_series_rolling_cov_check_types(self, other=other, + pairwise=pairwise, ddof=ddof) - return hpat_pandas_rolling_series_cov_impl + return _gen_hpat_pandas_rolling_series_cov_impl(other) @sdc_rolling_overload(SeriesRollingType, 'kurt') diff --git a/sdc/tests/test_rolling.py b/sdc/tests/test_rolling.py index 058ae0c2a..9d3acd92e 100644 --- a/sdc/tests/test_rolling.py +++ b/sdc/tests/test_rolling.py @@ -629,6 +629,54 @@ def test_impl(obj, window, min_periods): ref_result = test_impl(obj, window, min_periods) assert_equal(jit_result, ref_result) + def _test_rolling_cov(self, obj, other): + def test_impl(obj, window, min_periods, other, ddof): + return obj.rolling(window, min_periods).cov(other, ddof=ddof) + + hpat_func = self.jit(test_impl) + assert_equal = self._get_assert_equal(obj) + + for window in range(0, len(obj) + 3, 2): + for min_periods, ddof in product(range(0, window, 2), [0, 1]): + with self.subTest(obj=obj, other=other, window=window, + min_periods=min_periods, ddof=ddof): + jit_result = hpat_func(obj, window, min_periods, other, ddof) + ref_result = test_impl(obj, window, min_periods, other, ddof) + assert_equal(jit_result, ref_result) + + def _test_rolling_cov_with_no_other(self, obj): + def test_impl(obj, window, min_periods): + return obj.rolling(window, min_periods).cov(pairwise=False) + + hpat_func = self.jit(test_impl) + assert_equal = self._get_assert_equal(obj) + + for window in range(0, len(obj) + 3, 2): + for min_periods in range(0, window, 2): + with self.subTest(obj=obj, window=window, + min_periods=min_periods): + jit_result = hpat_func(obj, window, min_periods) + ref_result = test_impl(obj, window, min_periods) + assert_equal(jit_result, ref_result) + + def _test_rolling_cov_unsupported_types(self, obj): + def test_impl(obj, pairwise, ddof): + return obj.rolling(3, 3).cov(pairwise=pairwise, ddof=ddof) + + hpat_func = self.jit(test_impl) + + msg_tmpl = 'Method rolling.cov(). The object {}\n given: {}\n expected: {}' + + with self.assertRaises(TypingError) as raises: + hpat_func(obj, 1, 1) + msg = msg_tmpl.format('pairwise', 'int64', 'bool') + self.assertIn(msg, str(raises.exception)) + + with self.assertRaises(TypingError) as raises: + hpat_func(obj, None, '1') + msg = msg_tmpl.format('ddof', 'unicode_type', 'int') + self.assertIn(msg, str(raises.exception)) + def _test_rolling_kurt(self, obj): def test_impl(obj, window, min_periods): return obj.rolling(window, min_periods).kurt() @@ -957,6 +1005,102 @@ def test_df_rolling_count(self): self._test_rolling_count(df) + @skip_sdc_jit('DataFrame.rolling.cov() unsupported') + def test_df_rolling_cov(self): + all_data = [ + list(range(10)), [1., -1., 0., 0.1, -0.1], + [1., np.inf, np.inf, -1., 0., np.inf, np.NINF, np.NINF], + [np.nan, np.inf, np.inf, np.nan, np.nan, np.nan, np.NINF, np.NZERO] + ] + length = min(len(d) for d in all_data) + data = {n: d[:length] for n, d in zip(string.ascii_uppercase, all_data)} + df = pd.DataFrame(data) + for d in all_data: + other = pd.Series(d) + self._test_rolling_cov(df, other) + + other_all_data = deepcopy(all_data) + [list(range(10))[::-1]] + other_all_data[1] = [-1., 1., 0., -0.1, 0.1] + other_length = min(len(d) for d in other_all_data) + other_data = {n: d[:other_length] for n, d in zip(string.ascii_uppercase, other_all_data)} + other = pd.DataFrame(other_data) + + self._test_rolling_cov(df, other) + + @skip_sdc_jit('DataFrame.rolling.cov() unsupported') + def test_df_rolling_cov_no_other(self): + all_data = [ + list(range(10)), [1., -1., 0., 0.1, -0.1], + [1., np.inf, np.inf, -1., 0., np.inf, np.NINF, np.NINF], + [np.nan, np.inf, np.inf, np.nan, np.nan, np.nan, np.NINF, np.NZERO] + ] + length = min(len(d) for d in all_data) + data = {n: d[:length] for n, d in zip(string.ascii_uppercase, all_data)} + df = pd.DataFrame(data) + + self._test_rolling_cov_with_no_other(df) + + @skip_sdc_jit('DataFrame.rolling.cov() unsupported exceptions') + def test_df_rolling_cov_unsupported_types(self): + all_data = [[1., -1., 0., 0.1, -0.1], [-1., 1., 0., -0.1, 0.1]] + length = min(len(d) for d in all_data) + data = {n: d[:length] for n, d in zip(string.ascii_uppercase, all_data)} + df = pd.DataFrame(data) + + self._test_rolling_cov_unsupported_types(df) + + @skip_sdc_jit('DataFrame.rolling.cov() unsupported exceptions') + def test_df_rolling_cov_unsupported_values(self): + def test_impl(df, other, pairwise): + return df.rolling(3, 3).cov(other=other, pairwise=pairwise) + + hpat_func = self.jit(test_impl) + msg_tmpl = 'Method rolling.cov(). The object pairwise\n expected: {}' + + df = pd.DataFrame({'A': [1., -1., 0., 0.1, -0.1], + 'B': [-1., 1., 0., -0.1, 0.1]}) + for pairwise in [None, True]: + with self.assertRaises(ValueError) as raises: + hpat_func(df, None, pairwise) + self.assertIn(msg_tmpl.format('False'), str(raises.exception)) + + other = pd.DataFrame({'A': [-1., 1., 0., -0.1, 0.1], + 'C': [1., -1., 0., 0.1, -0.1]}) + with self.assertRaises(ValueError) as raises: + hpat_func(df, other, True) + self.assertIn(msg_tmpl.format('False, None'), str(raises.exception)) + + @skip_sdc_jit('Series.rolling.cov() unsupported Series index') + @unittest.expectedFailure + def test_df_rolling_cov_issue_floating_point_rounding(self): + """ + Cover issue of different float rounding in Python and SDC/Numba: + + s = np.Series([1., -1., 0., 0.1, -0.1]) + s.rolling(2, 0).mean() + + Python: SDC/Numba: + 0 1.000000e+00 0 1.00 + 1 0.000000e+00 1 0.00 + 2 -5.000000e-01 2 -0.50 + 3 5.000000e-02 3 0.05 + 4 -1.387779e-17 4 0.00 + dtype: float64 dtype: float64 + + BTW: cov uses mean inside itself + """ + def test_impl(df, window, min_periods, other, ddof): + return df.rolling(window, min_periods).cov(other, ddof=ddof) + + hpat_func = self.jit(test_impl) + + df = pd.DataFrame({'A': [1., -1., 0., 0.1, -0.1]}) + other = pd.DataFrame({'A': [-1., 1., 0., -0.1, 0.1, 0.]}) + + jit_result = hpat_func(df, 2, 0, other, 1) + ref_result = test_impl(df, 2, 0, other, 1) + pd.testing.assert_frame_equal(jit_result, ref_result) + @skip_sdc_jit('DataFrame.rolling.kurt() unsupported') def test_df_rolling_kurt(self): all_data = test_global_input_data_float64 @@ -1218,11 +1362,6 @@ def test_series_rolling_count(self): @skip_sdc_jit('Series.rolling.cov() unsupported Series index') def test_series_rolling_cov(self): - def test_impl(series, window, min_periods, other, ddof): - return series.rolling(window, min_periods).cov(other, ddof=ddof) - - hpat_func = self.jit(test_impl) - all_data = [ list(range(5)), [1., -1., 0., 0.1, -0.1], [1., np.inf, np.inf, -1., 0., np.inf, np.NINF, np.NINF], @@ -1231,21 +1370,10 @@ def test_impl(series, window, min_periods, other, ddof): for main_data, other_data in product(all_data, all_data): series = pd.Series(main_data) other = pd.Series(other_data) - for window in range(0, len(series) + 3, 2): - for min_periods, ddof in product(range(0, window, 2), [0, 1]): - with self.subTest(series=series, other=other, window=window, - min_periods=min_periods, ddof=ddof): - jit_result = hpat_func(series, window, min_periods, other, ddof) - ref_result = test_impl(series, window, min_periods, other, ddof) - pd.testing.assert_series_equal(jit_result, ref_result) + self._test_rolling_cov(series, other) @skip_sdc_jit('Series.rolling.cov() unsupported Series index') - def test_series_rolling_cov_default(self): - def test_impl(series, window, min_periods): - return series.rolling(window, min_periods).cov() - - hpat_func = self.jit(test_impl) - + def test_series_rolling_cov_no_other(self): all_data = [ list(range(5)), [1., -1., 0., 0.1, -0.1], [1., np.inf, np.inf, -1., 0., np.inf, np.NINF, np.NINF], @@ -1253,13 +1381,7 @@ def test_impl(series, window, min_periods): ] for data in all_data: series = pd.Series(data) - for window in range(0, len(series) + 3, 2): - for min_periods in range(0, window, 2): - with self.subTest(series=series, window=window, - min_periods=min_periods): - jit_result = hpat_func(series, window, min_periods) - ref_result = test_impl(series, window, min_periods) - pd.testing.assert_series_equal(jit_result, ref_result) + self._test_rolling_cov_with_no_other(series) @skip_sdc_jit('Series.rolling.cov() unsupported Series index') @unittest.expectedFailure @@ -1278,23 +1400,8 @@ def test_impl(series, window, min_periods, other, ddof): @skip_sdc_jit('Series.rolling.cov() unsupported exceptions') def test_series_rolling_cov_unsupported_types(self): - def test_impl(pairwise, ddof): - series = pd.Series([1., -1., 0., 0.1, -0.1]) - return series.rolling(3, 3).cov(pairwise=pairwise, ddof=ddof) - - hpat_func = self.jit(test_impl) - - msg_tmpl = 'Method rolling.cov(). The object {}\n given: {}\n expected: {}' - - with self.assertRaises(TypingError) as raises: - hpat_func(1, 1) - msg = msg_tmpl.format('pairwise', 'int64', 'bool') - self.assertIn(msg, str(raises.exception)) - - with self.assertRaises(TypingError) as raises: - hpat_func(None, '1') - msg = msg_tmpl.format('ddof', 'unicode_type', 'int') - self.assertIn(msg, str(raises.exception)) + series = pd.Series([1., -1., 0., 0.1, -0.1]) + self._test_rolling_cov_unsupported_types(series) @skip_sdc_jit('Series.rolling.kurt() unsupported Series index') def test_series_rolling_kurt(self): diff --git a/sdc/tests/tests_perf/test_perf_df_rolling.py b/sdc/tests/tests_perf/test_perf_df_rolling.py index 9d2920df4..8f84b16e4 100644 --- a/sdc/tests/tests_perf/test_perf_df_rolling.py +++ b/sdc/tests/tests_perf/test_perf_df_rolling.py @@ -117,8 +117,7 @@ def _test_python(self, pyfunc, record, *args, **kwargs): def _gen_df(self, data, columns_num=10): """Generate DataFrame based on input data""" - return pandas.DataFrame( - {col: data for col in string.ascii_uppercase[:columns_num]}) + return pandas.DataFrame({col: data for col in string.ascii_uppercase[:columns_num]}) def _test_case(self, pyfunc, name, input_data=test_global_input_data_float64, @@ -181,6 +180,10 @@ def test_df_rolling_corr(self): def test_df_rolling_count(self): self._test_df_rolling_method('count') + def test_df_rolling_cov(self): + self._test_df_rolling_method('cov', extra_usecase_params='other', + method_params='other=other') + def test_df_rolling_kurt(self): self._test_df_rolling_method('kurt')