From 5ea7a171385e43cb6ed8c9283b2e8f63c55d30b2 Mon Sep 17 00:00:00 2001 From: Alexey Kozlov Date: Fri, 14 May 2021 17:28:28 +0300 Subject: [PATCH 1/2] Moving to numba=0.53.1 (#971) * Moving to numba=0.53 * Workarounds to avoid Numba regressions in 0.53 * Changing Numba 0.53.0 to 0.53.1 --- conda-recipe/meta.yaml | 2 +- requirements.txt | 2 +- sdc/datatypes/hpat_pandas_series_functions.py | 92 +++++++++++-------- sdc/functions/numpy_like.py | 2 +- setup.py | 2 +- 5 files changed, 57 insertions(+), 43 deletions(-) diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 17189b818..4886a6653 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -1,4 +1,4 @@ -{% set NUMBA_VERSION = "==0.52.0" %} +{% set NUMBA_VERSION = "==0.53.1" %} {% set PANDAS_VERSION = "==1.2.0" %} {% set PYARROW_VERSION = "==2.0.0" %} diff --git a/requirements.txt b/requirements.txt index 4e7e3940c..5b123c130 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ numpy>=1.16 pandas==1.2.0 pyarrow==2.0.0 -numba==0.52.0 +numba==0.53.1 tbb tbb-devel diff --git a/sdc/datatypes/hpat_pandas_series_functions.py b/sdc/datatypes/hpat_pandas_series_functions.py index 3594902a8..348cf2665 100644 --- a/sdc/datatypes/hpat_pandas_series_functions.py +++ b/sdc/datatypes/hpat_pandas_series_functions.py @@ -459,13 +459,15 @@ def _series_getitem_idx_bool_indexer_impl(self, idx): if (isinstance(idx, SeriesType) and index_is_positional and not isinstance(idx.data.dtype, (types.Boolean, bool))): def hpat_pandas_series_getitem_idx_list_impl(self, idx): - res = numpy.copy(self._data[:len(idx._data)]) - index = numpy.arange(len(self._data)) + idx_data = idx._data + self_data = self._data + res = numpy.copy(self._data[:len(idx_data)]) + index = numpy.arange(len(self_data)) for i in numba.prange(len(res)): for j in numba.prange(len(index)): - if j == idx._data[i]: - res[i] = self._data[j] - return pandas.Series(data=res, index=index[idx._data], name=self._name) + if j == idx_data[i]: + res[i] = self_data[j] + return pandas.Series(data=res, index=index[idx_data], name=self._name) return hpat_pandas_series_getitem_idx_list_impl # idx is Series and it's index is not PositionalIndex, idx.dtype is not Boolean @@ -647,6 +649,7 @@ def sdc_pandas_series_setitem_no_reindexing_impl(self, idx, value): def sdc_pandas_series_setitem_idx_bool_array_align_impl(self, idx, value): + series_data = self._data # FIXME_Numba#6960 # if idx is a Boolean array (and value is a series) it's used as a mask for self.index # and filtered indexes are looked in value.index, and if found corresponding value is set if value_is_series == True: # noqa @@ -659,7 +662,7 @@ def sdc_pandas_series_setitem_idx_bool_array_align_impl(self, idx, value): self_index_has_duplicates = len(unique_self_indices) != len(self_index) value_index_has_duplicates = len(unique_value_indices) != len(value_index) if (self_index_has_duplicates or value_index_has_duplicates): - self._data[idx] = value._data + series_data[idx] = value._data else: map_index_to_position = Dict.empty( key_type=indexes_common_dtype, @@ -674,13 +677,13 @@ def sdc_pandas_series_setitem_idx_bool_array_align_impl(self, idx, value): if idx[i]: self_index_value = self_index[i] if self_index_value in map_index_to_position: - self._data[i] = value._data[map_index_to_position[self_index_value]] + series_data[i] = value._data[map_index_to_position[self_index_value]] else: - sdc.hiframes.join.setitem_arr_nan(self._data, i) + sdc.hiframes.join.setitem_arr_nan(series_data, i) else: # if value has no index - nothing to reindex and assignment is made along positions set by idx mask - self._data[idx] = value + series_data[idx] = value return self @@ -755,21 +758,25 @@ def sdc_pandas_series_setitem_idx_bool_series_align_impl(self, idx, value): value_is_scalar = not (value_is_series or value_is_array) def sdc_pandas_series_setitem_idx_int_series_align_impl(self, idx, value): + # FIXME_Numba#6960: all changes of this commit are unnecessary - revert when resolved + self_data = self._data + self_index = self._index + self_index_size = len(self_index) + idx_size = len(idx) + _idx = idx._data if idx_is_series == True else idx # noqa _value = value._data if value_is_series == True else value # noqa - self_index_size = len(self._index) - idx_size = len(_idx) valid_indices = numpy.repeat(-1, self_index_size) for i in numba.prange(self_index_size): for j in numpy.arange(idx_size): - if self._index[i] == _idx[j]: + if self_index[i] == _idx[j]: valid_indices[i] = j valid_indices_positions = numpy.arange(self_index_size)[valid_indices != -1] valid_indices_masked = valid_indices[valid_indices != -1] - indexes_found = self._index[valid_indices_positions] + indexes_found = self_index[valid_indices_positions] if len(numpy.unique(indexes_found)) != len(indexes_found): raise ValueError("Reindexing only valid with uniquely valued Index objects") @@ -777,9 +784,9 @@ def sdc_pandas_series_setitem_idx_int_series_align_impl(self, idx, value): raise KeyError("Reindexing not possible: idx has index not found in Series") if value_is_scalar == True: # noqa - self._data[valid_indices_positions] = _value + self_data[valid_indices_positions] = _value else: - self._data[valid_indices_positions] = numpy.take(_value, valid_indices_masked) + self_data[valid_indices_positions] = numpy.take(_value, valid_indices_masked) return self @@ -1598,17 +1605,18 @@ def hpat_pandas_series_var_impl(self, axis=None, skipna=None, level=None, ddof=1 if skipna is None: skipna = True + self_data = self._data # FIXME_Numba#6960 if skipna: - valuable_length = len(self._data) - numpy.sum(numpy.isnan(self._data)) + valuable_length = len(self_data) - numpy.sum(numpy.isnan(self_data)) if valuable_length <= ddof: return numpy.nan - return numpy_like.nanvar(self._data) * valuable_length / (valuable_length - ddof) + return numpy_like.nanvar(self_data) * valuable_length / (valuable_length - ddof) - if len(self._data) <= ddof: + if len(self_data) <= ddof: return numpy.nan - return self._data.var() * len(self._data) / (len(self._data) - ddof) + return self_data.var() * len(self_data) / (len(self_data) - ddof) return hpat_pandas_series_var_impl @@ -2859,8 +2867,9 @@ def hpat_pandas_series_prod_impl(self, axis=None, skipna=None, level=None, numer else: _skipna = skipna + series_data = self._data # FIXME_Numba#6960 if _skipna: - return numpy_like.nanprod(self._data) + return numpy_like.nanprod(series_data) else: return numpy.prod(self._data) @@ -3079,8 +3088,9 @@ def hpat_pandas_series_min_impl(self, axis=None, skipna=None, level=None, numeri else: _skipna = skipna + series_data = self._data # FIXME_Numba#6960 if _skipna: - return numpy_like.nanmin(self._data) + return numpy_like.nanmin(series_data) return self._data.min() @@ -3156,8 +3166,9 @@ def hpat_pandas_series_max_impl(self, axis=None, skipna=None, level=None, numeri else: _skipna = skipna + series_data = self._data # FIXME_Numba#6960 if _skipna: - return numpy_like.nanmax(self._data) + return numpy_like.nanmax(series_data) return self._data.max() @@ -3222,8 +3233,9 @@ def hpat_pandas_series_mean_impl(self, axis=None, skipna=None, level=None, numer else: _skipna = skipna + series_data = self._data # FIXME_Numba#6960 if _skipna: - return numpy_like.nanmean(self._data) + return numpy_like.nanmean(series_data) return self._data.mean() @@ -3780,27 +3792,28 @@ def hpat_pandas_series_argsort(self, axis=0, kind='quicksort', order=None): if not isinstance(self.index, PositionalIndexType): def hpat_pandas_series_argsort_idx_impl(self, axis=0, kind='quicksort', order=None): + series_data = self._data # FIXME_Numba#6960 if kind != 'quicksort' and kind != 'mergesort': raise ValueError("Method argsort(). Unsupported parameter. Given 'kind' != 'quicksort' or 'mergesort'") if kind == 'mergesort': #It is impossible to use numpy.argsort(self._data, kind=kind) since numba gives typing error - sort = numpy_like.argsort(self._data, kind='mergesort') + sort = numpy_like.argsort(series_data, kind='mergesort') else: - sort = numpy_like.argsort(self._data) + sort = numpy_like.argsort(series_data) na = self.isna().sum() - result = numpy.empty(len(self._data), dtype=numpy.int64) - na_data_arr = sdc.hiframes.api.get_nan_mask(self._data) + result = numpy.empty(len(series_data), dtype=numpy.int64) + na_data_arr = sdc.hiframes.api.get_nan_mask(series_data) if kind == 'mergesort': - sort_nona = numpy_like.argsort(self._data[~na_data_arr], kind='mergesort') + sort_nona = numpy_like.argsort(series_data[~na_data_arr], kind='mergesort') else: - sort_nona = numpy_like.argsort(self._data[~na_data_arr]) + sort_nona = numpy_like.argsort(series_data[~na_data_arr]) q = 0 for id, i in enumerate(sort): - if id in set(sort[len(self._data) - na:]): + if id in set(sort[len(series_data) - na:]): q += 1 else: result[id] = sort_nona[id - q] - for i in sort[len(self._data) - na:]: + for i in sort[len(series_data) - na:]: result[i] = -1 return pandas.Series(result, self._index) @@ -3808,26 +3821,27 @@ def hpat_pandas_series_argsort_idx_impl(self, axis=0, kind='quicksort', order=No return hpat_pandas_series_argsort_idx_impl def hpat_pandas_series_argsort_noidx_impl(self, axis=0, kind='quicksort', order=None): + series_data = self._data # FIXME_Numba#6960 if kind != 'quicksort' and kind != 'mergesort': raise ValueError("Method argsort(). Unsupported parameter. Given 'kind' != 'quicksort' or 'mergesort'") if kind == 'mergesort': - sort = numpy_like.argsort(self._data, kind='mergesort') + sort = numpy_like.argsort(series_data, kind='mergesort') else: - sort = numpy_like.argsort(self._data) + sort = numpy_like.argsort(series_data) na = self.isna().sum() - result = numpy.empty(len(self._data), dtype=numpy.int64) - na_data_arr = sdc.hiframes.api.get_nan_mask(self._data) + result = numpy.empty(len(series_data), dtype=numpy.int64) + na_data_arr = sdc.hiframes.api.get_nan_mask(series_data) if kind == 'mergesort': - sort_nona = numpy_like.argsort(self._data[~na_data_arr], kind='mergesort') + sort_nona = numpy_like.argsort(series_data[~na_data_arr], kind='mergesort') else: - sort_nona = numpy_like.argsort(self._data[~na_data_arr]) + sort_nona = numpy_like.argsort(series_data[~na_data_arr]) q = 0 for id, i in enumerate(sort): - if id in set(sort[len(self._data) - na:]): + if id in set(sort[len(series_data) - na:]): q += 1 else: result[id] = sort_nona[id - q] - for i in sort[len(self._data) - na:]: + for i in sort[len(series_data) - na:]: result[i] = -1 return pandas.Series(result) diff --git a/sdc/functions/numpy_like.py b/sdc/functions/numpy_like.py index 8636e0bc2..96d1f0a5c 100644 --- a/sdc/functions/numpy_like.py +++ b/sdc/functions/numpy_like.py @@ -149,7 +149,7 @@ def sdc_astype_number_to_string_impl(self, dtype): arr_len = len(self) # Get total bytes for new array - for i in prange(arr_len): + for i in np.arange(arr_len): # FIXME_Numba#6969: prange segfaults, use it when resolved item = self[i] num_bytes += get_utf8_size(str(item)) diff --git a/setup.py b/setup.py index 0bccb26fc..9335178d0 100644 --- a/setup.py +++ b/setup.py @@ -382,7 +382,7 @@ def run(self): 'numpy>=1.16', 'pandas==1.2.0', 'pyarrow==2.0.0', - 'numba==0.52.0', + 'numba==0.53.1', 'tbb' ], cmdclass=sdc_build_commands, From 1574682c4ee637de7763b1d711bcf067f3db0c69 Mon Sep 17 00:00:00 2001 From: Alexey Kozlov Date: Fri, 14 May 2021 21:03:32 +0300 Subject: [PATCH 2/2] Initial version of ConcurrentDict container via TBB hashmap (#972) * Initial version of ConcurrentDict container via TBB hashmap Motivation: SDC relies on typed.Dict implementation in many core pandas algorithms, and it doesn't support concurrent read/writes. To fill this gap we add ConcurrentDict type which will be used if threading layer is TBB. * Fixing PEP and updating failing import * Fixing builds, warnings and complying to C++11 syntax * Fixing PEP and review comments #1 * Fixing remarks #2 * Applying remarks #3 --- sdc/__init__.py | 2 + sdc/extensions/sdc_hashmap_ext.py | 1125 ++++++++++++++++++++++++++++ sdc/extensions/sdc_hashmap_type.py | 193 +++++ sdc/native/conc_dict_module.cpp | 281 +++++++ sdc/native/hashmap.hpp | 867 +++++++++++++++++++++ sdc/tests/__init__.py | 2 + sdc/tests/test_tbb_hashmap.py | 1049 ++++++++++++++++++++++++++ setup.py | 25 +- 8 files changed, 3543 insertions(+), 1 deletion(-) create mode 100644 sdc/extensions/sdc_hashmap_ext.py create mode 100644 sdc/extensions/sdc_hashmap_type.py create mode 100644 sdc/native/conc_dict_module.cpp create mode 100644 sdc/native/hashmap.hpp create mode 100644 sdc/tests/test_tbb_hashmap.py diff --git a/sdc/__init__.py b/sdc/__init__.py index e9ca063dd..76c29ae97 100644 --- a/sdc/__init__.py +++ b/sdc/__init__.py @@ -50,6 +50,8 @@ import sdc.extensions.indexes.range_index_ext import sdc.extensions.indexes.int64_index_ext +import sdc.extensions.sdc_hashmap_ext + from ._version import get_versions """ diff --git a/sdc/extensions/sdc_hashmap_ext.py b/sdc/extensions/sdc_hashmap_ext.py new file mode 100644 index 000000000..5fea972d8 --- /dev/null +++ b/sdc/extensions/sdc_hashmap_ext.py @@ -0,0 +1,1125 @@ +# ***************************************************************************** +# Copyright (c) 2019-2021, Intel Corporation All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +import llvmlite.binding as ll +import llvmlite.llvmpy.core as lc +import numba +import numpy as np +import operator +import sdc + +from sdc import hstr_ext +from glob import glob +from llvmlite import ir as lir +from numba import types, cfunc +from numba.core import cgutils +from numba.extending import (typeof_impl, type_callable, models, register_model, NativeValue, + lower_builtin, box, unbox, lower_getattr, intrinsic, + overload_method, overload, overload_attribute) +from numba.cpython.hashing import _Py_hash_t +from numba.core.imputils import (impl_ret_new_ref, impl_ret_borrowed, iternext_impl, RefType) +from numba.cpython.listobj import ListInstance +from numba.core.typing.templates import (infer_global, AbstractTemplate, infer, + signature, AttributeTemplate, infer_getattr, bound_function) +from numba import prange + +from sdc.str_ext import string_type +from sdc.str_arr_type import (StringArray, string_array_type, StringArrayType, + StringArrayPayloadType, str_arr_payload_type, StringArrayIterator, + is_str_arr_typ, offset_typ, data_ctypes_type, offset_ctypes_type) +from sdc.utilities.sdc_typing_utils import check_is_array_of_dtype + +from numba.typed.typedobjectutils import _as_bytes +from sdc import hconc_dict +from sdc.extensions.sdc_hashmap_type import (ConcurrentDict, ConcurrentDictType, + ConcDictKeysIterableType, ConcDictIteratorType, + ConcDictItemsIterableType, ConcDictValuesIterableType) +from numba.extending import register_jitable + +from sdc.extensions.sdc_hashmap_type import SdcTypeRef +from sdc.utilities.sdc_typing_utils import TypingError, TypeChecker, check_types_comparable +from itertools import product + +from numba.typed.dictobject import _cast + + +def gen_func_suffixes(): + key_suffixes = ['int32_t', 'int64_t', 'voidptr'] + val_suffixes = ['int32_t', 'int64_t', 'float', 'double', 'voidptr'] + return map(lambda x: f'{x[0]}_to_{x[1]}', + product(key_suffixes, val_suffixes)) + + +def load_native_func(fname, module, skip_check=None): + for suffix in gen_func_suffixes(): + if skip_check and skip_check(suffix): + continue + full_func_name = f'{fname}_{suffix}' + ll.add_symbol(full_func_name, + getattr(module, full_func_name)) + + +load_native_func('hashmap_create', hconc_dict) +load_native_func('hashmap_size', hconc_dict) +load_native_func('hashmap_set', hconc_dict) +load_native_func('hashmap_contains', hconc_dict) +load_native_func('hashmap_lookup', hconc_dict) +load_native_func('hashmap_clear', hconc_dict) +load_native_func('hashmap_pop', hconc_dict) +load_native_func('hashmap_update', hconc_dict) +load_native_func('hashmap_create_from_data', hconc_dict, lambda x: 'voidptr' in x) +load_native_func('hashmap_getiter', hconc_dict) +load_native_func('hashmap_iternext', hconc_dict) + + +supported_numeric_key_types = [ + types.int32, + types.uint32, + types.int64, + types.uint64 +] + + +supported_numeric_value_types = [ + types.int32, + types.uint32, + types.int64, + types.uint64, + types.float32, + types.float64, +] + + +# to avoid over-specialization native hashmap structs always use signed integers +# and function arguments and return values are converted when needed +reduced_type_map = { + types.int32: types.int32, + types.uint32: types.int32, + types.int64: types.int64, + types.uint64: types.int64, +} + + +map_numba_type_to_prefix = { + types.int32: 'int32_t', + types.int64: 'int64_t', + types.float32: 'float', + types.float64: 'double', +} + + +def _get_types_postfixes(key_type, value_type): + _key_type = reduced_type_map[key_type] if isinstance(key_type, types.Integer) else key_type + _value_type = reduced_type_map[value_type] if isinstance(value_type, types.Integer) else value_type + + key_postfix = map_numba_type_to_prefix.get(_key_type, 'voidptr') + value_postfix = map_numba_type_to_prefix.get(_value_type, 'voidptr') + + return (key_postfix, value_postfix) + + +def gen_deref_voidptr(key_type): + @intrinsic + def deref_voidptr(typingctx, data_ptr_type): + if data_ptr_type is not types.voidptr: + return None + + ret_type = key_type + + def codegen(context, builder, sig, args): + str_val, = args + + ty_ret_type_pointer = lir.PointerType(context.get_data_type(ret_type)) + casted_ptr = builder.bitcast(str_val, ty_ret_type_pointer) + + return impl_ret_borrowed(context, builder, ret_type, builder.load(casted_ptr)) + + return ret_type(types.voidptr), codegen + + return deref_voidptr + + +def gen_hash_compare_ops(key_type): + + deref_voidptr = gen_deref_voidptr(key_type) + c_sig_hash = types.uintp(types.voidptr) + c_sig_eq = types.boolean(types.voidptr, types.voidptr) + + @cfunc(c_sig_hash) + def hash_func_adaptor(voidptr_to_data): + obj = deref_voidptr(voidptr_to_data) + return hash(obj) + + @cfunc(c_sig_eq) + def eq_func_adaptor(lhs_ptr, rhs_ptr): + lhs_str = deref_voidptr(lhs_ptr) + rhs_str = deref_voidptr(rhs_ptr) + return lhs_str == rhs_str + + hasher_ptr = hash_func_adaptor.address + eq_ptr = eq_func_adaptor.address + + return hasher_ptr, eq_ptr + + +@intrinsic +def call_incref(typingctx, val_type): + ret_type = types.void + + def codegen(context, builder, sig, args): + [arg_val] = args + [arg_type] = sig.args + + if context.enable_nrt: + context.nrt.incref(builder, arg_type, arg_val) + + return ret_type(val_type), codegen + + +@intrinsic +def call_decref(typingctx, val_type): + ret_type = types.void + + def codegen(context, builder, sig, args): + [arg_val] = args + [arg_type] = sig.args + + if context.enable_nrt: + context.nrt.decref(builder, arg_type, arg_val) + + return ret_type(val_type), codegen + + +def gen_incref_decref_ops(key_type): + + deref_voidptr = gen_deref_voidptr(key_type) + c_sig_incref = types.void(types.voidptr) + c_sig_decref = types.void(types.voidptr) + + @cfunc(c_sig_incref) + def incref_func_adaptor(voidptr_to_data): + obj = deref_voidptr(voidptr_to_data) + return call_incref(obj) + + @cfunc(c_sig_decref) + def decref_func_adaptor(voidptr_to_data): + obj = deref_voidptr(voidptr_to_data) + return call_decref(obj) + + incref_ptr = incref_func_adaptor.address + decref_ptr = decref_func_adaptor.address + + return incref_ptr, decref_ptr + + +def codegen_get_voidptr(context, builder, ty_var, var_val): + dm_key = context.data_model_manager[ty_var] + data_val = dm_key.as_data(builder, var_val) + ptr_var = cgutils.alloca_once_value(builder, data_val) + val_as_voidptr = _as_bytes(builder, ptr_var) + + return val_as_voidptr + + +def transform_input_arg(context, builder, ty_arg, val): + """ This function should adjust key to satisfy argument type of native function to + which it will be passed later """ + + if isinstance(ty_arg, types.Number): + arg_native_type = reduced_type_map.get(ty_arg, ty_arg) + key_val = val + if ty_arg is not arg_native_type: + key_val = context.cast(builder, key_val, ty_arg, arg_native_type) + lir_key_type = context.get_value_type(arg_native_type) + else: + key_val = codegen_get_voidptr(context, builder, ty_arg, val) + lir_key_type = context.get_value_type(types.voidptr) + + return (key_val, lir_key_type) + + +def alloc_native_value(context, builder, ty_arg): + """ This function allocates argument to be used as return value of a native function """ + + if isinstance(ty_arg, types.Number): + native_arg_type = reduced_type_map.get(ty_arg, ty_arg) + else: + native_arg_type = types.voidptr + + lir_val_type = context.get_value_type(native_arg_type) + ret_val_ptr = cgutils.alloca_once(builder, lir_val_type) + + return (ret_val_ptr, lir_val_type) + + +def transform_native_val(context, builder, ty_arg, val): + """ This function should cast value returned from native func back to dicts value_type """ + + if isinstance(ty_arg, types.Number): + reduced_value_type = reduced_type_map.get(ty_arg, ty_arg) + result_value = context.cast(builder, val, reduced_value_type, ty_arg) + else: + # for values stored as void* in native dict we also need to dereference + lir_typed_value_ptr = context.get_value_type(ty_arg).as_pointer() + casted_ptr = builder.bitcast(val, lir_typed_value_ptr) + result_value = builder.load(casted_ptr) + + return result_value + + +@intrinsic +def hashmap_create(typingctx, key, value): + + key_numeric = isinstance(key, types.NumberClass) + val_numeric = isinstance(value, types.NumberClass) + dict_key_type = key.dtype if key_numeric else key.instance_type + dict_val_type = value.dtype if val_numeric else value.instance_type + dict_type = ConcurrentDictType(dict_key_type, dict_val_type) + + hash_func_addr, eq_func_addr = gen_hash_compare_ops(dict_key_type) + key_incref_func_addr, key_decref_func_addr = gen_incref_decref_ops(dict_key_type) + val_incref_func_addr, val_decref_func_addr = gen_incref_decref_ops(dict_val_type) + + key_type_postfix, value_type_postfix = _get_types_postfixes(dict_key_type, dict_val_type) + + def codegen(context, builder, sig, args): + nrt_table = context.nrt.get_nrt_api(builder) + + llptrtype = context.get_value_type(types.intp) + cdict = cgutils.create_struct_proxy(sig.return_type)(context, builder) + fnty = lir.FunctionType(lir.VoidType(), + [cdict.meminfo.type.as_pointer(), # meminfo to fill + lir.IntType(8).as_pointer(), # NRT API func table + lir.IntType(8), lir.IntType(8), # gen_key, gen_value flags + llptrtype, llptrtype, # hash_func, equality func + llptrtype, llptrtype, # key incref, decref + llptrtype, llptrtype, # val incref, decref + lir.IntType(64), lir.IntType(64)]) # key size, val size + func_name = f"hashmap_create_{key_type_postfix}_to_{value_type_postfix}" + fn_hashmap_create = builder.module.get_or_insert_function( + fnty, name=func_name) + + gen_key = context.get_constant(types.int8, types.int8(not key_numeric)) + gen_val = context.get_constant(types.int8, types.int8(not val_numeric)) + + lir_key_type = context.get_value_type(dict_key_type) + hash_func_addr_const = context.get_constant(types.intp, hash_func_addr) + eq_func_addr_const = context.get_constant(types.intp, eq_func_addr) + key_incref = context.get_constant(types.intp, key_incref_func_addr) + key_decref = context.get_constant(types.intp, key_decref_func_addr) + key_type_size = context.get_constant(types.int64, context.get_abi_sizeof(lir_key_type)) + + lir_val_type = context.get_value_type(dict_val_type) + val_incref = context.get_constant(types.intp, val_incref_func_addr) + val_decref = context.get_constant(types.intp, val_decref_func_addr) + val_type_size = context.get_constant(types.int64, context.get_abi_sizeof(lir_val_type)) + + builder.call(fn_hashmap_create, + [cdict._get_ptr_by_name('meminfo'), + nrt_table, + gen_key, + gen_val, + hash_func_addr_const, + eq_func_addr_const, + key_incref, + key_decref, + val_incref, + val_decref, + key_type_size, + val_type_size]) + + cdict.data_ptr = context.nrt.meminfo_data(builder, cdict.meminfo) + return cdict._getvalue() + + return dict_type(key, value), codegen + + +@overload_method(SdcTypeRef, 'empty') +def concurrent_dict_empty(cls, key_type, value_type): + + if cls.instance_type is not ConcurrentDictType: + return + + _func_name = 'Method SdcTypeRef::empty().' + ty_checker = TypeChecker(_func_name) + + supported_key_types = (types.NumberClass, types.TypeRef) + supported_value_types = (types.NumberClass, types.TypeRef) + + if not isinstance(key_type, supported_key_types): + ty_checker.raise_exc(key_type, f'Numba type of dict keys (e.g. types.int32)', 'key_type') + if not isinstance(value_type, supported_value_types): + ty_checker.raise_exc(value_type, f'Numba type of dict values (e.g. types.int32)', 'value_type') + + if (isinstance(key_type, types.NumberClass) + and key_type.dtype not in supported_numeric_key_types or + isinstance(key_type, types.TypeRef) + and not isinstance(key_type.instance_type, (types.UnicodeType, types.Hashable) or + isinstance(value_type, types.NumberClass) + and value_type.dtype not in supported_numeric_value_types)): + error_msg = '{} SDC ConcurrentDict({}, {}) is not supported. ' + raise TypingError(error_msg.format(_func_name, key_type, value_type)) + + def concurrent_dict_empty_impl(cls, key_type, value_type): + return hashmap_create(key_type, value_type) + + return concurrent_dict_empty_impl + + +@intrinsic +def hashmap_size(typingctx, dict_type): + + ty_key, ty_val = dict_type.key_type, dict_type.value_type + key_type_postfix, value_type_postfix = _get_types_postfixes(ty_key, ty_val) + + def codegen(context, builder, sig, args): + dict_val, = args + + cdict = cgutils.create_struct_proxy(dict_type)( + context, builder, value=dict_val) + fnty = lir.FunctionType(lir.IntType(64), + [lir.IntType(8).as_pointer()]) + func_name = f"hashmap_size_{key_type_postfix}_to_{value_type_postfix}" + fn_hashmap_size = builder.module.get_or_insert_function( + fnty, name=func_name) + ret = builder.call(fn_hashmap_size, [cdict.data_ptr]) + return ret + + return types.uint64(dict_type), codegen + + +@overload(len) +def concurrent_dict_len_ovld(cdict): + if not isinstance(cdict, ConcurrentDictType): + return None + + def concurrent_dict_len_impl(cdict): + return hashmap_size(cdict) + + return concurrent_dict_len_impl + + +@intrinsic +def hashmap_set(typingctx, dict_type, key_type, value_type): + + key_type_postfix, value_type_postfix = _get_types_postfixes(key_type, value_type) + + def codegen(context, builder, sig, args): + dict_val, key_val, value_val = args + + key_val, lir_key_type = transform_input_arg(context, builder, key_type, key_val) + val_val, lir_val_type = transform_input_arg(context, builder, value_type, value_val) + + cdict = cgutils.create_struct_proxy(dict_type)( + context, builder, value=dict_val) + fnty = lir.FunctionType(lir.VoidType(), + [lir.IntType(8).as_pointer(), + lir_key_type, + lir_val_type]) + + func_name = f"hashmap_set_{key_type_postfix}_to_{value_type_postfix}" + fn_hashmap_insert = builder.module.get_or_insert_function( + fnty, name=func_name) + + builder.call(fn_hashmap_insert, [cdict.data_ptr, key_val, val_val]) + return + + return types.void(dict_type, key_type, value_type), codegen + + +@overload(operator.setitem, prefer_literal=False) +def concurrent_dict_set_ovld(self, key, value): + if not isinstance(self, ConcurrentDictType): + return None + + dict_key_type, dict_value_type = self.key_type, self.value_type + cast_key = key is not dict_key_type + cast_value = value is not dict_value_type + + def concurrent_dict_set_impl(self, key, value): + _key = key if cast_key == False else _cast(key, dict_key_type) # noqa + _value = value if cast_value == False else _cast(value, dict_value_type) # noqa + return hashmap_set(self, _key, _value) + + return concurrent_dict_set_impl + + +@intrinsic +def hashmap_contains(typingctx, dict_type, key_type): + + ty_key, ty_val = dict_type.key_type, dict_type.value_type + key_type_postfix, value_type_postfix = _get_types_postfixes(ty_key, ty_val) + + def codegen(context, builder, sig, args): + dict_val, key_val = args + + key_val, lir_key_type = transform_input_arg(context, builder, key_type, key_val) + cdict = cgutils.create_struct_proxy(dict_type)( + context, builder, value=dict_val) + fnty = lir.FunctionType(lir.IntType(8), + [lir.IntType(8).as_pointer(), + lir_key_type]) + func_name = f"hashmap_contains_{key_type_postfix}_to_{value_type_postfix}" + fn_hashmap_contains = builder.module.get_or_insert_function( + fnty, name=func_name) + + res = builder.call(fn_hashmap_contains, [cdict.data_ptr, key_val]) + return context.cast(builder, res, types.uint8, types.bool_) + + return types.bool_(dict_type, key_type), codegen + + +@overload(operator.contains, prefer_literal=False) +def concurrent_dict_contains_ovld(self, key): + if not isinstance(self, ConcurrentDictType): + return None + + dict_key_type = self.key_type + cast_key = key is not dict_key_type + + def concurrent_dict_contains_impl(self, key): + _key = key if cast_key == False else _cast(key, dict_key_type) # noqa + return hashmap_contains(self, _key) + + return concurrent_dict_contains_impl + + +@intrinsic +def hashmap_lookup(typingctx, dict_type, key_type): + + ty_key, ty_val = dict_type.key_type, dict_type.value_type + return_type = types.Tuple([types.bool_, types.Optional(ty_val)]) + key_type_postfix, value_type_postfix = _get_types_postfixes(ty_key, ty_val) + + def codegen(context, builder, sig, args): + dict_val, key_val = args + + key_val, lir_key_type = transform_input_arg(context, builder, key_type, key_val) + native_value_ptr, lir_value_type = alloc_native_value(context, builder, ty_val) + + cdict = cgutils.create_struct_proxy(dict_type)(context, builder, value=dict_val) + fnty = lir.FunctionType(lir.IntType(8), + [lir.IntType(8).as_pointer(), + lir_key_type, + lir_value_type.as_pointer() + ]) + func_name = f"hashmap_lookup_{key_type_postfix}_to_{value_type_postfix}" + fn_hashmap_lookup = builder.module.get_or_insert_function( + fnty, name=func_name) + + status = builder.call(fn_hashmap_lookup, [cdict.data_ptr, key_val, native_value_ptr]) + status_as_bool = context.cast(builder, status, types.uint8, types.bool_) + + # if key was not found nothing would be stored to native_value_ptr, so depending on status + # we either deref it or not, wrapping final result into types.Optional value + result_ptr = cgutils.alloca_once(builder, + context.get_value_type(types.Optional(ty_val))) + with builder.if_else(status_as_bool, likely=True) as (if_ok, if_not_ok): + with if_ok: + native_value = builder.load(native_value_ptr) + result_value = transform_native_val(context, builder, ty_val, native_value) + + if context.enable_nrt: + context.nrt.incref(builder, ty_val, result_value) + + builder.store(context.make_optional_value(builder, ty_val, result_value), + result_ptr) + + with if_not_ok: + builder.store(context.make_optional_none(builder, ty_val), + result_ptr) + + opt_result = builder.load(result_ptr) + return context.make_tuple(builder, return_type, [status_as_bool, opt_result]) + + func_sig = return_type(dict_type, key_type) + return func_sig, codegen + + +@overload(operator.getitem, prefer_literal=False) +def concurrent_dict_lookup_ovld(self, key): + if not isinstance(self, ConcurrentDictType): + return None + + dict_key_type = self.key_type + cast_key = key is not dict_key_type + + def concurrent_dict_lookup_impl(self, key): + _key = key if cast_key == False else _cast(key, dict_key_type) # noqa + found, res = hashmap_lookup(self, _key) + + # Note: this function raises exception so expect no scaling if you use it in prange + if not found: + raise KeyError("ConcurrentDict key not found") + return res + + return concurrent_dict_lookup_impl + + +@intrinsic +def hashmap_clear(typingctx, dict_type): + + ty_key, ty_val = dict_type.key_type, dict_type.value_type + key_type_postfix, value_type_postfix = _get_types_postfixes(ty_key, ty_val) + + def codegen(context, builder, sig, args): + dict_val, = args + + cdict = cgutils.create_struct_proxy(dict_type)( + context, builder, value=dict_val) + fnty = lir.FunctionType(lir.VoidType(), + [lir.IntType(8).as_pointer()]) + func_name = f"hashmap_clear_{key_type_postfix}_to_{value_type_postfix}" + fn_hashmap_clear = builder.module.get_or_insert_function( + fnty, name=func_name) + builder.call(fn_hashmap_clear, [cdict.data_ptr]) + return + + return types.void(dict_type), codegen + + +@overload_method(ConcurrentDictType, 'clear') +def concurrent_dict_clear_ovld(self): + if not isinstance(self, ConcurrentDictType): + return None + + def concurrent_dict_clear_impl(self): + hashmap_clear(self) + + return concurrent_dict_clear_impl + + +@overload_method(ConcurrentDictType, 'get') +def concurrent_dict_get_ovld(self, key, default=None): + if not isinstance(self, ConcurrentDictType): + return None + + _func_name = f'Method {self}::get()' + ty_checker = TypeChecker(_func_name) + + # default value is expected to be of the same (or safely casted) type as dict's value_type + no_default = isinstance(default, (types.NoneType, types.Omitted)) or default is None + default_is_optional = isinstance(default, types.Optional) + if not (no_default or check_types_comparable(default, self.value_type) + or default_is_optional and check_types_comparable(default.type, self.value_type)): + ty_checker.raise_exc(default, f'{self.value_type} or convertible or None', 'default') + + dict_key_type, dict_value_type = self.key_type, self.value_type + cast_key = key is not dict_key_type + + def concurrent_dict_get_impl(self, key, default=None): + _key = key if cast_key == False else _cast(key, dict_key_type) # noqa + found, res = hashmap_lookup(self, _key) + + if not found: + # just to make obvious that return type is types.Optional(dict.value_type) + if no_default == False: # noqa + return _cast(default, dict_value_type) + else: + return None + return res + + return concurrent_dict_get_impl + + +@intrinsic +def hashmap_pop(typingctx, dict_type, key_type): + + ty_key, ty_val = dict_type.key_type, dict_type.value_type + return_type = types.Tuple([types.bool_, types.Optional(ty_val)]) + key_type_postfix, value_type_postfix = _get_types_postfixes(ty_key, ty_val) + + def codegen(context, builder, sig, args): + dict_val, key_val = args + + key_val, lir_key_type = transform_input_arg(context, builder, key_type, key_val) + + # unlike in lookup operation we allocate value here and pass into native function + # voidptr to allocated data, which copies and frees it's copy + if isinstance(ty_val, types.Number): + ret_val_ptr, lir_val_type = alloc_native_value(context, builder, ty_val) + else: + lir_val_type = context.get_value_type(ty_val) + ret_val_ptr = cgutils.alloca_once(builder, lir_val_type) + + llvoidptr = context.get_value_type(types.voidptr) + ret_val_ptr = builder.bitcast(ret_val_ptr, llvoidptr) + + cdict = cgutils.create_struct_proxy(dict_type)(context, builder, value=dict_val) + fnty = lir.FunctionType(lir.IntType(8), + [lir.IntType(8).as_pointer(), + lir_key_type, + llvoidptr, + ]) + func_name = f"hashmap_pop_{key_type_postfix}_to_{value_type_postfix}" + fn_hashmap_pop = builder.module.get_or_insert_function( + fnty, name=func_name) + + status = builder.call(fn_hashmap_pop, [cdict.data_ptr, key_val, ret_val_ptr]) + status_as_bool = context.cast(builder, status, types.uint8, types.bool_) + + # same logic to handle non-existing key as in hashmap_lookup + result_ptr = cgutils.alloca_once(builder, + context.get_value_type(types.Optional(ty_val))) + with builder.if_else(status_as_bool, likely=True) as (if_ok, if_not_ok): + with if_ok: + + ret_val_ptr = builder.bitcast(ret_val_ptr, lir_val_type.as_pointer()) + native_value = builder.load(ret_val_ptr) + if isinstance(ty_val, types.Number): + reduced_value_type = reduced_type_map.get(ty_val, ty_val) + native_value = context.cast(builder, native_value, reduced_value_type, ty_val) + + # no incref of the value here, since it was removed from the dict + # w/o decref to consider the case when value in the dict had refcnt == 1 + + builder.store(context.make_optional_value(builder, ty_val, native_value), + result_ptr) + + with if_not_ok: + builder.store(context.make_optional_none(builder, ty_val), + result_ptr) + + opt_result = builder.load(result_ptr) + return context.make_tuple(builder, return_type, [status_as_bool, opt_result]) + + func_sig = return_type(dict_type, key_type) + return func_sig, codegen + + +@overload_method(ConcurrentDictType, 'pop', prefer_literal=False) +def concurrent_dict_pop_ovld(self, key, default=None): + if not isinstance(self, ConcurrentDictType): + return None + + _func_name = f'Method {self}::pop()' + ty_checker = TypeChecker(_func_name) + + # default value is expected to be of the same (or safely casted) type as dict's value_type + no_default = isinstance(default, (types.NoneType, types.Omitted)) or default is None + default_is_optional = isinstance(default, types.Optional) + if not (no_default or check_types_comparable(default, self.value_type) + or default_is_optional and check_types_comparable(default.type, self.value_type)): + ty_checker.raise_exc(default, f'{self.value_type} or convertible or None', 'default') + + dict_key_type, dict_value_type = self.key_type, self.value_type + cast_key = key is not dict_key_type + + def concurrent_dict_pop_impl(self, key, default=None): + _key = key if cast_key == False else _cast(key, dict_key_type) # noqa + found, res = hashmap_pop(self, _key) + + if not found: + if no_default == False: # noqa + return _cast(default, dict_value_type) + else: + return None + return res + + return concurrent_dict_pop_impl + + +@intrinsic +def hashmap_update(typingctx, dict_type, other_dict_type): + + ty_key, ty_val = dict_type.key_type, dict_type.value_type + return_type = types.void + key_type_postfix, value_type_postfix = _get_types_postfixes(ty_key, ty_val) + + def codegen(context, builder, sig, args): + dict_val, other_dict_val = args + + self_cdict = cgutils.create_struct_proxy(dict_type)(context, builder, value=dict_val) + other_cdict = cgutils.create_struct_proxy(other_dict_type)(context, builder, value=other_dict_val) + fnty = lir.FunctionType(lir.IntType(8), + [lir.IntType(8).as_pointer(), + lir.IntType(8).as_pointer() + ]) + func_name = f"hashmap_update_{key_type_postfix}_to_{value_type_postfix}" + fn_hashmap_update = builder.module.get_or_insert_function( + fnty, name=func_name) + + builder.call(fn_hashmap_update, [self_cdict.data_ptr, other_cdict.data_ptr]) + return + + func_sig = return_type(dict_type, other_dict_type) + return func_sig, codegen + + +@overload_method(ConcurrentDictType, 'update', prefer_literal=False) +def concurrent_dict_update_ovld(self, other): + if not ((self, ConcurrentDictType) and isinstance(other, ConcurrentDictType)): + return None + + _func_name = f'Method {self}::update()' + ty_checker = TypeChecker(_func_name) + + if self is not other: + ty_checker.raise_exc(other, f'{self}', 'other') + + def concurrent_dict_update_impl(self, other): + return hashmap_update(self, other) + + return concurrent_dict_update_impl + + +@overload_method(ConcurrentDictType, 'fromkeys', prefer_literal=False) +def concurrent_dict_fromkeys_ovld(self, keys, value): + if not isinstance(self, ConcurrentDictType): + return None + + def wrapper_impl(self, keys, value): + return ConcurrentDict.fromkeys(keys, value) + + return wrapper_impl + + +@register_jitable +def get_min_size(A, B): + return min(len(A), len(B)) + + +@intrinsic +def create_from_arrays(typingctx, keys, values): + + ty_key, ty_val = keys.dtype, values.dtype + dict_type = ConcurrentDictType(ty_key, ty_val) + key_type_postfix, value_type_postfix = _get_types_postfixes(ty_key, ty_val) + get_min_size_sig = signature(types.int64, keys, values) + + def codegen(context, builder, sig, args): + keys_val, values_val = args + nrt_table = context.nrt.get_nrt_api(builder) + + keys_ctinfo = context.make_helper(builder, keys, keys_val) + values_ctinfo = context.make_helper(builder, values, values_val) + size_val = context.compile_internal( + builder, + lambda k, v: get_min_size(k, v), + get_min_size_sig, + [keys_val, values_val] + ) + + # create concurrent dict struct and call native ctor filling meminfo + lir_key_type = context.get_value_type(reduced_type_map.get(ty_key, ty_key)) + lir_value_type = context.get_value_type(reduced_type_map.get(ty_val, ty_val)) + cdict = cgutils.create_struct_proxy(dict_type)(context, builder) + fnty = lir.FunctionType(lir.VoidType(), + [cdict.meminfo.type.as_pointer(), # meminfo to fill + lir.IntType(8).as_pointer(), # NRT API func table + lir_key_type.as_pointer(), # array of keys + lir_value_type.as_pointer(), # array of values + lir.IntType(64), # size + ]) + func_name = f"hashmap_create_from_data_{key_type_postfix}_to_{value_type_postfix}" + fn_hashmap_create = builder.module.get_or_insert_function( + fnty, name=func_name) + builder.call(fn_hashmap_create, + [cdict._get_ptr_by_name('meminfo'), + nrt_table, + keys_ctinfo.data, + values_ctinfo.data, + size_val + ]) + cdict.data_ptr = context.nrt.meminfo_data(builder, cdict.meminfo) + return cdict._getvalue() + + return dict_type(keys, values), codegen + + +@overload_method(SdcTypeRef, 'from_arrays') +def concurrent_dict_from_arrays_ovld(cls, keys, values): + if cls.instance_type is not ConcurrentDictType: + return + + _func_name = f'Method ConcurrentDict::from_arrays()' + if not (isinstance(keys, types.Array) and keys.ndim == 1 + and isinstance(values, types.Array) and values.ndim == 1): + raise TypingError('{} Supported only with 1D arrays of keys and values' + 'Given: keys={}, values={}'.format(_func_name, keys, values)) + + def concurrent_dict_from_arrays_impl(cls, keys, values): + return create_from_arrays(keys, values) + + return concurrent_dict_from_arrays_impl + + +@overload_method(SdcTypeRef, 'fromkeys', prefer_literal=False) +def concurrent_dict_type_fromkeys_ovld(cls, keys, value): + if cls.instance_type is not ConcurrentDictType: + return + + _func_name = f'Method ConcurrentDict::fromkeys()' + ty_checker = TypeChecker(_func_name) + + valid_keys_types = (types.Sequence, types.Array, StringArrayType) + if not isinstance(keys, valid_keys_types): + ty_checker.raise_exc(keys, f'array or sequence', 'keys') + + dict_key_type, dict_value_type = keys.dtype, value + if isinstance(keys, (types.Array, StringArrayType)): + def concurrent_dict_fromkeys_impl(cls, keys, value): + res = ConcurrentDict.empty(dict_key_type, dict_value_type) + for i in numba.prange(len(keys)): + res[keys[i]] = value + return res + else: # generic for all other iterables + def concurrent_dict_fromkeys_impl(cls, keys, value): + res = ConcurrentDict.empty(dict_key_type, dict_value_type) + for k in keys: + res[k] = value + return res + + return concurrent_dict_fromkeys_impl + + +@intrinsic +def _hashmap_dump(typingctx, dict_type): + + # load hashmap_dump here as otherwise module import will fail + # since it's included in debug build only + load_native_func('hashmap_dump', hconc_dict) + ty_key, ty_val = dict_type.key_type, dict_type.value_type + key_type_postfix, value_type_postfix = _get_types_postfixes(ty_key, ty_val) + + def codegen(context, builder, sig, args): + dict_val, = args + + cdict = cgutils.create_struct_proxy(dict_type)( + context, builder, value=dict_val) + fnty = lir.FunctionType(lir.VoidType(), + [lir.IntType(8).as_pointer()]) + func_name = f"hashmap_dump_{key_type_postfix}_to_{value_type_postfix}" + fn_hashmap_dump = builder.module.get_or_insert_function( + fnty, name=func_name) + builder.call(fn_hashmap_dump, [cdict.data_ptr]) + return + + return types.void(dict_type), codegen + + +def _iterator_codegen(resty): + """The common codegen for iterator intrinsics. + + Populates the iterator struct and increfs. + """ + + def codegen(context, builder, sig, args): + [d] = args + [td] = sig.args + iterhelper = context.make_helper(builder, resty) + iterhelper.parent = d + iterhelper.state = iterhelper.state.type(None) + return impl_ret_borrowed( + context, + builder, + resty, + iterhelper._getvalue(), + ) + + return codegen + + +@intrinsic +def _conc_dict_items(typingctx, d): + """Get dictionary iterator for .items()""" + resty = ConcDictItemsIterableType(d) + sig = resty(d) + codegen = _iterator_codegen(resty) + return sig, codegen + + +@intrinsic +def _conc_dict_keys(typingctx, d): + """Get dictionary iterator for .keys()""" + resty = ConcDictKeysIterableType(d) + sig = resty(d) + codegen = _iterator_codegen(resty) + return sig, codegen + + +@intrinsic +def _conc_dict_values(typingctx, d): + """Get dictionary iterator for .values()""" + resty = ConcDictValuesIterableType(d) + sig = resty(d) + codegen = _iterator_codegen(resty) + return sig, codegen + + +@overload_method(ConcurrentDictType, 'items') +def impl_items(d): + if not isinstance(d, ConcurrentDictType): + return + + def impl(d): + it = _conc_dict_items(d) + return it + + return impl + + +@overload_method(ConcurrentDictType, 'keys') +def impl_keys(d): + if not isinstance(d, ConcurrentDictType): + return + + def impl(d): + return _conc_dict_keys(d) + + return impl + + +@overload_method(ConcurrentDictType, 'values') +def impl_values(d): + if not isinstance(d, ConcurrentDictType): + return + + def impl(d): + return _conc_dict_values(d) + + return impl + + +def call_native_getiter(context, builder, dict_type, dict_val, it): + """ This function should produce LLVM code for calling native + hashmap_getiter and fill iterator data accordingly """ + + ty_key, ty_val = dict_type.key_type, dict_type.value_type + key_type_postfix, value_type_postfix = _get_types_postfixes(ty_key, ty_val) + + nrt_table = context.nrt.get_nrt_api(builder) + llvoidptr = context.get_value_type(types.voidptr) + fnty = lir.FunctionType(llvoidptr, + [it.meminfo.type.as_pointer(), + llvoidptr, + llvoidptr]) + func_name = f"hashmap_getiter_{key_type_postfix}_to_{value_type_postfix}" + fn_hashmap_getiter = builder.module.get_or_insert_function( + fnty, name=func_name) + + cdict = cgutils.create_struct_proxy(dict_type)(context, builder, value=dict_val) + it.state = builder.call(fn_hashmap_getiter, + [it._get_ptr_by_name('meminfo'), + nrt_table, + cdict.data_ptr]) + + # store the reference to parent and incref + it.parent = dict_val + if context.enable_nrt: + context.nrt.incref(builder, dict_type, dict_val) + + +@lower_builtin('getiter', ConcDictItemsIterableType) +@lower_builtin('getiter', ConcDictKeysIterableType) +@lower_builtin('getiter', ConcDictValuesIterableType) +def impl_iterable_getiter(context, builder, sig, args): + """Implement iter() for .keys(), .values(), .items() + """ + iterablety, = sig.args + iter_val, = args + + # iter_val is an empty dict iterator created with call to _iterator_codegen() + # this iterator has no state or meminfo filled yet, only parent (i.e. dict), + # which we use to make actual call + dict_type = iterablety.parent + it = context.make_helper(builder, iterablety.iterator_type, iter_val) + call_native_getiter(context, builder, dict_type, it.parent, it) + + # this is a new NRT managed iterator, so no need to use impl_ret_borrowed + return it._getvalue() + + +@lower_builtin('getiter', ConcurrentDictType) +def impl_conc_dict_getiter(context, builder, sig, args): + dict_type, = sig.args + dict_val, = args + + iterablety = ConcDictKeysIterableType(dict_type) + it = context.make_helper(builder, iterablety.iterator_type) + call_native_getiter(context, builder, dict_type, dict_val, it) + + # this is a new NRT managed iterator, so no need to use impl_ret_borrowed + return it._getvalue() + + +@lower_builtin('iternext', ConcDictIteratorType) +@iternext_impl(RefType.BORROWED) +def impl_iterator_iternext(context, builder, sig, args, result): + iter_type, = sig.args + iter_val, = args + + dict_type = iter_type.parent + ty_key, ty_val = dict_type.key_type, dict_type.value_type + key_type_postfix, value_type_postfix = _get_types_postfixes(ty_key, ty_val) + + native_key_ptr, lir_key_type = alloc_native_value(context, builder, ty_key) + native_value_ptr, lir_value_type = alloc_native_value(context, builder, ty_val) + + llvoidptr = context.get_value_type(types.voidptr) + fnty = lir.FunctionType(lir.IntType(8), + [llvoidptr, + lir_key_type.as_pointer(), + lir_value_type.as_pointer()]) + func_name = f"hashmap_iternext_{key_type_postfix}_to_{value_type_postfix}" + fn_hashmap_iternext = builder.module.get_or_insert_function( + fnty, name=func_name) + + iter_ctinfo = context.make_helper(builder, iter_type, iter_val) + status = builder.call(fn_hashmap_iternext, + [iter_ctinfo.state, + native_key_ptr, + native_value_ptr]) + + # TODO: no handling of error state i.e. mutated dictionary + # all errors are treated as exhausted iterator + is_valid = builder.icmp_unsigned('==', status, status.type(0)) + result.set_valid(is_valid) + + with builder.if_then(is_valid): + yield_type = iter_type.yield_type + + native_key = builder.load(native_key_ptr) + result_key = transform_native_val(context, builder, ty_key, native_key) + + native_val = builder.load(native_value_ptr) + result_val = transform_native_val(context, builder, ty_val, native_val) + + # All dict iterators use this common implementation. + # Their differences are resolved here. + if isinstance(iter_type.iterable, ConcDictItemsIterableType): + # .items() + tup = context.make_tuple(builder, yield_type, [result_key, result_val]) + result.yield_(tup) + elif isinstance(iter_type.iterable, ConcDictKeysIterableType): + # .keys() + result.yield_(result_key) + elif isinstance(iter_type.iterable, ConcDictValuesIterableType): + # .values() + result.yield_(result_val) + else: + # unreachable + raise AssertionError('unknown type: {}'.format(iter_type.iterable)) diff --git a/sdc/extensions/sdc_hashmap_type.py b/sdc/extensions/sdc_hashmap_type.py new file mode 100644 index 000000000..b54c49b56 --- /dev/null +++ b/sdc/extensions/sdc_hashmap_type.py @@ -0,0 +1,193 @@ +# ***************************************************************************** +# Copyright (c) 2021, Intel Corporation All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +from numba.core.typing.templates import ( + infer_global, AbstractTemplate, signature, + ) +from numba.extending import type_callable, lower_builtin +from numba import types +from numba.extending import (models, register_model, make_attribute_wrapper, overload_method) +from sdc.str_ext import string_type + +from collections.abc import MutableMapping +from numba.core.types import Dummy, IterableType, SimpleIterableType, SimpleIteratorType + +from numba.extending import typeof_impl +from numba.typed import Dict +from numba.core.typing.typeof import _typeof_type as numba_typeof_type + + +class ConcDictIteratorType(SimpleIteratorType): + def __init__(self, iterable): + self.parent = iterable.parent + self.iterable = iterable + yield_type = iterable.yield_type + name = "iter[{}->{}],{}".format( + iterable.parent, yield_type, iterable.name + ) + super(ConcDictIteratorType, self).__init__(name, yield_type) + + +class ConcDictKeysIterableType(SimpleIterableType): + """Concurrent Dictionary iterable type for .keys() + """ + + def __init__(self, parent): + assert isinstance(parent, ConcurrentDictType) + self.parent = parent + self.yield_type = self.parent.key_type + name = "keys[{}]".format(self.parent.name) + self.name = name + iterator_type = ConcDictIteratorType(self) + super(ConcDictKeysIterableType, self).__init__(name, iterator_type) + + +class ConcDictItemsIterableType(SimpleIterableType): + """Concurrent Dictionary iterable type for .items() + """ + + def __init__(self, parent): + assert isinstance(parent, ConcurrentDictType) + self.parent = parent + self.yield_type = self.parent.keyvalue_type + name = "items[{}]".format(self.parent.name) + self.name = name + iterator_type = ConcDictIteratorType(self) + super(ConcDictItemsIterableType, self).__init__(name, iterator_type) + + +class ConcDictValuesIterableType(SimpleIterableType): + """Concurrent Dictionary iterable type for .values() + """ + + def __init__(self, parent): + assert isinstance(parent, ConcurrentDictType) + self.parent = parent + self.yield_type = self.parent.value_type + name = "values[{}]".format(self.parent.name) + self.name = name + iterator_type = ConcDictIteratorType(self) + super(ConcDictValuesIterableType, self).__init__(name, iterator_type) + + +@register_model(ConcDictItemsIterableType) +@register_model(ConcDictKeysIterableType) +@register_model(ConcDictValuesIterableType) +@register_model(ConcDictIteratorType) +class ConcDictIterModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ('parent', fe_type.parent), # reference to the dict + ('state', types.voidptr), # iterator state in native code + ('meminfo', types.MemInfoPointer(types.voidptr)), # meminfo for allocated iter state + ] + super(ConcDictIterModel, self).__init__(dmm, fe_type, members) + + +class ConcurrentDictType(IterableType): + def __init__(self, keyty, valty): + self.key_type = keyty + self.value_type = valty + self.keyvalue_type = types.Tuple([keyty, valty]) + super(ConcurrentDictType, self).__init__( + name='ConcurrentDictType({}, {})'.format(keyty, valty)) + + @property + def iterator_type(self): + return ConcDictKeysIterableType(self).iterator_type + + +@register_model(ConcurrentDictType) +class ConcurrentDictModel(models.StructModel): + def __init__(self, dmm, fe_type): + + members = [ + ('data_ptr', types.CPointer(types.uint8)), + ('meminfo', types.MemInfoPointer(types.voidptr)), + ] + models.StructModel.__init__(self, dmm, fe_type, members) + + +make_attribute_wrapper(ConcurrentDictType, 'data_ptr', '_data_ptr') + + +class ConcurrentDict(MutableMapping): + def __new__(cls, dcttype=None, meminfo=None): + return object.__new__(cls) + + @classmethod + def empty(cls, key_type, value_type): + return cls(dcttype=ConcurrentDictType(key_type, value_type)) + + @classmethod + def from_arrays(cls, keys, values): + return cls(dcttype=ConcurrentDictType(keys.dtype, values.dtype)) + + @classmethod + def fromkeys(cls, keys, value): + return cls(dcttype=ConcurrentDictType(keys.dtype, value)) + + def __init__(self, **kwargs): + if kwargs: + self._dict_type, self._opaque = self._parse_arg(**kwargs) + else: + self._dict_type = None + + @property + def _numba_type_(self): + if self._dict_type is None: + raise TypeError("invalid operation on untyped dictionary") + return self._dict_type + + +# FIXME_Numba#6781: due to overlapping of overload_methods for Numba TypeRef +# we have to use our new SdcTypeRef to type objects created from types.Type +# (i.e. ConcurrentDict meta-type). This should be removed once it's fixed. +class SdcTypeRef(Dummy): + """Reference to a type. + + Used when a type is passed as a value. + """ + def __init__(self, instance_type): + self.instance_type = instance_type + super(SdcTypeRef, self).__init__('sdc_typeref[{}]'.format(self.instance_type)) + + +@register_model(SdcTypeRef) +class SdcTypeRefModel(models.OpaqueModel): + def __init__(self, dmm, fe_type): + + models.OpaqueModel.__init__(self, dmm, fe_type) + + +@typeof_impl.register(type) +def mynew_typeof_type(val, c): + """ This function is a workaround for """ + + if not issubclass(val, ConcurrentDict): + return numba_typeof_type(val, c) + else: + return SdcTypeRef(ConcurrentDictType) diff --git a/sdc/native/conc_dict_module.cpp b/sdc/native/conc_dict_module.cpp new file mode 100644 index 000000000..52d91ba0c --- /dev/null +++ b/sdc/native/conc_dict_module.cpp @@ -0,0 +1,281 @@ +// ***************************************************************************** +// Copyright (c) 2021, Intel Corporation All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +// EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// ***************************************************************************** + +#include +#include "hashmap.hpp" + + +#define declare_hashmap_create(key_type, val_type, suffix) \ +void hashmap_create_##suffix(NRT_MemInfo** meminfo, \ + void* nrt_table, \ + int8_t gen_key, \ + int8_t gen_val, \ + void* hash_func_ptr, \ + void* eq_func_ptr, \ + void* key_incref_func_ptr, \ + void* key_decref_func_ptr, \ + void* val_incref_func_ptr, \ + void* val_decref_func_ptr, \ + uint64_t key_size, \ + uint64_t val_size) \ +{ \ + hashmap_create( \ + meminfo, nrt_table, \ + gen_key, gen_val, \ + hash_func_ptr, eq_func_ptr, \ + key_incref_func_ptr, key_decref_func_ptr, \ + val_incref_func_ptr, val_decref_func_ptr, \ + key_size, val_size); \ +} \ + + +#define declare_hashmap_size(key_type, val_type, suffix) \ +uint64_t hashmap_size_##suffix(void* p_hash_map) \ +{ \ + return hashmap_size(p_hash_map); \ +} \ + + +#define declare_hashmap_set(key_type, val_type, suffix) \ +void hashmap_set_##suffix(void* p_hash_map, key_type key, val_type val) \ +{ \ + hashmap_set(p_hash_map, key, val); \ +} \ + + +#define declare_hashmap_contains(key_type, val_type, suffix) \ +int8_t hashmap_contains_##suffix(void* p_hash_map, key_type key) \ +{ \ + auto p_hash_map_spec = reinterpet_hashmap_ptr(p_hash_map); \ + return p_hash_map_spec->contains(key); \ +} \ + + +#define declare_hashmap_lookup(key_type, val_type, suffix) \ +int8_t hashmap_lookup_##suffix(void* p_hash_map, key_type key, val_type* res) \ +{ \ + return hashmap_lookup(p_hash_map, key, res); \ +} \ + + +#define declare_hashmap_clear(key_type, val_type, suffix) \ +void hashmap_clear_##suffix(void* p_hash_map) \ +{ \ + return hashmap_clear(p_hash_map); \ +} \ + + +#define declare_hashmap_pop(key_type, val_type, suffix) \ +int8_t hashmap_pop_##suffix(void* p_hash_map, key_type key, val_type* res) \ +{ \ + return hashmap_unsafe_extract(p_hash_map, key, res); \ +} \ + + +#define declare_hashmap_create_from_data(key_type, val_type) \ +void hashmap_create_from_data_##key_type##_to_##val_type(NRT_MemInfo** meminfo, void* nrt_table, key_type* keys, val_type* values, int64_t size) \ +{ \ + hashmap_numeric_from_arrays(meminfo, nrt_table, keys, values, size); \ +} \ + + +#define declare_hashmap_update(key_type, val_type, suffix) \ +void hashmap_update_##suffix(void* p_self_hash_map, void* p_arg_hash_map) \ +{ \ + return hashmap_update(p_self_hash_map, p_arg_hash_map); \ +} \ + + +#ifdef SDC_DEBUG_NATIVE +#define declare_hashmap_dump(key_type, val_type, suffix) \ +void hashmap_dump_##suffix(void* p_hash_map) \ +{ \ + hashmap_dump(p_hash_map); \ +} +#else +#define declare_hashmap_dump(key_type, val_type, suffix) +#endif + + +#define declare_hashmap_getiter(key_type, val_type, suffix) \ +void* hashmap_getiter_##suffix(NRT_MemInfo** meminfo, void* nrt_table, void* p_hash_map) \ +{ \ + return hashmap_getiter(meminfo, nrt_table, p_hash_map); \ +} \ + + +#define declare_hashmap_iternext(key_type, val_type, suffix) \ +int8_t hashmap_iternext_##suffix(void* p_iter_state, key_type* ret_key, val_type* ret_val) \ +{ \ + return hashmap_iternext(p_iter_state, ret_key, ret_val); \ +} \ + + +#define declare_hashmap(key_type, val_type, suffix) \ +declare_hashmap_create(key_type, val_type, suffix); \ +declare_hashmap_size(key_type, val_type, suffix); \ +declare_hashmap_set(key_type, val_type, suffix); \ +declare_hashmap_contains(key_type, val_type, suffix); \ +declare_hashmap_lookup(key_type, val_type, suffix); \ +declare_hashmap_clear(key_type, val_type, suffix); \ +declare_hashmap_pop(key_type, val_type, suffix); \ +declare_hashmap_update(key_type, val_type, suffix); \ +declare_hashmap_getiter(key_type, val_type, suffix); \ +declare_hashmap_iternext(key_type, val_type, suffix); \ +declare_hashmap_dump(key_type, val_type, suffix); \ + + +#define REGISTER(func) PyObject_SetAttrString(m, #func, PyLong_FromVoidPtr((void*)(&func))); + + +#define register_release(suffix) \ +REGISTER(hashmap_create_##suffix) \ +REGISTER(hashmap_size_##suffix) \ +REGISTER(hashmap_set_##suffix) \ +REGISTER(hashmap_contains_##suffix) \ +REGISTER(hashmap_lookup_##suffix) \ +REGISTER(hashmap_clear_##suffix) \ +REGISTER(hashmap_pop_##suffix) \ +REGISTER(hashmap_update_##suffix) \ +REGISTER(hashmap_getiter_##suffix) \ +REGISTER(hashmap_iternext_##suffix) \ + +#define register_debug(suffix) \ +REGISTER(hashmap_dump_##suffix) + +#ifndef SDC_DEBUG_NATIVE +#define register_hashmap(suffix) register_release(suffix) +#else +#define register_hashmap(suffix) \ +register_release(suffix) \ +register_debug(suffix) +#endif + + +extern "C" +{ + +// declare all hashmap methods for below combinations of key-value types +declare_hashmap(int32_t, int32_t, int32_t_to_int32_t) +declare_hashmap(int32_t, int64_t, int32_t_to_int64_t) +declare_hashmap(int32_t, float, int32_t_to_float) +declare_hashmap(int32_t, double, int32_t_to_double) + +declare_hashmap(int64_t, int32_t, int64_t_to_int32_t) +declare_hashmap(int64_t, int64_t, int64_t_to_int64_t) +declare_hashmap(int64_t, float, int64_t_to_float) +declare_hashmap(int64_t, double, int64_t_to_double) + +declare_hashmap(void*, int32_t, voidptr_to_int32_t) +declare_hashmap(void*, int64_t, voidptr_to_int64_t) +declare_hashmap(void*, float, voidptr_to_float) +declare_hashmap(void*, double, voidptr_to_double) + +declare_hashmap(int32_t, void*, int32_t_to_voidptr) +declare_hashmap(int64_t, void*, int64_t_to_voidptr) + +declare_hashmap(void*, void*, voidptr_to_voidptr) + +// additionally declare create_from_data functions for numeric hashmap +declare_hashmap_create_from_data(int32_t, int32_t) +declare_hashmap_create_from_data(int32_t, int64_t) +declare_hashmap_create_from_data(int32_t, float) +declare_hashmap_create_from_data(int32_t, double) + +declare_hashmap_create_from_data(int64_t, int32_t) +declare_hashmap_create_from_data(int64_t, int64_t) +declare_hashmap_create_from_data(int64_t, float) +declare_hashmap_create_from_data(int64_t, double) + +PyMODINIT_FUNC PyInit_hconc_dict() +{ + static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "htbb_hashmap", + "No docs", + -1, + NULL, + }; + PyObject* m = PyModule_Create(&moduledef); + if (m == NULL) + { + return NULL; + } + + // register previosuly declared hashmap methods + register_hashmap(int32_t_to_int32_t) + register_hashmap(int32_t_to_int64_t) + register_hashmap(int64_t_to_int32_t) + register_hashmap(int64_t_to_int64_t) + + register_hashmap(int32_t_to_float) + register_hashmap(int32_t_to_double) + register_hashmap(int64_t_to_float) + register_hashmap(int64_t_to_double) + + register_hashmap(voidptr_to_int32_t) + register_hashmap(voidptr_to_int64_t) + register_hashmap(voidptr_to_float) + register_hashmap(voidptr_to_double) + + register_hashmap(int32_t_to_voidptr) + register_hashmap(int64_t_to_voidptr) + + register_hashmap(voidptr_to_voidptr); + + // hashmap create_from_data functions for numeric hashmap + REGISTER(hashmap_create_from_data_int32_t_to_int32_t) + REGISTER(hashmap_create_from_data_int32_t_to_int64_t) + REGISTER(hashmap_create_from_data_int64_t_to_int32_t) + REGISTER(hashmap_create_from_data_int64_t_to_int64_t) + + REGISTER(hashmap_create_from_data_int32_t_to_float) + REGISTER(hashmap_create_from_data_int32_t_to_double) + REGISTER(hashmap_create_from_data_int64_t_to_float) + REGISTER(hashmap_create_from_data_int64_t_to_double) + + utils::tbb_control::init(); + + return m; +} + +} // extern "C" + +#undef declare_hashmap_create +#undef declare_hashmap_size +#undef declare_hashmap_set +#undef declare_hashmap_contains +#undef declare_hashmap_lookup +#undef declare_hashmap_clear +#undef declare_hashmap_pop +#undef declare_hashmap_create_from_data +#undef declare_hashmap_update +#undef declare_hashmap_getiter +#undef declare_hashmap_iternext +#undef declare_hashmap_dump +#undef register_hashmap +#undef REGISTER +#undef declare_hashmap diff --git a/sdc/native/hashmap.hpp b/sdc/native/hashmap.hpp new file mode 100644 index 000000000..7a168302a --- /dev/null +++ b/sdc/native/hashmap.hpp @@ -0,0 +1,867 @@ +// ***************************************************************************** +// Copyright (c) 2021, Intel Corporation All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +// EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// ***************************************************************************** + +#include +#include +#include +#include + +#ifdef SDC_DEBUG_NATIVE +#include +#endif + +#include "utils.hpp" +#include "tbb/tbb.h" +#include "tbb/concurrent_unordered_map.h" +#include "numba/core/runtime/nrt_external.h" + + +using voidptr_hash_type = size_t (*)(void* key_ptr); +using voidptr_eq_type = bool (*)(void* lhs_ptr, void* rhs_ptr); +using voidptr_refcnt = void (*)(void* key_ptr); + +using iter_state = std::pair; + + +class CustomVoidPtrHasher +{ +private: + voidptr_hash_type ptr_hash_callback; + +public: + CustomVoidPtrHasher(void* ptr_func) { + ptr_hash_callback = reinterpret_cast(ptr_func); + } + + size_t operator()(void* data_ptr) const { + auto res = ptr_hash_callback(data_ptr); + return res; + } +}; + + +class CustomVoidPtrEquality +{ +private: + voidptr_eq_type ptr_eq_callback; + +public: + CustomVoidPtrEquality(void* ptr_func) { + ptr_eq_callback = reinterpret_cast(ptr_func); + } + + size_t operator()(void* lhs, void* rhs) const { + return ptr_eq_callback(lhs, rhs); + } +}; + + +struct VoidPtrHashCompare { +private: + voidptr_hash_type ptr_hash_callback; + voidptr_eq_type ptr_eq_callback; + +public: + size_t hash(void* data_ptr) const { + return ptr_hash_callback(data_ptr); + } + bool equal(void* lhs, void* rhs) const { + return ptr_eq_callback(lhs, rhs); + } + + VoidPtrHashCompare(void* ptr_hash, void* ptr_equality) { + ptr_hash_callback = reinterpret_cast(ptr_hash); + ptr_eq_callback = reinterpret_cast(ptr_equality); + } + + VoidPtrHashCompare() = delete; + VoidPtrHashCompare(const VoidPtrHashCompare&) = default; + VoidPtrHashCompare& operator=(const VoidPtrHashCompare&) = default; + VoidPtrHashCompare(VoidPtrHashCompare&&) = default; + VoidPtrHashCompare& operator=(VoidPtrHashCompare&&) = default; + ~VoidPtrHashCompare() = default; +}; + + +struct VoidPtrTypeInfo { + voidptr_refcnt incref; + voidptr_refcnt decref; + uint64_t size; + + VoidPtrTypeInfo(void* incref_addr, void* decref_addr, uint64_t val_size) { + incref = reinterpret_cast(incref_addr); + decref = reinterpret_cast(decref_addr); + size = val_size; + } + + VoidPtrTypeInfo() = delete; + VoidPtrTypeInfo(const VoidPtrTypeInfo&) = default; + VoidPtrTypeInfo& operator=(const VoidPtrTypeInfo&) = default; + VoidPtrTypeInfo& operator=(VoidPtrTypeInfo&&) = default; + ~VoidPtrTypeInfo() = default; + + void delete_voidptr(void* ptr_data) { + this->decref(ptr_data); + free(ptr_data); + } +}; + + +template, + typename Equality=std::equal_to +> +class NumericHashmapType { +public: + using map_type = typename tbb::concurrent_unordered_map; + using iterator_type = typename map_type::iterator; + map_type map; + + NumericHashmapType() + : map(0, Hasher(), Equality()) {} + // TO-DO: support copying for all hashmaps and ConcurrentDict's .copy() method in python? + NumericHashmapType(const NumericHashmapType&) = delete; + NumericHashmapType& operator=(const NumericHashmapType&) = delete; + NumericHashmapType(NumericHashmapType&& rhs) = delete; + NumericHashmapType& operator=(NumericHashmapType&& rhs) = delete; + ~NumericHashmapType() {} + + uint64_t size() { + return this->map.size(); + } + + void set(Key key, Val val) { + this->map[key] = val; + } + + int8_t contains(Key key) { + auto it = this->map.find(key); + return it != this->map.end(); + } + int8_t lookup(Key key, Val* res) { + auto it = this->map.find(key); + bool found = it != this->map.end(); + if (found) + *res = (*it).second; + + return found; + } + void clear() { + this->map.clear(); + } + + int8_t pop(Key key, Val* res) { + auto node_handle = this->map.unsafe_extract(key); + auto found = !node_handle.empty(); + if (found) + *res = node_handle.mapped(); + + return found; + } + + void update(NumericHashmapType& other) { + this->map.merge(other.map); + } + + void* getiter() { + auto p_it = new iterator_type(this->map.begin()); + auto state = new iter_state((void*)p_it, (void*)this); + return state; + } +}; + + +template +> +class GenericHashmapBase { +public: + using map_type = typename tbb::concurrent_hash_map; + using iterator_type = typename map_type::iterator; + map_type map; + + // FIXME: 0 default size is suboptimal, can we optimize this? + GenericHashmapBase() : map(0, HashCompare()) {} + GenericHashmapBase(const HashCompare& hash_compare) : map(0, hash_compare) {} + + GenericHashmapBase(const GenericHashmapBase&) = delete; + GenericHashmapBase& operator=(const GenericHashmapBase&) = delete; + GenericHashmapBase(GenericHashmapBase&& rhs) = delete; + GenericHashmapBase& operator=(GenericHashmapBase&& rhs) = delete; + virtual ~GenericHashmapBase() { + } + + uint64_t size() { + return this->map.size(); + } + + int8_t contains(Key key) { + bool found = false; + { + typename map_type::const_accessor result; + found = this->map.find(result, key); + result.release(); + } + return found; + } + + int8_t lookup(Key key, Val* res) { + bool found = false; + { + typename map_type::const_accessor result; + found = this->map.find(result, key); + if (found) + *res = result->second; + result.release(); + } + + return found; + } + + virtual void set(Key key, Val val) = 0; + + void update(GenericHashmapBase& other) { + tbb::parallel_for( + other.map.range(), + [this](const typename map_type::range_type& r) { + for (typename map_type::iterator i = r.begin(); i != r.end(); ++i) { + this->set(i->first, i->second); + } + }); + } + + void* getiter() { + auto p_it = new iterator_type(this->map.begin()); + auto state = new iter_state((void*)p_it, (void*)this); + return state; + } +}; + + +/* primary template for GenericHashmapType */ +template +> +class GenericHashmapType : public GenericHashmapBase { +public: + // TO-DO: make VoidPtrTypeInfo templates and unify modifiers impl via calls to template funcs + using map_type = typename GenericHashmapBase::map_type; + VoidPtrTypeInfo key_info; + VoidPtrTypeInfo val_info; + + GenericHashmapType(const VoidPtrTypeInfo& ki, + const VoidPtrTypeInfo& vi, + const HashCompare& hash_compare) + : GenericHashmapBase(hash_compare), + key_info(ki), + val_info(vi) {} + GenericHashmapType(const VoidPtrTypeInfo& ki, const VoidPtrTypeInfo& vi) : GenericHashmapType(ki, vi, HashCompare()) {} + + GenericHashmapType() = delete; + GenericHashmapType(const GenericHashmapType&) = delete; + GenericHashmapType& operator=(const GenericHashmapType&) = delete; + GenericHashmapType(GenericHashmapType&& rhs) = delete; + GenericHashmapType& operator=(GenericHashmapType&& rhs) = delete; + virtual ~GenericHashmapType() {}; + + + void clear() { + this->map.clear(); + } + + int8_t pop(Key key, void* res) { + bool found = false; + { + typename map_type::const_accessor result; + found = this->map.find(result, key); + if (found) + { + memcpy(res, &(result->second), this->val_info.size); + this->map.erase(result); + } + result.release(); + } + + return found; + } + + void set(Key key, Val val) { + typename map_type::value_type inserted_node(key, val); + { + typename map_type::accessor existing_node; + bool ok = this->map.insert(existing_node, inserted_node); + if (!ok) + { + // insertion failed key already exists + existing_node->second = val; + } + } + } +}; + + +/* generic-value partial specialization */ +template +class GenericHashmapType : public GenericHashmapBase { +public: + using map_type = typename GenericHashmapBase::map_type; + VoidPtrTypeInfo key_info; + VoidPtrTypeInfo val_info; + + GenericHashmapType(const VoidPtrTypeInfo& ki, + const VoidPtrTypeInfo& vi, + const HashCompare& hash_compare) + : GenericHashmapBase(hash_compare), + key_info(ki), + val_info(vi) {} + GenericHashmapType(const VoidPtrTypeInfo& ki, const VoidPtrTypeInfo& vi) : GenericHashmapType(ki, vi, HashCompare()) {} + + GenericHashmapType() = delete; + GenericHashmapType(const GenericHashmapType&) = delete; + GenericHashmapType& operator=(const GenericHashmapType&) = delete; + GenericHashmapType(GenericHashmapType&& rhs) = delete; + GenericHashmapType& operator=(GenericHashmapType&& rhs) = delete; + virtual ~GenericHashmapType() {}; + + void clear(); + virtual void set(Key key, void* val) override; + int8_t pop(Key key, void* val); +}; + + +/* generic-key partial specialization */ +template +class GenericHashmapType : public GenericHashmapBase { +public: + using map_type = typename GenericHashmapBase::map_type; + VoidPtrTypeInfo key_info; + VoidPtrTypeInfo val_info; + + GenericHashmapType(const VoidPtrTypeInfo& ki, + const VoidPtrTypeInfo& vi, + const VoidPtrHashCompare& hash_compare) + : GenericHashmapBase(hash_compare), + key_info(ki), + val_info(vi) {} + + GenericHashmapType() = delete; + GenericHashmapType(const GenericHashmapType&) = delete; + GenericHashmapType& operator=(const GenericHashmapType&) = delete; + GenericHashmapType(GenericHashmapType&& rhs) = delete; + GenericHashmapType& operator=(GenericHashmapType&& rhs) = delete; + virtual ~GenericHashmapType() {}; + + void clear(); + virtual void set(void* key, Val val) override; + int8_t pop(void* key, void* val); +}; + + +/* generic-key-and-value partial specialization */ +template<> +class GenericHashmapType : public GenericHashmapBase { +public: + using map_type = typename GenericHashmapType::map_type; + VoidPtrTypeInfo key_info; + VoidPtrTypeInfo val_info; + + GenericHashmapType(const VoidPtrTypeInfo& ki, + const VoidPtrTypeInfo& vi, + const VoidPtrHashCompare& hash_compare) + : GenericHashmapBase(hash_compare), + key_info(ki), + val_info(vi) {} + + GenericHashmapType() = delete; + GenericHashmapType(const GenericHashmapType&) = delete; + GenericHashmapType& operator=(const GenericHashmapType&) = delete; + GenericHashmapType(GenericHashmapType&& rhs) = delete; + GenericHashmapType& operator=(GenericHashmapType&& rhs) = delete; + virtual ~GenericHashmapType() {}; + + void clear(); + virtual void set(void* key, void* val) override; + int8_t pop(void* key, void* val); +}; + + +template +using numeric_hashmap = NumericHashmapType; + +template +using generic_key_hashmap = GenericHashmapType; + +template > +using generic_value_hashmap = GenericHashmapType; + +using generic_hashmap = GenericHashmapType; + + +template +numeric_hashmap* +reinterpet_hashmap_ptr(void* p_hash_map, + typename std::enable_if< + !std::is_same::value && + !std::is_same::value>::type* = 0) +{ + return reinterpret_cast*>(p_hash_map); +} + +template +generic_hashmap* +reinterpet_hashmap_ptr(void* p_hash_map, + typename std::enable_if< + std::is_same::value && + std::is_same::value>::type* = 0) +{ + return reinterpret_cast(p_hash_map); +} + +template +generic_value_hashmap* +reinterpet_hashmap_ptr(void* p_hash_map, + typename std::enable_if< + !std::is_same::value && + std::is_same::value>::type* = 0) +{ + return reinterpret_cast*>(p_hash_map); +} + +template +generic_key_hashmap* +reinterpet_hashmap_ptr(void* p_hash_map, + typename std::enable_if< + std::is_same::value && + !std::is_same::value>::type* = 0) +{ + return reinterpret_cast*>(p_hash_map); +} + + +template +void delete_generic_key_hashmap(void* p_hash_map) +{ + auto p_hash_map_spec = (generic_key_hashmap*)p_hash_map; + for (auto kv_pair: p_hash_map_spec->map) { + p_hash_map_spec->key_info.delete_voidptr(kv_pair.first); + } + delete p_hash_map_spec; +} + +template +void delete_generic_value_hashmap(void* p_hash_map) +{ + + auto p_hash_map_spec = (generic_value_hashmap*)p_hash_map; + for (auto kv_pair: p_hash_map_spec->map) { + p_hash_map_spec->val_info.delete_voidptr(kv_pair.second); + } + delete p_hash_map_spec; +} + +void delete_generic_hashmap(void* p_hash_map) +{ + auto p_hash_map_spec = (generic_hashmap*)p_hash_map; + for (auto kv_pair: p_hash_map_spec->map) { + p_hash_map_spec->key_info.delete_voidptr(kv_pair.first); + p_hash_map_spec->val_info.delete_voidptr(kv_pair.second); + } + delete p_hash_map_spec; +} + +template +void delete_numeric_hashmap(void* p_hash_map) +{ + + auto p_hash_map_spec = (numeric_hashmap*)p_hash_map; + delete p_hash_map_spec; +} + + +template +void delete_iter_state(void* p_iter_state) +{ + auto p_iter_state_spec = reinterpret_cast(p_iter_state); + auto p_hash_map_spec = reinterpet_hashmap_ptr(p_iter_state_spec->second); + using itertype = typename std::remove_reference::type::iterator_type; + auto p_hash_map_iter = reinterpret_cast(p_iter_state_spec->first); + + delete p_hash_map_iter; + delete p_iter_state_spec; +} + + +template +void GenericHashmapType::set(Key key, void* val) +{ + auto vsize = this->val_info.size; + void* _val = malloc(vsize); + memcpy(_val, val, vsize); + + typename map_type::value_type inserted_node(key, _val); + { + typename map_type::accessor existing_node; + bool ok = this->map.insert(existing_node, inserted_node); + if (ok) + { + // insertion succeeded need to incref value + this->val_info.incref(val); + } + else + { + // insertion failed key already exists + this->val_info.delete_voidptr(existing_node->second); + existing_node->second = _val; + this->val_info.incref(val); + } + } +} + +template +void GenericHashmapType::clear() +{ + for (auto kv_pair: this->map) { + this->val_info.delete_voidptr(kv_pair.second); + } + this->map.clear(); +} + + +template +int8_t GenericHashmapType::pop(Key key, void* res) { + bool found = false; + { + typename map_type::const_accessor result; + found = this->map.find(result, key); + if (found) + { + memcpy(res, result->second, this->val_info.size); + free(result->second); + // no decref for value since it would be returned (and no incref on python side!) + this->map.erase(result); + } + result.release(); + } + + return found; +} + + +template +void GenericHashmapType::set(void* key, Val val) +{ + auto ksize = this->key_info.size; + void* _key = malloc(ksize); + memcpy(_key, key, ksize); + + typename map_type::value_type inserted_node(_key, val); + { + typename map_type::accessor existing_node; + bool ok = this->map.insert(existing_node, inserted_node); + if (ok) + { + // insertion succeeded need to incref key + this->key_info.incref(key); + } + else + { + // insertion failed key already exists + free(_key); + existing_node->second = val; + } + } +} + +template +void GenericHashmapType::clear() +{ + for (auto kv_pair: this->map) { + this->key_info.delete_voidptr(kv_pair.first); + } + this->map.clear(); +} + +template +int8_t GenericHashmapType::pop(void* key, void* res) { + bool found = false; + { + typename map_type::const_accessor result; + found = this->map.find(result, key); + if (found) + { + memcpy(res, &(result->second), this->val_info.size); + this->key_info.delete_voidptr(result->first); + // no decref for value since it would be returned (and no incref on python side!) + this->map.erase(result); + } + result.release(); + } + + return found; +} + + +void GenericHashmapType::set(void* key, void* val) + +{ + auto ksize = this->key_info.size; + void* _key = malloc(ksize); + memcpy(_key, key, ksize); + + auto vsize = this->val_info.size; + void* _val = malloc(vsize); + memcpy(_val, val, vsize); + + typename map_type::value_type inserted_node(_key, _val); + { + typename map_type::accessor existing_node; + bool ok = this->map.insert(existing_node, inserted_node); + if (ok) + { + this->key_info.incref(key); + this->val_info.incref(val); + } + else + { + // insertion failed key already exists + free(_key); + + this->val_info.delete_voidptr(existing_node->second); + existing_node->second = _val; + this->val_info.incref(val); + } + } +} + +void GenericHashmapType::clear() +{ + for (auto kv_pair: this->map) { + this->key_info.delete_voidptr(kv_pair.first); + this->val_info.delete_voidptr(kv_pair.second); + } + this->map.clear(); +} + + +int8_t GenericHashmapType::pop(void* key, void* res) { + bool found = false; + { + typename map_type::const_accessor result; + found = this->map.find(result, key); + if (found) + { + memcpy(res, result->second, this->val_info.size); + + free(result->second); + this->key_info.delete_voidptr(result->first); + // no decref for value since it would be returned (and no incref on python side!) + this->map.erase(result); + } + result.release(); + } + + return found; +} + + +template +void hashmap_create(NRT_MemInfo** meminfo, + void* nrt_table, + int8_t gen_key, + int8_t gen_val, + void* hash_func_ptr, + void* eq_func_ptr, + void* key_incref_func_ptr, + void* key_decref_func_ptr, + void* val_incref_func_ptr, + void* val_decref_func_ptr, + uint64_t key_size, + uint64_t val_size) +{ + auto nrt = (NRT_api_functions*)nrt_table; + + // it is essential for all specializations to have common ctor signature, taking both key_info and val_info + // since all specializations should be instantiable with different key_type/value_type, so e.g. + // generic_key_hashmap with val_type = void* would match full specialization. TO-DO: consider refactoring + auto key_info = VoidPtrTypeInfo(key_incref_func_ptr, key_decref_func_ptr, key_size); + auto val_info = VoidPtrTypeInfo(val_incref_func_ptr, val_decref_func_ptr, val_size); + if (gen_key && gen_val) + { + auto p_hash_map = new generic_hashmap(key_info, val_info, VoidPtrHashCompare(hash_func_ptr, eq_func_ptr)); + (*meminfo) = nrt->manage_memory((void*)p_hash_map, delete_generic_hashmap); + } + else if (gen_key) + { + auto p_hash_map = new generic_key_hashmap(key_info, val_info, VoidPtrHashCompare(hash_func_ptr, eq_func_ptr)); + (*meminfo) = nrt->manage_memory((void*)p_hash_map, delete_generic_key_hashmap); + } + else if (gen_val) + { + auto p_hash_map = new generic_value_hashmap(key_info, val_info); + (*meminfo) = nrt->manage_memory((void*)p_hash_map, delete_generic_value_hashmap); + } + else + { + // numeric_hashmap is actually an instance of NumericHashmapType, not a specialization of + // GenericHashmapType since it's built upon tbb::concurrent_unordered_map. TO-DO: consider + // moving to one impl later if there's no performance penalty + auto p_hash_map = new numeric_hashmap; + (*meminfo) = nrt->manage_memory((void*)p_hash_map, delete_numeric_hashmap); + } + + return; +} + + +template +uint64_t hashmap_size(void* p_hash_map) +{ + auto p_hash_map_spec = reinterpet_hashmap_ptr(p_hash_map); + return p_hash_map_spec->size(); +} + + +template +void hashmap_set(void* p_hash_map, key_type key, val_type val) +{ + auto p_hash_map_spec = reinterpet_hashmap_ptr(p_hash_map); + p_hash_map_spec->set(key, val); +} + + +template +int8_t hashmap_lookup(void* p_hash_map, + key_type key, + val_type* res) +{ + auto p_hash_map_spec = reinterpet_hashmap_ptr(p_hash_map); + return p_hash_map_spec->lookup(key, res); +} + + +template +void hashmap_clear(void* p_hash_map) +{ + auto p_hash_map_spec = reinterpet_hashmap_ptr(p_hash_map); + p_hash_map_spec->clear(); +} + + +template +int8_t hashmap_unsafe_extract(void* p_hash_map, key_type key, val_type* res) +{ + auto p_hash_map_spec = reinterpet_hashmap_ptr(p_hash_map); + return p_hash_map_spec->pop(key, res); +} + + +template +void hashmap_numeric_from_arrays(NRT_MemInfo** meminfo, void* nrt_table, key_type* keys, val_type* values, uint64_t size) +{ + auto nrt = (NRT_api_functions*)nrt_table; + auto p_hash_map = new numeric_hashmap; + (*meminfo) = nrt->manage_memory((void*)p_hash_map, delete_numeric_hashmap); + + // FIXME: apply arena to make this react on changing NUMBA_NUM_THREADS + tbb::parallel_for(tbb::blocked_range(0, size), + [=](const tbb::blocked_range& r) { + for(size_t i=r.begin(); i!=r.end(); ++i) { + auto kv_pair = std::pair(keys[i], values[i]); + p_hash_map->map.insert( + std::move(kv_pair) + ); + } + } + ); +} + + +template +void hashmap_update(void* p_self_hash_map, void* p_arg_hash_map) +{ + auto p_self_hash_map_spec = reinterpet_hashmap_ptr(p_self_hash_map); + auto p_arg_hash_map_spec = reinterpet_hashmap_ptr(p_arg_hash_map); + p_self_hash_map_spec->update(*p_arg_hash_map_spec); + return; +} + + +#ifdef SDC_DEBUG_NATIVE +template +void hashmap_dump(void* p_hash_map) +{ + auto p_hash_map_spec = reinterpet_hashmap_ptr(p_hash_map); + auto size = p_hash_map_spec->map.size(); + std::cout << "Hashmap at: " << p_hash_map_spec << ", size = " << size << std::endl; + for (auto kv_pair: p_hash_map_spec->map) + { + std::cout << "key, value: " << kv_pair.first << ", " << kv_pair.second << std::endl; + } + return; +} +#endif + + +template +void* hashmap_getiter(NRT_MemInfo** meminfo, void* nrt_table, void* p_hash_map) +{ + auto p_hash_map_spec = reinterpet_hashmap_ptr(p_hash_map); + auto p_iter_state = p_hash_map_spec->getiter(); + + auto nrt = (NRT_api_functions*)nrt_table; + (*meminfo) = nrt->manage_memory((void*)p_iter_state, delete_iter_state); + return p_iter_state; +} + + +template +int8_t hashmap_iternext(void* p_iter_state, key_type* ret_key, val_type* ret_val) +{ + auto p_iter_state_spec = reinterpret_cast(p_iter_state); + auto p_hash_map_spec = reinterpet_hashmap_ptr(p_iter_state_spec->second); + using itertype = typename std::remove_reference::type::iterator_type; + auto p_hash_map_iter = reinterpret_cast(p_iter_state_spec->first); + + int8_t status = 1; + if (*p_hash_map_iter != p_hash_map_spec->map.end()) + { + *ret_key = (*p_hash_map_iter)->first; + *ret_val = (*p_hash_map_iter)->second; + status = 0; + ++(*p_hash_map_iter); + } + + return status; +} diff --git a/sdc/tests/__init__.py b/sdc/tests/__init__.py index eeb4014b8..56775d77c 100644 --- a/sdc/tests/__init__.py +++ b/sdc/tests/__init__.py @@ -50,5 +50,7 @@ from sdc.tests.test_prange_utils import * from sdc.tests.test_compile_time import * +from sdc.tests.test_tbb_hashmap import * + # performance tests import sdc.tests.tests_perf diff --git a/sdc/tests/test_tbb_hashmap.py b/sdc/tests/test_tbb_hashmap.py new file mode 100644 index 000000000..45625b5ed --- /dev/null +++ b/sdc/tests/test_tbb_hashmap.py @@ -0,0 +1,1049 @@ +# ***************************************************************************** +# Copyright (c) 2021, Intel Corporation All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +import numba +import numpy as np +import unittest + +from itertools import product, chain, filterfalse, groupby +from numba.core import types +from numba.core.errors import TypingError +from numba import prange +from numba.np.numpy_support import as_dtype +from numba.tests.support import MemoryLeakMixin +from sdc.extensions.sdc_hashmap_type import ConcurrentDict +from sdc.tests.test_base import TestCase +from sdc.tests.test_series import test_global_input_data_float64 +from sdc.tests.test_utils import gen_strlist + +from sdc.extensions.sdc_hashmap_ext import ( + supported_numeric_key_types, + supported_numeric_value_types, + ) + + +int_limits_min = list(map(lambda x: np.iinfo(x).min, ['int32', 'int64', 'uint32', 'uint64'])) +int_limits_max = list(map(lambda x: np.iinfo(x).max, ['int32', 'int64', 'uint32', 'uint64'])) + + +global_test_cdict_values = { + types.Integer: [0, -5, 17] + int_limits_min + int_limits_max, + types.Float: list(chain.from_iterable(test_global_input_data_float64)), + types.UnicodeType: ['a1', 'a2', 'b', 'sdf', 'q', 're', 'fde', ''] +} + + +def assert_dict_correct(self, result, fromdata): + """ This function checks that result's keys and values match data from which it was created, + i.e. keys match strictly and all values are associated with the same key in fromdata """ + + self.assertEqual(set(result.keys()), set(dict(fromdata).keys())) + + def key_func(x): + return x[0] + + fromdata = sorted(fromdata, key=key_func) + for k, g in groupby(fromdata, key_func): + v = result[k] + group_values_arr = np.array(list(zip(*g))[1]) + group_values_set = set(group_values_arr) + if isinstance(v, float) and np.isnan(v): + self.assertTrue(any(np.isnan(group_values_arr)), + f"result[{k}] == {v} not found in source values") + else: + self.assertIn(v, group_values_set) + + +class TestHashmapNumeric(MemoryLeakMixin, TestCase): + """ Verifies correctness of numeric implementation TBB based hashmap, + i.e. specialization that is selected when both keys and values are numeric. """ + + key_types = supported_numeric_key_types + value_types = supported_numeric_value_types + + _default_key = 7 + _default_value = 25 + + def get_default_key(self, nbtype): + return as_dtype(nbtype).type(self._default_key) + + def get_default_value(self, nbtype): + return as_dtype(nbtype).type(self._default_value) + + def get_random_sequence(self, nbtype, n=10): + assert isinstance(nbtype, types.Number), "Non-numeric type used in TestHashmapNumeric" + values = global_test_cdict_values[type(nbtype)] + return np.random.choice(values, n).astype(str(nbtype)) + + # **************************** Tests start **************************** # + + def test_hashmap_numeric_create_empty(self): + + @self.jit + def test_impl(key_type, value_type): + a_dict = ConcurrentDict.empty(key_type, value_type) + return len(a_dict) + + expected_res = 0 + for key_type, value_type in product(self.key_types, self.value_types): + with self.subTest(key_type=key_type, value_type=value_type): + self.assertEqual(test_impl(key_type, value_type), expected_res) + + def test_hashmap_numeric_create_from_arrays(self): + + @self.jit + def test_impl(keys, values): + a_dict = ConcurrentDict.from_arrays(keys, values) + res = list(a_dict.items()) # this relies on working iterator + return res + + n = 47 + np.random.seed(0) + + for key_type, value_type in product(self.key_types, self.value_types): + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + source_kv_pairs = list(zip(keys, values)) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, values=values): + result = test_impl(keys, values) + assert_dict_correct(self, dict(result), source_kv_pairs) + + def test_hashmap_numeric_create_from_typed_dict(self): + + # FIXME_Numba#XXXX: iterating through typed.Dict fails memleak checks! + self.disable_leak_check() + + from numba.typed import Dict + + @self.jit + # FIXME: we still need to implement key_type and value_type properties?? + def test_impl(tdict, key_type, value_type): + a_dict = ConcurrentDict.empty(key_type, value_type) + for k, v in tdict.items(): + a_dict[k] = v + + res = list(a_dict.items()) # this relies on working iterator + return res + + n = 47 + np.random.seed(0) + + for key_type, value_type in product(self.key_types, self.value_types): + tdict = Dict.empty(key_type, value_type) + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + source_kv_pairs = list(zip(keys, values)) + for k, v in source_kv_pairs: + tdict[k] = v + with self.subTest(key_type=key_type, value_type=value_type, tdict=tdict): + result = test_impl(tdict, key_type, value_type) + assert_dict_correct(self, dict(result), source_kv_pairs) + + def test_hashmap_numeric_insert(self): + + @self.jit + def test_impl(key_type, value_type, key, value): + a_dict = ConcurrentDict.empty(key_type, value_type) + a_dict[key] = value + return len(a_dict), a_dict[key] + + for key_type, value_type in product(self.key_types, self.value_types): + + _key = self.get_default_key(key_type) + _value = self.get_default_value(value_type) + with self.subTest(key_type=key_type, value_type=value_type): + self.assertEqual( + test_impl(key_type, value_type, _key, _value), + (1, _value) + ) + + def test_hashmap_numeric_set_value(self): + + @self.jit + def test_impl(key, value, new_value): + a_dict = ConcurrentDict.from_arrays( + np.array([key, ]), + np.array([value, ]), + ) + + a_dict[key] = new_value + return a_dict[key] + + new_value = 11 + for key_type, value_type in product(self.key_types, self.value_types): + + _key = self.get_default_key(key_type) + _value = self.get_default_value(value_type) + _new_value = as_dtype(value_type).type(new_value) + with self.subTest(key_type=key_type, value_type=value_type): + self.assertEqual(test_impl(_key, _value, _new_value), _new_value) + + def test_hashmap_numeric_lookup_positive(self): + + @self.jit + def test_impl(key, value): + a_dict = ConcurrentDict.from_arrays( + np.array([key, ]), + np.array([value, ]), + ) + return a_dict[key] + + for key_type, value_type in product(self.key_types, self.value_types): + _key = self.get_default_key(key_type) + _value = self.get_default_value(value_type) + with self.subTest(key_type=key_type, value_type=value_type): + self.assertEqual(test_impl(_key, _value), _value) + + def test_hashmap_numeric_lookup_negative(self): + + # this is common for all Numba tests that check exceptions are raised + self.disable_leak_check() + + @self.jit + def test_impl(key, value): + a_dict = ConcurrentDict.from_arrays( + np.array([key, ]), + np.array([value, ]), + ) + + return a_dict[2*key] + + for key_type, value_type in product(self.key_types, self.value_types): + _key = self.get_default_key(key_type) + _value = self.get_default_value(value_type) + with self.subTest(key_type=key_type, value_type=value_type): + with self.assertRaises(KeyError) as raises: + test_impl(_key, _value) + msg = 'ConcurrentDict key not found' + self.assertIn(msg, str(raises.exception)) + + def test_hashmap_numeric_contains(self): + + @self.jit + def test_impl(key, value): + a_dict = ConcurrentDict.from_arrays( + np.array([key, ]), + np.array([value, ]), + ) + return key in a_dict, 2*key in a_dict + + expected_res = (True, False) + for key_type, value_type in product(self.key_types, self.value_types): + _key = self.get_default_key(key_type) + _value = self.get_default_value(value_type) + with self.subTest(key_type=key_type, value_type=value_type): + self.assertEqual(test_impl(_key, _value), expected_res) + + def test_hashmap_numeric_pop(self): + + @self.jit + def test_impl(key, value): + a_dict = ConcurrentDict.from_arrays( + np.array([key, ]), + np.array([value, ]), + ) + a_dict.pop(key) + return len(a_dict), a_dict.get(key, None) + + expected_res = (0, None) + for key_type, value_type in product(self.key_types, self.value_types): + _key = self.get_default_key(key_type) + _value = self.get_default_value(value_type) + with self.subTest(key_type=key_type, value_type=value_type): + self.assertEqual(test_impl(_key, _value), expected_res) + + def test_hashmap_numeric_clear(self): + + @self.jit + def test_impl(keys, values): + a_dict = ConcurrentDict.from_arrays(keys, values) + r1 = len(a_dict) + a_dict.clear() + r2 = len(a_dict) + return r1, r2 + + n = 47 + np.random.seed(0) + + for key_type, value_type in product(self.key_types, self.value_types): + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + expected_res = (len(set(keys)), 0) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, values=values): + self.assertEqual(test_impl(keys, values), expected_res) + + def test_hashmap_numeric_get(self): + + @self.jit + def test_impl(key, value, default): + a_dict = ConcurrentDict.from_arrays( + np.array([key, ]), + np.array([value, ]), + ) + r1 = a_dict.get(key, None) + r2 = a_dict.get(2*key, default) + r3 = a_dict.get(2*key) + return r1, r2, r3 + + default_value = 0 + for key_type, value_type in product(self.key_types, self.value_types): + _key = self.get_default_key(key_type) + _value = self.get_default_value(value_type) + _default = value_type(default_value) + expected_res = (_value, _default, None) + with self.subTest(key_type=key_type, value_type=value_type): + self.assertEqual(test_impl(_key, _value, _default), expected_res) + + def test_hashmap_numeric_insert_implicit_cast(self): + + @self.jit + def test_impl(key_type, value_type, key, value): + a_dict = ConcurrentDict.empty(key_type, value_type) + a_dict[key] = value + return len(a_dict), key in a_dict + + key_type, value_type = types.int64, types.int64 + _key = np.dtype('int16').type(self._default_key) + _value = np.dtype('uint16').type(self._default_value) + expected_res = (1, True) + result = test_impl(key_type, value_type, _key, _value) + self.assertEqual(result, expected_res) + + def test_hashmap_numeric_insert_cast_fails(self): + + @self.jit + def test_impl(key_type, value_type, key, value): + a_dict = ConcurrentDict.empty(key_type, value_type) + a_dict[key] = value + return len(a_dict), key in a_dict + + key_type, value_type = types.int64, types.int64 + _key = np.dtype('float32').type(self._default_key) + _value = np.dtype('uint16').type(self._default_value) + with self.subTest(subtest='first', key_type=key_type, value_type=value_type): + with self.assertRaises(TypingError) as raises: + test_impl(key_type, value_type, _key, _value) + msg = 'TypingError: cannot safely cast' + self.assertIn(msg, str(raises.exception)) + + _key = np.dtype('uint16').type(self._default_key) + _value = np.dtype('float64').type(self._default_value) + with self.subTest(subtest='second', key_type=key_type, value_type=value_type): + with self.assertRaises(TypingError) as raises: + test_impl(key_type, value_type, _key, _value) + msg = 'TypingError: cannot safely cast' + self.assertIn(msg, str(raises.exception)) + + def test_hashmap_numeric_use_prange(self): + + @self.jit + def test_impl(key_type, value_type, keys, values): + a_dict = ConcurrentDict.empty(key_type, value_type) + for i in prange(len(keys)): + a_dict[keys[i]] = values[i] + + res = list(a_dict.items()) # this relies on working iterator + return res + + n = 47 + np.random.seed(0) + + for key_type, value_type in product(self.key_types, self.value_types): + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + source_kv_pairs = list(zip(keys, values)) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, values=values): + result = test_impl(key_type, value_type, keys, values) + assert_dict_correct(self, dict(result), source_kv_pairs) + + def test_hashmap_numeric_fromkeys_class(self): + + @self.jit + def test_impl(keys, value): + a_dict = ConcurrentDict.fromkeys(keys, value) + check_keys = np.array([k in a_dict for k in keys]) + return len(a_dict), np.all(check_keys) + + n = 47 + np.random.seed(0) + + for key_type, value_type in product(self.key_types, self.value_types): + keys = self.get_random_sequence(key_type, n) + value = self.get_default_value(value_type) + expected_res = (len(set(keys)), True) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, value=value): + self.assertEqual(test_impl(keys, value), expected_res) + + def test_hashmap_numeric_fromkeys_dictobject(self): + + @self.jit + def test_impl(keys, value): + a_dict = ConcurrentDict.empty(types.int64, types.float64) + res = a_dict.fromkeys(keys, value) + check_keys = np.array([k in res for k in keys]) + return len(res), np.all(check_keys), len(a_dict) + + n = 47 + np.random.seed(0) + + for key_type, value_type in product(self.key_types, self.value_types): + keys = self.get_random_sequence(key_type, n) + value = self.get_default_value(value_type) + expected_res = (len(set(keys)), True, 0) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, value=value): + self.assertEqual(test_impl(keys, value), expected_res) + + def test_hashmap_numeric_update(self): + + @self.jit + def test_impl(keys1, values1, keys2, values2): + a_dict = ConcurrentDict.from_arrays(keys1, values1) + other_dict = ConcurrentDict.from_arrays(keys2, values2) + r1 = len(a_dict) + a_dict.update(other_dict) + r2 = len(a_dict) + check_keys = np.array([k in a_dict for k in keys2]) + return r1, r2, np.all(check_keys) + + n = 47 + np.random.seed(0) + + for key_type, value_type in product(self.key_types, self.value_types): + keys1 = self.get_random_sequence(key_type, n) + keys2 = self.get_random_sequence(key_type, 2 * n) + values1 = self.get_random_sequence(value_type, n) + values2 = self.get_random_sequence(value_type, 2 * n) + before_size = len(set(keys1)) + after_size = len(set(keys1).union(set(keys2))) + expected_res = (before_size, after_size, True) + with self.subTest(key_type=key_type, value_type=value_type, + keys1=keys1, values1=values1, + keys2=keys2, values2=values2): + result = test_impl(keys1, values1, keys2, values2) + self.assertEqual(result, expected_res) + + def test_hashmap_numeric_iterator(self): + + @self.jit + def test_impl(keys, values): + a_dict = ConcurrentDict.from_arrays(keys, values) + res = [] + for k in a_dict: + res.append(k) + return res + + n = 47 + np.random.seed(0) + + for key_type, value_type in product(self.key_types, self.value_types): + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, values=values): + # expect a list of keys returned in some (i.e. non-fixed) order + result = test_impl(keys, values) + self.assertEqual(set(result), set(keys)) + + def test_hashmap_numeric_iterator_freed(self): + + @self.jit + def test_impl(keys, values): + a_dict = ConcurrentDict.from_arrays(keys, values) + dict_iter = iter(a_dict) + r1 = next(dict_iter) + r2 = next(dict_iter) + r3 = next(dict_iter) + return r1, r2, r3 + + n = 47 + np.random.seed(0) + + for key_type, value_type in product(self.key_types, self.value_types): + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, values=values): + result = test_impl(keys, values) + self.assertTrue(set(result).issubset(set(keys)), + f"Some key ({result}) is not found in source keys: {keys}") + + def test_hashmap_numeric_keys(self): + + @self.jit + def test_impl(keys, values): + a_dict = ConcurrentDict.from_arrays(keys, values) + res = [] + for k in a_dict.keys(): + res.append(k) + return res + + n = 47 + np.random.seed(0) + + for key_type, value_type in product(self.key_types, self.value_types): + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, values=values): + # expect a list of keys returned in some (i.e. non-fixed) order + result = test_impl(keys, values) + self.assertEqual(set(result), set(keys)) + + def test_hashmap_numeric_items(self): + + @self.jit + def test_impl(keys, values): + a_dict = ConcurrentDict.from_arrays(keys, values) + res = [] + for k, v in a_dict.items(): + res.append((k, v)) + return res + + n = 47 + np.random.seed(0) + + for key_type, value_type in product(self.key_types, self.value_types): + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + source_kv_pairs = list(zip(keys, values)) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, values=values): + result = test_impl(keys, values) + assert_dict_correct(self, dict(result), source_kv_pairs) + + def test_hashmap_numeric_values(self): + + @self.jit + def test_impl(keys, values): + a_dict = ConcurrentDict.from_arrays(keys, values) + res = [] + for k, v in a_dict.items(): + res.append((k, v)) + return res + + n = 47 + np.random.seed(0) + + for key_type, value_type in product(self.key_types, self.value_types): + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + source_kv_pairs = list(zip(keys, values)) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, values=values): + result = test_impl(keys, values) + assert_dict_correct(self, dict(result), source_kv_pairs) + + +class TestHashmapGeneric(MemoryLeakMixin, TestCase): + """ Verifies correctness of following specializations: + generic-key hashmap, generic-value hashmap and generic-key-and-value. + Generic means objects are passed as void*. """ + + @classmethod + def key_value_combinations(cls): + res = filterfalse( + lambda x: isinstance(x[0], types.Number) and isinstance(x[1], types.Number), + product(cls.key_types, cls.value_types) + ) + return res + + key_types = [ + types.int32, + types.uint32, + types.int64, + types.uint64, + types.unicode_type, + ] + + value_types = [ + types.int32, + types.uint32, + types.int64, + types.uint64, + types.float32, + types.float64, + types.unicode_type, + ] + + _default_data = { + types.Integer: 11, + types.Float: 42.3, + types.UnicodeType: 'sdf', + } + + def get_default_scalar(self, nbtype): + meta_type = type(nbtype) + if isinstance(nbtype, types.Number): + res = as_dtype(nbtype).type(self._default_data[meta_type]) + elif isinstance(nbtype, types.UnicodeType): + res = self._default_data[meta_type] + return res + + # TO-DO: this looks too similar to gen_arr_of_dtype in perf_tests, re-use? + def get_random_sequence(self, nbtype, n=10): + if isinstance(nbtype, types.Number): + values = np.arange(n // 2, dtype=as_dtype(nbtype)) + res = np.random.choice(values, n) + elif isinstance(nbtype, types.UnicodeType): + values = gen_strlist(n // 2) + res = list(np.random.choice(values, n)) + return res + + # **************************** Tests start **************************** # + + def test_hashmap_generic_create_empty(self): + + @self.jit + def test_impl(key_type, value_type): + a_dict = ConcurrentDict.empty(key_type, value_type) + return len(a_dict) + + expected_res = 0 + for key_type, value_type in self.key_value_combinations(): + with self.subTest(key_type=key_type, value_type=value_type): + self.assertEqual(test_impl(key_type, value_type), expected_res) + + def test_hashmap_generic_create_from_typed_dict(self): + + # FIXME_Numba#XXXX: iterating through typed.Dict fails memleak checks! + self.disable_leak_check() + + from numba.typed import Dict + + @self.jit + def test_impl(tdict, key_type, value_type): + a_dict = ConcurrentDict.empty(key_type, value_type) + for k, v in tdict.items(): + a_dict[k] = v + + res = list(a_dict.items()) # this relies on working iterator + return res + + n = 47 + np.random.seed(0) + + for key_type, value_type in self.key_value_combinations(): + tdict = Dict.empty(key_type, value_type) + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + source_kv_pairs = list(zip(keys, values)) + for k, v in source_kv_pairs: + tdict[k] = v + with self.subTest(key_type=key_type, value_type=value_type, tdict=tdict): + result = test_impl(tdict, key_type, value_type) + assert_dict_correct(self, dict(result), source_kv_pairs) + + def test_hashmap_generic_insert(self): + + @self.jit + def test_impl(key_type, value_type, key, value): + a_dict = ConcurrentDict.empty(key_type, value_type) + a_dict[key] = value + return len(a_dict), a_dict[key] + + for key_type, value_type in self.key_value_combinations(): + + _key = self.get_default_scalar(key_type) + _value = self.get_default_scalar(value_type) + with self.subTest(key_type=key_type, value_type=value_type): + self.assertEqual( + test_impl(key_type, value_type, _key, _value), + (1, _value) + ) + + def test_hashmap_generic_set_value(self): + + @self.jit + def test_impl(key, value, new_value): + a_dict = ConcurrentDict.fromkeys([key], value) + + a_dict[key] = new_value + return a_dict[key] + + np.random.seed(0) + + for key_type, value_type in self.key_value_combinations(): + _key = self.get_default_scalar(key_type) + _value = self.get_default_scalar(value_type) + _new_value = self.get_random_sequence(value_type)[0] + with self.subTest(key_type=key_type, value_type=value_type): + self.assertEqual(test_impl(_key, _value, _new_value), _new_value) + + def test_hashmap_generic_lookup_positive(self): + + @self.jit + def test_impl(key, value): + a_dict = ConcurrentDict.fromkeys([key], value) + return a_dict[key] + + for key_type, value_type in self.key_value_combinations(): + _key = self.get_default_scalar(key_type) + _value = self.get_default_scalar(value_type) + with self.subTest(key_type=key_type, value_type=value_type): + self.assertEqual(test_impl(_key, _value), _value) + + def test_hashmap_generic_lookup_negative(self): + + # this is common for all Numba tests that check exceptions are raised + self.disable_leak_check() + + @self.jit + def test_impl(key, value): + a_dict = ConcurrentDict.fromkeys([key], value) + + return a_dict[2*key] + + for key_type, value_type in self.key_value_combinations(): + _key = self.get_default_scalar(key_type) + _value = self.get_default_scalar(value_type) + with self.subTest(key_type=key_type, value_type=value_type): + with self.assertRaises(KeyError) as raises: + test_impl(_key, _value) + msg = 'ConcurrentDict key not found' + self.assertIn(msg, str(raises.exception)) + + def test_hashmap_generic_contains(self): + + @self.jit + def test_impl(key, value): + a_dict = ConcurrentDict.fromkeys([key], value) + return key in a_dict, 2*key in a_dict + + expected_res = (True, False) + for key_type, value_type in self.key_value_combinations(): + _key = self.get_default_scalar(key_type) + _value = self.get_default_scalar(value_type) + with self.subTest(key_type=key_type, value_type=value_type): + self.assertEqual(test_impl(_key, _value), expected_res) + + def test_hashmap_generic_pop_positive(self): + + @self.jit + def test_impl(key, value): + a_dict = ConcurrentDict.fromkeys([key], value) + r1 = a_dict.pop(key) + return r1, len(a_dict), a_dict.get(key, None) + + for key_type, value_type in self.key_value_combinations(): + _key = self.get_default_scalar(key_type) + _value = self.get_default_scalar(value_type) + expected_res = (_value, 0, None) + with self.subTest(key_type=key_type, value_type=value_type): + self.assertEqual(test_impl(_key, _value), expected_res) + + def test_hashmap_generic_pop_negative(self): + + @self.jit + def test_impl(key, value): + a_dict = ConcurrentDict.fromkeys([key], value) + r1 = a_dict.pop(2*key) + r2 = a_dict.pop(2*key, 2*value) + return r1, r2 + + for key_type, value_type in self.key_value_combinations(): + _key = self.get_default_scalar(key_type) + _value = self.get_default_scalar(value_type) + expected_res = (None, 2*_value) + with self.subTest(key_type=key_type, value_type=value_type): + self.assertEqual(test_impl(_key, _value), expected_res) + + def test_hashmap_generic_clear(self): + + @self.jit + def test_impl(keys, values): + a_dict = ConcurrentDict.fromkeys(keys, values[0]) + r1 = len(a_dict) + a_dict.clear() + r2 = len(a_dict) + return r1, r2 + + n = 47 + np.random.seed(0) + + for key_type, value_type in self.key_value_combinations(): + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + expected_res = (len(set(keys)), 0) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, values=values): + self.assertEqual(test_impl(keys, values), expected_res) + + def test_hashmap_generic_get(self): + + @self.jit + def test_impl(key, value, default): + a_dict = ConcurrentDict.fromkeys([key], value) + r1 = a_dict.get(key, None) + r2 = a_dict.get(2*key, default) + r3 = a_dict.get(2*key) + return r1, r2, r3 + + np.random.seed(0) + + for key_type, value_type in self.key_value_combinations(): + _key = self.get_default_scalar(key_type) + _value = self.get_default_scalar(value_type) + _default = self.get_random_sequence(value_type)[0] + expected_res = (_value, _default, None) + with self.subTest(key_type=key_type, value_type=value_type): + self.assertEqual(test_impl(_key, _value, _default), expected_res) + + def test_hashmap_generic_use_prange(self): + + @self.jit + def test_impl(key_type, value_type, keys, values): + a_dict = ConcurrentDict.empty(key_type, value_type) + for i in prange(len(keys)): + a_dict[keys[i]] = values[i] + + res = list(a_dict.items()) # this relies on working iterator + return res + + n = 47 + np.random.seed(0) + + for key_type, value_type in self.key_value_combinations(): + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + source_kv_pairs = list(zip(keys, values)) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, values=values): + result = test_impl(key_type, value_type, keys, values) + assert_dict_correct(self, dict(result), source_kv_pairs) + + def test_hashmap_generic_fromkeys_class(self): + + @self.jit + def test_impl(keys, value): + a_dict = ConcurrentDict.fromkeys(keys, value) + check_keys = np.array([k in a_dict for k in keys]) + return len(a_dict), np.all(check_keys) + + n = 47 + np.random.seed(0) + + for key_type, value_type in self.key_value_combinations(): + keys = self.get_random_sequence(key_type, n) + value = self.get_default_scalar(value_type) + expected_res = (len(set(keys)), True) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys): + self.assertEqual(test_impl(keys, value), expected_res) + + def test_hashmap_generic_fromkeys_dictobject(self): + + @self.jit + def test_impl(keys, value): + a_dict = ConcurrentDict.empty(types.int64, types.float64) + res = a_dict.fromkeys(keys, value) + check_keys = np.array([k in res for k in keys]) + return len(res), np.all(check_keys), len(a_dict) + + n = 47 + np.random.seed(0) + + for key_type, value_type in self.key_value_combinations(): + keys = self.get_random_sequence(key_type, n) + value = self.get_default_scalar(value_type) + expected_res = (len(set(keys)), True, 0) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys): + self.assertEqual(test_impl(keys, value), expected_res) + + def test_hashmap_generic_update(self): + + @self.jit + def test_impl(keys1, values1, keys2, values2): + a_dict = ConcurrentDict.fromkeys(keys1, values1[0]) + for k, v in zip(keys1, values1): + a_dict[k] = v + + other_dict = ConcurrentDict.fromkeys(keys2, values2[0]) + for k, v in zip(keys2, values2): + other_dict[k] = v + + r1 = len(a_dict) + a_dict.update(other_dict) + r2 = len(a_dict) + check_keys = np.array([k in a_dict for k in keys2]) + + return r1, r2, np.all(check_keys) + + n = 47 + np.random.seed(0) + + for key_type, value_type in self.key_value_combinations(): + keys1 = self.get_random_sequence(key_type, n) + keys2 = self.get_random_sequence(key_type, 2 * n) + values1 = self.get_random_sequence(value_type, n) + values2 = self.get_random_sequence(value_type, 2 * n) + before_size = len(set(keys1)) + after_size = len(set(keys1).union(set(keys2))) + expected_res = (before_size, after_size, True) + with self.subTest(key_type=key_type, value_type=value_type, + keys1=keys1, values1=values1, + keys2=keys2, values2=values2): + result = test_impl(keys1, values1, keys2, values2) + self.assertEqual(result, expected_res) + + def test_hashmap_generic_iterator(self): + + @self.jit + def test_impl(keys, values): + a_dict = ConcurrentDict.fromkeys(keys, values[0]) + for k, v in zip(keys, values): + a_dict[k] = v + + res = [] + for k in a_dict: + res.append(k) + return res + + n = 47 + np.random.seed(0) + + for key_type, value_type in self.key_value_combinations(): + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, values=values): + # expect a list of keys returned in some (i.e. non-fixed) order + result = test_impl(keys, values) + self.assertEqual(set(result), set(keys)) + + def test_hashmap_generic_iterator_freed(self): + + @self.jit + def test_impl(keys, values): + a_dict = ConcurrentDict.fromkeys(keys, values[0]) + for k, v in zip(keys, values): + a_dict[k] = v + + dict_iter = iter(a_dict) + r1 = next(dict_iter) + r2 = next(dict_iter) + r3 = next(dict_iter) + return r1, r2, r3 + + n = 47 + np.random.seed(0) + + for key_type, value_type in self.key_value_combinations(): + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, values=values): + result = test_impl(keys, values) + self.assertTrue(set(result).issubset(set(keys)), + f"Some key ({result}) is not found in source keys: {keys}") + + def test_hashmap_generic_keys(self): + + @self.jit + def test_impl(keys, values): + a_dict = ConcurrentDict.fromkeys(keys, values[0]) + for k, v in zip(keys, values): + a_dict[k] = v + + res = [] + for k in a_dict.keys(): + res.append(k) + return res + + n = 47 + np.random.seed(0) + + for key_type, value_type in self.key_value_combinations(): + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, values=values): + # expect a list of keys returned in some (i.e. non-fixed) order + result = test_impl(keys, values) + self.assertEqual(set(result), set(keys)) + + def test_hashmap_generic_items(self): + + @self.jit + def test_impl(keys, values): + a_dict = ConcurrentDict.fromkeys(keys, values[0]) + for k, v in zip(keys, values): + a_dict[k] = v + + res = [] + for k, v in a_dict.items(): + res.append((k, v)) + return res + + n = 47 + np.random.seed(0) + + for key_type, value_type in self.key_value_combinations(): + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + source_kv_pairs = list(zip(keys, values)) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, values=values): + result = test_impl(keys, values) + assert_dict_correct(self, dict(result), source_kv_pairs) + + def test_hashmap_generic_values(self): + + @self.jit + def test_impl(keys, values): + a_dict = ConcurrentDict.fromkeys(keys, values[0]) + for k, v in zip(keys, values): + a_dict[k] = v + + res = [] + for k, v in a_dict.items(): + res.append((k, v)) + return res + + n = 47 + np.random.seed(0) + + for key_type, value_type in self.key_value_combinations(): + keys = self.get_random_sequence(key_type, n) + values = self.get_random_sequence(value_type, n) + source_kv_pairs = list(zip(keys, values)) + with self.subTest(key_type=key_type, value_type=value_type, keys=keys, values=values): + result = test_impl(keys, values) + assert_dict_correct(self, dict(result), source_kv_pairs) + + def test_hashmap_generic_tuple_keys(self): + + @self.jit + def test_impl(key_type, value_type, keys, values): + a_dict = ConcurrentDict.empty(key_type, value_type) + for k, v in zip(keys, values): + a_dict[k] = v + return len(a_dict) + + n = 47 + np.random.seed(0) + + key_values = list(product([0, 1], repeat=5)) + key_type, value_type = numba.typeof(key_values[0]), types.int64 + + keys = [key_values[i] for i in np.random.randint(0, len(key_values), n)] + values = self.get_random_sequence(types.int64, n) + expected_res = len(set(keys)) + self.assertEqual( + test_impl(key_type, value_type, keys, values), + expected_res + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/setup.py b/setup.py index 9335178d0..a730e7373 100644 --- a/setup.py +++ b/setup.py @@ -221,7 +221,30 @@ def check_file_at_path(path2file): library_dirs=lid, ) -_ext_mods = [ext_hdist, ext_chiframes, ext_set, ext_str, ext_dt, ext_io, ext_transport_seq, ext_sort] +ext_conc_dict = Extension(name="sdc.hconc_dict", + sources=[ + "sdc/native/conc_dict_module.cpp", + "sdc/native/utils.cpp"], + extra_compile_args=eca, + extra_link_args=ela, + libraries=['tbb'], + include_dirs=[ + "sdc/native/", + numba_include_path, + os.path.join(tbb_root, 'include')], + library_dirs=lid + [ + # for Linux + os.path.join(tbb_root, 'lib', 'intel64', 'gcc4.4'), + # for MacOS + os.path.join(tbb_root, 'lib'), + # for Windows + os.path.join(tbb_root, 'lib', 'intel64', 'vc_mt'), + ], + language="c++" + ) + +_ext_mods = [ext_hdist, ext_chiframes, ext_set, ext_str, ext_dt, ext_io, ext_transport_seq, ext_sort, + ext_conc_dict] # Support of Parquet is disabled because HPAT pipeline does not work now # if _has_pyarrow: