Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions sdc/datatypes/hpat_pandas_stringmethods_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def hpat_pandas_stringmethods_upper_impl(self):

import numba
from numba.extending import overload_method
from numba.types import (Integer, NoneType, Omitted, StringLiteral, UnicodeType)

from sdc.datatypes.common_functions import TypeChecker
from sdc.datatypes.hpat_pandas_stringmethods_types import StringMethodsType
Expand Down Expand Up @@ -187,6 +188,65 @@ def hpat_pandas_stringmethods_{methodname}_impl(self{methodparams}):
"""


@overload_method(StringMethodsType, 'find')
def hpat_pandas_stringmethods_find(self, sub, start=0, end=None):
"""
Pandas Series method :meth:`pandas.core.strings.StringMethods.find()` implementation.

Note: Unicode type of list elements are supported only. Numpy.NaN is not supported as elements.

.. only:: developer

Test: python -m sdc.runtests -k sdc.tests.test_series.TestSeries.test_series_find

Parameters
----------
self: :class:`pandas.core.strings.StringMethods`
input arg
sub: :obj:`str`
Substring being searched
start: :obj:`int`
Left edge index
*unsupported*
end: :obj:`int`
Right edge index
*unsupported*

Returns
-------
:obj:`pandas.Series`
returns :obj:`pandas.Series` object
"""

ty_checker = TypeChecker('Method find().')
ty_checker.check(self, StringMethodsType)

if not isinstance(sub, (StringLiteral, UnicodeType)):
ty_checker.raise_exc(sub, 'str', 'sub')

accepted_types = (Integer, NoneType, Omitted)
if not isinstance(start, accepted_types) and start != 0:
ty_checker.raise_exc(start, 'None, int', 'start')

if not isinstance(end, accepted_types) and end is not None:
ty_checker.raise_exc(end, 'None, int', 'end')

def hpat_pandas_stringmethods_find_impl(self, sub, start=0, end=None):
if start != 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we are not supporting start and end? UnicodeString doesn't?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numba 0.46 doesn't support str.find with extra parameters.

raise ValueError('Method find(). The object start\n expected: 0')
if end is not None:
raise ValueError('Method find(). The object end\n expected: None')

item_count = len(self._data)
result = numpy.empty(item_count, numba.types.int64)
for idx, item in enumerate(self._data._data):
result[idx] = item.find(sub)

return pandas.Series(result, self._data._index, name=self._data._name)

return hpat_pandas_stringmethods_find_impl


@overload_method(StringMethodsType, 'isupper')
def hpat_pandas_stringmethods_isupper(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion sdc/hiframes/pd_series_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def resolve_head(self, ary, args, kws):
Functions which are still overloaded by HPAT compiler pipeline
"""

str2str_methods_excluded = ['upper', 'isupper', 'len', 'lower',
str2str_methods_excluded = ['upper', 'find', 'isupper', 'len', 'lower',
'lstrip', 'rstrip', 'strip']
"""
Functions which are used from Numba directly by calling from StringMethodsType
Expand Down
53 changes: 53 additions & 0 deletions sdc/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,59 @@ def test_impl(S1, S2):
hpat_func(S1, S2), test_impl(S1, S2),
err_msg='S1={}\nS2={}'.format(S1, S2))

def test_series_str_find(self):
def test_impl(series, sub):
return series.str.find(sub)
hpat_func = self.jit(test_impl)

data = test_global_input_data_unicode_kind4
subs = [''] + [s[:min(len(s) for s in data)] for s in data] + data
indices = [None, list(range(len(data)))[::-1], data[::-1]]
names = [None, 'A']
for index, name in product(indices, names):
series = pd.Series(data, index, name=name)
for sub in subs:
pd.testing.assert_series_equal(hpat_func(series, sub),
test_impl(series, sub))

def test_series_str_find_exception_unsupported_start(self):
def test_impl(series, sub, start):
return series.str.find(sub, start)
hpat_func = self.jit(test_impl)

series = pd.Series(test_global_input_data_unicode_kind4)
msg_tmpl = 'Method {}(). The object {}\n {}'

with self.assertRaises(TypingError) as raises:
hpat_func(series, '', '0')
msg = msg_tmpl.format('find', 'start', 'given: unicode_type\n '
'expected: None, int')
self.assertIn(msg, str(raises.exception))

with self.assertRaises(ValueError) as raises:
hpat_func(series, '', 1)
msg = msg_tmpl.format('find', 'start', 'expected: 0')
self.assertIn(msg, str(raises.exception))

def test_series_str_find_exception_unsupported_end(self):
def test_impl(series, sub, start, end):
return series.str.find(sub, start, end)
hpat_func = self.jit(test_impl)

series = pd.Series(test_global_input_data_unicode_kind4)
msg_tmpl = 'Method {}(). The object {}\n {}'

with self.assertRaises(TypingError) as raises:
hpat_func(series, '', 0, 'None')
msg = msg_tmpl.format('find', 'end', 'given: unicode_type\n '
'expected: None, int')
self.assertIn(msg, str(raises.exception))

with self.assertRaises(ValueError) as raises:
hpat_func(series, '', 0, 0)
msg = msg_tmpl.format('find', 'end', 'expected: None')
self.assertIn(msg, str(raises.exception))

def test_series_str_len1(self):
def test_impl(S):
return S.str.len()
Expand Down