From 2dcf91e1d712ea4cdaf680cf2785a25155b8d4aa Mon Sep 17 00:00:00 2001 From: akharche Date: Fri, 16 Aug 2019 15:49:47 +0300 Subject: [PATCH] Implemented iterator for pandas.Series --- hpat/hiframes/pd_series_ext.py | 143 ++++++++++++++++++++++++++++++--- hpat/str_arr_ext.py | 40 +++++---- hpat/tests/test_series.py | 24 ++++++ 3 files changed, 181 insertions(+), 26 deletions(-) diff --git a/hpat/hiframes/pd_series_ext.py b/hpat/hiframes/pd_series_ext.py index d1705a27a..ad842265f 100644 --- a/hpat/hiframes/pd_series_ext.py +++ b/hpat/hiframes/pd_series_ext.py @@ -2,9 +2,11 @@ import pandas as pd import numpy as np import numba -from numba import types +import hpat +import llvmlite.llvmpy.core as lc +from numba import types, cgutils from numba.extending import (models, register_model, lower_cast, infer_getattr, - type_callable, infer, overload, make_attribute_wrapper) + type_callable, infer, overload, make_attribute_wrapper, lower_builtin) from numba.typing.templates import (infer_global, AbstractTemplate, signature, AttributeTemplate, bound_function) from numba.typing.arraydecl import (get_array_index_type, _expand_integer, @@ -12,17 +14,17 @@ from numba.typing.npydecl import (Numpy_rules_ufunc, NumpyRulesArrayOperator, NumpyRulesInplaceArrayOperator, NumpyRulesUnaryArrayOperator, NdConstructorLike) -import hpat +from numba.targets.imputils import (impl_ret_new_ref, iternext_impl, RefType) +from numba.targets.arrayobj import (make_array, _getitem_array1d) from hpat.str_ext import string_type, list_string_array_type -from hpat.str_arr_ext import (string_array_type, offset_typ, char_typ, +from hpat.str_arr_ext import (string_array_type, iternext_str_array, offset_typ, char_typ, str_arr_payload_type, StringArrayType, GetItemStringArray) from hpat.hiframes.pd_timestamp_ext import pandas_timestamp_type, datetime_date_type from hpat.hiframes.pd_categorical_ext import (PDCategoricalDtype, CategoricalArray) from hpat.hiframes.rolling import supported_rolling_funcs -import datetime from hpat.hiframes.split_impl import (string_array_split_view_type, - GetItemStringArraySplitView) + GetItemStringArraySplitView) class SeriesType(types.IterableType): @@ -35,9 +37,8 @@ def __init__(self, dtype, data=None, index=None, is_named=False): data = _get_series_array_type(dtype) if data is None else data # convert Record to tuple (for tuple output of map) # TODO: handle actual Record objects in Series? - dtype = (types.Tuple(list(dict(dtype.members).values())) - if isinstance(dtype, types.Record) else dtype) - self.dtype = dtype + self.dtype = (types.Tuple(list(dict(dtype.members).values())) + if isinstance(dtype, types.Record) else dtype) self.data = data if index is None: index = types.none @@ -96,9 +97,38 @@ def is_precise(self): @property def iterator_type(self): - # same as Buffer # TODO: fix timestamp - return types.iterators.ArrayIterator(self.data) + return SeriesIterator(self) + + +class SeriesIterator(types.SimpleIteratorType): + """ + Type class for iterator over dataframe series. + """ + + def __init__(self, series_type): + self.series_type = series_type + self.array_type = series_type.data + + name = f'iter({self.series_type.data})' + yield_type = series_type.dtype + super(SeriesIterator, self).__init__(name, yield_type) + + @property + def _iternext(self): + if isinstance(self.array_type, StringArrayType): + return iternext_str_array + elif isinstance(self.array_type, types.Array): + return iternext_series_array + + +@register_model(SeriesIterator) +class SeriesIteratorModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [('index', types.EphemeralPointer(types.uintp)), + ('array', fe_type.series_type.data)] + + models.StructModel.__init__(self, dmm, fe_type, members) def _get_series_array_type(dtype): @@ -165,6 +195,97 @@ def __init__(self, dmm, fe_type): make_attribute_wrapper(SeriesType, 'name', '_name') +@lower_builtin('getiter', SeriesType) +def getiter_series(context, builder, sig, args): + """ + Getting iterator for the Series type + + :param context: context descriptor + :param builder: llvmlite IR Builder + :param sig: iterator signature + :param args: tuple with iterator arguments, such as instruction, operands and types + :param result: iternext result + :return: reference to iterator + """ + + arraytype = sig.args[0].data + + # Create instruction to get array to iterate + zero_member_pointer = context.get_constant(types.intp, 0) + zero_member = context.get_constant(types.int32, 0) + alloca = args[0].operands[0] + gep_result = builder.gep(alloca, [zero_member_pointer, zero_member]) + array = builder.load(gep_result) + + # TODO: call numba getiter with gep_result for array + iterobj = context.make_helper(builder, sig.return_type) + zero_index = context.get_constant(types.intp, 0) + indexptr = cgutils.alloca_once_value(builder, zero_index) + + iterobj.index = indexptr + iterobj.array = array + + if context.enable_nrt: + context.nrt.incref(builder, arraytype, array) + + result = iterobj._getvalue() + # Note: a decref on the iterator will dereference all internal MemInfo* + out = impl_ret_new_ref(context, builder, sig.return_type, result) + return out + + +# TODO: call it from numba.targets.arrayobj, need separate function in numba +def iternext_series_array(context, builder, sig, args, result): + """ + Implementation of iternext() for the ArrayIterator type + + :param context: context descriptor + :param builder: llvmlite IR Builder + :param sig: iterator signature + :param args: tuple with iterator arguments, such as instruction, operands and types + :param result: iternext result + """ + + [iterty] = sig.args + [iter] = args + arrayty = iterty.array_type + + if arrayty.ndim != 1: + raise NotImplementedError("iterating over %dD array" % arrayty.ndim) + + iterobj = context.make_helper(builder, iterty, value=iter) + ary = make_array(arrayty)(context, builder, value=iterobj.array) + + nitems, = cgutils.unpack_tuple(builder, ary.shape, count=1) + + index = builder.load(iterobj.index) + is_valid = builder.icmp(lc.ICMP_SLT, index, nitems) + result.set_valid(is_valid) + + with builder.if_then(is_valid): + value = _getitem_array1d(context, builder, arrayty, ary, index, + wraparound=False) + result.yield_(value) + nindex = cgutils.increment_index(builder, index) + builder.store(nindex, iterobj.index) + + +@lower_builtin('iternext', SeriesIterator) +@iternext_impl(RefType.BORROWED) +def iternext_series(context, builder, sig, args, result): + """ + Iternext implementation depending on Array type + + :param context: context descriptor + :param builder: llvmlite IR Builder + :param sig: iterator signature + :param args: tuple with iterator arguments, such as instruction, operands and types + :param result: iternext result + """ + iternext_func = sig.args[0]._iternext + iternext_func(context=context, builder=builder, sig=sig, args=args, result=result) + + def series_to_array_type(typ, replace_boxed=False): return typ.data # return _get_series_array_type(typ.dtype) diff --git a/hpat/str_arr_ext.py b/hpat/str_arr_ext.py index 609cb1379..b35bffdac 100644 --- a/hpat/str_arr_ext.py +++ b/hpat/str_arr_ext.py @@ -137,25 +137,21 @@ def __init__(self): super(StringArrayIterator, self).__init__(name, yield_type) -@register_model(StringArrayIterator) -class StrArrayIteratorModel(models.StructModel): - def __init__(self, dmm, fe_type): - # We use an unsigned index to avoid the cost of negative index tests. - members = [('index', types.EphemeralPointer(types.uintp)), - ('array', string_array_type)] - super(StrArrayIteratorModel, self).__init__(dmm, fe_type, members) - - -lower_builtin('getiter', string_array_type)(numba.targets.arrayobj.getiter_array) +def iternext_str_array(context, builder, sig, args, result): + """ + Implementation of iternext() for the StringArrayIterator type + :param context: context descriptor + :param builder: llvmlite IR Builder + :param sig: iterator signature + :param args: tuple with iterator arguments, such as instruction, operands and types + :param result: iternext result + """ -@lower_builtin('iternext', StringArrayIterator) -@iternext_impl(RefType.NEW) -def iternext_str_array(context, builder, sig, args, result): - [iterty] = sig.args + [itertype] = sig.args [iter_arg] = args - iterobj = context.make_helper(builder, iterty, value=iter_arg) + iterobj = context.make_helper(builder, itertype, value=iter_arg) len_sig = signature(types.intp, string_array_type) nitems = context.compile_internal(builder, lambda a: len(a), len_sig, [iterobj.array]) @@ -171,6 +167,20 @@ def iternext_str_array(context, builder, sig, args, result): builder.store(nindex, iterobj.index) +@register_model(StringArrayIterator) +class StrArrayIteratorModel(models.StructModel): + def __init__(self, dmm, fe_type): + # We use an unsigned index to avoid the cost of negative index tests. + members = [('index', types.EphemeralPointer(types.uintp)), + ('array', string_array_type)] + super(StrArrayIteratorModel, self).__init__(dmm, fe_type, members) + + +lower_builtin('getiter', string_array_type)(numba.targets.arrayobj.getiter_array) + +lower_builtin('iternext', StringArrayIterator)(iternext_impl(RefType.NEW)(iternext_str_array)) + + @intrinsic def num_total_chars(typingctx, str_arr_typ=None): # None default to make IntelliSense happy diff --git a/hpat/tests/test_series.py b/hpat/tests/test_series.py index 1ed6fec58..56ffc55b7 100644 --- a/hpat/tests/test_series.py +++ b/hpat/tests/test_series.py @@ -1485,6 +1485,30 @@ def test_impl(): hpat_func = hpat.jit(test_impl) np.testing.assert_array_equal(hpat_func(), test_impl()) + def test_series_iterator_int(self): + def test_impl(): + A = pd.Series([1, 2, 3, 4, 5]) + return [i for i in A] + + hpat_func = hpat.jit(test_impl) + np.testing.assert_array_equal(hpat_func(), test_impl()) + + def test_series_iterator_string(self): + def test_impl(): + A = pd.Series(['a', 'ab', 'abc', '', 'dddd']) + return [i for i in A] + + hpat_func = hpat.jit(test_impl) + np.testing.assert_array_equal(hpat_func(), test_impl()) + + def test_series_iterator_one_value(self): + def test_impl(): + A = pd.Series([5]) + return [i for i in A] + + hpat_func = hpat.jit(test_impl) + np.testing.assert_array_equal(hpat_func(), test_impl()) + if __name__ == "__main__": unittest.main()