diff --git a/sdc/datatypes/hpat_pandas_series_functions.py b/sdc/datatypes/hpat_pandas_series_functions.py index b96464173..795647084 100644 --- a/sdc/datatypes/hpat_pandas_series_functions.py +++ b/sdc/datatypes/hpat_pandas_series_functions.py @@ -1777,42 +1777,20 @@ def hpat_pandas_series_astype(self, dtype, copy=True, errors='raise'): errors in ('raise', 'ignore')): ty_checker.raise_exc(errors, 'str', 'errors') - # Return StringArray for astype(str) or astype('str') - def hpat_pandas_series_astype_to_str_impl(self, dtype, copy=True, errors='raise'): - num_chars = 0 - arr_len = len(self._data) - - # Get total chars for new array - for i in prange(arr_len): - item = self._data[i] - num_chars += len(str(item)) # TODO: check NA - - data = pre_alloc_string_array(arr_len, num_chars) - for i in prange(arr_len): - item = self._data[i] - data[i] = str(item) # TODO: check NA - - return pandas.Series(data=data, index=self._index, name=self._name) - # Return npytypes.Array from npytypes.Array for astype(types.functions.NumberClass), example - astype(np.int64) - def hpat_pandas_series_astype_numba_impl(self, dtype, copy=True, errors='raise'): - return pandas.Series(data=self._data.astype(dtype), index=self._index, name=self._name) - # Return npytypes.Array from npytypes.Array for astype(types.StringLiteral), example - astype('int64') - def hpat_pandas_series_astype_literal_type_numba_impl(self, dtype, copy=True, errors='raise'): - return pandas.Series(data=self._data.astype(numpy.dtype(dtype)), index=self._index, name=self._name) + def hpat_pandas_series_astype_numba_impl(self, dtype, copy=True, errors='raise'): + return pandas.Series(data=numpy_like.astype(self._data, dtype), index=self._index, name=self._name) # Return self def hpat_pandas_series_astype_no_modify_impl(self, dtype, copy=True, errors='raise'): return pandas.Series(data=self._data, index=self._index, name=self._name) - - if ((isinstance(dtype, types.Function) and dtype.typing_key == str) - or (isinstance(dtype, types.StringLiteral) and dtype.literal_value == 'str')): - return hpat_pandas_series_astype_to_str_impl + str_check = ((isinstance(dtype, types.Function) and dtype.typing_key == str) or + (isinstance(dtype, types.StringLiteral) and dtype.literal_value == 'str')) # Needs Numba astype impl support converting unicode_type to NumberClass and other types - if isinstance(self.data, StringArrayType): + if (isinstance(self.data, StringArrayType) and not str_check): if isinstance(dtype, types.functions.NumberClass) and errors == 'raise': raise TypingError(f'Needs Numba astype impl support converting unicode_type to {dtype}') if isinstance(dtype, types.StringLiteral) and errors == 'raise': @@ -1823,18 +1801,12 @@ def hpat_pandas_series_astype_no_modify_impl(self, dtype, copy=True, errors='rai else: raise TypingError(f'Needs Numba astype impl support converting unicode_type to {dtype.literal_value}') - if isinstance(self.data, types.npytypes.Array) and isinstance(dtype, types.functions.NumberClass): - return hpat_pandas_series_astype_numba_impl + data_narr = isinstance(self.data, types.npytypes.Array) + dtype_num_liter = isinstance(dtype, (types.functions.NumberClass, types.StringLiteral)) - if isinstance(self.data, types.npytypes.Array) and isinstance(dtype, types.StringLiteral): - try: - literal_value = numpy.dtype(dtype.literal_value) - except: - pass # Will raise the exception later - else: - return hpat_pandas_series_astype_literal_type_numba_impl + if data_narr and dtype_num_liter or str_check: + return hpat_pandas_series_astype_numba_impl - # Raise error if dtype is not supported if errors == 'raise': raise TypingError(f'{_func_name} The object must be a supported type. Given dtype: {dtype}') else: diff --git a/sdc/functions/numpy_like.py b/sdc/functions/numpy_like.py index 3ac537d4e..82b2a3d60 100644 --- a/sdc/functions/numpy_like.py +++ b/sdc/functions/numpy_like.py @@ -108,7 +108,7 @@ def sdc_astype_overload(self, dtype): """ ty_checker = TypeChecker("numpy-like 'astype'") - if not isinstance(self, types.Array): + if not isinstance(self, (types.Array, StringArrayType)): return None if not isinstance(dtype, (types.functions.NumberClass, types.Function, types.Literal)): diff --git a/sdc/tests/tests_perf/test_perf_series.py b/sdc/tests/tests_perf/test_perf_series.py index bbc2c90ee..6e98fc32b 100644 --- a/sdc/tests/tests_perf/test_perf_series.py +++ b/sdc/tests/tests_perf/test_perf_series.py @@ -68,7 +68,7 @@ def _test_case(self, pyfunc, name, total_data_length, data_num=1, input_data=tes TC(name='append', size=[10 ** 7], params='other', data_num=2), TC(name='apply', size=[10 ** 7], params='lambda x: x'), TC(name='argsort', size=[10 ** 4]), - TC(name='astype', size=[10 ** 5], call_expr='data.astype(np.int8)', usecase_params='data', + TC(name='astype', size=[10 ** 8], call_expr='data.astype(np.int8)', usecase_params='data', input_data=[test_global_input_data_float64[0]]), TC(name='at', size=[10 ** 7], call_expr='data.at[3]', usecase_params='data'), TC(name='chain_add_and_sum', size=[20 * 10 ** 6, 25 * 10 ** 6, 30 * 10 ** 6], call_expr='(A + B).sum()',