diff --git a/sdc/datatypes/hpat_pandas_series_functions.py b/sdc/datatypes/hpat_pandas_series_functions.py index 3f36f6fdb..f52931332 100644 --- a/sdc/datatypes/hpat_pandas_series_functions.py +++ b/sdc/datatypes/hpat_pandas_series_functions.py @@ -156,6 +156,61 @@ def hpat_pandas_series_getitem_idx_series_impl(self, idx): raise TypingError('{} The index must be an Integer, Slice or a pandas.series. Given: {}'.format(_func_name, idx)) +@overload(operator.setitem) +def hpat_pandas_series_setitem(self, idx, value): + """ + Pandas Series operator :attr:`pandas.Series.get` implementation + ''' + Test: python -m sdc.runtests sdc.tests.test_series.TestSeries.test_series_setitem_unsupported + ''' + Parameters + ---------- + series: :obj:`pandas.Series` + input series + idx: :obj:`int`, :obj:`slice` or :obj:`pandas.Series` + input index + value: :object + input value + Returns + ------- + :class:`pandas.Series` or an element of the underneath type + object of :class:`pandas.Series` + """ + + _func_name = 'Operator setitem().' + if not isinstance(self, SeriesType): + raise TypingError('{} The object must be a pandas.series. Given: {}'.format(_func_name, self)) + + if not isinstance(self.dtype, type(value)): + raise TypingError('{} Value must be one type with series. Given: {}, self.dtype={}'.format(_func_name, + value, self.dtype)) + + if isinstance(idx, types.Integer) or isinstance(idx, types.SliceType): + def hpat_pandas_series_setitem_idx_integer_impl(self, idx, value): + """ + Test: python -m sdc.runtests sdc.tests.test_series.TestSeries.test_series_setitem_for_value + Test: python -m sdc.runtests sdc.tests.test_series.TestSeries.test_series_setitem_for_slice + """ + + self._data[idx] = value + return self + + return hpat_pandas_series_setitem_idx_integer_impl + + if isinstance(idx, SeriesType): + def hpat_pandas_series_getitem_idx_series_impl(self, idx, value): + """ + Test: python -m sdc.runtests sdc.tests.test_series.TestSeries.test_series_setitem_for_series + """ + super_index = idx._data + self._data[super_index] = value + return self + + return hpat_pandas_series_getitem_idx_series_impl + + raise TypingError('{} The index must be an Integer, Slice or a pandas.series. Given: {}'.format(_func_name, idx)) + + @overload_attribute(SeriesType, 'at') @overload_attribute(SeriesType, 'iat') @overload_attribute(SeriesType, 'iloc') diff --git a/sdc/hiframes/pd_series_ext.py b/sdc/hiframes/pd_series_ext.py index dede5d4b6..d90b19897 100644 --- a/sdc/hiframes/pd_series_ext.py +++ b/sdc/hiframes/pd_series_ext.py @@ -441,9 +441,9 @@ def resolve_iloc(self, ary): return SeriesIatType(ary) # PR135. This needs to be commented out - def resolve_loc(self, ary): - # TODO: support iat/iloc differences - return SeriesIatType(ary) +# def resolve_loc(self, ary): +# # TODO: support iat/iloc differences +# return SeriesIatType(ary) # @bound_function("array.astype") # def resolve_astype(self, ary, args, kws): @@ -900,14 +900,14 @@ def __init__(self, stype): # PR135. This needs to be commented out -@infer_global(operator.getitem) -class GetItemSeriesIat(AbstractTemplate): - key = operator.getitem - - def generic(self, args, kws): - # iat[] is the same as regular getitem - if isinstance(args[0], SeriesIatType): - return GetItemSeries.generic(self, (args[0].stype, args[1]), kws) +# @infer_global(operator.getitem) +# class GetItemSeriesIat(AbstractTemplate): +# key = operator.getitem +# +# def generic(self, args, kws): +# # iat[] is the same as regular getitem +# if isinstance(args[0], SeriesIatType): +# return GetItemSeries.generic(self, (args[0].stype, args[1]), kws) @infer @@ -1033,110 +1033,110 @@ def generic_expand_cumulative_series(self, args, kws): delattr(SeriesAttribute, attr) # PR135. This needs to be commented out -@infer_global(operator.getitem) -class GetItemSeries(AbstractTemplate): - key = operator.getitem +# @infer_global(operator.getitem) +# class GetItemSeries(AbstractTemplate): +# key = operator.getitem +# +# def generic(self, args, kws): +# assert not kws +# [in_arr, in_idx] = args +# is_arr_series = False +# is_idx_series = False +# is_arr_dt_index = False +# +# if not isinstance(in_arr, SeriesType) and not isinstance(in_idx, SeriesType): +# return None +# +# if isinstance(in_arr, SeriesType): +# in_arr = series_to_array_type(in_arr) +# is_arr_series = True +# if in_arr.dtype == types.NPDatetime('ns'): +# is_arr_dt_index = True +# +# if isinstance(in_idx, SeriesType): +# in_idx = series_to_array_type(in_idx) +# is_idx_series = True +# +# # TODO: dt_index +# if in_arr == string_array_type: +# # XXX fails due in overload +# # compile_internal version results in symbol not found! +# # sig = self.context.resolve_function_type( +# # operator.getitem, (in_arr, in_idx), kws) +# # HACK to get avoid issues for now +# if isinstance(in_idx, (types.Integer, types.IntegerLiteral)): +# sig = string_type(in_arr, in_idx) +# else: +# sig = GetItemStringArray.generic(self, (in_arr, in_idx), kws) +# elif in_arr == list_string_array_type: +# # TODO: split view +# # mimic array indexing for list +# if (isinstance(in_idx, types.Array) and in_idx.ndim == 1 +# and isinstance( +# in_idx.dtype, (types.Integer, types.Boolean))): +# sig = signature(in_arr, in_arr, in_idx) +# else: +# sig = numba.typing.collections.GetItemSequence.generic( +# self, (in_arr, in_idx), kws) +# elif in_arr == string_array_split_view_type: +# sig = GetItemStringArraySplitView.generic( +# self, (in_arr, in_idx), kws) +# else: +# out = get_array_index_type(in_arr, in_idx) +# sig = signature(out.result, in_arr, out.index) +# +# if sig is not None: +# arg1 = sig.args[0] +# arg2 = sig.args[1] +# if is_arr_series: +# sig.return_type = if_arr_to_series_type(sig.return_type) +# arg1 = if_arr_to_series_type(arg1) +# if is_idx_series: +# arg2 = if_arr_to_series_type(arg2) +# sig.args = (arg1, arg2) +# # dt_index and Series(dt64) should return Timestamp +# if is_arr_dt_index and sig.return_type == types.NPDatetime('ns'): +# sig.return_type = pandas_timestamp_type +# return sig - def generic(self, args, kws): - assert not kws - [in_arr, in_idx] = args - is_arr_series = False - is_idx_series = False - is_arr_dt_index = False - - if not isinstance(in_arr, SeriesType) and not isinstance(in_idx, SeriesType): - return None - - if isinstance(in_arr, SeriesType): - in_arr = series_to_array_type(in_arr) - is_arr_series = True - if in_arr.dtype == types.NPDatetime('ns'): - is_arr_dt_index = True - - if isinstance(in_idx, SeriesType): - in_idx = series_to_array_type(in_idx) - is_idx_series = True - - # TODO: dt_index - if in_arr == string_array_type: - # XXX fails due in overload - # compile_internal version results in symbol not found! - # sig = self.context.resolve_function_type( - # operator.getitem, (in_arr, in_idx), kws) - # HACK to get avoid issues for now - if isinstance(in_idx, (types.Integer, types.IntegerLiteral)): - sig = string_type(in_arr, in_idx) - else: - sig = GetItemStringArray.generic(self, (in_arr, in_idx), kws) - elif in_arr == list_string_array_type: - # TODO: split view - # mimic array indexing for list - if (isinstance(in_idx, types.Array) and in_idx.ndim == 1 - and isinstance( - in_idx.dtype, (types.Integer, types.Boolean))): - sig = signature(in_arr, in_arr, in_idx) - else: - sig = numba.typing.collections.GetItemSequence.generic( - self, (in_arr, in_idx), kws) - elif in_arr == string_array_split_view_type: - sig = GetItemStringArraySplitView.generic( - self, (in_arr, in_idx), kws) - else: - out = get_array_index_type(in_arr, in_idx) - sig = signature(out.result, in_arr, out.index) - - if sig is not None: - arg1 = sig.args[0] - arg2 = sig.args[1] - if is_arr_series: - sig.return_type = if_arr_to_series_type(sig.return_type) - arg1 = if_arr_to_series_type(arg1) - if is_idx_series: - arg2 = if_arr_to_series_type(arg2) - sig.args = (arg1, arg2) - # dt_index and Series(dt64) should return Timestamp - if is_arr_dt_index and sig.return_type == types.NPDatetime('ns'): - sig.return_type = pandas_timestamp_type - return sig - - -@infer_global(operator.setitem) -class SetItemSeries(SetItemBuffer): - def generic(self, args, kws): - assert not kws - series, idx, val = args - if not isinstance(series, SeriesType): - return None - # TODO: handle any of args being Series independently - ary = series_to_array_type(series) - is_idx_series = False - if isinstance(idx, SeriesType): - idx = series_to_array_type(idx) - is_idx_series = True - is_val_series = False - if isinstance(val, SeriesType): - val = series_to_array_type(val) - is_val_series = True - # TODO: strings, dt_index - res = super(SetItemSeries, self).generic((ary, idx, val), kws) - if res is not None: - new_series = if_arr_to_series_type(res.args[0]) - idx = res.args[1] - val = res.args[2] - if is_idx_series: - idx = if_arr_to_series_type(idx) - if is_val_series: - val = if_arr_to_series_type(val) - res.args = (new_series, idx, val) - return res - - -@infer_global(operator.setitem) -class SetItemSeriesIat(SetItemSeries): - def generic(self, args, kws): - # iat[] is the same as regular setitem - if isinstance(args[0], SeriesIatType): - return SetItemSeries.generic(self, (args[0].stype, args[1], args[2]), kws) + +# @infer_global(operator.setitem) +# class SetItemSeries(SetItemBuffer): +# def generic(self, args, kws): +# assert not kws +# series, idx, val = args +# if not isinstance(series, SeriesType): +# return None +# # TODO: handle any of args being Series independently +# ary = series_to_array_type(series) +# is_idx_series = False +# if isinstance(idx, SeriesType): +# idx = series_to_array_type(idx) +# is_idx_series = True +# is_val_series = False +# if isinstance(val, SeriesType): +# val = series_to_array_type(val) +# is_val_series = True +# # TODO: strings, dt_index +# res = super(SetItemSeries, self).generic((ary, idx, val), kws) +# if res is not None: +# new_series = if_arr_to_series_type(res.args[0]) +# idx = res.args[1] +# val = res.args[2] +# if is_idx_series: +# idx = if_arr_to_series_type(idx) +# if is_val_series: +# val = if_arr_to_series_type(val) +# res.args = (new_series, idx, val) +# return res +# +# +# @infer_global(operator.setitem) +# class SetItemSeriesIat(SetItemSeries): +# def generic(self, args, kws): +# # iat[] is the same as regular setitem +# if isinstance(args[0], SeriesIatType): +# return SetItemSeries.generic(self, (args[0].stype, args[1], args[2]), kws) inplace_ops = [ diff --git a/sdc/tests/test_series.py b/sdc/tests/test_series.py index 9447d1056..066965802 100644 --- a/sdc/tests/test_series.py +++ b/sdc/tests/test_series.py @@ -4305,6 +4305,65 @@ def test_series_pct_change_impl(S, periods=1, fill_method='pad', limit=None, fre msg = 'Method pct_change(). The object periods' self.assertIn(msg, str(raises.exception)) + def test_series_setitem_for_value(self): + def test_impl(S, val): + S[3] = val + return S + + hpat_func = self.jit(test_impl) + S = pd.Series([0, 1, 2, 3, 4]) + value = 50 + result_ref = test_impl(S, value) + result = hpat_func(S, value) + pd.testing.assert_series_equal(result_ref, result) + + def test_series_setitem_for_slice(self): + def test_impl(S, val): + S[2:] = val + return S + + hpat_func = self.jit(test_impl) + S = pd.Series([0, 1, 2, 3, 4]) + value = 50 + result_ref = test_impl(S, value) + result = hpat_func(S, value) + pd.testing.assert_series_equal(result_ref, result) + + def test_series_setitem_for_series(self): + def test_impl(S, ind, val): + S[ind] = val + return S + + hpat_func = self.jit(test_impl) + S = pd.Series([0, 1, 2, 3, 4]) + ind = pd.Series([0, 2, 4]) + value = 50 + result_ref = test_impl(S, ind, value) + result = hpat_func(S, ind, value) + pd.testing.assert_series_equal(result_ref, result) + + def test_series_setitem_unsupported(self): + def test_impl(S, ind, val): + S[ind] = val + return S + + hpat_func = self.jit(test_impl) + S = pd.Series([0, 1, 2, 3, 4, 5]) + ind1 = 5 + ind2 = '3' + value1 = 'ababa' + value2 = 101 + + with self.assertRaises(TypingError) as raises: + hpat_func(S, ind1, value1) + msg = 'Operator setitem(). Value must be one type with series.' + self.assertIn(msg, str(raises.exception)) + + with self.assertRaises(TypingError) as raises: + hpat_func(S, ind2, value2) + msg = 'Operator setitem(). The index must be an Integer, Slice or a pandas.series.' + self.assertIn(msg, str(raises.exception)) + if __name__ == "__main__": unittest.main()