Skip to content
This repository has been archived by the owner on Feb 2, 2024. It is now read-only.

Commit

Permalink
Merge 7c6ea14 into c6397f4
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov-alexey committed Jul 17, 2019
2 parents c6397f4 + 7c6ea14 commit be7c073
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
25 changes: 21 additions & 4 deletions hpat/str_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ def unliteral_all(args):

## use objmode for string methods for now

# string methods that just return another string
str2str_methods = ('capitalize', 'casefold', 'lower', 'lstrip', 'rstrip',
'strip', 'swapcase', 'title', 'upper')
# string methods that take no arguments and return a string
str2str_noargs = ('capitalize', 'casefold', 'lower', 'swapcase', 'title', 'upper')

for method in str2str_methods:
# define overload methods for all string functions in str2str_noargs
# to call corresponding Numba methods
for method in str2str_noargs:
func_text = "def str_overload(in_str):\n"
func_text += " def _str_impl(in_str):\n"
func_text += " with numba.objmode(out='unicode_type'):\n"
Expand All @@ -68,6 +69,22 @@ def unliteral_all(args):
str_overload = loc_vars['str_overload']
overload_method(types.UnicodeType, method)(str_overload)

# strip string methods that take one argument and return a string
str2str_1arg = ('lstrip', 'rstrip', 'strip')

# for strip methods define overloads to call Numba implementation
# forwarding arg1 as a necessary 'chars' argument
for method in str2str_1arg:
func_text = "def str_overload(in_str, arg1):\n"
func_text += " def _str_impl(in_str, arg1):\n"
func_text += " with numba.objmode(out='unicode_type'):\n"
func_text += " out = in_str.{}(arg1)\n".format(method)
func_text += " return out\n"
func_text += " return _str_impl\n"
loc_vars = {}
exec(func_text, {'numba': numba}, loc_vars)
str_overload = loc_vars['str_overload']
overload_method(types.UnicodeType, method)(str_overload)

@overload_method(types.UnicodeType, 'replace')
def str_replace_overload(in_str, old, new, count=-1):
Expand Down
5 changes: 0 additions & 5 deletions hpat/tests/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ def test_impl():
hpat_func = hpat.jit(test_impl)
self.assertEqual(hpat_func(), test_impl())

@unittest.skip('numba.errors.LoweringError - fix needed\n'
'Failed in hpat mode pipeline'
'(step: nopython mode backend)\n'
'str_overload() takes 1 positional argument '
'but 2 were given\n')
def test_str2str(self):
str2str_methods = ['capitalize', 'casefold', 'lower', 'lstrip',
'rstrip', 'strip', 'swapcase', 'title', 'upper']
Expand Down

0 comments on commit be7c073

Please sign in to comment.