diff --git a/sdc/datatypes/hpat_pandas_stringmethods_types.py b/sdc/datatypes/hpat_pandas_stringmethods_types.py index 2457c71dd..bb747ea4e 100644 --- a/sdc/datatypes/hpat_pandas_stringmethods_types.py +++ b/sdc/datatypes/hpat_pandas_stringmethods_types.py @@ -36,6 +36,7 @@ from numba.extending import (models, overload, register_model, make_attribute_wrapper, intrinsic) from numba.datamodel import (register_default, StructModel) from numba.typing.templates import signature +from sdc.hiframes.split_impl import SplitViewStringMethodsType, StringArraySplitViewType class StringMethodsType(types.IterableType): @@ -50,7 +51,8 @@ class StringMethodsType(types.IterableType): def __init__(self, data): self.data = data - super(StringMethodsType, self).__init__('StringMethodsType') + name = 'StringMethodsType({})'.format(self.data) + super(StringMethodsType, self).__init__(name) @property def iterator_type(self): @@ -74,37 +76,47 @@ def __init__(self, dmm, fe_type): make_attribute_wrapper(StringMethodsType, 'data', '_data') -@intrinsic -def _hpat_pandas_stringmethods_init(typingctx, data): - """ - Internal Numba required function to register StringMethodsType and - connect it with corresponding Python type mentioned in @overload(pandas.core.strings.StringMethods) - """ +def _gen_hpat_pandas_stringmethods_init(string_methods_type=None): + string_methods_type = string_methods_type or StringMethodsType - def _hpat_pandas_stringmethods_init_codegen(context, builder, signature, args): + def _hpat_pandas_stringmethods_init(typingctx, data): """ - It is looks like it creates StringMethodsModel structure - - - Fixed number of parameters. Must be 4 - - increase reference count for the data + Internal Numba required function to register StringMethodsType and + connect it with corresponding Python type mentioned in @overload(pandas.core.strings.StringMethods) """ - [data_val] = args - stringmethod = cgutils.create_struct_proxy(signature.return_type)(context, builder) - stringmethod.data = data_val + def _hpat_pandas_stringmethods_init_codegen(context, builder, signature, args): + """ + It is looks like it creates StringMethodsModel structure - if context.enable_nrt: - context.nrt.incref(builder, data, stringmethod.data) + - Fixed number of parameters. Must be 4 + - increase reference count for the data + """ - return stringmethod._getvalue() + [data_val] = args + stringmethod = cgutils.create_struct_proxy(signature.return_type)(context, builder) + stringmethod.data = data_val - ret_typ = StringMethodsType(data) - sig = signature(ret_typ, data) - """ - Construct signature of the Numba SeriesGroupByType::ctor() - """ + if context.enable_nrt: + context.nrt.incref(builder, data, stringmethod.data) - return sig, _hpat_pandas_stringmethods_init_codegen + return stringmethod._getvalue() + + ret_typ = string_methods_type(data) + sig = signature(ret_typ, data) + """ + Construct signature of the Numba SeriesGroupByType::ctor() + """ + + return sig, _hpat_pandas_stringmethods_init_codegen + + return _hpat_pandas_stringmethods_init + + +_hpat_pandas_stringmethods_init = intrinsic( + _gen_hpat_pandas_stringmethods_init(string_methods_type=StringMethodsType)) +_hpat_pandas_split_view_stringmethods_init = intrinsic( + _gen_hpat_pandas_stringmethods_init(string_methods_type=SplitViewStringMethodsType)) @overload(pandas.core.strings.StringMethods) @@ -113,6 +125,11 @@ def hpat_pandas_stringmethods(obj): Special Numba procedure to overload Python type pandas.core.strings.StringMethods::ctor() with Numba registered model """ + if isinstance(obj.data, StringArraySplitViewType): + def hpat_pandas_split_view_stringmethods_impl(obj): + return _hpat_pandas_split_view_stringmethods_init(obj) + + return hpat_pandas_split_view_stringmethods_impl def hpat_pandas_stringmethods_impl(obj): return _hpat_pandas_stringmethods_init(obj) diff --git a/sdc/hiframes/hiframes_typed.py b/sdc/hiframes/hiframes_typed.py index 9dcb82387..e27e81a29 100644 --- a/sdc/hiframes/hiframes_typed.py +++ b/sdc/hiframes/hiframes_typed.py @@ -65,9 +65,10 @@ from sdc.hiframes.rolling import get_rolling_setup_args from sdc.hiframes.aggregate import Aggregate from sdc.hiframes.series_kernels import series_replace_funcs -from sdc.hiframes.split_impl import (string_array_split_view_type, - StringArraySplitViewType, getitem_c_arr, get_array_ctypes_ptr, - get_split_view_index, get_split_view_data_ptr) +from sdc.hiframes.split_impl import (SplitViewStringMethodsType, + string_array_split_view_type, StringArraySplitViewType, + getitem_c_arr, get_array_ctypes_ptr, + get_split_view_index, get_split_view_data_ptr) _dt_index_binops = ('==', '!=', '>=', '>', '<=', '<', '-', @@ -480,7 +481,8 @@ def _run_call(self, assign, lhs, rhs): else: func_name, func_mod = fdef - if (isinstance(func_mod, ir.Var) and isinstance(self.state.typemap[func_mod.name], StringMethodsType)): + string_methods_types = (SplitViewStringMethodsType, StringMethodsType) + if isinstance(func_mod, ir.Var) and isinstance(self.state.typemap[func_mod.name], string_methods_types): f_def = guard(get_definition, self.state.func_ir, rhs.func) str_def = guard(get_definition, self.state.func_ir, f_def.value) if str_def is None: # TODO: check for errors diff --git a/sdc/hiframes/pd_series_ext.py b/sdc/hiframes/pd_series_ext.py index 6900efbd3..4b510bb92 100644 --- a/sdc/hiframes/pd_series_ext.py +++ b/sdc/hiframes/pd_series_ext.py @@ -57,7 +57,9 @@ from sdc.hiframes.pd_categorical_ext import (PDCategoricalDtype, CategoricalArray) from sdc.hiframes.pd_timestamp_ext import (pandas_timestamp_type, datetime_date_type) from sdc.hiframes.rolling import supported_rolling_funcs -from sdc.hiframes.split_impl import (string_array_split_view_type, GetItemStringArraySplitView) +from sdc.hiframes.split_impl import (SplitViewStringMethodsType, + string_array_split_view_type, + GetItemStringArraySplitView) from sdc.str_arr_ext import ( string_array_type, iternext_str_array, @@ -423,9 +425,9 @@ def resolve_T(self, ary): # def resolve_index(self, ary): # return ary.index - def resolve_str(self, ary): - assert ary.dtype in (string_type, types.List(string_type)) - return StringMethodsType(ary) + # def resolve_str(self, ary): + # assert ary.dtype in (string_type, types.List(string_type)) + # return StringMethodsType(ary) def resolve_dt(self, ary): assert ary.dtype == types.NPDatetime('ns') @@ -780,9 +782,9 @@ class SeriesStrMethodAttribute(AttributeTemplate): def resolve_contains(self, ary, args, kws): return signature(SeriesType(types.bool_), *args) - @bound_function("strmethod.len") - def resolve_len(self, ary, args, kws): - return signature(SeriesType(types.int64), *args) + # @bound_function("strmethod.len") + # def resolve_len(self, ary, args, kws): + # return signature(SeriesType(types.int64), *args) @bound_function("strmethod.replace") def resolve_replace(self, ary, args, kws): @@ -820,6 +822,16 @@ def generic(self, args, kws): raise NotImplementedError('Series.str.{} is not supported yet'.format(func_name)) +@infer_getattr +class SplitViewSeriesStrMethodAttribute(AttributeTemplate): + key = SplitViewStringMethodsType + + @bound_function('strmethod.get') + def resolve_get(self, ary, args, kws): + # XXX only list(list(str)) supported + return signature(SeriesType(string_type), *args) + + class SeriesDtMethodType(types.Type): def __init__(self): name = "SeriesDtMethodType" diff --git a/sdc/hiframes/split_impl.py b/sdc/hiframes/split_impl.py index 8a589a6eb..fd75a2223 100644 --- a/sdc/hiframes/split_impl.py +++ b/sdc/hiframes/split_impl.py @@ -26,12 +26,16 @@ import operator +import numpy +import pandas import numba import sdc from numba import types from numba.typing.templates import (infer_global, AbstractTemplate, infer, signature, AttributeTemplate, infer_getattr, bound_function) import numba.typing.typeof +from numba.datamodel import StructModel +from numba.errors import TypingError from numba.extending import (typeof_impl, type_callable, models, register_model, NativeValue, make_attribute_wrapper, lower_builtin, box, unbox, lower_getattr, intrinsic, overload_method, overload, overload_attribute) @@ -131,6 +135,43 @@ def __init__(self, dmm, fe_type): make_attribute_wrapper(StringArraySplitViewType, 'data', '_data') +class SplitViewStringMethodsType(types.IterableType): + """ + Type definition for pandas.core.strings.StringMethods functions handling. + + Members + ---------- + _data: :class:`SeriesType` + input arg + """ + + def __init__(self, data): + self.data = data + name = 'SplitViewStringMethodsType({})'.format(self.data) + super(SplitViewStringMethodsType, self).__init__(name) + + @property + def iterator_type(self): + return None + + +@register_model(SplitViewStringMethodsType) +class SplitViewStringMethodsTypeModel(StructModel): + """ + Model for SplitViewStringMethodsType type + All members must be the same as main type for this model + """ + + def __init__(self, dmm, fe_type): + members = [ + ('data', fe_type.data) + ] + models.StructModel.__init__(self, dmm, fe_type, members) + + +make_attribute_wrapper(SplitViewStringMethodsType, 'data', '_data') + + def construct_str_arr_split_view(context, builder): """Creates meminfo and sets dtor. """ @@ -404,6 +445,44 @@ def str_arr_split_view_len_overload(arr): return lambda arr: arr._num_items +@overload_method(SplitViewStringMethodsType, 'len') +def hpat_pandas_spliview_stringmethods_len(self): + """ + Pandas Series method :meth:`pandas.core.strings.StringMethods.len()` implementation. + + Note: Unicode type of list elements are supported only. Numpy.NaN is not supported as elements. + + .. only:: developer + + Test: python -m sdc.runtests sdc.tests.test_hiframes.TestHiFrames.test_str_split_filter + + Parameters + ---------- + self: :class:`pandas.core.strings.StringMethods` + input arg + + Returns + ------- + :obj:`pandas.Series` + returns :obj:`pandas.Series` object + """ + + if not isinstance(self, SplitViewStringMethodsType): + msg = 'Method len(). The object must be a pandas.core.strings. Given: {}' + raise TypingError(msg.format(self)) + + def hpat_pandas_spliview_stringmethods_len_impl(self): + item_count = len(self._data) + result = numpy.empty(item_count, numba.types.int64) + local_data = self._data._data + for i in range(len(local_data)): + result[i] = len(local_data[i]) + + return pandas.Series(result, name=self._data._name) + + return hpat_pandas_spliview_stringmethods_len_impl + + # @infer_global(operator.getitem) class GetItemStringArraySplitView(AbstractTemplate): key = operator.getitem diff --git a/sdc/tests/test_hiframes.py b/sdc/tests/test_hiframes.py index ae7412093..3a296f3b1 100644 --- a/sdc/tests/test_hiframes.py +++ b/sdc/tests/test_hiframes.py @@ -424,7 +424,6 @@ def test_impl(df): pd.testing.assert_series_equal( hpat_func(df), test_impl(df), check_names=False) - @skip_sdc_jit("Could not get length of Series(StringArraySplitView)") @skip_numba_jit def test_str_split_filter(self): def test_impl(df):