Skip to content
This repository has been archived by the owner on Feb 2, 2024. It is now read-only.

Commit

Permalink
Fix for string tests failed due to bad str_overload (#94)
Browse files Browse the repository at this point in the history
Fix for string tests failed due to bad str_overload
  • Loading branch information
kozlov-alexey authored and fschlimb committed Jul 22, 2019
1 parent 9f0aca7 commit db2736f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 21 deletions.
47 changes: 31 additions & 16 deletions hpat/str_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,37 @@ def unliteral_all(args):

## use objmode for string methods for now

# string methods that just return another string
str2str_methods = ('capitalize', 'casefold', 'lower', 'lstrip', 'rstrip',
'strip', 'swapcase', 'title', 'upper')

for method in str2str_methods:
func_text = "def str_overload(in_str):\n"
func_text += " def _str_impl(in_str):\n"
func_text += " with numba.objmode(out='unicode_type'):\n"
func_text += " out = in_str.{}()\n".format(method)
func_text += " return out\n"
func_text += " return _str_impl\n"
loc_vars = {}
exec(func_text, {'numba': numba}, loc_vars)
str_overload = loc_vars['str_overload']
overload_method(types.UnicodeType, method)(str_overload)

# string methods that take no arguments and return a string
str2str_noargs = ('capitalize', 'casefold', 'lower', 'swapcase', 'title', 'upper')

def str_overload_noargs(method):
@overload_method(types.UnicodeType, method)
def str_overload(in_str):
def _str_impl(in_str):
with numba.objmode(out='unicode_type'):
out = getattr(in_str, method)()
return out

return _str_impl

for method in str2str_noargs:
str_overload_noargs(method)

# strip string methods that take one argument and return a string
str2str_1arg = ('lstrip', 'rstrip', 'strip')

def str_overload_1arg(method):
@overload_method(types.UnicodeType, method)
def str_overload(in_str, arg1):
def _str_impl(in_str, arg1):
with numba.objmode(out='unicode_type'):
out = getattr(in_str, method)(arg1)
return out

return _str_impl

for method in str2str_1arg:
str_overload_1arg(method)

@overload_method(types.UnicodeType, 'replace')
def str_replace_overload(in_str, old, new, count=-1):
Expand Down
5 changes: 0 additions & 5 deletions hpat/tests/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ def test_impl():
hpat_func = hpat.jit(test_impl)
self.assertEqual(hpat_func(), test_impl())

@unittest.skip('numba.errors.LoweringError - fix needed\n'
'Failed in hpat mode pipeline'
'(step: nopython mode backend)\n'
'str_overload() takes 1 positional argument '
'but 2 were given\n')
def test_str2str(self):
str2str_methods = ['capitalize', 'casefold', 'lower', 'lstrip',
'rstrip', 'strip', 'swapcase', 'title', 'upper']
Expand Down

0 comments on commit db2736f

Please sign in to comment.