Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 41 additions & 24 deletions sdc/datatypes/hpat_pandas_stringmethods_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from numba.extending import (models, overload, register_model, make_attribute_wrapper, intrinsic)
from numba.datamodel import (register_default, StructModel)
from numba.typing.templates import signature
from sdc.hiframes.split_impl import SplitViewStringMethodsType, StringArraySplitViewType


class StringMethodsType(types.IterableType):
Expand All @@ -50,7 +51,8 @@ class StringMethodsType(types.IterableType):

def __init__(self, data):
self.data = data
super(StringMethodsType, self).__init__('StringMethodsType')
name = 'StringMethodsType({})'.format(self.data)
super(StringMethodsType, self).__init__(name)

@property
def iterator_type(self):
Expand All @@ -74,37 +76,47 @@ def __init__(self, dmm, fe_type):
make_attribute_wrapper(StringMethodsType, 'data', '_data')


@intrinsic
def _hpat_pandas_stringmethods_init(typingctx, data):
"""
Internal Numba required function to register StringMethodsType and
connect it with corresponding Python type mentioned in @overload(pandas.core.strings.StringMethods)
"""
def _gen_hpat_pandas_stringmethods_init(string_methods_type=None):
string_methods_type = string_methods_type or StringMethodsType

def _hpat_pandas_stringmethods_init_codegen(context, builder, signature, args):
def _hpat_pandas_stringmethods_init(typingctx, data):
"""
It is looks like it creates StringMethodsModel structure

- Fixed number of parameters. Must be 4
- increase reference count for the data
Internal Numba required function to register StringMethodsType and
connect it with corresponding Python type mentioned in @overload(pandas.core.strings.StringMethods)
"""

[data_val] = args
stringmethod = cgutils.create_struct_proxy(signature.return_type)(context, builder)
stringmethod.data = data_val
def _hpat_pandas_stringmethods_init_codegen(context, builder, signature, args):
"""
It is looks like it creates StringMethodsModel structure

if context.enable_nrt:
context.nrt.incref(builder, data, stringmethod.data)
- Fixed number of parameters. Must be 4
- increase reference count for the data
"""

return stringmethod._getvalue()
[data_val] = args
stringmethod = cgutils.create_struct_proxy(signature.return_type)(context, builder)
stringmethod.data = data_val

ret_typ = StringMethodsType(data)
sig = signature(ret_typ, data)
"""
Construct signature of the Numba SeriesGroupByType::ctor()
"""
if context.enable_nrt:
context.nrt.incref(builder, data, stringmethod.data)

return sig, _hpat_pandas_stringmethods_init_codegen
return stringmethod._getvalue()

ret_typ = string_methods_type(data)
sig = signature(ret_typ, data)
"""
Construct signature of the Numba SeriesGroupByType::ctor()
"""

return sig, _hpat_pandas_stringmethods_init_codegen

return _hpat_pandas_stringmethods_init


_hpat_pandas_stringmethods_init = intrinsic(
_gen_hpat_pandas_stringmethods_init(string_methods_type=StringMethodsType))
_hpat_pandas_split_view_stringmethods_init = intrinsic(
_gen_hpat_pandas_stringmethods_init(string_methods_type=SplitViewStringMethodsType))


