diff --git a/dali/operators/reader/caffe2_reader_op.h b/dali/operators/reader/caffe2_reader_op.h index dedadf14db..09788a1fea 100644 --- a/dali/operators/reader/caffe2_reader_op.h +++ b/dali/operators/reader/caffe2_reader_op.h @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2017-2018, 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,12 +21,13 @@ namespace dali { -class Caffe2Reader : public DataReader> { +class Caffe2Reader : public DataReader, Tensor, true> { public: explicit Caffe2Reader(const OpSpec& spec) - : DataReader>(spec) { + : DataReader, Tensor, true>(spec) { loader_ = InitLoader(spec); parser_.reset(new Caffe2Parser(spec)); + this->SetInitialSnapshot(); } void RunImpl(SampleWorkspace &ws) override { @@ -35,7 +36,7 @@ class Caffe2Reader : public DataReader> { } protected: - USE_READER_OPERATOR_MEMBERS(CPUBackend, Tensor); + USE_READER_OPERATOR_MEMBERS(CPUBackend, Tensor, Tensor, true); }; } // namespace dali diff --git a/dali/operators/reader/caffe_reader_op.h b/dali/operators/reader/caffe_reader_op.h index dd90ac4629..aca749f981 100644 --- a/dali/operators/reader/caffe_reader_op.h +++ b/dali/operators/reader/caffe_reader_op.h @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2017-2018, 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,12 +21,13 @@ namespace dali { -class CaffeReader : public DataReader> { +class CaffeReader : public DataReader, Tensor, true> { public: explicit CaffeReader(const OpSpec& spec) - : DataReader>(spec) { + : DataReader, Tensor, true>(spec) { loader_ = InitLoader(spec); parser_.reset(new CaffeParser(spec)); + this->SetInitialSnapshot(); } void RunImpl(SampleWorkspace &ws) override { @@ -35,7 +36,7 @@ class CaffeReader : public DataReader> { } protected: - USE_READER_OPERATOR_MEMBERS(CPUBackend, Tensor); + USE_READER_OPERATOR_MEMBERS(CPUBackend, Tensor, Tensor, true); }; } // namespace dali diff --git a/dali/operators/reader/loader/lmdb.h b/dali/operators/reader/loader/lmdb.h index 51727c185d..f599c0eaf2 100644 --- a/dali/operators/reader/loader/lmdb.h +++ b/dali/operators/reader/loader/lmdb.h @@ -145,7 +145,7 @@ static int find_lower_bound(const std::vector& a, Index x) { return -1; } -class LMDBLoader : public Loader> { +class LMDBLoader : public Loader, true> { public: explicit LMDBLoader(const OpSpec& options) : Loader(options) { @@ -204,6 +204,10 @@ class LMDBLoader : public Loader> { value.mv_size * sizeof(uint8_t)); } + void Skip() override { + MoveToNextShard(++current_index_); + } + protected: Index SizeImpl() override { return offsets_.size() > 0 ? offsets_.back() : 0; @@ -224,7 +228,7 @@ class LMDBLoader : public Loader> { void Reset(bool wrap_to_shard) override { // work out how many entries to move forward to handle sharding if (wrap_to_shard) { - current_index_ = start_index(shard_id_, num_shards_, SizeImpl()); + current_index_ = start_index(virtual_shard_id_, num_shards_, SizeImpl()); } else { current_index_ = 0; } @@ -233,8 +237,8 @@ class LMDBLoader : public Loader> { mdb_[file_index].SeekByIndex(local_index); } - using Loader>::shard_id_; - using Loader>::num_shards_; + using Loader, true>::virtual_shard_id_; + using Loader, true>::num_shards_; std::vector mdb_; diff --git a/dali/test/python/checkpointing/test_dali_checkpointing.py b/dali/test/python/checkpointing/test_dali_checkpointing.py index 9eae2c4384..0f814ce2e3 100644 --- a/dali/test/python/checkpointing/test_dali_checkpointing.py +++ b/dali/test/python/checkpointing/test_dali_checkpointing.py @@ -267,6 +267,62 @@ def test_sequence_reader(num_epochs, batch_size, shard_id, num_shards, initial_fill=initial_fill) +@params( + (1, 3, 0, 1, True, True, True, 1), + (5, 5, 1, 3, True, True, False, 2), + (6, 7, 2, 3, True, False, True, 3), + (5, 3, 0, 1, True, False, False, 1), + (7, 5, 2, 3, False, True, True, None), + (4, 1, 1, 2, False, True, False, 2), + (0, 3, 3, 4, False, False, True, None), + (1, 4, 2, 3, False, False, False, 3), +) +def test_caffe_reader( + num_epochs, batch_size, shard_id, num_shards, + random_shuffle, stick_to_shard, pad_last_batch, + iters_into_epoch=None, initial_fill=1024): + + caffe_dir = os.path.join(data_root, 'db', 'lmdb') + + check_reader_checkpointing( + fn.readers.caffe, num_epochs, batch_size, iters_into_epoch, + path=caffe_dir, + pad_last_batch=pad_last_batch, + random_shuffle=random_shuffle, + shard_id=shard_id, + num_shards=num_shards, + stick_to_shard=stick_to_shard, + initial_fill=initial_fill) + + +@params( + (1, 2, 0, 2, True, True, True, 1), + (4, 4, 1, 2, True, True, False, 2), + (5, 6, 0, 2, True, False, True, None), + (6, 2, 1, 3, True, False, False, 1), + (3, 4, 3, 4, False, True, True, 2), + (8, 1, 2, 3, False, True, False, None), + (0, 2, 4, 5, False, False, True, None), + (3, 3, 1, 3, False, False, False, 2), +) +def test_caffe2_reader( + num_epochs, batch_size, shard_id, num_shards, + random_shuffle, stick_to_shard, pad_last_batch, + iters_into_epoch=None, initial_fill=1024): + + caffe2_dir = os.path.join(data_root, 'db', 'c2lmdb') + + check_reader_checkpointing( + fn.readers.caffe2, num_epochs, batch_size, iters_into_epoch, + path=caffe2_dir, + pad_last_batch=pad_last_batch, + random_shuffle=random_shuffle, + shard_id=shard_id, + num_shards=num_shards, + stick_to_shard=stick_to_shard, + initial_fill=initial_fill) + + @attr('pytorch') @params( (1, 3, 0, 1, True, False, False),