diff --git a/sdc/utilities/prange_utils.py b/sdc/utilities/prange_utils.py index 2380bb513..19c7efbaf 100644 --- a/sdc/utilities/prange_utils.py +++ b/sdc/utilities/prange_utils.py @@ -29,7 +29,7 @@ import sdc from typing import NamedTuple -from sdc.utilities.utils import sdc_overload +from sdc.utilities.utils import sdc_overload, sdc_register_jitable class Chunk(NamedTuple): @@ -54,36 +54,21 @@ def get_pool_size_impl(): return get_pool_size_impl +@sdc_register_jitable def get_chunks(size, pool_size=0): if pool_size == 0: pool_size = get_pool_size() + pool_size = min(pool_size, size) chunk_size = (size - 1) // pool_size + 1 chunks = [] for i in range(pool_size): - start = min(i * chunk_size, size) + start = i * chunk_size + if start >= size: + break stop = min((i + 1) * chunk_size, size) + chunks.append(Chunk(start, stop)) return chunks - - -@sdc_overload(get_chunks) -def get_chunks_overload(size, pool_size=0): - def get_chunks_impl(size, pool_size=0): - if pool_size == 0: - pool_size = get_pool_size() - - chunk_size = (size - 1) // pool_size + 1 - - chunks = [] - for i in range(pool_size): - start = min(i * chunk_size, size) - stop = min((i + 1) * chunk_size, size) - chunk = Chunk(start, stop) - chunks.append(chunk) - - return chunks - - return get_chunks_impl