From b4e47b78daa82847c11e6ee42827da0ea93ed315 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Mon, 27 Jun 2022 15:57:24 +0200 Subject: [PATCH 1/7] Add warnings against suspicious or erroneous end of epoch conditions Signed-off-by: Kamil Tokarski --- dali/python/nvidia/dali/pipeline.py | 20 ++++++++++++++++++- .../nvidia/dali/plugin/base_iterator.py | 9 +++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/dali/python/nvidia/dali/pipeline.py b/dali/python/nvidia/dali/pipeline.py index 196fbb4b9e..b770fd6286 100644 --- a/dali/python/nvidia/dali/pipeline.py +++ b/dali/python/nvidia/dali/pipeline.py @@ -920,6 +920,12 @@ def schedule_run(self): Should not be mixed with :meth:`run` in the same pipeline""" with self._check_api_type_scope(types.PipelineAPIType.SCHEDULED): if self._first_iter and self._exec_pipelined: + if not self.empty(): + warnings.warn( + "Prefetching data into a non-empty pipeline may result in corrupted " + "outputs. Please make sure all batches from previous epoch are consumed " + "before scheduling work for a new epoch.", + Warning) self._prefetch() else: self._run_once() @@ -1078,7 +1084,19 @@ def reset(self): If pipeline iterator reached the end then reset its state to the beginning. """ - if self._last_iter: + if not self._last_iter: + # resetting before some external source raised StopIteration is a no-op + if self._input_callbacks: + warnings.warn( + "Resetting the pipeline before any of the external sources reached " + "the end of epoch (i.e. raised StopIteration) has no effect.", + Warning) + else: + if not self.empty(): + warnings.warn( + "Resetting the pipeline before all scheduled batches have been consumed " + "is discouraged and may be unsupported in the future.", + Warning) self._first_iter = True self._last_iter = False self._iter = 0 diff --git a/dali/python/nvidia/dali/plugin/base_iterator.py b/dali/python/nvidia/dali/plugin/base_iterator.py index d6fa9ca5de..ab34fb7877 100644 --- a/dali/python/nvidia/dali/plugin/base_iterator.py +++ b/dali/python/nvidia/dali/plugin/base_iterator.py @@ -299,6 +299,15 @@ def _get_outputs(self): # in case ExternalSource returns StopIteration if self._size < 0 and self._auto_reset == "yes": self.reset() + if self._size >= 0: + warnings.warn( + f"Pipeline unexpectedly raised StopIteration before reaching the end of " + f"dataset. There were {self._counter} samples returned in this epoch, but " + f"{self.size} was passed as `size`. Please verify the `size` value or " + f"consider alternatives. For DALI readers, please use `reader_name` instead. " + f"For external source, you may rely solely on raising StopIteration " + f"from the source.", + Warning) raise e self._check_batch_size(outputs) return outputs From 2ed59da1c2c525533179b19a0aaab5fefe2ffd59 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Mon, 27 Jun 2022 18:53:59 +0200 Subject: [PATCH 2/7] Add assert_no_warnings utility Signed-off-by: Kamil Tokarski --- dali/test/python/nose_utils.py | 39 ++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/dali/test/python/nose_utils.py b/dali/test/python/nose_utils.py index cb7e72d7b9..eb0d5f8b79 100644 --- a/dali/test/python/nose_utils.py +++ b/dali/test/python/nose_utils.py @@ -24,9 +24,11 @@ nose.loader.collections = collections.abc nose.suite.collections = collections.abc +import contextlib import nose.tools as tools import re import fnmatch +import warnings def glob_to_regex(glob): @@ -91,6 +93,43 @@ def assert_warns(exception=Warning, *args, glob=None, regex=None, match_case=Non return tools.assert_warns_regex(exception, pattern, *args, **kwargs) +@contextlib.contextmanager +def assert_no_warnings(exception=None, glob=None, regex=None, match_case=None): + msg_param_provided = any(param is not None for param in (glob, regex, match_case)) + pattern = None + if msg_param_provided: + pattern = get_pattern(glob, regex, match_case) + assert pattern is not None + pattern = pattern if isinstance(pattern, re.Pattern) else re.compile(pattern) + if exception is None: + exception = Warning + + with warnings.catch_warnings(record=True) as recorder_warnings: + try: + yield recorder_warnings + finally: + if exception is None: + if len(recorder_warnings): + raise AssertionError( + f"Test expected to emit no warnings emitted the following " + f"warnings: {[str(w) for w in recorder_warnings]}") + elif not msg_param_provided: + for m in recorder_warnings: + w = m.message + if isinstance(w, exception): + raise AssertionError( + f"Test expected to emit no warning of type {exception} emitted " + f"the following warning: {str(w)}") + else: + for m in recorder_warnings: + w = m.message + if isinstance(w, exception) and pattern.search(str(w)): + raise AssertionError( + f"Test was expected to emit no warning of type {exception} matching " + f"the pattern {pattern}, but the following warning was " + f"emitted: {str(w)}") + + def raises(exception, glob=None, regex=None, match_case=None): """ From c58c6994e50f53acfac591b8f8cd6041dda2d273 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Mon, 27 Jun 2022 19:10:48 +0200 Subject: [PATCH 3/7] Add test for .run API end of epoch warnings Signed-off-by: Kamil Tokarski --- .../python/test_external_source_parallel.py | 62 ++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/dali/test/python/test_external_source_parallel.py b/dali/test/python/test_external_source_parallel.py index f1856352db..89f6eff14f 100644 --- a/dali/test/python/test_external_source_parallel.py +++ b/dali/test/python/test_external_source_parallel.py @@ -18,7 +18,7 @@ from nvidia.dali.types import SampleInfo, BatchInfo import test_external_source_parallel_utils as utils -from nose_utils import raises +from nose_utils import raises, assert_warns, assert_no_warnings def no_arg_fun(): @@ -732,3 +732,63 @@ def test_permute_dataset(): for reader_queue_depth in (1, 5): yield _test_permute_dataset, batch_size, epoch_size, trailing_samples, \ cb, 4, 1, reader_queue_depth + + +@with_setup(utils.setup_function, utils.teardown_function) +def _test_no_op_reset_warning(stop_earlier, prefetch_queue_depth, source, batch, + batch_size, num_iterations, expected_warning): + + def run_pipeline(): + + @dali.pipeline_def + def pipeline(): + out = dali.fn.external_source(source=source, parallel=True, batch=batch) + return out + + pipe = pipeline( + batch_size=batch_size, device_id=0, num_threads=4, + prefetch_queue_depth=prefetch_queue_depth, + py_start_method="spawn") + pipe.build() + utils.capture_processes(pipe._py_pool) + for _ in range(num_iterations - stop_earlier): + pipe.run() + pipe.reset() + + if expected_warning is None: + with assert_no_warnings(): + run_pipeline() + else: + with assert_warns(Warning, glob=expected_warning): + run_pipeline() + + +def test_no_op_reset_warning(): + num_iterations = 5 + batch_size = 8 + + def gen_source(): + for i in range(num_iterations): + yield [np.full((1024, 1024), batch_size * i + j) for j in range(batch_size)] + + def cb_source(sample_info): + if sample_info.idx_in_epoch >= 42: + raise StopIteration + return np.full((5, 5), sample_info.idx_in_epoch) + + for source, batch in ((gen_source, True), (cb_source, False)): + for prefetch_queue_depth in (1, 2, 3): + for stop_earlier in range(prefetch_queue_depth): + if stop_earlier == prefetch_queue_depth - 1: + expected_warning = ( + "Resetting the pipeline before any of the external sources " + "reached the end of epoch (i.e. raised StopIteration) " + "has no effect.") + elif stop_earlier > 0: + expected_warning = ( + "Resetting the pipeline before all scheduled batches have been " + "consumed is discouraged and may be unsupported in the future.") + else: + expected_warning = None + yield _test_no_op_reset_warning, stop_earlier, prefetch_queue_depth, \ + source, batch, batch_size, num_iterations, expected_warning From 1d1bd8bb8e64191793db45d51c7f662f3e5c3634 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Tue, 28 Jun 2022 19:23:36 +0200 Subject: [PATCH 4/7] Test warnings against unexpected StopIteration in iterators Signed-off-by: Kamil Tokarski --- dali/python/nvidia/dali/pipeline.py | 2 + dali/test/python/test_fw_iterators.py | 89 +++++++++++++++++++++++++-- 2 files changed, 85 insertions(+), 6 deletions(-) diff --git a/dali/python/nvidia/dali/pipeline.py b/dali/python/nvidia/dali/pipeline.py index b770fd6286..f6ec873abf 100644 --- a/dali/python/nvidia/dali/pipeline.py +++ b/dali/python/nvidia/dali/pipeline.py @@ -920,6 +920,8 @@ def schedule_run(self): Should not be mixed with :meth:`run` in the same pipeline""" with self._check_api_type_scope(types.PipelineAPIType.SCHEDULED): if self._first_iter and self._exec_pipelined: + # For one, when prefetching, parallel external source will reuse buffers + # that might be still referenced by no_copy input fed to pipeline if not self.empty(): warnings.warn( "Prefetching data into a non-empty pipeline may result in corrupted " diff --git a/dali/test/python/test_fw_iterators.py b/dali/test/python/test_fw_iterators.py index f22a809a70..303a87c873 100644 --- a/dali/test/python/test_fw_iterators.py +++ b/dali/test/python/test_fw_iterators.py @@ -20,7 +20,7 @@ import os from test_utils import get_dali_extra_path from nose.tools import nottest -from nose_utils import raises, assert_raises +from nose_utils import raises, assert_raises, assert_warns, assert_no_warnings from nvidia.dali.plugin.base_iterator import LastBatchPolicy as LastBatchPolicy import random @@ -1031,11 +1031,17 @@ def check_stop_iter(fw_iter, iterator_name, batch_size, epochs, iter_num, total_ iter_size = it.size loader = fw_iter(pipe, iter_size, auto_reset) count = 0 - for _ in range(epochs): - for _ in enumerate(loader): - count += 1 - if not auto_reset: - loader.reset() + if not infinite and total_iter_num >= 0 and epochs * iter_num > total_iter_num: + warning_check = assert_warns + else: + warning_check = assert_no_warnings + with warning_check(glob="Pipeline unexpectedly raised StopIteration before reaching " + "the end of dataset."): + for _ in range(epochs): + for _ in enumerate(loader): + count += 1 + if not auto_reset: + loader.reset() if total_iter_num < 0: # infinite source of data assert(count == iter_num * epochs) @@ -1044,6 +1050,41 @@ def check_stop_iter(fw_iter, iterator_name, batch_size, epochs, iter_num, total_ assert(count == min(total_iter_num, iter_num * epochs)) +def check_too_early_reset(fw_iter, auto_reset, prefetch_queue_depth): + num_source_iters = 10 + num_fw_iter_iters = num_source_iters - prefetch_queue_depth + 1 + batch_size = 6 + if prefetch_queue_depth == 2: + expected_warning = ("Resetting the pipeline before all scheduled batches have been " + "consumed is discouraged and may be unsupported in the future.") + else: + assert prefetch_queue_depth == 3 + expected_warning = ("Prefetching data into a non-empty pipeline may result " + "in corrupted outputs.") + + def source(): + for i in range(num_source_iters): + yield [np.array([i, j]) for j in range(batch_size)] + + @pipeline_def + def pipeline(): + return fn.external_source(source=source, parallel=True, cycle="raise") + + pipe = pipeline( + batch_size=batch_size, device_id=0, num_threads=4, + py_start_method="spawn", prefetch_queue_depth=prefetch_queue_depth) + pipe.build() + loader = fw_iter(pipe, num_fw_iter_iters * batch_size, auto_reset) + with assert_warns(glob=expected_warning): + for i, _ in enumerate(loader): + pass + assert i + 1 == num_fw_iter_iters, f"Expected {num_fw_iter_iters} iterations, got {i + 1}" + if not auto_reset: + loader.reset() + next(loader) + + + @raises(Exception, glob="Negative size is supported only for a single pipeline") def check_stop_iter_fail_multi(fw_iter): batch_size = 1 @@ -1190,6 +1231,15 @@ def fw_iter(pipe, size, auto_reset): return MXNetIterator( check_stop_iter_fail_single(fw_iter) +def test_too_early_reset_mxnet_warning(): + from nvidia.dali.plugin.mxnet import DALIGenericIterator as MXNetIterator + def fw_iter(pipe, size, auto_reset): return MXNetIterator( + pipe, [("data", MXNetIterator.DATA_TAG)], size=size, auto_reset=auto_reset) + for auto_reset in (True, False): + for prefetch_queue_depth in (2, 3): + yield check_too_early_reset, fw_iter, auto_reset, prefetch_queue_depth + + def test_mxnet_iterator_wrapper_first_iteration(): from nvidia.dali.plugin.mxnet import DALIGenericIterator as MXNetIterator check_iterator_wrapper_first_iteration( @@ -1293,6 +1343,15 @@ def fw_iter(pipe, size, auto_reset): return GluonIterator( check_stop_iter_fail_single(fw_iter) +def test_too_early_reset_gluon_warning(): + from nvidia.dali.plugin.mxnet import DALIGluonIterator as GluonIterator + def fw_iter(pipe, size, auto_reset): return GluonIterator( + pipe, size=size, auto_reset=auto_reset) + for auto_reset in (True, False): + for prefetch_queue_depth in (2, 3): + yield check_too_early_reset, fw_iter, auto_reset, prefetch_queue_depth + + def test_gluon_iterator_wrapper_first_iteration(): from nvidia.dali.plugin.mxnet import DALIGluonIterator as GluonIterator check_iterator_wrapper_first_iteration(GluonIterator, output_types=[ @@ -1348,6 +1407,15 @@ def fw_iter(pipe, size, auto_reset): return PyTorchIterator( check_stop_iter_fail_single(fw_iter) +def test_too_early_reset_pytorch_warning(): + from nvidia.dali.plugin.pytorch import DALIGenericIterator as PyTorchIterator + def fw_iter(pipe, size, auto_reset): return PyTorchIterator( + pipe, output_map=["data"], size=size, auto_reset=auto_reset) + for auto_reset in (True, False): + for prefetch_queue_depth in (2, 3): + yield check_too_early_reset, fw_iter, auto_reset, prefetch_queue_depth + + def test_pytorch_iterator_wrapper_first_iteration(): from nvidia.dali.plugin.pytorch import DALIGenericIterator as PyTorchIterator check_iterator_wrapper_first_iteration( @@ -1403,6 +1471,15 @@ def fw_iter(pipe, size, auto_reset): return PaddleIterator( check_stop_iter_fail_single(fw_iter) +def test_too_early_reset_paddle_warning(): + from nvidia.dali.plugin.paddle import DALIGenericIterator as PaddleIterator + def fw_iter(pipe, size, auto_reset): return PaddleIterator( + pipe, output_map=["data"], size=size, auto_reset=auto_reset) + for auto_reset in (True, False): + for prefetch_queue_depth in (2, 3): + yield check_too_early_reset, fw_iter, auto_reset, prefetch_queue_depth + + def test_paddle_iterator_wrapper_first_iteration(): from nvidia.dali.plugin.paddle import DALIGenericIterator as PaddleIterator check_iterator_wrapper_first_iteration( From f2ea23bfbb56c8e193d3e1751d1a20d8788e1866 Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Wed, 29 Jun 2022 11:05:17 +0200 Subject: [PATCH 5/7] Add parallel external source test to fw iterators Signed-off-by: Kamil Tokarski --- dali/test/python/test_fw_iterators.py | 102 ++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/dali/test/python/test_fw_iterators.py b/dali/test/python/test_fw_iterators.py index 303a87c873..3dde6af4d5 100644 --- a/dali/test/python/test_fw_iterators.py +++ b/dali/test/python/test_fw_iterators.py @@ -1050,6 +1050,79 @@ def check_stop_iter(fw_iter, iterator_name, batch_size, epochs, iter_num, total_ assert(count == min(total_iter_num, iter_num * epochs)) +def create_parallel_pipeline(source_type, lightweight, reader_queue_depth, batch_size, + num_iterations, **kwargs): + sample_shape = (1, 2,) if lightweight else (1024, 1024) + if source_type == "generator": + cycle = None + batch_mode = False + def source(sample_info): + if sample_info.iteration >= num_iterations: + raise StopIteration + a = np.full(sample_shape, sample_info.idx_in_epoch, dtype=np.int32) + a[0, 0] = sample_info.epoch_idx + return a + else: + cycle = "raise" + batch_mode = True + epoch_idx = 0 + def source(): + nonlocal epoch_idx + for i in range(num_iterations): + batch = [] + for j in range(batch_size): + idx_in_epoch = i * batch_size + j + a = np.full(sample_shape, idx_in_epoch, dtype=np.int32) + a[0, 0] = epoch_idx + batch.append(a) + yield batch + epoch_idx += 1 + + @pipeline_def + def pipeline(): + data = fn.external_source( + source=source, batch=batch_mode, cycle=cycle, + prefetch_queue_depth=reader_queue_depth, + parallel=True) + if lightweight: + return data, data + else: + return data, fn.gaussian_blur(data.gpu(), window_size=31) + + return pipeline( + batch_size=batch_size, device_id=0, num_threads=4, + py_start_method="spawn",**kwargs) + + +def _check_parallel_stop_iter(fw_iter, source_type, lightweight, fw_size_aware, + prefetch_queue_depth, reader_queue_depth): + num_iterations = 7 + batch_size = 6 + num_epochs = 3 + pipe = create_parallel_pipeline( + source_type, lightweight, reader_queue_depth, batch_size, num_iterations, + prefetch_queue_depth=prefetch_queue_depth) + pipe.build() + fw_size = -1 if not fw_size_aware else num_iterations * batch_size + it = fw_iter(pipe, fw_size, True) + with assert_no_warnings(regex="Resetting the pipeline|Prefetching data|StopIteration"): + epochs = [[batch for batch in it] for _ in range(num_epochs)] + assert len(epochs) + for epoch_idx, batches in enumerate(epochs): + assert len(batches) == num_iterations, \ + f"Expected {num_iterations} batches in epoch {epoch_idx}, got: {len(batches)}" + + +def check_parallel_stop_iter(fw_iter): + for source_type in ('generator', 'callback'): + for lightweight in (True, False): + for fw_size_aware in (True, False): + for prefetch_queue_depth in (1, 2, 3): + for reader_queue_depth in (1, 2, 3): + yield _check_parallel_stop_iter, fw_iter, source_type, lightweight, \ + fw_size_aware, prefetch_queue_depth, reader_queue_depth + + def check_too_early_reset(fw_iter, auto_reset, prefetch_queue_depth): num_source_iters = 10 num_fw_iter_iters = num_source_iters - prefetch_queue_depth + 1 @@ -1240,6 +1313,14 @@ def fw_iter(pipe, size, auto_reset): return MXNetIterator( yield check_too_early_reset, fw_iter, auto_reset, prefetch_queue_depth +def test_stop_iteration_parallel_mxnet(): + from nvidia.dali.plugin.mxnet import DALIGenericIterator as MXNetIterator + def fw_iter(pipe, size, auto_reset): return MXNetIterator( + pipe, [("data_0", MXNetIterator.DATA_TAG), ("data_1", MXNetIterator.DATA_TAG)], + size=size, auto_reset=auto_reset) + yield from check_parallel_stop_iter(fw_iter) + + def test_mxnet_iterator_wrapper_first_iteration(): from nvidia.dali.plugin.mxnet import DALIGenericIterator as MXNetIterator check_iterator_wrapper_first_iteration( @@ -1352,6 +1433,13 @@ def fw_iter(pipe, size, auto_reset): return GluonIterator( yield check_too_early_reset, fw_iter, auto_reset, prefetch_queue_depth +def test_stop_iteration_parallel_gluon(): + from nvidia.dali.plugin.mxnet import DALIGluonIterator as GluonIterator + def fw_iter(pipe, size, auto_reset): return GluonIterator( + pipe, size=size, auto_reset=auto_reset) + yield from check_parallel_stop_iter(fw_iter) + + def test_gluon_iterator_wrapper_first_iteration(): from nvidia.dali.plugin.mxnet import DALIGluonIterator as GluonIterator check_iterator_wrapper_first_iteration(GluonIterator, output_types=[ @@ -1416,6 +1504,13 @@ def fw_iter(pipe, size, auto_reset): return PyTorchIterator( yield check_too_early_reset, fw_iter, auto_reset, prefetch_queue_depth +def test_stop_iteration_parallel_pytorch(): + from nvidia.dali.plugin.pytorch import DALIGenericIterator as PyTorchIterator + def fw_iter(pipe, size, auto_reset): return PyTorchIterator( + pipe, output_map=["data_0", "data_1"], size=size, auto_reset=auto_reset) + yield from check_parallel_stop_iter(fw_iter) + + def test_pytorch_iterator_wrapper_first_iteration(): from nvidia.dali.plugin.pytorch import DALIGenericIterator as PyTorchIterator check_iterator_wrapper_first_iteration( @@ -1480,6 +1575,13 @@ def fw_iter(pipe, size, auto_reset): return PaddleIterator( yield check_too_early_reset, fw_iter, auto_reset, prefetch_queue_depth +def test_stop_iteration_parallel_paddle(): + from nvidia.dali.plugin.paddle import DALIGenericIterator as PaddleIterator + def fw_iter(pipe, size, auto_reset): return PaddleIterator( + pipe, output_map=["data_0", "data_1"], size=size, auto_reset=auto_reset) + yield from check_parallel_stop_iter(fw_iter) + + def test_paddle_iterator_wrapper_first_iteration(): from nvidia.dali.plugin.paddle import DALIGenericIterator as PaddleIterator check_iterator_wrapper_first_iteration( From cd233bb063a0769fa04d4fd3684a1a728bdd648a Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Wed, 29 Jun 2022 12:57:24 +0200 Subject: [PATCH 6/7] Check actual data in the returned samples Signed-off-by: Kamil Tokarski --- dali/test/python/test_fw_iterators.py | 37 ++++++++++++++++++--------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/dali/test/python/test_fw_iterators.py b/dali/test/python/test_fw_iterators.py index 3dde6af4d5..91de90e4b5 100644 --- a/dali/test/python/test_fw_iterators.py +++ b/dali/test/python/test_fw_iterators.py @@ -1050,9 +1050,8 @@ def check_stop_iter(fw_iter, iterator_name, batch_size, epochs, iter_num, total_ assert(count == min(total_iter_num, iter_num * epochs)) -def create_parallel_pipeline(source_type, lightweight, reader_queue_depth, batch_size, - num_iterations, **kwargs): - sample_shape = (1, 2,) if lightweight else (1024, 1024) +def create_parallel_pipeline(source_type, lightweight, sample_shape, reader_queue_depth, + batch_size, num_iterations, **kwargs): if source_type == "generator": cycle = None batch_mode = False @@ -1095,13 +1094,15 @@ def pipeline(): def _check_parallel_stop_iter(fw_iter, source_type, lightweight, fw_size_aware, - prefetch_queue_depth, reader_queue_depth): + prefetch_queue_depth, reader_queue_depth, + batch_as_np): + sample_shape = (1, 2,) if lightweight else (1024, 1024) num_iterations = 7 batch_size = 6 num_epochs = 3 pipe = create_parallel_pipeline( - source_type, lightweight, reader_queue_depth, batch_size, num_iterations, - prefetch_queue_depth=prefetch_queue_depth) + source_type, lightweight, sample_shape, reader_queue_depth, batch_size, + num_iterations, prefetch_queue_depth=prefetch_queue_depth) pipe.build() fw_size = -1 if not fw_size_aware else num_iterations * batch_size it = fw_iter(pipe, fw_size, True) @@ -1111,16 +1112,23 @@ def _check_parallel_stop_iter(fw_iter, source_type, lightweight, fw_size_aware, for epoch_idx, batches in enumerate(epochs): assert len(batches) == num_iterations, \ f"Expected {num_iterations} batches in epoch {epoch_idx}, got: {len(batches)}" + for iter_idx, fw_batch in enumerate(batches): + batch = batch_as_np(fw_batch) + for sample_idx, sample in enumerate(batch): + sample_idx = batch_size * iter_idx + sample_idx + ref = np.full(sample_shape, sample_idx, dtype=np.int32) + ref[0, 0] = epoch_idx + np.testing.assert_array_equal(sample, ref) -def check_parallel_stop_iter(fw_iter): +def check_parallel_stop_iter(fw_iter, batch_as_np): for source_type in ('generator', 'callback'): for lightweight in (True, False): for fw_size_aware in (True, False): for prefetch_queue_depth in (1, 2, 3): for reader_queue_depth in (1, 2, 3): yield _check_parallel_stop_iter, fw_iter, source_type, lightweight, \ - fw_size_aware, prefetch_queue_depth, reader_queue_depth + fw_size_aware, prefetch_queue_depth, reader_queue_depth, batch_as_np def check_too_early_reset(fw_iter, auto_reset, prefetch_queue_depth): @@ -1318,7 +1326,8 @@ def test_stop_iteration_parallel_mxnet(): def fw_iter(pipe, size, auto_reset): return MXNetIterator( pipe, [("data_0", MXNetIterator.DATA_TAG), ("data_1", MXNetIterator.DATA_TAG)], size=size, auto_reset=auto_reset) - yield from check_parallel_stop_iter(fw_iter) + def as_np(batch): return batch[0].data[0].asnumpy() + yield from check_parallel_stop_iter(fw_iter, as_np) def test_mxnet_iterator_wrapper_first_iteration(): @@ -1437,7 +1446,8 @@ def test_stop_iteration_parallel_gluon(): from nvidia.dali.plugin.mxnet import DALIGluonIterator as GluonIterator def fw_iter(pipe, size, auto_reset): return GluonIterator( pipe, size=size, auto_reset=auto_reset) - yield from check_parallel_stop_iter(fw_iter) + def as_np(batch): return batch[0][0].asnumpy() + yield from check_parallel_stop_iter(fw_iter, as_np) def test_gluon_iterator_wrapper_first_iteration(): @@ -1508,7 +1518,8 @@ def test_stop_iteration_parallel_pytorch(): from nvidia.dali.plugin.pytorch import DALIGenericIterator as PyTorchIterator def fw_iter(pipe, size, auto_reset): return PyTorchIterator( pipe, output_map=["data_0", "data_1"], size=size, auto_reset=auto_reset) - yield from check_parallel_stop_iter(fw_iter) + def as_np(batch): return batch[0]["data_0"].numpy() + yield from check_parallel_stop_iter(fw_iter, as_np) def test_pytorch_iterator_wrapper_first_iteration(): @@ -1579,7 +1590,9 @@ def test_stop_iteration_parallel_paddle(): from nvidia.dali.plugin.paddle import DALIGenericIterator as PaddleIterator def fw_iter(pipe, size, auto_reset): return PaddleIterator( pipe, output_map=["data_0", "data_1"], size=size, auto_reset=auto_reset) - yield from check_parallel_stop_iter(fw_iter) + def as_np(batch): + return np.array(batch[0]['data_0']) + yield from check_parallel_stop_iter(fw_iter, as_np) def test_paddle_iterator_wrapper_first_iteration(): From b68227243fa766732e847cfb9b7a3cddb7d2592e Mon Sep 17 00:00:00 2001 From: Kamil Tokarski Date: Wed, 29 Jun 2022 15:05:30 +0200 Subject: [PATCH 7/7] Fix pes with generator test resetting one of the piplines too early Signed-off-by: Kamil Tokarski --- ...rnal_source_parallel_custom_serialization.py | 17 +++++++++-------- dali/test/python/test_utils.py | 13 +++++++++++++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/dali/test/python/test_external_source_parallel_custom_serialization.py b/dali/test/python/test_external_source_parallel_custom_serialization.py index 12fab54252..4afbe85216 100644 --- a/dali/test/python/test_external_source_parallel_custom_serialization.py +++ b/dali/test/python/test_external_source_parallel_custom_serialization.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ from pickle import PicklingError from nose_utils import raises -from test_utils import get_dali_extra_path, restrict_python_version +from test_utils import get_dali_extra_path, restrict_python_version, run_pipelines tests_dali_pickling = [] @@ -230,7 +230,8 @@ def create_pipline(): def create_decoding_pipeline(callback, py_callback_pickler, batch_size, parallel=True, - py_num_workers=None, py_start_method="spawn", batch=False): + py_num_workers=None, py_start_method="spawn", batch=False, + cycle=None): extra = {} if parallel: extra["py_num_workers"] = py_num_workers @@ -242,7 +243,8 @@ def create_decoding_pipeline(callback, py_callback_pickler, batch_size, parallel def create_pipline(): jpegs, labels = fn.external_source( source=callback, num_outputs=2, - batch=batch, parallel=parallel) + batch=batch, parallel=parallel, + cycle=cycle) images = fn.decoders.image(jpegs, device="cpu") return images, labels @@ -250,8 +252,7 @@ def create_pipline(): def _run_and_compare_outputs(batch_size, parallel_pipeline, serial_pipeline): - parallel_batch = parallel_pipeline.run() - serial_batch = serial_pipeline.run() + parallel_batch, serial_batch = run_pipelines(parallel_pipeline, serial_pipeline) for (parallel_output, serial_output) in zip(parallel_batch, serial_batch): assert len(parallel_output) == batch_size assert len(serial_output) == batch_size @@ -532,9 +533,9 @@ def _test_generator_closure(name, py_callback_pickler): batch_size=batch_size, data_set_size=batches_in_epoch * batch_size) parallel_pipeline = create_decoding_pipeline(callback, py_callback_pickler, batch_size=batch_size, py_num_workers=1, - parallel=True, batch=True) + parallel=True, batch=True, cycle="raise") serial_pipeline = create_decoding_pipeline(callback, None, batch_size=batch_size, - parallel=False, batch=True) + parallel=False, batch=True, cycle="raise") _build_and_compare_pipelines_epochs(epochs_num, batch_size, parallel_pipeline, serial_pipeline) diff --git a/dali/test/python/test_utils.py b/dali/test/python/test_utils.py index 668fabd1b1..f388b98332 100644 --- a/dali/test/python/test_utils.py +++ b/dali/test/python/test_utils.py @@ -661,3 +661,16 @@ def wrapper(*exec_inputs): return function(*iteration_inputs) return dali.fn.python_function(*node_inputs, function=wrapper, **kwargs) + + +def run_pipelines(*pipelines): + batches = [] + stop_iter = False + for pipeline in pipelines: + try: + batches.append(pipeline.run()) + except StopIteration: + stop_iter = True + if stop_iter: + raise StopIteration + return batches