@overload(pandas.core.strings.StringMethods)
Expand All @@ -113,6 +125,11 @@ def hpat_pandas_stringmethods(obj):
Special Numba procedure to overload Python type pandas.core.strings.StringMethods::ctor()
with Numba registered model
"""
if isinstance(obj.data, StringArraySplitViewType):
def hpat_pandas_split_view_stringmethods_impl(obj):
return _hpat_pandas_split_view_stringmethods_init(obj)

return hpat_pandas_split_view_stringmethods_impl

def hpat_pandas_stringmethods_impl(obj):
return _hpat_pandas_stringmethods_init(obj)
Expand Down
10 changes: 6 additions & 4 deletions sdc/hiframes/hiframes_typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@
from sdc.hiframes.rolling import get_rolling_setup_args
from sdc.hiframes.aggregate import Aggregate
from sdc.hiframes.series_kernels import series_replace_funcs
from sdc.hiframes.split_impl import (string_array_split_view_type,
StringArraySplitViewType, getitem_c_arr, get_array_ctypes_ptr,
get_split_view_index, get_split_view_data_ptr)
from sdc.hiframes.split_impl import (SplitViewStringMethodsType,
string_array_split_view_type, StringArraySplitViewType,
getitem_c_arr, get_array_ctypes_ptr,
get_split_view_index, get_split_view_data_ptr)


_dt_index_binops = ('==', '!=', '>=', '>', '<=', '<', '-',
Expand Down Expand Up @@ -480,7 +481,8 @@ def _run_call(self, assign, lhs, rhs):
else:
func_name, func_mod = fdef

if (isinstance(func_mod, ir.Var) and isinstance(self.state.typemap[func_mod.name], StringMethodsType)):
string_methods_types = (SplitViewStringMethodsType, StringMethodsType)
if isinstance(func_mod, ir.Var) and isinstance(self.state.typemap[func_mod.name], string_methods_types):
f_def = guard(get_definition, self.state.func_ir, rhs.func)
str_def = guard(get_definition, self.state.func_ir, f_def.value)
if str_def is None: # TODO: check for errors
Expand Down
26 changes: 19 additions & 7 deletions sdc/hiframes/pd_series_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@
from sdc.hiframes.pd_categorical_ext import (PDCategoricalDtype, CategoricalArray)
from sdc.hiframes.pd_timestamp_ext import (pandas_timestamp_type, datetime_date_type)
from sdc.hiframes.rolling import supported_rolling_funcs
from sdc.hiframes.split_impl import (string_array_split_view_type, GetItemStringArraySplitView)
from sdc.hiframes.split_impl import (SplitViewStringMethodsType,
string_array_split_view_type,
GetItemStringArraySplitView)
from sdc.str_arr_ext import (
string_array_type,
iternext_str_array,
Expand Down Expand Up @@ -423,9 +425,9 @@ def resolve_T(self, ary):
# def resolve_index(self, ary):
# return ary.index

def resolve_str(self, ary):
assert ary.dtype in (string_type, types.List(string_type))
return StringMethodsType(ary)
# def resolve_str(self, ary):
# assert ary.dtype in (string_type, types.List(string_type))
# return StringMethodsType(ary)

def resolve_dt(self, ary):
assert ary.dtype == types.NPDatetime('ns')
Expand Down Expand Up @@ -780,9 +782,9 @@ class SeriesStrMethodAttribute(AttributeTemplate):
def resolve_contains(self, ary, args, kws):
return signature(SeriesType(types.bool_), *args)

@bound_function("strmethod.len")
def resolve_len(self, ary, args, kws):
return signature(SeriesType(types.int64), *args)
# @bound_function("strmethod.len")
# def resolve_len(self, ary, args, kws):
# return signature(SeriesType(types.int64), *args)

@bound_function("strmethod.replace")
def resolve_replace(self, ary, args, kws):
Expand Down Expand Up @@ -820,6 +822,16 @@ def generic(self, args, kws):
raise NotImplementedError('Series.str.{} is not supported yet'.format(func_name))


@infer_getattr
class SplitViewSeriesStrMethodAttribute(AttributeTemplate):
key = SplitViewStringMethodsType

@bound_function('strmethod.get')
def resolve_get(self, ary, args, kws):
# XXX only list(list(str)) supported
return signature(SeriesType(string_type), *args)


class SeriesDtMethodType(types.Type):
def __init__(self):
name = "SeriesDtMethodType"
Expand Down
79 changes: 79 additions & 0 deletions sdc/hiframes/split_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,16 @@


import operator
import numpy
import pandas
import numba
import sdc
from numba import types
from numba.typing.templates import (infer_global, AbstractTemplate, infer,
signature, AttributeTemplate, infer_getattr, bound_function)
import numba.typing.typeof
from numba.datamodel import StructModel
from numba.errors import TypingError
from numba.extending import (typeof_impl, type_callable, models, register_model, NativeValue,
make_attribute_wrapper, lower_builtin, box, unbox,
lower_getattr, intrinsic, overload_method, overload, overload_attribute)
Expand Down Expand Up @@ -131,6 +135,43 @@ def __init__(self, dmm, fe_type):
make_attribute_wrapper(StringArraySplitViewType, 'data', '_data')


class SplitViewStringMethodsType(types.IterableType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How this class related to pandas.core.strings.StringMethods? Why we need it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class related to the pandas string methods through existing unit test:

    def test_str_split_filter(self):
        def test_impl(df):
            B = df.A.str.split(',')
            df2 = pd.DataFrame({'B': B})
            return df2[df2.B.str.len() > 1]
        # ...

where split() returns Series with StringArraySplitView. So I made str.len() to work with such data this way. Otherwise we get something like that:
Invalid store of {{i64, i32*, i32*, {i8*, i8*}, i8*}, i8*, {i8*, i64, i32, i32, i64, i8*, i8*}} to {{i64, i64, i32*, i8*, i8*, i8*}, i8*, {i8*, i64, i32, i32, i64, i8*, i8*}}

"""
Type definition for pandas.core.strings.StringMethods functions handling.

