From d7cfbe7e9c72aaeae0e7f740a784118ed1439777 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Staniewski?= Date: Tue, 1 Aug 2023 10:45:37 +0200 Subject: [PATCH 1/2] Work in progress MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Michał Staniewski --- dali/operators/reader/file_reader_op.h | 14 +-- dali/operators/reader/loader/coco_loader.h | 4 +- .../reader/loader/file_label_loader.cc | 12 +-- .../reader/loader/file_label_loader.h | 44 +-------- dali/operators/reader/loader/loader.h | 96 +------------------ dali/operators/reader/reader_op.h | 68 +++++++++---- dali/operators/reader/reader_op_test.cc | 47 +++++---- 7 files changed, 86 insertions(+), 199 deletions(-) diff --git a/dali/operators/reader/file_reader_op.h b/dali/operators/reader/file_reader_op.h index 2491c32710..91a37294b7 100644 --- a/dali/operators/reader/file_reader_op.h +++ b/dali/operators/reader/file_reader_op.h @@ -24,10 +24,10 @@ namespace dali { -class FileReader : public DataReader { +class FileReader : public DataReader { public: explicit FileReader(const OpSpec& spec) - : DataReader(spec) { + : DataReader(spec) { bool shuffle_after_epoch = spec.GetArgument("shuffle_after_epoch"); loader_ = InitLoader(spec, shuffle_after_epoch); } @@ -54,16 +54,8 @@ class FileReader : public DataReader()[0] = image_label.label; } - void SaveState(OpCheckpoint &cpt, std::optional stream) override { - cpt.MutableCheckpointState() = loader_->PopStateSnapshot(); - } - - void RestoreState(const OpCheckpoint &cpt) override { - loader_->RestoreStateFromSnapshot(cpt.CheckpointState()); - } - protected: - USE_READER_OPERATOR_MEMBERS(CPUBackend, ImageLabelWrapper, ImageLabelWrapper, true); + USE_READER_OPERATOR_MEMBERS(CPUBackend, ImageLabelWrapper); }; } // namespace dali diff --git a/dali/operators/reader/loader/coco_loader.h b/dali/operators/reader/loader/coco_loader.h index b65ee13bce..bcbfe92298 100644 --- a/dali/operators/reader/loader/coco_loader.h +++ b/dali/operators/reader/loader/coco_loader.h @@ -97,10 +97,10 @@ struct RLEMask : public UniqueHandle { using RLEMaskPtr = std::shared_ptr; -class DLL_PUBLIC CocoLoader : public FileLabelLoaderBase { +class DLL_PUBLIC CocoLoader : public FileLabelLoader { public: explicit inline CocoLoader(const OpSpec &spec) - : FileLabelLoaderBase(spec, spec.GetArgument("shuffle_after_epoch")) + : FileLabelLoader(spec, spec.GetArgument("shuffle_after_epoch")) , spec_(spec) { has_preprocessed_annotations_ = HasPreprocessedAnnotations(spec); DALI_ENFORCE(has_preprocessed_annotations_ || spec.HasArgument("annotations_file"), diff --git a/dali/operators/reader/loader/file_label_loader.cc b/dali/operators/reader/loader/file_label_loader.cc index 46c67c4ada..8356d5582b 100644 --- a/dali/operators/reader/loader/file_label_loader.cc +++ b/dali/operators/reader/loader/file_label_loader.cc @@ -23,13 +23,11 @@ namespace dali { using filesystem::dir_sep; -template -void FileLabelLoaderBase::PrepareEmpty(ImageLabelWrapper &image_label) { +void FileLabelLoader::PrepareEmpty(ImageLabelWrapper &image_label) { PrepareEmptyTensor(image_label.image); } -template -void FileLabelLoaderBase::ReadSample(ImageLabelWrapper &image_label) { +void FileLabelLoader::ReadSample(ImageLabelWrapper &image_label) { auto image_pair = image_label_pairs_[current_index_++]; // handle wrap-around @@ -75,12 +73,8 @@ void FileLabelLoaderBase::ReadSample(ImageLabelWrapper image_label.image.SetMeta(meta); } -template -Index FileLabelLoaderBase::SizeImpl() { +Index FileLabelLoader::SizeImpl() { return static_cast(image_label_pairs_.size()); } -template class FileLabelLoaderBase; -template class FileLabelLoaderBase; - } // namespace dali diff --git a/dali/operators/reader/loader/file_label_loader.h b/dali/operators/reader/loader/file_label_loader.h index 51a90ea369..6e56c40dbe 100755 --- a/dali/operators/reader/loader/file_label_loader.h +++ b/dali/operators/reader/loader/file_label_loader.h @@ -38,15 +38,12 @@ struct ImageLabelWrapper { int label; }; -template -class DLL_PUBLIC FileLabelLoaderBase : public Loader { +class DLL_PUBLIC FileLabelLoader : public Loader { public: - using Base = Loader; - explicit inline FileLabelLoaderBase( + explicit inline FileLabelLoader( const OpSpec& spec, bool shuffle_after_epoch = false) - : Base(spec), + : Loader(spec), shuffle_after_epoch_(shuffle_after_epoch), current_index_(0), current_epoch_(0) { @@ -186,18 +183,12 @@ class DLL_PUBLIC FileLabelLoaderBase : public Loader::shard_id_; string file_root_, file_list_; vector> image_label_pairs_; - vector> backup_image_label_pairs_; vector filters_; bool has_files_arg_ = false; @@ -251,8 +219,6 @@ class DLL_PUBLIC FileLabelLoaderBase : public Loader; - } // namespace dali #endif // DALI_OPERATORS_READER_LOADER_FILE_LABEL_LOADER_H_ diff --git a/dali/operators/reader/loader/loader.h b/dali/operators/reader/loader/loader.h index 868097a301..60caea4e24 100644 --- a/dali/operators/reader/loader/loader.h +++ b/dali/operators/reader/loader/loader.h @@ -43,14 +43,6 @@ DLL_PUBLIC size_t start_index(const size_t shard_id, DLL_PUBLIC Index num_samples(const size_t shard_num, const size_t size); -/** - * @brief Structure describing Loader base state, at the begining of an epoch. -*/ -struct LoaderStateSnapshot { - std::default_random_engine rng; - int current_epoch; -}; - /** * @brief Base class for Loaders, responsible for reading samples from resource of some kind * into memory. @@ -59,8 +51,7 @@ struct LoaderStateSnapshot { * @tparam LoadTarget Type into which samples are loaded. * @tparam supports_checkpointing A marker for checkpointing support. */ -template +template class Loader { public: using LoadTargetUniquePtr = std::unique_ptr; @@ -71,9 +62,6 @@ class Loader { initial_empty_size_(2 * options.GetArgument("prefetch_queue_depth") * options.GetArgument("max_batch_size")), tensor_init_bytes_(options.GetArgument("tensor_init_bytes")), - state_queue_front_(0), - state_queue_back_(0), - checkpoint_epoch_(0), seed_(options.GetArgument("seed")), shard_id_(options.GetArgument("shard_id")), num_shards_(options.GetArgument("num_shards")), @@ -95,31 +83,6 @@ class Loader { std::seed_seq seq({seed_}); e_ = std::default_random_engine(seq); virtual_shard_id_ = shard_id_; - - // TODO(mstaniewski): add a proper internal argument in schema - if (!options.TryGetArgument(checkpointing_, "checkpointing")) { - checkpointing_ = false; - } - - if (checkpointing_) { - DALI_ENFORCE(supports_checkpointing, "Checkpointing is disabled for this loader. "); - - // TODO(mstaniewski): support pad_last_batch=false - DALI_ENFORCE(pad_last_batch_, - "Currently, checkpointing is only supported with pad_last_batch=true"); - - /* - * A checkpoint is created every time the prefetching thread starts working - * on a new epoch. Therefore, we are guaranteed, there will be at most - * prefetch_queue_depth checkpoints waiting in the queue at a time. - * - * The +1 is added, because there could be a situation, where a batch is - * collected from the prefetch_queue (possibly leading to creation of another checkpoint), - * but the checkpoint from the corresponding epoch is not yet collected. - */ - state_queue_.resize(options.GetArgument("prefetch_queue_depth") + 1); - PushStateSnapshot(); - } } virtual ~Loader() { @@ -155,47 +118,6 @@ class Loader { DALI_ERROR("Please overload PrepareEmpty for custom LoadTarget type other than Tensor"); } - /** - * @brief Called when loader is moving to the next shard, - * to create a new snapshot and store it in the inner queue. - */ - void PushStateSnapshot() { - std::lock_guard lock(state_queue_mutex_); - state_queue_[state_queue_back_].rng = e_; - state_queue_[state_queue_back_].current_epoch = checkpoint_epoch_++; - state_queue_back_ = (state_queue_back_ + 1) % state_queue_.size(); - } - - /** - * @brief Collects a state snapshot from the inner queue. - */ - LoaderStateSnapshot PopStateSnapshot() { - DALI_ENFORCE(checkpointing_, "PopStateSnapshot called, but checkpointing is not enabled. "); - std::lock_guard lock(state_queue_mutex_); - auto result = state_queue_[state_queue_front_]; - state_queue_front_ = (state_queue_front_ + 1) % state_queue_.size(); - return result; - } - - /** - * @brief Restores the loader's state from a snapshot. - */ - void RestoreStateFromSnapshot(const LoaderStateSnapshot &state) { - e_ = state.rng; - checkpoint_epoch_ = state.current_epoch; - if (!stick_to_shard_) - virtual_shard_id_ = (shard_id_ + state.current_epoch) % num_shards_; - - RestoreStateImpl(state); - - // Re-run reset - Reset(true); - - // Reset checkpointing - state_queue_front_ = state_queue_back_; - PushStateSnapshot(); - } - // Get a random read sample LoadTargetSharedPtr ReadOne(bool is_new_batch) { PrepareMetadata(); @@ -244,11 +166,6 @@ class Loader { // remove shard that was fully consumed shards_.pop_front(); returned_sample_counter_ = 0; - - if (checkpointing_) { - // Create a checkpoint before processing the new shard - PushStateSnapshot(); - } } // choose the random index @@ -349,9 +266,6 @@ class Loader { // Reset reader to the first sample virtual void Reset(bool wrap_to_shard) = 0; - // Overloadable method to handle restoring state in subclasses - virtual void RestoreStateImpl(const LoaderStateSnapshot &state) {} - // Check if given reader moved to the next shard virtual inline bool IsNextShard(Index current_index) { return current_index >= Size() || @@ -410,14 +324,6 @@ class Loader { const int tensor_init_bytes_; bool initial_buffer_filled_ = false; - // when enabled, the loader creates a checkpoint at the start of every epoch. - bool checkpointing_; - std::vector state_queue_; - Index state_queue_front_; - Index state_queue_back_; - int checkpoint_epoch_; - std::mutex state_queue_mutex_; - // rng std::default_random_engine e_; Index seed_; diff --git a/dali/operators/reader/reader_op.h b/dali/operators/reader/reader_op.h index 2fc905c3c2..2bdf6036bc 100644 --- a/dali/operators/reader/reader_op.h +++ b/dali/operators/reader/reader_op.h @@ -31,6 +31,19 @@ namespace dali { +template +struct DataReaderCheckpoint { + using LoadTargetPtr = std::shared_ptr; + using BatchQueueElement = std::vector; + + std::vector prefetched_batch_queue_; + int curr_batch_consumer_; + int curr_batch_producer_; + bool consumer_cycle_; + bool producer_cycle_; +}; + /** * @brief BaseClass for operators that perform prefetching work * @@ -49,7 +62,7 @@ namespace dali { * @tparam supports_checkpointing A marker for checkpointing support. */ template + typename ParseTarget = LoadTarget> class DataReader : public Operator { public: using LoadTargetPtr = std::shared_ptr; @@ -64,8 +77,7 @@ class DataReader : public Operator { curr_batch_producer_(0), consumer_cycle_(false), producer_cycle_(false), - device_id_(-1), - samples_processed_(0) { + device_id_(-1) { if (std::is_same::value) { device_id_ = spec.GetArgument("device_id"); } @@ -224,6 +236,36 @@ class DataReader : public Operator { prefetched_batch_queue_[curr_batch_consumer_].clear(); } + using Checkpoint = DataReaderCheckpoint; + + void SaveState(OpCheckpoint &cpt, std::optional stream) override { + { + // Wait until the prefetch queue is full. + std::unique_lock prefetch_lock(prefetch_access_mutex_); + consumer_.wait(prefetch_lock, [this]() { return finished_ || IsPrefetchQueueFull(); }); + if (prefetch_error_) std::rethrow_exception(prefetch_error_); + } + + // Now, the prefetching thread is idle, because there is no space in queue. + // We can safely access prefetching data. + cpt.MutableCheckpointState() = Checkpoint { + prefetched_batch_queue_, + curr_batch_consumer_, + curr_batch_producer_, + consumer_cycle_, + producer_cycle_, + }; + } + + void RestoreState(const OpCheckpoint &cpt) override { + const auto &state = cpt.CheckpointState(); + prefetched_batch_queue_ = state.prefetched_batch_queue_; + curr_batch_consumer_ = state.curr_batch_consumer_; + curr_batch_producer_ = state.curr_batch_producer_; + consumer_cycle_ = state.consumer_cycle_; + producer_cycle_ = state.producer_cycle_; + } + protected: void ParseIfNeeded(const Tensor& tensor, SampleWorkspace* ws) { using OutputCache = std::unordered_map>>; @@ -352,14 +394,11 @@ class DataReader : public Operator { bool producer_cycle_; int device_id_; - // keep track of how many samples have been processed over all threads. - std::atomic samples_processed_; - // stores any catched exceptions in the prefetch worker std::exception_ptr prefetch_error_; // Loader - std::unique_ptr> loader_; + std::unique_ptr> loader_; // Parser std::unique_ptr> parser_; @@ -375,19 +414,10 @@ class DataReader : public Operator { using DataReader::parser_; \ using DataReader::prefetched_batch_queue_; -#define USE_READER_OPERATOR_MEMBERS_3(Backend, LoadTarget, ParseTarget, supports_checkpointing) \ - using DataReader::loader_; \ - using DataReader::parser_; \ - using DataReader::prefetched_batch_queue_; - -#define GET_MACRO3(_1, _2, _3, NAME, ...) NAME - #define USE_READER_OPERATOR_MEMBERS(Backend, ...) \ - GET_MACRO3(__VA_ARGS__, \ - USE_READER_OPERATOR_MEMBERS_3, \ - USE_READER_OPERATOR_MEMBERS_2, \ - USE_READER_OPERATOR_MEMBERS_1)(Backend, __VA_ARGS__) + GET_MACRO(__VA_ARGS__, \ + USE_READER_OPERATOR_MEMBERS_2, \ + USE_READER_OPERATOR_MEMBERS_1)(Backend, __VA_ARGS__) }; // namespace dali diff --git a/dali/operators/reader/reader_op_test.cc b/dali/operators/reader/reader_op_test.cc index 8067687f3e..f7aa516cd1 100644 --- a/dali/operators/reader/reader_op_test.cc +++ b/dali/operators/reader/reader_op_test.cc @@ -512,31 +512,24 @@ class FileReaderTest : public DALITest { pipe.Build(outputs); } - std::vector RunEpoch(Pipeline &pipe, int batch_size, - int num_shards = 1, bool stick_to_shard = false) { + std::vector RunIteration(Pipeline &pipe) { std::vector result; - int samples_per_shard = (filepaths_.size() + num_shards - 1) / num_shards; - int batches_per_shard = (samples_per_shard + batch_size - 1) / batch_size; - for (int it = 0; it < batches_per_shard; it++) { - pipe.RunCPU(); - pipe.RunGPU(); - pipe.Outputs(&ws_); - - auto shape = ws_.Output(0).AsTensor().shape(); - for (int nr = 0; nr < shape[0]; nr++) - result.push_back(ws_.Output(0).tensor(0)[nr]); - } + pipe.RunCPU(); + pipe.RunGPU(); + pipe.Outputs(&ws_); + + auto shape = ws_.Output(0).AsTensor().shape(); + for (int nr = 0; nr < shape[0]; nr++) + result.push_back(ws_.Output(0).tensor(0)[nr]); + return result; } - std::pair, OpCheckpoint> CheckpointEpoch(Pipeline &pipe, int batch_size, - int epoch_nr, int num_shards = 1, - bool stick_to_shard = false) { + std::pair, OpCheckpoint> CheckpointIteration(Pipeline &pipe) { auto node = pipe.GetOperatorNode("file_reader"); OpCheckpoint cpt(node->spec); node->op->SaveState(cpt, std::nullopt); - EXPECT_EQ(cpt.CheckpointState().current_epoch, epoch_nr); - return {RunEpoch(pipe, batch_size, num_shards, stick_to_shard), cpt}; + return {RunIteration(pipe), cpt}; } void RestoreCheckpointedState(Pipeline &pipe, const OpCheckpoint &cpt) { @@ -550,34 +543,39 @@ class FileReaderTest : public DALITest { TEST_F(FileReaderTest, SimpleCheckpointing) { constexpr int batch_size = 3; - constexpr int epochs = 4; + constexpr int iterations = 15; auto prepare_pipeline = [this](Pipeline &pipe) { pipe.AddOperator( MakeOpSpec() - .AddArg("prefetch_queue_depth", 20) + .AddArg("prefetch_queue_depth", 5) .AddArg("initial_fill", 3), "file_reader"); BuildPipeline(pipe); }; Pipeline pipe(batch_size, 1, 0); prepare_pipeline(pipe); + + // Initial run + RunIteration(pipe); + std::vector> results; std::vector checkpoints; - for (int i = 0; i < epochs; i++) { - auto [res, cpt] = CheckpointEpoch(pipe, batch_size, i); + for (int i = 0; i < iterations; i++) { + auto [res, cpt] = CheckpointIteration(pipe); results.push_back(res); checkpoints.push_back(cpt); } - for (int i = 0; i < epochs; i++) { + for (int i = 0; i < iterations; i++) { Pipeline fresh_pipe(batch_size, 1, 0); prepare_pipeline(fresh_pipe); RestoreCheckpointedState(fresh_pipe, checkpoints[i]); - EXPECT_EQ(RunEpoch(fresh_pipe, batch_size), results[i]); + EXPECT_EQ(RunIteration(fresh_pipe), results[i]); } } +/* TEST_F(FileReaderTest, CheckpointingRandomShuffle) { constexpr int batch_size = 7; constexpr int epochs = 8; @@ -749,5 +747,6 @@ TEST_F(FileReaderTest, CheckpointingResumeThenSave) { } } } +*/ }; // namespace dali From 62e3497efb514013888a0d9f84a24b89273be7b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Staniewski?= Date: Wed, 2 Aug 2023 11:16:42 +0200 Subject: [PATCH 2/2] Working draft MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Michał Staniewski --- .../reader/loader/file_label_loader.h | 32 ++- dali/operators/reader/loader/loader.h | 81 ++++++- dali/operators/reader/reader_op.h | 36 +++- dali/operators/reader/reader_op_test.cc | 198 ++++++------------ 4 files changed, 207 insertions(+), 140 deletions(-) diff --git a/dali/operators/reader/loader/file_label_loader.h b/dali/operators/reader/loader/file_label_loader.h index 6e56c40dbe..a53fcc5e58 100755 --- a/dali/operators/reader/loader/file_label_loader.h +++ b/dali/operators/reader/loader/file_label_loader.h @@ -38,6 +38,12 @@ struct ImageLabelWrapper { int label; }; +struct FileLabelLoaderCheckpoint { + vector> image_label_pairs_; + Index current_index_; + int current_epoch_; +}; + class DLL_PUBLIC FileLabelLoader : public Loader { public: explicit inline FileLabelLoader( @@ -105,7 +111,7 @@ class DLL_PUBLIC FileLabelLoader : public Loader * DALI instances will do shuffling after each epoch */ DALI_ENFORCE(!(shuffle_after_epoch_ && stick_to_shard_), - "shuffle_after_epoch and stick_to_shard cannot be both true"); + "shu/fileffle_after_epoch and stick_to_shard cannot be both true"); DALI_ENFORCE(!(shuffle_after_epoch_ && shuffle_), "shuffle_after_epoch and random_shuffle cannot be both true"); /* @@ -124,6 +130,15 @@ class DLL_PUBLIC FileLabelLoader : public Loader void PrepareEmpty(ImageLabelWrapper &tensor) override; void ReadSample(ImageLabelWrapper &tensor) override; + ImageLabelWrapper CopyTarget(const ImageLabelWrapper &target) override { + Tensor tensor; + tensor.Copy(target.image); + return ImageLabelWrapper { + std::move(tensor), + target.label + }; + } + protected: Index SizeImpl() override; @@ -201,6 +216,21 @@ class DLL_PUBLIC FileLabelLoader : public Loader } } + std::any SaveStateImpl() override { + return FileLabelLoaderCheckpoint { + image_label_pairs_, + current_index_, + current_epoch_, + }; + } + + void RestoreStateImpl(const std::any &opaque_state) override { + const auto &state = std::any_cast(opaque_state); + image_label_pairs_ = state.image_label_pairs_; + current_index_ = state.current_index_; + current_epoch_ = state.current_epoch_; + } + using Loader::shard_id_; string file_root_, file_list_; diff --git a/dali/operators/reader/loader/loader.h b/dali/operators/reader/loader/loader.h index 60caea4e24..957a1aa8ba 100644 --- a/dali/operators/reader/loader/loader.h +++ b/dali/operators/reader/loader/loader.h @@ -26,6 +26,7 @@ #include #include #include +#include #include "dali/core/nvtx.h" #include "dali/core/common.h" @@ -43,6 +44,27 @@ DLL_PUBLIC size_t start_index(const size_t shard_id, DLL_PUBLIC Index num_samples(const size_t shard_num, const size_t size); +struct ShardBoundaries { + Index start; + Index end; +}; + +template +struct LoaderCheckpoint { + using LoadTargetSharedPtr = std::shared_ptr; + + std::default_random_engine e_; + std::vector sample_buffer_; + bool initial_buffer_filled_ = false; + Index read_sample_counter_; + Index returned_sample_counter_; + int virtual_shard_id_; + LoadTargetSharedPtr last_sample_ptr_tmp; + std::deque shards_; + + std::any subclass_state_; +}; + /** * @brief Base class for Loaders, responsible for reading samples from resource of some kind * into memory. @@ -253,6 +275,57 @@ class Loader { return stick_to_shard_; } + using Checkpoint = LoaderCheckpoint; + + Checkpoint SaveState() { + std::vector copied_buffer_; + for (const auto &sample : sample_buffer_) + copied_buffer_.push_back(std::make_shared(CopyTarget(*sample))); + + auto copied_last_sample_= std::make_shared(CopyTarget(*last_sample_ptr_tmp)); + + return Checkpoint { + e_, + std::move(copied_buffer_), + initial_buffer_filled_, + read_sample_counter_, + returned_sample_counter_, + virtual_shard_id_, + copied_last_sample_, + shards_, + SaveStateImpl(), + }; + } + + void RestoreState(const Checkpoint &state) { + e_ = state.e_; + // sample_buffer_ = std::move(state.sample_buffer_); + initial_buffer_filled_ = state.initial_buffer_filled_; + read_sample_counter_ = state.read_sample_counter_; + returned_sample_counter_ = state.returned_sample_counter_; + virtual_shard_id_ = state.virtual_shard_id_; + last_sample_ptr_tmp = state.last_sample_ptr_tmp; + shards_ = state.shards_; + + for (const auto &sample : state.sample_buffer_) + sample_buffer_.push_back(std::make_unique(CopyTarget(*sample))); + + if (initial_buffer_filled_) { + std::lock_guard lock(empty_tensors_mutex_); + for (int i = 0; i < initial_empty_size_; ++i) { + auto tensor_ptr = LoadTargetUniquePtr(new LoadTarget()); + PrepareEmpty(*tensor_ptr); + empty_tensors_.push_back(std::move(tensor_ptr)); + } + } + + RestoreStateImpl(state.subclass_state_); + } + + virtual LoadTarget CopyTarget(const LoadTarget &target) { + DALI_FAIL("This loader does not support checkpointing. "); + } + protected: virtual Index SizeImpl() = 0; @@ -312,6 +385,9 @@ class Loader { return cache_ && cache_->IsCached(key); } + virtual std::any SaveStateImpl() { return {}; } + virtual void RestoreStateImpl(const std::any &opaque_state) {} + std::vector sample_buffer_; std::vector empty_tensors_; @@ -375,11 +451,6 @@ class Loader { // Keeps pointer to the last returned sample just in case it needs to be cloned LoadTargetSharedPtr last_sample_ptr_tmp; - struct ShardBoundaries { - Index start; - Index end; - }; - std::deque shards_; }; diff --git a/dali/operators/reader/reader_op.h b/dali/operators/reader/reader_op.h index 2bdf6036bc..db154e6a7b 100644 --- a/dali/operators/reader/reader_op.h +++ b/dali/operators/reader/reader_op.h @@ -37,11 +37,28 @@ struct DataReaderCheckpoint { using LoadTargetPtr = std::shared_ptr; using BatchQueueElement = std::vector; + /* DataReader state */ std::vector prefetched_batch_queue_; int curr_batch_consumer_; int curr_batch_producer_; bool consumer_cycle_; bool producer_cycle_; + + LoaderCheckpoint loader_state_; + + DataReaderCheckpoint( + std::vector &&batch_queue_copy, + int curr_batch_consumer, + int curr_batch_producer, + bool consumer_cycle, + bool producer_cycle, + LoaderCheckpoint &&loader_state) + : prefetched_batch_queue_(std::move(batch_queue_copy)) + , curr_batch_consumer_(curr_batch_consumer) + , curr_batch_producer_(curr_batch_producer) + , consumer_cycle_(consumer_cycle) + , producer_cycle_(producer_cycle) + , loader_state_(std::move(loader_state)) {} }; /** @@ -239,6 +256,9 @@ class DataReader : public Operator { using Checkpoint = DataReaderCheckpoint; void SaveState(OpCheckpoint &cpt, std::optional stream) override { + // Make sure the prefetch thread is running + StartPrefetchThread(); + { // Wait until the prefetch queue is full. std::unique_lock prefetch_lock(prefetch_access_mutex_); @@ -246,15 +266,24 @@ class DataReader : public Operator { if (prefetch_error_) std::rethrow_exception(prefetch_error_); } + std::vector batch_queue_copy; + for (const auto &batch : prefetched_batch_queue_) { + batch_queue_copy.emplace_back(); + for (const auto &sample : batch) + batch_queue_copy.back().push_back( + std::make_shared(loader_->CopyTarget(*sample))); + } + // Now, the prefetching thread is idle, because there is no space in queue. // We can safely access prefetching data. - cpt.MutableCheckpointState() = Checkpoint { - prefetched_batch_queue_, + cpt.MutableCheckpointState() = (Checkpoint { + std::move(batch_queue_copy), curr_batch_consumer_, curr_batch_producer_, consumer_cycle_, producer_cycle_, - }; + loader_->SaveState(), + }); } void RestoreState(const OpCheckpoint &cpt) override { @@ -264,6 +293,7 @@ class DataReader : public Operator { curr_batch_producer_ = state.curr_batch_producer_; consumer_cycle_ = state.consumer_cycle_; producer_cycle_ = state.producer_cycle_; + loader_->RestoreState(state.loader_state_); } protected: diff --git a/dali/operators/reader/reader_op_test.cc b/dali/operators/reader/reader_op_test.cc index f7aa516cd1..a5bb2a2ab0 100644 --- a/dali/operators/reader/reader_op_test.cc +++ b/dali/operators/reader/reader_op_test.cc @@ -525,11 +525,11 @@ class FileReaderTest : public DALITest { return result; } - std::pair, OpCheckpoint> CheckpointIteration(Pipeline &pipe) { + OpCheckpoint MakeCheckpoint(Pipeline &pipe) { auto node = pipe.GetOperatorNode("file_reader"); OpCheckpoint cpt(node->spec); node->op->SaveState(cpt, std::nullopt); - return {RunIteration(pipe), cpt}; + return cpt; } void RestoreCheckpointedState(Pipeline &pipe, const OpCheckpoint &cpt) { @@ -543,7 +543,7 @@ class FileReaderTest : public DALITest { TEST_F(FileReaderTest, SimpleCheckpointing) { constexpr int batch_size = 3; - constexpr int iterations = 15; + constexpr int iterations = 25; auto prepare_pipeline = [this](Pipeline &pipe) { pipe.AddOperator( @@ -553,93 +553,80 @@ TEST_F(FileReaderTest, SimpleCheckpointing) { BuildPipeline(pipe); }; - Pipeline pipe(batch_size, 1, 0); - prepare_pipeline(pipe); + Pipeline original_pipe(batch_size, 1, 0); + Pipeline restored_pipe(batch_size, 1, 0); + prepare_pipeline(original_pipe); + prepare_pipeline(restored_pipe); - // Initial run - RunIteration(pipe); + for (int i = 0; i < iterations; i++) + RunIteration(original_pipe); - std::vector> results; - std::vector checkpoints; - for (int i = 0; i < iterations; i++) { - auto [res, cpt] = CheckpointIteration(pipe); - results.push_back(res); - checkpoints.push_back(cpt); - } + auto cpt = MakeCheckpoint(original_pipe); + RestoreCheckpointedState(restored_pipe, cpt); - for (int i = 0; i < iterations; i++) { - Pipeline fresh_pipe(batch_size, 1, 0); - prepare_pipeline(fresh_pipe); - RestoreCheckpointedState(fresh_pipe, checkpoints[i]); - EXPECT_EQ(RunIteration(fresh_pipe), results[i]); - } + for (int i = 0; i < iterations; i++) + EXPECT_EQ(RunIteration(original_pipe), RunIteration(restored_pipe)); } -/* TEST_F(FileReaderTest, CheckpointingRandomShuffle) { constexpr int batch_size = 7; - constexpr int epochs = 8; + constexpr int iterations = 15; auto prepare_pipeline = [this](Pipeline &pipe) { pipe.AddOperator( MakeOpSpec() .AddArg("random_shuffle", true) + .AddArg("prefetch_queue_depth", 1) .AddArg("initial_fill", 3), "file_reader"); BuildPipeline(pipe); }; - Pipeline pipe(batch_size, 1, 0); - prepare_pipeline(pipe); - std::vector> results; - std::vector checkpoints; - for (int i = 0; i < epochs; i++) { - auto [res, cpt] = CheckpointEpoch(pipe, batch_size, i); - results.push_back(res); - checkpoints.push_back(cpt); - } + Pipeline original_pipe(batch_size, 1, 0); + Pipeline restored_pipe(batch_size, 1, 0); + prepare_pipeline(original_pipe); + prepare_pipeline(restored_pipe); - for (int i = 0; i < epochs; i++) { - Pipeline fresh_pipe(batch_size, 1, 0); - prepare_pipeline(fresh_pipe); - RestoreCheckpointedState(fresh_pipe, checkpoints[i]); - EXPECT_EQ(RunEpoch(fresh_pipe, batch_size), results[i]); - } + for (int i = 0; i < iterations; i++) + RunIteration(original_pipe); + + auto cpt = MakeCheckpoint(original_pipe); + RestoreCheckpointedState(restored_pipe, cpt); + + for (int i = 0; i < iterations; i++) + EXPECT_EQ(RunIteration(original_pipe), RunIteration(restored_pipe)); } TEST_F(FileReaderTest, CheckpointingShuffleAfterEpoch) { constexpr int batch_size = 5; - constexpr int epochs = 8; + constexpr int iterations = 50; auto prepare_pipeline = [this](Pipeline &pipe) { pipe.AddOperator( MakeOpSpec() .AddArg("shuffle_after_epoch", true) - .AddArg("prefetch_queue_depth", 9) + .AddArg("prefetch_queue_depth", 3) .AddArg("initial_fill", 10), "file_reader"); BuildPipeline(pipe); }; - Pipeline pipe(batch_size, 1, 0); - prepare_pipeline(pipe); - std::vector> results; - std::vector checkpoints; - for (int i = 0; i < epochs; i++) { - auto [res, cpt] = CheckpointEpoch(pipe, batch_size, i); - results.push_back(res); - checkpoints.push_back(cpt); - } + Pipeline original_pipe(batch_size, 1, 0); + Pipeline restored_pipe(batch_size, 1, 0); + prepare_pipeline(original_pipe); + prepare_pipeline(restored_pipe); - for (int i = 0; i < epochs; i++) { - Pipeline fresh_pipe(batch_size, 1, 0); - prepare_pipeline(fresh_pipe); - RestoreCheckpointedState(fresh_pipe, checkpoints[i]); - EXPECT_EQ(RunEpoch(fresh_pipe, batch_size), results[i]); - } + for (int i = 0; i < iterations; i++) + RunIteration(original_pipe); + + auto cpt = MakeCheckpoint(original_pipe); + RestoreCheckpointedState(restored_pipe, cpt); + + for (int i = 0; i < iterations; i++) + EXPECT_EQ(RunIteration(original_pipe), RunIteration(restored_pipe)); } TEST_F(FileReaderTest, CheckpointingMultipleShards) { constexpr int batch_size = 5; - constexpr int epochs = 8; + constexpr int iterations = 50; constexpr int num_shards = 5; auto prepare_pipeline = [this, num_shards](Pipeline &pipe) { @@ -652,27 +639,24 @@ TEST_F(FileReaderTest, CheckpointingMultipleShards) { BuildPipeline(pipe); }; - Pipeline pipe(batch_size, 1, 0); - prepare_pipeline(pipe); - std::vector> results; - std::vector checkpoints; - for (int i = 0; i < epochs; i++) { - auto [res, cpt] = CheckpointEpoch(pipe, batch_size, i, num_shards); - results.push_back(res); - checkpoints.push_back(cpt); - } + Pipeline original_pipe(batch_size, 1, 0); + Pipeline restored_pipe(batch_size, 1, 0); + prepare_pipeline(original_pipe); + prepare_pipeline(restored_pipe); - for (int i = 0; i < epochs; i++) { - Pipeline fresh_pipe(batch_size, 1, 0); - prepare_pipeline(fresh_pipe); - RestoreCheckpointedState(fresh_pipe, checkpoints[i]); - EXPECT_EQ(RunEpoch(fresh_pipe, batch_size, num_shards), results[i]); - } + for (int i = 0; i < iterations; i++) + RunIteration(original_pipe); + + auto cpt = MakeCheckpoint(original_pipe); + RestoreCheckpointedState(restored_pipe, cpt); + + for (int i = 0; i < iterations; i++) + EXPECT_EQ(RunIteration(original_pipe), RunIteration(restored_pipe)); } -TEST_F(FileReaderTest, CheckpointingStickStoShard) { +TEST_F(FileReaderTest, CheckpointingSticktoShard) { constexpr int batch_size = 3; - constexpr int epochs = 8; + constexpr int iterations = 50; constexpr int num_shards = 3; auto prepare_pipeline = [this, num_shards](Pipeline &pipe) { @@ -681,72 +665,24 @@ TEST_F(FileReaderTest, CheckpointingStickStoShard) { .AddArg("stick_to_shard", true) .AddArg("shard_id", 1) .AddArg("num_shards", num_shards) - .AddArg("initial_fill", 5), "file_reader"); - BuildPipeline(pipe); - }; - - Pipeline pipe(batch_size, 1, 0); - prepare_pipeline(pipe); - std::vector> results; - std::vector checkpoints; - for (int i = 0; i < epochs; i++) { - auto [res, cpt] = CheckpointEpoch(pipe, batch_size, i, num_shards, true); - results.push_back(res); - checkpoints.push_back(cpt); - } - - for (int i = 0; i < epochs; i++) { - Pipeline fresh_pipe(batch_size, 1, 0); - prepare_pipeline(fresh_pipe); - RestoreCheckpointedState(fresh_pipe, checkpoints[i]); - EXPECT_EQ(RunEpoch(fresh_pipe, batch_size, num_shards, true), results[i]); - } -} - - -TEST_F(FileReaderTest, CheckpointingResumeThenSave) { - constexpr int batch_size = 3; - constexpr int epochs = 4; - - auto prepare_pipeline = [this](Pipeline &pipe) { - pipe.AddOperator( - MakeOpSpec() - .AddArg("shuffle_after_epoch", true) .AddArg("prefetch_queue_depth", 2) - .AddArg("initial_fill", 3), "file_reader"); + .AddArg("initial_fill", 5), "file_reader"); BuildPipeline(pipe); }; - Pipeline pipe(batch_size, 1, 0); - prepare_pipeline(pipe); - std::vector> results; - std::vector checkpoints; - for (int i = 0; i < epochs; i++) { - auto [res, cpt] = CheckpointEpoch(pipe, batch_size, i); - results.push_back(res); - checkpoints.push_back(cpt); - } + Pipeline original_pipe(batch_size, 1, 0); + Pipeline restored_pipe(batch_size, 1, 0); + prepare_pipeline(original_pipe); + prepare_pipeline(restored_pipe); - for (int i = 0; i < epochs; i++) { - Pipeline intermediate_pipe(batch_size, 1, 0); - prepare_pipeline(intermediate_pipe); - RestoreCheckpointedState(intermediate_pipe, checkpoints[i]); + for (int i = 0; i < iterations; i++) + RunIteration(original_pipe); - std::vector cpts_after_resume; - for (int j = 0; i + j < epochs; j++) { - auto [res, cpt] = CheckpointEpoch(intermediate_pipe, batch_size, i + j); - EXPECT_EQ(res, results[i + j]); - cpts_after_resume.push_back(cpt); - } + auto cpt = MakeCheckpoint(original_pipe); + RestoreCheckpointedState(restored_pipe, cpt); - for (int j = 0; i + j < epochs; j++) { - Pipeline fresh_pipe(batch_size, 1, 0); - prepare_pipeline(fresh_pipe); - RestoreCheckpointedState(fresh_pipe, cpts_after_resume[j]); - EXPECT_EQ(RunEpoch(fresh_pipe, batch_size), results[i + j]); - } - } + for (int i = 0; i < iterations; i++) + EXPECT_EQ(RunIteration(original_pipe), RunIteration(restored_pipe)); } -*/ }; // namespace dali