Skip to content

Commit

Permalink
Support checkpointing in PaddlePaddle iterator (#5279)
Browse files Browse the repository at this point in the history
This PR adds checkpointing support to PaddlePaddle iterator.

---------

Signed-off-by: Szymon Karpiński <skarpinski@nvidia.com>
  • Loading branch information
szkarpinski committed Feb 9, 2024
1 parent 1b60777 commit 08e8ea2
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 8 deletions.
16 changes: 9 additions & 7 deletions dali/python/nvidia/dali/plugin/paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,6 @@ def __init__(
prepare_first_batch=prepare_first_batch,
)

self._counter = 0

self._first_batch = None
if self._prepare_first_batch:
try:
Expand All @@ -269,11 +267,15 @@ def __init__(
# here we should set if to False again
self._ever_consumed = False
except StopIteration:
assert False, (
"It seems that there is no data in the pipeline. This may happen "
"if `last_batch_policy` is set to PARTIAL and the requested batch size is "
"greater than the shard size."
)
# This case might not be an error if we're iterating over pipeline that is
# currently at the end of epoch, for example because it was restored from
# checkpoint.
if all(not p.is_restored_from_checkpoint or p._first_iter for p in self._pipes):
raise RuntimeError(
"It seems that there is no data in the pipeline. This may happen "
"if `last_batch_policy` is set to PARTIAL and the requested batch size is "
"greater than the shard size."
)

def __next__(self):
self._ever_consumed = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import nvidia.dali.fn as fn
from nvidia.dali.pipeline import pipeline_def
from nose2.tools import params, cartesian_params
import numpy as np
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
from nose import SkipTest

Expand Down Expand Up @@ -410,3 +411,24 @@ def supported_last_batch_policies(self):
(LastBatchPolicy.DROP, False),
(LastBatchPolicy.FILL, True),
)


class TestPaddle(FwTestBase):
def __init__(self):
super().__init__()
from nvidia.dali.plugin.paddle import DALIGenericIterator as PaddlePaddleIterator

self.FwIterator = PaddlePaddleIterator

def equal(self, a, b):
return (np.array(a) == np.array(b)).all()

def supported_last_batch_policies(self):
return (
# (last_batch_policy, pad_last_batch)
(LastBatchPolicy.DROP, True),
(LastBatchPolicy.DROP, False),
(LastBatchPolicy.FILL, True),
(LastBatchPolicy.PARTIAL, False),
(LastBatchPolicy.PARTIAL, True),
)
2 changes: 1 addition & 1 deletion dali/test/python/test_fw_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,7 +1784,7 @@ def check_paddle_iterator_pass_reader_name(

if batch_size > data_set_size // shards_num and last_batch_policy == LastBatchPolicy.DROP:
assert_raises(
AssertionError,
RuntimeError,
PaddleIterator,
pipes,
output_map=["data"],
Expand Down

0 comments on commit 08e8ea2

Please sign in to comment.