diff --git a/sdc/datatypes/hpat_pandas_series_functions.py b/sdc/datatypes/hpat_pandas_series_functions.py index 52901e627..de5a48d1c 100644 --- a/sdc/datatypes/hpat_pandas_series_functions.py +++ b/sdc/datatypes/hpat_pandas_series_functions.py @@ -156,61 +156,6 @@ def hpat_pandas_series_getitem_idx_series_impl(self, idx): raise TypingError('{} The index must be an Integer, Slice or a pandas.series. Given: {}'.format(_func_name, idx)) -@overload(operator.setitem) -def hpat_pandas_series_setitem(self, idx, value): - """ - Pandas Series operator :attr:`pandas.Series.get` implementation - ''' - Test: python -m sdc.runtests sdc.tests.test_series.TestSeries.test_series_setitem_unsupported - ''' - Parameters - ---------- - series: :obj:`pandas.Series` - input series - idx: :obj:`int`, :obj:`slice` or :obj:`pandas.Series` - input index - value: :object - input value - Returns - ------- - :class:`pandas.Series` or an element of the underneath type - object of :class:`pandas.Series` - """ - - _func_name = 'Operator setitem().' - if not isinstance(self, SeriesType): - raise TypingError('{} The object must be a pandas.series. Given: {}'.format(_func_name, self)) - - if not isinstance(self.dtype, type(value)): - raise TypingError('{} Value must be one type with series. Given: {}, self.dtype={}'.format(_func_name, - value, self.dtype)) - - if isinstance(idx, types.Integer) or isinstance(idx, types.SliceType): - def hpat_pandas_series_setitem_idx_integer_impl(self, idx, value): - """ - Test: python -m sdc.runtests sdc.tests.test_series.TestSeries.test_series_setitem_for_value - Test: python -m sdc.runtests sdc.tests.test_series.TestSeries.test_series_setitem_for_slice - """ - - self._data[idx] = value - return self - - return hpat_pandas_series_setitem_idx_integer_impl - - if isinstance(idx, SeriesType): - def hpat_pandas_series_getitem_idx_series_impl(self, idx, value): - """ - Test: python -m sdc.runtests sdc.tests.test_series.TestSeries.test_series_setitem_for_series - """ - super_index = idx._data - self._data[super_index] = value - return self - - return hpat_pandas_series_getitem_idx_series_impl - - raise TypingError('{} The index must be an Integer, Slice or a pandas.series. Given: {}'.format(_func_name, idx)) - - @overload_attribute(SeriesType, 'at') @overload_attribute(SeriesType, 'iat') @overload_attribute(SeriesType, 'iloc') diff --git a/sdc/hiframes/pd_series_ext.py b/sdc/hiframes/pd_series_ext.py index ce403b58e..6900efbd3 100644 --- a/sdc/hiframes/pd_series_ext.py +++ b/sdc/hiframes/pd_series_ext.py @@ -441,9 +441,9 @@ def resolve_iloc(self, ary): return SeriesIatType(ary) # PR135. This needs to be commented out -# def resolve_loc(self, ary): -# # TODO: support iat/iloc differences -# return SeriesIatType(ary) + def resolve_loc(self, ary): + # TODO: support iat/iloc differences + return SeriesIatType(ary) # @bound_function("array.astype") # def resolve_astype(self, ary, args, kws): @@ -900,14 +900,14 @@ def __init__(self, stype): # PR135. This needs to be commented out -# @infer_global(operator.getitem) -# class GetItemSeriesIat(AbstractTemplate): -# key = operator.getitem -# -# def generic(self, args, kws): -# # iat[] is the same as regular getitem -# if isinstance(args[0], SeriesIatType): -# return GetItemSeries.generic(self, (args[0].stype, args[1]), kws) +@infer_global(operator.getitem) +class GetItemSeriesIat(AbstractTemplate): + key = operator.getitem + + def generic(self, args, kws): + # iat[] is the same as regular getitem + if isinstance(args[0], SeriesIatType): + return GetItemSeries.generic(self, (args[0].stype, args[1]), kws) @infer @@ -1033,110 +1033,110 @@ def generic_expand_cumulative_series(self, args, kws): delattr(SeriesAttribute, attr) # PR135. This needs to be commented out -# @infer_global(operator.getitem) -# class GetItemSeries(AbstractTemplate): -# key = operator.getitem -# -# def generic(self, args, kws): -# assert not kws -# [in_arr, in_idx] = args -# is_arr_series = False -# is_idx_series = False -# is_arr_dt_index = False -# -# if not isinstance(in_arr, SeriesType) and not isinstance(in_idx, SeriesType): -# return None -# -# if isinstance(in_arr, SeriesType): -# in_arr = series_to_array_type(in_arr) -# is_arr_series = True -# if in_arr.dtype == types.NPDatetime('ns'): -# is_arr_dt_index = True -# -# if isinstance(in_idx, SeriesType): -# in_idx = series_to_array_type(in_idx) -# is_idx_series = True -# -# # TODO: dt_index -# if in_arr == string_array_type: -# # XXX fails due in overload -# # compile_internal version results in symbol not found! -# # sig = self.context.resolve_function_type( -# # operator.getitem, (in_arr, in_idx), kws) -# # HACK to get avoid issues for now -# if isinstance(in_idx, (types.Integer, types.IntegerLiteral)): -# sig = string_type(in_arr, in_idx) -# else: -# sig = GetItemStringArray.generic(self, (in_arr, in_idx), kws) -# elif in_arr == list_string_array_type: -# # TODO: split view -# # mimic array indexing for list -# if (isinstance(in_idx, types.Array) and in_idx.ndim == 1 -# and isinstance( -# in_idx.dtype, (types.Integer, types.Boolean))): -# sig = signature(in_arr, in_arr, in_idx) -# else: -# sig = numba.typing.collections.GetItemSequence.generic( -# self, (in_arr, in_idx), kws) -# elif in_arr == string_array_split_view_type: -# sig = GetItemStringArraySplitView.generic( -# self, (in_arr, in_idx), kws) -# else: -# out = get_array_index_type(in_arr, in_idx) -# sig = signature(out.result, in_arr, out.index) -# -# if sig is not None: -# arg1 = sig.args[0] -# arg2 = sig.args[1] -# if is_arr_series: -# sig.return_type = if_arr_to_series_type(sig.return_type) -# arg1 = if_arr_to_series_type(arg1) -# if is_idx_series: -# arg2 = if_arr_to_series_type(arg2) -# sig.args = (arg1, arg2) -# # dt_index and Series(dt64) should return Timestamp -# if is_arr_dt_index and sig.return_type == types.NPDatetime('ns'): -# sig.return_type = pandas_timestamp_type -# return sig - +@infer_global(operator.getitem) +class GetItemSeries(AbstractTemplate): + key = operator.getitem -# @infer_global(operator.setitem) -# class SetItemSeries(SetItemBuffer): -# def generic(self, args, kws): -# assert not kws -# series, idx, val = args -# if not isinstance(series, SeriesType): -# return None -# # TODO: handle any of args being Series independently -# ary = series_to_array_type(series) -# is_idx_series = False -# if isinstance(idx, SeriesType): -# idx = series_to_array_type(idx) -# is_idx_series = True -# is_val_series = False -# if isinstance(val, SeriesType): -# val = series_to_array_type(val) -# is_val_series = True -# # TODO: strings, dt_index -# res = super(SetItemSeries, self).generic((ary, idx, val), kws) -# if res is not None: -# new_series = if_arr_to_series_type(res.args[0]) -# idx = res.args[1] -# val = res.args[2] -# if is_idx_series: -# idx = if_arr_to_series_type(idx) -# if is_val_series: -# val = if_arr_to_series_type(val) -# res.args = (new_series, idx, val) -# return res -# -# -# @infer_global(operator.setitem) -# class SetItemSeriesIat(SetItemSeries): -# def generic(self, args, kws): -# # iat[] is the same as regular setitem -# if isinstance(args[0], SeriesIatType): -# return SetItemSeries.generic(self, (args[0].stype, args[1], args[2]), kws) + def generic(self, args, kws): + assert not kws + [in_arr, in_idx] = args + is_arr_series = False + is_idx_series = False + is_arr_dt_index = False + + if not isinstance(in_arr, SeriesType) and not isinstance(in_idx, SeriesType): + return None + + if isinstance(in_arr, SeriesType): + in_arr = series_to_array_type(in_arr) + is_arr_series = True + if in_arr.dtype == types.NPDatetime('ns'): + is_arr_dt_index = True + + if isinstance(in_idx, SeriesType): + in_idx = series_to_array_type(in_idx) + is_idx_series = True + + # TODO: dt_index + if in_arr == string_array_type: + # XXX fails due in overload + # compile_internal version results in symbol not found! + # sig = self.context.resolve_function_type( + # operator.getitem, (in_arr, in_idx), kws) + # HACK to get avoid issues for now + if isinstance(in_idx, (types.Integer, types.IntegerLiteral)): + sig = string_type(in_arr, in_idx) + else: + sig = GetItemStringArray.generic(self, (in_arr, in_idx), kws) + elif in_arr == list_string_array_type: + # TODO: split view + # mimic array indexing for list + if (isinstance(in_idx, types.Array) and in_idx.ndim == 1 + and isinstance( + in_idx.dtype, (types.Integer, types.Boolean))): + sig = signature(in_arr, in_arr, in_idx) + else: + sig = numba.typing.collections.GetItemSequence.generic( + self, (in_arr, in_idx), kws) + elif in_arr == string_array_split_view_type: + sig = GetItemStringArraySplitView.generic( + self, (in_arr, in_idx), kws) + else: + out = get_array_index_type(in_arr, in_idx) + sig = signature(out.result, in_arr, out.index) + + if sig is not None: + arg1 = sig.args[0] + arg2 = sig.args[1] + if is_arr_series: + sig.return_type = if_arr_to_series_type(sig.return_type) + arg1 = if_arr_to_series_type(arg1) + if is_idx_series: + arg2 = if_arr_to_series_type(arg2) + sig.args = (arg1, arg2) + # dt_index and Series(dt64) should return Timestamp + if is_arr_dt_index and sig.return_type == types.NPDatetime('ns'): + sig.return_type = pandas_timestamp_type + return sig + + +@infer_global(operator.setitem) +class SetItemSeries(SetItemBuffer): + def generic(self, args, kws): + assert not kws + series, idx, val = args + if not isinstance(series, SeriesType): + return None + # TODO: handle any of args being Series independently + ary = series_to_array_type(series) + is_idx_series = False + if isinstance(idx, SeriesType): + idx = series_to_array_type(idx) + is_idx_series = True + is_val_series = False + if isinstance(val, SeriesType): + val = series_to_array_type(val) + is_val_series = True + # TODO: strings, dt_index + res = super(SetItemSeries, self).generic((ary, idx, val), kws) + if res is not None: + new_series = if_arr_to_series_type(res.args[0]) + idx = res.args[1] + val = res.args[2] + if is_idx_series: + idx = if_arr_to_series_type(idx) + if is_val_series: + val = if_arr_to_series_type(val) + res.args = (new_series, idx, val) + return res + + +@infer_global(operator.setitem) +class SetItemSeriesIat(SetItemSeries): + def generic(self, args, kws): + # iat[] is the same as regular setitem + if isinstance(args[0], SeriesIatType): + return SetItemSeries.generic(self, (args[0].stype, args[1], args[2]), kws) inplace_ops = [ diff --git a/sdc/tests/test_series.py b/sdc/tests/test_series.py index bfa040c41..05d23fe1c 100644 --- a/sdc/tests/test_series.py +++ b/sdc/tests/test_series.py @@ -4388,65 +4388,6 @@ def test_series_pct_change_impl(S, periods=1, fill_method='pad', limit=None, fre msg = 'Method pct_change(). The object periods' self.assertIn(msg, str(raises.exception)) - def test_series_setitem_for_value(self): - def test_impl(S, val): - S[3] = val - return S - - hpat_func = self.jit(test_impl) - S = pd.Series([0, 1, 2, 3, 4]) - value = 50 - result_ref = test_impl(S, value) - result = hpat_func(S, value) - pd.testing.assert_series_equal(result_ref, result) - - def test_series_setitem_for_slice(self): - def test_impl(S, val): - S[2:] = val - return S - - hpat_func = self.jit(test_impl) - S = pd.Series([0, 1, 2, 3, 4]) - value = 50 - result_ref = test_impl(S, value) - result = hpat_func(S, value) - pd.testing.assert_series_equal(result_ref, result) - - def test_series_setitem_for_series(self): - def test_impl(S, ind, val): - S[ind] = val - return S - - hpat_func = self.jit(test_impl) - S = pd.Series([0, 1, 2, 3, 4]) - ind = pd.Series([0, 2, 4]) - value = 50 - result_ref = test_impl(S, ind, value) - result = hpat_func(S, ind, value) - pd.testing.assert_series_equal(result_ref, result) - - def test_series_setitem_unsupported(self): - def test_impl(S, ind, val): - S[ind] = val - return S - - hpat_func = self.jit(test_impl) - S = pd.Series([0, 1, 2, 3, 4, 5]) - ind1 = 5 - ind2 = '3' - value1 = 'ababa' - value2 = 101 - - with self.assertRaises(TypingError) as raises: - hpat_func(S, ind1, value1) - msg = 'Operator setitem(). Value must be one type with series.' - self.assertIn(msg, str(raises.exception)) - - with self.assertRaises(TypingError) as raises: - hpat_func(S, ind2, value2) - msg = 'Operator setitem(). The index must be an Integer, Slice or a pandas.series.' - self.assertIn(msg, str(raises.exception)) - if __name__ == "__main__": unittest.main()