diff --git a/sdc/tests/__init__.py b/sdc/tests/__init__.py index def2a3a4a..41236156a 100644 --- a/sdc/tests/__init__.py +++ b/sdc/tests/__init__.py @@ -45,6 +45,7 @@ from sdc.tests.test_hpat_jit import * from sdc.tests.test_sdc_numpy import * +from sdc.tests.test_prange_utils import * # performance tests import sdc.tests.tests_perf diff --git a/sdc/tests/test_prange_utils.py b/sdc/tests/test_prange_utils.py new file mode 100644 index 000000000..5ae0f629b --- /dev/null +++ b/sdc/tests/test_prange_utils.py @@ -0,0 +1,82 @@ +# ***************************************************************************** +# Copyright (c) 2020, Intel Corporation All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +from sdc.tests.test_base import TestCase + +from sdc.utilities.prange_utils import get_chunks, Chunk + + +class ChunkTest(TestCase): + + def _get_chunks_data(self): + yield (5, 5), [ + Chunk(start=0, stop=1), + Chunk(start=1, stop=2), + Chunk(start=2, stop=3), + Chunk(start=3, stop=4), + Chunk(start=4, stop=5), + ] + yield (5, 6), [ + Chunk(start=0, stop=1), + Chunk(start=1, stop=2), + Chunk(start=2, stop=3), + Chunk(start=3, stop=4), + Chunk(start=4, stop=5), + ] + yield (9, 5), [ + Chunk(start=0, stop=2), + Chunk(start=2, stop=4), + Chunk(start=4, stop=6), + Chunk(start=6, stop=8), + Chunk(start=8, stop=9), + ] + yield (9, 4), [ + Chunk(start=0, stop=3), + Chunk(start=3, stop=5), + Chunk(start=5, stop=7), + Chunk(start=7, stop=9), + ] + yield (9, 2), [ + Chunk(start=0, stop=5), + Chunk(start=5, stop=9), + ] + yield (9, 3), [ + Chunk(start=0, stop=3), + Chunk(start=3, stop=6), + Chunk(start=6, stop=9), + ] + + def _check_get_chunks(self, args, expected_chunks): + pyfunc = get_chunks + cfunc = self.jit(pyfunc) + + self.assertEqual(pyfunc(*args), expected_chunks) + self.assertEqual(cfunc(*args), expected_chunks) + + def test_get_chunks(self): + for args, expected_chunks in self._get_chunks_data(): + with self.subTest(args=args): + self._check_get_chunks(args, expected_chunks) diff --git a/sdc/utilities/prange_utils.py b/sdc/utilities/prange_utils.py index 2380bb513..0f24d34ea 100644 --- a/sdc/utilities/prange_utils.py +++ b/sdc/utilities/prange_utils.py @@ -29,7 +29,7 @@ import sdc from typing import NamedTuple -from sdc.utilities.utils import sdc_overload +from sdc.utilities.utils import sdc_overload, sdc_register_jitable class Chunk(NamedTuple): @@ -37,6 +37,7 @@ class Chunk(NamedTuple): stop: int +@sdc_register_jitable def get_pool_size(): if sdc.config.config_use_parallel_overloads: return numba.config.NUMBA_NUM_THREADS @@ -44,46 +45,21 @@ def get_pool_size(): return 1 -@sdc_overload(get_pool_size) -def get_pool_size_overload(): - pool_size = get_pool_size() - - def get_pool_size_impl(): - return pool_size - - return get_pool_size_impl - - -def get_chunks(size, pool_size=0): - if pool_size == 0: - pool_size = get_pool_size() - - chunk_size = (size - 1) // pool_size + 1 +@sdc_register_jitable +def get_chunks(size, pool_size): + pool_size = min(pool_size, size) + chunk_size = size // pool_size + overload_size = size % pool_size chunks = [] for i in range(pool_size): - start = min(i * chunk_size, size) - stop = min((i + 1) * chunk_size, size) + start = i * chunk_size + min(i, overload_size) + stop = (i + 1) * chunk_size + min(i + 1, overload_size) chunks.append(Chunk(start, stop)) return chunks -@sdc_overload(get_chunks) -def get_chunks_overload(size, pool_size=0): - def get_chunks_impl(size, pool_size=0): - if pool_size == 0: - pool_size = get_pool_size() - - chunk_size = (size - 1) // pool_size + 1 - - chunks = [] - for i in range(pool_size): - start = min(i * chunk_size, size) - stop = min((i + 1) * chunk_size, size) - chunk = Chunk(start, stop) - chunks.append(chunk) - - return chunks - - return get_chunks_impl +@sdc_register_jitable +def parallel_chunks(size): + return get_chunks(size, get_pool_size())