Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warnings against resetting pipeline before end of epoch and test parallel ES with fw iterator #4023

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
22 changes: 21 additions & 1 deletion dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,14 @@ def schedule_run(self):
Should not be mixed with :meth:`run` in the same pipeline"""
with self._check_api_type_scope(types.PipelineAPIType.SCHEDULED):
if self._first_iter and self._exec_pipelined:
# For one, when prefetching, parallel external source will reuse buffers
# that might be still referenced by no_copy input fed to pipeline
if not self.empty():
warnings.warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is possible, prefetch queue depth is 2, batches to consume is 1, we can still schedule one more run. The native part can overschedule - it will just wait for the empty output buffer, but the ES may fail (parallel mode with nocopy, but the regular ES does copy and have an internal queue).

Copy link
Member Author

@stiepan stiepan Jun 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that it should be supported? By possible you mean that using fw iterators correctly you may still end up in such situation?

If that should be supported:

  • Should we warn only if we have PES and treat it as extra limitation of schedule api that is not there otherwise.
  • Or the PES needs to be adjusted to handle that as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we warn only if we have PES and treat it as extra limitation of schedule api that is not there otherwise.

I think so. You can still use this API without the FW iterator.

"Prefetching data into a non-empty pipeline may result in corrupted "
"outputs. Please make sure all batches from previous epoch are consumed "
"before scheduling work for a new epoch.",
Warning)
self._prefetch()
else:
self._run_once()
Expand Down Expand Up @@ -1078,7 +1086,19 @@ def reset(self):

If pipeline iterator reached the end then reset its state to the beginning.
"""
if self._last_iter:
if not self._last_iter:
# resetting before some external source raised StopIteration is a no-op
if self._input_callbacks:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How legitimate use case is pipeline with external source that is infinite + FW iterator with iterator size passed explicitly? Because in that case, the warning will be triggered too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤷 I guess it still can happen and we should behave consistently.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😟I thought that one would be useful if one has two or more pipelines with (P)ES that raise StopIteration and should be reset but for some reason the epochs diverge. Either because the number of iterations is really (but unintentionally) different in those two or because one uses .run API, with prefetch_queue_depth 1 and resets all pipelines when the first one raises, not letting the others actually reach end of epoch.

Copy link
Member Author

@stiepan stiepan Jun 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I could safegaurd this check with pipeline._epoch_idx> 0. It seems that if it has ever been incremented, then there must be ES that raises StopIteration.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤷 I guess it still can happen and we should behave consistently.

I think that we should warn that something may be not right and if te user provides -1 as the size it should work silently.

warnings.warn(
"Resetting the pipeline before any of the external sources reached "
"the end of epoch (i.e. raised StopIteration) has no effect.",
Warning)
else:
if not self.empty():
warnings.warn(
"Resetting the pipeline before all scheduled batches have been consumed "
"is discouraged and may be unsupported in the future.",
Warning)
self._first_iter = True
self._last_iter = False
self._iter = 0
Expand Down
9 changes: 9 additions & 0 deletions dali/python/nvidia/dali/plugin/base_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,15 @@ def _get_outputs(self):
# in case ExternalSource returns StopIteration
if self._size < 0 and self._auto_reset == "yes":
self.reset()
if self._size >= 0:
warnings.warn(
f"Pipeline unexpectedly raised StopIteration before reaching the end of "
f"dataset. There were {self._counter} samples returned in this epoch, but "
f"{self.size} was passed as `size`. Please verify the `size` value or "
f"consider alternatives. For DALI readers, please use `reader_name` instead. "
f"For external source, you may rely solely on raising StopIteration "
f"from the source.",
Warning)
raise e
self._check_batch_size(outputs)
return outputs
Expand Down
39 changes: 39 additions & 0 deletions dali/test/python/nose_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
nose.loader.collections = collections.abc
nose.suite.collections = collections.abc

import contextlib
import nose.tools as tools
import re
import fnmatch
import warnings