Members
----------
_data: :class:`SeriesType`
input arg
"""

def __init__(self, data):
self.data = data
name = 'SplitViewStringMethodsType({})'.format(self.data)
super(SplitViewStringMethodsType, self).__init__(name)

@property
def iterator_type(self):
return None


@register_model(SplitViewStringMethodsType)
class SplitViewStringMethodsTypeModel(StructModel):
"""
Model for SplitViewStringMethodsType type
All members must be the same as main type for this model
"""

def __init__(self, dmm, fe_type):
members = [
('data', fe_type.data)
]
models.StructModel.__init__(self, dmm, fe_type, members)


make_attribute_wrapper(SplitViewStringMethodsType, 'data', '_data')


def construct_str_arr_split_view(context, builder):
"""Creates meminfo and sets dtor.
"""
Expand Down Expand Up @@ -404,6 +445,44 @@ def str_arr_split_view_len_overload(arr):
return lambda arr: arr._num_items


@overload_method(SplitViewStringMethodsType, 'len')
def hpat_pandas_spliview_stringmethods_len(self):
"""
Pandas Series method :meth:`pandas.core.strings.StringMethods.len()` implementation.

Note: Unicode type of list elements are supported only. Numpy.NaN is not supported as elements.

.. only:: developer

Test: python -m sdc.runtests sdc.tests.test_hiframes.TestHiFrames.test_str_split_filter

Parameters
----------
self: :class:`pandas.core.strings.StringMethods`
input arg

Returns
-------
:obj:`pandas.Series`
returns :obj:`pandas.Series` object
"""

if not isinstance(self, SplitViewStringMethodsType):
msg = 'Method len(). The object must be a pandas.core.strings. Given: {}'
raise TypingError(msg.format(self))

def hpat_pandas_spliview_stringmethods_len_impl(self):
item_count = len(self._data)
result = numpy.empty(item_count, numba.types.int64)
local_data = self._data._data
for i in range(len(local_data)):
result[i] = len(local_data[i])

return pandas.Series(result, name=self._data._name)

return hpat_pandas_spliview_stringmethods_len_impl


# @infer_global(operator.getitem)
class GetItemStringArraySplitView(AbstractTemplate):
key = operator.getitem
Expand Down
1 change: 0 additions & 1 deletion sdc/tests/test_hiframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ def test_impl(df):
pd.testing.assert_series_equal(
hpat_func(df), test_impl(df), check_names=False)

@skip_sdc_jit("Could not get length of Series(StringArraySplitView)")
@skip_numba_jit
def test_str_split_filter(self):
def test_impl(df):
Expand Down