Skip to content
This repository has been archived by the owner on Feb 2, 2024. It is now read-only.

Commit

Permalink
Replace rand chararray generator with strlist one
Browse files Browse the repository at this point in the history
  • Loading branch information
densmirn committed Nov 1, 2019
1 parent 776827c commit 3038203
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 20 deletions.
1 change: 1 addition & 0 deletions hpat/datatypes/hpat_pandas_series_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def hpat_pandas_series_nlargest_impl(self, n=5, keep='first'):

# data: [0, 1, -1, 1, 0] -> [1, 1, 0, 0, -1]
# index: [0, 1, 2, 3, 4] -> [1, 3, 0, 4, 2] (not [3, 1, 4, 0, 2])
# subtract 1 to ensure reverse ordering at boundaries
indices = (-self._data - 1).argsort(kind='mergesort')[:max(n, 0)]

return self.take(indices)
Expand Down
52 changes: 32 additions & 20 deletions hpat/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pyarrow.parquet as pq
import hpat
from itertools import islice, permutations
from hpat.tests.test_utils import (
count_array_REPs, count_parfor_REPs, count_array_OneDs, get_start_end)
from hpat.tests.gen_test_data import ParquetGenerator
Expand Down Expand Up @@ -73,13 +74,12 @@
]


def gen_srand_array(size, nchars=8):
"""Generate array of strings of specified size based on [a-zA-Z] + [0-9]"""
accepted_chars = list(string.ascii_letters + string.digits)
rands_chars = np.array(accepted_chars, dtype=(np.str_, 1))
def gen_strlist(size, nchars=8):
"""Generate list of strings of specified size based on [a-zA-Z] + [0-9]"""
accepted_chars = string.ascii_letters + string.digits
generated_chars = islice(permutations(accepted_chars, nchars), size)

np.random.seed(100)
return np.random.choice(rands_chars, size=nchars * size).view((np.str_, nchars))
return [''.join(chars) for chars in generated_chars]


def _make_func_from_text(func_text, func_name='test_impl'):
Expand Down Expand Up @@ -2282,25 +2282,31 @@ def test_impl(series, n, keep):
jit_result = hpat_func(series, n, keep)
pd.testing.assert_series_equal(ref_result, jit_result)

@unittest.skipIf(hpat.config.config_pipeline_hpat_default,
'Series.nlargest() index unsupported')
def test_series_nlargest_index(self):
def test_impl(series, n):
return series.nlargest(n)
hpat_func = hpat.jit(test_impl)

# TODO: check data == [] when index fixed
# TODO: check data == [] after index is fixed
for data in test_global_input_data_numeric:
data *= 3
for index in [gen_srand_array(len(data)), range(len(data))]:
# TODO: add integer index not equal to range after index is fixed
indexes = [range(len(data))]
if not hpat.config.config_pipeline_hpat_default:
indexes.append(gen_strlist(len(data)))

for index in indexes:
series = pd.Series(data, index)
for n in range(-1, 10):
ref_result = test_impl(series, n)
jit_result = hpat_func(series, n)
pd.testing.assert_series_equal(ref_result, jit_result)
if hpat.config.config_pipeline_hpat_default:
np.testing.assert_array_equal(ref_result, jit_result)
else:
pd.testing.assert_series_equal(ref_result, jit_result)

@unittest.skipIf(hpat.config.config_pipeline_hpat_default,
'Series.nlargest() types validation unsupported')
'Series.nlargest() does not raise an exception')
def test_series_nlargest_typing(self):
_func_name = 'Method nlargest().'

Expand All @@ -2324,7 +2330,7 @@ def test_impl(series, n, keep):
self.assertIn(msg.format(_func_name, dtype), str(raises.exception))

@unittest.skipIf(hpat.config.config_pipeline_hpat_default,
'Series.nlargest() values validation unsupported')
'Series.nlargest() does not raise an exception')
def test_series_nlargest_unsupported(self):
msg = "Method nlargest(). Unsupported parameter. Given 'keep' != 'first'"

Expand Down Expand Up @@ -2402,25 +2408,31 @@ def test_impl(series, n, keep):
jit_result = hpat_func(series, n, keep)
pd.testing.assert_series_equal(ref_result, jit_result)

@unittest.skipIf(hpat.config.config_pipeline_hpat_default,
'Series.nsmallest() index unsupported')
def test_series_nsmallest_index(self):
def test_impl(series, n):
return series.nsmallest(n)
hpat_func = hpat.jit(test_impl)

# TODO: check data == [] when index fixed
# TODO: check data == [] after index is fixed
for data in test_global_input_data_numeric:
data *= 3
for index in [gen_srand_array(len(data)), range(len(data))]:
# TODO: add integer index not equal to range after index is fixed
indexes = [range(len(data))]
if not hpat.config.config_pipeline_hpat_default:
indexes.append(gen_strlist(len(data)))

for index in indexes:
series = pd.Series(data, index)
for n in range(-1, 10):
ref_result = test_impl(series, n)
jit_result = hpat_func(series, n)
pd.testing.assert_series_equal(ref_result, jit_result)
if hpat.config.config_pipeline_hpat_default:
np.testing.assert_array_equal(ref_result, jit_result)
else:
pd.testing.assert_series_equal(ref_result, jit_result)

@unittest.skipIf(hpat.config.config_pipeline_hpat_default,
'Series.nsmallest() types validation unsupported')
'Series.nsmallest() does not raise an exception')
def test_series_nsmallest_typing(self):
_func_name = 'Method nsmallest().'

Expand All @@ -2444,7 +2456,7 @@ def test_impl(series, n, keep):
self.assertIn(msg.format(_func_name, dtype), str(raises.exception))

@unittest.skipIf(hpat.config.config_pipeline_hpat_default,
'Series.nsmallest() values validation unsupported')
'Series.nsmallest() does not raise an exception')
def test_series_nsmallest_unsupported(self):
msg = "Method nsmallest(). Unsupported parameter. Given 'keep' != 'first'"

Expand Down

0 comments on commit 3038203

Please sign in to comment.