def glob_to_regex(glob):
Expand Down Expand Up @@ -91,6 +93,43 @@ def assert_warns(exception=Warning, *args, glob=None, regex=None, match_case=Non
return tools.assert_warns_regex(exception, pattern, *args, **kwargs)


@contextlib.contextmanager
def assert_no_warnings(exception=None, glob=None, regex=None, match_case=None):
msg_param_provided = any(param is not None for param in (glob, regex, match_case))
pattern = None
if msg_param_provided:
pattern = get_pattern(glob, regex, match_case)
assert pattern is not None
pattern = pattern if isinstance(pattern, re.Pattern) else re.compile(pattern)
if exception is None:
exception = Warning

with warnings.catch_warnings(record=True) as recorder_warnings:
try:
yield recorder_warnings
finally:
if exception is None:
if len(recorder_warnings):
raise AssertionError(
f"Test expected to emit no warnings emitted the following "
f"warnings: {[str(w) for w in recorder_warnings]}")
elif not msg_param_provided:
for m in recorder_warnings:
w = m.message
if isinstance(w, exception):
raise AssertionError(
f"Test expected to emit no warning of type {exception} emitted "
f"the following warning: {str(w)}")
else:
for m in recorder_warnings:
w = m.message
if isinstance(w, exception) and pattern.search(str(w)):
raise AssertionError(
f"Test was expected to emit no warning of type {exception} matching "
f"the pattern {pattern}, but the following warning was "
f"emitted: {str(w)}")


def raises(exception, glob=None, regex=None, match_case=None):

"""
Expand Down
62 changes: 61 additions & 1 deletion dali/test/python/test_external_source_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from nvidia.dali.types import SampleInfo, BatchInfo

import test_external_source_parallel_utils as utils
from nose_utils import raises
from nose_utils import raises, assert_warns, assert_no_warnings


def no_arg_fun():
Expand Down Expand Up @@ -732,3 +732,63 @@ def test_permute_dataset():
for reader_queue_depth in (1, 5):
yield _test_permute_dataset, batch_size, epoch_size, trailing_samples, \
cb, 4, 1, reader_queue_depth


@with_setup(utils.setup_function, utils.teardown_function)
def _test_no_op_reset_warning(stop_earlier, prefetch_queue_depth, source, batch,
batch_size, num_iterations, expected_warning):

def run_pipeline():

@dali.pipeline_def
def pipeline():
out = dali.fn.external_source(source=source, parallel=True, batch=batch)
return out

pipe = pipeline(
batch_size=batch_size, device_id=0, num_threads=4,
prefetch_queue_depth=prefetch_queue_depth,
py_start_method="spawn")
pipe.build()
utils.capture_processes(pipe._py_pool)
for _ in range(num_iterations - stop_earlier):
pipe.run()
pipe.reset()

if expected_warning is None:
with assert_no_warnings():
run_pipeline()
else:
with assert_warns(Warning, glob=expected_warning):
run_pipeline()


def test_no_op_reset_warning():
num_iterations = 5
batch_size = 8

def gen_source():
for i in range(num_iterations):
yield [np.full((1024, 1024), batch_size * i + j) for j in range(batch_size)]

def cb_source(sample_info):
if sample_info.idx_in_epoch >= 42:
raise StopIteration
return np.full((5, 5), sample_info.idx_in_epoch)

for source, batch in ((gen_source, True), (cb_source, False)):
for prefetch_queue_depth in (1, 2, 3):
for stop_earlier in range(prefetch_queue_depth):
if stop_earlier == prefetch_queue_depth - 1:
expected_warning = (
"Resetting the pipeline before any of the external sources "
"reached the end of epoch (i.e. raised StopIteration) "
"has no effect.")
elif stop_earlier > 0:
expected_warning = (
"Resetting the pipeline before all scheduled batches have been "
"consumed is discouraged and may be unsupported in the future.")
else:
expected_warning = None
yield _test_no_op_reset_warning, stop_earlier, prefetch_queue_depth, \
source, batch, batch_size, num_iterations, expected_warning